mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
826 lines
32 KiB
Python
826 lines
32 KiB
Python
"""
|
|
Kademlia DHT implementation for py-libp2p.
|
|
|
|
This module provides a complete Distributed Hash Table (DHT)
|
|
implementation based on the Kademlia algorithm and protocol.
|
|
"""
|
|
|
|
from collections.abc import Awaitable, Callable
|
|
from enum import (
|
|
Enum,
|
|
)
|
|
import logging
|
|
import time
|
|
|
|
from multiaddr import (
|
|
Multiaddr,
|
|
)
|
|
import trio
|
|
import varint
|
|
|
|
from libp2p.abc import (
|
|
IHost,
|
|
)
|
|
from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager
|
|
from libp2p.kad_dht.utils import maybe_consume_signed_record
|
|
from libp2p.network.stream.net_stream import (
|
|
INetStream,
|
|
)
|
|
from libp2p.peer.envelope import Envelope
|
|
from libp2p.peer.id import (
|
|
ID,
|
|
)
|
|
from libp2p.peer.peerinfo import (
|
|
PeerInfo,
|
|
)
|
|
from libp2p.peer.peerstore import env_to_send_in_RPC
|
|
from libp2p.tools.async_service import (
|
|
Service,
|
|
)
|
|
|
|
from .common import (
|
|
ALPHA,
|
|
PROTOCOL_ID,
|
|
QUERY_TIMEOUT,
|
|
)
|
|
from .pb.kademlia_pb2 import (
|
|
Message,
|
|
)
|
|
from .peer_routing import (
|
|
PeerRouting,
|
|
)
|
|
from .provider_store import (
|
|
ProviderStore,
|
|
)
|
|
from .routing_table import (
|
|
RoutingTable,
|
|
)
|
|
from .value_store import (
|
|
ValueStore,
|
|
)
|
|
|
|
logger = logging.getLogger("kademlia-example.kad_dht")
|
|
# logger = logging.getLogger("libp2p.kademlia")
|
|
# Default parameters
|
|
ROUTING_TABLE_REFRESH_INTERVAL = 60 # 1 min in seconds for testing
|
|
|
|
|
|
class DHTMode(Enum):
|
|
"""DHT operation modes."""
|
|
|
|
CLIENT = "CLIENT"
|
|
SERVER = "SERVER"
|
|
|
|
|
|
class KadDHT(Service):
|
|
"""
|
|
Kademlia DHT implementation for libp2p.
|
|
|
|
This class provides a DHT implementation that combines routing table management,
|
|
peer discovery, content routing, and value storage.
|
|
|
|
Optional Random Walk feature enhances peer discovery by automatically
|
|
performing periodic random queries to discover new peers and maintain
|
|
routing table health.
|
|
|
|
Example:
|
|
# Basic DHT without random walk (default)
|
|
dht = KadDHT(host, DHTMode.SERVER)
|
|
|
|
# DHT with random walk enabled for enhanced peer discovery
|
|
dht = KadDHT(host, DHTMode.SERVER, enable_random_walk=True)
|
|
|
|
"""
|
|
|
|
def __init__(self, host: IHost, mode: DHTMode, enable_random_walk: bool = False):
|
|
"""
|
|
Initialize a new Kademlia DHT node.
|
|
|
|
:param host: The libp2p host.
|
|
:param mode: The mode of host (Client or Server) - must be DHTMode enum
|
|
:param enable_random_walk: Whether to enable automatic random walk
|
|
"""
|
|
super().__init__()
|
|
|
|
self.host = host
|
|
self.local_peer_id = host.get_id()
|
|
|
|
# Validate that mode is a DHTMode enum
|
|
if not isinstance(mode, DHTMode):
|
|
raise TypeError(f"mode must be DHTMode enum, got {type(mode)}")
|
|
|
|
self.mode = mode
|
|
self.enable_random_walk = enable_random_walk
|
|
|
|
# Initialize the routing table
|
|
self.routing_table = RoutingTable(self.local_peer_id, self.host)
|
|
|
|
# Initialize peer routing
|
|
self.peer_routing = PeerRouting(host, self.routing_table)
|
|
|
|
# Initialize value store
|
|
self.value_store = ValueStore(host=host, local_peer_id=self.local_peer_id)
|
|
|
|
# Initialize provider store with host and peer_routing references
|
|
self.provider_store = ProviderStore(host=host, peer_routing=self.peer_routing)
|
|
|
|
# Last time we republished provider records
|
|
self._last_provider_republish = time.time()
|
|
|
|
# Initialize RT Refresh Manager (only if random walk is enabled)
|
|
self.rt_refresh_manager: RTRefreshManager | None = None
|
|
if self.enable_random_walk:
|
|
self.rt_refresh_manager = RTRefreshManager(
|
|
host=self.host,
|
|
routing_table=self.routing_table,
|
|
local_peer_id=self.local_peer_id,
|
|
query_function=self._create_query_function(),
|
|
enable_auto_refresh=True,
|
|
)
|
|
|
|
# Set protocol handlers
|
|
host.set_stream_handler(PROTOCOL_ID, self.handle_stream)
|
|
|
|
def _create_query_function(self) -> Callable[[bytes], Awaitable[list[ID]]]:
|
|
"""
|
|
Create a query function that wraps peer_routing.find_closest_peers_network.
|
|
|
|
This function is used by the RandomWalk module to query for peers without
|
|
directly importing PeerRouting, avoiding circular import issues.
|
|
|
|
Returns:
|
|
Callable that takes target_key bytes and returns list of peer IDs
|
|
|
|
"""
|
|
|
|
async def query_function(target_key: bytes) -> list[ID]:
|
|
"""Query for closest peers to target key."""
|
|
return await self.peer_routing.find_closest_peers_network(target_key)
|
|
|
|
return query_function
|
|
|
|
async def run(self) -> None:
|
|
"""Run the DHT service."""
|
|
logger.info(f"Starting Kademlia DHT with peer ID {self.local_peer_id}")
|
|
|
|
# Start the RT Refresh Manager in parallel with the main DHT service
|
|
async with trio.open_nursery() as nursery:
|
|
# Start the RT Refresh Manager only if random walk is enabled
|
|
if self.rt_refresh_manager is not None:
|
|
nursery.start_soon(self.rt_refresh_manager.start)
|
|
logger.info("RT Refresh Manager started - Random Walk is now active")
|
|
else:
|
|
logger.info("Random Walk is disabled - RT Refresh Manager not started")
|
|
|
|
# Start the main DHT service loop
|
|
nursery.start_soon(self._run_main_loop)
|
|
|
|
async def _run_main_loop(self) -> None:
|
|
"""Run the main DHT service loop."""
|
|
# Main service loop
|
|
while self.manager.is_running:
|
|
# Periodically refresh the routing table
|
|
await self.refresh_routing_table()
|
|
|
|
# Check if it's time to republish provider records
|
|
current_time = time.time()
|
|
# await self._republish_provider_records()
|
|
self._last_provider_republish = current_time
|
|
|
|
# Clean up expired values and provider records
|
|
expired_values = self.value_store.cleanup_expired()
|
|
if expired_values > 0:
|
|
logger.debug(f"Cleaned up {expired_values} expired values")
|
|
|
|
self.provider_store.cleanup_expired()
|
|
|
|
# Wait before next maintenance cycle
|
|
await trio.sleep(ROUTING_TABLE_REFRESH_INTERVAL)
|
|
|
|
async def stop(self) -> None:
|
|
"""Stop the DHT service and cleanup resources."""
|
|
logger.info("Stopping Kademlia DHT")
|
|
|
|
# Stop the RT Refresh Manager only if it was started
|
|
if self.rt_refresh_manager is not None:
|
|
await self.rt_refresh_manager.stop()
|
|
logger.info("RT Refresh Manager stopped")
|
|
else:
|
|
logger.info("RT Refresh Manager was not running (Random Walk disabled)")
|
|
|
|
async def switch_mode(self, new_mode: DHTMode) -> DHTMode:
|
|
"""
|
|
Switch the DHT mode.
|
|
|
|
:param new_mode: The new mode - must be DHTMode enum
|
|
:return: The new mode as DHTMode enum
|
|
"""
|
|
# Validate that new_mode is a DHTMode enum
|
|
if not isinstance(new_mode, DHTMode):
|
|
raise TypeError(f"new_mode must be DHTMode enum, got {type(new_mode)}")
|
|
|
|
if new_mode == DHTMode.CLIENT:
|
|
self.routing_table.cleanup_routing_table()
|
|
self.mode = new_mode
|
|
logger.info(f"Switched to {new_mode.value} mode")
|
|
return self.mode
|
|
|
|
async def handle_stream(self, stream: INetStream) -> None:
|
|
"""
|
|
Handle an incoming DHT stream using varint length prefixes.
|
|
"""
|
|
if self.mode == DHTMode.CLIENT:
|
|
stream.close
|
|
return
|
|
peer_id = stream.muxed_conn.peer_id
|
|
logger.debug(f"Received DHT stream from peer {peer_id}")
|
|
await self.add_peer(peer_id)
|
|
logger.debug(f"Added peer {peer_id} to routing table")
|
|
|
|
closer_peer_envelope: Envelope | None = None
|
|
provider_peer_envelope: Envelope | None = None
|
|
|
|
try:
|
|
# Read varint-prefixed length for the message
|
|
length_prefix = b""
|
|
while True:
|
|
byte = await stream.read(1)
|
|
if not byte:
|
|
logger.warning("Stream closed while reading varint length")
|
|
await stream.close()
|
|
return
|
|
length_prefix += byte
|
|
if byte[0] & 0x80 == 0:
|
|
break
|
|
msg_length = varint.decode_bytes(length_prefix)
|
|
|
|
# Read the message bytes
|
|
msg_bytes = await stream.read(msg_length)
|
|
if len(msg_bytes) < msg_length:
|
|
logger.warning("Failed to read full message from stream")
|
|
await stream.close()
|
|
return
|
|
|
|
try:
|
|
# Parse as protobuf
|
|
message = Message()
|
|
message.ParseFromString(msg_bytes)
|
|
logger.debug(
|
|
f"Received DHT message from {peer_id}, type: {message.type}"
|
|
)
|
|
|
|
# Handle FIND_NODE message
|
|
if message.type == Message.MessageType.FIND_NODE:
|
|
# Get target key directly from protobuf
|
|
target_key = message.key
|
|
|
|
# Find closest peers to the target key
|
|
closest_peers = self.routing_table.find_local_closest_peers(
|
|
target_key, 20
|
|
)
|
|
logger.debug(f"Found {len(closest_peers)} peers close to target")
|
|
|
|
# Consume the source signed_peer_record if sent
|
|
if not maybe_consume_signed_record(message, self.host, peer_id):
|
|
logger.error(
|
|
"Received an invalid-signed-record, dropping the stream"
|
|
)
|
|
await stream.close()
|
|
return
|
|
|
|
# Build response message with protobuf
|
|
response = Message()
|
|
response.type = Message.MessageType.FIND_NODE
|
|
|
|
# Add closest peers to response
|
|
for peer in closest_peers:
|
|
# Skip if the peer is the requester
|
|
if peer == peer_id:
|
|
continue
|
|
|
|
# Add peer to closerPeers field
|
|
peer_proto = response.closerPeers.add()
|
|
peer_proto.id = peer.to_bytes()
|
|
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
|
|
|
|
# Add addresses if available
|
|
try:
|
|
addrs = self.host.get_peerstore().addrs(peer)
|
|
if addrs:
|
|
for addr in addrs:
|
|
peer_proto.addrs.append(addr.to_bytes())
|
|
except Exception:
|
|
pass
|
|
|
|
# Add the signed-peer-record for each peer in the peer-proto
|
|
# if cached in the peerstore
|
|
closer_peer_envelope = (
|
|
self.host.get_peerstore().get_peer_record(peer)
|
|
)
|
|
|
|
if closer_peer_envelope is not None:
|
|
peer_proto.signedRecord = (
|
|
closer_peer_envelope.marshal_envelope()
|
|
)
|
|
|
|
# Create sender_signed_peer_record
|
|
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
|
response.senderRecord = envelope_bytes
|
|
|
|
# Serialize and send response
|
|
response_bytes = response.SerializeToString()
|
|
await stream.write(varint.encode(len(response_bytes)))
|
|
await stream.write(response_bytes)
|
|
logger.debug(
|
|
f"Sent FIND_NODE response with{len(response.closerPeers)} peers"
|
|
)
|
|
|
|
# Handle ADD_PROVIDER message
|
|
elif message.type == Message.MessageType.ADD_PROVIDER:
|
|
# Process ADD_PROVIDER
|
|
key = message.key
|
|
logger.debug(f"Received ADD_PROVIDER for key {key.hex()}")
|
|
|
|
# Consume the source signed-peer-record if sent
|
|
if not maybe_consume_signed_record(message, self.host, peer_id):
|
|
logger.error(
|
|
"Received an invalid-signed-record, dropping the stream"
|
|
)
|
|
await stream.close()
|
|
return
|
|
|
|
# Extract provider information
|
|
for provider_proto in message.providerPeers:
|
|
try:
|
|
# Validate that the provider is the sender
|
|
provider_id = ID(provider_proto.id)
|
|
if provider_id != peer_id:
|
|
logger.warning(
|
|
f"Provider ID {provider_id} doesn't"
|
|
f"match sender {peer_id}, ignoring"
|
|
)
|
|
continue
|
|
|
|
# Convert addresses to Multiaddr
|
|
addrs = []
|
|
for addr_bytes in provider_proto.addrs:
|
|
try:
|
|
addrs.append(Multiaddr(addr_bytes))
|
|
except Exception as e:
|
|
logger.warning(f"Failed to parse address: {e}")
|
|
|
|
# Add to provider store
|
|
provider_info = PeerInfo(provider_id, addrs)
|
|
self.provider_store.add_provider(key, provider_info)
|
|
logger.debug(
|
|
f"Added provider {provider_id} for key {key.hex()}"
|
|
)
|
|
|
|
# Process the signed-records of provider if sent
|
|
if not maybe_consume_signed_record(
|
|
provider_proto, self.host
|
|
):
|
|
logger.error(
|
|
"Received an invalid-signed-record,"
|
|
"dropping the stream"
|
|
)
|
|
await stream.close()
|
|
return
|
|
except Exception as e:
|
|
logger.warning(f"Failed to process provider info: {e}")
|
|
|
|
# Send acknowledgement
|
|
response = Message()
|
|
response.type = Message.MessageType.ADD_PROVIDER
|
|
response.key = key
|
|
|
|
# Add sender's signed-peer-record
|
|
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
|
response.senderRecord = envelope_bytes
|
|
|
|
response_bytes = response.SerializeToString()
|
|
await stream.write(varint.encode(len(response_bytes)))
|
|
await stream.write(response_bytes)
|
|
logger.debug("Sent ADD_PROVIDER acknowledgement")
|
|
|
|
# Handle GET_PROVIDERS message
|
|
elif message.type == Message.MessageType.GET_PROVIDERS:
|
|
# Process GET_PROVIDERS
|
|
key = message.key
|
|
logger.debug(f"Received GET_PROVIDERS request for key {key.hex()}")
|
|
|
|
# Consume the source signed_peer_record if sent
|
|
if not maybe_consume_signed_record(message, self.host, peer_id):
|
|
logger.error(
|
|
"Received an invalid-signed-record, dropping the stream"
|
|
)
|
|
await stream.close()
|
|
return
|
|
|
|
# Find providers for the key
|
|
providers = self.provider_store.get_providers(key)
|
|
logger.debug(
|
|
f"Found {len(providers)} providers for key {key.hex()}"
|
|
)
|
|
|
|
# Create response
|
|
response = Message()
|
|
response.type = Message.MessageType.GET_PROVIDERS
|
|
response.key = key
|
|
|
|
# Create sender_signed_peer_record for the response
|
|
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
|
response.senderRecord = envelope_bytes
|
|
|
|
# Add provider information to response
|
|
for provider_info in providers:
|
|
provider_proto = response.providerPeers.add()
|
|
provider_proto.id = provider_info.peer_id.to_bytes()
|
|
provider_proto.connection = Message.ConnectionType.CAN_CONNECT
|
|
|
|
# Add provider signed-records if cached
|
|
provider_peer_envelope = (
|
|
self.host.get_peerstore().get_peer_record(
|
|
provider_info.peer_id
|
|
)
|
|
)
|
|
|
|
if provider_peer_envelope is not None:
|
|
provider_proto.signedRecord = (
|
|
provider_peer_envelope.marshal_envelope()
|
|
)
|
|
|
|
# Add addresses if available
|
|
for addr in provider_info.addrs:
|
|
provider_proto.addrs.append(addr.to_bytes())
|
|
|
|
# Also include closest peers if we don't have providers
|
|
if not providers:
|
|
closest_peers = self.routing_table.find_local_closest_peers(
|
|
key, 20
|
|
)
|
|
logger.debug(
|
|
f"No providers found, including {len(closest_peers)}"
|
|
"closest peers"
|
|
)
|
|
|
|
for peer in closest_peers:
|
|
# Skip if peer is the requester
|
|
if peer == peer_id:
|
|
continue
|
|
|
|
peer_proto = response.closerPeers.add()
|
|
peer_proto.id = peer.to_bytes()
|
|
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
|
|
|
|
# Add the signed-records of closest_peers if cached
|
|
closer_peer_envelope = (
|
|
self.host.get_peerstore().get_peer_record(peer)
|
|
)
|
|
|
|
if closer_peer_envelope is not None:
|
|
peer_proto.signedRecord = (
|
|
closer_peer_envelope.marshal_envelope()
|
|
)
|
|
|
|
# Add addresses if available
|
|
try:
|
|
addrs = self.host.get_peerstore().addrs(peer)
|
|
for addr in addrs:
|
|
peer_proto.addrs.append(addr.to_bytes())
|
|
except Exception:
|
|
pass
|
|
|
|
# Serialize and send response
|
|
response_bytes = response.SerializeToString()
|
|
await stream.write(varint.encode(len(response_bytes)))
|
|
await stream.write(response_bytes)
|
|
logger.debug("Sent GET_PROVIDERS response")
|
|
|
|
# Handle GET_VALUE message
|
|
elif message.type == Message.MessageType.GET_VALUE:
|
|
# Process GET_VALUE
|
|
key = message.key
|
|
logger.debug(f"Received GET_VALUE request for key {key.hex()}")
|
|
|
|
# Consume the sender_signed_peer_record
|
|
if not maybe_consume_signed_record(message, self.host, peer_id):
|
|
logger.error(
|
|
"Received an invalid-signed-record, dropping the stream"
|
|
)
|
|
await stream.close()
|
|
return
|
|
|
|
value = self.value_store.get(key)
|
|
if value:
|
|
logger.debug(f"Found value for key {key.hex()}")
|
|
|
|
# Create response using protobuf
|
|
response = Message()
|
|
response.type = Message.MessageType.GET_VALUE
|
|
|
|
# Create record
|
|
response.key = key
|
|
response.record.key = key
|
|
response.record.value = value
|
|
response.record.timeReceived = str(time.time())
|
|
|
|
# Create sender_signed_peer_record
|
|
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
|
response.senderRecord = envelope_bytes
|
|
|
|
# Serialize and send response
|
|
response_bytes = response.SerializeToString()
|
|
await stream.write(varint.encode(len(response_bytes)))
|
|
await stream.write(response_bytes)
|
|
logger.debug("Sent GET_VALUE response")
|
|
else:
|
|
logger.debug(f"No value found for key {key.hex()}")
|
|
|
|
# Create response with closest peers when no value is found
|
|
response = Message()
|
|
response.type = Message.MessageType.GET_VALUE
|
|
response.key = key
|
|
|
|
# Create sender_signed_peer_record for the response
|
|
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
|
response.senderRecord = envelope_bytes
|
|
|
|
# Add closest peers to key
|
|
closest_peers = self.routing_table.find_local_closest_peers(
|
|
key, 20
|
|
)
|
|
logger.debug(
|
|
"No value found,"
|
|
f"including {len(closest_peers)} closest peers"
|
|
)
|
|
|
|
for peer in closest_peers:
|
|
# Skip if peer is the requester
|
|
if peer == peer_id:
|
|
continue
|
|
|
|
peer_proto = response.closerPeers.add()
|
|
peer_proto.id = peer.to_bytes()
|
|
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
|
|
|
|
# Add signed-records of closer-peers if cached
|
|
closer_peer_envelope = (
|
|
self.host.get_peerstore().get_peer_record(peer)
|
|
)
|
|
|
|
if closer_peer_envelope is not None:
|
|
peer_proto.signedRecord = (
|
|
closer_peer_envelope.marshal_envelope()
|
|
)
|
|
|
|
# Add addresses if available
|
|
try:
|
|
addrs = self.host.get_peerstore().addrs(peer)
|
|
for addr in addrs:
|
|
peer_proto.addrs.append(addr.to_bytes())
|
|
except Exception:
|
|
pass
|
|
|
|
# Serialize and send response
|
|
response_bytes = response.SerializeToString()
|
|
await stream.write(varint.encode(len(response_bytes)))
|
|
await stream.write(response_bytes)
|
|
logger.debug("Sent GET_VALUE response with closest peers")
|
|
|
|
# Handle PUT_VALUE message
|
|
elif message.type == Message.MessageType.PUT_VALUE and message.HasField(
|
|
"record"
|
|
):
|
|
# Process PUT_VALUE
|
|
key = message.record.key
|
|
value = message.record.value
|
|
success = False
|
|
|
|
# Consume the source signed_peer_record if sent
|
|
if not maybe_consume_signed_record(message, self.host, peer_id):
|
|
logger.error(
|
|
"Received an invalid-signed-record, dropping the stream"
|
|
)
|
|
await stream.close()
|
|
return
|
|
|
|
try:
|
|
if not (key and value):
|
|
raise ValueError(
|
|
"Missing key or value in PUT_VALUE message"
|
|
)
|
|
|
|
self.value_store.put(key, value)
|
|
logger.debug(f"Stored value {value.hex()} for key {key.hex()}")
|
|
success = True
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Failed to store value {value.hex()} for key "
|
|
f"{key.hex()}: {e}"
|
|
)
|
|
finally:
|
|
# Send acknowledgement
|
|
response = Message()
|
|
response.type = Message.MessageType.PUT_VALUE
|
|
if success:
|
|
response.key = key
|
|
|
|
# Create sender_signed_peer_record for the response
|
|
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
|
response.senderRecord = envelope_bytes
|
|
|
|
# Serialize and send response
|
|
response_bytes = response.SerializeToString()
|
|
await stream.write(varint.encode(len(response_bytes)))
|
|
await stream.write(response_bytes)
|
|
logger.debug("Sent PUT_VALUE acknowledgement")
|
|
|
|
except Exception as proto_err:
|
|
logger.warning(f"Failed to parse protobuf message: {proto_err}")
|
|
|
|
await stream.close()
|
|
except Exception as e:
|
|
logger.error(f"Error handling DHT stream: {e}")
|
|
await stream.close()
|
|
|
|
async def refresh_routing_table(self) -> None:
|
|
"""Refresh the routing table."""
|
|
logger.debug("Refreshing routing table")
|
|
await self.peer_routing.refresh_routing_table()
|
|
|
|
# Peer routing methods
|
|
|
|
async def find_peer(self, peer_id: ID) -> PeerInfo | None:
|
|
"""
|
|
Find a peer with the given ID.
|
|
"""
|
|
logger.debug(f"Finding peer: {peer_id}")
|
|
return await self.peer_routing.find_peer(peer_id)
|
|
|
|
# Value storage and retrieval methods
|
|
|
|
async def put_value(self, key: bytes, value: bytes) -> None:
|
|
"""
|
|
Store a value in the DHT.
|
|
"""
|
|
logger.debug(f"Storing value for key {key.hex()}")
|
|
|
|
# 1. Store locally first
|
|
self.value_store.put(key, value)
|
|
try:
|
|
decoded_value = value.decode("utf-8")
|
|
except UnicodeDecodeError:
|
|
decoded_value = value.hex()
|
|
logger.debug(
|
|
f"Stored value locally for key {key.hex()} with value {decoded_value}"
|
|
)
|
|
|
|
# 2. Get closest peers, excluding self
|
|
closest_peers = [
|
|
peer
|
|
for peer in self.routing_table.find_local_closest_peers(key)
|
|
if peer != self.local_peer_id
|
|
]
|
|
logger.debug(f"Found {len(closest_peers)} peers to store value at")
|
|
|
|
# 3. Store at remote peers in batches of ALPHA, in parallel
|
|
stored_count = 0
|
|
for i in range(0, len(closest_peers), ALPHA):
|
|
batch = closest_peers[i : i + ALPHA]
|
|
batch_results = [False] * len(batch)
|
|
|
|
async def store_one(idx: int, peer: ID) -> None:
|
|
try:
|
|
with trio.move_on_after(QUERY_TIMEOUT):
|
|
success = await self.value_store._store_at_peer(
|
|
peer, key, value
|
|
)
|
|
batch_results[idx] = success
|
|
if success:
|
|
logger.debug(f"Stored value at peer {peer}")
|
|
else:
|
|
logger.debug(f"Failed to store value at peer {peer}")
|
|
except Exception as e:
|
|
logger.debug(f"Error storing value at peer {peer}: {e}")
|
|
|
|
async with trio.open_nursery() as nursery:
|
|
for idx, peer in enumerate(batch):
|
|
nursery.start_soon(store_one, idx, peer)
|
|
|
|
stored_count += sum(batch_results)
|
|
|
|
logger.info(f"Successfully stored value at {stored_count} peers")
|
|
|
|
async def get_value(self, key: bytes) -> bytes | None:
|
|
logger.debug(f"Getting value for key: {key.hex()}")
|
|
|
|
# 1. Check local store first
|
|
value = self.value_store.get(key)
|
|
if value:
|
|
logger.debug("Found value locally")
|
|
return value
|
|
|
|
# 2. Get closest peers, excluding self
|
|
closest_peers = [
|
|
peer
|
|
for peer in self.routing_table.find_local_closest_peers(key)
|
|
if peer != self.local_peer_id
|
|
]
|
|
logger.debug(f"Searching {len(closest_peers)} peers for value")
|
|
|
|
# 3. Query ALPHA peers at a time in parallel
|
|
for i in range(0, len(closest_peers), ALPHA):
|
|
batch = closest_peers[i : i + ALPHA]
|
|
found_value = None
|
|
|
|
async def query_one(peer: ID) -> None:
|
|
nonlocal found_value
|
|
try:
|
|
with trio.move_on_after(QUERY_TIMEOUT):
|
|
value = await self.value_store._get_from_peer(peer, key)
|
|
if value is not None and found_value is None:
|
|
found_value = value
|
|
logger.debug(f"Found value at peer {peer}")
|
|
except Exception as e:
|
|
logger.debug(f"Error querying peer {peer}: {e}")
|
|
|
|
async with trio.open_nursery() as nursery:
|
|
for peer in batch:
|
|
nursery.start_soon(query_one, peer)
|
|
|
|
if found_value is not None:
|
|
self.value_store.put(key, found_value)
|
|
logger.info("Successfully retrieved value from network")
|
|
return found_value
|
|
|
|
# 4. Not found
|
|
logger.warning(f"Value not found for key {key.hex()}")
|
|
return None
|
|
|
|
# Add these methods in the Utility methods section
|
|
|
|
# Utility methods
|
|
|
|
async def add_peer(self, peer_id: ID) -> bool:
|
|
"""
|
|
Add a peer to the routing table.
|
|
|
|
params: peer_id: The peer ID to add.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
True if peer was added or updated, False otherwise.
|
|
|
|
"""
|
|
return await self.routing_table.add_peer(peer_id)
|
|
|
|
async def provide(self, key: bytes) -> bool:
|
|
"""
|
|
Reference to provider_store.provide for convenience.
|
|
"""
|
|
return await self.provider_store.provide(key)
|
|
|
|
async def find_providers(self, key: bytes, count: int = 20) -> list[PeerInfo]:
|
|
"""
|
|
Reference to provider_store.find_providers for convenience.
|
|
"""
|
|
return await self.provider_store.find_providers(key, count)
|
|
|
|
def get_routing_table_size(self) -> int:
|
|
"""
|
|
Get the number of peers in the routing table.
|
|
|
|
Returns
|
|
-------
|
|
int
|
|
Number of peers.
|
|
|
|
"""
|
|
return self.routing_table.size()
|
|
|
|
def get_value_store_size(self) -> int:
|
|
"""
|
|
Get the number of items in the value store.
|
|
|
|
Returns
|
|
-------
|
|
int
|
|
Number of items.
|
|
|
|
"""
|
|
return self.value_store.size()
|
|
|
|
def is_random_walk_enabled(self) -> bool:
|
|
"""
|
|
Check if random walk peer discovery is enabled.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
True if random walk is enabled, False otherwise.
|
|
|
|
"""
|
|
return self.enable_random_walk
|