add reissuing mechanism of records if addrs dont change

This commit is contained in:
lla-dane
2025-08-14 10:39:48 +05:30
parent 702ad4876e
commit cea1985c5c
10 changed files with 170 additions and 91 deletions

View File

@ -970,6 +970,13 @@ class IPeerStore(
# --------CERTIFIED-ADDR-BOOK---------- # --------CERTIFIED-ADDR-BOOK----------
@abstractmethod
def get_local_record(self) -> Optional["Envelope"]:
"""Get the local-peer-record wrapped in Envelope"""
def set_local_record(self, envelope: "Envelope") -> None:
"""Set the local-peer-record wrapped in Envelope"""
@abstractmethod @abstractmethod
def consume_peer_record(self, envelope: "Envelope", ttl: int) -> bool: def consume_peer_record(self, envelope: "Envelope", ttl: int) -> bool:
""" """

View File

@ -43,6 +43,7 @@ from libp2p.peer.id import (
from libp2p.peer.peerinfo import ( from libp2p.peer.peerinfo import (
PeerInfo, PeerInfo,
) )
from libp2p.peer.peerstore import create_signed_peer_record
from libp2p.protocol_muxer.exceptions import ( from libp2p.protocol_muxer.exceptions import (
MultiselectClientError, MultiselectClientError,
MultiselectError, MultiselectError,
@ -110,6 +111,14 @@ class BasicHost(IHost):
if bootstrap: if bootstrap:
self.bootstrap = BootstrapDiscovery(network, bootstrap) self.bootstrap = BootstrapDiscovery(network, bootstrap)
# Cache a signed-record if the local-node in the PeerStore
envelope = create_signed_peer_record(
self.get_id(),
self.get_addrs(),
self.get_private_key(),
)
self.get_peerstore().set_local_record(envelope)
def get_id(self) -> ID: def get_id(self) -> ID:
""" """
:return: peer_id of host :return: peer_id of host

View File

@ -22,7 +22,7 @@ from libp2p.abc import (
IHost, IHost,
) )
from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager
from libp2p.kad_dht.utils import maybe_consume_signed_record from libp2p.kad_dht.utils import env_to_send_in_RPC, maybe_consume_signed_record
from libp2p.network.stream.net_stream import ( from libp2p.network.stream.net_stream import (
INetStream, INetStream,
) )
@ -33,7 +33,6 @@ from libp2p.peer.id import (
from libp2p.peer.peerinfo import ( from libp2p.peer.peerinfo import (
PeerInfo, PeerInfo,
) )
from libp2p.peer.peerstore import create_signed_peer_record
from libp2p.tools.async_service import ( from libp2p.tools.async_service import (
Service, Service,
) )
@ -319,12 +318,8 @@ class KadDHT(Service):
) )
# Create sender_signed_peer_record # Create sender_signed_peer_record
envelope = create_signed_peer_record( envelope_bytes, bool = env_to_send_in_RPC(self.host)
self.host.get_id(), response.senderRecord = envelope_bytes
self.host.get_addrs(),
self.host.get_private_key(),
)
response.senderRecord = envelope.marshal_envelope()
# Serialize and send response # Serialize and send response
response_bytes = response.SerializeToString() response_bytes = response.SerializeToString()
@ -383,12 +378,8 @@ class KadDHT(Service):
response.key = key response.key = key
# Add sender's signed-peer-record # Add sender's signed-peer-record
envelope = create_signed_peer_record( envelope_bytes, bool = env_to_send_in_RPC(self.host)
self.host.get_id(), response.senderRecord = envelope_bytes
self.host.get_addrs(),
self.host.get_private_key(),
)
response.senderRecord = envelope.marshal_envelope()
response_bytes = response.SerializeToString() response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes))) await stream.write(varint.encode(len(response_bytes)))
@ -416,12 +407,8 @@ class KadDHT(Service):
response.key = key response.key = key
# Create sender_signed_peer_record for the response # Create sender_signed_peer_record for the response
envelope = create_signed_peer_record( envelope_bytes, bool = env_to_send_in_RPC(self.host)
self.host.get_id(), response.senderRecord = envelope_bytes
self.host.get_addrs(),
self.host.get_private_key(),
)
response.senderRecord = envelope.marshal_envelope()
# Add provider information to response # Add provider information to response
for provider_info in providers: for provider_info in providers:
@ -512,12 +499,8 @@ class KadDHT(Service):
response.record.timeReceived = str(time.time()) response.record.timeReceived = str(time.time())
# Create sender_signed_peer_record # Create sender_signed_peer_record
envelope = create_signed_peer_record( envelope_bytes, bool = env_to_send_in_RPC(self.host)
self.host.get_id(), response.senderRecord = envelope_bytes
self.host.get_addrs(),
self.host.get_private_key(),
)
response.senderRecord = envelope.marshal_envelope()
# Serialize and send response # Serialize and send response
response_bytes = response.SerializeToString() response_bytes = response.SerializeToString()
@ -533,12 +516,8 @@ class KadDHT(Service):
response.key = key response.key = key
# Create sender_signed_peer_record for the response # Create sender_signed_peer_record for the response
envelope = create_signed_peer_record( envelope_bytes, bool = env_to_send_in_RPC(self.host)
self.host.get_id(), response.senderRecord = envelope_bytes
self.host.get_addrs(),
self.host.get_private_key(),
)
response.senderRecord = envelope.marshal_envelope()
# Add closest peers to key # Add closest peers to key
closest_peers = self.routing_table.find_local_closest_peers( closest_peers = self.routing_table.find_local_closest_peers(
@ -616,12 +595,8 @@ class KadDHT(Service):
response.key = key response.key = key
# Create sender_signed_peer_record for the response # Create sender_signed_peer_record for the response
envelope = create_signed_peer_record( envelope_bytes, bool = env_to_send_in_RPC(self.host)
self.host.get_id(), response.senderRecord = envelope_bytes
self.host.get_addrs(),
self.host.get_private_key(),
)
response.senderRecord = envelope.marshal_envelope()
# Serialize and send response # Serialize and send response
response_bytes = response.SerializeToString() response_bytes = response.SerializeToString()

View File

@ -22,7 +22,6 @@ from libp2p.peer.id import (
from libp2p.peer.peerinfo import ( from libp2p.peer.peerinfo import (
PeerInfo, PeerInfo,
) )
from libp2p.peer.peerstore import create_signed_peer_record
from .common import ( from .common import (
ALPHA, ALPHA,
@ -35,6 +34,7 @@ from .routing_table import (
RoutingTable, RoutingTable,
) )
from .utils import ( from .utils import (
env_to_send_in_RPC,
maybe_consume_signed_record, maybe_consume_signed_record,
sort_peer_ids_by_distance, sort_peer_ids_by_distance,
) )
@ -259,10 +259,8 @@ class PeerRouting(IPeerRouting):
find_node_msg.key = target_key # Set target key directly as bytes find_node_msg.key = target_key # Set target key directly as bytes
# Create sender_signed_peer_record # Create sender_signed_peer_record
envelope = create_signed_peer_record( envelope_bytes, bool = env_to_send_in_RPC(self.host)
self.host.get_id(), self.host.get_addrs(), self.host.get_private_key() find_node_msg.senderRecord = envelope_bytes
)
find_node_msg.senderRecord = envelope.marshal_envelope()
# Serialize and send the protobuf message with varint length prefix # Serialize and send the protobuf message with varint length prefix
proto_bytes = find_node_msg.SerializeToString() proto_bytes = find_node_msg.SerializeToString()
@ -381,12 +379,8 @@ class PeerRouting(IPeerRouting):
response.type = Message.MessageType.FIND_NODE response.type = Message.MessageType.FIND_NODE
# Create sender_signed_peer_record for the response # Create sender_signed_peer_record for the response
envelope = create_signed_peer_record( envelope_bytes, bool = env_to_send_in_RPC(self.host)
self.host.get_id(), response.senderRecord = envelope_bytes
self.host.get_addrs(),
self.host.get_private_key(),
)
response.senderRecord = envelope.marshal_envelope()
# Add peer information to response # Add peer information to response
for peer_id in closest_peers: for peer_id in closest_peers:

View File

@ -22,14 +22,13 @@ from libp2p.abc import (
from libp2p.custom_types import ( from libp2p.custom_types import (
TProtocol, TProtocol,
) )
from libp2p.kad_dht.utils import maybe_consume_signed_record from libp2p.kad_dht.utils import env_to_send_in_RPC, maybe_consume_signed_record
from libp2p.peer.id import ( from libp2p.peer.id import (
ID, ID,
) )
from libp2p.peer.peerinfo import ( from libp2p.peer.peerinfo import (
PeerInfo, PeerInfo,
) )
from libp2p.peer.peerstore import create_signed_peer_record
from .common import ( from .common import (
ALPHA, ALPHA,
@ -243,12 +242,8 @@ class ProviderStore:
message.key = key message.key = key
# Create sender's signed-peer-record # Create sender's signed-peer-record
envelope = create_signed_peer_record( envelope_bytes, bool = env_to_send_in_RPC(self.host)
self.host.get_id(), message.senderRecord = envelope_bytes
self.host.get_addrs(),
self.host.get_private_key(),
)
message.senderRecord = envelope.marshal_envelope()
# Add our provider info # Add our provider info
provider = message.providerPeers.add() provider = message.providerPeers.add()
@ -256,7 +251,7 @@ class ProviderStore:
provider.addrs.extend(addrs) provider.addrs.extend(addrs)
# Add the provider's signed-peer-record # Add the provider's signed-peer-record
provider.signedRecord = envelope.marshal_envelope() provider.signedRecord = envelope_bytes
# Serialize and send the message # Serialize and send the message
proto_bytes = message.SerializeToString() proto_bytes = message.SerializeToString()
@ -394,12 +389,8 @@ class ProviderStore:
message.key = key message.key = key
# Create sender's signed-peer-record # Create sender's signed-peer-record
envelope = create_signed_peer_record( envelope_bytes, bool = env_to_send_in_RPC(self.host)
self.host.get_id(), message.senderRecord = envelope_bytes
self.host.get_addrs(),
self.host.get_private_key(),
)
message.senderRecord = envelope.marshal_envelope()
# Serialize and send the message # Serialize and send the message
proto_bytes = message.SerializeToString() proto_bytes = message.SerializeToString()

View File

@ -12,6 +12,7 @@ from libp2p.peer.envelope import consume_envelope
from libp2p.peer.id import ( from libp2p.peer.id import (
ID, ID,
) )
from libp2p.peer.peerstore import create_signed_peer_record
from .pb.kademlia_pb2 import ( from .pb.kademlia_pb2 import (
Message, Message,
@ -54,6 +55,34 @@ def maybe_consume_signed_record(msg: Message | Message.Peer, host: IHost) -> boo
return True return True
def env_to_send_in_RPC(host: IHost) -> tuple[bytes, bool]:
listen_addrs_set = {addr for addr in host.get_addrs()}
local_env = host.get_peerstore().get_local_record()
if local_env is None:
# No cached SPR yet -> create one
return issue_and_cache_local_record(host), True
else:
record_addrs_set = local_env._env_addrs_set()
if record_addrs_set == listen_addrs_set:
# Perfect match -> reuse cached envelope
return local_env.marshal_envelope(), False
else:
# Addresses changed -> issue a new SPR and cache it
return issue_and_cache_local_record(host), True
def issue_and_cache_local_record(host: IHost) -> bytes:
env = create_signed_peer_record(
host.get_id(),
host.get_addrs(),
host.get_private_key(),
)
# Cache it for nexxt time use
host.get_peerstore().set_local_record(env)
return env.marshal_envelope()
def create_key_from_binary(binary_data: bytes) -> bytes: def create_key_from_binary(binary_data: bytes) -> bytes:
""" """
Creates a key for the DHT by hashing binary data with SHA-256. Creates a key for the DHT by hashing binary data with SHA-256.

View File

@ -15,11 +15,10 @@ from libp2p.abc import (
from libp2p.custom_types import ( from libp2p.custom_types import (
TProtocol, TProtocol,
) )
from libp2p.kad_dht.utils import maybe_consume_signed_record from libp2p.kad_dht.utils import env_to_send_in_RPC, maybe_consume_signed_record
from libp2p.peer.id import ( from libp2p.peer.id import (
ID, ID,
) )
from libp2p.peer.peerstore import create_signed_peer_record
from .common import ( from .common import (
DEFAULT_TTL, DEFAULT_TTL,
@ -113,12 +112,8 @@ class ValueStore:
message.type = Message.MessageType.PUT_VALUE message.type = Message.MessageType.PUT_VALUE
# Create sender's signed-peer-record # Create sender's signed-peer-record
envelope = create_signed_peer_record( envelope_bytes, bool = env_to_send_in_RPC(self.host)
self.host.get_id(), message.senderRecord = envelope_bytes
self.host.get_addrs(),
self.host.get_private_key(),
)
message.senderRecord = envelope.marshal_envelope()
# Set message fields # Set message fields
message.key = key message.key = key
@ -245,12 +240,8 @@ class ValueStore:
message.key = key message.key = key
# Create sender's signed-peer-record # Create sender's signed-peer-record
envelope = create_signed_peer_record( envelope_bytes, bool = env_to_send_in_RPC(self.host)
self.host.get_id(), message.senderRecord = envelope_bytes
self.host.get_addrs(),
self.host.get_private_key(),
)
message.senderRecord = envelope.marshal_envelope()
# Serialize and send the protobuf message # Serialize and send the protobuf message
proto_bytes = message.SerializeToString() proto_bytes = message.SerializeToString()

View File

@ -1,5 +1,7 @@
from typing import Any, cast from typing import Any, cast
import multiaddr
from libp2p.crypto.ed25519 import Ed25519PublicKey from libp2p.crypto.ed25519 import Ed25519PublicKey
from libp2p.crypto.keys import PrivateKey, PublicKey from libp2p.crypto.keys import PrivateKey, PublicKey
from libp2p.crypto.rsa import RSAPublicKey from libp2p.crypto.rsa import RSAPublicKey
@ -131,6 +133,9 @@ class Envelope:
) )
return False return False
def _env_addrs_set(self) -> set[multiaddr.Multiaddr]:
return {b for b in self.record().addrs}
def pub_key_to_protobuf(pub_key: PublicKey) -> cryto_pb.PublicKey: def pub_key_to_protobuf(pub_key: PublicKey) -> cryto_pb.PublicKey:
""" """

View File

@ -65,8 +65,17 @@ class PeerStore(IPeerStore):
self.peer_data_map = defaultdict(PeerData) self.peer_data_map = defaultdict(PeerData)
self.addr_update_channels: dict[ID, MemorySendChannel[Multiaddr]] = {} self.addr_update_channels: dict[ID, MemorySendChannel[Multiaddr]] = {}
self.peer_record_map: dict[ID, PeerRecordState] = {} self.peer_record_map: dict[ID, PeerRecordState] = {}
self.local_peer_record: Envelope | None = None
self.max_records = max_records self.max_records = max_records
def get_local_record(self) -> Envelope | None:
"""Get the local-signed-record wrapped in Envelope"""
return self.local_peer_record
def set_local_record(self, envelope: Envelope) -> None:
"""Set the local-signed-record wrapped in Envelope"""
self.local_peer_record = envelope
def peer_info(self, peer_id: ID) -> PeerInfo: def peer_info(self, peer_id: ID) -> PeerInfo:
""" """
:param peer_id: peer ID to get info for :param peer_id: peer ID to get info for

View File

@ -9,9 +9,12 @@ This module tests core functionality of the Kademlia DHT including:
import hashlib import hashlib
import logging import logging
import os
from unittest.mock import patch
import uuid import uuid
import pytest import pytest
import multiaddr
import trio import trio
from libp2p.kad_dht.kad_dht import ( from libp2p.kad_dht.kad_dht import (
@ -77,6 +80,18 @@ async def test_find_node(dht_pair: tuple[KadDHT, KadDHT]):
"""Test that nodes can find each other in the DHT.""" """Test that nodes can find each other in the DHT."""
dht_a, dht_b = dht_pair dht_a, dht_b = dht_pair
# An extra FIND_NODE req is sent between the 2 nodes while dht creation,
# so both the nodes will have records of each other before the next FIND_NODE
# req is sent
envelope_a = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id())
envelope_b = dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id())
assert isinstance(envelope_a, Envelope)
assert isinstance(envelope_b, Envelope)
record_a = envelope_a.record()
record_b = envelope_b.record()
# Node A should be able to find Node B # Node A should be able to find Node B
with trio.fail_after(TEST_TIMEOUT): with trio.fail_after(TEST_TIMEOUT):
found_info = await dht_a.find_peer(dht_b.host.get_id()) found_info = await dht_a.find_peer(dht_b.host.get_id())
@ -91,6 +106,26 @@ async def test_find_node(dht_pair: tuple[KadDHT, KadDHT]):
dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id()), Envelope dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id()), Envelope
) )
# These are the records that were sent betweeen the peers during the FIND_NODE req
envelope_a_find_peer = dht_a.host.get_peerstore().get_peer_record(
dht_b.host.get_id()
)
envelope_b_find_peer = dht_b.host.get_peerstore().get_peer_record(
dht_a.host.get_id()
)
assert isinstance(envelope_a_find_peer, Envelope)
assert isinstance(envelope_b_find_peer, Envelope)
record_a_find_peer = envelope_a_find_peer.record()
record_b_find_peer = envelope_b_find_peer.record()
# This proves that both the records are same, and a latest cached signed record
# was passed between the peers during FIND_NODE exceution, which proves the
# signed-record transfer/re-issuing works correctly in FIND_NODE executions.
assert record_a.seq == record_a_find_peer.seq
assert record_b.seq == record_b_find_peer.seq
# Verify that the found peer has the correct peer ID # Verify that the found peer has the correct peer ID
assert found_info is not None, "Failed to find the target peer" assert found_info is not None, "Failed to find the target peer"
assert found_info.peer_id == dht_b.host.get_id(), "Found incorrect peer ID" assert found_info.peer_id == dht_b.host.get_id(), "Found incorrect peer ID"
@ -144,11 +179,11 @@ async def test_put_and_get_value(dht_pair: tuple[KadDHT, KadDHT]):
record_a_put_value = envelope_a_put_value.record() record_a_put_value = envelope_a_put_value.record()
record_b_put_value = envelope_b_put_value.record() record_b_put_value = envelope_b_put_value.record()
# This proves that both the records are different, and a new signed record # This proves that both the records are same, and a latest cached signed record
# was passed between the peers during PUT_VALUE exceution, which proves the # was passed between the peers during PUT_VALUE exceution, which proves the
# signed-record transfer works correctly in PUT_VALUE executions. # signed-record transfer/re-issuing works correctly in PUT_VALUE executions.
assert record_a.seq < record_a_put_value.seq assert record_a.seq == record_a_put_value.seq
assert record_b.seq < record_b_put_value.seq assert record_b.seq == record_b_put_value.seq
# # Log debugging information # # Log debugging information
logger.debug("Put value with key %s...", key.hex()[:10]) logger.debug("Put value with key %s...", key.hex()[:10])
@ -234,11 +269,12 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]):
record_a_add_prov = envelope_a_add_prov.record() record_a_add_prov = envelope_a_add_prov.record()
record_b_add_prov = envelope_b_add_prov.record() record_b_add_prov = envelope_b_add_prov.record()
# This proves that both the records are different, and a new signed record # This proves that both the records are same, the latest cached signed record
# was passed between the peers during ADD_PROVIDER exceution, which proves the # was passed between the peers during ADD_PROVIDER exceution, which proves the
# signed-record transfer works correctly in ADD_PROVIDER executions. # signed-record transfer/re-issuing of the latest record works correctly in
assert record_a.seq < record_a_add_prov.seq # ADD_PROVIDER executions.
assert record_b.seq < record_b_add_prov.seq assert record_a.seq == record_a_add_prov.seq
assert record_b.seq == record_b_add_prov.seq
# Allow time for the provider record to propagate # Allow time for the provider record to propagate
await trio.sleep(0.1) await trio.sleep(0.1)
@ -294,8 +330,41 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]):
record_a_get_value = envelope_a_get_value.record() record_a_get_value = envelope_a_get_value.record()
record_b_get_value = envelope_b_get_value.record() record_b_get_value = envelope_b_get_value.record()
# This proves that both the records are different, meaning that there was # This proves that both the records are same, meaning that the latest cached
# a new signed-record tranfer during the GET_VALUE execution by dht_b, which means # signed-record tranfer happened during the GET_VALUE execution by dht_b,
# the signed-record transfer works correctly in GET_VALUE executions. # which means the signed-record transfer/re-issuing works correctly
assert record_a_find_prov.seq < record_a_get_value.seq # in GET_VALUE executions.
assert record_b_find_prov.seq < record_b_get_value.seq assert record_a_find_prov.seq == record_a_get_value.seq
assert record_b_find_prov.seq == record_b_get_value.seq
@pytest.mark.trio
async def test_reissue_when_listen_addrs_change(dht_pair: tuple[KadDHT, KadDHT]):
dht_a, dht_b = dht_pair
# Warm-up: A stores B's current record
with trio.fail_after(10):
await dht_a.find_peer(dht_b.host.get_id())
env0 = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id())
assert isinstance(env0, Envelope)
seq0 = env0.record().seq
# Simulate B's listen addrs changing (different port)
new_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/123")
# Patch just for the duration we force B to respond:
with patch.object(dht_b.host, "get_addrs", return_value=[new_addr]):
# Force B to send a response (which should include a fresh SPR)
with trio.fail_after(10):
await dht_a.peer_routing._query_peer_for_closest(
dht_b.host.get_id(), os.urandom(32)
)
# A should now hold B's new record with a bumped seq
env1 = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id())
assert isinstance(env1, Envelope)
seq1 = env1.record().seq
# This proves that upon the change in listen_addrs, we issue new records
assert seq1 > seq0, f"Expected seq to bump after addr change, got {seq0} -> {seq1}"