diff --git a/kad-dht/__init__.py b/kad-dht/__init__.py deleted file mode 100644 index bb56f29a..00000000 --- a/kad-dht/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -""" -A Kademlia DHT implemention on py-libp2p -""" -__version__ = "0.0" \ No newline at end of file diff --git a/kad-dht/lookup.py b/kad-dht/lookup.py deleted file mode 100644 index e69de29b..00000000 diff --git a/kad-dht/network.py b/kad-dht/network.py deleted file mode 100644 index e69de29b..00000000 diff --git a/kad-dht/node.py b/kad-dht/node.py deleted file mode 100644 index e69de29b..00000000 diff --git a/kad-dht/protocol.py b/kad-dht/protocol.py deleted file mode 100644 index e69de29b..00000000 diff --git a/kad-dht/routing.py b/kad-dht/routing.py deleted file mode 100644 index e69de29b..00000000 diff --git a/kad-dht/storage.py b/kad-dht/storage.py deleted file mode 100644 index e69de29b..00000000 diff --git a/kademlia/__init__.py b/kademlia/__init__.py new file mode 100644 index 00000000..7999ecae --- /dev/null +++ b/kademlia/__init__.py @@ -0,0 +1,5 @@ +""" +Kademlia is a Python implementation of the Kademlia protocol which +utilizes the asyncio library. +""" +__version__ = "1.1" diff --git a/kademlia/crawling.py b/kademlia/crawling.py new file mode 100644 index 00000000..b1460c25 --- /dev/null +++ b/kademlia/crawling.py @@ -0,0 +1,181 @@ +from collections import Counter +import logging + +from kademlia.node import Node, NodeHeap +from kademlia.utils import gather_dict + +log = logging.getLogger(__name__) + + +class SpiderCrawl(object): + """ + Crawl the network and look for given 160-bit keys. + """ + def __init__(self, protocol, node, peers, ksize, alpha): + """ + Create a new C{SpiderCrawl}er. + + Args: + protocol: A :class:`~kademlia.protocol.KademliaProtocol` instance. + node: A :class:`~kademlia.node.Node` representing the key we're + looking for + peers: A list of :class:`~kademlia.node.Node` instances that + provide the entry point for the network + ksize: The value for k based on the paper + alpha: The value for alpha based on the paper + """ + self.protocol = protocol + self.ksize = ksize + self.alpha = alpha + self.node = node + self.nearest = NodeHeap(self.node, self.ksize) + self.lastIDsCrawled = [] + log.info("creating spider with peers: %s", peers) + self.nearest.push(peers) + + async def _find(self, rpcmethod): + """ + Get either a value or list of nodes. + + Args: + rpcmethod: The protocol's callfindValue or callFindNode. + + The process: + 1. calls find_* to current ALPHA nearest not already queried nodes, + adding results to current nearest list of k nodes. + 2. current nearest list needs to keep track of who has been queried + already sort by nearest, keep KSIZE + 3. if list is same as last time, next call should be to everyone not + yet queried + 4. repeat, unless nearest list has all been queried, then ur done + """ + log.info("crawling network with nearest: %s", str(tuple(self.nearest))) + count = self.alpha + if self.nearest.getIDs() == self.lastIDsCrawled: + count = len(self.nearest) + self.lastIDsCrawled = self.nearest.getIDs() + + ds = {} + for peer in self.nearest.getUncontacted()[:count]: + ds[peer.id] = rpcmethod(peer, self.node) + self.nearest.markContacted(peer) + found = await gather_dict(ds) + return await self._nodesFound(found) + + async def _nodesFound(self, responses): + raise NotImplementedError + + +class ValueSpiderCrawl(SpiderCrawl): + def __init__(self, protocol, node, peers, ksize, alpha): + SpiderCrawl.__init__(self, protocol, node, peers, ksize, alpha) + # keep track of the single nearest node without value - per + # section 2.3 so we can set the key there if found + self.nearestWithoutValue = NodeHeap(self.node, 1) + + async def find(self): + """ + Find either the closest nodes or the value requested. + """ + return await self._find(self.protocol.callFindValue) + + async def _nodesFound(self, responses): + """ + Handle the result of an iteration in _find. + """ + toremove = [] + foundValues = [] + for peerid, response in responses.items(): + response = RPCFindResponse(response) + if not response.happened(): + toremove.append(peerid) + elif response.hasValue(): + foundValues.append(response.getValue()) + else: + peer = self.nearest.getNodeById(peerid) + self.nearestWithoutValue.push(peer) + self.nearest.push(response.getNodeList()) + self.nearest.remove(toremove) + + if len(foundValues) > 0: + return await self._handleFoundValues(foundValues) + if self.nearest.allBeenContacted(): + # not found! + return None + return await self.find() + + async def _handleFoundValues(self, values): + """ + We got some values! Exciting. But let's make sure + they're all the same or freak out a little bit. Also, + make sure we tell the nearest node that *didn't* have + the value to store it. + """ + valueCounts = Counter(values) + if len(valueCounts) != 1: + log.warning("Got multiple values for key %i: %s", + self.node.long_id, str(values)) + value = valueCounts.most_common(1)[0][0] + + peerToSaveTo = self.nearestWithoutValue.popleft() + if peerToSaveTo is not None: + await self.protocol.callStore(peerToSaveTo, self.node.id, value) + return value + + +class NodeSpiderCrawl(SpiderCrawl): + async def find(self): + """ + Find the closest nodes. + """ + return await self._find(self.protocol.callFindNode) + + async def _nodesFound(self, responses): + """ + Handle the result of an iteration in _find. + """ + toremove = [] + for peerid, response in responses.items(): + response = RPCFindResponse(response) + if not response.happened(): + toremove.append(peerid) + else: + self.nearest.push(response.getNodeList()) + self.nearest.remove(toremove) + + if self.nearest.allBeenContacted(): + return list(self.nearest) + return await self.find() + + +class RPCFindResponse(object): + def __init__(self, response): + """ + A wrapper for the result of a RPC find. + + Args: + response: This will be a tuple of (, ) + where will be a list of tuples if not found or + a dictionary of {'value': v} where v is the value desired + """ + self.response = response + + def happened(self): + """ + Did the other host actually respond? + """ + return self.response[0] + + def hasValue(self): + return isinstance(self.response[1], dict) + + def getValue(self): + return self.response[1]['value'] + + def getNodeList(self): + """ + Get the node list in the response. If there's no value, this should + be set. + """ + nodelist = self.response[1] or [] + return [Node(*nodeple) for nodeple in nodelist] diff --git a/kademlia/network.py b/kademlia/network.py new file mode 100644 index 00000000..ca87d3e5 --- /dev/null +++ b/kademlia/network.py @@ -0,0 +1,258 @@ +""" +Package for interacting on the network at a high level. +""" +import random +import pickle +import asyncio +import logging + +from kademlia.protocol import KademliaProtocol +from kademlia.utils import digest +from kademlia.storage import ForgetfulStorage +from kademlia.node import Node +from kademlia.crawling import ValueSpiderCrawl +from kademlia.crawling import NodeSpiderCrawl + +log = logging.getLogger(__name__) + + +class Server(object): + """ + High level view of a node instance. This is the object that should be + created to start listening as an active node on the network. + """ + + protocol_class = KademliaProtocol + + def __init__(self, ksize=20, alpha=3, node_id=None, storage=None): + """ + Create a server instance. This will start listening on the given port. + + Args: + ksize (int): The k parameter from the paper + alpha (int): The alpha parameter from the paper + node_id: The id for this node on the network. + storage: An instance that implements + :interface:`~kademlia.storage.IStorage` + """ + self.ksize = ksize + self.alpha = alpha + self.storage = storage or ForgetfulStorage() + self.node = Node(node_id or digest(random.getrandbits(255))) + self.transport = None + self.protocol = None + self.refresh_loop = None + self.save_state_loop = None + + def stop(self): + if self.transport is not None: + self.transport.close() + + if self.refresh_loop: + self.refresh_loop.cancel() + + if self.save_state_loop: + self.save_state_loop.cancel() + + def _create_protocol(self): + return self.protocol_class(self.node, self.storage, self.ksize) + + def listen(self, port, interface='0.0.0.0'): + """ + Start listening on the given port. + + Provide interface="::" to accept ipv6 address + """ + loop = asyncio.get_event_loop() + listen = loop.create_datagram_endpoint(self._create_protocol, + local_addr=(interface, port)) + log.info("Node %i listening on %s:%i", + self.node.long_id, interface, port) + self.transport, self.protocol = loop.run_until_complete(listen) + # finally, schedule refreshing table + self.refresh_table() + + def refresh_table(self): + log.debug("Refreshing routing table") + asyncio.ensure_future(self._refresh_table()) + loop = asyncio.get_event_loop() + self.refresh_loop = loop.call_later(3600, self.refresh_table) + + async def _refresh_table(self): + """ + Refresh buckets that haven't had any lookups in the last hour + (per section 2.3 of the paper). + """ + ds = [] + for node_id in self.protocol.getRefreshIDs(): + node = Node(node_id) + nearest = self.protocol.router.findNeighbors(node, self.alpha) + spider = NodeSpiderCrawl(self.protocol, node, nearest, + self.ksize, self.alpha) + ds.append(spider.find()) + + # do our crawling + await asyncio.gather(*ds) + + # now republish keys older than one hour + for dkey, value in self.storage.iteritemsOlderThan(3600): + await self.set_digest(dkey, value) + + def bootstrappableNeighbors(self): + """ + Get a :class:`list` of (ip, port) :class:`tuple` pairs suitable for + use as an argument to the bootstrap method. + + The server should have been bootstrapped + already - this is just a utility for getting some neighbors and then + storing them if this server is going down for a while. When it comes + back up, the list of nodes can be used to bootstrap. + """ + neighbors = self.protocol.router.findNeighbors(self.node) + return [tuple(n)[-2:] for n in neighbors] + + async def bootstrap(self, addrs): + """ + Bootstrap the server by connecting to other known nodes in the network. + + Args: + addrs: A `list` of (ip, port) `tuple` pairs. Note that only IP + addresses are acceptable - hostnames will cause an error. + """ + log.debug("Attempting to bootstrap node with %i initial contacts", + len(addrs)) + cos = list(map(self.bootstrap_node, addrs)) + gathered = await asyncio.gather(*cos) + nodes = [node for node in gathered if node is not None] + spider = NodeSpiderCrawl(self.protocol, self.node, nodes, + self.ksize, self.alpha) + return await spider.find() + + async def bootstrap_node(self, addr): + result = await self.protocol.ping(addr, self.node.id) + return Node(result[1], addr[0], addr[1]) if result[0] else None + + async def get(self, key): + """ + Get a key if the network has it. + + Returns: + :class:`None` if not found, the value otherwise. + """ + log.info("Looking up key %s", key) + dkey = digest(key) + # if this node has it, return it + if self.storage.get(dkey) is not None: + return self.storage.get(dkey) + node = Node(dkey) + nearest = self.protocol.router.findNeighbors(node) + if len(nearest) == 0: + log.warning("There are no known neighbors to get key %s", key) + return None + spider = ValueSpiderCrawl(self.protocol, node, nearest, + self.ksize, self.alpha) + return await spider.find() + + async def set(self, key, value): + """ + Set the given string key to the given value in the network. + """ + if not check_dht_value_type(value): + raise TypeError( + "Value must be of type int, float, bool, str, or bytes" + ) + log.info("setting '%s' = '%s' on network", key, value) + dkey = digest(key) + return await self.set_digest(dkey, value) + + async def set_digest(self, dkey, value): + """ + Set the given SHA1 digest key (bytes) to the given value in the + network. + """ + node = Node(dkey) + + nearest = self.protocol.router.findNeighbors(node) + if len(nearest) == 0: + log.warning("There are no known neighbors to set key %s", + dkey.hex()) + return False + + spider = NodeSpiderCrawl(self.protocol, node, nearest, + self.ksize, self.alpha) + nodes = await spider.find() + log.info("setting '%s' on %s", dkey.hex(), list(map(str, nodes))) + + # if this node is close too, then store here as well + biggest = max([n.distanceTo(node) for n in nodes]) + if self.node.distanceTo(node) < biggest: + self.storage[dkey] = value + ds = [self.protocol.callStore(n, dkey, value) for n in nodes] + # return true only if at least one store call succeeded + return any(await asyncio.gather(*ds)) + + def saveState(self, fname): + """ + Save the state of this node (the alpha/ksize/id/immediate neighbors) + to a cache file with the given fname. + """ + log.info("Saving state to %s", fname) + data = { + 'ksize': self.ksize, + 'alpha': self.alpha, + 'id': self.node.id, + 'neighbors': self.bootstrappableNeighbors() + } + if len(data['neighbors']) == 0: + log.warning("No known neighbors, so not writing to cache.") + return + with open(fname, 'wb') as f: + pickle.dump(data, f) + + @classmethod + def loadState(self, fname): + """ + Load the state of this node (the alpha/ksize/id/immediate neighbors) + from a cache file with the given fname. + """ + log.info("Loading state from %s", fname) + with open(fname, 'rb') as f: + data = pickle.load(f) + s = Server(data['ksize'], data['alpha'], data['id']) + if len(data['neighbors']) > 0: + s.bootstrap(data['neighbors']) + return s + + def saveStateRegularly(self, fname, frequency=600): + """ + Save the state of node with a given regularity to the given + filename. + + Args: + fname: File name to save retularly to + frequency: Frequency in seconds that the state should be saved. + By default, 10 minutes. + """ + self.saveState(fname) + loop = asyncio.get_event_loop() + self.save_state_loop = loop.call_later(frequency, + self.saveStateRegularly, + fname, + frequency) + + +def check_dht_value_type(value): + """ + Checks to see if the type of the value is a valid type for + placing in the dht. + """ + typeset = set( + [ + int, + float, + bool, + str, + bytes, + ] + ) + return type(value) in typeset diff --git a/kademlia/node.py b/kademlia/node.py new file mode 100644 index 00000000..b25e9f87 --- /dev/null +++ b/kademlia/node.py @@ -0,0 +1,115 @@ +from operator import itemgetter +import heapq + + +class Node: + def __init__(self, node_id, ip=None, port=None): + self.id = node_id + self.ip = ip + self.port = port + self.long_id = int(node_id.hex(), 16) + + def sameHomeAs(self, node): + return self.ip == node.ip and self.port == node.port + + def distanceTo(self, node): + """ + Get the distance between this node and another. + """ + return self.long_id ^ node.long_id + + def __iter__(self): + """ + Enables use of Node as a tuple - i.e., tuple(node) works. + """ + return iter([self.id, self.ip, self.port]) + + def __repr__(self): + return repr([self.long_id, self.ip, self.port]) + + def __str__(self): + return "%s:%s" % (self.ip, str(self.port)) + + +class NodeHeap(object): + """ + A heap of nodes ordered by distance to a given node. + """ + def __init__(self, node, maxsize): + """ + Constructor. + + @param node: The node to measure all distnaces from. + @param maxsize: The maximum size that this heap can grow to. + """ + self.node = node + self.heap = [] + self.contacted = set() + self.maxsize = maxsize + + def remove(self, peerIDs): + """ + Remove a list of peer ids from this heap. Note that while this + heap retains a constant visible size (based on the iterator), it's + actual size may be quite a bit larger than what's exposed. Therefore, + removal of nodes may not change the visible size as previously added + nodes suddenly become visible. + """ + peerIDs = set(peerIDs) + if len(peerIDs) == 0: + return + nheap = [] + for distance, node in self.heap: + if node.id not in peerIDs: + heapq.heappush(nheap, (distance, node)) + self.heap = nheap + + def getNodeById(self, node_id): + for _, node in self.heap: + if node.id == node_id: + return node + return None + + def allBeenContacted(self): + return len(self.getUncontacted()) == 0 + + def getIDs(self): + return [n.id for n in self] + + def markContacted(self, node): + self.contacted.add(node.id) + + def popleft(self): + if len(self) > 0: + return heapq.heappop(self.heap)[1] + return None + + def push(self, nodes): + """ + Push nodes onto heap. + + @param nodes: This can be a single item or a C{list}. + """ + if not isinstance(nodes, list): + nodes = [nodes] + + for node in nodes: + if node not in self: + distance = self.node.distanceTo(node) + heapq.heappush(self.heap, (distance, node)) + + def __len__(self): + return min(len(self.heap), self.maxsize) + + def __iter__(self): + nodes = heapq.nsmallest(self.maxsize, self.heap) + return iter(map(itemgetter(1), nodes)) + + def __contains__(self, node): + for _, n in self.heap: + if node.id == n.id: + return True + return False + + def getUncontacted(self): + return [n for n in self if n.id not in self.contacted] diff --git a/kademlia/protocol.py b/kademlia/protocol.py new file mode 100644 index 00000000..6e22f8ba --- /dev/null +++ b/kademlia/protocol.py @@ -0,0 +1,128 @@ +import random +import asyncio +import logging + +from rpcudp.protocol import RPCProtocol + +from kademlia.node import Node +from kademlia.routing import RoutingTable +from kademlia.utils import digest + +log = logging.getLogger(__name__) + + +class KademliaProtocol(RPCProtocol): + def __init__(self, sourceNode, storage, ksize): + RPCProtocol.__init__(self) + self.router = RoutingTable(self, ksize, sourceNode) + self.storage = storage + self.sourceNode = sourceNode + + def getRefreshIDs(self): + """ + Get ids to search for to keep old buckets up to date. + """ + ids = [] + for bucket in self.router.getLonelyBuckets(): + rid = random.randint(*bucket.range).to_bytes(20, byteorder='big') + ids.append(rid) + return ids + + def rpc_stun(self, sender): + return sender + + def rpc_ping(self, sender, nodeid): + source = Node(nodeid, sender[0], sender[1]) + self.welcomeIfNewNode(source) + return self.sourceNode.id + + def rpc_store(self, sender, nodeid, key, value): + source = Node(nodeid, sender[0], sender[1]) + self.welcomeIfNewNode(source) + log.debug("got a store request from %s, storing '%s'='%s'", + sender, key.hex(), value) + self.storage[key] = value + return True + + def rpc_find_node(self, sender, nodeid, key): + log.info("finding neighbors of %i in local table", + int(nodeid.hex(), 16)) + source = Node(nodeid, sender[0], sender[1]) + self.welcomeIfNewNode(source) + node = Node(key) + neighbors = self.router.findNeighbors(node, exclude=source) + return list(map(tuple, neighbors)) + + def rpc_find_value(self, sender, nodeid, key): + source = Node(nodeid, sender[0], sender[1]) + self.welcomeIfNewNode(source) + value = self.storage.get(key, None) + if value is None: + return self.rpc_find_node(sender, nodeid, key) + return {'value': value} + + async def callFindNode(self, nodeToAsk, nodeToFind): + address = (nodeToAsk.ip, nodeToAsk.port) + result = await self.find_node(address, self.sourceNode.id, + nodeToFind.id) + return self.handleCallResponse(result, nodeToAsk) + + async def callFindValue(self, nodeToAsk, nodeToFind): + address = (nodeToAsk.ip, nodeToAsk.port) + result = await self.find_value(address, self.sourceNode.id, + nodeToFind.id) + return self.handleCallResponse(result, nodeToAsk) + + async def callPing(self, nodeToAsk): + address = (nodeToAsk.ip, nodeToAsk.port) + result = await self.ping(address, self.sourceNode.id) + return self.handleCallResponse(result, nodeToAsk) + + async def callStore(self, nodeToAsk, key, value): + address = (nodeToAsk.ip, nodeToAsk.port) + result = await self.store(address, self.sourceNode.id, key, value) + return self.handleCallResponse(result, nodeToAsk) + + def welcomeIfNewNode(self, node): + """ + Given a new node, send it all the keys/values it should be storing, + then add it to the routing table. + + @param node: A new node that just joined (or that we just found out + about). + + Process: + For each key in storage, get k closest nodes. If newnode is closer + than the furtherst in that list, and the node for this server + is closer than the closest in that list, then store the key/value + on the new node (per section 2.5 of the paper) + """ + if not self.router.isNewNode(node): + return + + log.info("never seen %s before, adding to router", node) + for key, value in self.storage.items(): + keynode = Node(digest(key)) + neighbors = self.router.findNeighbors(keynode) + if len(neighbors) > 0: + last = neighbors[-1].distanceTo(keynode) + newNodeClose = node.distanceTo(keynode) < last + first = neighbors[0].distanceTo(keynode) + thisNodeClosest = self.sourceNode.distanceTo(keynode) < first + if len(neighbors) == 0 or (newNodeClose and thisNodeClosest): + asyncio.ensure_future(self.callStore(node, key, value)) + self.router.addContact(node) + + def handleCallResponse(self, result, node): + """ + If we get a response, add the node to the routing table. If + we get no response, make sure it's removed from the routing table. + """ + if not result[0]: + log.warning("no response from %s, removing from router", node) + self.router.removeContact(node) + return result + + log.info("got successful response from %s", node) + self.welcomeIfNewNode(node) + return result diff --git a/kademlia/routing.py b/kademlia/routing.py new file mode 100644 index 00000000..85f3ddaa --- /dev/null +++ b/kademlia/routing.py @@ -0,0 +1,184 @@ +import heapq +import time +import operator +import asyncio + +from collections import OrderedDict +from kademlia.utils import OrderedSet, sharedPrefix, bytesToBitString + + +class KBucket(object): + def __init__(self, rangeLower, rangeUpper, ksize): + self.range = (rangeLower, rangeUpper) + self.nodes = OrderedDict() + self.replacementNodes = OrderedSet() + self.touchLastUpdated() + self.ksize = ksize + + def touchLastUpdated(self): + self.lastUpdated = time.monotonic() + + def getNodes(self): + return list(self.nodes.values()) + + def split(self): + midpoint = (self.range[0] + self.range[1]) / 2 + one = KBucket(self.range[0], midpoint, self.ksize) + two = KBucket(midpoint + 1, self.range[1], self.ksize) + for node in self.nodes.values(): + bucket = one if node.long_id <= midpoint else two + bucket.nodes[node.id] = node + return (one, two) + + def removeNode(self, node): + if node.id not in self.nodes: + return + + # delete node, and see if we can add a replacement + del self.nodes[node.id] + if len(self.replacementNodes) > 0: + newnode = self.replacementNodes.pop() + self.nodes[newnode.id] = newnode + + def hasInRange(self, node): + return self.range[0] <= node.long_id <= self.range[1] + + def isNewNode(self, node): + return node.id not in self.nodes + + def addNode(self, node): + """ + Add a C{Node} to the C{KBucket}. Return True if successful, + False if the bucket is full. + + If the bucket is full, keep track of node in a replacement list, + per section 4.1 of the paper. + """ + if node.id in self.nodes: + del self.nodes[node.id] + self.nodes[node.id] = node + elif len(self) < self.ksize: + self.nodes[node.id] = node + else: + self.replacementNodes.push(node) + return False + return True + + def depth(self): + vals = self.nodes.values() + sp = sharedPrefix([bytesToBitString(n.id) for n in vals]) + return len(sp) + + def head(self): + return list(self.nodes.values())[0] + + def __getitem__(self, node_id): + return self.nodes.get(node_id, None) + + def __len__(self): + return len(self.nodes) + + +class TableTraverser(object): + def __init__(self, table, startNode): + index = table.getBucketFor(startNode) + table.buckets[index].touchLastUpdated() + self.currentNodes = table.buckets[index].getNodes() + self.leftBuckets = table.buckets[:index] + self.rightBuckets = table.buckets[(index + 1):] + self.left = True + + def __iter__(self): + return self + + def __next__(self): + """ + Pop an item from the left subtree, then right, then left, etc. + """ + if len(self.currentNodes) > 0: + return self.currentNodes.pop() + + if self.left and len(self.leftBuckets) > 0: + self.currentNodes = self.leftBuckets.pop().getNodes() + self.left = False + return next(self) + + if len(self.rightBuckets) > 0: + self.currentNodes = self.rightBuckets.pop(0).getNodes() + self.left = True + return next(self) + + raise StopIteration + + +class RoutingTable(object): + def __init__(self, protocol, ksize, node): + """ + @param node: The node that represents this server. It won't + be added to the routing table, but will be needed later to + determine which buckets to split or not. + """ + self.node = node + self.protocol = protocol + self.ksize = ksize + self.flush() + + def flush(self): + self.buckets = [KBucket(0, 2 ** 160, self.ksize)] + + def splitBucket(self, index): + one, two = self.buckets[index].split() + self.buckets[index] = one + self.buckets.insert(index + 1, two) + + def getLonelyBuckets(self): + """ + Get all of the buckets that haven't been updated in over + an hour. + """ + hrago = time.monotonic() - 3600 + return [b for b in self.buckets if b.lastUpdated < hrago] + + def removeContact(self, node): + index = self.getBucketFor(node) + self.buckets[index].removeNode(node) + + def isNewNode(self, node): + index = self.getBucketFor(node) + return self.buckets[index].isNewNode(node) + + def addContact(self, node): + index = self.getBucketFor(node) + bucket = self.buckets[index] + + # this will succeed unless the bucket is full + if bucket.addNode(node): + return + + # Per section 4.2 of paper, split if the bucket has the node + # in its range or if the depth is not congruent to 0 mod 5 + if bucket.hasInRange(self.node) or bucket.depth() % 5 != 0: + self.splitBucket(index) + self.addContact(node) + else: + asyncio.ensure_future(self.protocol.callPing(bucket.head())) + + def getBucketFor(self, node): + """ + Get the index of the bucket that the given node would fall into. + """ + for index, bucket in enumerate(self.buckets): + if node.long_id < bucket.range[1]: + return index + + def findNeighbors(self, node, k=None, exclude=None): + k = k or self.ksize + nodes = [] + for neighbor in TableTraverser(self, node): + notexcluded = exclude is None or not neighbor.sameHomeAs(exclude) + if neighbor.id != node.id and notexcluded: + heapq.heappush(nodes, (node.distanceTo(neighbor), neighbor)) + if len(nodes) == k: + break + + return list(map(operator.itemgetter(1), heapq.nsmallest(k, nodes))) diff --git a/kademlia/storage.py b/kademlia/storage.py new file mode 100644 index 00000000..71865f23 --- /dev/null +++ b/kademlia/storage.py @@ -0,0 +1,97 @@ +import time +from itertools import takewhile +import operator +from collections import OrderedDict + + +class IStorage: + """ + Local storage for this node. + IStorage implementations of get must return the same type as put in by set + """ + + def __setitem__(self, key, value): + """ + Set a key to the given value. + """ + raise NotImplementedError + + def __getitem__(self, key): + """ + Get the given key. If item doesn't exist, raises C{KeyError} + """ + raise NotImplementedError + + def get(self, key, default=None): + """ + Get given key. If not found, return default. + """ + raise NotImplementedError + + def iteritemsOlderThan(self, secondsOld): + """ + Return the an iterator over (key, value) tuples for items older + than the given secondsOld. + """ + raise NotImplementedError + + def __iter__(self): + """ + Get the iterator for this storage, should yield tuple of (key, value) + """ + raise NotImplementedError + + +class ForgetfulStorage(IStorage): + def __init__(self, ttl=604800): + """ + By default, max age is a week. + """ + self.data = OrderedDict() + self.ttl = ttl + + def __setitem__(self, key, value): + if key in self.data: + del self.data[key] + self.data[key] = (time.monotonic(), value) + self.cull() + + def cull(self): + for _, _ in self.iteritemsOlderThan(self.ttl): + self.data.popitem(last=False) + + def get(self, key, default=None): + self.cull() + if key in self.data: + return self[key] + return default + + def __getitem__(self, key): + self.cull() + return self.data[key][1] + + def __iter__(self): + self.cull() + return iter(self.data) + + def __repr__(self): + self.cull() + return repr(self.data) + + def iteritemsOlderThan(self, secondsOld): + minBirthday = time.monotonic() - secondsOld + zipped = self._tripleIterable() + matches = takewhile(lambda r: minBirthday >= r[1], zipped) + return list(map(operator.itemgetter(0, 2), matches)) + + def _tripleIterable(self): + ikeys = self.data.keys() + ibirthday = map(operator.itemgetter(0), self.data.values()) + ivalues = map(operator.itemgetter(1), self.data.values()) + return zip(ikeys, ibirthday, ivalues) + + def items(self): + self.cull() + ikeys = self.data.keys() + ivalues = map(operator.itemgetter(1), self.data.values()) + return zip(ikeys, ivalues) diff --git a/kademlia/utils.py b/kademlia/utils.py new file mode 100644 index 00000000..9732bf5d --- /dev/null +++ b/kademlia/utils.py @@ -0,0 +1,57 @@ +""" +General catchall for functions that don't make sense as methods. +""" +import hashlib +import operator +import asyncio + + +async def gather_dict(d): + cors = list(d.values()) + results = await asyncio.gather(*cors) + return dict(zip(d.keys(), results)) + + +def digest(s): + if not isinstance(s, bytes): + s = str(s).encode('utf8') + return hashlib.sha1(s).digest() + + +class OrderedSet(list): + """ + Acts like a list in all ways, except in the behavior of the + :meth:`push` method. + """ + + def push(self, thing): + """ + 1. If the item exists in the list, it's removed + 2. The item is pushed to the end of the list + """ + if thing in self: + self.remove(thing) + self.append(thing) + + +def sharedPrefix(args): + """ + Find the shared prefix between the strings. + + For instance: + + sharedPrefix(['blahblah', 'blahwhat']) + + returns 'blah'. + """ + i = 0 + while i < min(map(len, args)): + if len(set(map(operator.itemgetter(i), args))) != 1: + break + i += 1 + return args[0][:i] + + +def bytesToBitString(bites): + bits = [bin(bite)[2:].rjust(8, '0') for bite in bites] + return "".join(bits)