From cea1985c5c7b8aed6ea2b202b775adc949ad682b Mon Sep 17 00:00:00 2001 From: lla-dane Date: Thu, 14 Aug 2025 10:39:48 +0530 Subject: [PATCH] add reissuing mechanism of records if addrs dont change --- libp2p/abc.py | 7 +++ libp2p/host/basic_host.py | 9 +++ libp2p/kad_dht/kad_dht.py | 51 ++++------------ libp2p/kad_dht/peer_routing.py | 16 ++--- libp2p/kad_dht/provider_store.py | 21 ++----- libp2p/kad_dht/utils.py | 29 +++++++++ libp2p/kad_dht/value_store.py | 19 ++---- libp2p/peer/envelope.py | 5 ++ libp2p/peer/peerstore.py | 9 +++ tests/core/kad_dht/test_kad_dht.py | 95 ++++++++++++++++++++++++++---- 10 files changed, 170 insertions(+), 91 deletions(-) diff --git a/libp2p/abc.py b/libp2p/abc.py index 90ad6a45..614af8bf 100644 --- a/libp2p/abc.py +++ b/libp2p/abc.py @@ -970,6 +970,13 @@ class IPeerStore( # --------CERTIFIED-ADDR-BOOK---------- + @abstractmethod + def get_local_record(self) -> Optional["Envelope"]: + """Get the local-peer-record wrapped in Envelope""" + + def set_local_record(self, envelope: "Envelope") -> None: + """Set the local-peer-record wrapped in Envelope""" + @abstractmethod def consume_peer_record(self, envelope: "Envelope", ttl: int) -> bool: """ diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index b40b0128..a0311bd8 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -43,6 +43,7 @@ from libp2p.peer.id import ( from libp2p.peer.peerinfo import ( PeerInfo, ) +from libp2p.peer.peerstore import create_signed_peer_record from libp2p.protocol_muxer.exceptions import ( MultiselectClientError, MultiselectError, @@ -110,6 +111,14 @@ class BasicHost(IHost): if bootstrap: self.bootstrap = BootstrapDiscovery(network, bootstrap) + # Cache a signed-record if the local-node in the PeerStore + envelope = create_signed_peer_record( + self.get_id(), + self.get_addrs(), + self.get_private_key(), + ) + self.get_peerstore().set_local_record(envelope) + def get_id(self) -> ID: """ :return: peer_id of host diff --git a/libp2p/kad_dht/kad_dht.py b/libp2p/kad_dht/kad_dht.py index db0e635e..f93aa75e 100644 --- a/libp2p/kad_dht/kad_dht.py +++ b/libp2p/kad_dht/kad_dht.py @@ -22,7 +22,7 @@ from libp2p.abc import ( IHost, ) from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager -from libp2p.kad_dht.utils import maybe_consume_signed_record +from libp2p.kad_dht.utils import env_to_send_in_RPC, maybe_consume_signed_record from libp2p.network.stream.net_stream import ( INetStream, ) @@ -33,7 +33,6 @@ from libp2p.peer.id import ( from libp2p.peer.peerinfo import ( PeerInfo, ) -from libp2p.peer.peerstore import create_signed_peer_record from libp2p.tools.async_service import ( Service, ) @@ -319,12 +318,8 @@ class KadDHT(Service): ) # Create sender_signed_peer_record - envelope = create_signed_peer_record( - self.host.get_id(), - self.host.get_addrs(), - self.host.get_private_key(), - ) - response.senderRecord = envelope.marshal_envelope() + envelope_bytes, bool = env_to_send_in_RPC(self.host) + response.senderRecord = envelope_bytes # Serialize and send response response_bytes = response.SerializeToString() @@ -383,12 +378,8 @@ class KadDHT(Service): response.key = key # Add sender's signed-peer-record - envelope = create_signed_peer_record( - self.host.get_id(), - self.host.get_addrs(), - self.host.get_private_key(), - ) - response.senderRecord = envelope.marshal_envelope() + envelope_bytes, bool = env_to_send_in_RPC(self.host) + response.senderRecord = envelope_bytes response_bytes = response.SerializeToString() await stream.write(varint.encode(len(response_bytes))) @@ -416,12 +407,8 @@ class KadDHT(Service): response.key = key # Create sender_signed_peer_record for the response - envelope = create_signed_peer_record( - self.host.get_id(), - self.host.get_addrs(), - self.host.get_private_key(), - ) - response.senderRecord = envelope.marshal_envelope() + envelope_bytes, bool = env_to_send_in_RPC(self.host) + response.senderRecord = envelope_bytes # Add provider information to response for provider_info in providers: @@ -512,12 +499,8 @@ class KadDHT(Service): response.record.timeReceived = str(time.time()) # Create sender_signed_peer_record - envelope = create_signed_peer_record( - self.host.get_id(), - self.host.get_addrs(), - self.host.get_private_key(), - ) - response.senderRecord = envelope.marshal_envelope() + envelope_bytes, bool = env_to_send_in_RPC(self.host) + response.senderRecord = envelope_bytes # Serialize and send response response_bytes = response.SerializeToString() @@ -533,12 +516,8 @@ class KadDHT(Service): response.key = key # Create sender_signed_peer_record for the response - envelope = create_signed_peer_record( - self.host.get_id(), - self.host.get_addrs(), - self.host.get_private_key(), - ) - response.senderRecord = envelope.marshal_envelope() + envelope_bytes, bool = env_to_send_in_RPC(self.host) + response.senderRecord = envelope_bytes # Add closest peers to key closest_peers = self.routing_table.find_local_closest_peers( @@ -616,12 +595,8 @@ class KadDHT(Service): response.key = key # Create sender_signed_peer_record for the response - envelope = create_signed_peer_record( - self.host.get_id(), - self.host.get_addrs(), - self.host.get_private_key(), - ) - response.senderRecord = envelope.marshal_envelope() + envelope_bytes, bool = env_to_send_in_RPC(self.host) + response.senderRecord = envelope_bytes # Serialize and send response response_bytes = response.SerializeToString() diff --git a/libp2p/kad_dht/peer_routing.py b/libp2p/kad_dht/peer_routing.py index e36f7caf..4362ffea 100644 --- a/libp2p/kad_dht/peer_routing.py +++ b/libp2p/kad_dht/peer_routing.py @@ -22,7 +22,6 @@ from libp2p.peer.id import ( from libp2p.peer.peerinfo import ( PeerInfo, ) -from libp2p.peer.peerstore import create_signed_peer_record from .common import ( ALPHA, @@ -35,6 +34,7 @@ from .routing_table import ( RoutingTable, ) from .utils import ( + env_to_send_in_RPC, maybe_consume_signed_record, sort_peer_ids_by_distance, ) @@ -259,10 +259,8 @@ class PeerRouting(IPeerRouting): find_node_msg.key = target_key # Set target key directly as bytes # Create sender_signed_peer_record - envelope = create_signed_peer_record( - self.host.get_id(), self.host.get_addrs(), self.host.get_private_key() - ) - find_node_msg.senderRecord = envelope.marshal_envelope() + envelope_bytes, bool = env_to_send_in_RPC(self.host) + find_node_msg.senderRecord = envelope_bytes # Serialize and send the protobuf message with varint length prefix proto_bytes = find_node_msg.SerializeToString() @@ -381,12 +379,8 @@ class PeerRouting(IPeerRouting): response.type = Message.MessageType.FIND_NODE # Create sender_signed_peer_record for the response - envelope = create_signed_peer_record( - self.host.get_id(), - self.host.get_addrs(), - self.host.get_private_key(), - ) - response.senderRecord = envelope.marshal_envelope() + envelope_bytes, bool = env_to_send_in_RPC(self.host) + response.senderRecord = envelope_bytes # Add peer information to response for peer_id in closest_peers: diff --git a/libp2p/kad_dht/provider_store.py b/libp2p/kad_dht/provider_store.py index 21bd1c80..4c6a8e06 100644 --- a/libp2p/kad_dht/provider_store.py +++ b/libp2p/kad_dht/provider_store.py @@ -22,14 +22,13 @@ from libp2p.abc import ( from libp2p.custom_types import ( TProtocol, ) -from libp2p.kad_dht.utils import maybe_consume_signed_record +from libp2p.kad_dht.utils import env_to_send_in_RPC, maybe_consume_signed_record from libp2p.peer.id import ( ID, ) from libp2p.peer.peerinfo import ( PeerInfo, ) -from libp2p.peer.peerstore import create_signed_peer_record from .common import ( ALPHA, @@ -243,12 +242,8 @@ class ProviderStore: message.key = key # Create sender's signed-peer-record - envelope = create_signed_peer_record( - self.host.get_id(), - self.host.get_addrs(), - self.host.get_private_key(), - ) - message.senderRecord = envelope.marshal_envelope() + envelope_bytes, bool = env_to_send_in_RPC(self.host) + message.senderRecord = envelope_bytes # Add our provider info provider = message.providerPeers.add() @@ -256,7 +251,7 @@ class ProviderStore: provider.addrs.extend(addrs) # Add the provider's signed-peer-record - provider.signedRecord = envelope.marshal_envelope() + provider.signedRecord = envelope_bytes # Serialize and send the message proto_bytes = message.SerializeToString() @@ -394,12 +389,8 @@ class ProviderStore: message.key = key # Create sender's signed-peer-record - envelope = create_signed_peer_record( - self.host.get_id(), - self.host.get_addrs(), - self.host.get_private_key(), - ) - message.senderRecord = envelope.marshal_envelope() + envelope_bytes, bool = env_to_send_in_RPC(self.host) + message.senderRecord = envelope_bytes # Serialize and send the message proto_bytes = message.SerializeToString() diff --git a/libp2p/kad_dht/utils.py b/libp2p/kad_dht/utils.py index 64976cb3..3cf79efd 100644 --- a/libp2p/kad_dht/utils.py +++ b/libp2p/kad_dht/utils.py @@ -12,6 +12,7 @@ from libp2p.peer.envelope import consume_envelope from libp2p.peer.id import ( ID, ) +from libp2p.peer.peerstore import create_signed_peer_record from .pb.kademlia_pb2 import ( Message, @@ -54,6 +55,34 @@ def maybe_consume_signed_record(msg: Message | Message.Peer, host: IHost) -> boo return True +def env_to_send_in_RPC(host: IHost) -> tuple[bytes, bool]: + listen_addrs_set = {addr for addr in host.get_addrs()} + local_env = host.get_peerstore().get_local_record() + + if local_env is None: + # No cached SPR yet -> create one + return issue_and_cache_local_record(host), True + else: + record_addrs_set = local_env._env_addrs_set() + if record_addrs_set == listen_addrs_set: + # Perfect match -> reuse cached envelope + return local_env.marshal_envelope(), False + else: + # Addresses changed -> issue a new SPR and cache it + return issue_and_cache_local_record(host), True + + +def issue_and_cache_local_record(host: IHost) -> bytes: + env = create_signed_peer_record( + host.get_id(), + host.get_addrs(), + host.get_private_key(), + ) + # Cache it for nexxt time use + host.get_peerstore().set_local_record(env) + return env.marshal_envelope() + + def create_key_from_binary(binary_data: bytes) -> bytes: """ Creates a key for the DHT by hashing binary data with SHA-256. diff --git a/libp2p/kad_dht/value_store.py b/libp2p/kad_dht/value_store.py index adc37b72..bb143dcd 100644 --- a/libp2p/kad_dht/value_store.py +++ b/libp2p/kad_dht/value_store.py @@ -15,11 +15,10 @@ from libp2p.abc import ( from libp2p.custom_types import ( TProtocol, ) -from libp2p.kad_dht.utils import maybe_consume_signed_record +from libp2p.kad_dht.utils import env_to_send_in_RPC, maybe_consume_signed_record from libp2p.peer.id import ( ID, ) -from libp2p.peer.peerstore import create_signed_peer_record from .common import ( DEFAULT_TTL, @@ -113,12 +112,8 @@ class ValueStore: message.type = Message.MessageType.PUT_VALUE # Create sender's signed-peer-record - envelope = create_signed_peer_record( - self.host.get_id(), - self.host.get_addrs(), - self.host.get_private_key(), - ) - message.senderRecord = envelope.marshal_envelope() + envelope_bytes, bool = env_to_send_in_RPC(self.host) + message.senderRecord = envelope_bytes # Set message fields message.key = key @@ -245,12 +240,8 @@ class ValueStore: message.key = key # Create sender's signed-peer-record - envelope = create_signed_peer_record( - self.host.get_id(), - self.host.get_addrs(), - self.host.get_private_key(), - ) - message.senderRecord = envelope.marshal_envelope() + envelope_bytes, bool = env_to_send_in_RPC(self.host) + message.senderRecord = envelope_bytes # Serialize and send the protobuf message proto_bytes = message.SerializeToString() diff --git a/libp2p/peer/envelope.py b/libp2p/peer/envelope.py index e93a8280..f8bf9f43 100644 --- a/libp2p/peer/envelope.py +++ b/libp2p/peer/envelope.py @@ -1,5 +1,7 @@ from typing import Any, cast +import multiaddr + from libp2p.crypto.ed25519 import Ed25519PublicKey from libp2p.crypto.keys import PrivateKey, PublicKey from libp2p.crypto.rsa import RSAPublicKey @@ -131,6 +133,9 @@ class Envelope: ) return False + def _env_addrs_set(self) -> set[multiaddr.Multiaddr]: + return {b for b in self.record().addrs} + def pub_key_to_protobuf(pub_key: PublicKey) -> cryto_pb.PublicKey: """ diff --git a/libp2p/peer/peerstore.py b/libp2p/peer/peerstore.py index 0faccb45..ad6f08db 100644 --- a/libp2p/peer/peerstore.py +++ b/libp2p/peer/peerstore.py @@ -65,8 +65,17 @@ class PeerStore(IPeerStore): self.peer_data_map = defaultdict(PeerData) self.addr_update_channels: dict[ID, MemorySendChannel[Multiaddr]] = {} self.peer_record_map: dict[ID, PeerRecordState] = {} + self.local_peer_record: Envelope | None = None self.max_records = max_records + def get_local_record(self) -> Envelope | None: + """Get the local-signed-record wrapped in Envelope""" + return self.local_peer_record + + def set_local_record(self, envelope: Envelope) -> None: + """Set the local-signed-record wrapped in Envelope""" + self.local_peer_record = envelope + def peer_info(self, peer_id: ID) -> PeerInfo: """ :param peer_id: peer ID to get info for diff --git a/tests/core/kad_dht/test_kad_dht.py b/tests/core/kad_dht/test_kad_dht.py index 70d9a5e9..a2e9ec4c 100644 --- a/tests/core/kad_dht/test_kad_dht.py +++ b/tests/core/kad_dht/test_kad_dht.py @@ -9,9 +9,12 @@ This module tests core functionality of the Kademlia DHT including: import hashlib import logging +import os +from unittest.mock import patch import uuid import pytest +import multiaddr import trio from libp2p.kad_dht.kad_dht import ( @@ -77,6 +80,18 @@ async def test_find_node(dht_pair: tuple[KadDHT, KadDHT]): """Test that nodes can find each other in the DHT.""" dht_a, dht_b = dht_pair + # An extra FIND_NODE req is sent between the 2 nodes while dht creation, + # so both the nodes will have records of each other before the next FIND_NODE + # req is sent + envelope_a = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id()) + envelope_b = dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id()) + + assert isinstance(envelope_a, Envelope) + assert isinstance(envelope_b, Envelope) + + record_a = envelope_a.record() + record_b = envelope_b.record() + # Node A should be able to find Node B with trio.fail_after(TEST_TIMEOUT): found_info = await dht_a.find_peer(dht_b.host.get_id()) @@ -91,6 +106,26 @@ async def test_find_node(dht_pair: tuple[KadDHT, KadDHT]): dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id()), Envelope ) + # These are the records that were sent betweeen the peers during the FIND_NODE req + envelope_a_find_peer = dht_a.host.get_peerstore().get_peer_record( + dht_b.host.get_id() + ) + envelope_b_find_peer = dht_b.host.get_peerstore().get_peer_record( + dht_a.host.get_id() + ) + + assert isinstance(envelope_a_find_peer, Envelope) + assert isinstance(envelope_b_find_peer, Envelope) + + record_a_find_peer = envelope_a_find_peer.record() + record_b_find_peer = envelope_b_find_peer.record() + + # This proves that both the records are same, and a latest cached signed record + # was passed between the peers during FIND_NODE exceution, which proves the + # signed-record transfer/re-issuing works correctly in FIND_NODE executions. + assert record_a.seq == record_a_find_peer.seq + assert record_b.seq == record_b_find_peer.seq + # Verify that the found peer has the correct peer ID assert found_info is not None, "Failed to find the target peer" assert found_info.peer_id == dht_b.host.get_id(), "Found incorrect peer ID" @@ -144,11 +179,11 @@ async def test_put_and_get_value(dht_pair: tuple[KadDHT, KadDHT]): record_a_put_value = envelope_a_put_value.record() record_b_put_value = envelope_b_put_value.record() - # This proves that both the records are different, and a new signed record + # This proves that both the records are same, and a latest cached signed record # was passed between the peers during PUT_VALUE exceution, which proves the - # signed-record transfer works correctly in PUT_VALUE executions. - assert record_a.seq < record_a_put_value.seq - assert record_b.seq < record_b_put_value.seq + # signed-record transfer/re-issuing works correctly in PUT_VALUE executions. + assert record_a.seq == record_a_put_value.seq + assert record_b.seq == record_b_put_value.seq # # Log debugging information logger.debug("Put value with key %s...", key.hex()[:10]) @@ -234,11 +269,12 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]): record_a_add_prov = envelope_a_add_prov.record() record_b_add_prov = envelope_b_add_prov.record() - # This proves that both the records are different, and a new signed record + # This proves that both the records are same, the latest cached signed record # was passed between the peers during ADD_PROVIDER exceution, which proves the - # signed-record transfer works correctly in ADD_PROVIDER executions. - assert record_a.seq < record_a_add_prov.seq - assert record_b.seq < record_b_add_prov.seq + # signed-record transfer/re-issuing of the latest record works correctly in + # ADD_PROVIDER executions. + assert record_a.seq == record_a_add_prov.seq + assert record_b.seq == record_b_add_prov.seq # Allow time for the provider record to propagate await trio.sleep(0.1) @@ -294,8 +330,41 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]): record_a_get_value = envelope_a_get_value.record() record_b_get_value = envelope_b_get_value.record() - # This proves that both the records are different, meaning that there was - # a new signed-record tranfer during the GET_VALUE execution by dht_b, which means - # the signed-record transfer works correctly in GET_VALUE executions. - assert record_a_find_prov.seq < record_a_get_value.seq - assert record_b_find_prov.seq < record_b_get_value.seq + # This proves that both the records are same, meaning that the latest cached + # signed-record tranfer happened during the GET_VALUE execution by dht_b, + # which means the signed-record transfer/re-issuing works correctly + # in GET_VALUE executions. + assert record_a_find_prov.seq == record_a_get_value.seq + assert record_b_find_prov.seq == record_b_get_value.seq + + +@pytest.mark.trio +async def test_reissue_when_listen_addrs_change(dht_pair: tuple[KadDHT, KadDHT]): + dht_a, dht_b = dht_pair + + # Warm-up: A stores B's current record + with trio.fail_after(10): + await dht_a.find_peer(dht_b.host.get_id()) + + env0 = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id()) + assert isinstance(env0, Envelope) + seq0 = env0.record().seq + + # Simulate B's listen addrs changing (different port) + new_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/123") + + # Patch just for the duration we force B to respond: + with patch.object(dht_b.host, "get_addrs", return_value=[new_addr]): + # Force B to send a response (which should include a fresh SPR) + with trio.fail_after(10): + await dht_a.peer_routing._query_peer_for_closest( + dht_b.host.get_id(), os.urandom(32) + ) + + # A should now hold B's new record with a bumped seq + env1 = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id()) + assert isinstance(env1, Envelope) + seq1 = env1.record().seq + + # This proves that upon the change in listen_addrs, we issue new records + assert seq1 > seq0, f"Expected seq to bump after addr change, got {seq0} -> {seq1}"