Merge branch 'main' into chore01

This commit is contained in:
Manu Sheel Gupta
2025-08-29 03:14:51 +05:30
committed by GitHub
17 changed files with 818 additions and 160 deletions

View File

@ -49,6 +49,7 @@ from libp2p.peer.id import (
) )
from libp2p.peer.peerstore import ( from libp2p.peer.peerstore import (
PeerStore, PeerStore,
create_signed_peer_record,
) )
from libp2p.security.insecure.transport import ( from libp2p.security.insecure.transport import (
PLAINTEXT_PROTOCOL_ID, PLAINTEXT_PROTOCOL_ID,
@ -155,7 +156,6 @@ def get_default_muxer_options() -> TMuxerOptions:
else: # YAMUX is default else: # YAMUX is default
return create_yamux_muxer_option() return create_yamux_muxer_option()
def new_swarm( def new_swarm(
key_pair: KeyPair | None = None, key_pair: KeyPair | None = None,
muxer_opt: TMuxerOptions | None = None, muxer_opt: TMuxerOptions | None = None,

View File

@ -970,6 +970,14 @@ class IPeerStore(
# --------CERTIFIED-ADDR-BOOK---------- # --------CERTIFIED-ADDR-BOOK----------
@abstractmethod
def get_local_record(self) -> Optional["Envelope"]:
"""Get the local-peer-record wrapped in Envelope"""
@abstractmethod
def set_local_record(self, envelope: "Envelope") -> None:
"""Set the local-peer-record wrapped in Envelope"""
@abstractmethod @abstractmethod
def consume_peer_record(self, envelope: "Envelope", ttl: int) -> bool: def consume_peer_record(self, envelope: "Envelope", ttl: int) -> bool:
""" """

View File

@ -43,6 +43,7 @@ from libp2p.peer.id import (
from libp2p.peer.peerinfo import ( from libp2p.peer.peerinfo import (
PeerInfo, PeerInfo,
) )
from libp2p.peer.peerstore import create_signed_peer_record
from libp2p.protocol_muxer.exceptions import ( from libp2p.protocol_muxer.exceptions import (
MultiselectClientError, MultiselectClientError,
MultiselectError, MultiselectError,
@ -110,6 +111,14 @@ class BasicHost(IHost):
if bootstrap: if bootstrap:
self.bootstrap = BootstrapDiscovery(network, bootstrap) self.bootstrap = BootstrapDiscovery(network, bootstrap)
# Cache a signed-record if the local-node in the PeerStore
envelope = create_signed_peer_record(
self.get_id(),
self.get_addrs(),
self.get_private_key(),
)
self.get_peerstore().set_local_record(envelope)
def get_id(self) -> ID: def get_id(self) -> ID:
""" """
:return: peer_id of host :return: peer_id of host

View File

@ -15,8 +15,7 @@ from libp2p.custom_types import (
from libp2p.network.stream.exceptions import ( from libp2p.network.stream.exceptions import (
StreamClosed, StreamClosed,
) )
from libp2p.peer.envelope import seal_record from libp2p.peer.peerstore import env_to_send_in_RPC
from libp2p.peer.peer_record import PeerRecord
from libp2p.utils import ( from libp2p.utils import (
decode_varint_with_size, decode_varint_with_size,
get_agent_version, get_agent_version,
@ -66,9 +65,7 @@ def _mk_identify_protobuf(
protocols = tuple(str(p) for p in host.get_mux().get_protocols() if p is not None) protocols = tuple(str(p) for p in host.get_mux().get_protocols() if p is not None)
# Create a signed peer-record for the remote peer # Create a signed peer-record for the remote peer
record = PeerRecord(host.get_id(), host.get_addrs()) envelope_bytes, _ = env_to_send_in_RPC(host)
envelope = seal_record(record, host.get_private_key())
protobuf = envelope.marshal_envelope()
observed_addr = observed_multiaddr.to_bytes() if observed_multiaddr else b"" observed_addr = observed_multiaddr.to_bytes() if observed_multiaddr else b""
return Identify( return Identify(
@ -78,7 +75,7 @@ def _mk_identify_protobuf(
listen_addrs=map(_multiaddr_to_bytes, laddrs), listen_addrs=map(_multiaddr_to_bytes, laddrs),
observed_addr=observed_addr, observed_addr=observed_addr,
protocols=protocols, protocols=protocols,
signedPeerRecord=protobuf, signedPeerRecord=envelope_bytes,
) )

View File

@ -22,15 +22,18 @@ from libp2p.abc import (
IHost, IHost,
) )
from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager
from libp2p.kad_dht.utils import maybe_consume_signed_record
from libp2p.network.stream.net_stream import ( from libp2p.network.stream.net_stream import (
INetStream, INetStream,
) )
from libp2p.peer.envelope import Envelope
from libp2p.peer.id import ( from libp2p.peer.id import (
ID, ID,
) )
from libp2p.peer.peerinfo import ( from libp2p.peer.peerinfo import (
PeerInfo, PeerInfo,
) )
from libp2p.peer.peerstore import env_to_send_in_RPC
from libp2p.tools.async_service import ( from libp2p.tools.async_service import (
Service, Service,
) )
@ -234,6 +237,9 @@ class KadDHT(Service):
await self.add_peer(peer_id) await self.add_peer(peer_id)
logger.debug(f"Added peer {peer_id} to routing table") logger.debug(f"Added peer {peer_id} to routing table")
closer_peer_envelope: Envelope | None = None
provider_peer_envelope: Envelope | None = None
try: try:
# Read varint-prefixed length for the message # Read varint-prefixed length for the message
length_prefix = b"" length_prefix = b""
@ -274,6 +280,14 @@ class KadDHT(Service):
) )
logger.debug(f"Found {len(closest_peers)} peers close to target") logger.debug(f"Found {len(closest_peers)} peers close to target")
# Consume the source signed_peer_record if sent
if not maybe_consume_signed_record(message, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, dropping the stream"
)
await stream.close()
return
# Build response message with protobuf # Build response message with protobuf
response = Message() response = Message()
response.type = Message.MessageType.FIND_NODE response.type = Message.MessageType.FIND_NODE
@ -298,6 +312,21 @@ class KadDHT(Service):
except Exception: except Exception:
pass pass
# Add the signed-peer-record for each peer in the peer-proto
# if cached in the peerstore
closer_peer_envelope = (
self.host.get_peerstore().get_peer_record(peer)
)
if closer_peer_envelope is not None:
peer_proto.signedRecord = (
closer_peer_envelope.marshal_envelope()
)
# Create sender_signed_peer_record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
# Serialize and send response # Serialize and send response
response_bytes = response.SerializeToString() response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes))) await stream.write(varint.encode(len(response_bytes)))
@ -312,6 +341,14 @@ class KadDHT(Service):
key = message.key key = message.key
logger.debug(f"Received ADD_PROVIDER for key {key.hex()}") logger.debug(f"Received ADD_PROVIDER for key {key.hex()}")
# Consume the source signed-peer-record if sent
if not maybe_consume_signed_record(message, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, dropping the stream"
)
await stream.close()
return
# Extract provider information # Extract provider information
for provider_proto in message.providerPeers: for provider_proto in message.providerPeers:
try: try:
@ -338,6 +375,17 @@ class KadDHT(Service):
logger.debug( logger.debug(
f"Added provider {provider_id} for key {key.hex()}" f"Added provider {provider_id} for key {key.hex()}"
) )
# Process the signed-records of provider if sent
if not maybe_consume_signed_record(
provider_proto, self.host
):
logger.error(
"Received an invalid-signed-record,"
"dropping the stream"
)
await stream.close()
return
except Exception as e: except Exception as e:
logger.warning(f"Failed to process provider info: {e}") logger.warning(f"Failed to process provider info: {e}")
@ -346,6 +394,10 @@ class KadDHT(Service):
response.type = Message.MessageType.ADD_PROVIDER response.type = Message.MessageType.ADD_PROVIDER
response.key = key response.key = key
# Add sender's signed-peer-record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
response_bytes = response.SerializeToString() response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes))) await stream.write(varint.encode(len(response_bytes)))
await stream.write(response_bytes) await stream.write(response_bytes)
@ -357,6 +409,14 @@ class KadDHT(Service):
key = message.key key = message.key
logger.debug(f"Received GET_PROVIDERS request for key {key.hex()}") logger.debug(f"Received GET_PROVIDERS request for key {key.hex()}")
# Consume the source signed_peer_record if sent
if not maybe_consume_signed_record(message, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, dropping the stream"
)
await stream.close()
return
# Find providers for the key # Find providers for the key
providers = self.provider_store.get_providers(key) providers = self.provider_store.get_providers(key)
logger.debug( logger.debug(
@ -368,12 +428,28 @@ class KadDHT(Service):
response.type = Message.MessageType.GET_PROVIDERS response.type = Message.MessageType.GET_PROVIDERS
response.key = key response.key = key
# Create sender_signed_peer_record for the response
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
# Add provider information to response # Add provider information to response
for provider_info in providers: for provider_info in providers:
provider_proto = response.providerPeers.add() provider_proto = response.providerPeers.add()
provider_proto.id = provider_info.peer_id.to_bytes() provider_proto.id = provider_info.peer_id.to_bytes()
provider_proto.connection = Message.ConnectionType.CAN_CONNECT provider_proto.connection = Message.ConnectionType.CAN_CONNECT
# Add provider signed-records if cached
provider_peer_envelope = (
self.host.get_peerstore().get_peer_record(
provider_info.peer_id
)
)
if provider_peer_envelope is not None:
provider_proto.signedRecord = (
provider_peer_envelope.marshal_envelope()
)
# Add addresses if available # Add addresses if available
for addr in provider_info.addrs: for addr in provider_info.addrs:
provider_proto.addrs.append(addr.to_bytes()) provider_proto.addrs.append(addr.to_bytes())
@ -397,6 +473,16 @@ class KadDHT(Service):
peer_proto.id = peer.to_bytes() peer_proto.id = peer.to_bytes()
peer_proto.connection = Message.ConnectionType.CAN_CONNECT peer_proto.connection = Message.ConnectionType.CAN_CONNECT
# Add the signed-records of closest_peers if cached
closer_peer_envelope = (
self.host.get_peerstore().get_peer_record(peer)
)
if closer_peer_envelope is not None:
peer_proto.signedRecord = (
closer_peer_envelope.marshal_envelope()
)
# Add addresses if available # Add addresses if available
try: try:
addrs = self.host.get_peerstore().addrs(peer) addrs = self.host.get_peerstore().addrs(peer)
@ -417,6 +503,14 @@ class KadDHT(Service):
key = message.key key = message.key
logger.debug(f"Received GET_VALUE request for key {key.hex()}") logger.debug(f"Received GET_VALUE request for key {key.hex()}")
# Consume the sender_signed_peer_record
if not maybe_consume_signed_record(message, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, dropping the stream"
)
await stream.close()
return
value = self.value_store.get(key) value = self.value_store.get(key)
if value: if value:
logger.debug(f"Found value for key {key.hex()}") logger.debug(f"Found value for key {key.hex()}")
@ -431,6 +525,10 @@ class KadDHT(Service):
response.record.value = value response.record.value = value
response.record.timeReceived = str(time.time()) response.record.timeReceived = str(time.time())
# Create sender_signed_peer_record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
# Serialize and send response # Serialize and send response
response_bytes = response.SerializeToString() response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes))) await stream.write(varint.encode(len(response_bytes)))
@ -444,6 +542,10 @@ class KadDHT(Service):
response.type = Message.MessageType.GET_VALUE response.type = Message.MessageType.GET_VALUE
response.key = key response.key = key
# Create sender_signed_peer_record for the response
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
# Add closest peers to key # Add closest peers to key
closest_peers = self.routing_table.find_local_closest_peers( closest_peers = self.routing_table.find_local_closest_peers(
key, 20 key, 20
@ -462,6 +564,16 @@ class KadDHT(Service):
peer_proto.id = peer.to_bytes() peer_proto.id = peer.to_bytes()
peer_proto.connection = Message.ConnectionType.CAN_CONNECT peer_proto.connection = Message.ConnectionType.CAN_CONNECT
# Add signed-records of closer-peers if cached
closer_peer_envelope = (
self.host.get_peerstore().get_peer_record(peer)
)
if closer_peer_envelope is not None:
peer_proto.signedRecord = (
closer_peer_envelope.marshal_envelope()
)
# Add addresses if available # Add addresses if available
try: try:
addrs = self.host.get_peerstore().addrs(peer) addrs = self.host.get_peerstore().addrs(peer)
@ -484,6 +596,15 @@ class KadDHT(Service):
key = message.record.key key = message.record.key
value = message.record.value value = message.record.value
success = False success = False
# Consume the source signed_peer_record if sent
if not maybe_consume_signed_record(message, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, dropping the stream"
)
await stream.close()
return
try: try:
if not (key and value): if not (key and value):
raise ValueError( raise ValueError(
@ -504,6 +625,12 @@ class KadDHT(Service):
response.type = Message.MessageType.PUT_VALUE response.type = Message.MessageType.PUT_VALUE
if success: if success:
response.key = key response.key = key
# Create sender_signed_peer_record for the response
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
# Serialize and send response
response_bytes = response.SerializeToString() response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes))) await stream.write(varint.encode(len(response_bytes)))
await stream.write(response_bytes) await stream.write(response_bytes)

View File

@ -27,6 +27,7 @@ message Message {
bytes id = 1; bytes id = 1;
repeated bytes addrs = 2; repeated bytes addrs = 2;
ConnectionType connection = 3; ConnectionType connection = 3;
optional bytes signedRecord = 4; // Envelope(PeerRecord) encoded
} }
MessageType type = 1; MessageType type = 1;
@ -35,4 +36,6 @@ message Message {
Record record = 3; Record record = 3;
repeated Peer closerPeers = 8; repeated Peer closerPeers = 8;
repeated Peer providerPeers = 9; repeated Peer providerPeers = 9;
optional bytes senderRecord = 11; // Envelope(PeerRecord) encoded
} }

View File

@ -1,11 +1,12 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT! # Generated by the protocol buffer compiler. DO NOT EDIT!
# source: libp2p/kad_dht/pb/kademlia.proto # source: libp2p/kad_dht/pb/kademlia.proto
# Protobuf Python Version: 4.25.3
"""Generated protocol buffer code.""" """Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports) # @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default() _sym_db = _symbol_database.Default()
@ -13,21 +14,21 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/kad_dht/pb/kademlia.proto\":\n\x06Record\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x14\n\x0ctimeReceived\x18\x05 \x01(\t\"\xca\x03\n\x07Message\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.Message.MessageType\x12\x17\n\x0f\x63lusterLevelRaw\x18\n \x01(\x05\x12\x0b\n\x03key\x18\x02 \x01(\x0c\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record\x12\"\n\x0b\x63loserPeers\x18\x08 \x03(\x0b\x32\r.Message.Peer\x12$\n\rproviderPeers\x18\t \x03(\x0b\x32\r.Message.Peer\x1aN\n\x04Peer\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12+\n\nconnection\x18\x03 \x01(\x0e\x32\x17.Message.ConnectionType\"i\n\x0bMessageType\x12\r\n\tPUT_VALUE\x10\x00\x12\r\n\tGET_VALUE\x10\x01\x12\x10\n\x0c\x41\x44\x44_PROVIDER\x10\x02\x12\x11\n\rGET_PROVIDERS\x10\x03\x12\r\n\tFIND_NODE\x10\x04\x12\x08\n\x04PING\x10\x05\"W\n\x0e\x43onnectionType\x12\x11\n\rNOT_CONNECTED\x10\x00\x12\r\n\tCONNECTED\x10\x01\x12\x0f\n\x0b\x43\x41N_CONNECT\x10\x02\x12\x12\n\x0e\x43\x41NNOT_CONNECT\x10\x03\x62\x06proto3') DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/kad_dht/pb/kademlia.proto\":\n\x06Record\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x14\n\x0ctimeReceived\x18\x05 \x01(\t\"\xa2\x04\n\x07Message\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.Message.MessageType\x12\x17\n\x0f\x63lusterLevelRaw\x18\n \x01(\x05\x12\x0b\n\x03key\x18\x02 \x01(\x0c\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record\x12\"\n\x0b\x63loserPeers\x18\x08 \x03(\x0b\x32\r.Message.Peer\x12$\n\rproviderPeers\x18\t \x03(\x0b\x32\r.Message.Peer\x12\x19\n\x0csenderRecord\x18\x0b \x01(\x0cH\x00\x88\x01\x01\x1az\n\x04Peer\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12+\n\nconnection\x18\x03 \x01(\x0e\x32\x17.Message.ConnectionType\x12\x19\n\x0csignedRecord\x18\x04 \x01(\x0cH\x00\x88\x01\x01\x42\x0f\n\r_signedRecord\"i\n\x0bMessageType\x12\r\n\tPUT_VALUE\x10\x00\x12\r\n\tGET_VALUE\x10\x01\x12\x10\n\x0c\x41\x44\x44_PROVIDER\x10\x02\x12\x11\n\rGET_PROVIDERS\x10\x03\x12\r\n\tFIND_NODE\x10\x04\x12\x08\n\x04PING\x10\x05\"W\n\x0e\x43onnectionType\x12\x11\n\rNOT_CONNECTED\x10\x00\x12\r\n\tCONNECTED\x10\x01\x12\x0f\n\x0b\x43\x41N_CONNECT\x10\x02\x12\x12\n\x0e\x43\x41NNOT_CONNECT\x10\x03\x42\x0f\n\r_senderRecordb\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _globals = globals()
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.kad_dht.pb.kademlia_pb2', globals()) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.kad_dht.pb.kademlia_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False: if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None DESCRIPTOR._options = None
_RECORD._serialized_start=36 _globals['_RECORD']._serialized_start=36
_RECORD._serialized_end=94 _globals['_RECORD']._serialized_end=94
_MESSAGE._serialized_start=97 _globals['_MESSAGE']._serialized_start=97
_MESSAGE._serialized_end=555 _globals['_MESSAGE']._serialized_end=643
_MESSAGE_PEER._serialized_start=281 _globals['_MESSAGE_PEER']._serialized_start=308
_MESSAGE_PEER._serialized_end=359 _globals['_MESSAGE_PEER']._serialized_end=430
_MESSAGE_MESSAGETYPE._serialized_start=361 _globals['_MESSAGE_MESSAGETYPE']._serialized_start=432
_MESSAGE_MESSAGETYPE._serialized_end=466 _globals['_MESSAGE_MESSAGETYPE']._serialized_end=537
_MESSAGE_CONNECTIONTYPE._serialized_start=468 _globals['_MESSAGE_CONNECTIONTYPE']._serialized_start=539
_MESSAGE_CONNECTIONTYPE._serialized_end=555 _globals['_MESSAGE_CONNECTIONTYPE']._serialized_end=626
# @@protoc_insertion_point(module_scope) # @@protoc_insertion_point(module_scope)

View File

@ -1,133 +1,70 @@
""" from google.protobuf.internal import containers as _containers
@generated by mypy-protobuf. Do not edit manually! from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
isort:skip_file from google.protobuf import descriptor as _descriptor
""" from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union
import builtins DESCRIPTOR: _descriptor.FileDescriptor
import collections.abc
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.internal.enum_type_wrapper
import google.protobuf.message
import sys
import typing
if sys.version_info >= (3, 10): class Record(_message.Message):
import typing as typing_extensions __slots__ = ("key", "value", "timeReceived")
else: KEY_FIELD_NUMBER: _ClassVar[int]
import typing_extensions VALUE_FIELD_NUMBER: _ClassVar[int]
TIMERECEIVED_FIELD_NUMBER: _ClassVar[int]
key: bytes
value: bytes
timeReceived: str
def __init__(self, key: _Optional[bytes] = ..., value: _Optional[bytes] = ..., timeReceived: _Optional[str] = ...) -> None: ...
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor class Message(_message.Message):
__slots__ = ("type", "clusterLevelRaw", "key", "record", "closerPeers", "providerPeers", "senderRecord")
@typing.final class MessageType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
class Record(google.protobuf.message.Message): __slots__ = ()
DESCRIPTOR: google.protobuf.descriptor.Descriptor PUT_VALUE: _ClassVar[Message.MessageType]
GET_VALUE: _ClassVar[Message.MessageType]
KEY_FIELD_NUMBER: builtins.int ADD_PROVIDER: _ClassVar[Message.MessageType]
VALUE_FIELD_NUMBER: builtins.int GET_PROVIDERS: _ClassVar[Message.MessageType]
TIMERECEIVED_FIELD_NUMBER: builtins.int FIND_NODE: _ClassVar[Message.MessageType]
key: builtins.bytes PING: _ClassVar[Message.MessageType]
value: builtins.bytes PUT_VALUE: Message.MessageType
timeReceived: builtins.str GET_VALUE: Message.MessageType
def __init__( ADD_PROVIDER: Message.MessageType
self, GET_PROVIDERS: Message.MessageType
*, FIND_NODE: Message.MessageType
key: builtins.bytes = ..., PING: Message.MessageType
value: builtins.bytes = ..., class ConnectionType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
timeReceived: builtins.str = ..., __slots__ = ()
) -> None: ... NOT_CONNECTED: _ClassVar[Message.ConnectionType]
def ClearField(self, field_name: typing.Literal["key", b"key", "timeReceived", b"timeReceived", "value", b"value"]) -> None: ... CONNECTED: _ClassVar[Message.ConnectionType]
CAN_CONNECT: _ClassVar[Message.ConnectionType]
global___Record = Record CANNOT_CONNECT: _ClassVar[Message.ConnectionType]
NOT_CONNECTED: Message.ConnectionType
@typing.final CONNECTED: Message.ConnectionType
class Message(google.protobuf.message.Message): CAN_CONNECT: Message.ConnectionType
DESCRIPTOR: google.protobuf.descriptor.Descriptor CANNOT_CONNECT: Message.ConnectionType
class Peer(_message.Message):
class _MessageType: __slots__ = ("id", "addrs", "connection", "signedRecord")
ValueType = typing.NewType("ValueType", builtins.int) ID_FIELD_NUMBER: _ClassVar[int]
V: typing_extensions.TypeAlias = ValueType ADDRS_FIELD_NUMBER: _ClassVar[int]
CONNECTION_FIELD_NUMBER: _ClassVar[int]
class _MessageTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._MessageType.ValueType], builtins.type): SIGNEDRECORD_FIELD_NUMBER: _ClassVar[int]
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor id: bytes
PUT_VALUE: Message._MessageType.ValueType # 0 addrs: _containers.RepeatedScalarFieldContainer[bytes]
GET_VALUE: Message._MessageType.ValueType # 1 connection: Message.ConnectionType
ADD_PROVIDER: Message._MessageType.ValueType # 2 signedRecord: bytes
GET_PROVIDERS: Message._MessageType.ValueType # 3 def __init__(self, id: _Optional[bytes] = ..., addrs: _Optional[_Iterable[bytes]] = ..., connection: _Optional[_Union[Message.ConnectionType, str]] = ..., signedRecord: _Optional[bytes] = ...) -> None: ...
FIND_NODE: Message._MessageType.ValueType # 4 TYPE_FIELD_NUMBER: _ClassVar[int]
PING: Message._MessageType.ValueType # 5 CLUSTERLEVELRAW_FIELD_NUMBER: _ClassVar[int]
KEY_FIELD_NUMBER: _ClassVar[int]
class MessageType(_MessageType, metaclass=_MessageTypeEnumTypeWrapper): ... RECORD_FIELD_NUMBER: _ClassVar[int]
PUT_VALUE: Message.MessageType.ValueType # 0 CLOSERPEERS_FIELD_NUMBER: _ClassVar[int]
GET_VALUE: Message.MessageType.ValueType # 1 PROVIDERPEERS_FIELD_NUMBER: _ClassVar[int]
ADD_PROVIDER: Message.MessageType.ValueType # 2 SENDERRECORD_FIELD_NUMBER: _ClassVar[int]
GET_PROVIDERS: Message.MessageType.ValueType # 3 type: Message.MessageType
FIND_NODE: Message.MessageType.ValueType # 4 clusterLevelRaw: int
PING: Message.MessageType.ValueType # 5 key: bytes
record: Record
class _ConnectionType: closerPeers: _containers.RepeatedCompositeFieldContainer[Message.Peer]
ValueType = typing.NewType("ValueType", builtins.int) providerPeers: _containers.RepeatedCompositeFieldContainer[Message.Peer]
V: typing_extensions.TypeAlias = ValueType senderRecord: bytes
def __init__(self, type: _Optional[_Union[Message.MessageType, str]] = ..., clusterLevelRaw: _Optional[int] = ..., key: _Optional[bytes] = ..., record: _Optional[_Union[Record, _Mapping]] = ..., closerPeers: _Optional[_Iterable[_Union[Message.Peer, _Mapping]]] = ..., providerPeers: _Optional[_Iterable[_Union[Message.Peer, _Mapping]]] = ..., senderRecord: _Optional[bytes] = ...) -> None: ... # type: ignore
class _ConnectionTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._ConnectionType.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
NOT_CONNECTED: Message._ConnectionType.ValueType # 0
CONNECTED: Message._ConnectionType.ValueType # 1
CAN_CONNECT: Message._ConnectionType.ValueType # 2
CANNOT_CONNECT: Message._ConnectionType.ValueType # 3
class ConnectionType(_ConnectionType, metaclass=_ConnectionTypeEnumTypeWrapper): ...
NOT_CONNECTED: Message.ConnectionType.ValueType # 0
CONNECTED: Message.ConnectionType.ValueType # 1
CAN_CONNECT: Message.ConnectionType.ValueType # 2
CANNOT_CONNECT: Message.ConnectionType.ValueType # 3
@typing.final
class Peer(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
ID_FIELD_NUMBER: builtins.int
ADDRS_FIELD_NUMBER: builtins.int
CONNECTION_FIELD_NUMBER: builtins.int
id: builtins.bytes
connection: global___Message.ConnectionType.ValueType
@property
def addrs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ...
def __init__(
self,
*,
id: builtins.bytes = ...,
addrs: collections.abc.Iterable[builtins.bytes] | None = ...,
connection: global___Message.ConnectionType.ValueType = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["addrs", b"addrs", "connection", b"connection", "id", b"id"]) -> None: ...
TYPE_FIELD_NUMBER: builtins.int
CLUSTERLEVELRAW_FIELD_NUMBER: builtins.int
KEY_FIELD_NUMBER: builtins.int
RECORD_FIELD_NUMBER: builtins.int
CLOSERPEERS_FIELD_NUMBER: builtins.int
PROVIDERPEERS_FIELD_NUMBER: builtins.int
type: global___Message.MessageType.ValueType
clusterLevelRaw: builtins.int
key: builtins.bytes
@property
def record(self) -> global___Record: ...
@property
def closerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ...
@property
def providerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ...
def __init__(
self,
*,
type: global___Message.MessageType.ValueType = ...,
clusterLevelRaw: builtins.int = ...,
key: builtins.bytes = ...,
record: global___Record | None = ...,
closerPeers: collections.abc.Iterable[global___Message.Peer] | None = ...,
providerPeers: collections.abc.Iterable[global___Message.Peer] | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["record", b"record"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["closerPeers", b"closerPeers", "clusterLevelRaw", b"clusterLevelRaw", "key", b"key", "providerPeers", b"providerPeers", "record", b"record", "type", b"type"]) -> None: ...
global___Message = Message

View File

@ -15,12 +15,14 @@ from libp2p.abc import (
INetStream, INetStream,
IPeerRouting, IPeerRouting,
) )
from libp2p.peer.envelope import Envelope
from libp2p.peer.id import ( from libp2p.peer.id import (
ID, ID,
) )
from libp2p.peer.peerinfo import ( from libp2p.peer.peerinfo import (
PeerInfo, PeerInfo,
) )
from libp2p.peer.peerstore import env_to_send_in_RPC
from .common import ( from .common import (
ALPHA, ALPHA,
@ -33,6 +35,7 @@ from .routing_table import (
RoutingTable, RoutingTable,
) )
from .utils import ( from .utils import (
maybe_consume_signed_record,
sort_peer_ids_by_distance, sort_peer_ids_by_distance,
) )
@ -255,6 +258,10 @@ class PeerRouting(IPeerRouting):
find_node_msg.type = Message.MessageType.FIND_NODE find_node_msg.type = Message.MessageType.FIND_NODE
find_node_msg.key = target_key # Set target key directly as bytes find_node_msg.key = target_key # Set target key directly as bytes
# Create sender_signed_peer_record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
find_node_msg.senderRecord = envelope_bytes
# Serialize and send the protobuf message with varint length prefix # Serialize and send the protobuf message with varint length prefix
proto_bytes = find_node_msg.SerializeToString() proto_bytes = find_node_msg.SerializeToString()
logger.debug( logger.debug(
@ -299,7 +306,22 @@ class PeerRouting(IPeerRouting):
# Process closest peers from response # Process closest peers from response
if response_msg.type == Message.MessageType.FIND_NODE: if response_msg.type == Message.MessageType.FIND_NODE:
# Consume the sender_signed_peer_record
if not maybe_consume_signed_record(response_msg, self.host, peer):
logger.error(
"Received an invalid-signed-record,ignoring the response"
)
return []
for peer_data in response_msg.closerPeers: for peer_data in response_msg.closerPeers:
# Consume the received closer_peers signed-records, peer-id is
# sent with the peer-data
if not maybe_consume_signed_record(peer_data, self.host):
logger.error(
"Received an invalid-signed-record,ignoring the response"
)
return []
new_peer_id = ID(peer_data.id) new_peer_id = ID(peer_data.id)
if new_peer_id not in results: if new_peer_id not in results:
results.append(new_peer_id) results.append(new_peer_id)
@ -332,6 +354,7 @@ class PeerRouting(IPeerRouting):
""" """
try: try:
# Read message length # Read message length
peer_id = stream.muxed_conn.peer_id
length_bytes = await stream.read(4) length_bytes = await stream.read(4)
if not length_bytes: if not length_bytes:
return return
@ -345,10 +368,18 @@ class PeerRouting(IPeerRouting):
# Parse protobuf message # Parse protobuf message
kad_message = Message() kad_message = Message()
closer_peer_envelope: Envelope | None = None
try: try:
kad_message.ParseFromString(message_bytes) kad_message.ParseFromString(message_bytes)
if kad_message.type == Message.MessageType.FIND_NODE: if kad_message.type == Message.MessageType.FIND_NODE:
# Consume the sender's signed-peer-record if sent
if not maybe_consume_signed_record(kad_message, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, dropping the stream"
)
return
# Get target key directly from protobuf message # Get target key directly from protobuf message
target_key = kad_message.key target_key = kad_message.key
@ -361,12 +392,26 @@ class PeerRouting(IPeerRouting):
response = Message() response = Message()
response.type = Message.MessageType.FIND_NODE response.type = Message.MessageType.FIND_NODE
# Create sender_signed_peer_record for the response
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
# Add peer information to response # Add peer information to response
for peer_id in closest_peers: for peer_id in closest_peers:
peer_proto = response.closerPeers.add() peer_proto = response.closerPeers.add()
peer_proto.id = peer_id.to_bytes() peer_proto.id = peer_id.to_bytes()
peer_proto.connection = Message.ConnectionType.CAN_CONNECT peer_proto.connection = Message.ConnectionType.CAN_CONNECT
# Add the signed-records of closest_peers if cached
closer_peer_envelope = (
self.host.get_peerstore().get_peer_record(peer_id)
)
if isinstance(closer_peer_envelope, Envelope):
peer_proto.signedRecord = (
closer_peer_envelope.marshal_envelope()
)
# Add addresses if available # Add addresses if available
try: try:
addrs = self.host.get_peerstore().addrs(peer_id) addrs = self.host.get_peerstore().addrs(peer_id)

View File

@ -22,12 +22,14 @@ from libp2p.abc import (
from libp2p.custom_types import ( from libp2p.custom_types import (
TProtocol, TProtocol,
) )
from libp2p.kad_dht.utils import maybe_consume_signed_record
from libp2p.peer.id import ( from libp2p.peer.id import (
ID, ID,
) )
from libp2p.peer.peerinfo import ( from libp2p.peer.peerinfo import (
PeerInfo, PeerInfo,
) )
from libp2p.peer.peerstore import env_to_send_in_RPC
from .common import ( from .common import (
ALPHA, ALPHA,
@ -240,11 +242,18 @@ class ProviderStore:
message.type = Message.MessageType.ADD_PROVIDER message.type = Message.MessageType.ADD_PROVIDER
message.key = key message.key = key
# Create sender's signed-peer-record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
message.senderRecord = envelope_bytes
# Add our provider info # Add our provider info
provider = message.providerPeers.add() provider = message.providerPeers.add()
provider.id = self.local_peer_id.to_bytes() provider.id = self.local_peer_id.to_bytes()
provider.addrs.extend(addrs) provider.addrs.extend(addrs)
# Add the provider's signed-peer-record
provider.signedRecord = envelope_bytes
# Serialize and send the message # Serialize and send the message
proto_bytes = message.SerializeToString() proto_bytes = message.SerializeToString()
await stream.write(varint.encode(len(proto_bytes))) await stream.write(varint.encode(len(proto_bytes)))
@ -276,10 +285,15 @@ class ProviderStore:
response = Message() response = Message()
response.ParseFromString(response_bytes) response.ParseFromString(response_bytes)
# Check response type if response.type == Message.MessageType.ADD_PROVIDER:
response.type == Message.MessageType.ADD_PROVIDER # Consume the sender's signed-peer-record if sent
if response.type: if not maybe_consume_signed_record(response, self.host, peer_id):
result = True logger.error(
"Received an invalid-signed-record, ignoring the response"
)
result = False
else:
result = True
except Exception as e: except Exception as e:
logger.warning(f"Error sending ADD_PROVIDER to {peer_id}: {e}") logger.warning(f"Error sending ADD_PROVIDER to {peer_id}: {e}")
@ -380,6 +394,10 @@ class ProviderStore:
message.type = Message.MessageType.GET_PROVIDERS message.type = Message.MessageType.GET_PROVIDERS
message.key = key message.key = key
# Create sender's signed-peer-record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
message.senderRecord = envelope_bytes
# Serialize and send the message # Serialize and send the message
proto_bytes = message.SerializeToString() proto_bytes = message.SerializeToString()
await stream.write(varint.encode(len(proto_bytes))) await stream.write(varint.encode(len(proto_bytes)))
@ -414,10 +432,26 @@ class ProviderStore:
if response.type != Message.MessageType.GET_PROVIDERS: if response.type != Message.MessageType.GET_PROVIDERS:
return [] return []
# Consume the sender's signed-peer-record if sent
if not maybe_consume_signed_record(response, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, ignoring the response"
)
return []
# Extract provider information # Extract provider information
providers = [] providers = []
for provider_proto in response.providerPeers: for provider_proto in response.providerPeers:
try: try:
# Consume the provider's signed-peer-record if sent, peer-id
# already sent with the provider-proto
if not maybe_consume_signed_record(provider_proto, self.host):
logger.error(
"Received an invalid-signed-record, "
"ignoring the response"
)
return []
# Create peer ID from bytes # Create peer ID from bytes
provider_id = ID(provider_proto.id) provider_id = ID(provider_proto.id)
@ -431,6 +465,7 @@ class ProviderStore:
# Create PeerInfo and add to result # Create PeerInfo and add to result
providers.append(PeerInfo(provider_id, addrs)) providers.append(PeerInfo(provider_id, addrs))
except Exception as e: except Exception as e:
logger.warning(f"Failed to parse provider info: {e}") logger.warning(f"Failed to parse provider info: {e}")

View File

@ -2,13 +2,93 @@
Utility functions for Kademlia DHT implementation. Utility functions for Kademlia DHT implementation.
""" """
import logging
import base58 import base58
import multihash import multihash
from libp2p.abc import IHost
from libp2p.peer.envelope import consume_envelope
from libp2p.peer.id import ( from libp2p.peer.id import (
ID, ID,
) )
from .pb.kademlia_pb2 import (
Message,
)
logger = logging.getLogger("kademlia-example.utils")
def maybe_consume_signed_record(
msg: Message | Message.Peer, host: IHost, peer_id: ID | None = None
) -> bool:
"""
Attempt to parse and store a signed-peer-record (Envelope) received during
DHT communication. If the record is invalid, the peer-id does not match, or
updating the peerstore fails, the function logs an error and returns False.
Parameters
----------
msg : Message | Message.Peer
The protobuf message received during DHT communication. Can either be a
top-level `Message` containing `senderRecord` or a `Message.Peer`
containing `signedRecord`.
host : IHost
The local host instance, providing access to the peerstore for storing
verified peer records.
peer_id : ID | None, optional
The expected peer ID for record validation. If provided, the peer ID
inside the record must match this value.
Returns
-------
bool
True if a valid signed peer record was successfully consumed and stored,
False otherwise.
"""
if isinstance(msg, Message):
if msg.HasField("senderRecord"):
try:
# Convert the signed-peer-record(Envelope) from
# protobuf bytes
envelope, record = consume_envelope(
msg.senderRecord,
"libp2p-peer-record",
)
if not (isinstance(peer_id, ID) and record.peer_id == peer_id):
return False
# Use the default TTL of 2 hours (7200 seconds)
if not host.get_peerstore().consume_peer_record(envelope, 7200):
logger.error("Failed to update the Certified-Addr-Book")
return False
except Exception as e:
logger.error("Failed to update the Certified-Addr-Book: %s", e)
return False
else:
if msg.HasField("signedRecord"):
try:
# Convert the signed-peer-record(Envelope) from
# protobuf bytes
envelope, record = consume_envelope(
msg.signedRecord,
"libp2p-peer-record",
)
if not record.peer_id.to_bytes() == msg.id:
return False
# Use the default TTL of 2 hours (7200 seconds)
if not host.get_peerstore().consume_peer_record(envelope, 7200):
logger.error("Failed to update the Certified-Addr-Book")
return False
except Exception as e:
logger.error(
"Failed to update the Certified-Addr-Book: %s",
e,
)
return False
return True
def create_key_from_binary(binary_data: bytes) -> bytes: def create_key_from_binary(binary_data: bytes) -> bytes:
""" """

View File

@ -15,9 +15,11 @@ from libp2p.abc import (
from libp2p.custom_types import ( from libp2p.custom_types import (
TProtocol, TProtocol,
) )
from libp2p.kad_dht.utils import maybe_consume_signed_record
from libp2p.peer.id import ( from libp2p.peer.id import (
ID, ID,
) )
from libp2p.peer.peerstore import env_to_send_in_RPC
from .common import ( from .common import (
DEFAULT_TTL, DEFAULT_TTL,
@ -110,6 +112,10 @@ class ValueStore:
message = Message() message = Message()
message.type = Message.MessageType.PUT_VALUE message.type = Message.MessageType.PUT_VALUE
# Create sender's signed-peer-record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
message.senderRecord = envelope_bytes
# Set message fields # Set message fields
message.key = key message.key = key
message.record.key = key message.record.key = key
@ -155,7 +161,13 @@ class ValueStore:
# Check if response is valid # Check if response is valid
if response.type == Message.MessageType.PUT_VALUE: if response.type == Message.MessageType.PUT_VALUE:
if response.key: # Consume the sender's signed-peer-record if sent
if not maybe_consume_signed_record(response, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, ignoring the response"
)
return False
if response.key == key:
result = True result = True
return result return result
@ -231,6 +243,10 @@ class ValueStore:
message.type = Message.MessageType.GET_VALUE message.type = Message.MessageType.GET_VALUE
message.key = key message.key = key
# Create sender's signed-peer-record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
message.senderRecord = envelope_bytes
# Serialize and send the protobuf message # Serialize and send the protobuf message
proto_bytes = message.SerializeToString() proto_bytes = message.SerializeToString()
await stream.write(varint.encode(len(proto_bytes))) await stream.write(varint.encode(len(proto_bytes)))
@ -275,6 +291,13 @@ class ValueStore:
and response.HasField("record") and response.HasField("record")
and response.record.value and response.record.value
): ):
# Consume the sender's signed-peer-record
if not maybe_consume_signed_record(response, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, ignoring the response"
)
return None
logger.debug( logger.debug(
f"Received value for key {key.hex()} from peer {peer_id}" f"Received value for key {key.hex()} from peer {peer_id}"
) )

View File

@ -1,5 +1,7 @@
from typing import Any, cast from typing import Any, cast
import multiaddr
from libp2p.crypto.ed25519 import Ed25519PublicKey from libp2p.crypto.ed25519 import Ed25519PublicKey
from libp2p.crypto.keys import PrivateKey, PublicKey from libp2p.crypto.keys import PrivateKey, PublicKey
from libp2p.crypto.rsa import RSAPublicKey from libp2p.crypto.rsa import RSAPublicKey
@ -131,6 +133,9 @@ class Envelope:
) )
return False return False
def _env_addrs_set(self) -> set[multiaddr.Multiaddr]:
return {b for b in self.record().addrs}
def pub_key_to_protobuf(pub_key: PublicKey) -> cryto_pb.PublicKey: def pub_key_to_protobuf(pub_key: PublicKey) -> cryto_pb.PublicKey:
""" """

View File

@ -16,6 +16,7 @@ import trio
from trio import MemoryReceiveChannel, MemorySendChannel from trio import MemoryReceiveChannel, MemorySendChannel
from libp2p.abc import ( from libp2p.abc import (
IHost,
IPeerStore, IPeerStore,
) )
from libp2p.crypto.keys import ( from libp2p.crypto.keys import (
@ -23,7 +24,8 @@ from libp2p.crypto.keys import (
PrivateKey, PrivateKey,
PublicKey, PublicKey,
) )
from libp2p.peer.envelope import Envelope from libp2p.peer.envelope import Envelope, seal_record
from libp2p.peer.peer_record import PeerRecord
from .id import ( from .id import (
ID, ID,
@ -39,6 +41,86 @@ from .peerinfo import (
PERMANENT_ADDR_TTL = 0 PERMANENT_ADDR_TTL = 0
def create_signed_peer_record(
peer_id: ID, addrs: list[Multiaddr], pvt_key: PrivateKey
) -> Envelope:
"""Creates a signed_peer_record wrapped in an Envelope"""
record = PeerRecord(peer_id, addrs)
envelope = seal_record(record, pvt_key)
return envelope
def env_to_send_in_RPC(host: IHost) -> tuple[bytes, bool]:
"""
Return the signed peer record (Envelope) to be sent in an RPC.
This function checks whether the host already has a cached signed peer record
(SPR). If one exists and its addresses match the host's current listen
addresses, the cached envelope is reused. Otherwise, a new signed peer record
is created, cached, and returned.
Parameters
----------
host : IHost
The local host instance, providing access to peer ID, listen addresses,
private key, and the peerstore.
Returns
-------
tuple[bytes, bool]
A 2-tuple where the first element is the serialized envelope (bytes)
for the signed peer record, and the second element is a boolean flag
indicating whether a new record was created (True) or an existing cached
one was reused (False).
"""
listen_addrs_set = {addr for addr in host.get_addrs()}
local_env = host.get_peerstore().get_local_record()
if local_env is None:
# No cached SPR yet -> create one
return issue_and_cache_local_record(host), True
else:
record_addrs_set = local_env._env_addrs_set()
if record_addrs_set == listen_addrs_set:
# Perfect match -> reuse cached envelope
return local_env.marshal_envelope(), False
else:
# Addresses changed -> issue a new SPR and cache it
return issue_and_cache_local_record(host), True
def issue_and_cache_local_record(host: IHost) -> bytes:
"""
Create and cache a new signed peer record (Envelope) for the host.
This function generates a new signed peer record from the hosts peer ID,
listen addresses, and private key. The resulting envelope is stored in
the peerstore as the local record for future reuse.
Parameters
----------
host : IHost
The local host instance, providing access to peer ID, listen addresses,
private key, and the peerstore.
Returns
-------
bytes
The serialized envelope (bytes) representing the newly created signed
peer record.
"""
env = create_signed_peer_record(
host.get_id(),
host.get_addrs(),
host.get_private_key(),
)
# Cache it for next time use
host.get_peerstore().set_local_record(env)
return env.marshal_envelope()
class PeerRecordState: class PeerRecordState:
envelope: Envelope envelope: Envelope
seq: int seq: int
@ -55,8 +137,17 @@ class PeerStore(IPeerStore):
self.peer_data_map = defaultdict(PeerData) self.peer_data_map = defaultdict(PeerData)
self.addr_update_channels: dict[ID, MemorySendChannel[Multiaddr]] = {} self.addr_update_channels: dict[ID, MemorySendChannel[Multiaddr]] = {}
self.peer_record_map: dict[ID, PeerRecordState] = {} self.peer_record_map: dict[ID, PeerRecordState] = {}
self.local_peer_record: Envelope | None = None
self.max_records = max_records self.max_records = max_records
def get_local_record(self) -> Envelope | None:
"""Get the local-signed-record wrapped in Envelope"""
return self.local_peer_record
def set_local_record(self, envelope: Envelope) -> None:
"""Set the local-signed-record wrapped in Envelope"""
self.local_peer_record = envelope
def peer_info(self, peer_id: ID) -> PeerInfo: def peer_info(self, peer_id: ID) -> PeerInfo:
""" """
:param peer_id: peer ID to get info for :param peer_id: peer ID to get info for

View File

@ -0,0 +1 @@
KAD-DHT now include signed-peer-records in its protobuf message schema, for more secure peer-discovery.

View File

@ -9,11 +9,15 @@ This module tests core functionality of the Kademlia DHT including:
import hashlib import hashlib
import logging import logging
import os
from unittest.mock import patch
import uuid import uuid
import pytest import pytest
import multiaddr
import trio import trio
from libp2p.crypto.rsa import create_new_key_pair
from libp2p.kad_dht.kad_dht import ( from libp2p.kad_dht.kad_dht import (
DHTMode, DHTMode,
KadDHT, KadDHT,
@ -21,9 +25,13 @@ from libp2p.kad_dht.kad_dht import (
from libp2p.kad_dht.utils import ( from libp2p.kad_dht.utils import (
create_key_from_binary, create_key_from_binary,
) )
from libp2p.peer.envelope import Envelope, seal_record
from libp2p.peer.id import ID
from libp2p.peer.peer_record import PeerRecord
from libp2p.peer.peerinfo import ( from libp2p.peer.peerinfo import (
PeerInfo, PeerInfo,
) )
from libp2p.peer.peerstore import create_signed_peer_record
from libp2p.tools.async_service import ( from libp2p.tools.async_service import (
background_trio_service, background_trio_service,
) )
@ -76,10 +84,52 @@ async def test_find_node(dht_pair: tuple[KadDHT, KadDHT]):
"""Test that nodes can find each other in the DHT.""" """Test that nodes can find each other in the DHT."""
dht_a, dht_b = dht_pair dht_a, dht_b = dht_pair
# An extra FIND_NODE req is sent between the 2 nodes while dht creation,
# so both the nodes will have records of each other before the next FIND_NODE
# req is sent
envelope_a = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id())
envelope_b = dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id())
assert isinstance(envelope_a, Envelope)
assert isinstance(envelope_b, Envelope)
record_a = envelope_a.record()
record_b = envelope_b.record()
# Node A should be able to find Node B # Node A should be able to find Node B
with trio.fail_after(TEST_TIMEOUT): with trio.fail_after(TEST_TIMEOUT):
found_info = await dht_a.find_peer(dht_b.host.get_id()) found_info = await dht_a.find_peer(dht_b.host.get_id())
# Verifies if the senderRecord in the FIND_NODE request is correctly processed
assert isinstance(
dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id()), Envelope
)
# Verifies if the senderRecord in the FIND_NODE response is correctly processed
assert isinstance(
dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id()), Envelope
)
# These are the records that were sent between the peers during the FIND_NODE req
envelope_a_find_peer = dht_a.host.get_peerstore().get_peer_record(
dht_b.host.get_id()
)
envelope_b_find_peer = dht_b.host.get_peerstore().get_peer_record(
dht_a.host.get_id()
)
assert isinstance(envelope_a_find_peer, Envelope)
assert isinstance(envelope_b_find_peer, Envelope)
record_a_find_peer = envelope_a_find_peer.record()
record_b_find_peer = envelope_b_find_peer.record()
# This proves that both the records are same, and a latest cached signed record
# was passed between the peers during FIND_NODE execution, which proves the
# signed-record transfer/re-issuing works correctly in FIND_NODE executions.
assert record_a.seq == record_a_find_peer.seq
assert record_b.seq == record_b_find_peer.seq
# Verify that the found peer has the correct peer ID # Verify that the found peer has the correct peer ID
assert found_info is not None, "Failed to find the target peer" assert found_info is not None, "Failed to find the target peer"
assert found_info.peer_id == dht_b.host.get_id(), "Found incorrect peer ID" assert found_info.peer_id == dht_b.host.get_id(), "Found incorrect peer ID"
@ -104,14 +154,44 @@ async def test_put_and_get_value(dht_pair: tuple[KadDHT, KadDHT]):
await dht_a.routing_table.add_peer(peer_b_info) await dht_a.routing_table.add_peer(peer_b_info)
print("Routing table of a has ", dht_a.routing_table.get_peer_ids()) print("Routing table of a has ", dht_a.routing_table.get_peer_ids())
# An extra FIND_NODE req is sent between the 2 nodes while dht creation,
# so both the nodes will have records of each other before PUT_VALUE req is sent
envelope_a = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id())
envelope_b = dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id())
assert isinstance(envelope_a, Envelope)
assert isinstance(envelope_b, Envelope)
record_a = envelope_a.record()
record_b = envelope_b.record()
# Store the value using the first node (this will also store locally) # Store the value using the first node (this will also store locally)
with trio.fail_after(TEST_TIMEOUT): with trio.fail_after(TEST_TIMEOUT):
await dht_a.put_value(key, value) await dht_a.put_value(key, value)
# These are the records that were sent between the peers during the PUT_VALUE req
envelope_a_put_value = dht_a.host.get_peerstore().get_peer_record(
dht_b.host.get_id()
)
envelope_b_put_value = dht_b.host.get_peerstore().get_peer_record(
dht_a.host.get_id()
)
assert isinstance(envelope_a_put_value, Envelope)
assert isinstance(envelope_b_put_value, Envelope)
record_a_put_value = envelope_a_put_value.record()
record_b_put_value = envelope_b_put_value.record()
# This proves that both the records are same, and a latest cached signed record
# was passed between the peers during PUT_VALUE execution, which proves the
# signed-record transfer/re-issuing works correctly in PUT_VALUE executions.
assert record_a.seq == record_a_put_value.seq
assert record_b.seq == record_b_put_value.seq
# # Log debugging information # # Log debugging information
logger.debug("Put value with key %s...", key.hex()[:10]) logger.debug("Put value with key %s...", key.hex()[:10])
logger.debug("Node A value store: %s", dht_a.value_store.store) logger.debug("Node A value store: %s", dht_a.value_store.store)
print("hello test")
# # Allow more time for the value to propagate # # Allow more time for the value to propagate
await trio.sleep(0.5) await trio.sleep(0.5)
@ -126,6 +206,26 @@ async def test_put_and_get_value(dht_pair: tuple[KadDHT, KadDHT]):
print("the value stored in node b is", dht_b.get_value_store_size()) print("the value stored in node b is", dht_b.get_value_store_size())
logger.debug("Retrieved value: %s", retrieved_value) logger.debug("Retrieved value: %s", retrieved_value)
# These are the records that were sent between the peers during the PUT_VALUE req
envelope_a_get_value = dht_a.host.get_peerstore().get_peer_record(
dht_b.host.get_id()
)
envelope_b_get_value = dht_b.host.get_peerstore().get_peer_record(
dht_a.host.get_id()
)
assert isinstance(envelope_a_get_value, Envelope)
assert isinstance(envelope_b_get_value, Envelope)
record_a_get_value = envelope_a_get_value.record()
record_b_get_value = envelope_b_get_value.record()
# This proves that there was no record exchange between the nodes during GET_VALUE
# execution, as dht_b already had the key/value pair stored locally after the
# PUT_VALUE execution.
assert record_a_get_value.seq == record_a_put_value.seq
assert record_b_get_value.seq == record_b_put_value.seq
# Verify that the retrieved value matches the original # Verify that the retrieved value matches the original
assert retrieved_value == value, "Retrieved value does not match the stored value" assert retrieved_value == value, "Retrieved value does not match the stored value"
@ -142,11 +242,44 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]):
# Store content on the first node # Store content on the first node
dht_a.value_store.put(content_id, content) dht_a.value_store.put(content_id, content)
# An extra FIND_NODE req is sent between the 2 nodes while dht creation,
# so both the nodes will have records of each other before PUT_VALUE req is sent
envelope_a = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id())
envelope_b = dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id())
assert isinstance(envelope_a, Envelope)
assert isinstance(envelope_b, Envelope)
record_a = envelope_a.record()
record_b = envelope_b.record()
# Advertise the first node as a provider # Advertise the first node as a provider
with trio.fail_after(TEST_TIMEOUT): with trio.fail_after(TEST_TIMEOUT):
success = await dht_a.provide(content_id) success = await dht_a.provide(content_id)
assert success, "Failed to advertise as provider" assert success, "Failed to advertise as provider"
# These are the records that were sent between the peers during
# the ADD_PROVIDER req
envelope_a_add_prov = dht_a.host.get_peerstore().get_peer_record(
dht_b.host.get_id()
)
envelope_b_add_prov = dht_b.host.get_peerstore().get_peer_record(
dht_a.host.get_id()
)
assert isinstance(envelope_a_add_prov, Envelope)
assert isinstance(envelope_b_add_prov, Envelope)
record_a_add_prov = envelope_a_add_prov.record()
record_b_add_prov = envelope_b_add_prov.record()
# This proves that both the records are same, the latest cached signed record
# was passed between the peers during ADD_PROVIDER execution, which proves the
# signed-record transfer/re-issuing of the latest record works correctly in
# ADD_PROVIDER executions.
assert record_a.seq == record_a_add_prov.seq
assert record_b.seq == record_b_add_prov.seq
# Allow time for the provider record to propagate # Allow time for the provider record to propagate
await trio.sleep(0.1) await trio.sleep(0.1)
@ -154,6 +287,26 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]):
with trio.fail_after(TEST_TIMEOUT): with trio.fail_after(TEST_TIMEOUT):
providers = await dht_b.find_providers(content_id) providers = await dht_b.find_providers(content_id)
# These are the records in each peer after the find_provider execution
envelope_a_find_prov = dht_a.host.get_peerstore().get_peer_record(
dht_b.host.get_id()
)
envelope_b_find_prov = dht_b.host.get_peerstore().get_peer_record(
dht_a.host.get_id()
)
assert isinstance(envelope_a_find_prov, Envelope)
assert isinstance(envelope_b_find_prov, Envelope)
record_a_find_prov = envelope_a_find_prov.record()
record_b_find_prov = envelope_b_find_prov.record()
# This proves that both the records are same, as the dht_b already
# has the provider record for the content_id, after the ADD_PROVIDER
# advertisement by dht_a
assert record_a_find_prov.seq == record_a_add_prov.seq
assert record_b_find_prov.seq == record_b_add_prov.seq
# Verify that we found the first node as a provider # Verify that we found the first node as a provider
assert providers, "No providers found" assert providers, "No providers found"
assert any(p.peer_id == dht_a.local_peer_id for p in providers), ( assert any(p.peer_id == dht_a.local_peer_id for p in providers), (
@ -166,3 +319,143 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]):
assert retrieved_value == content, ( assert retrieved_value == content, (
"Retrieved content does not match the original" "Retrieved content does not match the original"
) )
# These are the record state of each peer aftet the GET_VALUE execution
envelope_a_get_value = dht_a.host.get_peerstore().get_peer_record(
dht_b.host.get_id()
)
envelope_b_get_value = dht_b.host.get_peerstore().get_peer_record(
dht_a.host.get_id()
)
assert isinstance(envelope_a_get_value, Envelope)
assert isinstance(envelope_b_get_value, Envelope)
record_a_get_value = envelope_a_get_value.record()
record_b_get_value = envelope_b_get_value.record()
# This proves that both the records are same, meaning that the latest cached
# signed-record tranfer happened during the GET_VALUE execution by dht_b,
# which means the signed-record transfer/re-issuing works correctly
# in GET_VALUE executions.
assert record_a_find_prov.seq == record_a_get_value.seq
assert record_b_find_prov.seq == record_b_get_value.seq
# Create a new provider record in dht_a
provider_key_pair = create_new_key_pair()
provider_peer_id = ID.from_pubkey(provider_key_pair.public_key)
provider_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/123")
provider_peer_info = PeerInfo(peer_id=provider_peer_id, addrs=[provider_addr])
# Generate a random content ID
content_2 = f"random-content-{uuid.uuid4()}".encode()
content_id_2 = hashlib.sha256(content_2).digest()
provider_signed_envelope = create_signed_peer_record(
provider_peer_id, [provider_addr], provider_key_pair.private_key
)
assert (
dht_a.host.get_peerstore().consume_peer_record(provider_signed_envelope, 7200)
is True
)
# Store this provider record in dht_a
dht_a.provider_store.add_provider(content_id_2, provider_peer_info)
# Fetch the provider-record via peer-discovery at dht_b's end
peerinfo = await dht_b.provider_store.find_providers(content_id_2)
assert len(peerinfo) == 1
assert peerinfo[0].peer_id == provider_peer_id
provider_envelope = dht_b.host.get_peerstore().get_peer_record(provider_peer_id)
# This proves that the signed-envelope of provider is consumed on dht_b's end
assert provider_envelope is not None
assert (
provider_signed_envelope.marshal_envelope()
== provider_envelope.marshal_envelope()
)
@pytest.mark.trio
async def test_reissue_when_listen_addrs_change(dht_pair: tuple[KadDHT, KadDHT]):
dht_a, dht_b = dht_pair
# Warm-up: A stores B's current record
with trio.fail_after(10):
await dht_a.find_peer(dht_b.host.get_id())
env0 = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id())
assert isinstance(env0, Envelope)
seq0 = env0.record().seq
# Simulate B's listen addrs changing (different port)
new_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/123")
# Patch just for the duration we force B to respond:
with patch.object(dht_b.host, "get_addrs", return_value=[new_addr]):
# Force B to send a response (which should include a fresh SPR)
with trio.fail_after(10):
await dht_a.peer_routing._query_peer_for_closest(
dht_b.host.get_id(), os.urandom(32)
)
# A should now hold B's new record with a bumped seq
env1 = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id())
assert isinstance(env1, Envelope)
seq1 = env1.record().seq
# This proves that upon the change in listen_addrs, we issue new records
assert seq1 > seq0, f"Expected seq to bump after addr change, got {seq0} -> {seq1}"
@pytest.mark.trio
async def test_dht_req_fail_with_invalid_record_transfer(
dht_pair: tuple[KadDHT, KadDHT],
):
"""
Testing showing failure of storing and retrieving values in the DHT,
if invalid signed-records are sent.
"""
dht_a, dht_b = dht_pair
peer_b_info = PeerInfo(dht_b.host.get_id(), dht_b.host.get_addrs())
# Generate a random key and value
key = create_key_from_binary(b"test-key")
value = b"test-value"
# First add the value directly to node A's store to verify storage works
dht_a.value_store.put(key, value)
local_value = dht_a.value_store.get(key)
assert local_value == value, "Local value storage failed"
await dht_a.routing_table.add_peer(peer_b_info)
# Corrupt dht_a's local peer_record
envelope = dht_a.host.get_peerstore().get_local_record()
if envelope is not None:
true_record = envelope.record()
key_pair = create_new_key_pair()
if envelope is not None:
envelope.public_key = key_pair.public_key
dht_a.host.get_peerstore().set_local_record(envelope)
await dht_a.put_value(key, value)
retrieved_value = dht_b.value_store.get(key)
# This proves that DHT_B rejected DHT_A PUT_RECORD req upon receiving
# the corrupted invalid record
assert retrieved_value is None
# Create a corrupt envelope with correct signature but false peer_id
false_record = PeerRecord(ID.from_pubkey(key_pair.public_key), true_record.addrs)
false_envelope = seal_record(false_record, dht_a.host.get_private_key())
dht_a.host.get_peerstore().set_local_record(false_envelope)
await dht_a.put_value(key, value)
retrieved_value = dht_b.value_store.get(key)
# This proves that DHT_B rejected DHT_A PUT_RECORD req upon receving
# the record with a different peer_id regardless of a valid signature
assert retrieved_value is None

View File

@ -57,7 +57,10 @@ class TestPeerRouting:
def mock_host(self): def mock_host(self):
"""Create a mock host for testing.""" """Create a mock host for testing."""
host = Mock() host = Mock()
host.get_id.return_value = create_valid_peer_id("local") key_pair = create_new_key_pair()
host.get_id.return_value = ID.from_pubkey(key_pair.public_key)
host.get_public_key.return_value = key_pair.public_key
host.get_private_key.return_value = key_pair.private_key
host.get_addrs.return_value = [Multiaddr("/ip4/127.0.0.1/tcp/8000")] host.get_addrs.return_value = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
host.get_peerstore.return_value = Mock() host.get_peerstore.return_value = Mock()
host.new_stream = AsyncMock() host.new_stream = AsyncMock()