""" Kademlia DHT routing table implementation. """ from collections import ( OrderedDict, ) import logging import time import multihash import trio from libp2p.abc import ( IHost, ) from libp2p.kad_dht.utils import ( xor_distance, ) from libp2p.peer.id import ( ID, ) from libp2p.peer.peerinfo import ( PeerInfo, ) from .common import ( PROTOCOL_ID, ) from .pb.kademlia_pb2 import ( Message, ) # logger = logging.getLogger("libp2p.kademlia.routing_table") logger = logging.getLogger("kademlia-example.routing_table") # Default parameters BUCKET_SIZE = 20 # k in the Kademlia paper MAXIMUM_BUCKETS = 256 # Maximum number of buckets (for 256-bit keys) PEER_REFRESH_INTERVAL = 60 # Interval to refresh peers in seconds STALE_PEER_THRESHOLD = 3600 # Time in seconds after which a peer is considered stale def peer_id_to_key(peer_id: ID) -> bytes: """ Convert a peer ID to a 256-bit key for routing table operations. This normalizes all peer IDs to exactly 256 bits by hashing them with SHA-256. :param peer_id: The peer ID to convert :return: 32-byte (256-bit) key for routing table operations """ return multihash.digest(peer_id.to_bytes(), "sha2-256").digest def key_to_int(key: bytes) -> int: """Convert a 256-bit key to an integer for range calculations.""" return int.from_bytes(key, byteorder="big") class KBucket: """ A k-bucket implementation for the Kademlia DHT. Each k-bucket stores up to k (BUCKET_SIZE) peers, sorted by least-recently seen. """ def __init__( self, host: IHost, bucket_size: int = BUCKET_SIZE, min_range: int = 0, max_range: int = 2**256, ): """ Initialize a new k-bucket. :param host: The host this bucket belongs to :param bucket_size: Maximum number of peers to store in the bucket :param min_range: Lower boundary of the bucket's key range (inclusive) :param max_range: Upper boundary of the bucket's key range (exclusive) """ self.bucket_size = bucket_size self.host = host self.min_range = min_range self.max_range = max_range # Store PeerInfo objects along with last-seen timestamp self.peers: OrderedDict[ID, tuple[PeerInfo, float]] = OrderedDict() def peer_ids(self) -> list[ID]: """Get all peer IDs in the bucket.""" return list(self.peers.keys()) def peer_infos(self) -> list[PeerInfo]: """Get all PeerInfo objects in the bucket.""" return [info for info, _ in self.peers.values()] def get_oldest_peer(self) -> ID | None: """Get the least-recently seen peer.""" if not self.peers: return None return next(iter(self.peers.keys())) async def add_peer(self, peer_info: PeerInfo) -> bool: """ Add a peer to the bucket. Returns True if the peer was added or updated, False if the bucket is full. """ current_time = time.time() peer_id = peer_info.peer_id # If peer is already in the bucket, move it to the end (most recently seen) if peer_id in self.peers: self.refresh_peer_last_seen(peer_id) return True # If bucket has space, add the peer if len(self.peers) < self.bucket_size: self.peers[peer_id] = (peer_info, current_time) return True # If bucket is full, we need to replace the least-recently seen peer # Get the least-recently seen peer oldest_peer_id = self.get_oldest_peer() if oldest_peer_id is None: logger.warning("No oldest peer found when bucket is full") return False # Check if the old peer is responsive to ping request try: # Try to ping the oldest peer, not the new peer response = await self._ping_peer(oldest_peer_id) if response: # If the old peer is still alive, we will not add the new peer logger.debug( "Old peer %s is still alive, cannot add new peer %s", oldest_peer_id, peer_id, ) return False except Exception as e: # If the old peer is unresponsive, we can replace it with the new peer logger.debug( "Old peer %s is unresponsive, replacing with new peer %s: %s", oldest_peer_id, peer_id, str(e), ) self.peers.popitem(last=False) # Remove oldest peer self.peers[peer_id] = (peer_info, current_time) return True # If we got here, the oldest peer responded but we couldn't add the new peer return False def remove_peer(self, peer_id: ID) -> bool: """ Remove a peer from the bucket. Returns True if the peer was in the bucket, False otherwise. """ if peer_id in self.peers: del self.peers[peer_id] return True return False def has_peer(self, peer_id: ID) -> bool: """Check if the peer is in the bucket.""" return peer_id in self.peers def get_peer_info(self, peer_id: ID) -> PeerInfo | None: """Get the PeerInfo for a given peer ID if it exists in the bucket.""" if peer_id in self.peers: return self.peers[peer_id][0] return None def size(self) -> int: """Get the number of peers in the bucket.""" return len(self.peers) def get_stale_peers(self, stale_threshold_seconds: int = 3600) -> list[ID]: """ Get peers that haven't been pinged recently. params: stale_threshold_seconds: Time in seconds params: after which a peer is considered stale Returns ------- list[ID] List of peer IDs that need to be refreshed """ current_time = time.time() stale_peers = [] for peer_id, (_, last_seen) in self.peers.items(): if current_time - last_seen > stale_threshold_seconds: stale_peers.append(peer_id) return stale_peers async def _periodic_peer_refresh(self) -> None: """Background task to periodically refresh peers""" try: while True: await trio.sleep(PEER_REFRESH_INTERVAL) # Check every minute # Find stale peers (not pinged in last hour) stale_peers = self.get_stale_peers( stale_threshold_seconds=STALE_PEER_THRESHOLD ) if stale_peers: logger.debug(f"Found {len(stale_peers)} stale peers to refresh") for peer_id in stale_peers: try: # Try to ping the peer logger.debug("Pinging stale peer %s", peer_id) responce = await self._ping_peer(peer_id) if responce: # Update the last seen time self.refresh_peer_last_seen(peer_id) logger.debug(f"Refreshed peer {peer_id}") else: # If ping fails, remove the peer logger.debug(f"Failed to ping peer {peer_id}") self.remove_peer(peer_id) logger.info(f"Removed unresponsive peer {peer_id}") logger.debug(f"Successfully refreshed peer {peer_id}") except Exception as e: # If ping fails, remove the peer logger.debug( "Failed to ping peer %s: %s", peer_id, e, ) self.remove_peer(peer_id) logger.info(f"Removed unresponsive peer {peer_id}") except trio.Cancelled: logger.debug("Peer refresh task cancelled") except Exception as e: logger.error(f"Error in peer refresh task: {e}", exc_info=True) async def _ping_peer(self, peer_id: ID) -> bool: """ Ping a peer using protobuf message to check if it's still alive and update last seen time. params: peer_id: The ID of the peer to ping Returns ------- bool True if ping successful, False otherwise """ result = False # Get peer info directly from the bucket peer_info = self.get_peer_info(peer_id) if not peer_info: raise ValueError(f"Peer {peer_id} not in bucket") try: # Open a stream to the peer with the DHT protocol stream = await self.host.new_stream(peer_id, [PROTOCOL_ID]) try: # Create ping protobuf message ping_msg = Message() ping_msg.type = Message.PING # Use correct enum # Serialize and send with length prefix (4 bytes big-endian) msg_bytes = ping_msg.SerializeToString() logger.debug( f"Sending PING message to {peer_id}, size: {len(msg_bytes)} bytes" ) await stream.write(len(msg_bytes).to_bytes(4, byteorder="big")) await stream.write(msg_bytes) # Wait for response with timeout with trio.move_on_after(2): # 2 second timeout # Read response length (4 bytes) length_bytes = await stream.read(4) if not length_bytes or len(length_bytes) < 4: logger.warning(f"Peer {peer_id} disconnected during ping") return False msg_len = int.from_bytes(length_bytes, byteorder="big") if ( msg_len <= 0 or msg_len > 1024 * 1024 ): # Sanity check on message size logger.warning( f"Invalid message length from {peer_id}: {msg_len}" ) return False logger.debug( f"Receiving response from {peer_id}, size: {msg_len} bytes" ) # Read full message response_bytes = await stream.read(msg_len) if not response_bytes: logger.warning(f"Failed to read response from {peer_id}") return False # Parse protobuf response response = Message() try: response.ParseFromString(response_bytes) except Exception as e: logger.warning( f"Failed to parse protobuf response from {peer_id}: {e}" ) return False if response.type == Message.PING: # Update the last seen timestamp for this peer logger.debug(f"Successfully pinged peer {peer_id}") result = True return result else: logger.warning( f"Unexpected response type from {peer_id}: {response.type}" ) return False # If we get here, the ping timed out logger.warning(f"Ping to peer {peer_id} timed out") return False finally: await stream.close() return result except Exception as e: logger.error(f"Error pinging peer {peer_id}: {str(e)}") return False def refresh_peer_last_seen(self, peer_id: ID) -> bool: """ Update the last-seen timestamp for a peer in the bucket. params: peer_id: The ID of the peer to refresh Returns ------- bool True if the peer was found and refreshed, False otherwise """ if peer_id in self.peers: # Get current peer info and update the timestamp peer_info, _ = self.peers[peer_id] current_time = time.time() self.peers[peer_id] = (peer_info, current_time) # Move to end of ordered dict to mark as most recently seen self.peers.move_to_end(peer_id) return True return False def key_in_range(self, key: bytes) -> bool: """ Check if a key is in the range of this bucket. params: key: The key to check (bytes) Returns ------- bool True if the key is in range, False otherwise """ key_int = key_to_int(key) return self.min_range <= key_int < self.max_range def peer_id_in_range(self, peer_id: ID) -> bool: """ Check if a peer ID is in the range of this bucket. params: peer_id: The peer ID to check Returns ------- bool True if the peer ID is in range, False otherwise """ key = peer_id_to_key(peer_id) return self.key_in_range(key) def split(self) -> tuple["KBucket", "KBucket"]: """ Split the bucket into two buckets. Returns ------- tuple (lower_bucket, upper_bucket) """ midpoint = (self.min_range + self.max_range) // 2 lower_bucket = KBucket(self.host, self.bucket_size, self.min_range, midpoint) upper_bucket = KBucket(self.host, self.bucket_size, midpoint, self.max_range) # Redistribute peers for peer_id, (peer_info, timestamp) in self.peers.items(): peer_key = peer_id_to_key(peer_id) peer_key_int = key_to_int(peer_key) if peer_key_int < midpoint: lower_bucket.peers[peer_id] = (peer_info, timestamp) else: upper_bucket.peers[peer_id] = (peer_info, timestamp) return lower_bucket, upper_bucket class RoutingTable: """ The Kademlia routing table maintains information on which peers to contact for any given peer ID in the network. """ def __init__(self, local_id: ID, host: IHost) -> None: """ Initialize the routing table. :param local_id: The ID of the local node. :param host: The host this routing table belongs to. """ self.local_id = local_id self.host = host self.buckets = [KBucket(host, BUCKET_SIZE)] async def add_peer(self, peer_obj: PeerInfo | ID) -> bool: """ Add a peer to the routing table. :param peer_obj: Either PeerInfo object or peer ID to add Returns ------- bool: True if the peer was added or updated, False otherwise """ peer_id = None peer_info = None try: # Handle different types of input if isinstance(peer_obj, PeerInfo): # Already have PeerInfo object peer_info = peer_obj peer_id = peer_obj.peer_id else: # Assume it's a peer ID peer_id = peer_obj # Try to get addresses from the peerstore if available try: addrs = self.host.get_peerstore().addrs(peer_id) if addrs: # Create PeerInfo object peer_info = PeerInfo(peer_id, addrs) else: logger.debug( "No addresses found for peer %s in peerstore, skipping", peer_id, ) return False except Exception as peerstore_error: # Handle case where peer is not in peerstore yet logger.debug( "Peer %s not found in peerstore: %s, skipping", peer_id, str(peerstore_error), ) return False # Don't add ourselves if peer_id == self.local_id: return False # Find the right bucket for this peer bucket = self.find_bucket(peer_id) # Try to add to the bucket success = await bucket.add_peer(peer_info) if success: logger.debug(f"Successfully added peer {peer_id} to routing table") return True # If bucket is full and couldn't add peer, try splitting the bucket # Only split if the bucket contains our Peer ID if self._should_split_bucket(bucket): logger.debug( f"Bucket is full, attempting to split bucket for peer {peer_id}" ) split_success = self._split_bucket(bucket) if split_success: # After splitting, # find the appropriate bucket for the peer and try to add it target_bucket = self.find_bucket(peer_info.peer_id) success = await target_bucket.add_peer(peer_info) if success: logger.debug( f"Successfully added peer {peer_id} after bucket split" ) return True else: logger.debug( f"Failed to add peer {peer_id} even after bucket split" ) return False else: logger.debug(f"Failed to split bucket for peer {peer_id}") return False else: logger.debug( f"Bucket is full and cannot be split, peer {peer_id} not added" ) return False except Exception as e: logger.debug(f"Error adding peer {peer_obj} to routing table: {e}") return False def remove_peer(self, peer_id: ID) -> bool: """ Remove a peer from the routing table. :param peer_id: The ID of the peer to remove Returns ------- bool: True if the peer was removed, False otherwise """ bucket = self.find_bucket(peer_id) return bucket.remove_peer(peer_id) def find_bucket(self, peer_id: ID) -> KBucket: """ Find the bucket that would contain the given peer ID. :param peer_id: The peer ID to find a bucket for Returns ------- KBucket: The bucket for this peer """ for bucket in self.buckets: if bucket.peer_id_in_range(peer_id): return bucket return self.buckets[0] def find_local_closest_peers(self, key: bytes, count: int = 20) -> list[ID]: """ Find the closest peers to a given key. :param key: The key to find closest peers to (bytes) :param count: Maximum number of peers to return Returns ------- List[ID]: List of peer IDs closest to the key """ # Get all peers from all buckets all_peers = [] for bucket in self.buckets: all_peers.extend(bucket.peer_ids()) # Sort by XOR distance to the key def distance_to_key(peer_id: ID) -> int: peer_key = peer_id_to_key(peer_id) return xor_distance(peer_key, key) all_peers.sort(key=distance_to_key) return all_peers[:count] def get_peer_ids(self) -> list[ID]: """ Get all peer IDs in the routing table. Returns ------- :param List[ID]: List of all peer IDs """ peers = [] for bucket in self.buckets: peers.extend(bucket.peer_ids()) return peers def get_peer_info(self, peer_id: ID) -> PeerInfo | None: """ Get the peer info for a specific peer. :param peer_id: The ID of the peer to get info for Returns ------- PeerInfo: The peer info, or None if not found """ bucket = self.find_bucket(peer_id) return bucket.get_peer_info(peer_id) def peer_in_table(self, peer_id: ID) -> bool: """ Check if a peer is in the routing table. :param peer_id: The ID of the peer to check Returns ------- bool: True if the peer is in the routing table, False otherwise """ bucket = self.find_bucket(peer_id) return bucket.has_peer(peer_id) def size(self) -> int: """ Get the number of peers in the routing table. Returns ------- int: Number of peers """ count = 0 for bucket in self.buckets: count += bucket.size() return count def get_stale_peers(self, stale_threshold_seconds: int = 3600) -> list[ID]: """ Get all stale peers from all buckets params: stale_threshold_seconds: Time in seconds after which a peer is considered stale Returns ------- list[ID] List of stale peer IDs """ stale_peers = [] for bucket in self.buckets: stale_peers.extend(bucket.get_stale_peers(stale_threshold_seconds)) return stale_peers def get_peer_infos(self) -> list[PeerInfo]: """ Get all PeerInfo objects in the routing table. Returns ------- List[PeerInfo]: List of all PeerInfo objects """ peer_infos = [] for bucket in self.buckets: peer_infos.extend(bucket.peer_infos()) return peer_infos def cleanup_routing_table(self) -> None: """ Cleanup the routing table by removing all data. This is useful for resetting the routing table during tests or reinitialization. """ self.buckets = [KBucket(self.host, BUCKET_SIZE)] logger.info("Routing table cleaned up, all data removed.") def _should_split_bucket(self, bucket: KBucket) -> bool: """ Check if a bucket should be split according to Kademlia rules. :param bucket: The bucket to check :return: True if the bucket should be split """ # Check if we've exceeded maximum buckets if len(self.buckets) >= MAXIMUM_BUCKETS: logger.debug("Maximum number of buckets reached, cannot split") return False # Check if the bucket contains our local ID local_key = peer_id_to_key(self.local_id) local_key_int = key_to_int(local_key) contains_local_id = bucket.min_range <= local_key_int < bucket.max_range logger.debug( f"Bucket range: {bucket.min_range} - {bucket.max_range}, " f"local_key_int: {local_key_int}, contains_local: {contains_local_id}" ) return contains_local_id def _split_bucket(self, bucket: KBucket) -> bool: """ Split a bucket into two buckets. :param bucket: The bucket to split :return: True if the bucket was successfully split """ try: # Find the bucket index bucket_index = self.buckets.index(bucket) logger.debug(f"Splitting bucket at index {bucket_index}") # Split the bucket lower_bucket, upper_bucket = bucket.split() # Replace the original bucket with the two new buckets self.buckets[bucket_index] = lower_bucket self.buckets.insert(bucket_index + 1, upper_bucket) logger.debug( f"Bucket split successful. New bucket count: {len(self.buckets)}" ) logger.debug( f"Lower bucket range: " f"{lower_bucket.min_range} - {lower_bucket.max_range}, " f"peers: {lower_bucket.size()}" ) logger.debug( f"Upper bucket range: " f"{upper_bucket.min_range} - {upper_bucket.max_range}, " f"peers: {upper_bucket.size()}" ) return True except Exception as e: logger.error(f"Error splitting bucket: {e}") return False