Merge branch 'main' into chore01

This commit is contained in:
Manu Sheel Gupta
2025-08-29 03:14:51 +05:30
committed by GitHub
17 changed files with 818 additions and 160 deletions

View File

@ -49,6 +49,7 @@ from libp2p.peer.id import (
)
from libp2p.peer.peerstore import (
PeerStore,
create_signed_peer_record,
)
from libp2p.security.insecure.transport import (
PLAINTEXT_PROTOCOL_ID,
@ -155,7 +156,6 @@ def get_default_muxer_options() -> TMuxerOptions:
else: # YAMUX is default
return create_yamux_muxer_option()
def new_swarm(
key_pair: KeyPair | None = None,
muxer_opt: TMuxerOptions | None = None,

View File

@ -970,6 +970,14 @@ class IPeerStore(
# --------CERTIFIED-ADDR-BOOK----------
@abstractmethod
def get_local_record(self) -> Optional["Envelope"]:
"""Get the local-peer-record wrapped in Envelope"""
@abstractmethod
def set_local_record(self, envelope: "Envelope") -> None:
"""Set the local-peer-record wrapped in Envelope"""
@abstractmethod
def consume_peer_record(self, envelope: "Envelope", ttl: int) -> bool:
"""

View File

@ -43,6 +43,7 @@ from libp2p.peer.id import (
from libp2p.peer.peerinfo import (
PeerInfo,
)
from libp2p.peer.peerstore import create_signed_peer_record
from libp2p.protocol_muxer.exceptions import (
MultiselectClientError,
MultiselectError,
@ -110,6 +111,14 @@ class BasicHost(IHost):
if bootstrap:
self.bootstrap = BootstrapDiscovery(network, bootstrap)
# Cache a signed-record if the local-node in the PeerStore
envelope = create_signed_peer_record(
self.get_id(),
self.get_addrs(),
self.get_private_key(),
)
self.get_peerstore().set_local_record(envelope)
def get_id(self) -> ID:
"""
:return: peer_id of host

View File

@ -15,8 +15,7 @@ from libp2p.custom_types import (
from libp2p.network.stream.exceptions import (
StreamClosed,
)
from libp2p.peer.envelope import seal_record
from libp2p.peer.peer_record import PeerRecord
from libp2p.peer.peerstore import env_to_send_in_RPC
from libp2p.utils import (
decode_varint_with_size,
get_agent_version,
@ -66,9 +65,7 @@ def _mk_identify_protobuf(
protocols = tuple(str(p) for p in host.get_mux().get_protocols() if p is not None)
# Create a signed peer-record for the remote peer
record = PeerRecord(host.get_id(), host.get_addrs())
envelope = seal_record(record, host.get_private_key())
protobuf = envelope.marshal_envelope()
envelope_bytes, _ = env_to_send_in_RPC(host)
observed_addr = observed_multiaddr.to_bytes() if observed_multiaddr else b""
return Identify(
@ -78,7 +75,7 @@ def _mk_identify_protobuf(
listen_addrs=map(_multiaddr_to_bytes, laddrs),
observed_addr=observed_addr,
protocols=protocols,
signedPeerRecord=protobuf,
signedPeerRecord=envelope_bytes,
)

View File

@ -22,15 +22,18 @@ from libp2p.abc import (
IHost,
)
from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager
from libp2p.kad_dht.utils import maybe_consume_signed_record
from libp2p.network.stream.net_stream import (
INetStream,
)
from libp2p.peer.envelope import Envelope
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
from libp2p.peer.peerstore import env_to_send_in_RPC
from libp2p.tools.async_service import (
Service,
)
@ -234,6 +237,9 @@ class KadDHT(Service):
await self.add_peer(peer_id)
logger.debug(f"Added peer {peer_id} to routing table")
closer_peer_envelope: Envelope | None = None
provider_peer_envelope: Envelope | None = None
try:
# Read varint-prefixed length for the message
length_prefix = b""
@ -274,6 +280,14 @@ class KadDHT(Service):
)
logger.debug(f"Found {len(closest_peers)} peers close to target")
# Consume the source signed_peer_record if sent
if not maybe_consume_signed_record(message, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, dropping the stream"
)
await stream.close()
return
# Build response message with protobuf
response = Message()
response.type = Message.MessageType.FIND_NODE
@ -298,6 +312,21 @@ class KadDHT(Service):
except Exception:
pass
# Add the signed-peer-record for each peer in the peer-proto
# if cached in the peerstore
closer_peer_envelope = (
self.host.get_peerstore().get_peer_record(peer)
)
if closer_peer_envelope is not None:
peer_proto.signedRecord = (
closer_peer_envelope.marshal_envelope()
)
# Create sender_signed_peer_record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
# Serialize and send response
response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes)))
@ -312,6 +341,14 @@ class KadDHT(Service):
key = message.key
logger.debug(f"Received ADD_PROVIDER for key {key.hex()}")
# Consume the source signed-peer-record if sent
if not maybe_consume_signed_record(message, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, dropping the stream"
)
await stream.close()
return
# Extract provider information
for provider_proto in message.providerPeers:
try:
@ -338,6 +375,17 @@ class KadDHT(Service):
logger.debug(
f"Added provider {provider_id} for key {key.hex()}"
)
# Process the signed-records of provider if sent
if not maybe_consume_signed_record(
provider_proto, self.host
):
logger.error(
"Received an invalid-signed-record,"
"dropping the stream"
)
await stream.close()
return
except Exception as e:
logger.warning(f"Failed to process provider info: {e}")
@ -346,6 +394,10 @@ class KadDHT(Service):
response.type = Message.MessageType.ADD_PROVIDER
response.key = key
# Add sender's signed-peer-record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes)))
await stream.write(response_bytes)
@ -357,6 +409,14 @@ class KadDHT(Service):
key = message.key
logger.debug(f"Received GET_PROVIDERS request for key {key.hex()}")
# Consume the source signed_peer_record if sent
if not maybe_consume_signed_record(message, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, dropping the stream"
)
await stream.close()
return
# Find providers for the key
providers = self.provider_store.get_providers(key)
logger.debug(
@ -368,12 +428,28 @@ class KadDHT(Service):
response.type = Message.MessageType.GET_PROVIDERS
response.key = key
# Create sender_signed_peer_record for the response
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
# Add provider information to response
for provider_info in providers:
provider_proto = response.providerPeers.add()
provider_proto.id = provider_info.peer_id.to_bytes()
provider_proto.connection = Message.ConnectionType.CAN_CONNECT
# Add provider signed-records if cached
provider_peer_envelope = (
self.host.get_peerstore().get_peer_record(
provider_info.peer_id
)
)
if provider_peer_envelope is not None:
provider_proto.signedRecord = (
provider_peer_envelope.marshal_envelope()
)
# Add addresses if available
for addr in provider_info.addrs:
provider_proto.addrs.append(addr.to_bytes())
@ -397,6 +473,16 @@ class KadDHT(Service):
peer_proto.id = peer.to_bytes()
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
# Add the signed-records of closest_peers if cached
closer_peer_envelope = (
self.host.get_peerstore().get_peer_record(peer)
)
if closer_peer_envelope is not None:
peer_proto.signedRecord = (
closer_peer_envelope.marshal_envelope()
)
# Add addresses if available
try:
addrs = self.host.get_peerstore().addrs(peer)
@ -417,6 +503,14 @@ class KadDHT(Service):
key = message.key
logger.debug(f"Received GET_VALUE request for key {key.hex()}")
# Consume the sender_signed_peer_record
if not maybe_consume_signed_record(message, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, dropping the stream"
)
await stream.close()
return
value = self.value_store.get(key)
if value:
logger.debug(f"Found value for key {key.hex()}")
@ -431,6 +525,10 @@ class KadDHT(Service):
response.record.value = value
response.record.timeReceived = str(time.time())
# Create sender_signed_peer_record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
# Serialize and send response
response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes)))
@ -444,6 +542,10 @@ class KadDHT(Service):
response.type = Message.MessageType.GET_VALUE
response.key = key
# Create sender_signed_peer_record for the response
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
# Add closest peers to key
closest_peers = self.routing_table.find_local_closest_peers(
key, 20
@ -462,6 +564,16 @@ class KadDHT(Service):
peer_proto.id = peer.to_bytes()
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
# Add signed-records of closer-peers if cached
closer_peer_envelope = (
self.host.get_peerstore().get_peer_record(peer)
)
if closer_peer_envelope is not None:
peer_proto.signedRecord = (
closer_peer_envelope.marshal_envelope()
)
# Add addresses if available
try:
addrs = self.host.get_peerstore().addrs(peer)
@ -484,6 +596,15 @@ class KadDHT(Service):
key = message.record.key
value = message.record.value
success = False
# Consume the source signed_peer_record if sent
if not maybe_consume_signed_record(message, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, dropping the stream"
)
await stream.close()
return
try:
if not (key and value):
raise ValueError(
@ -504,6 +625,12 @@ class KadDHT(Service):
response.type = Message.MessageType.PUT_VALUE
if success:
response.key = key
# Create sender_signed_peer_record for the response
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
# Serialize and send response
response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes)))
await stream.write(response_bytes)

View File

@ -27,6 +27,7 @@ message Message {
bytes id = 1;
repeated bytes addrs = 2;
ConnectionType connection = 3;
optional bytes signedRecord = 4; // Envelope(PeerRecord) encoded
}
MessageType type = 1;
@ -35,4 +36,6 @@ message Message {
Record record = 3;
repeated Peer closerPeers = 8;
repeated Peer providerPeers = 9;
optional bytes senderRecord = 11; // Envelope(PeerRecord) encoded
}

View File

@ -1,11 +1,12 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: libp2p/kad_dht/pb/kademlia.proto
# Protobuf Python Version: 4.25.3
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
@ -13,21 +14,21 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/kad_dht/pb/kademlia.proto\":\n\x06Record\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x14\n\x0ctimeReceived\x18\x05 \x01(\t\"\xca\x03\n\x07Message\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.Message.MessageType\x12\x17\n\x0f\x63lusterLevelRaw\x18\n \x01(\x05\x12\x0b\n\x03key\x18\x02 \x01(\x0c\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record\x12\"\n\x0b\x63loserPeers\x18\x08 \x03(\x0b\x32\r.Message.Peer\x12$\n\rproviderPeers\x18\t \x03(\x0b\x32\r.Message.Peer\x1aN\n\x04Peer\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12+\n\nconnection\x18\x03 \x01(\x0e\x32\x17.Message.ConnectionType\"i\n\x0bMessageType\x12\r\n\tPUT_VALUE\x10\x00\x12\r\n\tGET_VALUE\x10\x01\x12\x10\n\x0c\x41\x44\x44_PROVIDER\x10\x02\x12\x11\n\rGET_PROVIDERS\x10\x03\x12\r\n\tFIND_NODE\x10\x04\x12\x08\n\x04PING\x10\x05\"W\n\x0e\x43onnectionType\x12\x11\n\rNOT_CONNECTED\x10\x00\x12\r\n\tCONNECTED\x10\x01\x12\x0f\n\x0b\x43\x41N_CONNECT\x10\x02\x12\x12\n\x0e\x43\x41NNOT_CONNECT\x10\x03\x62\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/kad_dht/pb/kademlia.proto\":\n\x06Record\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x14\n\x0ctimeReceived\x18\x05 \x01(\t\"\xa2\x04\n\x07Message\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.Message.MessageType\x12\x17\n\x0f\x63lusterLevelRaw\x18\n \x01(\x05\x12\x0b\n\x03key\x18\x02 \x01(\x0c\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record\x12\"\n\x0b\x63loserPeers\x18\x08 \x03(\x0b\x32\r.Message.Peer\x12$\n\rproviderPeers\x18\t \x03(\x0b\x32\r.Message.Peer\x12\x19\n\x0csenderRecord\x18\x0b \x01(\x0cH\x00\x88\x01\x01\x1az\n\x04Peer\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12+\n\nconnection\x18\x03 \x01(\x0e\x32\x17.Message.ConnectionType\x12\x19\n\x0csignedRecord\x18\x04 \x01(\x0cH\x00\x88\x01\x01\x42\x0f\n\r_signedRecord\"i\n\x0bMessageType\x12\r\n\tPUT_VALUE\x10\x00\x12\r\n\tGET_VALUE\x10\x01\x12\x10\n\x0c\x41\x44\x44_PROVIDER\x10\x02\x12\x11\n\rGET_PROVIDERS\x10\x03\x12\r\n\tFIND_NODE\x10\x04\x12\x08\n\x04PING\x10\x05\"W\n\x0e\x43onnectionType\x12\x11\n\rNOT_CONNECTED\x10\x00\x12\r\n\tCONNECTED\x10\x01\x12\x0f\n\x0b\x43\x41N_CONNECT\x10\x02\x12\x12\n\x0e\x43\x41NNOT_CONNECT\x10\x03\x42\x0f\n\r_senderRecordb\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.kad_dht.pb.kademlia_pb2', globals())
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.kad_dht.pb.kademlia_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_RECORD._serialized_start=36
_RECORD._serialized_end=94
_MESSAGE._serialized_start=97
_MESSAGE._serialized_end=555
_MESSAGE_PEER._serialized_start=281
_MESSAGE_PEER._serialized_end=359
_MESSAGE_MESSAGETYPE._serialized_start=361
_MESSAGE_MESSAGETYPE._serialized_end=466
_MESSAGE_CONNECTIONTYPE._serialized_start=468
_MESSAGE_CONNECTIONTYPE._serialized_end=555
_globals['_RECORD']._serialized_start=36
_globals['_RECORD']._serialized_end=94
_globals['_MESSAGE']._serialized_start=97
_globals['_MESSAGE']._serialized_end=643
_globals['_MESSAGE_PEER']._serialized_start=308
_globals['_MESSAGE_PEER']._serialized_end=430
_globals['_MESSAGE_MESSAGETYPE']._serialized_start=432
_globals['_MESSAGE_MESSAGETYPE']._serialized_end=537
_globals['_MESSAGE_CONNECTIONTYPE']._serialized_start=539
_globals['_MESSAGE_CONNECTIONTYPE']._serialized_end=626
# @@protoc_insertion_point(module_scope)

View File

@ -1,133 +1,70 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
"""
from google.protobuf.internal import containers as _containers
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union
import builtins
import collections.abc
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.internal.enum_type_wrapper
import google.protobuf.message
import sys
import typing
DESCRIPTOR: _descriptor.FileDescriptor
if sys.version_info >= (3, 10):
import typing as typing_extensions
else:
import typing_extensions
class Record(_message.Message):
__slots__ = ("key", "value", "timeReceived")
KEY_FIELD_NUMBER: _ClassVar[int]
VALUE_FIELD_NUMBER: _ClassVar[int]
TIMERECEIVED_FIELD_NUMBER: _ClassVar[int]
key: bytes
value: bytes
timeReceived: str
def __init__(self, key: _Optional[bytes] = ..., value: _Optional[bytes] = ..., timeReceived: _Optional[str] = ...) -> None: ...
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@typing.final
class Record(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEY_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
TIMERECEIVED_FIELD_NUMBER: builtins.int
key: builtins.bytes
value: builtins.bytes
timeReceived: builtins.str
def __init__(
self,
*,
key: builtins.bytes = ...,
value: builtins.bytes = ...,
timeReceived: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["key", b"key", "timeReceived", b"timeReceived", "value", b"value"]) -> None: ...
global___Record = Record
@typing.final
class Message(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
class _MessageType:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _MessageTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._MessageType.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
PUT_VALUE: Message._MessageType.ValueType # 0
GET_VALUE: Message._MessageType.ValueType # 1
ADD_PROVIDER: Message._MessageType.ValueType # 2
GET_PROVIDERS: Message._MessageType.ValueType # 3
FIND_NODE: Message._MessageType.ValueType # 4
PING: Message._MessageType.ValueType # 5
class MessageType(_MessageType, metaclass=_MessageTypeEnumTypeWrapper): ...
PUT_VALUE: Message.MessageType.ValueType # 0
GET_VALUE: Message.MessageType.ValueType # 1
ADD_PROVIDER: Message.MessageType.ValueType # 2
GET_PROVIDERS: Message.MessageType.ValueType # 3
FIND_NODE: Message.MessageType.ValueType # 4
PING: Message.MessageType.ValueType # 5
class _ConnectionType:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _ConnectionTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._ConnectionType.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
NOT_CONNECTED: Message._ConnectionType.ValueType # 0
CONNECTED: Message._ConnectionType.ValueType # 1
CAN_CONNECT: Message._ConnectionType.ValueType # 2
CANNOT_CONNECT: Message._ConnectionType.ValueType # 3
class ConnectionType(_ConnectionType, metaclass=_ConnectionTypeEnumTypeWrapper): ...
NOT_CONNECTED: Message.ConnectionType.ValueType # 0
CONNECTED: Message.ConnectionType.ValueType # 1
CAN_CONNECT: Message.ConnectionType.ValueType # 2
CANNOT_CONNECT: Message.ConnectionType.ValueType # 3
@typing.final
class Peer(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
ID_FIELD_NUMBER: builtins.int
ADDRS_FIELD_NUMBER: builtins.int
CONNECTION_FIELD_NUMBER: builtins.int
id: builtins.bytes
connection: global___Message.ConnectionType.ValueType
@property
def addrs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ...
def __init__(
self,
*,
id: builtins.bytes = ...,
addrs: collections.abc.Iterable[builtins.bytes] | None = ...,
connection: global___Message.ConnectionType.ValueType = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["addrs", b"addrs", "connection", b"connection", "id", b"id"]) -> None: ...
TYPE_FIELD_NUMBER: builtins.int
CLUSTERLEVELRAW_FIELD_NUMBER: builtins.int
KEY_FIELD_NUMBER: builtins.int
RECORD_FIELD_NUMBER: builtins.int
CLOSERPEERS_FIELD_NUMBER: builtins.int
PROVIDERPEERS_FIELD_NUMBER: builtins.int
type: global___Message.MessageType.ValueType
clusterLevelRaw: builtins.int
key: builtins.bytes
@property
def record(self) -> global___Record: ...
@property
def closerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ...
@property
def providerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ...
def __init__(
self,
*,
type: global___Message.MessageType.ValueType = ...,
clusterLevelRaw: builtins.int = ...,
key: builtins.bytes = ...,
record: global___Record | None = ...,
closerPeers: collections.abc.Iterable[global___Message.Peer] | None = ...,
providerPeers: collections.abc.Iterable[global___Message.Peer] | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["record", b"record"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["closerPeers", b"closerPeers", "clusterLevelRaw", b"clusterLevelRaw", "key", b"key", "providerPeers", b"providerPeers", "record", b"record", "type", b"type"]) -> None: ...
global___Message = Message
class Message(_message.Message):
__slots__ = ("type", "clusterLevelRaw", "key", "record", "closerPeers", "providerPeers", "senderRecord")
class MessageType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
PUT_VALUE: _ClassVar[Message.MessageType]
GET_VALUE: _ClassVar[Message.MessageType]
ADD_PROVIDER: _ClassVar[Message.MessageType]
GET_PROVIDERS: _ClassVar[Message.MessageType]
FIND_NODE: _ClassVar[Message.MessageType]
PING: _ClassVar[Message.MessageType]
PUT_VALUE: Message.MessageType
GET_VALUE: Message.MessageType
ADD_PROVIDER: Message.MessageType
GET_PROVIDERS: Message.MessageType
FIND_NODE: Message.MessageType
PING: Message.MessageType
class ConnectionType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
NOT_CONNECTED: _ClassVar[Message.ConnectionType]
CONNECTED: _ClassVar[Message.ConnectionType]
CAN_CONNECT: _ClassVar[Message.ConnectionType]
CANNOT_CONNECT: _ClassVar[Message.ConnectionType]
NOT_CONNECTED: Message.ConnectionType
CONNECTED: Message.ConnectionType
CAN_CONNECT: Message.ConnectionType
CANNOT_CONNECT: Message.ConnectionType
class Peer(_message.Message):
__slots__ = ("id", "addrs", "connection", "signedRecord")
ID_FIELD_NUMBER: _ClassVar[int]
ADDRS_FIELD_NUMBER: _ClassVar[int]
CONNECTION_FIELD_NUMBER: _ClassVar[int]
SIGNEDRECORD_FIELD_NUMBER: _ClassVar[int]
id: bytes
addrs: _containers.RepeatedScalarFieldContainer[bytes]
connection: Message.ConnectionType
signedRecord: bytes
def __init__(self, id: _Optional[bytes] = ..., addrs: _Optional[_Iterable[bytes]] = ..., connection: _Optional[_Union[Message.ConnectionType, str]] = ..., signedRecord: _Optional[bytes] = ...) -> None: ...
TYPE_FIELD_NUMBER: _ClassVar[int]
CLUSTERLEVELRAW_FIELD_NUMBER: _ClassVar[int]
KEY_FIELD_NUMBER: _ClassVar[int]
RECORD_FIELD_NUMBER: _ClassVar[int]
CLOSERPEERS_FIELD_NUMBER: _ClassVar[int]
PROVIDERPEERS_FIELD_NUMBER: _ClassVar[int]
SENDERRECORD_FIELD_NUMBER: _ClassVar[int]
type: Message.MessageType
clusterLevelRaw: int
key: bytes
record: Record
closerPeers: _containers.RepeatedCompositeFieldContainer[Message.Peer]
providerPeers: _containers.RepeatedCompositeFieldContainer[Message.Peer]
senderRecord: bytes
def __init__(self, type: _Optional[_Union[Message.MessageType, str]] = ..., clusterLevelRaw: _Optional[int] = ..., key: _Optional[bytes] = ..., record: _Optional[_Union[Record, _Mapping]] = ..., closerPeers: _Optional[_Iterable[_Union[Message.Peer, _Mapping]]] = ..., providerPeers: _Optional[_Iterable[_Union[Message.Peer, _Mapping]]] = ..., senderRecord: _Optional[bytes] = ...) -> None: ... # type: ignore

View File

@ -15,12 +15,14 @@ from libp2p.abc import (
INetStream,
IPeerRouting,
)
from libp2p.peer.envelope import Envelope
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
from libp2p.peer.peerstore import env_to_send_in_RPC
from .common import (
ALPHA,
@ -33,6 +35,7 @@ from .routing_table import (
RoutingTable,
)
from .utils import (
maybe_consume_signed_record,
sort_peer_ids_by_distance,
)
@ -255,6 +258,10 @@ class PeerRouting(IPeerRouting):
find_node_msg.type = Message.MessageType.FIND_NODE
find_node_msg.key = target_key # Set target key directly as bytes
# Create sender_signed_peer_record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
find_node_msg.senderRecord = envelope_bytes
# Serialize and send the protobuf message with varint length prefix
proto_bytes = find_node_msg.SerializeToString()
logger.debug(
@ -299,7 +306,22 @@ class PeerRouting(IPeerRouting):
# Process closest peers from response
if response_msg.type == Message.MessageType.FIND_NODE:
# Consume the sender_signed_peer_record
if not maybe_consume_signed_record(response_msg, self.host, peer):
logger.error(
"Received an invalid-signed-record,ignoring the response"
)
return []
for peer_data in response_msg.closerPeers:
# Consume the received closer_peers signed-records, peer-id is
# sent with the peer-data
if not maybe_consume_signed_record(peer_data, self.host):
logger.error(
"Received an invalid-signed-record,ignoring the response"
)
return []
new_peer_id = ID(peer_data.id)
if new_peer_id not in results:
results.append(new_peer_id)
@ -332,6 +354,7 @@ class PeerRouting(IPeerRouting):
"""
try:
# Read message length
peer_id = stream.muxed_conn.peer_id
length_bytes = await stream.read(4)
if not length_bytes:
return
@ -345,10 +368,18 @@ class PeerRouting(IPeerRouting):
# Parse protobuf message
kad_message = Message()
closer_peer_envelope: Envelope | None = None
try:
kad_message.ParseFromString(message_bytes)
if kad_message.type == Message.MessageType.FIND_NODE:
# Consume the sender's signed-peer-record if sent
if not maybe_consume_signed_record(kad_message, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, dropping the stream"
)
return
# Get target key directly from protobuf message
target_key = kad_message.key
@ -361,12 +392,26 @@ class PeerRouting(IPeerRouting):
response = Message()
response.type = Message.MessageType.FIND_NODE
# Create sender_signed_peer_record for the response
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
# Add peer information to response
for peer_id in closest_peers:
peer_proto = response.closerPeers.add()
peer_proto.id = peer_id.to_bytes()
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
# Add the signed-records of closest_peers if cached
closer_peer_envelope = (
self.host.get_peerstore().get_peer_record(peer_id)
)
if isinstance(closer_peer_envelope, Envelope):
peer_proto.signedRecord = (
closer_peer_envelope.marshal_envelope()
)
# Add addresses if available
try:
addrs = self.host.get_peerstore().addrs(peer_id)

View File

@ -22,12 +22,14 @@ from libp2p.abc import (
from libp2p.custom_types import (
TProtocol,
)
from libp2p.kad_dht.utils import maybe_consume_signed_record
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
from libp2p.peer.peerstore import env_to_send_in_RPC
from .common import (
ALPHA,
@ -240,11 +242,18 @@ class ProviderStore:
message.type = Message.MessageType.ADD_PROVIDER
message.key = key
# Create sender's signed-peer-record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
message.senderRecord = envelope_bytes
# Add our provider info
provider = message.providerPeers.add()
provider.id = self.local_peer_id.to_bytes()
provider.addrs.extend(addrs)
# Add the provider's signed-peer-record
provider.signedRecord = envelope_bytes
# Serialize and send the message
proto_bytes = message.SerializeToString()
await stream.write(varint.encode(len(proto_bytes)))
@ -276,10 +285,15 @@ class ProviderStore:
response = Message()
response.ParseFromString(response_bytes)
# Check response type
response.type == Message.MessageType.ADD_PROVIDER
if response.type:
result = True
if response.type == Message.MessageType.ADD_PROVIDER:
# Consume the sender's signed-peer-record if sent
if not maybe_consume_signed_record(response, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, ignoring the response"
)
result = False
else:
result = True
except Exception as e:
logger.warning(f"Error sending ADD_PROVIDER to {peer_id}: {e}")
@ -380,6 +394,10 @@ class ProviderStore:
message.type = Message.MessageType.GET_PROVIDERS
message.key = key
# Create sender's signed-peer-record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
message.senderRecord = envelope_bytes
# Serialize and send the message
proto_bytes = message.SerializeToString()
await stream.write(varint.encode(len(proto_bytes)))
@ -414,10 +432,26 @@ class ProviderStore:
if response.type != Message.MessageType.GET_PROVIDERS:
return []
# Consume the sender's signed-peer-record if sent
if not maybe_consume_signed_record(response, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, ignoring the response"
)
return []
# Extract provider information
providers = []
for provider_proto in response.providerPeers:
try:
# Consume the provider's signed-peer-record if sent, peer-id
# already sent with the provider-proto
if not maybe_consume_signed_record(provider_proto, self.host):
logger.error(
"Received an invalid-signed-record, "
"ignoring the response"
)
return []
# Create peer ID from bytes
provider_id = ID(provider_proto.id)
@ -431,6 +465,7 @@ class ProviderStore:
# Create PeerInfo and add to result
providers.append(PeerInfo(provider_id, addrs))
except Exception as e:
logger.warning(f"Failed to parse provider info: {e}")

View File

@ -2,13 +2,93 @@
Utility functions for Kademlia DHT implementation.
"""
import logging
import base58
import multihash
from libp2p.abc import IHost
from libp2p.peer.envelope import consume_envelope
from libp2p.peer.id import (
ID,
)
from .pb.kademlia_pb2 import (
Message,
)
logger = logging.getLogger("kademlia-example.utils")
def maybe_consume_signed_record(
msg: Message | Message.Peer, host: IHost, peer_id: ID | None = None
) -> bool:
"""
Attempt to parse and store a signed-peer-record (Envelope) received during
DHT communication. If the record is invalid, the peer-id does not match, or
updating the peerstore fails, the function logs an error and returns False.
Parameters
----------
msg : Message | Message.Peer
The protobuf message received during DHT communication. Can either be a
top-level `Message` containing `senderRecord` or a `Message.Peer`
containing `signedRecord`.
host : IHost
The local host instance, providing access to the peerstore for storing
verified peer records.
peer_id : ID | None, optional
The expected peer ID for record validation. If provided, the peer ID
inside the record must match this value.
Returns
-------
bool
True if a valid signed peer record was successfully consumed and stored,
False otherwise.
"""
if isinstance(msg, Message):
if msg.HasField("senderRecord"):
try:
# Convert the signed-peer-record(Envelope) from
# protobuf bytes
envelope, record = consume_envelope(
msg.senderRecord,
"libp2p-peer-record",
)
if not (isinstance(peer_id, ID) and record.peer_id == peer_id):
return False
# Use the default TTL of 2 hours (7200 seconds)
if not host.get_peerstore().consume_peer_record(envelope, 7200):
logger.error("Failed to update the Certified-Addr-Book")
return False
except Exception as e:
logger.error("Failed to update the Certified-Addr-Book: %s", e)
return False
else:
if msg.HasField("signedRecord"):
try:
# Convert the signed-peer-record(Envelope) from
# protobuf bytes
envelope, record = consume_envelope(
msg.signedRecord,
"libp2p-peer-record",
)
if not record.peer_id.to_bytes() == msg.id:
return False
# Use the default TTL of 2 hours (7200 seconds)
if not host.get_peerstore().consume_peer_record(envelope, 7200):
logger.error("Failed to update the Certified-Addr-Book")
return False
except Exception as e:
logger.error(
"Failed to update the Certified-Addr-Book: %s",
e,
)
return False
return True
def create_key_from_binary(binary_data: bytes) -> bytes:
"""

View File

@ -15,9 +15,11 @@ from libp2p.abc import (
from libp2p.custom_types import (
TProtocol,
)
from libp2p.kad_dht.utils import maybe_consume_signed_record
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerstore import env_to_send_in_RPC
from .common import (
DEFAULT_TTL,
@ -110,6 +112,10 @@ class ValueStore:
message = Message()
message.type = Message.MessageType.PUT_VALUE
# Create sender's signed-peer-record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
message.senderRecord = envelope_bytes
# Set message fields
message.key = key
message.record.key = key
@ -155,7 +161,13 @@ class ValueStore:
# Check if response is valid
if response.type == Message.MessageType.PUT_VALUE:
if response.key:
# Consume the sender's signed-peer-record if sent
if not maybe_consume_signed_record(response, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, ignoring the response"
)
return False
if response.key == key:
result = True
return result
@ -231,6 +243,10 @@ class ValueStore:
message.type = Message.MessageType.GET_VALUE
message.key = key
# Create sender's signed-peer-record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
message.senderRecord = envelope_bytes
# Serialize and send the protobuf message
proto_bytes = message.SerializeToString()
await stream.write(varint.encode(len(proto_bytes)))
@ -275,6 +291,13 @@ class ValueStore:
and response.HasField("record")
and response.record.value
):
# Consume the sender's signed-peer-record
if not maybe_consume_signed_record(response, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, ignoring the response"
)
return None
logger.debug(
f"Received value for key {key.hex()} from peer {peer_id}"
)

View File

@ -1,5 +1,7 @@
from typing import Any, cast
import multiaddr
from libp2p.crypto.ed25519 import Ed25519PublicKey
from libp2p.crypto.keys import PrivateKey, PublicKey
from libp2p.crypto.rsa import RSAPublicKey
@ -131,6 +133,9 @@ class Envelope:
)
return False
def _env_addrs_set(self) -> set[multiaddr.Multiaddr]:
return {b for b in self.record().addrs}
def pub_key_to_protobuf(pub_key: PublicKey) -> cryto_pb.PublicKey:
"""

View File

@ -16,6 +16,7 @@ import trio
from trio import MemoryReceiveChannel, MemorySendChannel
from libp2p.abc import (
IHost,
IPeerStore,
)
from libp2p.crypto.keys import (
@ -23,7 +24,8 @@ from libp2p.crypto.keys import (
PrivateKey,
PublicKey,
)
from libp2p.peer.envelope import Envelope
from libp2p.peer.envelope import Envelope, seal_record
from libp2p.peer.peer_record import PeerRecord
from .id import (
ID,
@ -39,6 +41,86 @@ from .peerinfo import (
PERMANENT_ADDR_TTL = 0
def create_signed_peer_record(
peer_id: ID, addrs: list[Multiaddr], pvt_key: PrivateKey
) -> Envelope:
"""Creates a signed_peer_record wrapped in an Envelope"""
record = PeerRecord(peer_id, addrs)
envelope = seal_record(record, pvt_key)
return envelope
def env_to_send_in_RPC(host: IHost) -> tuple[bytes, bool]:
"""
Return the signed peer record (Envelope) to be sent in an RPC.
This function checks whether the host already has a cached signed peer record
(SPR). If one exists and its addresses match the host's current listen
addresses, the cached envelope is reused. Otherwise, a new signed peer record
is created, cached, and returned.
Parameters
----------
host : IHost
The local host instance, providing access to peer ID, listen addresses,
private key, and the peerstore.
Returns
-------
tuple[bytes, bool]
A 2-tuple where the first element is the serialized envelope (bytes)
for the signed peer record, and the second element is a boolean flag
indicating whether a new record was created (True) or an existing cached
one was reused (False).
"""
listen_addrs_set = {addr for addr in host.get_addrs()}
local_env = host.get_peerstore().get_local_record()
if local_env is None:
# No cached SPR yet -> create one
return issue_and_cache_local_record(host), True
else:
record_addrs_set = local_env._env_addrs_set()
if record_addrs_set == listen_addrs_set:
# Perfect match -> reuse cached envelope
return local_env.marshal_envelope(), False
else:
# Addresses changed -> issue a new SPR and cache it
return issue_and_cache_local_record(host), True
def issue_and_cache_local_record(host: IHost) -> bytes:
"""
Create and cache a new signed peer record (Envelope) for the host.
This function generates a new signed peer record from the hosts peer ID,
listen addresses, and private key. The resulting envelope is stored in
the peerstore as the local record for future reuse.
Parameters
----------
host : IHost
The local host instance, providing access to peer ID, listen addresses,
private key, and the peerstore.
Returns
-------
bytes
The serialized envelope (bytes) representing the newly created signed
peer record.
"""
env = create_signed_peer_record(
host.get_id(),
host.get_addrs(),
host.get_private_key(),
)
# Cache it for next time use
host.get_peerstore().set_local_record(env)
return env.marshal_envelope()
class PeerRecordState:
envelope: Envelope
seq: int
@ -55,8 +137,17 @@ class PeerStore(IPeerStore):
self.peer_data_map = defaultdict(PeerData)
self.addr_update_channels: dict[ID, MemorySendChannel[Multiaddr]] = {}
self.peer_record_map: dict[ID, PeerRecordState] = {}
self.local_peer_record: Envelope | None = None
self.max_records = max_records
def get_local_record(self) -> Envelope | None:
"""Get the local-signed-record wrapped in Envelope"""
return self.local_peer_record
def set_local_record(self, envelope: Envelope) -> None:
"""Set the local-signed-record wrapped in Envelope"""
self.local_peer_record = envelope
def peer_info(self, peer_id: ID) -> PeerInfo:
"""
:param peer_id: peer ID to get info for