diff --git a/libp2p/kad_dht/routing_table.py b/libp2p/kad_dht/routing_table.py index 15b6721e..b688c1c7 100644 --- a/libp2p/kad_dht/routing_table.py +++ b/libp2p/kad_dht/routing_table.py @@ -8,6 +8,7 @@ from collections import ( import logging import time +import multihash import trio from libp2p.abc import ( @@ -40,6 +41,22 @@ PEER_REFRESH_INTERVAL = 60 # Interval to refresh peers in seconds STALE_PEER_THRESHOLD = 3600 # Time in seconds after which a peer is considered stale +def peer_id_to_key(peer_id: ID) -> bytes: + """ + Convert a peer ID to a 256-bit key for routing table operations. + This normalizes all peer IDs to exactly 256 bits by hashing them with SHA-256. + + :param peer_id: The peer ID to convert + :return: 32-byte (256-bit) key for routing table operations + """ + return multihash.digest(peer_id.to_bytes(), "sha2-256").digest + + +def key_to_int(key: bytes) -> int: + """Convert a 256-bit key to an integer for range calculations.""" + return int.from_bytes(key, byteorder="big") + + class KBucket: """ A k-bucket implementation for the Kademlia DHT. @@ -357,9 +374,24 @@ class KBucket: True if the key is in range, False otherwise """ - key_int = int.from_bytes(key, byteorder="big") + key_int = key_to_int(key) return self.min_range <= key_int < self.max_range + def peer_id_in_range(self, peer_id: ID) -> bool: + """ + Check if a peer ID is in the range of this bucket. + + params: peer_id: The peer ID to check + + Returns + ------- + bool + True if the peer ID is in range, False otherwise + + """ + key = peer_id_to_key(peer_id) + return self.key_in_range(key) + def split(self) -> tuple["KBucket", "KBucket"]: """ Split the bucket into two buckets. @@ -376,8 +408,9 @@ class KBucket: # Redistribute peers for peer_id, (peer_info, timestamp) in self.peers.items(): - peer_key = int.from_bytes(peer_id.to_bytes(), byteorder="big") - if peer_key < midpoint: + peer_key = peer_id_to_key(peer_id) + peer_key_int = key_to_int(peer_key) + if peer_key_int < midpoint: lower_bucket.peers[peer_id] = (peer_info, timestamp) else: upper_bucket.peers[peer_id] = (peer_info, timestamp) @@ -458,7 +491,38 @@ class RoutingTable: success = await bucket.add_peer(peer_info) if success: logger.debug(f"Successfully added peer {peer_id} to routing table") - return success + return True + + # If bucket is full and couldn't add peer, try splitting the bucket + # Only split if the bucket contains our Peer ID + if self._should_split_bucket(bucket): + logger.debug( + f"Bucket is full, attempting to split bucket for peer {peer_id}" + ) + split_success = self._split_bucket(bucket) + if split_success: + # After splitting, + # find the appropriate bucket for the peer and try to add it + target_bucket = self.find_bucket(peer_info.peer_id) + success = await target_bucket.add_peer(peer_info) + if success: + logger.debug( + f"Successfully added peer {peer_id} after bucket split" + ) + return True + else: + logger.debug( + f"Failed to add peer {peer_id} even after bucket split" + ) + return False + else: + logger.debug(f"Failed to split bucket for peer {peer_id}") + return False + else: + logger.debug( + f"Bucket is full and cannot be split, peer {peer_id} not added" + ) + return False except Exception as e: logger.debug(f"Error adding peer {peer_obj} to routing table: {e}") @@ -480,9 +544,9 @@ class RoutingTable: def find_bucket(self, peer_id: ID) -> KBucket: """ - Find the bucket that would contain the given peer ID or PeerInfo. + Find the bucket that would contain the given peer ID. - :param peer_obj: Either a peer ID or a PeerInfo object + :param peer_id: The peer ID to find a bucket for Returns ------- @@ -490,7 +554,7 @@ class RoutingTable: """ for bucket in self.buckets: - if bucket.key_in_range(peer_id.to_bytes()): + if bucket.peer_id_in_range(peer_id): return bucket return self.buckets[0] @@ -513,7 +577,11 @@ class RoutingTable: all_peers.extend(bucket.peer_ids()) # Sort by XOR distance to the key - all_peers.sort(key=lambda p: xor_distance(p.to_bytes(), key)) + def distance_to_key(peer_id: ID) -> int: + peer_key = peer_id_to_key(peer_id) + return xor_distance(peer_key, key) + + all_peers.sort(key=distance_to_key) return all_peers[:count] @@ -591,6 +659,20 @@ class RoutingTable: stale_peers.extend(bucket.get_stale_peers(stale_threshold_seconds)) return stale_peers + def get_peer_infos(self) -> list[PeerInfo]: + """ + Get all PeerInfo objects in the routing table. + + Returns + ------- + List[PeerInfo]: List of all PeerInfo objects + + """ + peer_infos = [] + for bucket in self.buckets: + peer_infos.extend(bucket.peer_infos()) + return peer_infos + def cleanup_routing_table(self) -> None: """ Cleanup the routing table by removing all data. @@ -598,3 +680,66 @@ class RoutingTable: """ self.buckets = [KBucket(self.host, BUCKET_SIZE)] logger.info("Routing table cleaned up, all data removed.") + + def _should_split_bucket(self, bucket: KBucket) -> bool: + """ + Check if a bucket should be split according to Kademlia rules. + + :param bucket: The bucket to check + :return: True if the bucket should be split + """ + # Check if we've exceeded maximum buckets + if len(self.buckets) >= MAXIMUM_BUCKETS: + logger.debug("Maximum number of buckets reached, cannot split") + return False + + # Check if the bucket contains our local ID + local_key = peer_id_to_key(self.local_id) + local_key_int = key_to_int(local_key) + contains_local_id = bucket.min_range <= local_key_int < bucket.max_range + + logger.debug( + f"Bucket range: {bucket.min_range} - {bucket.max_range}, " + f"local_key_int: {local_key_int}, contains_local: {contains_local_id}" + ) + + return contains_local_id + + def _split_bucket(self, bucket: KBucket) -> bool: + """ + Split a bucket into two buckets. + + :param bucket: The bucket to split + :return: True if the bucket was successfully split + """ + try: + # Find the bucket index + bucket_index = self.buckets.index(bucket) + logger.debug(f"Splitting bucket at index {bucket_index}") + + # Split the bucket + lower_bucket, upper_bucket = bucket.split() + + # Replace the original bucket with the two new buckets + self.buckets[bucket_index] = lower_bucket + self.buckets.insert(bucket_index + 1, upper_bucket) + + logger.debug( + f"Bucket split successful. New bucket count: {len(self.buckets)}" + ) + logger.debug( + f"Lower bucket range: " + f"{lower_bucket.min_range} - {lower_bucket.max_range}, " + f"peers: {lower_bucket.size()}" + ) + logger.debug( + f"Upper bucket range: " + f"{upper_bucket.min_range} - {upper_bucket.max_range}, " + f"peers: {upper_bucket.size()}" + ) + + return True + + except Exception as e: + logger.error(f"Error splitting bucket: {e}") + return False diff --git a/tests/core/kad_dht/test_unit_routing_table.py b/tests/core/kad_dht/test_unit_routing_table.py index af77eda5..38c29adc 100644 --- a/tests/core/kad_dht/test_unit_routing_table.py +++ b/tests/core/kad_dht/test_unit_routing_table.py @@ -226,6 +226,32 @@ class TestKBucket: class TestRoutingTable: """Test suite for RoutingTable class.""" + @pytest.mark.trio + async def test_kbucket_split_behavior(self, mock_host, local_peer_id): + """ + Test that adding more than BUCKET_SIZE peers to the routing table + triggers kbucket splitting and all peers are added. + """ + routing_table = RoutingTable(local_peer_id, mock_host) + + num_peers = BUCKET_SIZE + 5 + peer_ids = [] + for i in range(num_peers): + key_pair = create_new_key_pair() + peer_id = ID.from_pubkey(key_pair.public_key) + peer_info = PeerInfo(peer_id, [Multiaddr(f"/ip4/127.0.0.1/tcp/{9000 + i}")]) + peer_ids.append(peer_id) + added = await routing_table.add_peer(peer_info) + assert added, f"Peer {peer_id} should be added" + + assert len(routing_table.buckets) > 1, "KBucket splitting did not occur" + for pid in peer_ids: + assert routing_table.peer_in_table(pid), f"Peer {pid} not found after split" + all_peer_ids = routing_table.get_peer_ids() + assert set(peer_ids).issubset(set(all_peer_ids)), ( + "Not all peers present after split" + ) + @pytest.fixture def mock_host(self): """Create a mock host for testing."""