pass test

This commit is contained in:
zixuanzh
2019-04-28 17:58:14 -04:00
parent db7be2d561
commit e1d6fdae73
3 changed files with 32 additions and 17 deletions

View File

@ -166,6 +166,20 @@ class KademliaServer:
dkey = digest(key) dkey = digest(key)
return await self.set_digest(dkey, value) return await self.set_digest(dkey, value)
async def provide(self, key):
"""
publish to the network that it provides for a particular key
"""
neighbors = self.protocol.router.find_neighbors(self.node)
return [await self.protocol.call_add_provider(n, key, self.node.peer_id) for n in neighbors]
async def get_providers(self, key):
"""
get the list of providers for a key
"""
neighbors = self.protocol.router.find_neighbors(self.node)
return [await self.protocol.call_get_providers(n, key) for n in neighbors]
async def set_digest(self, dkey, value): async def set_digest(self, dkey, value):
""" """
Set the given SHA1 digest key (bytes) to the given value in the Set the given SHA1 digest key (bytes) to the given value in the

View File

@ -5,7 +5,6 @@ import logging
from rpcudp.protocol import RPCProtocol from rpcudp.protocol import RPCProtocol
from .kad_peerinfo import create_kad_peerinfo from .kad_peerinfo import create_kad_peerinfo
from .routing import RoutingTable from .routing import RoutingTable
from .utils import validate_provider_id
log = logging.getLogger(__name__) # pylint: disable=invalid-name log = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -74,35 +73,39 @@ class KademliaProtocol(RPCProtocol):
return self.rpc_find_node(sender, nodeid, key) return self.rpc_find_node(sender, nodeid, key)
return {'value': value} return {'value': value}
def rpc_add_provider(self, sender, nodeid, key, provider_peerinfo): def rpc_add_provider(self, sender, nodeid, key, provider_id):
# pylint: disable=unused-argument
""" """
rpc when receiving an add_provider call rpc when receiving an add_provider call
should validate received PeerInfo matches sender nodeid should validate received PeerInfo matches sender nodeid
if it does, receipient must store a record in its datastore if it does, receipient must store a record in its datastore
we store a map of content_id to peer_id (non xor)
""" """
log.info("adding provider for key %s in local table", if nodeid == provider_id:
str(key)) log.info("adding provider %s for key %s in local table",
if validate_provider_id(nodeid, provider_peerinfo): provider_id, str(key))
source = create_kad_peerinfo(nodeid, sender[0], sender[1]) self.storage[key] = provider_id
# TODO differentiate this from key, value
self.storage[key] = provider_peerinfo
return True return True
return False return False
def rpc_get_providers(self, sender, key): def rpc_get_providers(self, sender, key):
# pylint: disable=unused-argument
""" """
rpc when receiving a get_providers call rpc when receiving a get_providers call
should look up key in data store and respond with records should look up key in data store and respond with records
plus a list of closer peers in itrs routing table plus a list of closer peers in its routing table
""" """
providers = [] providers = []
record = self.storage.get(key, None) record = self.storage.get(key, None)
if record: if record:
providers.append(record) providers.append(record)
keynode = create_kad_peerinfo(key) keynode = create_kad_peerinfo(key)
neighbors = self.router.find_neighbors(keynode) neighbors = self.router.find_neighbors(keynode)
providers.extend(neighbors) for neighbor in neighbors:
if neighbor.peer_id != record:
providers.append(neighbor.peer_id)
return providers return providers
@ -128,13 +131,15 @@ class KademliaProtocol(RPCProtocol):
result = await self.store(address, self.source_node.peer_id, key, value) result = await self.store(address, self.source_node.peer_id, key, value)
return self.handle_call_response(result, node_to_ask) return self.handle_call_response(result, node_to_ask)
def call_add_provider(self, node_to_ask, key): async def call_add_provider(self, node_to_ask, key, provider_id):
address = (node_to_ask.ip, node_to_ask.port) address = (node_to_ask.ip, node_to_ask.port)
result = await self.add_provider(address, result = await self.add_provider(address,
self.source_node.peer_id, key) self.source_node.peer_id,
key, provider_id)
return self.handle_call_response(result, node_to_ask) return self.handle_call_response(result, node_to_ask)
def call_get_providers(self, node_to_ask, key): async def call_get_providers(self, node_to_ask, key):
address = (node_to_ask.ip, node_to_ask.port) address = (node_to_ask.ip, node_to_ask.port)
result = await self.get_providers(address, key) result = await self.get_providers(address, key)
return self.handle_call_response(result, node_to_ask) return self.handle_call_response(result, node_to_ask)

View File

@ -55,7 +55,3 @@ def shared_prefix(args):
def bytes_to_bit_string(bites): def bytes_to_bit_string(bites):
bits = [bin(bite)[2:].rjust(8, '0') for bite in bites] bits = [bin(bite)[2:].rjust(8, '0') for bite in bites]
return "".join(bits) return "".join(bits)
def validate_provider_id(sender_id, sender_peerinfo):
return sender_id == sender_peerinfo.peer_id