mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Merge branch 'master' into feature/porting-to-trio
This commit is contained in:
@ -5,15 +5,12 @@ from libp2p.crypto.rsa import create_new_key_pair
|
||||
from libp2p.host.basic_host import BasicHost
|
||||
from libp2p.host.host_interface import IHost
|
||||
from libp2p.host.routed_host import RoutedHost
|
||||
from libp2p.kademlia.network import KademliaServer
|
||||
from libp2p.kademlia.storage import IStorage
|
||||
from libp2p.network.network_interface import INetwork
|
||||
from libp2p.network.swarm import Swarm
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerstore import PeerStore
|
||||
from libp2p.peer.peerstore_interface import IPeerStore
|
||||
from libp2p.routing.interfaces import IPeerRouting
|
||||
from libp2p.routing.kademlia.kademlia_peer_router import KadmeliaPeerRouter
|
||||
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
||||
import libp2p.security.secio.transport as secio
|
||||
from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex
|
||||
@ -32,31 +29,6 @@ def generate_peer_id_from(key_pair: KeyPair) -> ID:
|
||||
return ID.from_pubkey(public_key)
|
||||
|
||||
|
||||
def initialize_default_kademlia_router(
|
||||
ksize: int = 20, alpha: int = 3, id_opt: ID = None, storage: IStorage = None
|
||||
) -> KadmeliaPeerRouter:
|
||||
"""
|
||||
initialize kadmelia router when no kademlia router is passed in.
|
||||
|
||||
:param ksize: The k parameter from the paper
|
||||
:param alpha: The alpha parameter from the paper
|
||||
:param id_opt: optional id for host
|
||||
:param storage: An instance that implements
|
||||
:interface:`~kademlia.storage.IStorage`
|
||||
:return: return a default kademlia instance
|
||||
"""
|
||||
if not id_opt:
|
||||
key_pair = generate_new_rsa_identity()
|
||||
id_opt = generate_peer_id_from(key_pair)
|
||||
|
||||
node_id = id_opt.to_bytes()
|
||||
# ignore type for Kademlia module
|
||||
server = KademliaServer( # type: ignore
|
||||
ksize=ksize, alpha=alpha, node_id=node_id, storage=storage
|
||||
)
|
||||
return KadmeliaPeerRouter(server)
|
||||
|
||||
|
||||
def initialize_default_swarm(
|
||||
key_pair: KeyPair,
|
||||
id_opt: ID = None,
|
||||
@ -92,6 +64,9 @@ def initialize_default_swarm(
|
||||
)
|
||||
|
||||
peerstore = peerstore_opt or PeerStore()
|
||||
# Store our key pair in peerstore
|
||||
peerstore.add_key_pair(id_opt, key_pair)
|
||||
|
||||
# TODO: Initialize discovery if not presented
|
||||
return Swarm(id_opt, peerstore, upgrader, transport)
|
||||
|
||||
@ -138,8 +113,8 @@ def new_node(
|
||||
# TODO routing unimplemented
|
||||
host: IHost # If not explicitly typed, MyPy raises error
|
||||
if disc_opt:
|
||||
host = RoutedHost(key_pair.public_key, swarm_opt, disc_opt)
|
||||
host = RoutedHost(swarm_opt, disc_opt)
|
||||
else:
|
||||
host = BasicHost(key_pair.public_key, swarm_opt)
|
||||
host = BasicHost(swarm_opt)
|
||||
|
||||
return host
|
||||
|
||||
@ -1,12 +1,14 @@
|
||||
from typing import Callable, Tuple, cast
|
||||
|
||||
from fastecdsa.encoding.util import int_bytelen
|
||||
from fastecdsa.encoding import util
|
||||
|
||||
from libp2p.crypto.ecc import ECCPrivateKey, ECCPublicKey, create_new_key_pair
|
||||
from libp2p.crypto.keys import PublicKey
|
||||
|
||||
SharedKeyGenerator = Callable[[bytes], bytes]
|
||||
|
||||
int_bytelen = util.int_bytelen
|
||||
|
||||
|
||||
def create_ephemeral_key_pair(curve_type: str) -> Tuple[PublicKey, SharedKeyGenerator]:
|
||||
"""Facilitates ECDH key exchange."""
|
||||
|
||||
@ -8,3 +8,9 @@ class ValidationError(BaseLibp2pError):
|
||||
|
||||
class ParseError(BaseLibp2pError):
|
||||
pass
|
||||
|
||||
|
||||
class MultiError(BaseLibp2pError):
|
||||
"""Raised with multiple exceptions."""
|
||||
|
||||
# todo: find some way for this to fancy-print all encapsulated errors
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, List, Sequence
|
||||
|
||||
import multiaddr
|
||||
|
||||
from libp2p.crypto.keys import PublicKey
|
||||
from libp2p.crypto.keys import PrivateKey, PublicKey
|
||||
from libp2p.host.defaults import get_default_protocols
|
||||
from libp2p.host.exceptions import StreamFailure
|
||||
from libp2p.network.network_interface import INetwork
|
||||
@ -39,7 +39,6 @@ class BasicHost(IHost):
|
||||
right after a stream is initialized.
|
||||
"""
|
||||
|
||||
_public_key: PublicKey
|
||||
_network: INetwork
|
||||
peerstore: IPeerStore
|
||||
|
||||
@ -48,11 +47,9 @@ class BasicHost(IHost):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
public_key: PublicKey,
|
||||
network: INetwork,
|
||||
default_protocols: "OrderedDict[TProtocol, StreamHandlerFn]" = None,
|
||||
) -> None:
|
||||
self._public_key = public_key
|
||||
self._network = network
|
||||
self._network.set_stream_handler(self._swarm_stream_handler)
|
||||
self.peerstore = self._network.peerstore
|
||||
@ -68,7 +65,10 @@ class BasicHost(IHost):
|
||||
return self._network.get_peer_id()
|
||||
|
||||
def get_public_key(self) -> PublicKey:
|
||||
return self._public_key
|
||||
return self.peerstore.pubkey(self.get_id())
|
||||
|
||||
def get_private_key(self) -> PrivateKey:
|
||||
return self.peerstore.privkey(self.get_id())
|
||||
|
||||
def get_network(self) -> INetwork:
|
||||
"""
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Any, List, Sequence
|
||||
|
||||
import multiaddr
|
||||
|
||||
from libp2p.crypto.keys import PublicKey
|
||||
from libp2p.crypto.keys import PrivateKey, PublicKey
|
||||
from libp2p.network.network_interface import INetwork
|
||||
from libp2p.network.stream.net_stream_interface import INetStream
|
||||
from libp2p.peer.id import ID
|
||||
@ -24,6 +24,12 @@ class IHost(ABC):
|
||||
:return: the public key belonging to the peer
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_private_key(self) -> PrivateKey:
|
||||
"""
|
||||
:return: the private key belonging to the peer
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_network(self) -> INetwork:
|
||||
"""
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from libp2p.crypto.keys import PublicKey
|
||||
from libp2p.host.basic_host import BasicHost
|
||||
from libp2p.host.exceptions import ConnectionFailure
|
||||
from libp2p.network.network_interface import INetwork
|
||||
@ -11,8 +10,8 @@ from libp2p.routing.interfaces import IPeerRouting
|
||||
class RoutedHost(BasicHost):
|
||||
_router: IPeerRouting
|
||||
|
||||
def __init__(self, public_key: PublicKey, network: INetwork, router: IPeerRouting):
|
||||
super().__init__(public_key, network)
|
||||
def __init__(self, network: INetwork, router: IPeerRouting):
|
||||
super().__init__(network)
|
||||
self._router = router
|
||||
|
||||
async def connect(self, peer_info: PeerInfo) -> None:
|
||||
|
||||
@ -1,173 +0,0 @@
|
||||
from collections import Counter
|
||||
import logging
|
||||
|
||||
from .kad_peerinfo import KadPeerHeap, create_kad_peerinfo
|
||||
from .utils import gather_dict
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SpiderCrawl:
|
||||
"""Crawl the network and look for given 160-bit keys."""
|
||||
|
||||
def __init__(self, protocol, node, peers, ksize, alpha):
|
||||
"""
|
||||
Create a new C{SpiderCrawl}er.
|
||||
|
||||
Args:
|
||||
protocol: A :class:`~kademlia.protocol.KademliaProtocol` instance.
|
||||
node: A :class:`~kademlia.node.Node` representing the key we're
|
||||
looking for
|
||||
peers: A list of :class:`~kademlia.node.Node` instances that
|
||||
provide the entry point for the network
|
||||
ksize: The value for k based on the paper
|
||||
alpha: The value for alpha based on the paper
|
||||
"""
|
||||
self.protocol = protocol
|
||||
self.ksize = ksize
|
||||
self.alpha = alpha
|
||||
self.node = node
|
||||
self.nearest = KadPeerHeap(self.node, self.ksize)
|
||||
self.last_ids_crawled = []
|
||||
log.info("creating spider with peers: %s", peers)
|
||||
self.nearest.push(peers)
|
||||
|
||||
async def _find(self, rpcmethod):
|
||||
"""
|
||||
Get either a value or list of nodes.
|
||||
|
||||
Args:
|
||||
rpcmethod: The protocol's callfindValue or call_find_node.
|
||||
|
||||
The process:
|
||||
1. calls find_* to current ALPHA nearest not already queried nodes,
|
||||
adding results to current nearest list of k nodes.
|
||||
2. current nearest list needs to keep track of who has been queried
|
||||
already sort by nearest, keep KSIZE
|
||||
3. if list is same as last time, next call should be to everyone not
|
||||
yet queried
|
||||
4. repeat, unless nearest list has all been queried, then ur done
|
||||
"""
|
||||
log.info("crawling network with nearest: %s", str(tuple(self.nearest)))
|
||||
count = self.alpha
|
||||
if self.nearest.get_ids() == self.last_ids_crawled:
|
||||
count = len(self.nearest)
|
||||
self.last_ids_crawled = self.nearest.get_ids()
|
||||
|
||||
dicts = {}
|
||||
for peer in self.nearest.get_uncontacted()[:count]:
|
||||
dicts[peer.peer_id_bytes] = rpcmethod(peer, self.node)
|
||||
self.nearest.mark_contacted(peer)
|
||||
found = await gather_dict(dicts)
|
||||
return await self._nodes_found(found)
|
||||
|
||||
async def _nodes_found(self, responses):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ValueSpiderCrawl(SpiderCrawl):
|
||||
def __init__(self, protocol, node, peers, ksize, alpha):
|
||||
SpiderCrawl.__init__(self, protocol, node, peers, ksize, alpha)
|
||||
# keep track of the single nearest node without value - per
|
||||
# section 2.3 so we can set the key there if found
|
||||
self.nearest_without_value = KadPeerHeap(self.node, 1)
|
||||
|
||||
async def find(self):
|
||||
"""Find either the closest nodes or the value requested."""
|
||||
return await self._find(self.protocol.call_find_value)
|
||||
|
||||
async def _nodes_found(self, responses):
|
||||
"""Handle the result of an iteration in _find."""
|
||||
toremove = []
|
||||
found_values = []
|
||||
for peerid, response in responses.items():
|
||||
response = RPCFindResponse(response)
|
||||
if not response.happened():
|
||||
toremove.append(peerid)
|
||||
elif response.has_value():
|
||||
found_values.append(response.get_value())
|
||||
else:
|
||||
peer = self.nearest.get_node(peerid)
|
||||
self.nearest_without_value.push(peer)
|
||||
self.nearest.push(response.get_node_list())
|
||||
self.nearest.remove(toremove)
|
||||
|
||||
if found_values:
|
||||
return await self._handle_found_values(found_values)
|
||||
if self.nearest.have_contacted_all():
|
||||
# not found!
|
||||
return None
|
||||
return await self.find()
|
||||
|
||||
async def _handle_found_values(self, values):
|
||||
"""
|
||||
We got some values!
|
||||
|
||||
Exciting. But let's make sure they're all the same or freak out
|
||||
a little bit. Also, make sure we tell the nearest node that
|
||||
*didn't* have the value to store it.
|
||||
"""
|
||||
value_counts = Counter(values)
|
||||
if len(value_counts) != 1:
|
||||
log.warning(
|
||||
"Got multiple values for key %i: %s", self.node.xor_id, str(values)
|
||||
)
|
||||
value = value_counts.most_common(1)[0][0]
|
||||
|
||||
peer = self.nearest_without_value.popleft()
|
||||
if peer:
|
||||
await self.protocol.call_store(peer, self.node.peer_id_bytes, value)
|
||||
return value
|
||||
|
||||
|
||||
class NodeSpiderCrawl(SpiderCrawl):
|
||||
async def find(self):
|
||||
"""Find the closest nodes."""
|
||||
return await self._find(self.protocol.call_find_node)
|
||||
|
||||
async def _nodes_found(self, responses):
|
||||
"""Handle the result of an iteration in _find."""
|
||||
toremove = []
|
||||
for peerid, response in responses.items():
|
||||
response = RPCFindResponse(response)
|
||||
if not response.happened():
|
||||
toremove.append(peerid)
|
||||
else:
|
||||
self.nearest.push(response.get_node_list())
|
||||
self.nearest.remove(toremove)
|
||||
|
||||
if self.nearest.have_contacted_all():
|
||||
return list(self.nearest)
|
||||
return await self.find()
|
||||
|
||||
|
||||
class RPCFindResponse:
|
||||
def __init__(self, response):
|
||||
"""
|
||||
A wrapper for the result of a RPC find.
|
||||
|
||||
Args:
|
||||
response: This will be a tuple of (<response received>, <value>)
|
||||
where <value> will be a list of tuples if not found or
|
||||
a dictionary of {'value': v} where v is the value desired
|
||||
"""
|
||||
self.response = response
|
||||
|
||||
def happened(self):
|
||||
"""Did the other host actually respond?"""
|
||||
return self.response[0]
|
||||
|
||||
def has_value(self):
|
||||
return isinstance(self.response[1], dict)
|
||||
|
||||
def get_value(self):
|
||||
return self.response[1]["value"]
|
||||
|
||||
def get_node_list(self):
|
||||
"""
|
||||
Get the node list in the response.
|
||||
|
||||
If there's no value, this should be set.
|
||||
"""
|
||||
nodelist = self.response[1] or []
|
||||
return [create_kad_peerinfo(*nodeple) for nodeple in nodelist]
|
||||
@ -1,153 +0,0 @@
|
||||
import heapq
|
||||
from operator import itemgetter
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import PeerInfo
|
||||
|
||||
from .utils import digest
|
||||
|
||||
P_IP = "ip4"
|
||||
P_UDP = "udp"
|
||||
|
||||
|
||||
class KadPeerInfo(PeerInfo):
|
||||
def __init__(self, peer_id, addrs):
|
||||
super(KadPeerInfo, self).__init__(peer_id, addrs)
|
||||
|
||||
self.peer_id_bytes = peer_id.to_bytes()
|
||||
self.xor_id = peer_id.xor_id
|
||||
|
||||
self.addrs = addrs
|
||||
|
||||
self.ip = self.addrs[0].value_for_protocol(P_IP) if addrs else None
|
||||
self.port = int(self.addrs[0].value_for_protocol(P_UDP)) if addrs else None
|
||||
|
||||
def same_home_as(self, node):
|
||||
return sorted(self.addrs) == sorted(node.addrs)
|
||||
|
||||
def distance_to(self, node):
|
||||
"""Get the distance between this node and another."""
|
||||
return self.xor_id ^ node.xor_id
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
Enables use of Node as a tuple - i.e., tuple(node) works.
|
||||
"""
|
||||
return iter([self.peer_id_bytes, self.ip, self.port])
|
||||
|
||||
def __repr__(self):
|
||||
return repr([self.xor_id, self.ip, self.port, self.peer_id_bytes])
|
||||
|
||||
def __str__(self):
|
||||
return "%s:%s" % (self.ip, str(self.port))
|
||||
|
||||
def encode(self):
|
||||
return (
|
||||
str(self.peer_id_bytes)
|
||||
+ "\n"
|
||||
+ str("/ip4/" + str(self.ip) + "/udp/" + str(self.port))
|
||||
)
|
||||
|
||||
|
||||
class KadPeerHeap:
|
||||
"""A heap of peers ordered by distance to a given node."""
|
||||
|
||||
def __init__(self, node, maxsize):
|
||||
"""
|
||||
Constructor.
|
||||
|
||||
@param node: The node to measure all distnaces from.
|
||||
@param maxsize: The maximum size that this heap can grow to.
|
||||
"""
|
||||
self.node = node
|
||||
self.heap = []
|
||||
self.contacted = set()
|
||||
self.maxsize = maxsize
|
||||
|
||||
def remove(self, peers):
|
||||
"""
|
||||
Remove a list of peer ids from this heap.
|
||||
|
||||
Note that while this heap retains a constant visible size (based
|
||||
on the iterator), it's actual size may be quite a bit larger
|
||||
than what's exposed. Therefore, removal of nodes may not change
|
||||
the visible size as previously added nodes suddenly become
|
||||
visible.
|
||||
"""
|
||||
peers = set(peers)
|
||||
if not peers:
|
||||
return
|
||||
nheap = []
|
||||
for distance, node in self.heap:
|
||||
if node.peer_id_bytes not in peers:
|
||||
heapq.heappush(nheap, (distance, node))
|
||||
self.heap = nheap
|
||||
|
||||
def get_node(self, node_id):
|
||||
for _, node in self.heap:
|
||||
if node.peer_id_bytes == node_id:
|
||||
return node
|
||||
return None
|
||||
|
||||
def have_contacted_all(self):
|
||||
return len(self.get_uncontacted()) == 0
|
||||
|
||||
def get_ids(self):
|
||||
return [n.peer_id_bytes for n in self]
|
||||
|
||||
def mark_contacted(self, node):
|
||||
self.contacted.add(node.peer_id_bytes)
|
||||
|
||||
def popleft(self):
|
||||
return heapq.heappop(self.heap)[1] if self else None
|
||||
|
||||
def push(self, nodes):
|
||||
"""
|
||||
Push nodes onto heap.
|
||||
|
||||
@param nodes: This can be a single item or a C{list}.
|
||||
"""
|
||||
if not isinstance(nodes, list):
|
||||
nodes = [nodes]
|
||||
|
||||
for node in nodes:
|
||||
if node not in self:
|
||||
distance = self.node.distance_to(node)
|
||||
heapq.heappush(self.heap, (distance, node))
|
||||
|
||||
def __len__(self):
|
||||
return min(len(self.heap), self.maxsize)
|
||||
|
||||
def __iter__(self):
|
||||
nodes = heapq.nsmallest(self.maxsize, self.heap)
|
||||
return iter(map(itemgetter(1), nodes))
|
||||
|
||||
def __contains__(self, node):
|
||||
for _, other in self.heap:
|
||||
if node.peer_id_bytes == other.peer_id_bytes:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_uncontacted(self):
|
||||
return [n for n in self if n.peer_id_bytes not in self.contacted]
|
||||
|
||||
|
||||
def create_kad_peerinfo(node_id_bytes=None, sender_ip=None, sender_port=None):
|
||||
node_id = (
|
||||
ID(node_id_bytes) if node_id_bytes else ID(digest(random.getrandbits(255)))
|
||||
)
|
||||
addrs: List[Multiaddr]
|
||||
if sender_ip and sender_port:
|
||||
addrs = [
|
||||
Multiaddr(
|
||||
"/" + P_IP + "/" + str(sender_ip) + "/" + P_UDP + "/" + str(sender_port)
|
||||
)
|
||||
]
|
||||
else:
|
||||
addrs = []
|
||||
|
||||
return KadPeerInfo(node_id, addrs)
|
||||
@ -1,251 +0,0 @@
|
||||
"""Package for interacting on the network at a high level."""
|
||||
import asyncio
|
||||
import logging
|
||||
import pickle
|
||||
|
||||
from .crawling import NodeSpiderCrawl, ValueSpiderCrawl
|
||||
from .kad_peerinfo import create_kad_peerinfo
|
||||
from .protocol import KademliaProtocol
|
||||
from .storage import ForgetfulStorage
|
||||
from .utils import digest
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KademliaServer:
|
||||
"""
|
||||
High level view of a node instance.
|
||||
|
||||
This is the object that should be created to start listening as an
|
||||
active node on the network.
|
||||
"""
|
||||
|
||||
protocol_class = KademliaProtocol
|
||||
|
||||
def __init__(self, ksize=20, alpha=3, node_id=None, storage=None):
|
||||
"""
|
||||
Create a server instance. This will start listening on the given port.
|
||||
|
||||
Args:
|
||||
ksize (int): The k parameter from the paper
|
||||
alpha (int): The alpha parameter from the paper
|
||||
node_id: The id for this node on the network.
|
||||
storage: An instance that implements
|
||||
:interface:`~kademlia.storage.IStorage`
|
||||
"""
|
||||
self.ksize = ksize
|
||||
self.alpha = alpha
|
||||
self.storage = storage or ForgetfulStorage()
|
||||
self.node = create_kad_peerinfo(node_id)
|
||||
self.transport = None
|
||||
self.protocol = None
|
||||
self.refresh_loop = None
|
||||
self.save_state_loop = None
|
||||
|
||||
def stop(self):
|
||||
if self.transport is not None:
|
||||
self.transport.close()
|
||||
|
||||
if self.refresh_loop:
|
||||
self.refresh_loop.cancel()
|
||||
|
||||
if self.save_state_loop:
|
||||
self.save_state_loop.cancel()
|
||||
|
||||
def _create_protocol(self):
|
||||
return self.protocol_class(self.node, self.storage, self.ksize)
|
||||
|
||||
async def listen(self, port=0, interface="0.0.0.0"):
|
||||
"""
|
||||
Start listening on the given port.
|
||||
|
||||
Provide interface="::" to accept ipv6 address
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
listen = loop.create_datagram_endpoint(
|
||||
self._create_protocol, local_addr=(interface, port)
|
||||
)
|
||||
self.transport, self.protocol = await listen
|
||||
socket = self.transport.get_extra_info("socket")
|
||||
self.address = socket.getsockname()
|
||||
log.info(
|
||||
"Node %i listening on %s:%i",
|
||||
self.node.xor_id,
|
||||
self.address[0],
|
||||
self.address[1],
|
||||
)
|
||||
# finally, schedule refreshing table
|
||||
self.refresh_table()
|
||||
|
||||
def refresh_table(self):
|
||||
log.debug("Refreshing routing table")
|
||||
asyncio.ensure_future(self._refresh_table())
|
||||
loop = asyncio.get_event_loop()
|
||||
self.refresh_loop = loop.call_later(3600, self.refresh_table)
|
||||
|
||||
async def _refresh_table(self):
|
||||
"""Refresh buckets that haven't had any lookups in the last hour (per
|
||||
section 2.3 of the paper)."""
|
||||
results = []
|
||||
for node_id in self.protocol.get_refresh_ids():
|
||||
node = create_kad_peerinfo(node_id)
|
||||
nearest = self.protocol.router.find_neighbors(node, self.alpha)
|
||||
spider = NodeSpiderCrawl(
|
||||
self.protocol, node, nearest, self.ksize, self.alpha
|
||||
)
|
||||
results.append(spider.find())
|
||||
|
||||
# do our crawling
|
||||
await asyncio.gather(*results)
|
||||
|
||||
# now republish keys older than one hour
|
||||
for dkey, value in self.storage.iter_older_than(3600):
|
||||
await self.set_digest(dkey, value)
|
||||
|
||||
def bootstrappable_neighbors(self):
|
||||
"""
|
||||
Get a :class:`list` of (ip, port) :class:`tuple` pairs suitable for use
|
||||
as an argument to the bootstrap method.
|
||||
|
||||
The server should have been bootstrapped
|
||||
already - this is just a utility for getting some neighbors and then
|
||||
storing them if this server is going down for a while. When it comes
|
||||
back up, the list of nodes can be used to bootstrap.
|
||||
"""
|
||||
neighbors = self.protocol.router.find_neighbors(self.node)
|
||||
return [tuple(n)[-2:] for n in neighbors]
|
||||
|
||||
async def bootstrap(self, addrs):
|
||||
"""
|
||||
Bootstrap the server by connecting to other known nodes in the network.
|
||||
|
||||
Args:
|
||||
addrs: A `list` of (ip, port) `tuple` pairs. Note that only IP
|
||||
addresses are acceptable - hostnames will cause an error.
|
||||
"""
|
||||
log.debug("Attempting to bootstrap node with %i initial contacts", len(addrs))
|
||||
cos = list(map(self.bootstrap_node, addrs))
|
||||
gathered = await asyncio.gather(*cos)
|
||||
nodes = [node for node in gathered if node is not None]
|
||||
spider = NodeSpiderCrawl(
|
||||
self.protocol, self.node, nodes, self.ksize, self.alpha
|
||||
)
|
||||
return await spider.find()
|
||||
|
||||
async def bootstrap_node(self, addr):
|
||||
result = await self.protocol.ping(addr, self.node.peer_id_bytes)
|
||||
return create_kad_peerinfo(result[1], addr[0], addr[1]) if result[0] else None
|
||||
|
||||
async def get(self, key):
|
||||
"""
|
||||
Get a key if the network has it.
|
||||
|
||||
Returns:
|
||||
:class:`None` if not found, the value otherwise.
|
||||
"""
|
||||
log.info("Looking up key %s", key)
|
||||
dkey = digest(key)
|
||||
# if this node has it, return it
|
||||
if self.storage.get(dkey) is not None:
|
||||
return self.storage.get(dkey)
|
||||
|
||||
node = create_kad_peerinfo(dkey)
|
||||
nearest = self.protocol.router.find_neighbors(node)
|
||||
if not nearest:
|
||||
log.warning("There are no known neighbors to get key %s", key)
|
||||
return None
|
||||
spider = ValueSpiderCrawl(self.protocol, node, nearest, self.ksize, self.alpha)
|
||||
return await spider.find()
|
||||
|
||||
async def set(self, key, value):
|
||||
"""Set the given string key to the given value in the network."""
|
||||
if not check_dht_value_type(value):
|
||||
raise TypeError("Value must be of type int, float, bool, str, or bytes")
|
||||
log.info("setting '%s' = '%s' on network", key, value)
|
||||
dkey = digest(key)
|
||||
return await self.set_digest(dkey, value)
|
||||
|
||||
async def provide(self, key):
|
||||
"""publish to the network that it provides for a particular key."""
|
||||
neighbors = self.protocol.router.find_neighbors(self.node)
|
||||
return [
|
||||
await self.protocol.call_add_provider(n, key, self.node.peer_id_bytes)
|
||||
for n in neighbors
|
||||
]
|
||||
|
||||
async def get_providers(self, key):
|
||||
"""get the list of providers for a key."""
|
||||
neighbors = self.protocol.router.find_neighbors(self.node)
|
||||
return [await self.protocol.call_get_providers(n, key) for n in neighbors]
|
||||
|
||||
async def set_digest(self, dkey, value):
|
||||
"""Set the given SHA1 digest key (bytes) to the given value in the
|
||||
network."""
|
||||
node = create_kad_peerinfo(dkey)
|
||||
|
||||
nearest = self.protocol.router.find_neighbors(node)
|
||||
if not nearest:
|
||||
log.warning("There are no known neighbors to set key %s", dkey.hex())
|
||||
return False
|
||||
|
||||
spider = NodeSpiderCrawl(self.protocol, node, nearest, self.ksize, self.alpha)
|
||||
nodes = await spider.find()
|
||||
log.info("setting '%s' on %s", dkey.hex(), list(map(str, nodes)))
|
||||
|
||||
# if this node is close too, then store here as well
|
||||
biggest = max([n.distance_to(node) for n in nodes])
|
||||
if self.node.distance_to(node) < biggest:
|
||||
self.storage[dkey] = value
|
||||
results = [self.protocol.call_store(n, dkey, value) for n in nodes]
|
||||
# return true only if at least one store call succeeded
|
||||
return any(await asyncio.gather(*results))
|
||||
|
||||
def save_state(self, fname):
|
||||
"""Save the state of this node (the alpha/ksize/id/immediate neighbors)
|
||||
to a cache file with the given fname."""
|
||||
log.info("Saving state to %s", fname)
|
||||
data = {
|
||||
"ksize": self.ksize,
|
||||
"alpha": self.alpha,
|
||||
"id": self.node.peer_id_bytes,
|
||||
"neighbors": self.bootstrappable_neighbors(),
|
||||
}
|
||||
if not data["neighbors"]:
|
||||
log.warning("No known neighbors, so not writing to cache.")
|
||||
return
|
||||
with open(fname, "wb") as file:
|
||||
pickle.dump(data, file)
|
||||
|
||||
@classmethod
|
||||
def load_state(cls, fname):
|
||||
"""Load the state of this node (the alpha/ksize/id/immediate neighbors)
|
||||
from a cache file with the given fname."""
|
||||
log.info("Loading state from %s", fname)
|
||||
with open(fname, "rb") as file:
|
||||
data = pickle.load(file)
|
||||
svr = KademliaServer(data["ksize"], data["alpha"], data["id"])
|
||||
if data["neighbors"]:
|
||||
svr.bootstrap(data["neighbors"])
|
||||
return svr
|
||||
|
||||
def save_state_regularly(self, fname, frequency=600):
|
||||
"""
|
||||
Save the state of node with a given regularity to the given filename.
|
||||
|
||||
Args:
|
||||
fname: File name to save retularly to
|
||||
frequency: Frequency in seconds that the state should be saved.
|
||||
By default, 10 minutes.
|
||||
"""
|
||||
self.save_state(fname)
|
||||
loop = asyncio.get_event_loop()
|
||||
self.save_state_loop = loop.call_later(
|
||||
frequency, self.save_state_regularly, fname, frequency
|
||||
)
|
||||
|
||||
|
||||
def check_dht_value_type(value):
|
||||
"""Checks to see if the type of the value is a valid type for placing in
|
||||
the dht."""
|
||||
typeset = [int, float, bool, str, bytes]
|
||||
return type(value) in typeset
|
||||
@ -1,188 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
|
||||
from rpcudp.protocol import RPCProtocol
|
||||
|
||||
from .kad_peerinfo import create_kad_peerinfo
|
||||
from .routing import RoutingTable
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KademliaProtocol(RPCProtocol):
|
||||
"""
|
||||
There are four main RPCs in the Kademlia protocol PING, STORE, FIND_NODE,
|
||||
FIND_VALUE.
|
||||
|
||||
- PING probes if a node is still online
|
||||
- STORE instructs a node to store (key, value)
|
||||
- FIND_NODE takes a 160-bit ID and gets back
|
||||
(ip, udp_port, node_id) for k closest nodes to target
|
||||
- FIND_VALUE behaves like FIND_NODE unless a value is stored.
|
||||
"""
|
||||
|
||||
def __init__(self, source_node, storage, ksize):
|
||||
RPCProtocol.__init__(self)
|
||||
self.router = RoutingTable(self, ksize, source_node)
|
||||
self.storage = storage
|
||||
self.source_node = source_node
|
||||
|
||||
def get_refresh_ids(self):
|
||||
"""Get ids to search for to keep old buckets up to date."""
|
||||
ids = []
|
||||
for bucket in self.router.lonely_buckets():
|
||||
rid = random.randint(*bucket.range).to_bytes(20, byteorder="big")
|
||||
ids.append(rid)
|
||||
return ids
|
||||
|
||||
def rpc_stun(self, sender):
|
||||
return sender
|
||||
|
||||
def rpc_ping(self, sender, nodeid):
|
||||
source = create_kad_peerinfo(nodeid, sender[0], sender[1])
|
||||
|
||||
self.welcome_if_new(source)
|
||||
return self.source_node.peer_id_bytes
|
||||
|
||||
def rpc_store(self, sender, nodeid, key, value):
|
||||
source = create_kad_peerinfo(nodeid, sender[0], sender[1])
|
||||
|
||||
self.welcome_if_new(source)
|
||||
log.debug(
|
||||
"got a store request from %s, storing '%s'='%s'", sender, key.hex(), value
|
||||
)
|
||||
self.storage[key] = value
|
||||
return True
|
||||
|
||||
def rpc_find_node(self, sender, nodeid, key):
|
||||
log.info("finding neighbors of %i in local table", int(nodeid.hex(), 16))
|
||||
source = create_kad_peerinfo(nodeid, sender[0], sender[1])
|
||||
|
||||
self.welcome_if_new(source)
|
||||
node = create_kad_peerinfo(key)
|
||||
neighbors = self.router.find_neighbors(node, exclude=source)
|
||||
return list(map(tuple, neighbors))
|
||||
|
||||
def rpc_find_value(self, sender, nodeid, key):
|
||||
source = create_kad_peerinfo(nodeid, sender[0], sender[1])
|
||||
|
||||
self.welcome_if_new(source)
|
||||
value = self.storage.get(key, None)
|
||||
if value is None:
|
||||
return self.rpc_find_node(sender, nodeid, key)
|
||||
return {"value": value}
|
||||
|
||||
def rpc_add_provider(self, sender, nodeid, key, provider_id):
|
||||
"""rpc when receiving an add_provider call should validate received
|
||||
PeerInfo matches sender nodeid if it does, receipient must store a
|
||||
record in its datastore we store a map of content_id to peer_id (non
|
||||
xor)"""
|
||||
if nodeid == provider_id:
|
||||
log.info(
|
||||
"adding provider %s for key %s in local table", provider_id, str(key)
|
||||
)
|
||||
self.storage[key] = provider_id
|
||||
return True
|
||||
return False
|
||||
|
||||
def rpc_get_providers(self, sender, key):
|
||||
"""rpc when receiving a get_providers call should look up key in data
|
||||
store and respond with records plus a list of closer peers in its
|
||||
routing table."""
|
||||
providers = []
|
||||
record = self.storage.get(key, None)
|
||||
|
||||
if record:
|
||||
providers.append(record)
|
||||
|
||||
keynode = create_kad_peerinfo(key)
|
||||
neighbors = self.router.find_neighbors(keynode)
|
||||
for neighbor in neighbors:
|
||||
if neighbor.peer_id_bytes != record:
|
||||
providers.append(neighbor.peer_id_bytes)
|
||||
|
||||
return providers
|
||||
|
||||
async def call_find_node(self, node_to_ask, node_to_find):
|
||||
address = (node_to_ask.ip, node_to_ask.port)
|
||||
result = await self.find_node(
|
||||
address, self.source_node.peer_id_bytes, node_to_find.peer_id_bytes
|
||||
)
|
||||
return self.handle_call_response(result, node_to_ask)
|
||||
|
||||
async def call_find_value(self, node_to_ask, node_to_find):
|
||||
address = (node_to_ask.ip, node_to_ask.port)
|
||||
result = await self.find_value(
|
||||
address, self.source_node.peer_id_bytes, node_to_find.peer_id_bytes
|
||||
)
|
||||
return self.handle_call_response(result, node_to_ask)
|
||||
|
||||
async def call_ping(self, node_to_ask):
|
||||
address = (node_to_ask.ip, node_to_ask.port)
|
||||
result = await self.ping(address, self.source_node.peer_id_bytes)
|
||||
return self.handle_call_response(result, node_to_ask)
|
||||
|
||||
async def call_store(self, node_to_ask, key, value):
|
||||
address = (node_to_ask.ip, node_to_ask.port)
|
||||
result = await self.store(address, self.source_node.peer_id_bytes, key, value)
|
||||
return self.handle_call_response(result, node_to_ask)
|
||||
|
||||
async def call_add_provider(self, node_to_ask, key, provider_id):
|
||||
address = (node_to_ask.ip, node_to_ask.port)
|
||||
result = await self.add_provider(
|
||||
address, self.source_node.peer_id_bytes, key, provider_id
|
||||
)
|
||||
|
||||
return self.handle_call_response(result, node_to_ask)
|
||||
|
||||
async def call_get_providers(self, node_to_ask, key):
|
||||
address = (node_to_ask.ip, node_to_ask.port)
|
||||
result = await self.get_providers(address, key)
|
||||
return self.handle_call_response(result, node_to_ask)
|
||||
|
||||
def welcome_if_new(self, node):
|
||||
"""
|
||||
Given a new node, send it all the keys/values it should be storing,
|
||||
then add it to the routing table.
|
||||
|
||||
@param node: A new node that just joined (or that we just found out
|
||||
about).
|
||||
|
||||
Process:
|
||||
For each key in storage, get k closest nodes. If newnode is closer
|
||||
than the furtherst in that list, and the node for this server
|
||||
is closer than the closest in that list, then store the key/value
|
||||
on the new node (per section 2.5 of the paper)
|
||||
"""
|
||||
if not self.router.is_new_node(node):
|
||||
return
|
||||
|
||||
log.info("never seen %s before, adding to router", node)
|
||||
for key, value in self.storage:
|
||||
keynode = create_kad_peerinfo(key)
|
||||
neighbors = self.router.find_neighbors(keynode)
|
||||
if neighbors:
|
||||
last = neighbors[-1].distance_to(keynode)
|
||||
new_node_close = node.distance_to(keynode) < last
|
||||
first = neighbors[0].distance_to(keynode)
|
||||
this_closest = self.source_node.distance_to(keynode) < first
|
||||
if not neighbors or (new_node_close and this_closest):
|
||||
asyncio.ensure_future(self.call_store(node, key, value))
|
||||
self.router.add_contact(node)
|
||||
|
||||
def handle_call_response(self, result, node):
|
||||
"""
|
||||
If we get a response, add the node to the routing table.
|
||||
|
||||
If we get no response, make sure it's removed from the routing
|
||||
table.
|
||||
"""
|
||||
if not result[0]:
|
||||
log.warning("no response from %s, removing from router", node)
|
||||
self.router.remove_contact(node)
|
||||
return result
|
||||
|
||||
log.info("got successful response from %s", node)
|
||||
self.welcome_if_new(node)
|
||||
return result
|
||||
@ -1,184 +0,0 @@
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
import heapq
|
||||
import operator
|
||||
import time
|
||||
|
||||
from .utils import OrderedSet, bytes_to_bit_string, shared_prefix
|
||||
|
||||
|
||||
class KBucket:
|
||||
"""each node keeps a list of (ip, udp_port, node_id) for nodes of distance
|
||||
between 2^i and 2^(i+1) this list that every node keeps is a k-bucket each
|
||||
k-bucket implements a last seen eviction policy except that live nodes are
|
||||
never removed."""
|
||||
|
||||
def __init__(self, rangeLower, rangeUpper, ksize):
|
||||
self.range = (rangeLower, rangeUpper)
|
||||
self.nodes = OrderedDict()
|
||||
self.replacement_nodes = OrderedSet()
|
||||
self.touch_last_updated()
|
||||
self.ksize = ksize
|
||||
|
||||
def touch_last_updated(self):
|
||||
self.last_updated = time.monotonic()
|
||||
|
||||
def get_nodes(self):
|
||||
return list(self.nodes.values())
|
||||
|
||||
def split(self):
|
||||
midpoint = (self.range[0] + self.range[1]) / 2
|
||||
one = KBucket(self.range[0], midpoint, self.ksize)
|
||||
two = KBucket(midpoint + 1, self.range[1], self.ksize)
|
||||
for node in self.nodes.values():
|
||||
bucket = one if node.xor_id <= midpoint else two
|
||||
bucket.nodes[node.peer_id_bytes] = node
|
||||
return (one, two)
|
||||
|
||||
def remove_node(self, node):
|
||||
if node.peer_id_bytes not in self.nodes:
|
||||
return
|
||||
|
||||
# delete node, and see if we can add a replacement
|
||||
del self.nodes[node.peer_id_bytes]
|
||||
if self.replacement_nodes:
|
||||
newnode = self.replacement_nodes.pop()
|
||||
self.nodes[newnode.peer_id_bytes] = newnode
|
||||
|
||||
def has_in_range(self, node):
|
||||
return self.range[0] <= node.xor_id <= self.range[1]
|
||||
|
||||
def is_new_node(self, node):
|
||||
return node.peer_id_bytes not in self.nodes
|
||||
|
||||
def add_node(self, node):
|
||||
"""
|
||||
Add a C{Node} to the C{KBucket}. Return True if successful, False if
|
||||
the bucket is full.
|
||||
|
||||
If the bucket is full, keep track of node in a replacement list,
|
||||
per section 4.1 of the paper.
|
||||
"""
|
||||
if node.peer_id_bytes in self.nodes:
|
||||
del self.nodes[node.peer_id_bytes]
|
||||
self.nodes[node.peer_id_bytes] = node
|
||||
elif len(self) < self.ksize:
|
||||
self.nodes[node.peer_id_bytes] = node
|
||||
else:
|
||||
self.replacement_nodes.push(node)
|
||||
return False
|
||||
return True
|
||||
|
||||
def depth(self):
|
||||
vals = self.nodes.values()
|
||||
sprefix = shared_prefix([bytes_to_bit_string(n.peer_id_bytes) for n in vals])
|
||||
return len(sprefix)
|
||||
|
||||
def head(self):
|
||||
return list(self.nodes.values())[0]
|
||||
|
||||
def __getitem__(self, node_id):
|
||||
return self.nodes.get(node_id, None)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.nodes)
|
||||
|
||||
|
||||
class TableTraverser:
|
||||
def __init__(self, table, startNode):
|
||||
index = table.get_bucket_for(startNode)
|
||||
table.buckets[index].touch_last_updated()
|
||||
self.current_nodes = table.buckets[index].get_nodes()
|
||||
self.left_buckets = table.buckets[:index]
|
||||
self.right_buckets = table.buckets[(index + 1) :]
|
||||
self.left = True
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
"""Pop an item from the left subtree, then right, then left, etc."""
|
||||
if self.current_nodes:
|
||||
return self.current_nodes.pop()
|
||||
|
||||
if self.left and self.left_buckets:
|
||||
self.current_nodes = self.left_buckets.pop().get_nodes()
|
||||
self.left = False
|
||||
return next(self)
|
||||
|
||||
if self.right_buckets:
|
||||
self.current_nodes = self.right_buckets.pop(0).get_nodes()
|
||||
self.left = True
|
||||
return next(self)
|
||||
|
||||
raise StopIteration
|
||||
|
||||
|
||||
class RoutingTable:
|
||||
def __init__(self, protocol, ksize, node):
|
||||
"""
|
||||
@param node: The node that represents this server. It won't
|
||||
be added to the routing table, but will be needed later to
|
||||
determine which buckets to split or not.
|
||||
"""
|
||||
self.node = node
|
||||
self.protocol = protocol
|
||||
self.ksize = ksize
|
||||
self.flush()
|
||||
|
||||
def flush(self):
|
||||
self.buckets = [KBucket(0, 2 ** 160, self.ksize)]
|
||||
|
||||
def split_bucket(self, index):
|
||||
one, two = self.buckets[index].split()
|
||||
self.buckets[index] = one
|
||||
self.buckets.insert(index + 1, two)
|
||||
|
||||
def lonely_buckets(self):
|
||||
"""Get all of the buckets that haven't been updated in over an hour."""
|
||||
hrago = time.monotonic() - 3600
|
||||
return [b for b in self.buckets if b.last_updated < hrago]
|
||||
|
||||
def remove_contact(self, node):
|
||||
index = self.get_bucket_for(node)
|
||||
self.buckets[index].remove_node(node)
|
||||
|
||||
def is_new_node(self, node):
|
||||
index = self.get_bucket_for(node)
|
||||
return self.buckets[index].is_new_node(node)
|
||||
|
||||
def add_contact(self, node):
|
||||
index = self.get_bucket_for(node)
|
||||
bucket = self.buckets[index]
|
||||
|
||||
# this will succeed unless the bucket is full
|
||||
if bucket.add_node(node):
|
||||
return
|
||||
|
||||
# Per section 4.2 of paper, split if the bucket has the node
|
||||
# in its range or if the depth is not congruent to 0 mod 5
|
||||
if bucket.has_in_range(self.node) or bucket.depth() % 5 != 0:
|
||||
self.split_bucket(index)
|
||||
self.add_contact(node)
|
||||
else:
|
||||
asyncio.ensure_future(self.protocol.call_ping(bucket.head()))
|
||||
|
||||
def get_bucket_for(self, node):
|
||||
"""Get the index of the bucket that the given node would fall into."""
|
||||
for index, bucket in enumerate(self.buckets):
|
||||
if node.xor_id < bucket.range[1]:
|
||||
return index
|
||||
# we should never be here, but make linter happy
|
||||
return None
|
||||
|
||||
def find_neighbors(self, node, k=None, exclude=None):
|
||||
k = k or self.ksize
|
||||
nodes = []
|
||||
for neighbor in TableTraverser(self, node):
|
||||
notexcluded = exclude is None or not neighbor.same_home_as(exclude)
|
||||
if neighbor.peer_id_bytes != node.peer_id_bytes and notexcluded:
|
||||
heapq.heappush(nodes, (node.distance_to(neighbor), neighbor))
|
||||
if len(nodes) == k:
|
||||
break
|
||||
|
||||
return list(map(operator.itemgetter(1), heapq.nsmallest(k, nodes)))
|
||||
@ -1,78 +0,0 @@
|
||||
// Record represents a dht record that contains a value
|
||||
// for a key value pair
|
||||
message Record {
|
||||
// The key that references this record
|
||||
bytes key = 1;
|
||||
|
||||
// The actual value this record is storing
|
||||
bytes value = 2;
|
||||
|
||||
// Note: These fields were removed from the Record message
|
||||
// hash of the authors public key
|
||||
//optional string author = 3;
|
||||
// A PKI signature for the key+value+author
|
||||
//optional bytes signature = 4;
|
||||
|
||||
// Time the record was received, set by receiver
|
||||
string timeReceived = 5;
|
||||
};
|
||||
|
||||
message Message {
|
||||
enum MessageType {
|
||||
PUT_VALUE = 0;
|
||||
GET_VALUE = 1;
|
||||
ADD_PROVIDER = 2;
|
||||
GET_PROVIDERS = 3;
|
||||
FIND_NODE = 4;
|
||||
PING = 5;
|
||||
}
|
||||
|
||||
enum ConnectionType {
|
||||
// sender does not have a connection to peer, and no extra information (default)
|
||||
NOT_CONNECTED = 0;
|
||||
|
||||
// sender has a live connection to peer
|
||||
CONNECTED = 1;
|
||||
|
||||
// sender recently connected to peer
|
||||
CAN_CONNECT = 2;
|
||||
|
||||
// sender recently tried to connect to peer repeatedly but failed to connect
|
||||
// ("try" here is loose, but this should signal "made strong effort, failed")
|
||||
CANNOT_CONNECT = 3;
|
||||
}
|
||||
|
||||
message Peer {
|
||||
// ID of a given peer.
|
||||
bytes id = 1;
|
||||
|
||||
// multiaddrs for a given peer
|
||||
repeated bytes addrs = 2;
|
||||
|
||||
// used to signal the sender's connection capabilities to the peer
|
||||
ConnectionType connection = 3;
|
||||
}
|
||||
|
||||
// defines what type of message it is.
|
||||
MessageType type = 1;
|
||||
|
||||
// defines what coral cluster level this query/response belongs to.
|
||||
// in case we want to implement coral's cluster rings in the future.
|
||||
int32 clusterLevelRaw = 10; // NOT USED
|
||||
|
||||
// Used to specify the key associated with this message.
|
||||
// PUT_VALUE, GET_VALUE, ADD_PROVIDER, GET_PROVIDERS
|
||||
bytes key = 2;
|
||||
|
||||
// Used to return a value
|
||||
// PUT_VALUE, GET_VALUE
|
||||
Record record = 3;
|
||||
|
||||
// Used to return peers closer to a key in a query
|
||||
// GET_VALUE, GET_PROVIDERS, FIND_NODE
|
||||
repeated Peer closerPeers = 8;
|
||||
|
||||
// Used to return Providers
|
||||
// GET_VALUE, ADD_PROVIDER, GET_PROVIDERS
|
||||
repeated Peer providerPeers = 9;
|
||||
}
|
||||
@ -1,93 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import OrderedDict
|
||||
from itertools import takewhile
|
||||
import operator
|
||||
import time
|
||||
|
||||
|
||||
class IStorage(ABC):
|
||||
"""
|
||||
Local storage for this node.
|
||||
|
||||
IStorage implementations of get must return the same type as put in
|
||||
by set
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __setitem__(self, key, value):
|
||||
"""Set a key to the given value."""
|
||||
|
||||
@abstractmethod
|
||||
def __getitem__(self, key):
|
||||
"""
|
||||
Get the given key.
|
||||
|
||||
If item doesn't exist, raises C{KeyError}
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key, default=None):
|
||||
"""
|
||||
Get given key.
|
||||
|
||||
If not found, return default.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def iter_older_than(self, seconds_old):
|
||||
"""Return the an iterator over (key, value) tuples for items older than
|
||||
the given seconds_old."""
|
||||
|
||||
@abstractmethod
|
||||
def __iter__(self):
|
||||
"""Get the iterator for this storage, should yield tuple of (key,
|
||||
value)"""
|
||||
|
||||
|
||||
class ForgetfulStorage(IStorage):
|
||||
def __init__(self, ttl=604800):
|
||||
"""By default, max age is a week."""
|
||||
self.data = OrderedDict()
|
||||
self.ttl = ttl
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if key in self.data:
|
||||
del self.data[key]
|
||||
self.data[key] = (time.monotonic(), value)
|
||||
self.cull()
|
||||
|
||||
def cull(self):
|
||||
for _, _ in self.iter_older_than(self.ttl):
|
||||
self.data.popitem(last=False)
|
||||
|
||||
def get(self, key, default=None):
|
||||
self.cull()
|
||||
if key in self.data:
|
||||
return self[key]
|
||||
return default
|
||||
|
||||
def __getitem__(self, key):
|
||||
self.cull()
|
||||
return self.data[key][1]
|
||||
|
||||
def __repr__(self):
|
||||
self.cull()
|
||||
return repr(self.data)
|
||||
|
||||
def iter_older_than(self, seconds_old):
|
||||
min_birthday = time.monotonic() - seconds_old
|
||||
zipped = self._triple_iter()
|
||||
matches = takewhile(lambda r: min_birthday >= r[1], zipped)
|
||||
return list(map(operator.itemgetter(0, 2), matches))
|
||||
|
||||
def _triple_iter(self):
|
||||
ikeys = self.data.keys()
|
||||
ibirthday = map(operator.itemgetter(0), self.data.values())
|
||||
ivalues = map(operator.itemgetter(1), self.data.values())
|
||||
return zip(ikeys, ibirthday, ivalues)
|
||||
|
||||
def __iter__(self):
|
||||
self.cull()
|
||||
ikeys = self.data.keys()
|
||||
ivalues = map(operator.itemgetter(1), self.data.values())
|
||||
return zip(ikeys, ivalues)
|
||||
@ -1,56 +0,0 @@
|
||||
"""General catchall for functions that don't make sense as methods."""
|
||||
import asyncio
|
||||
import hashlib
|
||||
import operator
|
||||
|
||||
|
||||
async def gather_dict(dic):
|
||||
cors = list(dic.values())
|
||||
results = await asyncio.gather(*cors)
|
||||
return dict(zip(dic.keys(), results))
|
||||
|
||||
|
||||
def digest(string):
|
||||
if not isinstance(string, bytes):
|
||||
string = str(string).encode("utf8")
|
||||
return hashlib.sha1(string).digest()
|
||||
|
||||
|
||||
class OrderedSet(list):
|
||||
"""
|
||||
Acts like a list in all ways, except in the behavior of the.
|
||||
|
||||
:meth:`push` method.
|
||||
"""
|
||||
|
||||
def push(self, thing):
|
||||
"""
|
||||
1. If the item exists in the list, it's removed
|
||||
2. The item is pushed to the end of the list
|
||||
"""
|
||||
if thing in self:
|
||||
self.remove(thing)
|
||||
self.append(thing)
|
||||
|
||||
|
||||
def shared_prefix(args):
|
||||
"""
|
||||
Find the shared prefix between the strings.
|
||||
|
||||
For instance:
|
||||
|
||||
sharedPrefix(['blahblah', 'blahwhat'])
|
||||
|
||||
returns 'blah'.
|
||||
"""
|
||||
i = 0
|
||||
while i < min(map(len, args)):
|
||||
if len(set(map(operator.itemgetter(i), args))) != 1:
|
||||
break
|
||||
i += 1
|
||||
return args[0][:i]
|
||||
|
||||
|
||||
def bytes_to_bit_string(bites):
|
||||
bits = [bin(bite)[2:].rjust(8, "0") for bite in bites]
|
||||
return "".join(bits)
|
||||
@ -21,6 +21,7 @@ from libp2p.transport.transport_interface import ITransport
|
||||
from libp2p.transport.upgrader import TransportUpgrader
|
||||
from libp2p.typing import StreamHandlerFn
|
||||
|
||||
from ..exceptions import MultiError
|
||||
from .connection.raw_connection import RawConnection
|
||||
from .connection.swarm_connection import SwarmConn
|
||||
from .exceptions import SwarmException
|
||||
@ -95,21 +96,51 @@ class Swarm(INetwork, Service):
|
||||
try:
|
||||
# Get peer info from peer store
|
||||
addrs = self.peerstore.addrs(peer_id)
|
||||
except PeerStoreError:
|
||||
raise SwarmException(f"No known addresses to peer {peer_id}")
|
||||
except PeerStoreError as error:
|
||||
raise SwarmException(f"No known addresses to peer {peer_id}") from error
|
||||
|
||||
if not addrs:
|
||||
raise SwarmException(f"No known addresses to peer {peer_id}")
|
||||
|
||||
multiaddr = addrs[0]
|
||||
exceptions: List[SwarmException] = []
|
||||
|
||||
# Try all known addresses
|
||||
for multiaddr in addrs:
|
||||
try:
|
||||
return await self.dial_addr(multiaddr, peer_id)
|
||||
except SwarmException as e:
|
||||
exceptions.append(e)
|
||||
logger.debug(
|
||||
"encountered swarm exception when trying to connect to %s, "
|
||||
"trying next address...",
|
||||
multiaddr,
|
||||
exc_info=e,
|
||||
)
|
||||
|
||||
# Tried all addresses, raising exception.
|
||||
raise SwarmException(
|
||||
f"unable to connect to {peer_id}, no addresses established a successful connection "
|
||||
"(with exceptions)"
|
||||
) from MultiError(exceptions)
|
||||
|
||||
async def dial_addr(self, addr: Multiaddr, peer_id: ID) -> INetConn:
|
||||
"""
|
||||
dial_addr try to create a connection to peer_id with addr.
|
||||
|
||||
:param addr: the address we want to connect with
|
||||
:param peer_id: the peer we want to connect to
|
||||
:raises SwarmException: raised when an error occurs
|
||||
:return: network connection
|
||||
"""
|
||||
|
||||
# Dial peer (connection to peer does not yet exist)
|
||||
# Transport dials peer (gets back a raw conn)
|
||||
try:
|
||||
raw_conn = await self.transport.dial(multiaddr)
|
||||
raw_conn = await self.transport.dial(addr)
|
||||
except OpenConnectionError as error:
|
||||
logger.debug("fail to dial peer %s over base transport", peer_id)
|
||||
raise SwarmException(
|
||||
"fail to open connection to peer %s", peer_id
|
||||
f"fail to open connection to peer {peer_id}"
|
||||
) from error
|
||||
|
||||
logger.debug("dialed peer %s over base transport", peer_id)
|
||||
@ -146,7 +177,6 @@ class Swarm(INetwork, Service):
|
||||
async def new_stream(self, peer_id: ID) -> INetStream:
|
||||
"""
|
||||
:param peer_id: peer_id of destination
|
||||
:param protocol_id: protocol id
|
||||
:raises SwarmException: raised when an error occurs
|
||||
:return: net stream instance
|
||||
"""
|
||||
@ -164,13 +194,15 @@ class Swarm(INetwork, Service):
|
||||
:return: true if at least one success
|
||||
|
||||
For each multiaddr
|
||||
Check if a listener for multiaddr exists already
|
||||
If listener already exists, continue
|
||||
Otherwise:
|
||||
Capture multiaddr in conn handler
|
||||
Have conn handler delegate to stream handler
|
||||
Call listener listen with the multiaddr
|
||||
Map multiaddr to listener
|
||||
|
||||
- Check if a listener for multiaddr exists already
|
||||
- If listener already exists, continue
|
||||
- Otherwise:
|
||||
|
||||
- Capture multiaddr in conn handler
|
||||
- Have conn handler delegate to stream handler
|
||||
- Call listener listen with the multiaddr
|
||||
- Map multiaddr to listener
|
||||
"""
|
||||
for maddr in multiaddrs:
|
||||
if str(maddr) in self.listeners:
|
||||
@ -251,7 +283,7 @@ class Swarm(INetwork, Service):
|
||||
# TODO: Should be changed to close multisple connections,
|
||||
# if we have several connections per peer in the future.
|
||||
connection = self.connections[peer_id]
|
||||
# NOTE: `connection.close` will perform `del self.connections[peer_id]`
|
||||
# NOTE: `connection.close` will delete `peer_id` from `self.connections`
|
||||
# and `notify_disconnected` for us.
|
||||
await connection.close()
|
||||
|
||||
|
||||
@ -7,9 +7,6 @@ from .id import ID
|
||||
|
||||
|
||||
class IAddrBook(ABC):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_addr(self, peer_id: ID, addr: Multiaddr, ttl: int) -> None:
|
||||
"""
|
||||
|
||||
@ -44,7 +44,7 @@ class ID:
|
||||
@property
|
||||
def xor_id(self) -> int:
|
||||
if not self._xor_id:
|
||||
self._xor_id = int(digest(self._bytes).hex(), 16)
|
||||
self._xor_id = int(sha256_digest(self._bytes).hex(), 16)
|
||||
return self._xor_id
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
@ -89,7 +89,7 @@ class ID:
|
||||
return cls(mh_digest.encode())
|
||||
|
||||
|
||||
def digest(data: Union[str, bytes]) -> bytes:
|
||||
def sha256_digest(data: Union[str, bytes]) -> bytes:
|
||||
if isinstance(data, str):
|
||||
data = data.encode("utf8")
|
||||
return hashlib.sha1(data).digest()
|
||||
return hashlib.sha256(data).digest()
|
||||
|
||||
@ -2,46 +2,107 @@ from typing import Any, Dict, List, Sequence
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
from libp2p.crypto.keys import PrivateKey, PublicKey
|
||||
|
||||
from .peerdata_interface import IPeerData
|
||||
|
||||
|
||||
class PeerData(IPeerData):
|
||||
|
||||
pubkey: PublicKey
|
||||
privkey: PrivateKey
|
||||
metadata: Dict[Any, Any]
|
||||
protocols: List[str]
|
||||
addrs: List[Multiaddr]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.pubkey = None
|
||||
self.privkey = None
|
||||
self.metadata = {}
|
||||
self.protocols = []
|
||||
self.addrs = []
|
||||
|
||||
def get_protocols(self) -> List[str]:
|
||||
"""
|
||||
:return: all protocols associated with given peer
|
||||
"""
|
||||
return self.protocols
|
||||
|
||||
def add_protocols(self, protocols: Sequence[str]) -> None:
|
||||
"""
|
||||
:param protocols: protocols to add
|
||||
"""
|
||||
self.protocols.extend(list(protocols))
|
||||
|
||||
def set_protocols(self, protocols: Sequence[str]) -> None:
|
||||
"""
|
||||
:param protocols: protocols to set
|
||||
"""
|
||||
self.protocols = list(protocols)
|
||||
|
||||
def add_addrs(self, addrs: Sequence[Multiaddr]) -> None:
|
||||
"""
|
||||
:param addrs: multiaddresses to add
|
||||
"""
|
||||
self.addrs.extend(addrs)
|
||||
|
||||
def get_addrs(self) -> List[Multiaddr]:
|
||||
"""
|
||||
:return: all multiaddresses
|
||||
"""
|
||||
return self.addrs
|
||||
|
||||
def clear_addrs(self) -> None:
|
||||
"""Clear all addresses."""
|
||||
self.addrs = []
|
||||
|
||||
def put_metadata(self, key: str, val: Any) -> None:
|
||||
"""
|
||||
:param key: key in KV pair
|
||||
:param val: val to associate with key
|
||||
"""
|
||||
self.metadata[key] = val
|
||||
|
||||
def get_metadata(self, key: str) -> Any:
|
||||
"""
|
||||
:param key: key in KV pair
|
||||
:return: val for key
|
||||
:raise PeerDataError: key not found
|
||||
"""
|
||||
if key in self.metadata:
|
||||
return self.metadata[key]
|
||||
raise PeerDataError("key not found")
|
||||
|
||||
def add_pubkey(self, pubkey: PublicKey) -> None:
|
||||
"""
|
||||
:param pubkey:
|
||||
"""
|
||||
self.pubkey = pubkey
|
||||
|
||||
def get_pubkey(self) -> PublicKey:
|
||||
"""
|
||||
:return: public key of the peer
|
||||
:raise PeerDataError: if public key not found
|
||||
"""
|
||||
if self.pubkey is None:
|
||||
raise PeerDataError("public key not found")
|
||||
return self.pubkey
|
||||
|
||||
def add_privkey(self, privkey: PrivateKey) -> None:
|
||||
"""
|
||||
:param privkey:
|
||||
"""
|
||||
self.privkey = privkey
|
||||
|
||||
def get_privkey(self) -> PrivateKey:
|
||||
"""
|
||||
:return: private key of the peer
|
||||
:raise PeerDataError: if private key not found
|
||||
"""
|
||||
if self.privkey is None:
|
||||
raise PeerDataError("private key not found")
|
||||
return self.privkey
|
||||
|
||||
|
||||
class PeerDataError(KeyError):
|
||||
"""Raised when a key is not found in peer metadata."""
|
||||
|
||||
@ -3,6 +3,8 @@ from typing import Any, List, Sequence
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
from libp2p.crypto.keys import PrivateKey, PublicKey
|
||||
|
||||
from .peermetadata_interface import IPeerMetadata
|
||||
|
||||
|
||||
@ -22,7 +24,7 @@ class IPeerData(ABC):
|
||||
@abstractmethod
|
||||
def set_protocols(self, protocols: Sequence[str]) -> None:
|
||||
"""
|
||||
:param protocols: protocols to add
|
||||
:param protocols: protocols to set
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@ -46,7 +48,6 @@ class IPeerData(ABC):
|
||||
"""
|
||||
:param key: key in KV pair
|
||||
:param val: val to associate with key
|
||||
:raise Exception: unsuccesful put
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@ -54,5 +55,31 @@ class IPeerData(ABC):
|
||||
"""
|
||||
:param key: key in KV pair
|
||||
:return: val for key
|
||||
:raise Exception: key not found
|
||||
:raise PeerDataError: key not found
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_pubkey(self, pubkey: PublicKey) -> None:
|
||||
"""
|
||||
:param pubkey:
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_pubkey(self) -> PublicKey:
|
||||
"""
|
||||
:return: public key of the peer
|
||||
:raise PeerDataError: if public key not found
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_privkey(self, privkey: PrivateKey) -> None:
|
||||
"""
|
||||
:param privkey:
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_privkey(self) -> PrivateKey:
|
||||
"""
|
||||
:return: private key of the peer
|
||||
:raise PeerDataError: if private key not found
|
||||
"""
|
||||
|
||||
@ -5,9 +5,6 @@ from .id import ID
|
||||
|
||||
|
||||
class IPeerMetadata(ABC):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(self, peer_id: ID, key: str) -> Any:
|
||||
"""
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Sequence
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
from libp2p.crypto.keys import KeyPair, PrivateKey, PublicKey
|
||||
|
||||
from .id import ID
|
||||
from .peerdata import PeerData, PeerDataError
|
||||
from .peerinfo import PeerInfo
|
||||
@ -10,90 +13,185 @@ from .peerstore_interface import IPeerStore
|
||||
|
||||
class PeerStore(IPeerStore):
|
||||
|
||||
peer_map: Dict[ID, PeerData]
|
||||
peer_data_map: Dict[ID, PeerData]
|
||||
|
||||
def __init__(self) -> None:
|
||||
IPeerStore.__init__(self)
|
||||
self.peer_map = {}
|
||||
self.peer_data_map = defaultdict(PeerData)
|
||||
|
||||
def __create_or_get_peer(self, peer_id: ID) -> PeerData:
|
||||
def peer_info(self, peer_id: ID) -> PeerInfo:
|
||||
"""
|
||||
Returns the peer data for peer_id or creates a new peer data (and
|
||||
stores it in peer_map) if peer data for peer_id does not yet exist.
|
||||
|
||||
:param peer_id: peer ID
|
||||
:return: peer data
|
||||
:param peer_id: peer ID to get info for
|
||||
:return: peer info object
|
||||
"""
|
||||
if peer_id in self.peer_map:
|
||||
return self.peer_map[peer_id]
|
||||
data = PeerData()
|
||||
self.peer_map[peer_id] = data
|
||||
return self.peer_map[peer_id]
|
||||
|
||||
def peer_info(self, peer_id: ID) -> Optional[PeerInfo]:
|
||||
if peer_id in self.peer_map:
|
||||
peer_data = self.peer_map[peer_id]
|
||||
return PeerInfo(peer_id, peer_data.addrs)
|
||||
return None
|
||||
if peer_id in self.peer_data_map:
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
return PeerInfo(peer_id, peer_data.get_addrs())
|
||||
raise PeerStoreError("peer ID not found")
|
||||
|
||||
def get_protocols(self, peer_id: ID) -> List[str]:
|
||||
if peer_id in self.peer_map:
|
||||
return self.peer_map[peer_id].get_protocols()
|
||||
"""
|
||||
:param peer_id: peer ID to get protocols for
|
||||
:return: protocols (as list of strings)
|
||||
:raise PeerStoreError: if peer ID not found
|
||||
"""
|
||||
if peer_id in self.peer_data_map:
|
||||
return self.peer_data_map[peer_id].get_protocols()
|
||||
raise PeerStoreError("peer ID not found")
|
||||
|
||||
def add_protocols(self, peer_id: ID, protocols: Sequence[str]) -> None:
|
||||
peer = self.__create_or_get_peer(peer_id)
|
||||
peer.add_protocols(list(protocols))
|
||||
"""
|
||||
:param peer_id: peer ID to add protocols for
|
||||
:param protocols: protocols to add
|
||||
"""
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
peer_data.add_protocols(list(protocols))
|
||||
|
||||
def set_protocols(self, peer_id: ID, protocols: Sequence[str]) -> None:
|
||||
peer = self.__create_or_get_peer(peer_id)
|
||||
peer.set_protocols(list(protocols))
|
||||
"""
|
||||
:param peer_id: peer ID to set protocols for
|
||||
:param protocols: protocols to set
|
||||
"""
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
peer_data.set_protocols(list(protocols))
|
||||
|
||||
def peer_ids(self) -> List[ID]:
|
||||
return list(self.peer_map.keys())
|
||||
"""
|
||||
:return: all of the peer IDs stored in peer store
|
||||
"""
|
||||
return list(self.peer_data_map.keys())
|
||||
|
||||
def get(self, peer_id: ID, key: str) -> Any:
|
||||
if peer_id in self.peer_map:
|
||||
"""
|
||||
:param peer_id: peer ID to get peer data for
|
||||
:param key: the key to search value for
|
||||
:return: value corresponding to the key
|
||||
:raise PeerStoreError: if peer ID or value not found
|
||||
"""
|
||||
if peer_id in self.peer_data_map:
|
||||
try:
|
||||
val = self.peer_map[peer_id].get_metadata(key)
|
||||
val = self.peer_data_map[peer_id].get_metadata(key)
|
||||
except PeerDataError as error:
|
||||
raise PeerStoreError(error)
|
||||
return val
|
||||
raise PeerStoreError("peer ID not found")
|
||||
|
||||
def put(self, peer_id: ID, key: str, val: Any) -> None:
|
||||
# <<?>>
|
||||
# This can output an error, not sure what the possible errors are
|
||||
peer = self.__create_or_get_peer(peer_id)
|
||||
peer.put_metadata(key, val)
|
||||
"""
|
||||
:param peer_id: peer ID to put peer data for
|
||||
:param key:
|
||||
:param value:
|
||||
"""
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
peer_data.put_metadata(key, val)
|
||||
|
||||
def add_addr(self, peer_id: ID, addr: Multiaddr, ttl: int) -> None:
|
||||
"""
|
||||
:param peer_id: peer ID to add address for
|
||||
:param addr:
|
||||
:param ttl: time-to-live for the this record
|
||||
"""
|
||||
self.add_addrs(peer_id, [addr], ttl)
|
||||
|
||||
def add_addrs(self, peer_id: ID, addrs: Sequence[Multiaddr], ttl: int) -> None:
|
||||
"""
|
||||
:param peer_id: peer ID to add address for
|
||||
:param addrs:
|
||||
:param ttl: time-to-live for the this record
|
||||
"""
|
||||
# Ignore ttl for now
|
||||
peer = self.__create_or_get_peer(peer_id)
|
||||
peer.add_addrs(list(addrs))
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
peer_data.add_addrs(list(addrs))
|
||||
|
||||
def addrs(self, peer_id: ID) -> List[Multiaddr]:
|
||||
if peer_id in self.peer_map:
|
||||
return self.peer_map[peer_id].get_addrs()
|
||||
"""
|
||||
:param peer_id: peer ID to get addrs for
|
||||
:return: list of addrs
|
||||
:raise PeerStoreError: if peer ID not found
|
||||
"""
|
||||
if peer_id in self.peer_data_map:
|
||||
return self.peer_data_map[peer_id].get_addrs()
|
||||
raise PeerStoreError("peer ID not found")
|
||||
|
||||
def clear_addrs(self, peer_id: ID) -> None:
|
||||
"""
|
||||
:param peer_id: peer ID to clear addrs for
|
||||
"""
|
||||
# Only clear addresses if the peer is in peer map
|
||||
if peer_id in self.peer_map:
|
||||
self.peer_map[peer_id].clear_addrs()
|
||||
if peer_id in self.peer_data_map:
|
||||
self.peer_data_map[peer_id].clear_addrs()
|
||||
|
||||
def peers_with_addrs(self) -> List[ID]:
|
||||
"""
|
||||
:return: all of the peer IDs which has addrs stored in peer store
|
||||
"""
|
||||
# Add all peers with addrs at least 1 to output
|
||||
output: List[ID] = []
|
||||
|
||||
for peer_id in self.peer_map:
|
||||
if len(self.peer_map[peer_id].get_addrs()) >= 1:
|
||||
for peer_id in self.peer_data_map:
|
||||
if len(self.peer_data_map[peer_id].get_addrs()) >= 1:
|
||||
output.append(peer_id)
|
||||
return output
|
||||
|
||||
def add_pubkey(self, peer_id: ID, pubkey: PublicKey) -> None:
|
||||
"""
|
||||
:param peer_id: peer ID to add public key for
|
||||
:param pubkey:
|
||||
:raise PeerStoreError: if peer ID and pubkey does not match
|
||||
"""
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
if ID.from_pubkey(pubkey) != peer_id:
|
||||
raise PeerStoreError("peer ID and pubkey does not match")
|
||||
peer_data.add_pubkey(pubkey)
|
||||
|
||||
def pubkey(self, peer_id: ID) -> PublicKey:
|
||||
"""
|
||||
:param peer_id: peer ID to get public key for
|
||||
:return: public key of the peer
|
||||
:raise PeerStoreError: if peer ID or peer pubkey not found
|
||||
"""
|
||||
if peer_id in self.peer_data_map:
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
try:
|
||||
pubkey = peer_data.get_pubkey()
|
||||
except PeerDataError:
|
||||
raise PeerStoreError("peer pubkey not found")
|
||||
return pubkey
|
||||
raise PeerStoreError("peer ID not found")
|
||||
|
||||
def add_privkey(self, peer_id: ID, privkey: PrivateKey) -> None:
|
||||
"""
|
||||
:param peer_id: peer ID to add private key for
|
||||
:param privkey:
|
||||
:raise PeerStoreError: if peer ID or peer privkey not found
|
||||
"""
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
if ID.from_pubkey(privkey.get_public_key()) != peer_id:
|
||||
raise PeerStoreError("peer ID and privkey does not match")
|
||||
peer_data.add_privkey(privkey)
|
||||
|
||||
def privkey(self, peer_id: ID) -> PrivateKey:
|
||||
"""
|
||||
:param peer_id: peer ID to get private key for
|
||||
:return: private key of the peer
|
||||
:raise PeerStoreError: if peer ID or peer privkey not found
|
||||
"""
|
||||
if peer_id in self.peer_data_map:
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
try:
|
||||
privkey = peer_data.get_privkey()
|
||||
except PeerDataError:
|
||||
raise PeerStoreError("peer privkey not found")
|
||||
return privkey
|
||||
raise PeerStoreError("peer ID not found")
|
||||
|
||||
def add_key_pair(self, peer_id: ID, key_pair: KeyPair) -> None:
|
||||
"""
|
||||
:param peer_id: peer ID to add private key for
|
||||
:param key_pair:
|
||||
"""
|
||||
self.add_pubkey(peer_id, key_pair.public_key)
|
||||
self.add_privkey(peer_id, key_pair.private_key)
|
||||
|
||||
|
||||
class PeerStoreError(KeyError):
|
||||
"""Raised when peer ID is not found in peer store."""
|
||||
|
||||
@ -1,5 +1,9 @@
|
||||
from abc import abstractmethod
|
||||
from typing import List, Sequence
|
||||
from typing import Any, List, Sequence
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
from libp2p.crypto.keys import KeyPair, PrivateKey, PublicKey
|
||||
|
||||
from .addrbook_interface import IAddrBook
|
||||
from .id import ID
|
||||
@ -8,10 +12,6 @@ from .peermetadata_interface import IPeerMetadata
|
||||
|
||||
|
||||
class IPeerStore(IAddrBook, IPeerMetadata):
|
||||
def __init__(self) -> None:
|
||||
IPeerMetadata.__init__(self)
|
||||
IAddrBook.__init__(self)
|
||||
|
||||
@abstractmethod
|
||||
def peer_info(self, peer_id: ID) -> PeerInfo:
|
||||
"""
|
||||
@ -23,8 +23,8 @@ class IPeerStore(IAddrBook, IPeerMetadata):
|
||||
def get_protocols(self, peer_id: ID) -> List[str]:
|
||||
"""
|
||||
:param peer_id: peer ID to get protocols for
|
||||
:return: protocols (as strings)
|
||||
:raise Exception: peer ID not found exception
|
||||
:return: protocols (as list of strings)
|
||||
:raise PeerStoreError: if peer ID not found
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@ -32,7 +32,6 @@ class IPeerStore(IAddrBook, IPeerMetadata):
|
||||
"""
|
||||
:param peer_id: peer ID to add protocols for
|
||||
:param protocols: protocols to add
|
||||
:raise Exception: peer ID not found
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@ -40,7 +39,6 @@ class IPeerStore(IAddrBook, IPeerMetadata):
|
||||
"""
|
||||
:param peer_id: peer ID to set protocols for
|
||||
:param protocols: protocols to set
|
||||
:raise Exception: peer ID not found
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@ -48,3 +46,95 @@ class IPeerStore(IAddrBook, IPeerMetadata):
|
||||
"""
|
||||
:return: all of the peer IDs stored in peer store
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, peer_id: ID, key: str) -> Any:
|
||||
"""
|
||||
:param peer_id: peer ID to get peer data for
|
||||
:param key: the key to search value for
|
||||
:return: value corresponding to the key
|
||||
:raise PeerStoreError: if peer ID or value not found
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def put(self, peer_id: ID, key: str, val: Any) -> None:
|
||||
"""
|
||||
:param peer_id: peer ID to put peer data for
|
||||
:param key:
|
||||
:param value:
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_addr(self, peer_id: ID, addr: Multiaddr, ttl: int) -> None:
|
||||
"""
|
||||
:param peer_id: peer ID to add address for
|
||||
:param addr:
|
||||
:param ttl: time-to-live for the this record
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_addrs(self, peer_id: ID, addrs: Sequence[Multiaddr], ttl: int) -> None:
|
||||
"""
|
||||
:param peer_id: peer ID to add address for
|
||||
:param addrs:
|
||||
:param ttl: time-to-live for the this record
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def addrs(self, peer_id: ID) -> List[Multiaddr]:
|
||||
"""
|
||||
:param peer_id: peer ID to get addrs for
|
||||
:return: list of addrs
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def clear_addrs(self, peer_id: ID) -> None:
|
||||
"""
|
||||
:param peer_id: peer ID to clear addrs for
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def peers_with_addrs(self) -> List[ID]:
|
||||
"""
|
||||
:return: all of the peer IDs which has addrs stored in peer store
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_pubkey(self, peer_id: ID, pubkey: PublicKey) -> None:
|
||||
"""
|
||||
:param peer_id: peer ID to add public key for
|
||||
:param pubkey:
|
||||
:raise PeerStoreError: if peer ID already has pubkey set
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def pubkey(self, peer_id: ID) -> PublicKey:
|
||||
"""
|
||||
:param peer_id: peer ID to get public key for
|
||||
:return: public key of the peer
|
||||
:raise PeerStoreError: if peer ID not found
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_privkey(self, peer_id: ID, privkey: PrivateKey) -> None:
|
||||
"""
|
||||
:param peer_id: peer ID to add private key for
|
||||
:param privkey:
|
||||
:raise PeerStoreError: if peer ID already has privkey set
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def privkey(self, peer_id: ID) -> PrivateKey:
|
||||
"""
|
||||
:param peer_id: peer ID to get private key for
|
||||
:return: private key of the peer
|
||||
:raise PeerStoreError: if peer ID not found
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_key_pair(self, peer_id: ID, key_pair: KeyPair) -> None:
|
||||
"""
|
||||
:param peer_id: peer ID to add private key for
|
||||
:param key_pair:
|
||||
:raise PeerStoreError: if peer ID already has pubkey or privkey set
|
||||
"""
|
||||
|
||||
@ -81,16 +81,20 @@ class FloodSub(IPubsubRouter):
|
||||
:param pubsub_msg: pubsub message in protobuf.
|
||||
"""
|
||||
|
||||
peers_gen = self._get_peers_to_send(
|
||||
pubsub_msg.topicIDs,
|
||||
msg_forwarder=msg_forwarder,
|
||||
origin=ID(pubsub_msg.from_id),
|
||||
peers_gen = set(
|
||||
self._get_peers_to_send(
|
||||
pubsub_msg.topicIDs,
|
||||
msg_forwarder=msg_forwarder,
|
||||
origin=ID(pubsub_msg.from_id),
|
||||
)
|
||||
)
|
||||
rpc_msg = rpc_pb2.RPC(publish=[pubsub_msg])
|
||||
|
||||
logger.debug("publishing message %s", pubsub_msg)
|
||||
|
||||
for peer_id in peers_gen:
|
||||
if peer_id not in self.pubsub.peers:
|
||||
continue
|
||||
stream = self.pubsub.peers[peer_id]
|
||||
# FIXME: We should add a `WriteMsg` similar to write delimited messages.
|
||||
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107
|
||||
@ -98,6 +102,7 @@ class FloodSub(IPubsubRouter):
|
||||
await stream.write(encode_varint_prefixed(rpc_msg.SerializeToString()))
|
||||
except StreamClosed:
|
||||
logger.debug("Fail to publish message to %s: stream closed", peer_id)
|
||||
self.pubsub._handle_dead_peer(peer_id)
|
||||
|
||||
async def join(self, topic: str) -> None:
|
||||
"""
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from ast import literal_eval
|
||||
from collections import defaultdict
|
||||
import logging
|
||||
import random
|
||||
from typing import Any, Dict, Iterable, List, Sequence, Set
|
||||
from typing import Any, DefaultDict, Dict, Iterable, List, Sequence, Set, Tuple
|
||||
|
||||
from async_service import Service
|
||||
import trio
|
||||
@ -33,18 +34,18 @@ class GossipSub(IPubsubRouter, Service):
|
||||
|
||||
time_to_live: int
|
||||
|
||||
mesh: Dict[str, List[ID]]
|
||||
fanout: Dict[str, List[ID]]
|
||||
mesh: Dict[str, Set[ID]]
|
||||
fanout: Dict[str, Set[ID]]
|
||||
|
||||
peers_to_protocol: Dict[ID, str]
|
||||
# The protocol peer supports
|
||||
peer_protocol: Dict[ID, TProtocol]
|
||||
|
||||
time_since_last_publish: Dict[str, int]
|
||||
|
||||
peers_gossipsub: List[ID]
|
||||
peers_floodsub: List[ID]
|
||||
# TODO: Add `time_since_last_publish`
|
||||
# Create topic --> time since last publish map.
|
||||
|
||||
mcache: MessageCache
|
||||
|
||||
heartbeat_initial_delay: float
|
||||
heartbeat_interval: int
|
||||
|
||||
def __init__(
|
||||
@ -56,6 +57,7 @@ class GossipSub(IPubsubRouter, Service):
|
||||
time_to_live: int,
|
||||
gossip_window: int = 3,
|
||||
gossip_history: int = 5,
|
||||
heartbeat_initial_delay: float = 0.1,
|
||||
heartbeat_interval: int = 120,
|
||||
) -> None:
|
||||
self.protocols = list(protocols)
|
||||
@ -74,18 +76,13 @@ class GossipSub(IPubsubRouter, Service):
|
||||
self.fanout = {}
|
||||
|
||||
# Create peer --> protocol mapping
|
||||
self.peers_to_protocol = {}
|
||||
|
||||
# Create topic --> time since last publish map
|
||||
self.time_since_last_publish = {}
|
||||
|
||||
self.peers_gossipsub = []
|
||||
self.peers_floodsub = []
|
||||
self.peer_protocol = {}
|
||||
|
||||
# Create message cache
|
||||
self.mcache = MessageCache(gossip_window, gossip_history)
|
||||
|
||||
# Create heartbeat timer
|
||||
self.heartbeat_initial_delay = heartbeat_initial_delay
|
||||
self.heartbeat_interval = heartbeat_interval
|
||||
|
||||
async def run(self) -> None:
|
||||
@ -122,18 +119,13 @@ class GossipSub(IPubsubRouter, Service):
|
||||
"""
|
||||
logger.debug("adding peer %s with protocol %s", peer_id, protocol_id)
|
||||
|
||||
if protocol_id == PROTOCOL_ID:
|
||||
self.peers_gossipsub.append(peer_id)
|
||||
elif protocol_id == floodsub.PROTOCOL_ID:
|
||||
self.peers_floodsub.append(peer_id)
|
||||
else:
|
||||
if protocol_id not in (PROTOCOL_ID, floodsub.PROTOCOL_ID):
|
||||
# We should never enter here. Becuase the `protocol_id` is registered by your pubsub
|
||||
# instance in multistream-select, but it is not the protocol that gossipsub supports.
|
||||
# In this case, probably we registered gossipsub to a wrong `protocol_id`
|
||||
# in multistream-select, or wrong versions.
|
||||
# TODO: Better handling
|
||||
raise Exception(f"protocol is not supported: protocol_id={protocol_id}")
|
||||
self.peers_to_protocol[peer_id] = protocol_id
|
||||
raise ValueError(f"Protocol={protocol_id} is not supported.")
|
||||
self.peer_protocol[peer_id] = protocol_id
|
||||
|
||||
def remove_peer(self, peer_id: ID) -> None:
|
||||
"""
|
||||
@ -143,13 +135,12 @@ class GossipSub(IPubsubRouter, Service):
|
||||
"""
|
||||
logger.debug("removing peer %s", peer_id)
|
||||
|
||||
if peer_id in self.peers_gossipsub:
|
||||
self.peers_gossipsub.remove(peer_id)
|
||||
elif peer_id in self.peers_floodsub:
|
||||
self.peers_floodsub.remove(peer_id)
|
||||
for topic in self.mesh:
|
||||
self.mesh[topic].discard(peer_id)
|
||||
for topic in self.fanout:
|
||||
self.fanout[topic].discard(peer_id)
|
||||
|
||||
if peer_id in self.peers_to_protocol:
|
||||
del self.peers_to_protocol[peer_id]
|
||||
self.peer_protocol.pop(peer_id, None)
|
||||
|
||||
async def handle_rpc(self, rpc: rpc_pb2.RPC, sender_peer_id: ID) -> None:
|
||||
"""
|
||||
@ -189,6 +180,8 @@ class GossipSub(IPubsubRouter, Service):
|
||||
logger.debug("publishing message %s", pubsub_msg)
|
||||
|
||||
for peer_id in peers_gen:
|
||||
if peer_id not in self.pubsub.peers:
|
||||
continue
|
||||
stream = self.pubsub.peers[peer_id]
|
||||
# FIXME: We should add a `WriteMsg` similar to write delimited messages.
|
||||
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107
|
||||
@ -215,36 +208,41 @@ class GossipSub(IPubsubRouter, Service):
|
||||
continue
|
||||
|
||||
# floodsub peers
|
||||
for peer_id in self.pubsub.peer_topics[topic]:
|
||||
# FIXME: `gossipsub.peers_floodsub` can be changed to `gossipsub.peers` in go.
|
||||
# This will improve the efficiency when searching for a peer's protocol id.
|
||||
if peer_id in self.peers_floodsub:
|
||||
send_to.add(peer_id)
|
||||
floodsub_peers: Set[ID] = set(
|
||||
peer_id
|
||||
for peer_id in self.pubsub.peer_topics[topic]
|
||||
if self.peer_protocol[peer_id] == floodsub.PROTOCOL_ID
|
||||
)
|
||||
send_to.update(floodsub_peers)
|
||||
|
||||
# gossipsub peers
|
||||
in_topic_gossipsub_peers: List[ID] = None
|
||||
# TODO: Do we need to check `topic in self.pubsub.my_topics`?
|
||||
gossipsub_peers: Set[ID] = set()
|
||||
if topic in self.mesh:
|
||||
in_topic_gossipsub_peers = self.mesh[topic]
|
||||
gossipsub_peers = self.mesh[topic]
|
||||
else:
|
||||
# TODO(robzajac): Is topic DEFINITELY supposed to be in fanout if we are not
|
||||
# subscribed?
|
||||
# I assume there could be short periods between heartbeats where topic may not
|
||||
# be but we should check that this path gets hit appropriately
|
||||
|
||||
if (topic not in self.fanout) or (len(self.fanout[topic]) == 0):
|
||||
# If no peers in fanout, choose some peers from gossipsub peers in topic.
|
||||
self.fanout[topic] = self._get_in_topic_gossipsub_peers_from_minus(
|
||||
topic, self.degree, []
|
||||
)
|
||||
in_topic_gossipsub_peers = self.fanout[topic]
|
||||
for peer_id in in_topic_gossipsub_peers:
|
||||
send_to.add(peer_id)
|
||||
# When we publish to a topic that we have not subscribe to, we randomly pick
|
||||
# `self.degree` number of peers who have subscribed to the topic and add them
|
||||
# as our `fanout` peers.
|
||||
topic_in_fanout: bool = topic in self.fanout
|
||||
fanout_peers: Set[ID] = self.fanout[topic] if topic_in_fanout else set()
|
||||
fanout_size = len(fanout_peers)
|
||||
if not topic_in_fanout or (
|
||||
topic_in_fanout and fanout_size < self.degree
|
||||
):
|
||||
if topic in self.pubsub.peer_topics:
|
||||
# Combine fanout peers with selected peers
|
||||
fanout_peers.update(
|
||||
self._get_in_topic_gossipsub_peers_from_minus(
|
||||
topic, self.degree - fanout_size, fanout_peers
|
||||
)
|
||||
)
|
||||
self.fanout[topic] = fanout_peers
|
||||
gossipsub_peers = fanout_peers
|
||||
send_to.update(gossipsub_peers)
|
||||
# Excludes `msg_forwarder` and `origin`
|
||||
yield from send_to.difference([msg_forwarder, origin])
|
||||
|
||||
async def join(self, topic: str) -> None:
|
||||
# Note: the comments here are the near-exact algorithm description from the spec
|
||||
"""
|
||||
Join notifies the router that we want to receive and forward messages
|
||||
in a topic. It is invoked after the subscription announcement.
|
||||
@ -256,10 +254,10 @@ class GossipSub(IPubsubRouter, Service):
|
||||
if topic in self.mesh:
|
||||
return
|
||||
# Create mesh[topic] if it does not yet exist
|
||||
self.mesh[topic] = []
|
||||
self.mesh[topic] = set()
|
||||
|
||||
topic_in_fanout: bool = topic in self.fanout
|
||||
fanout_peers: List[ID] = self.fanout[topic] if topic_in_fanout else []
|
||||
fanout_peers: Set[ID] = self.fanout[topic] if topic_in_fanout else set()
|
||||
fanout_size = len(fanout_peers)
|
||||
if not topic_in_fanout or (topic_in_fanout and fanout_size < self.degree):
|
||||
# There are less than D peers (let this number be x)
|
||||
@ -270,16 +268,14 @@ class GossipSub(IPubsubRouter, Service):
|
||||
topic, self.degree - fanout_size, fanout_peers
|
||||
)
|
||||
# Combine fanout peers with selected peers
|
||||
fanout_peers += selected_peers
|
||||
fanout_peers.update(selected_peers)
|
||||
|
||||
# Add fanout peers to mesh and notifies them with a GRAFT(topic) control message.
|
||||
for peer in fanout_peers:
|
||||
if peer not in self.mesh[topic]:
|
||||
self.mesh[topic].append(peer)
|
||||
await self.emit_graft(topic, peer)
|
||||
self.mesh[topic].add(peer)
|
||||
await self.emit_graft(topic, peer)
|
||||
|
||||
if topic_in_fanout:
|
||||
del self.fanout[topic]
|
||||
self.fanout.pop(topic, None)
|
||||
|
||||
async def leave(self, topic: str) -> None:
|
||||
# Note: the comments here are the near-exact algorithm description from the spec
|
||||
@ -298,7 +294,75 @@ class GossipSub(IPubsubRouter, Service):
|
||||
await self.emit_prune(topic, peer)
|
||||
|
||||
# Forget mesh[topic]
|
||||
del self.mesh[topic]
|
||||
self.mesh.pop(topic, None)
|
||||
|
||||
async def _emit_control_msgs(
|
||||
self,
|
||||
peers_to_graft: Dict[ID, List[str]],
|
||||
peers_to_prune: Dict[ID, List[str]],
|
||||
peers_to_gossip: Dict[ID, Dict[str, List[str]]],
|
||||
) -> None:
|
||||
graft_msgs: List[rpc_pb2.ControlGraft] = []
|
||||
prune_msgs: List[rpc_pb2.ControlPrune] = []
|
||||
ihave_msgs: List[rpc_pb2.ControlIHave] = []
|
||||
# Starting with GRAFT messages
|
||||
for peer, topics in peers_to_graft.items():
|
||||
for topic in topics:
|
||||
graft_msg: rpc_pb2.ControlGraft = rpc_pb2.ControlGraft(topicID=topic)
|
||||
graft_msgs.append(graft_msg)
|
||||
|
||||
# If there are also PRUNE messages to send to this peer
|
||||
if peer in peers_to_prune:
|
||||
for topic in peers_to_prune[peer]:
|
||||
prune_msg: rpc_pb2.ControlPrune = rpc_pb2.ControlPrune(
|
||||
topicID=topic
|
||||
)
|
||||
prune_msgs.append(prune_msg)
|
||||
del peers_to_prune[peer]
|
||||
|
||||
# If there are also IHAVE messages to send to this peer
|
||||
if peer in peers_to_gossip:
|
||||
for topic in peers_to_gossip[peer]:
|
||||
ihave_msg: rpc_pb2.ControlIHave = rpc_pb2.ControlIHave(
|
||||
messageIDs=peers_to_gossip[peer][topic], topicID=topic
|
||||
)
|
||||
ihave_msgs.append(ihave_msg)
|
||||
del peers_to_gossip[peer]
|
||||
|
||||
control_msg = self.pack_control_msgs(ihave_msgs, graft_msgs, prune_msgs)
|
||||
await self.emit_control_message(control_msg, peer)
|
||||
|
||||
# Next with PRUNE messages
|
||||
for peer, topics in peers_to_prune.items():
|
||||
prune_msgs = []
|
||||
for topic in topics:
|
||||
prune_msg = rpc_pb2.ControlPrune(topicID=topic)
|
||||
prune_msgs.append(prune_msg)
|
||||
|
||||
# If there are also IHAVE messages to send to this peer
|
||||
if peer in peers_to_gossip:
|
||||
ihave_msgs = []
|
||||
for topic in peers_to_gossip[peer]:
|
||||
ihave_msg = rpc_pb2.ControlIHave(
|
||||
messageIDs=peers_to_gossip[peer][topic], topicID=topic
|
||||
)
|
||||
ihave_msgs.append(ihave_msg)
|
||||
del peers_to_gossip[peer]
|
||||
|
||||
control_msg = self.pack_control_msgs(ihave_msgs, None, prune_msgs)
|
||||
await self.emit_control_message(control_msg, peer)
|
||||
|
||||
# Fianlly IHAVE messages
|
||||
for peer in peers_to_gossip:
|
||||
ihave_msgs = []
|
||||
for topic in peers_to_gossip[peer]:
|
||||
ihave_msg = rpc_pb2.ControlIHave(
|
||||
messageIDs=peers_to_gossip[peer][topic], topicID=topic
|
||||
)
|
||||
ihave_msgs.append(ihave_msg)
|
||||
|
||||
control_msg = self.pack_control_msgs(ihave_msgs, None, None)
|
||||
await self.emit_control_message(control_msg, peer)
|
||||
|
||||
# Heartbeat
|
||||
async def heartbeat(self) -> None:
|
||||
@ -308,16 +372,29 @@ class GossipSub(IPubsubRouter, Service):
|
||||
Note: the heartbeats are called with awaits because each heartbeat depends on the
|
||||
state changes in the preceding heartbeat
|
||||
"""
|
||||
# Start after a delay. Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/01b9825fbee1848751d90a8469e3f5f43bac8466/gossipsub.go#L410 # Noqa: E501
|
||||
await trio.sleep(self.heartbeat_initial_delay)
|
||||
while True:
|
||||
# Maintain mesh and keep track of which peers to send GRAFT or PRUNE to
|
||||
peers_to_graft, peers_to_prune = self.mesh_heartbeat()
|
||||
# Maintain fanout
|
||||
self.fanout_heartbeat()
|
||||
# Get the peers to send IHAVE to
|
||||
peers_to_gossip = self.gossip_heartbeat()
|
||||
# Pack GRAFT, PRUNE and IHAVE for the same peer into one control message and send it
|
||||
await self._emit_control_msgs(
|
||||
peers_to_graft, peers_to_prune, peers_to_gossip
|
||||
)
|
||||
|
||||
await self.mesh_heartbeat()
|
||||
await self.fanout_heartbeat()
|
||||
await self.gossip_heartbeat()
|
||||
self.mcache.shift()
|
||||
|
||||
await trio.sleep(self.heartbeat_interval)
|
||||
|
||||
async def mesh_heartbeat(self) -> None:
|
||||
# Note: the comments here are the exact pseudocode from the spec
|
||||
def mesh_heartbeat(
|
||||
self
|
||||
) -> Tuple[DefaultDict[ID, List[str]], DefaultDict[ID, List[str]]]:
|
||||
peers_to_graft: DefaultDict[ID, List[str]] = defaultdict(list)
|
||||
peers_to_prune: DefaultDict[ID, List[str]] = defaultdict(list)
|
||||
for topic in self.mesh:
|
||||
# Skip if no peers have subscribed to the topic
|
||||
if topic not in self.pubsub.peer_topics:
|
||||
@ -330,41 +407,43 @@ class GossipSub(IPubsubRouter, Service):
|
||||
topic, self.degree - num_mesh_peers_in_topic, self.mesh[topic]
|
||||
)
|
||||
|
||||
fanout_peers_not_in_mesh: List[ID] = [
|
||||
peer for peer in selected_peers if peer not in self.mesh[topic]
|
||||
]
|
||||
for peer in fanout_peers_not_in_mesh:
|
||||
for peer in selected_peers:
|
||||
# Add peer to mesh[topic]
|
||||
self.mesh[topic].append(peer)
|
||||
self.mesh[topic].add(peer)
|
||||
|
||||
# Emit GRAFT(topic) control message to peer
|
||||
await self.emit_graft(topic, peer)
|
||||
peers_to_graft[peer].append(topic)
|
||||
|
||||
if num_mesh_peers_in_topic > self.degree_high:
|
||||
# Select |mesh[topic]| - D peers from mesh[topic]
|
||||
selected_peers = self.select_from_minus(
|
||||
num_mesh_peers_in_topic - self.degree, self.mesh[topic], []
|
||||
num_mesh_peers_in_topic - self.degree, self.mesh[topic], set()
|
||||
)
|
||||
for peer in selected_peers:
|
||||
# Remove peer from mesh[topic]
|
||||
self.mesh[topic].remove(peer)
|
||||
self.mesh[topic].discard(peer)
|
||||
|
||||
# Emit PRUNE(topic) control message to peer
|
||||
await self.emit_prune(topic, peer)
|
||||
peers_to_prune[peer].append(topic)
|
||||
return peers_to_graft, peers_to_prune
|
||||
|
||||
async def fanout_heartbeat(self) -> None:
|
||||
def fanout_heartbeat(self) -> None:
|
||||
# Note: the comments here are the exact pseudocode from the spec
|
||||
for topic in self.fanout:
|
||||
# If time since last published > ttl
|
||||
# TODO: there's no way time_since_last_publish gets set anywhere yet
|
||||
if (
|
||||
topic in self.time_since_last_publish
|
||||
and self.time_since_last_publish[topic] > self.time_to_live
|
||||
):
|
||||
# Delete topic entry if it's not in `pubsub.peer_topics`
|
||||
# or (TODO) if it's time-since-last-published > ttl
|
||||
if topic not in self.pubsub.peer_topics:
|
||||
# Remove topic from fanout
|
||||
del self.fanout[topic]
|
||||
del self.time_since_last_publish[topic]
|
||||
else:
|
||||
# Check if fanout peers are still in the topic and remove the ones that are not
|
||||
# ref: https://github.com/libp2p/go-libp2p-pubsub/blob/01b9825fbee1848751d90a8469e3f5f43bac8466/gossipsub.go#L498-L504 # noqa: E501
|
||||
in_topic_fanout_peers = [
|
||||
peer
|
||||
for peer in self.fanout[topic]
|
||||
if peer in self.pubsub.peer_topics[topic]
|
||||
]
|
||||
self.fanout[topic] = set(in_topic_fanout_peers)
|
||||
num_fanout_peers_in_topic = len(self.fanout[topic])
|
||||
|
||||
# If |fanout[topic]| < D
|
||||
@ -376,53 +455,43 @@ class GossipSub(IPubsubRouter, Service):
|
||||
self.fanout[topic],
|
||||
)
|
||||
# Add the peers to fanout[topic]
|
||||
self.fanout[topic].extend(selected_peers)
|
||||
self.fanout[topic].update(selected_peers)
|
||||
|
||||
async def gossip_heartbeat(self) -> None:
|
||||
def gossip_heartbeat(self) -> DefaultDict[ID, Dict[str, List[str]]]:
|
||||
peers_to_gossip: DefaultDict[ID, Dict[str, List[str]]] = defaultdict(dict)
|
||||
for topic in self.mesh:
|
||||
msg_ids = self.mcache.window(topic)
|
||||
if msg_ids:
|
||||
# TODO: Make more efficient, possibly using a generator?
|
||||
# Get all pubsub peers in a topic and only add them if they are gossipsub peers too
|
||||
if topic in self.pubsub.peer_topics:
|
||||
# Select D peers from peers.gossipsub[topic]
|
||||
peers_to_emit_ihave_to = self._get_in_topic_gossipsub_peers_from_minus(
|
||||
topic, self.degree, []
|
||||
topic, self.degree, self.mesh[topic]
|
||||
)
|
||||
|
||||
msg_id_strs = [str(msg_id) for msg_id in msg_ids]
|
||||
for peer in peers_to_emit_ihave_to:
|
||||
# TODO: this line is a monster, can hopefully be simplified
|
||||
if (
|
||||
topic not in self.mesh or (peer not in self.mesh[topic])
|
||||
) and (
|
||||
topic not in self.fanout or (peer not in self.fanout[topic])
|
||||
):
|
||||
msg_id_strs = [str(msg_id) for msg_id in msg_ids]
|
||||
await self.emit_ihave(topic, msg_id_strs, peer)
|
||||
peers_to_gossip[peer][topic] = msg_id_strs
|
||||
|
||||
# TODO: Refactor and Dedup. This section is the roughly the same as the above.
|
||||
# Do the same for fanout, for all topics not already hit in mesh
|
||||
for topic in self.fanout:
|
||||
if topic not in self.mesh:
|
||||
msg_ids = self.mcache.window(topic)
|
||||
if msg_ids:
|
||||
# TODO: Make more efficient, possibly using a generator?
|
||||
# Get all pubsub peers in topic and only add if they are gossipsub peers also
|
||||
if topic in self.pubsub.peer_topics:
|
||||
# Select D peers from peers.gossipsub[topic]
|
||||
peers_to_emit_ihave_to = self._get_in_topic_gossipsub_peers_from_minus(
|
||||
topic, self.degree, []
|
||||
)
|
||||
for peer in peers_to_emit_ihave_to:
|
||||
if peer not in self.fanout[topic]:
|
||||
msg_id_strs = [str(msg) for msg in msg_ids]
|
||||
await self.emit_ihave(topic, msg_id_strs, peer)
|
||||
|
||||
self.mcache.shift()
|
||||
msg_ids = self.mcache.window(topic)
|
||||
if msg_ids:
|
||||
# Get all pubsub peers in topic and only add if they are gossipsub peers also
|
||||
if topic in self.pubsub.peer_topics:
|
||||
# Select D peers from peers.gossipsub[topic]
|
||||
peers_to_emit_ihave_to = self._get_in_topic_gossipsub_peers_from_minus(
|
||||
topic, self.degree, self.fanout[topic]
|
||||
)
|
||||
msg_id_strs = [str(msg) for msg in msg_ids]
|
||||
for peer in peers_to_emit_ihave_to:
|
||||
peers_to_gossip[peer][topic] = msg_id_strs
|
||||
return peers_to_gossip
|
||||
|
||||
@staticmethod
|
||||
def select_from_minus(
|
||||
num_to_select: int, pool: Sequence[Any], minus: Sequence[Any]
|
||||
num_to_select: int, pool: Iterable[Any], minus: Iterable[Any]
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Select at most num_to_select subset of elements from the set (pool - minus) randomly.
|
||||
@ -441,7 +510,7 @@ class GossipSub(IPubsubRouter, Service):
|
||||
|
||||
# If num_to_select > size(selection_pool), then return selection_pool (which has the most
|
||||
# possible elements s.t. the number of elements is less than num_to_select)
|
||||
if num_to_select > len(selection_pool):
|
||||
if num_to_select >= len(selection_pool):
|
||||
return selection_pool
|
||||
|
||||
# Random selection
|
||||
@ -450,16 +519,14 @@ class GossipSub(IPubsubRouter, Service):
|
||||
return selection
|
||||
|
||||
def _get_in_topic_gossipsub_peers_from_minus(
|
||||
self, topic: str, num_to_select: int, minus: Sequence[ID]
|
||||
self, topic: str, num_to_select: int, minus: Iterable[ID]
|
||||
) -> List[ID]:
|
||||
gossipsub_peers_in_topic = [
|
||||
gossipsub_peers_in_topic = set(
|
||||
peer_id
|
||||
for peer_id in self.pubsub.peer_topics[topic]
|
||||
if peer_id in self.peers_gossipsub
|
||||
]
|
||||
return self.select_from_minus(
|
||||
num_to_select, gossipsub_peers_in_topic, list(minus)
|
||||
if self.peer_protocol[peer_id] == PROTOCOL_ID
|
||||
)
|
||||
return self.select_from_minus(num_to_select, gossipsub_peers_in_topic, minus)
|
||||
|
||||
# RPC handlers
|
||||
|
||||
@ -517,6 +584,12 @@ class GossipSub(IPubsubRouter, Service):
|
||||
rpc_msg: bytes = packet.SerializeToString()
|
||||
|
||||
# 3) Get the stream to this peer
|
||||
if sender_peer_id not in self.pubsub.peers:
|
||||
logger.debug(
|
||||
"Fail to responed to iwant request from %s: peer record not exist",
|
||||
sender_peer_id,
|
||||
)
|
||||
return
|
||||
peer_stream = self.pubsub.peers[sender_peer_id]
|
||||
|
||||
# 4) And write the packet to the stream
|
||||
@ -537,7 +610,7 @@ class GossipSub(IPubsubRouter, Service):
|
||||
# Add peer to mesh for topic
|
||||
if topic in self.mesh:
|
||||
if sender_peer_id not in self.mesh[topic]:
|
||||
self.mesh[topic].append(sender_peer_id)
|
||||
self.mesh[topic].add(sender_peer_id)
|
||||
else:
|
||||
# Respond with PRUNE if not subscribed to the topic
|
||||
await self.emit_prune(topic, sender_peer_id)
|
||||
@ -547,12 +620,27 @@ class GossipSub(IPubsubRouter, Service):
|
||||
) -> None:
|
||||
topic: str = prune_msg.topicID
|
||||
|
||||
# Remove peer from mesh for topic, if peer is in topic
|
||||
if topic in self.mesh and sender_peer_id in self.mesh[topic]:
|
||||
self.mesh[topic].remove(sender_peer_id)
|
||||
# Remove peer from mesh for topic
|
||||
if topic in self.mesh:
|
||||
self.mesh[topic].discard(sender_peer_id)
|
||||
|
||||
# RPC emitters
|
||||
|
||||
def pack_control_msgs(
|
||||
self,
|
||||
ihave_msgs: List[rpc_pb2.ControlIHave],
|
||||
graft_msgs: List[rpc_pb2.ControlGraft],
|
||||
prune_msgs: List[rpc_pb2.ControlPrune],
|
||||
) -> rpc_pb2.ControlMessage:
|
||||
control_msg: rpc_pb2.ControlMessage = rpc_pb2.ControlMessage()
|
||||
if ihave_msgs:
|
||||
control_msg.ihave.extend(ihave_msgs)
|
||||
if graft_msgs:
|
||||
control_msg.graft.extend(graft_msgs)
|
||||
if prune_msgs:
|
||||
control_msg.prune.extend(prune_msgs)
|
||||
return control_msg
|
||||
|
||||
async def emit_ihave(self, topic: str, msg_ids: Any, to_peer: ID) -> None:
|
||||
"""Emit ihave message, sent to to_peer, for topic and msg_ids."""
|
||||
|
||||
@ -608,6 +696,11 @@ class GossipSub(IPubsubRouter, Service):
|
||||
rpc_msg: bytes = packet.SerializeToString()
|
||||
|
||||
# Get stream for peer from pubsub
|
||||
if to_peer not in self.pubsub.peers:
|
||||
logger.debug(
|
||||
"Fail to emit control message to %s: peer record not exist", to_peer
|
||||
)
|
||||
return
|
||||
peer_stream = self.pubsub.peers[to_peer]
|
||||
|
||||
# Write rpc to stream
|
||||
|
||||
@ -96,8 +96,7 @@ class MessageCache:
|
||||
last_entries: List[CacheEntry] = self.history[len(self.history) - 1]
|
||||
|
||||
for entry in last_entries:
|
||||
if entry.mid in self.msgs:
|
||||
del self.msgs[entry.mid]
|
||||
self.msgs.pop(entry.mid)
|
||||
|
||||
i: int = len(self.history) - 2
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ from typing import (
|
||||
KeysView,
|
||||
List,
|
||||
NamedTuple,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
@ -19,6 +20,7 @@ import base58
|
||||
from lru import LRU
|
||||
import trio
|
||||
|
||||
from libp2p.crypto.keys import PrivateKey
|
||||
from libp2p.exceptions import ParseError, ValidationError
|
||||
from libp2p.host.host_interface import IHost
|
||||
from libp2p.io.exceptions import IncompleteReadError
|
||||
@ -33,7 +35,7 @@ from .abc import IPubsub, ISubscriptionAPI
|
||||
from .pb import rpc_pb2
|
||||
from .pubsub_notifee import PubsubNotifee
|
||||
from .subscription import TrioSubscriptionAPI
|
||||
from .validators import signature_validator
|
||||
from .validators import PUBSUB_SIGNING_PREFIX, signature_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .abc import IPubsubRouter # noqa: F401
|
||||
@ -73,16 +75,23 @@ class Pubsub(IPubsub, Service):
|
||||
subscribed_topics_send: Dict[str, "trio.MemorySendChannel[rpc_pb2.Message]"]
|
||||
subscribed_topics_receive: Dict[str, "TrioSubscriptionAPI"]
|
||||
|
||||
peer_topics: Dict[str, List[ID]]
|
||||
peer_topics: Dict[str, Set[ID]]
|
||||
peers: Dict[ID, INetStream]
|
||||
|
||||
topic_validators: Dict[str, TopicValidator]
|
||||
|
||||
# TODO: Be sure it is increased atomically everytime.
|
||||
counter: int # uint64
|
||||
|
||||
# Indicate if we should enforce signature verification
|
||||
strict_signing: bool
|
||||
sign_key: PrivateKey
|
||||
|
||||
def __init__(
|
||||
self, host: IHost, router: "IPubsubRouter", cache_size: int = None
|
||||
self,
|
||||
host: IHost,
|
||||
router: "IPubsubRouter",
|
||||
cache_size: int = None,
|
||||
strict_signing: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Construct a new Pubsub object, which is responsible for handling all
|
||||
@ -126,6 +135,12 @@ class Pubsub(IPubsub, Service):
|
||||
else:
|
||||
self.cache_size = cache_size
|
||||
|
||||
self.strict_signing = strict_signing
|
||||
if strict_signing:
|
||||
self.sign_key = self.host.get_private_key()
|
||||
else:
|
||||
self.sign_key = None
|
||||
|
||||
self.seen_messages = LRU(self.cache_size)
|
||||
|
||||
# Map of topics we are subscribed to blocking queues
|
||||
@ -142,7 +157,7 @@ class Pubsub(IPubsub, Service):
|
||||
# Map of topic to topic validator
|
||||
self.topic_validators = {}
|
||||
|
||||
self.counter = time.time_ns()
|
||||
self.counter = int(time.time())
|
||||
|
||||
async def run(self) -> None:
|
||||
self.manager.run_daemon_task(self.handle_peer_queue)
|
||||
@ -239,8 +254,7 @@ class Pubsub(IPubsub, Service):
|
||||
|
||||
:param topic: the topic to remove validator from
|
||||
"""
|
||||
if topic in self.topic_validators:
|
||||
del self.topic_validators[topic]
|
||||
self.topic_validators.pop(topic, None)
|
||||
|
||||
def get_msg_validators(self, msg: rpc_pb2.Message) -> Tuple[TopicValidator, ...]:
|
||||
"""
|
||||
@ -282,24 +296,22 @@ class Pubsub(IPubsub, Service):
|
||||
logger.debug("fail to add new peer %s, error %s", peer_id, error)
|
||||
return
|
||||
|
||||
self.peers[peer_id] = stream
|
||||
|
||||
# Send hello packet
|
||||
hello = self.get_hello_packet()
|
||||
try:
|
||||
await stream.write(encode_varint_prefixed(hello.SerializeToString()))
|
||||
except StreamClosed:
|
||||
logger.debug("Fail to add new peer %s: stream closed", peer_id)
|
||||
del self.peers[peer_id]
|
||||
return
|
||||
# TODO: Check if the peer in black list.
|
||||
try:
|
||||
self.router.add_peer(peer_id, stream.get_protocol())
|
||||
except Exception as error:
|
||||
logger.debug("fail to add new peer %s, error %s", peer_id, error)
|
||||
del self.peers[peer_id]
|
||||
return
|
||||
|
||||
self.peers[peer_id] = stream
|
||||
|
||||
logger.debug("added new peer %s", peer_id)
|
||||
|
||||
def _handle_dead_peer(self, peer_id: ID) -> None:
|
||||
@ -309,19 +321,16 @@ class Pubsub(IPubsub, Service):
|
||||
|
||||
for topic in self.peer_topics:
|
||||
if peer_id in self.peer_topics[topic]:
|
||||
self.peer_topics[topic].remove(peer_id)
|
||||
self.peer_topics[topic].discard(peer_id)
|
||||
|
||||
self.router.remove_peer(peer_id)
|
||||
|
||||
logger.debug("removed dead peer %s", peer_id)
|
||||
|
||||
async def handle_peer_queue(self) -> None:
|
||||
"""
|
||||
Continuously read from peer channel and each time a new peer is found,
|
||||
open a stream to the peer using a supported pubsub protocol
|
||||
TODO: Handle failure for when the peer does not support any of the
|
||||
pubsub protocols we support
|
||||
"""
|
||||
"""Continuously read from peer queue and each time a new peer is found,
|
||||
open a stream to the peer using a supported pubsub protocol pubsub
|
||||
protocols we support."""
|
||||
async with self.peer_receive_channel:
|
||||
while self.manager.is_running:
|
||||
peer_id: ID = await self.peer_receive_channel.receive()
|
||||
@ -351,14 +360,14 @@ class Pubsub(IPubsub, Service):
|
||||
"""
|
||||
if sub_message.subscribe:
|
||||
if sub_message.topicid not in self.peer_topics:
|
||||
self.peer_topics[sub_message.topicid] = [origin_id]
|
||||
self.peer_topics[sub_message.topicid] = set([origin_id])
|
||||
elif origin_id not in self.peer_topics[sub_message.topicid]:
|
||||
# Add peer to topic
|
||||
self.peer_topics[sub_message.topicid].append(origin_id)
|
||||
self.peer_topics[sub_message.topicid].add(origin_id)
|
||||
else:
|
||||
if sub_message.topicid in self.peer_topics:
|
||||
if origin_id in self.peer_topics[sub_message.topicid]:
|
||||
self.peer_topics[sub_message.topicid].remove(origin_id)
|
||||
self.peer_topics[sub_message.topicid].discard(origin_id)
|
||||
|
||||
# FIXME(mhchia): Change the function name?
|
||||
async def handle_talk(self, publish_message: rpc_pb2.Message) -> None:
|
||||
@ -476,7 +485,13 @@ class Pubsub(IPubsub, Service):
|
||||
seqno=self._next_seqno(),
|
||||
)
|
||||
|
||||
# TODO: Sign with our signing key
|
||||
if self.strict_signing:
|
||||
priv_key = self.sign_key
|
||||
signature = priv_key.sign(
|
||||
PUBSUB_SIGNING_PREFIX.encode() + msg.SerializeToString()
|
||||
)
|
||||
msg.key = self.host.get_public_key().serialize()
|
||||
msg.signature = signature
|
||||
|
||||
await self.push_msg(self.my_id, msg)
|
||||
|
||||
@ -536,18 +551,17 @@ class Pubsub(IPubsub, Service):
|
||||
|
||||
# TODO: Check if the `from` is in the blacklist. If yes, reject.
|
||||
|
||||
# TODO: Check if signing is required and if so signature should be attached.
|
||||
|
||||
# If the message is processed before, return(i.e., don't further process the message).
|
||||
if self._is_msg_seen(msg):
|
||||
return
|
||||
|
||||
# TODO: - Validate the message. If failed, reject it.
|
||||
# Validate the signature of the message
|
||||
# FIXME: `signature_validator` is currently a stub.
|
||||
if not signature_validator(msg.key, msg.SerializeToString()):
|
||||
logger.debug("Signature validation failed for msg: %s", msg)
|
||||
return
|
||||
# Check if signing is required and if so validate the signature
|
||||
if self.strict_signing:
|
||||
# Validate the signature of the message
|
||||
if not signature_validator(msg):
|
||||
logger.debug("Signature validation failed for msg: %s", msg)
|
||||
return
|
||||
|
||||
# Validate the message with registered topic validators.
|
||||
# If the validation failed, return(i.e., don't further process the message).
|
||||
try:
|
||||
|
||||
@ -1,10 +1,41 @@
|
||||
# FIXME: Replace the type of `pubkey` with a custom type `Pubkey`
|
||||
def signature_validator(pubkey: bytes, msg: bytes) -> bool:
|
||||
import logging
|
||||
|
||||
from libp2p.crypto.serialization import deserialize_public_key
|
||||
from libp2p.peer.id import ID
|
||||
|
||||
from .pb import rpc_pb2
|
||||
|
||||
logger = logging.getLogger("libp2p.pubsub")
|
||||
|
||||
PUBSUB_SIGNING_PREFIX = "libp2p-pubsub:"
|
||||
|
||||
|
||||
def signature_validator(msg: rpc_pb2.Message) -> bool:
|
||||
"""
|
||||
Verify the message against the given public key.
|
||||
|
||||
:param pubkey: the public key which signs the message.
|
||||
:param msg: the message signed.
|
||||
"""
|
||||
# TODO: Implement the signature validation
|
||||
return True
|
||||
# Check if signature is attached
|
||||
if msg.signature == b"":
|
||||
logger.debug("Reject because no signature attached for msg: %s", msg)
|
||||
return False
|
||||
|
||||
# Validate if message sender matches message signer,
|
||||
# i.e., check if `msg.key` matches `msg.from_id`
|
||||
msg_pubkey = deserialize_public_key(msg.key)
|
||||
if ID.from_pubkey(msg_pubkey) != msg.from_id:
|
||||
logger.debug(
|
||||
"Reject because signing key does not match sender ID for msg: %s", msg
|
||||
)
|
||||
return False
|
||||
# First, construct the original payload that's signed by 'msg.key'
|
||||
msg_without_key_sig = rpc_pb2.Message(
|
||||
data=msg.data, topicIDs=msg.topicIDs, from_id=msg.from_id, seqno=msg.seqno
|
||||
)
|
||||
payload = PUBSUB_SIGNING_PREFIX.encode() + msg_without_key_sig.SerializeToString()
|
||||
try:
|
||||
return msg_pubkey.verify(payload, msg.signature)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@ -1,21 +0,0 @@
|
||||
from typing import Iterable
|
||||
|
||||
from libp2p.peer.peerinfo import PeerInfo
|
||||
from libp2p.routing.interfaces import IContentRouting
|
||||
|
||||
|
||||
class KadmeliaContentRouter(IContentRouting):
|
||||
def provide(self, cid: bytes, announce: bool = True) -> None:
|
||||
"""
|
||||
Provide adds the given cid to the content routing system.
|
||||
|
||||
If announce is True, it also announces it, otherwise it is just
|
||||
kept in the local accounting of which objects are being
|
||||
provided.
|
||||
"""
|
||||
# the DHT finds the closest peers to `key` using the `FIND_NODE` RPC
|
||||
# then sends a `ADD_PROVIDER` RPC with its own `PeerInfo` to each of these peers.
|
||||
|
||||
def find_provider_iter(self, cid: bytes, count: int) -> Iterable[PeerInfo]:
|
||||
"""Search for peers who are able to provide a given key returns an
|
||||
iterator of peer.PeerInfo."""
|
||||
@ -1,43 +0,0 @@
|
||||
import json
|
||||
|
||||
import multiaddr
|
||||
|
||||
from libp2p.kademlia.network import KademliaServer
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import PeerInfo
|
||||
from libp2p.routing.interfaces import IPeerRouting
|
||||
|
||||
|
||||
class KadmeliaPeerRouter(IPeerRouting):
|
||||
server: KademliaServer
|
||||
|
||||
def __init__(self, dht_server: KademliaServer) -> None:
|
||||
self.server = dht_server
|
||||
|
||||
async def find_peer(self, peer_id: ID) -> PeerInfo:
|
||||
"""
|
||||
Find a specific peer.
|
||||
|
||||
:param peer_id: peer to search for
|
||||
:return: PeerInfo of specified peer
|
||||
"""
|
||||
# switching peer_id to xor_id used by kademlia as node_id
|
||||
xor_id = peer_id.xor_id
|
||||
# ignore type for kad
|
||||
value = await self.server.get(xor_id) # type: ignore
|
||||
return (
|
||||
peer_info_from_str(value) if value else None
|
||||
) # TODO: should raise error if None?
|
||||
|
||||
|
||||
def peer_info_to_str(peer_info: PeerInfo) -> str:
|
||||
return json.dumps(
|
||||
[peer_info.peer_id.to_string(), list(map(lambda a: str(a), peer_info.addrs))]
|
||||
)
|
||||
|
||||
|
||||
def peer_info_from_str(string: str) -> PeerInfo:
|
||||
peer_id, raw_addrs = json.loads(string)
|
||||
return PeerInfo(
|
||||
ID.from_base58(peer_id), list(map(lambda a: multiaddr.Multiaddr(a), raw_addrs))
|
||||
)
|
||||
@ -50,8 +50,7 @@ class SecurityMultistream(ABC):
|
||||
:param transport: the corresponding transportation to the ``protocol``.
|
||||
"""
|
||||
# If protocol is already added before, remove it and add it again.
|
||||
if protocol in self.transports:
|
||||
del self.transports[protocol]
|
||||
self.transports.pop(protocol, None)
|
||||
self.transports[protocol] = transport
|
||||
# Note: None is added as the handler for the given protocol since
|
||||
# we only care about selecting the protocol, not any handler function
|
||||
|
||||
@ -292,8 +292,7 @@ class Mplex(IMuxedConn, Service):
|
||||
# the entry of this stream, to avoid others from accessing it.
|
||||
if is_local_closed:
|
||||
async with self.streams_lock:
|
||||
if stream_id in self.streams:
|
||||
del self.streams[stream_id]
|
||||
self.streams.pop(stream_id, None)
|
||||
|
||||
async def _handle_reset(self, stream_id: StreamID) -> None:
|
||||
async with self.streams_lock:
|
||||
@ -311,9 +310,8 @@ class Mplex(IMuxedConn, Service):
|
||||
if not stream.event_local_closed.is_set():
|
||||
stream.event_local_closed.set()
|
||||
async with self.streams_lock:
|
||||
if stream_id in self.streams:
|
||||
del self.streams[stream_id]
|
||||
del self.streams_msg_channels[stream_id]
|
||||
self.streams.pop(stream_id, None)
|
||||
self.streams_msg_channels.pop(stream_id, None)
|
||||
|
||||
async def _cleanup(self) -> None:
|
||||
if not self.event_shutting_down.is_set():
|
||||
|
||||
@ -170,8 +170,7 @@ class MplexStream(IMuxedStream):
|
||||
if _is_remote_closed:
|
||||
# Both sides are closed, we can safely remove the buffer from the dict.
|
||||
async with self.muxed_conn.streams_lock:
|
||||
if self.stream_id in self.muxed_conn.streams:
|
||||
del self.muxed_conn.streams[self.stream_id]
|
||||
self.muxed_conn.streams.pop(self.stream_id, None)
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""closes both ends of the stream tells this remote side to hang up."""
|
||||
@ -199,11 +198,8 @@ class MplexStream(IMuxedStream):
|
||||
await self.incoming_data_channel.aclose()
|
||||
|
||||
async with self.muxed_conn.streams_lock:
|
||||
if (
|
||||
self.muxed_conn.streams is not None
|
||||
and self.stream_id in self.muxed_conn.streams
|
||||
):
|
||||
del self.muxed_conn.streams[self.stream_id]
|
||||
if self.muxed_conn.streams is not None:
|
||||
self.muxed_conn.streams.pop(self.stream_id, None)
|
||||
|
||||
# TODO deadline not in use
|
||||
def set_deadline(self, ttl: int) -> bool:
|
||||
|
||||
@ -44,8 +44,7 @@ class MuxerMultistream:
|
||||
:param transport: the corresponding transportation to the ``protocol``.
|
||||
"""
|
||||
# If protocol is already added before, remove it and add it again.
|
||||
if protocol in self.transports:
|
||||
del self.transports[protocol]
|
||||
self.transports.pop(protocol, None)
|
||||
self.transports[protocol] = transport
|
||||
self.multiselect.add_handler(protocol, None)
|
||||
|
||||
|
||||
@ -24,6 +24,7 @@ class GossipsubParams(NamedTuple):
|
||||
time_to_live: int = 30
|
||||
gossip_window: int = 3
|
||||
gossip_history: int = 5
|
||||
heartbeat_initial_delay: float = 0.1
|
||||
heartbeat_interval: float = 0.5
|
||||
|
||||
|
||||
|
||||
@ -1,14 +1,18 @@
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from typing import Any, AsyncIterator, Dict, Sequence, Tuple, cast
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Any, AsyncIterator, Dict, List, Sequence, Tuple, cast
|
||||
|
||||
# NOTE: import ``asynccontextmanager`` from ``contextlib`` when support for python 3.6 is dropped.
|
||||
from async_generator import asynccontextmanager
|
||||
from async_service import background_trio_service
|
||||
import factory
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p import generate_new_rsa_identity, generate_peer_id_from
|
||||
from libp2p.crypto.keys import KeyPair
|
||||
from libp2p.host.basic_host import BasicHost
|
||||
from libp2p.host.host_interface import IHost
|
||||
from libp2p.host.routed_host import RoutedHost
|
||||
from libp2p.io.abc import ReadWriteCloser
|
||||
from libp2p.network.connection.raw_connection import RawConnection
|
||||
from libp2p.network.connection.raw_connection_interface import IRawConnection
|
||||
@ -16,11 +20,13 @@ from libp2p.network.connection.swarm_connection import SwarmConn
|
||||
from libp2p.network.stream.net_stream_interface import INetStream
|
||||
from libp2p.network.swarm import Swarm
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import PeerInfo
|
||||
from libp2p.peer.peerstore import PeerStore
|
||||
from libp2p.pubsub.abc import IPubsubRouter
|
||||
from libp2p.pubsub.floodsub import FloodSub
|
||||
from libp2p.pubsub.gossipsub import GossipSub
|
||||
from libp2p.pubsub.pubsub import Pubsub
|
||||
from libp2p.routing.interfaces import IPeerRouting
|
||||
from libp2p.security.base_transport import BaseSecureTransport
|
||||
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
||||
import libp2p.security.secio.transport as secio
|
||||
@ -45,6 +51,12 @@ class IDFactory(factory.Factory):
|
||||
)
|
||||
|
||||
|
||||
def initialize_peerstore_with_our_keypair(self_id: ID, key_pair: KeyPair) -> PeerStore:
|
||||
peer_store = PeerStore()
|
||||
peer_store.add_key_pair(self_id, key_pair)
|
||||
return peer_store
|
||||
|
||||
|
||||
def security_transport_factory(
|
||||
is_secure: bool, key_pair: KeyPair
|
||||
) -> Dict[TProtocol, BaseSecureTransport]:
|
||||
@ -60,10 +72,12 @@ async def raw_conn_factory(
|
||||
) -> AsyncIterator[Tuple[IRawConnection, IRawConnection]]:
|
||||
conn_0 = None
|
||||
conn_1 = None
|
||||
event = trio.Event()
|
||||
|
||||
async def tcp_stream_handler(stream: ReadWriteCloser) -> None:
|
||||
nonlocal conn_1
|
||||
conn_1 = RawConnection(stream, initiator=False)
|
||||
event.set()
|
||||
await trio.sleep_forever()
|
||||
|
||||
tcp_transport = TCP()
|
||||
@ -71,6 +85,7 @@ async def raw_conn_factory(
|
||||
await listener.listen(LISTEN_MADDR, nursery)
|
||||
listening_maddr = listener.get_addrs()[0]
|
||||
conn_0 = await tcp_transport.dial(listening_maddr)
|
||||
await event.wait()
|
||||
yield conn_0, conn_1
|
||||
|
||||
|
||||
@ -84,7 +99,9 @@ class SwarmFactory(factory.Factory):
|
||||
muxer_opt = {MPLEX_PROTOCOL_ID: Mplex}
|
||||
|
||||
peer_id = factory.LazyAttribute(lambda o: generate_peer_id_from(o.key_pair))
|
||||
peerstore = factory.LazyFunction(PeerStore)
|
||||
peerstore = factory.LazyAttribute(
|
||||
lambda o: initialize_peerstore_with_our_keypair(o.peer_id, o.key_pair)
|
||||
)
|
||||
upgrader = factory.LazyAttribute(
|
||||
lambda o: TransportUpgrader(
|
||||
security_transport_factory(o.is_secure, o.key_pair), o.muxer_opt
|
||||
@ -133,31 +150,59 @@ class HostFactory(factory.Factory):
|
||||
is_secure = False
|
||||
key_pair = factory.LazyFunction(generate_new_rsa_identity)
|
||||
|
||||
public_key = factory.LazyAttribute(lambda o: o.key_pair.public_key)
|
||||
network = factory.LazyAttribute(
|
||||
lambda o: SwarmFactory(is_secure=o.is_secure, key_pair=o.key_pair)
|
||||
)
|
||||
network = factory.LazyAttribute(lambda o: SwarmFactory(is_secure=o.is_secure))
|
||||
|
||||
@classmethod
|
||||
@asynccontextmanager
|
||||
async def create_batch_and_listen(
|
||||
cls, is_secure: bool, number: int
|
||||
) -> AsyncIterator[Tuple[BasicHost, ...]]:
|
||||
key_pairs = [generate_new_rsa_identity() for _ in range(number)]
|
||||
async with AsyncExitStack() as stack:
|
||||
swarms = [
|
||||
await stack.enter_async_context(
|
||||
SwarmFactory.create_and_listen(is_secure, key_pair)
|
||||
)
|
||||
for key_pair in key_pairs
|
||||
]
|
||||
hosts = tuple(
|
||||
BasicHost(key_pair.public_key, swarm)
|
||||
for key_pair, swarm in zip(key_pairs, swarms)
|
||||
)
|
||||
async with SwarmFactory.create_batch_and_listen(is_secure, number) as swarms:
|
||||
hosts = tuple(BasicHost(swarm) for swarm in swarms)
|
||||
yield hosts
|
||||
|
||||
|
||||
class DummyRouter(IPeerRouting):
|
||||
_routing_table: Dict[ID, PeerInfo]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._routing_table = dict()
|
||||
|
||||
def _add_peer(self, peer_id: ID, addrs: List[Multiaddr]) -> None:
|
||||
self._routing_table[peer_id] = PeerInfo(peer_id, addrs)
|
||||
|
||||
async def find_peer(self, peer_id: ID) -> PeerInfo:
|
||||
await trio.hazmat.checkpoint()
|
||||
return self._routing_table.get(peer_id, None)
|
||||
|
||||
|
||||
class RoutedHostFactory(factory.Factory):
|
||||
class Meta:
|
||||
model = RoutedHost
|
||||
|
||||
class Params:
|
||||
is_secure = False
|
||||
|
||||
network = factory.LazyAttribute(
|
||||
lambda o: HostFactory(is_secure=o.is_secure).get_network()
|
||||
)
|
||||
router = factory.LazyFunction(DummyRouter)
|
||||
|
||||
@classmethod
|
||||
@asynccontextmanager
|
||||
async def create_batch_and_listen(
|
||||
cls, is_secure: bool, number: int
|
||||
) -> AsyncIterator[Tuple[RoutedHost, ...]]:
|
||||
routing_table = DummyRouter()
|
||||
async with HostFactory.create_batch_and_listen(is_secure, number) as hosts:
|
||||
for host in hosts:
|
||||
routing_table._add_peer(host.get_id(), host.get_addrs())
|
||||
routed_hosts = tuple(
|
||||
RoutedHost(host.get_network(), routing_table) for host in hosts
|
||||
)
|
||||
yield routed_hosts
|
||||
|
||||
|
||||
class FloodsubFactory(factory.Factory):
|
||||
class Meta:
|
||||
model = FloodSub
|
||||
@ -176,6 +221,7 @@ class GossipsubFactory(factory.Factory):
|
||||
time_to_live = GOSSIPSUB_PARAMS.time_to_live
|
||||
gossip_window = GOSSIPSUB_PARAMS.gossip_window
|
||||
gossip_history = GOSSIPSUB_PARAMS.gossip_history
|
||||
heartbeat_initial_delay = GOSSIPSUB_PARAMS.heartbeat_initial_delay
|
||||
heartbeat_interval = GOSSIPSUB_PARAMS.heartbeat_interval
|
||||
|
||||
|
||||
@ -186,13 +232,19 @@ class PubsubFactory(factory.Factory):
|
||||
host = factory.SubFactory(HostFactory)
|
||||
router = None
|
||||
cache_size = None
|
||||
strict_signing = False
|
||||
|
||||
@classmethod
|
||||
@asynccontextmanager
|
||||
async def create_and_start(
|
||||
cls, host: IHost, router: IPubsubRouter, cache_size: int
|
||||
cls, host: IHost, router: IPubsubRouter, cache_size: int, strict_signing: bool
|
||||
) -> AsyncIterator[Pubsub]:
|
||||
pubsub = PubsubFactory(host=host, router=router, cache_size=cache_size)
|
||||
pubsub = PubsubFactory(
|
||||
host=host,
|
||||
router=router,
|
||||
cache_size=cache_size,
|
||||
strict_signing=strict_signing,
|
||||
)
|
||||
async with background_trio_service(pubsub):
|
||||
yield pubsub
|
||||
|
||||
@ -204,13 +256,14 @@ class PubsubFactory(factory.Factory):
|
||||
routers: Sequence[IPubsubRouter],
|
||||
is_secure: bool = False,
|
||||
cache_size: int = None,
|
||||
strict_signing: bool = False,
|
||||
) -> AsyncIterator[Tuple[Pubsub, ...]]:
|
||||
async with HostFactory.create_batch_and_listen(is_secure, number) as hosts:
|
||||
# Pubsubs should exit before hosts
|
||||
async with AsyncExitStack() as stack:
|
||||
pubsubs = [
|
||||
await stack.enter_async_context(
|
||||
cls.create_and_start(host, router, cache_size)
|
||||
cls.create_and_start(host, router, cache_size, strict_signing)
|
||||
)
|
||||
for host, router in zip(hosts, routers)
|
||||
]
|
||||
@ -223,6 +276,7 @@ class PubsubFactory(factory.Factory):
|
||||
number: int,
|
||||
is_secure: bool = False,
|
||||
cache_size: int = None,
|
||||
strict_signing: bool = False,
|
||||
protocols: Sequence[TProtocol] = None,
|
||||
) -> AsyncIterator[Tuple[Pubsub, ...]]:
|
||||
if protocols is not None:
|
||||
@ -230,7 +284,7 @@ class PubsubFactory(factory.Factory):
|
||||
else:
|
||||
floodsubs = FloodsubFactory.create_batch(number)
|
||||
async with cls._create_batch_with_router(
|
||||
number, floodsubs, is_secure, cache_size
|
||||
number, floodsubs, is_secure, cache_size, strict_signing
|
||||
) as pubsubs:
|
||||
yield pubsubs
|
||||
|
||||
@ -242,6 +296,7 @@ class PubsubFactory(factory.Factory):
|
||||
*,
|
||||
is_secure: bool = False,
|
||||
cache_size: int = None,
|
||||
strict_signing: bool = False,
|
||||
protocols: Sequence[TProtocol] = None,
|
||||
degree: int = GOSSIPSUB_PARAMS.degree,
|
||||
degree_low: int = GOSSIPSUB_PARAMS.degree_low,
|
||||
@ -250,6 +305,7 @@ class PubsubFactory(factory.Factory):
|
||||
gossip_window: int = GOSSIPSUB_PARAMS.gossip_window,
|
||||
gossip_history: int = GOSSIPSUB_PARAMS.gossip_history,
|
||||
heartbeat_interval: float = GOSSIPSUB_PARAMS.heartbeat_interval,
|
||||
heartbeat_initial_delay: float = GOSSIPSUB_PARAMS.heartbeat_initial_delay,
|
||||
) -> AsyncIterator[Tuple[Pubsub, ...]]:
|
||||
if protocols is not None:
|
||||
gossipsubs = GossipsubFactory.create_batch(
|
||||
@ -274,7 +330,7 @@ class PubsubFactory(factory.Factory):
|
||||
)
|
||||
|
||||
async with cls._create_batch_with_router(
|
||||
number, gossipsubs, is_secure, cache_size
|
||||
number, gossipsubs, is_secure, cache_size, strict_signing
|
||||
) as pubsubs:
|
||||
async with AsyncExitStack() as stack:
|
||||
for router in gossipsubs:
|
||||
|
||||
@ -153,31 +153,34 @@ floodsub_protocol_pytest_params = [
|
||||
|
||||
async def perform_test_from_obj(obj, pubsub_factory) -> None:
|
||||
"""
|
||||
Perform pubsub tests from a test obj.
|
||||
test obj are composed as follows:
|
||||
Perform pubsub tests from a test object, which is composed as follows:
|
||||
|
||||
{
|
||||
"supported_protocols": ["supported/protocol/1.0.0",...],
|
||||
"adj_list": {
|
||||
"node1": ["neighbor1_of_node1", "neighbor2_of_node1", ...],
|
||||
"node2": ["neighbor1_of_node2", "neighbor2_of_node2", ...],
|
||||
...
|
||||
},
|
||||
"topic_map": {
|
||||
"topic1": ["node1_subscribed_to_topic1", "node2_subscribed_to_topic1", ...]
|
||||
},
|
||||
"messages": [
|
||||
{
|
||||
"topics": ["topic1_for_message", "topic2_for_message", ...],
|
||||
"data": b"some contents of the message (newlines are not supported)",
|
||||
"node_id": "message sender node id"
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
"supported_protocols": ["supported/protocol/1.0.0",...],
|
||||
"adj_list": {
|
||||
"node1": ["neighbor1_of_node1", "neighbor2_of_node1", ...],
|
||||
"node2": ["neighbor1_of_node2", "neighbor2_of_node2", ...],
|
||||
...
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
NOTE: In adj_list, for any neighbors A and B, only list B as a neighbor of A
|
||||
or B as a neighbor of A once. Do NOT list both A: ["B"] and B:["A"] as the behavior
|
||||
is undefined (even if it may work)
|
||||
"topic_map": {
|
||||
"topic1": ["node1_subscribed_to_topic1", "node2_subscribed_to_topic1", ...]
|
||||
},
|
||||
"messages": [
|
||||
{
|
||||
"topics": ["topic1_for_message", "topic2_for_message", ...],
|
||||
"data": b"some contents of the message (newlines are not supported)",
|
||||
"node_id": "message sender node id"
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
|
||||
.. note::
|
||||
In adj_list, for any neighbors A and B, only list B as a neighbor of A
|
||||
or B as a neighbor of A once. Do NOT list both A: ["B"] and B:["A"] as the behavior
|
||||
is undefined (even if it may work)
|
||||
"""
|
||||
|
||||
# Step 1) Create graph
|
||||
|
||||
@ -39,6 +39,3 @@ def create_echo_stream_handler(
|
||||
await stream.write(resp.encode())
|
||||
|
||||
return echo_stream_handler
|
||||
|
||||
|
||||
# TODO: Service `external_api`
|
||||
|
||||
@ -8,6 +8,7 @@ from trio_typing import TaskStatus
|
||||
from libp2p.io.trio import TrioTCPStream
|
||||
from libp2p.network.connection.raw_connection import RawConnection
|
||||
from libp2p.network.connection.raw_connection_interface import IRawConnection
|
||||
from libp2p.transport.exceptions import OpenConnectionError
|
||||
from libp2p.transport.listener_interface import IListener
|
||||
from libp2p.transport.transport_interface import ITransport
|
||||
from libp2p.transport.typing import THandler
|
||||
@ -80,7 +81,10 @@ class TCP(ITransport):
|
||||
self.host = maddr.value_for_protocol("ip4")
|
||||
self.port = int(maddr.value_for_protocol("tcp"))
|
||||
|
||||
stream = await trio.open_tcp_stream(self.host, self.port)
|
||||
try:
|
||||
stream = await trio.open_tcp_stream(self.host, self.port)
|
||||
except OSError as error:
|
||||
raise OpenConnectionError from error
|
||||
read_write_closer = TrioTCPStream(stream)
|
||||
|
||||
return RawConnection(read_write_closer, True)
|
||||
|
||||
Reference in New Issue
Block a user