26 Commits

Author SHA1 Message Date
999315a74a Merge branch 'main' into noise-arch-change 2025-08-29 03:23:05 +05:30
5c11ac20e7 Merge pull request #815 from lla-dane/kad-record
Signed-Peer-Record support in KAD-DHT message transfer mechanism.
2025-08-29 03:09:07 +05:30
c2c4228591 added test for ADD_PROVIDER record processing 2025-08-27 13:02:32 +05:30
943bcc4d36 fix the logic error in add_provider handling 2025-08-27 10:17:40 +05:30
2006b2c92c added newsfragment 2025-08-26 12:59:18 +05:30
fe3f7adc1b fix typos 2025-08-26 12:49:51 +05:30
7b2d637382 Now using env_to_send_in_RPC for issuing records in Identify rpc messages 2025-08-26 12:49:51 +05:30
91bee9df89 Moved env_to_send_in_RPC function to libp2p/peer/peerstore.py 2025-08-26 12:49:51 +05:30
5bf9c7b537 Fix spinx error 2025-08-26 12:49:51 +05:30
8958c0fac3 Moved env_to_send_in_RPC function to libp2p/init.py 2025-08-26 12:49:51 +05:30
091ac082b9 Commented out the bool variable from env_to_send_in_RPC() at places 2025-08-26 12:49:51 +05:30
15f4a399ec Added and docstrings and removed typos 2025-08-26 12:49:51 +05:30
3917d7b596 verify peer_id in signed-record matches authenticated sender 2025-08-26 12:49:51 +05:30
3aacb3a391 remove the timeout bound from the kad-dht test 2025-08-26 12:49:51 +05:30
ba39e91a2e added test for req rejection upon invalid record transfer 2025-08-26 12:49:51 +05:30
57d1c9d807 reject dht-msgs upon receiving invalid records 2025-08-26 12:49:51 +05:30
efc899e872 fix abc.py file 2025-08-26 12:49:51 +05:30
cea1985c5c add reissuing mechanism of records if addrs dont change 2025-08-26 12:49:51 +05:30
702ad4876e remove too much repeatitive code 2025-08-26 12:49:51 +05:30
a21d9e878b recompile protobuf schema and remove typos 2025-08-26 12:49:51 +05:30
5ab68026d6 removed redundant logs 2025-08-26 12:49:51 +05:30
d1792588f9 added tests for signed-peee-record transfer in kad-dht 2025-08-26 12:49:51 +05:30
53db128f69 fix typos 2025-08-26 12:49:51 +05:30
cacb3c8aca feat: add webtransport certhashes field to NoiseExtensions and implement serialization test
Signed-off-by: varun-r-mallya <varunrmallya@gmail.com>
2025-08-26 12:49:21 +05:30
05fde3ad40 Merge branch 'main' into noise-arch-change 2025-08-25 16:21:43 +05:30
e4ab3cb2c5 Add early data support to Noise protocol
Signed-off-by: varun-r-mallya <varunrmallya@gmail.com>
2025-08-19 04:41:14 +05:30
29 changed files with 1074 additions and 219 deletions

View File

@ -24,13 +24,8 @@ async def main():
noise_transport = NoiseTransport(
# local_key_pair: The key pair used for libp2p identity and authentication
libp2p_keypair=key_pair,
# noise_privkey: The private key used for Noise protocol encryption
noise_privkey=key_pair.private_key,
# early_data: Optional data to send during the handshake
# (None means no early data)
early_data=None,
# with_noise_pipes: Whether to use Noise pipes for additional security features
with_noise_pipes=False,
# TODO: add early data
)
# Create a security options dictionary mapping protocol ID to transport

View File

@ -28,9 +28,7 @@ async def main():
noise_privkey=key_pair.private_key,
# early_data: Optional data to send during the handshake
# (None means no early data)
early_data=None,
# with_noise_pipes: Whether to use Noise pipes for additional security features
with_noise_pipes=False,
# TODO: add early data
)
# Create a security options dictionary mapping protocol ID to transport

View File

@ -31,9 +31,7 @@ async def main():
noise_privkey=key_pair.private_key,
# early_data: Optional data to send during the handshake
# (None means no early data)
early_data=None,
# with_noise_pipes: Whether to use Noise pipes for additional security features
with_noise_pipes=False,
# TODO: add early data
)
# Create a security options dictionary mapping protocol ID to transport

View File

@ -28,9 +28,7 @@ async def main():
noise_privkey=key_pair.private_key,
# early_data: Optional data to send during the handshake
# (None means no early data)
early_data=None,
# with_noise_pipes: Whether to use Noise pipes for additional security features
with_noise_pipes=False,
# TODO: add early data
)
# Create a security options dictionary mapping protocol ID to transport

View File

@ -49,6 +49,7 @@ from libp2p.peer.id import (
)
from libp2p.peer.peerstore import (
PeerStore,
create_signed_peer_record,
)
from libp2p.security.insecure.transport import (
PLAINTEXT_PROTOCOL_ID,
@ -155,7 +156,6 @@ def get_default_muxer_options() -> TMuxerOptions:
else: # YAMUX is default
return create_yamux_muxer_option()
def new_swarm(
key_pair: KeyPair | None = None,
muxer_opt: TMuxerOptions | None = None,

View File

@ -970,6 +970,14 @@ class IPeerStore(
# --------CERTIFIED-ADDR-BOOK----------
@abstractmethod
def get_local_record(self) -> Optional["Envelope"]:
"""Get the local-peer-record wrapped in Envelope"""
@abstractmethod
def set_local_record(self, envelope: "Envelope") -> None:
"""Set the local-peer-record wrapped in Envelope"""
@abstractmethod
def consume_peer_record(self, envelope: "Envelope", ttl: int) -> bool:
"""

View File

@ -43,6 +43,7 @@ from libp2p.peer.id import (
from libp2p.peer.peerinfo import (
PeerInfo,
)
from libp2p.peer.peerstore import create_signed_peer_record
from libp2p.protocol_muxer.exceptions import (
MultiselectClientError,
MultiselectError,
@ -110,6 +111,14 @@ class BasicHost(IHost):
if bootstrap:
self.bootstrap = BootstrapDiscovery(network, bootstrap)
# Cache a signed-record if the local-node in the PeerStore
envelope = create_signed_peer_record(
self.get_id(),
self.get_addrs(),
self.get_private_key(),
)
self.get_peerstore().set_local_record(envelope)
def get_id(self) -> ID:
"""
:return: peer_id of host

View File

@ -15,8 +15,7 @@ from libp2p.custom_types import (
from libp2p.network.stream.exceptions import (
StreamClosed,
)
from libp2p.peer.envelope import seal_record
from libp2p.peer.peer_record import PeerRecord
from libp2p.peer.peerstore import env_to_send_in_RPC
from libp2p.utils import (
decode_varint_with_size,
get_agent_version,
@ -66,9 +65,7 @@ def _mk_identify_protobuf(
protocols = tuple(str(p) for p in host.get_mux().get_protocols() if p is not None)
# Create a signed peer-record for the remote peer
record = PeerRecord(host.get_id(), host.get_addrs())
envelope = seal_record(record, host.get_private_key())
protobuf = envelope.marshal_envelope()
envelope_bytes, _ = env_to_send_in_RPC(host)
observed_addr = observed_multiaddr.to_bytes() if observed_multiaddr else b""
return Identify(
@ -78,7 +75,7 @@ def _mk_identify_protobuf(
listen_addrs=map(_multiaddr_to_bytes, laddrs),
observed_addr=observed_addr,
protocols=protocols,
signedPeerRecord=protobuf,
signedPeerRecord=envelope_bytes,
)

View File

@ -22,15 +22,18 @@ from libp2p.abc import (
IHost,
)
from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager
from libp2p.kad_dht.utils import maybe_consume_signed_record
from libp2p.network.stream.net_stream import (
INetStream,
)
from libp2p.peer.envelope import Envelope
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
from libp2p.peer.peerstore import env_to_send_in_RPC
from libp2p.tools.async_service import (
Service,
)
@ -234,6 +237,9 @@ class KadDHT(Service):
await self.add_peer(peer_id)
logger.debug(f"Added peer {peer_id} to routing table")
closer_peer_envelope: Envelope | None = None
provider_peer_envelope: Envelope | None = None
try:
# Read varint-prefixed length for the message
length_prefix = b""
@ -274,6 +280,14 @@ class KadDHT(Service):
)
logger.debug(f"Found {len(closest_peers)} peers close to target")
# Consume the source signed_peer_record if sent
if not maybe_consume_signed_record(message, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, dropping the stream"
)
await stream.close()
return
# Build response message with protobuf
response = Message()
response.type = Message.MessageType.FIND_NODE
@ -298,6 +312,21 @@ class KadDHT(Service):
except Exception:
pass
# Add the signed-peer-record for each peer in the peer-proto
# if cached in the peerstore
closer_peer_envelope = (
self.host.get_peerstore().get_peer_record(peer)
)
if closer_peer_envelope is not None:
peer_proto.signedRecord = (
closer_peer_envelope.marshal_envelope()
)
# Create sender_signed_peer_record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
# Serialize and send response
response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes)))
@ -312,6 +341,14 @@ class KadDHT(Service):
key = message.key
logger.debug(f"Received ADD_PROVIDER for key {key.hex()}")
# Consume the source signed-peer-record if sent
if not maybe_consume_signed_record(message, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, dropping the stream"
)
await stream.close()
return
# Extract provider information
for provider_proto in message.providerPeers:
try:
@ -338,6 +375,17 @@ class KadDHT(Service):
logger.debug(
f"Added provider {provider_id} for key {key.hex()}"
)
# Process the signed-records of provider if sent
if not maybe_consume_signed_record(
provider_proto, self.host
):
logger.error(
"Received an invalid-signed-record,"
"dropping the stream"
)
await stream.close()
return
except Exception as e:
logger.warning(f"Failed to process provider info: {e}")
@ -346,6 +394,10 @@ class KadDHT(Service):
response.type = Message.MessageType.ADD_PROVIDER
response.key = key
# Add sender's signed-peer-record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes)))
await stream.write(response_bytes)
@ -357,6 +409,14 @@ class KadDHT(Service):
key = message.key
logger.debug(f"Received GET_PROVIDERS request for key {key.hex()}")
# Consume the source signed_peer_record if sent
if not maybe_consume_signed_record(message, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, dropping the stream"
)
await stream.close()
return
# Find providers for the key
providers = self.provider_store.get_providers(key)
logger.debug(
@ -368,12 +428,28 @@ class KadDHT(Service):
response.type = Message.MessageType.GET_PROVIDERS
response.key = key
# Create sender_signed_peer_record for the response
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
# Add provider information to response
for provider_info in providers:
provider_proto = response.providerPeers.add()
provider_proto.id = provider_info.peer_id.to_bytes()
provider_proto.connection = Message.ConnectionType.CAN_CONNECT
# Add provider signed-records if cached
provider_peer_envelope = (
self.host.get_peerstore().get_peer_record(
provider_info.peer_id
)
)
if provider_peer_envelope is not None:
provider_proto.signedRecord = (
provider_peer_envelope.marshal_envelope()
)
# Add addresses if available
for addr in provider_info.addrs:
provider_proto.addrs.append(addr.to_bytes())
@ -397,6 +473,16 @@ class KadDHT(Service):
peer_proto.id = peer.to_bytes()
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
# Add the signed-records of closest_peers if cached
closer_peer_envelope = (
self.host.get_peerstore().get_peer_record(peer)
)
if closer_peer_envelope is not None:
peer_proto.signedRecord = (
closer_peer_envelope.marshal_envelope()
)
# Add addresses if available
try:
addrs = self.host.get_peerstore().addrs(peer)
@ -417,6 +503,14 @@ class KadDHT(Service):
key = message.key
logger.debug(f"Received GET_VALUE request for key {key.hex()}")
# Consume the sender_signed_peer_record
if not maybe_consume_signed_record(message, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, dropping the stream"
)
await stream.close()
return
value = self.value_store.get(key)
if value:
logger.debug(f"Found value for key {key.hex()}")
@ -431,6 +525,10 @@ class KadDHT(Service):
response.record.value = value
response.record.timeReceived = str(time.time())
# Create sender_signed_peer_record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
# Serialize and send response
response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes)))
@ -444,6 +542,10 @@ class KadDHT(Service):
response.type = Message.MessageType.GET_VALUE
response.key = key
# Create sender_signed_peer_record for the response
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
# Add closest peers to key
closest_peers = self.routing_table.find_local_closest_peers(
key, 20
@ -462,6 +564,16 @@ class KadDHT(Service):
peer_proto.id = peer.to_bytes()
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
# Add signed-records of closer-peers if cached
closer_peer_envelope = (
self.host.get_peerstore().get_peer_record(peer)
)
if closer_peer_envelope is not None:
peer_proto.signedRecord = (
closer_peer_envelope.marshal_envelope()
)
# Add addresses if available
try:
addrs = self.host.get_peerstore().addrs(peer)
@ -484,6 +596,15 @@ class KadDHT(Service):
key = message.record.key
value = message.record.value
success = False
# Consume the source signed_peer_record if sent
if not maybe_consume_signed_record(message, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, dropping the stream"
)
await stream.close()
return
try:
if not (key and value):
raise ValueError(
@ -504,6 +625,12 @@ class KadDHT(Service):
response.type = Message.MessageType.PUT_VALUE
if success:
response.key = key
# Create sender_signed_peer_record for the response
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
# Serialize and send response
response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes)))
await stream.write(response_bytes)

View File

@ -27,6 +27,7 @@ message Message {
bytes id = 1;
repeated bytes addrs = 2;
ConnectionType connection = 3;
optional bytes signedRecord = 4; // Envelope(PeerRecord) encoded
}
MessageType type = 1;
@ -35,4 +36,6 @@ message Message {
Record record = 3;
repeated Peer closerPeers = 8;
repeated Peer providerPeers = 9;
optional bytes senderRecord = 11; // Envelope(PeerRecord) encoded
}

View File

@ -1,11 +1,12 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: libp2p/kad_dht/pb/kademlia.proto
# Protobuf Python Version: 4.25.3
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
@ -13,21 +14,21 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/kad_dht/pb/kademlia.proto\":\n\x06Record\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x14\n\x0ctimeReceived\x18\x05 \x01(\t\"\xca\x03\n\x07Message\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.Message.MessageType\x12\x17\n\x0f\x63lusterLevelRaw\x18\n \x01(\x05\x12\x0b\n\x03key\x18\x02 \x01(\x0c\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record\x12\"\n\x0b\x63loserPeers\x18\x08 \x03(\x0b\x32\r.Message.Peer\x12$\n\rproviderPeers\x18\t \x03(\x0b\x32\r.Message.Peer\x1aN\n\x04Peer\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12+\n\nconnection\x18\x03 \x01(\x0e\x32\x17.Message.ConnectionType\"i\n\x0bMessageType\x12\r\n\tPUT_VALUE\x10\x00\x12\r\n\tGET_VALUE\x10\x01\x12\x10\n\x0c\x41\x44\x44_PROVIDER\x10\x02\x12\x11\n\rGET_PROVIDERS\x10\x03\x12\r\n\tFIND_NODE\x10\x04\x12\x08\n\x04PING\x10\x05\"W\n\x0e\x43onnectionType\x12\x11\n\rNOT_CONNECTED\x10\x00\x12\r\n\tCONNECTED\x10\x01\x12\x0f\n\x0b\x43\x41N_CONNECT\x10\x02\x12\x12\n\x0e\x43\x41NNOT_CONNECT\x10\x03\x62\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/kad_dht/pb/kademlia.proto\":\n\x06Record\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x14\n\x0ctimeReceived\x18\x05 \x01(\t\"\xa2\x04\n\x07Message\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.Message.MessageType\x12\x17\n\x0f\x63lusterLevelRaw\x18\n \x01(\x05\x12\x0b\n\x03key\x18\x02 \x01(\x0c\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record\x12\"\n\x0b\x63loserPeers\x18\x08 \x03(\x0b\x32\r.Message.Peer\x12$\n\rproviderPeers\x18\t \x03(\x0b\x32\r.Message.Peer\x12\x19\n\x0csenderRecord\x18\x0b \x01(\x0cH\x00\x88\x01\x01\x1az\n\x04Peer\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12+\n\nconnection\x18\x03 \x01(\x0e\x32\x17.Message.ConnectionType\x12\x19\n\x0csignedRecord\x18\x04 \x01(\x0cH\x00\x88\x01\x01\x42\x0f\n\r_signedRecord\"i\n\x0bMessageType\x12\r\n\tPUT_VALUE\x10\x00\x12\r\n\tGET_VALUE\x10\x01\x12\x10\n\x0c\x41\x44\x44_PROVIDER\x10\x02\x12\x11\n\rGET_PROVIDERS\x10\x03\x12\r\n\tFIND_NODE\x10\x04\x12\x08\n\x04PING\x10\x05\"W\n\x0e\x43onnectionType\x12\x11\n\rNOT_CONNECTED\x10\x00\x12\r\n\tCONNECTED\x10\x01\x12\x0f\n\x0b\x43\x41N_CONNECT\x10\x02\x12\x12\n\x0e\x43\x41NNOT_CONNECT\x10\x03\x42\x0f\n\r_senderRecordb\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.kad_dht.pb.kademlia_pb2', globals())
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.kad_dht.pb.kademlia_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_RECORD._serialized_start=36
_RECORD._serialized_end=94
_MESSAGE._serialized_start=97
_MESSAGE._serialized_end=555
_MESSAGE_PEER._serialized_start=281
_MESSAGE_PEER._serialized_end=359
_MESSAGE_MESSAGETYPE._serialized_start=361
_MESSAGE_MESSAGETYPE._serialized_end=466
_MESSAGE_CONNECTIONTYPE._serialized_start=468
_MESSAGE_CONNECTIONTYPE._serialized_end=555
_globals['_RECORD']._serialized_start=36
_globals['_RECORD']._serialized_end=94
_globals['_MESSAGE']._serialized_start=97
_globals['_MESSAGE']._serialized_end=643
_globals['_MESSAGE_PEER']._serialized_start=308
_globals['_MESSAGE_PEER']._serialized_end=430
_globals['_MESSAGE_MESSAGETYPE']._serialized_start=432
_globals['_MESSAGE_MESSAGETYPE']._serialized_end=537
_globals['_MESSAGE_CONNECTIONTYPE']._serialized_start=539
_globals['_MESSAGE_CONNECTIONTYPE']._serialized_end=626
# @@protoc_insertion_point(module_scope)

View File

@ -1,133 +1,70 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
"""
from google.protobuf.internal import containers as _containers
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union
import builtins
import collections.abc
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.internal.enum_type_wrapper
import google.protobuf.message
import sys
import typing
DESCRIPTOR: _descriptor.FileDescriptor
if sys.version_info >= (3, 10):
import typing as typing_extensions
else:
import typing_extensions
class Record(_message.Message):
__slots__ = ("key", "value", "timeReceived")
KEY_FIELD_NUMBER: _ClassVar[int]
VALUE_FIELD_NUMBER: _ClassVar[int]
TIMERECEIVED_FIELD_NUMBER: _ClassVar[int]
key: bytes
value: bytes
timeReceived: str
def __init__(self, key: _Optional[bytes] = ..., value: _Optional[bytes] = ..., timeReceived: _Optional[str] = ...) -> None: ...
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@typing.final
class Record(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEY_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
TIMERECEIVED_FIELD_NUMBER: builtins.int
key: builtins.bytes
value: builtins.bytes
timeReceived: builtins.str
def __init__(
self,
*,
key: builtins.bytes = ...,
value: builtins.bytes = ...,
timeReceived: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["key", b"key", "timeReceived", b"timeReceived", "value", b"value"]) -> None: ...
global___Record = Record
@typing.final
class Message(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
class _MessageType:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _MessageTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._MessageType.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
PUT_VALUE: Message._MessageType.ValueType # 0
GET_VALUE: Message._MessageType.ValueType # 1
ADD_PROVIDER: Message._MessageType.ValueType # 2
GET_PROVIDERS: Message._MessageType.ValueType # 3
FIND_NODE: Message._MessageType.ValueType # 4
PING: Message._MessageType.ValueType # 5
class MessageType(_MessageType, metaclass=_MessageTypeEnumTypeWrapper): ...
PUT_VALUE: Message.MessageType.ValueType # 0
GET_VALUE: Message.MessageType.ValueType # 1
ADD_PROVIDER: Message.MessageType.ValueType # 2
GET_PROVIDERS: Message.MessageType.ValueType # 3
FIND_NODE: Message.MessageType.ValueType # 4
PING: Message.MessageType.ValueType # 5
class _ConnectionType:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _ConnectionTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._ConnectionType.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
NOT_CONNECTED: Message._ConnectionType.ValueType # 0
CONNECTED: Message._ConnectionType.ValueType # 1
CAN_CONNECT: Message._ConnectionType.ValueType # 2
CANNOT_CONNECT: Message._ConnectionType.ValueType # 3
class ConnectionType(_ConnectionType, metaclass=_ConnectionTypeEnumTypeWrapper): ...
NOT_CONNECTED: Message.ConnectionType.ValueType # 0
CONNECTED: Message.ConnectionType.ValueType # 1
CAN_CONNECT: Message.ConnectionType.ValueType # 2
CANNOT_CONNECT: Message.ConnectionType.ValueType # 3
@typing.final
class Peer(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
ID_FIELD_NUMBER: builtins.int
ADDRS_FIELD_NUMBER: builtins.int
CONNECTION_FIELD_NUMBER: builtins.int
id: builtins.bytes
connection: global___Message.ConnectionType.ValueType
@property
def addrs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ...
def __init__(
self,
*,
id: builtins.bytes = ...,
addrs: collections.abc.Iterable[builtins.bytes] | None = ...,
connection: global___Message.ConnectionType.ValueType = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["addrs", b"addrs", "connection", b"connection", "id", b"id"]) -> None: ...
TYPE_FIELD_NUMBER: builtins.int
CLUSTERLEVELRAW_FIELD_NUMBER: builtins.int
KEY_FIELD_NUMBER: builtins.int
RECORD_FIELD_NUMBER: builtins.int
CLOSERPEERS_FIELD_NUMBER: builtins.int
PROVIDERPEERS_FIELD_NUMBER: builtins.int
type: global___Message.MessageType.ValueType
clusterLevelRaw: builtins.int
key: builtins.bytes
@property
def record(self) -> global___Record: ...
@property
def closerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ...
@property
def providerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ...
def __init__(
self,
*,
type: global___Message.MessageType.ValueType = ...,
clusterLevelRaw: builtins.int = ...,
key: builtins.bytes = ...,
record: global___Record | None = ...,
closerPeers: collections.abc.Iterable[global___Message.Peer] | None = ...,
providerPeers: collections.abc.Iterable[global___Message.Peer] | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["record", b"record"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["closerPeers", b"closerPeers", "clusterLevelRaw", b"clusterLevelRaw", "key", b"key", "providerPeers", b"providerPeers", "record", b"record", "type", b"type"]) -> None: ...
global___Message = Message
class Message(_message.Message):
__slots__ = ("type", "clusterLevelRaw", "key", "record", "closerPeers", "providerPeers", "senderRecord")
class MessageType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
PUT_VALUE: _ClassVar[Message.MessageType]
GET_VALUE: _ClassVar[Message.MessageType]
ADD_PROVIDER: _ClassVar[Message.MessageType]
GET_PROVIDERS: _ClassVar[Message.MessageType]
FIND_NODE: _ClassVar[Message.MessageType]
PING: _ClassVar[Message.MessageType]
PUT_VALUE: Message.MessageType
GET_VALUE: Message.MessageType
ADD_PROVIDER: Message.MessageType
GET_PROVIDERS: Message.MessageType
FIND_NODE: Message.MessageType
PING: Message.MessageType
class ConnectionType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
NOT_CONNECTED: _ClassVar[Message.ConnectionType]
CONNECTED: _ClassVar[Message.ConnectionType]
CAN_CONNECT: _ClassVar[Message.ConnectionType]
CANNOT_CONNECT: _ClassVar[Message.ConnectionType]
NOT_CONNECTED: Message.ConnectionType
CONNECTED: Message.ConnectionType
CAN_CONNECT: Message.ConnectionType
CANNOT_CONNECT: Message.ConnectionType
class Peer(_message.Message):
__slots__ = ("id", "addrs", "connection", "signedRecord")
ID_FIELD_NUMBER: _ClassVar[int]
ADDRS_FIELD_NUMBER: _ClassVar[int]
CONNECTION_FIELD_NUMBER: _ClassVar[int]
SIGNEDRECORD_FIELD_NUMBER: _ClassVar[int]
id: bytes
addrs: _containers.RepeatedScalarFieldContainer[bytes]
connection: Message.ConnectionType
signedRecord: bytes
def __init__(self, id: _Optional[bytes] = ..., addrs: _Optional[_Iterable[bytes]] = ..., connection: _Optional[_Union[Message.ConnectionType, str]] = ..., signedRecord: _Optional[bytes] = ...) -> None: ...
TYPE_FIELD_NUMBER: _ClassVar[int]
CLUSTERLEVELRAW_FIELD_NUMBER: _ClassVar[int]
KEY_FIELD_NUMBER: _ClassVar[int]
RECORD_FIELD_NUMBER: _ClassVar[int]
CLOSERPEERS_FIELD_NUMBER: _ClassVar[int]
PROVIDERPEERS_FIELD_NUMBER: _ClassVar[int]
SENDERRECORD_FIELD_NUMBER: _ClassVar[int]
type: Message.MessageType
clusterLevelRaw: int
key: bytes
record: Record
closerPeers: _containers.RepeatedCompositeFieldContainer[Message.Peer]
providerPeers: _containers.RepeatedCompositeFieldContainer[Message.Peer]
senderRecord: bytes
def __init__(self, type: _Optional[_Union[Message.MessageType, str]] = ..., clusterLevelRaw: _Optional[int] = ..., key: _Optional[bytes] = ..., record: _Optional[_Union[Record, _Mapping]] = ..., closerPeers: _Optional[_Iterable[_Union[Message.Peer, _Mapping]]] = ..., providerPeers: _Optional[_Iterable[_Union[Message.Peer, _Mapping]]] = ..., senderRecord: _Optional[bytes] = ...) -> None: ... # type: ignore

View File

@ -15,12 +15,14 @@ from libp2p.abc import (
INetStream,
IPeerRouting,
)
from libp2p.peer.envelope import Envelope
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
from libp2p.peer.peerstore import env_to_send_in_RPC
from .common import (
ALPHA,
@ -33,6 +35,7 @@ from .routing_table import (
RoutingTable,
)
from .utils import (
maybe_consume_signed_record,
sort_peer_ids_by_distance,
)
@ -255,6 +258,10 @@ class PeerRouting(IPeerRouting):
find_node_msg.type = Message.MessageType.FIND_NODE
find_node_msg.key = target_key # Set target key directly as bytes
# Create sender_signed_peer_record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
find_node_msg.senderRecord = envelope_bytes
# Serialize and send the protobuf message with varint length prefix
proto_bytes = find_node_msg.SerializeToString()
logger.debug(
@ -299,7 +306,22 @@ class PeerRouting(IPeerRouting):
# Process closest peers from response
if response_msg.type == Message.MessageType.FIND_NODE:
# Consume the sender_signed_peer_record
if not maybe_consume_signed_record(response_msg, self.host, peer):
logger.error(
"Received an invalid-signed-record,ignoring the response"
)
return []
for peer_data in response_msg.closerPeers:
# Consume the received closer_peers signed-records, peer-id is
# sent with the peer-data
if not maybe_consume_signed_record(peer_data, self.host):
logger.error(
"Received an invalid-signed-record,ignoring the response"
)
return []
new_peer_id = ID(peer_data.id)
if new_peer_id not in results:
results.append(new_peer_id)
@ -332,6 +354,7 @@ class PeerRouting(IPeerRouting):
"""
try:
# Read message length
peer_id = stream.muxed_conn.peer_id
length_bytes = await stream.read(4)
if not length_bytes:
return
@ -345,10 +368,18 @@ class PeerRouting(IPeerRouting):
# Parse protobuf message
kad_message = Message()
closer_peer_envelope: Envelope | None = None
try:
kad_message.ParseFromString(message_bytes)
if kad_message.type == Message.MessageType.FIND_NODE:
# Consume the sender's signed-peer-record if sent
if not maybe_consume_signed_record(kad_message, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, dropping the stream"
)
return
# Get target key directly from protobuf message
target_key = kad_message.key
@ -361,12 +392,26 @@ class PeerRouting(IPeerRouting):
response = Message()
response.type = Message.MessageType.FIND_NODE
# Create sender_signed_peer_record for the response
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
# Add peer information to response
for peer_id in closest_peers:
peer_proto = response.closerPeers.add()
peer_proto.id = peer_id.to_bytes()
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
# Add the signed-records of closest_peers if cached
closer_peer_envelope = (
self.host.get_peerstore().get_peer_record(peer_id)
)
if isinstance(closer_peer_envelope, Envelope):
peer_proto.signedRecord = (
closer_peer_envelope.marshal_envelope()
)
# Add addresses if available
try:
addrs = self.host.get_peerstore().addrs(peer_id)

View File

@ -22,12 +22,14 @@ from libp2p.abc import (
from libp2p.custom_types import (
TProtocol,
)
from libp2p.kad_dht.utils import maybe_consume_signed_record
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
from libp2p.peer.peerstore import env_to_send_in_RPC
from .common import (
ALPHA,
@ -240,11 +242,18 @@ class ProviderStore:
message.type = Message.MessageType.ADD_PROVIDER
message.key = key
# Create sender's signed-peer-record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
message.senderRecord = envelope_bytes
# Add our provider info
provider = message.providerPeers.add()
provider.id = self.local_peer_id.to_bytes()
provider.addrs.extend(addrs)
# Add the provider's signed-peer-record
provider.signedRecord = envelope_bytes
# Serialize and send the message
proto_bytes = message.SerializeToString()
await stream.write(varint.encode(len(proto_bytes)))
@ -276,10 +285,15 @@ class ProviderStore:
response = Message()
response.ParseFromString(response_bytes)
# Check response type
response.type == Message.MessageType.ADD_PROVIDER
if response.type:
result = True
if response.type == Message.MessageType.ADD_PROVIDER:
# Consume the sender's signed-peer-record if sent
if not maybe_consume_signed_record(response, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, ignoring the response"
)
result = False
else:
result = True
except Exception as e:
logger.warning(f"Error sending ADD_PROVIDER to {peer_id}: {e}")
@ -380,6 +394,10 @@ class ProviderStore:
message.type = Message.MessageType.GET_PROVIDERS
message.key = key
# Create sender's signed-peer-record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
message.senderRecord = envelope_bytes
# Serialize and send the message
proto_bytes = message.SerializeToString()
await stream.write(varint.encode(len(proto_bytes)))
@ -414,10 +432,26 @@ class ProviderStore:
if response.type != Message.MessageType.GET_PROVIDERS:
return []
# Consume the sender's signed-peer-record if sent
if not maybe_consume_signed_record(response, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, ignoring the response"
)
return []
# Extract provider information
providers = []
for provider_proto in response.providerPeers:
try:
# Consume the provider's signed-peer-record if sent, peer-id
# already sent with the provider-proto
if not maybe_consume_signed_record(provider_proto, self.host):
logger.error(
"Received an invalid-signed-record, "
"ignoring the response"
)
return []
# Create peer ID from bytes
provider_id = ID(provider_proto.id)
@ -431,6 +465,7 @@ class ProviderStore:
# Create PeerInfo and add to result
providers.append(PeerInfo(provider_id, addrs))
except Exception as e:
logger.warning(f"Failed to parse provider info: {e}")

View File

@ -2,13 +2,93 @@
Utility functions for Kademlia DHT implementation.
"""
import logging
import base58
import multihash
from libp2p.abc import IHost
from libp2p.peer.envelope import consume_envelope
from libp2p.peer.id import (
ID,
)
from .pb.kademlia_pb2 import (
Message,
)
logger = logging.getLogger("kademlia-example.utils")
def maybe_consume_signed_record(
msg: Message | Message.Peer, host: IHost, peer_id: ID | None = None
) -> bool:
"""
Attempt to parse and store a signed-peer-record (Envelope) received during
DHT communication. If the record is invalid, the peer-id does not match, or
updating the peerstore fails, the function logs an error and returns False.
Parameters
----------
msg : Message | Message.Peer
The protobuf message received during DHT communication. Can either be a
top-level `Message` containing `senderRecord` or a `Message.Peer`
containing `signedRecord`.
host : IHost
The local host instance, providing access to the peerstore for storing
verified peer records.
peer_id : ID | None, optional
The expected peer ID for record validation. If provided, the peer ID
inside the record must match this value.
Returns
-------
bool
True if a valid signed peer record was successfully consumed and stored,
False otherwise.
"""
if isinstance(msg, Message):
if msg.HasField("senderRecord"):
try:
# Convert the signed-peer-record(Envelope) from
# protobuf bytes
envelope, record = consume_envelope(
msg.senderRecord,
"libp2p-peer-record",
)
if not (isinstance(peer_id, ID) and record.peer_id == peer_id):
return False
# Use the default TTL of 2 hours (7200 seconds)
if not host.get_peerstore().consume_peer_record(envelope, 7200):
logger.error("Failed to update the Certified-Addr-Book")
return False
except Exception as e:
logger.error("Failed to update the Certified-Addr-Book: %s", e)
return False
else:
if msg.HasField("signedRecord"):
try:
# Convert the signed-peer-record(Envelope) from
# protobuf bytes
envelope, record = consume_envelope(
msg.signedRecord,
"libp2p-peer-record",
)
if not record.peer_id.to_bytes() == msg.id:
return False
# Use the default TTL of 2 hours (7200 seconds)
if not host.get_peerstore().consume_peer_record(envelope, 7200):
logger.error("Failed to update the Certified-Addr-Book")
return False
except Exception as e:
logger.error(
"Failed to update the Certified-Addr-Book: %s",
e,
)
return False
return True
def create_key_from_binary(binary_data: bytes) -> bytes:
"""

View File

@ -15,9 +15,11 @@ from libp2p.abc import (
from libp2p.custom_types import (
TProtocol,
)
from libp2p.kad_dht.utils import maybe_consume_signed_record
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerstore import env_to_send_in_RPC
from .common import (
DEFAULT_TTL,
@ -110,6 +112,10 @@ class ValueStore:
message = Message()
message.type = Message.MessageType.PUT_VALUE
# Create sender's signed-peer-record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
message.senderRecord = envelope_bytes
# Set message fields
message.key = key
message.record.key = key
@ -155,7 +161,13 @@ class ValueStore:
# Check if response is valid
if response.type == Message.MessageType.PUT_VALUE:
if response.key:
# Consume the sender's signed-peer-record if sent
if not maybe_consume_signed_record(response, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, ignoring the response"
)
return False
if response.key == key:
result = True
return result
@ -231,6 +243,10 @@ class ValueStore:
message.type = Message.MessageType.GET_VALUE
message.key = key
# Create sender's signed-peer-record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
message.senderRecord = envelope_bytes
# Serialize and send the protobuf message
proto_bytes = message.SerializeToString()
await stream.write(varint.encode(len(proto_bytes)))
@ -275,6 +291,13 @@ class ValueStore:
and response.HasField("record")
and response.record.value
):
# Consume the sender's signed-peer-record
if not maybe_consume_signed_record(response, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, ignoring the response"
)
return None
logger.debug(
f"Received value for key {key.hex()} from peer {peer_id}"
)

View File

@ -1,5 +1,7 @@
from typing import Any, cast
import multiaddr
from libp2p.crypto.ed25519 import Ed25519PublicKey
from libp2p.crypto.keys import PrivateKey, PublicKey
from libp2p.crypto.rsa import RSAPublicKey
@ -131,6 +133,9 @@ class Envelope:
)
return False
def _env_addrs_set(self) -> set[multiaddr.Multiaddr]:
return {b for b in self.record().addrs}
def pub_key_to_protobuf(pub_key: PublicKey) -> cryto_pb.PublicKey:
"""

View File

@ -16,6 +16,7 @@ import trio
from trio import MemoryReceiveChannel, MemorySendChannel
from libp2p.abc import (
IHost,
IPeerStore,
)
from libp2p.crypto.keys import (
@ -23,7 +24,8 @@ from libp2p.crypto.keys import (
PrivateKey,
PublicKey,
)
from libp2p.peer.envelope import Envelope
from libp2p.peer.envelope import Envelope, seal_record
from libp2p.peer.peer_record import PeerRecord
from .id import (
ID,
@ -39,6 +41,86 @@ from .peerinfo import (
PERMANENT_ADDR_TTL = 0
def create_signed_peer_record(
peer_id: ID, addrs: list[Multiaddr], pvt_key: PrivateKey
) -> Envelope:
"""Creates a signed_peer_record wrapped in an Envelope"""
record = PeerRecord(peer_id, addrs)
envelope = seal_record(record, pvt_key)
return envelope
def env_to_send_in_RPC(host: IHost) -> tuple[bytes, bool]:
"""
Return the signed peer record (Envelope) to be sent in an RPC.
This function checks whether the host already has a cached signed peer record
(SPR). If one exists and its addresses match the host's current listen
addresses, the cached envelope is reused. Otherwise, a new signed peer record
is created, cached, and returned.
Parameters
----------
host : IHost
The local host instance, providing access to peer ID, listen addresses,
private key, and the peerstore.
Returns
-------
tuple[bytes, bool]
A 2-tuple where the first element is the serialized envelope (bytes)
for the signed peer record, and the second element is a boolean flag
indicating whether a new record was created (True) or an existing cached
one was reused (False).
"""
listen_addrs_set = {addr for addr in host.get_addrs()}
local_env = host.get_peerstore().get_local_record()
if local_env is None:
# No cached SPR yet -> create one
return issue_and_cache_local_record(host), True
else:
record_addrs_set = local_env._env_addrs_set()
if record_addrs_set == listen_addrs_set:
# Perfect match -> reuse cached envelope
return local_env.marshal_envelope(), False
else:
# Addresses changed -> issue a new SPR and cache it
return issue_and_cache_local_record(host), True
def issue_and_cache_local_record(host: IHost) -> bytes:
"""
Create and cache a new signed peer record (Envelope) for the host.
This function generates a new signed peer record from the hosts peer ID,
listen addresses, and private key. The resulting envelope is stored in
the peerstore as the local record for future reuse.
Parameters
----------
host : IHost
The local host instance, providing access to peer ID, listen addresses,
private key, and the peerstore.
Returns
-------
bytes
The serialized envelope (bytes) representing the newly created signed
peer record.
"""
env = create_signed_peer_record(
host.get_id(),
host.get_addrs(),
host.get_private_key(),
)
# Cache it for next time use
host.get_peerstore().set_local_record(env)
return env.marshal_envelope()
class PeerRecordState:
envelope: Envelope
seq: int
@ -55,8 +137,17 @@ class PeerStore(IPeerStore):
self.peer_data_map = defaultdict(PeerData)
self.addr_update_channels: dict[ID, MemorySendChannel[Multiaddr]] = {}
self.peer_record_map: dict[ID, PeerRecordState] = {}
self.local_peer_record: Envelope | None = None
self.max_records = max_records
def get_local_record(self) -> Envelope | None:
"""Get the local-signed-record wrapped in Envelope"""
return self.local_peer_record
def set_local_record(self, envelope: Envelope) -> None:
"""Set the local-signed-record wrapped in Envelope"""
self.local_peer_record = envelope
def peer_info(self, peer_id: ID) -> PeerInfo:
"""
:param peer_id: peer ID to get info for

View File

@ -0,0 +1,68 @@
from abc import ABC, abstractmethod
from libp2p.abc import IRawConnection
from libp2p.custom_types import TProtocol
from libp2p.peer.id import ID
from .pb import noise_pb2 as noise_pb
class EarlyDataHandler(ABC):
"""Interface for handling early data during Noise handshake"""
@abstractmethod
async def send(
self, conn: IRawConnection, peer_id: ID
) -> noise_pb.NoiseExtensions | None:
"""Called to generate early data to send during handshake"""
pass
@abstractmethod
async def received(
self, conn: IRawConnection, extensions: noise_pb.NoiseExtensions | None
) -> None:
"""Called when early data is received during handshake"""
pass
class TransportEarlyDataHandler(EarlyDataHandler):
"""Default early data handler for muxer negotiation"""
def __init__(self, supported_muxers: list[TProtocol]):
self.supported_muxers = supported_muxers
self.received_muxers: list[TProtocol] = []
async def send(
self, conn: IRawConnection, peer_id: ID
) -> noise_pb.NoiseExtensions | None:
"""Send our supported muxers list"""
if not self.supported_muxers:
return None
extensions = noise_pb.NoiseExtensions()
# Convert TProtocol to string for serialization
extensions.stream_muxers[:] = [str(muxer) for muxer in self.supported_muxers]
return extensions
async def received(
self, conn: IRawConnection, extensions: noise_pb.NoiseExtensions | None
) -> None:
"""Store received muxers list"""
if extensions and extensions.stream_muxers:
self.received_muxers = [
TProtocol(muxer) for muxer in extensions.stream_muxers
]
def match_muxers(self, is_initiator: bool) -> TProtocol | None:
"""Find first common muxer between local and remote"""
if is_initiator:
# Initiator: find first local muxer that remote supports
for local_muxer in self.supported_muxers:
if local_muxer in self.received_muxers:
return local_muxer
else:
# Responder: find first remote muxer that we support
for remote_muxer in self.received_muxers:
if remote_muxer in self.supported_muxers:
return remote_muxer
return None

View File

@ -30,6 +30,9 @@ from libp2p.security.secure_session import (
SecureSession,
)
from .early_data import (
EarlyDataHandler,
)
from .exceptions import (
HandshakeHasNotFinished,
InvalidSignature,
@ -45,6 +48,7 @@ from .messages import (
make_handshake_payload_sig,
verify_handshake_payload_sig,
)
from .pb import noise_pb2 as noise_pb
class IPattern(ABC):
@ -62,7 +66,8 @@ class BasePattern(IPattern):
noise_static_key: PrivateKey
local_peer: ID
libp2p_privkey: PrivateKey
early_data: bytes | None
initiator_early_data_handler: EarlyDataHandler | None
responder_early_data_handler: EarlyDataHandler | None
def create_noise_state(self) -> NoiseState:
noise_state = NoiseState.from_name(self.protocol_name)
@ -73,11 +78,50 @@ class BasePattern(IPattern):
raise NoiseStateError("noise_protocol is not initialized")
return noise_state
def make_handshake_payload(self) -> NoiseHandshakePayload:
async def make_handshake_payload(
self, conn: IRawConnection, peer_id: ID, is_initiator: bool
) -> NoiseHandshakePayload:
signature = make_handshake_payload_sig(
self.libp2p_privkey, self.noise_static_key.get_public_key()
)
return NoiseHandshakePayload(self.libp2p_privkey.get_public_key(), signature)
# NEW: Get early data from appropriate handler
extensions = None
if is_initiator and self.initiator_early_data_handler:
extensions = await self.initiator_early_data_handler.send(conn, peer_id)
elif not is_initiator and self.responder_early_data_handler:
extensions = await self.responder_early_data_handler.send(conn, peer_id)
# NEW: Serialize extensions into early_data field
early_data = None
if extensions:
early_data = extensions.SerializeToString()
return NoiseHandshakePayload(
self.libp2p_privkey.get_public_key(),
signature,
early_data, # ← This is the key addition
)
async def handle_received_payload(
self, conn: IRawConnection, payload: NoiseHandshakePayload, is_initiator: bool
) -> None:
"""Process early data from received payload"""
if not payload.early_data:
return
# Deserialize the NoiseExtensions from early_data field
try:
extensions = noise_pb.NoiseExtensions.FromString(payload.early_data)
except Exception:
# Invalid extensions, ignore silently
return
# Pass to appropriate handler
if is_initiator and self.initiator_early_data_handler:
await self.initiator_early_data_handler.received(conn, extensions)
elif not is_initiator and self.responder_early_data_handler:
await self.responder_early_data_handler.received(conn, extensions)
class PatternXX(BasePattern):
@ -86,13 +130,15 @@ class PatternXX(BasePattern):
local_peer: ID,
libp2p_privkey: PrivateKey,
noise_static_key: PrivateKey,
early_data: bytes | None = None,
initiator_early_data_handler: EarlyDataHandler | None,
responder_early_data_handler: EarlyDataHandler | None,
) -> None:
self.protocol_name = b"Noise_XX_25519_ChaChaPoly_SHA256"
self.local_peer = local_peer
self.libp2p_privkey = libp2p_privkey
self.noise_static_key = noise_static_key
self.early_data = early_data
self.initiator_early_data_handler = initiator_early_data_handler
self.responder_early_data_handler = responder_early_data_handler
async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn:
noise_state = self.create_noise_state()
@ -106,18 +152,23 @@ class PatternXX(BasePattern):
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
# Consume msg#1.
# 1. Consume msg#1 (just empty bytes)
await read_writer.read_msg()
# Send msg#2, which should include our handshake payload.
our_payload = self.make_handshake_payload()
# 2. Send msg#2 with our payload INCLUDING EARLY DATA
our_payload = await self.make_handshake_payload(
conn,
self.local_peer, # We send our own peer ID in responder role
is_initiator=False,
)
msg_2 = our_payload.serialize()
await read_writer.write_msg(msg_2)
# Receive and consume msg#3.
# 3. Receive msg#3
msg_3 = await read_writer.read_msg()
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_3)
# Extract remote pubkey from noise handshake state
if handshake_state.rs is None:
raise NoiseStateError(
"something is wrong in the underlying noise `handshake_state`: "
@ -126,14 +177,31 @@ class PatternXX(BasePattern):
)
remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs)
# 4. Verify signature (unchanged)
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
raise InvalidSignature
# NEW: Process early data from msg#3 AFTER signature verification
await self.handle_received_payload(
conn, peer_handshake_payload, is_initiator=False
)
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
if not noise_state.handshake_finished:
raise HandshakeHasNotFinished(
"handshake is done but it is not marked as finished in `noise_state`"
)
# NEW: Get negotiated muxer for connection state
# negotiated_muxer = None
if self.responder_early_data_handler and hasattr(
self.responder_early_data_handler, "match_muxers"
):
# negotiated_muxer =
# self.responder_early_data_handler.match_muxers(is_initiator=False)
pass
transport_read_writer = NoiseTransportReadWriter(conn, noise_state)
return SecureSession(
local_peer=self.local_peer,
@ -142,6 +210,8 @@ class PatternXX(BasePattern):
remote_permanent_pubkey=remote_pubkey,
is_initiator=False,
conn=transport_read_writer,
# NOTE: negotiated_muxer would need to be added to SecureSession constructor
# For now, store it in connection metadata or similar
)
async def handshake_outbound(
@ -158,24 +228,27 @@ class PatternXX(BasePattern):
if handshake_state is None:
raise NoiseStateError("Handshake state is not initialized")
# Send msg#1, which is *not* encrypted.
# 1. Send msg#1 (empty) - no early data possible in XX pattern
msg_1 = b""
await read_writer.write_msg(msg_1)
# Read msg#2 from the remote, which contains the public key of the peer.
# 2. Read msg#2 from responder
msg_2 = await read_writer.read_msg()
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_2)
# Extract remote pubkey from noise handshake state
if handshake_state.rs is None:
raise NoiseStateError(
"something is wrong in the underlying noise `handshake_state`: "
"we received and consumed msg#3, which should have included the "
"we received and consumed msg#2, which should have included the "
"remote static public key, but it is not present in the handshake_state"
)
remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs)
# Verify signature BEFORE processing early data (security)
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
raise InvalidSignature
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
if remote_peer_id_from_pubkey != remote_peer:
raise PeerIDMismatchesPubkey(
@ -184,8 +257,15 @@ class PatternXX(BasePattern):
f"remote_peer_id_from_pubkey={remote_peer_id_from_pubkey}"
)
# Send msg#3, which includes our encrypted payload and our noise static key.
our_payload = self.make_handshake_payload()
# NEW: Process early data from msg#2 AFTER verification
await self.handle_received_payload(
conn, peer_handshake_payload, is_initiator=True
)
# 3. Send msg#3 with our payload INCLUDING EARLY DATA
our_payload = await self.make_handshake_payload(
conn, remote_peer, is_initiator=True
)
msg_3 = our_payload.serialize()
await read_writer.write_msg(msg_3)
@ -193,6 +273,16 @@ class PatternXX(BasePattern):
raise HandshakeHasNotFinished(
"handshake is done but it is not marked as finished in `noise_state`"
)
# NEW: Get negotiated muxer
# negotiated_muxer = None
if self.initiator_early_data_handler and hasattr(
self.initiator_early_data_handler, "match_muxers"
):
pass
# negotiated_muxer =
# self.initiator_early_data_handler.match_muxers(is_initiator=True)
transport_read_writer = NoiseTransportReadWriter(conn, noise_state)
return SecureSession(
local_peer=self.local_peer,
@ -201,6 +291,8 @@ class PatternXX(BasePattern):
remote_permanent_pubkey=remote_pubkey,
is_initiator=True,
conn=transport_read_writer,
# NOTE: negotiated_muxer would need to be added to SecureSession constructor
# For now, store it in connection metadata or similar
)
@staticmethod

View File

@ -1,8 +1,13 @@
syntax = "proto3";
syntax = "proto2";
package pb;
message NoiseHandshakePayload {
bytes identity_key = 1;
bytes identity_sig = 2;
bytes data = 3;
message NoiseExtensions {
repeated bytes webtransport_certhashes = 1;
repeated string stream_muxers = 2;
}
message NoiseHandshakePayload {
optional bytes identity_key = 1;
optional bytes identity_sig = 2;
optional bytes data = 3;
}

View File

@ -13,13 +13,15 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$libp2p/security/noise/pb/noise.proto\x12\x02pb\"Q\n\x15NoiseHandshakePayload\x12\x14\n\x0cidentity_key\x18\x01 \x01(\x0c\x12\x14\n\x0cidentity_sig\x18\x02 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\x62\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$libp2p/security/noise/pb/noise.proto\x12\x02pb\"I\n\x0fNoiseExtensions\x12\x1f\n\x17webtransport_certhashes\x18\x01 \x03(\x0c\x12\x15\n\rstream_muxers\x18\x02 \x03(\t\"Q\n\x15NoiseHandshakePayload\x12\x14\n\x0cidentity_key\x18\x01 \x01(\x0c\x12\x14\n\x0cidentity_sig\x18\x02 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.security.noise.pb.noise_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_NOISEHANDSHAKEPAYLOAD._serialized_start=44
_NOISEHANDSHAKEPAYLOAD._serialized_end=125
_NOISEEXTENSIONS._serialized_start=44
_NOISEEXTENSIONS._serialized_end=117
_NOISEHANDSHAKEPAYLOAD._serialized_start=119
_NOISEHANDSHAKEPAYLOAD._serialized_end=200
# @@protoc_insertion_point(module_scope)

View File

@ -4,12 +4,34 @@ isort:skip_file
"""
import builtins
import collections.abc
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.message
import typing
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@typing.final
class NoiseExtensions(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
WEBTRANSPORT_CERTHASHES_FIELD_NUMBER: builtins.int
STREAM_MUXERS_FIELD_NUMBER: builtins.int
@property
def webtransport_certhashes(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ...
@property
def stream_muxers(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ...
def __init__(
self,
*,
webtransport_certhashes: collections.abc.Iterable[builtins.bytes] | None = ...,
stream_muxers: collections.abc.Iterable[builtins.str] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["stream_muxers", b"stream_muxers", "webtransport_certhashes", b"webtransport_certhashes"]) -> None: ...
global___NoiseExtensions = NoiseExtensions
@typing.final
class NoiseHandshakePayload(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@ -23,10 +45,11 @@ class NoiseHandshakePayload(google.protobuf.message.Message):
def __init__(
self,
*,
identity_key: builtins.bytes = ...,
identity_sig: builtins.bytes = ...,
data: builtins.bytes = ...,
identity_key: builtins.bytes | None = ...,
identity_sig: builtins.bytes | None = ...,
data: builtins.bytes | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["data", b"data", "identity_key", b"identity_key", "identity_sig", b"identity_sig"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["data", b"data", "identity_key", b"identity_key", "identity_sig", b"identity_sig"]) -> None: ...
global___NoiseHandshakePayload = NoiseHandshakePayload

View File

@ -14,6 +14,7 @@ from libp2p.peer.id import (
ID,
)
from .early_data import EarlyDataHandler, TransportEarlyDataHandler
from .patterns import (
IPattern,
PatternXX,
@ -26,35 +27,40 @@ class Transport(ISecureTransport):
libp2p_privkey: PrivateKey
noise_privkey: PrivateKey
local_peer: ID
early_data: bytes | None
with_noise_pipes: bool
supported_muxers: list[TProtocol]
initiator_early_data_handler: EarlyDataHandler | None
responder_early_data_handler: EarlyDataHandler | None
def __init__(
self,
libp2p_keypair: KeyPair,
noise_privkey: PrivateKey,
early_data: bytes | None = None,
with_noise_pipes: bool = False,
supported_muxers: list[TProtocol] | None = None,
initiator_handler: EarlyDataHandler | None = None,
responder_handler: EarlyDataHandler | None = None,
) -> None:
self.libp2p_privkey = libp2p_keypair.private_key
self.noise_privkey = noise_privkey
self.local_peer = ID.from_pubkey(libp2p_keypair.public_key)
self.early_data = early_data
self.with_noise_pipes = with_noise_pipes
self.supported_muxers = supported_muxers or []
if self.with_noise_pipes:
raise NotImplementedError
# Create default handlers for muxer negotiation if none provided
if initiator_handler is None and self.supported_muxers:
initiator_handler = TransportEarlyDataHandler(self.supported_muxers)
if responder_handler is None and self.supported_muxers:
responder_handler = TransportEarlyDataHandler(self.supported_muxers)
self.initiator_early_data_handler = initiator_handler
self.responder_early_data_handler = responder_handler
def get_pattern(self) -> IPattern:
if self.with_noise_pipes:
raise NotImplementedError
else:
return PatternXX(
self.local_peer,
self.libp2p_privkey,
self.noise_privkey,
self.early_data,
)
return PatternXX(
self.local_peer,
self.libp2p_privkey,
self.noise_privkey,
self.initiator_early_data_handler,
self.responder_early_data_handler,
)
async def secure_inbound(self, conn: IRawConnection) -> ISecureConn:
pattern = self.get_pattern()

View File

@ -0,0 +1 @@
KAD-DHT now include signed-peer-records in its protobuf message schema, for more secure peer-discovery.

View File

@ -9,11 +9,15 @@ This module tests core functionality of the Kademlia DHT including:
import hashlib
import logging
import os
from unittest.mock import patch
import uuid
import pytest
import multiaddr
import trio
from libp2p.crypto.rsa import create_new_key_pair
from libp2p.kad_dht.kad_dht import (
DHTMode,
KadDHT,
@ -21,9 +25,13 @@ from libp2p.kad_dht.kad_dht import (
from libp2p.kad_dht.utils import (
create_key_from_binary,
)
from libp2p.peer.envelope import Envelope, seal_record
from libp2p.peer.id import ID
from libp2p.peer.peer_record import PeerRecord
from libp2p.peer.peerinfo import (
PeerInfo,
)
from libp2p.peer.peerstore import create_signed_peer_record
from libp2p.tools.async_service import (
background_trio_service,
)
@ -76,10 +84,52 @@ async def test_find_node(dht_pair: tuple[KadDHT, KadDHT]):
"""Test that nodes can find each other in the DHT."""
dht_a, dht_b = dht_pair
# An extra FIND_NODE req is sent between the 2 nodes while dht creation,
# so both the nodes will have records of each other before the next FIND_NODE
# req is sent
envelope_a = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id())
envelope_b = dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id())
assert isinstance(envelope_a, Envelope)
assert isinstance(envelope_b, Envelope)
record_a = envelope_a.record()
record_b = envelope_b.record()
# Node A should be able to find Node B
with trio.fail_after(TEST_TIMEOUT):
found_info = await dht_a.find_peer(dht_b.host.get_id())
# Verifies if the senderRecord in the FIND_NODE request is correctly processed
assert isinstance(
dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id()), Envelope
)
# Verifies if the senderRecord in the FIND_NODE response is correctly processed
assert isinstance(
dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id()), Envelope
)
# These are the records that were sent between the peers during the FIND_NODE req
envelope_a_find_peer = dht_a.host.get_peerstore().get_peer_record(
dht_b.host.get_id()
)
envelope_b_find_peer = dht_b.host.get_peerstore().get_peer_record(
dht_a.host.get_id()
)
assert isinstance(envelope_a_find_peer, Envelope)
assert isinstance(envelope_b_find_peer, Envelope)
record_a_find_peer = envelope_a_find_peer.record()
record_b_find_peer = envelope_b_find_peer.record()
# This proves that both the records are same, and a latest cached signed record
# was passed between the peers during FIND_NODE execution, which proves the
# signed-record transfer/re-issuing works correctly in FIND_NODE executions.
assert record_a.seq == record_a_find_peer.seq
assert record_b.seq == record_b_find_peer.seq
# Verify that the found peer has the correct peer ID
assert found_info is not None, "Failed to find the target peer"
assert found_info.peer_id == dht_b.host.get_id(), "Found incorrect peer ID"
@ -104,14 +154,44 @@ async def test_put_and_get_value(dht_pair: tuple[KadDHT, KadDHT]):
await dht_a.routing_table.add_peer(peer_b_info)
print("Routing table of a has ", dht_a.routing_table.get_peer_ids())
# An extra FIND_NODE req is sent between the 2 nodes while dht creation,
# so both the nodes will have records of each other before PUT_VALUE req is sent
envelope_a = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id())
envelope_b = dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id())
assert isinstance(envelope_a, Envelope)
assert isinstance(envelope_b, Envelope)
record_a = envelope_a.record()
record_b = envelope_b.record()
# Store the value using the first node (this will also store locally)
with trio.fail_after(TEST_TIMEOUT):
await dht_a.put_value(key, value)
# These are the records that were sent between the peers during the PUT_VALUE req
envelope_a_put_value = dht_a.host.get_peerstore().get_peer_record(
dht_b.host.get_id()
)
envelope_b_put_value = dht_b.host.get_peerstore().get_peer_record(
dht_a.host.get_id()
)
assert isinstance(envelope_a_put_value, Envelope)
assert isinstance(envelope_b_put_value, Envelope)
record_a_put_value = envelope_a_put_value.record()
record_b_put_value = envelope_b_put_value.record()
# This proves that both the records are same, and a latest cached signed record
# was passed between the peers during PUT_VALUE execution, which proves the
# signed-record transfer/re-issuing works correctly in PUT_VALUE executions.
assert record_a.seq == record_a_put_value.seq
assert record_b.seq == record_b_put_value.seq
# # Log debugging information
logger.debug("Put value with key %s...", key.hex()[:10])
logger.debug("Node A value store: %s", dht_a.value_store.store)
print("hello test")
# # Allow more time for the value to propagate
await trio.sleep(0.5)
@ -126,6 +206,26 @@ async def test_put_and_get_value(dht_pair: tuple[KadDHT, KadDHT]):
print("the value stored in node b is", dht_b.get_value_store_size())
logger.debug("Retrieved value: %s", retrieved_value)
# These are the records that were sent between the peers during the PUT_VALUE req
envelope_a_get_value = dht_a.host.get_peerstore().get_peer_record(
dht_b.host.get_id()
)
envelope_b_get_value = dht_b.host.get_peerstore().get_peer_record(
dht_a.host.get_id()
)
assert isinstance(envelope_a_get_value, Envelope)
assert isinstance(envelope_b_get_value, Envelope)
record_a_get_value = envelope_a_get_value.record()
record_b_get_value = envelope_b_get_value.record()
# This proves that there was no record exchange between the nodes during GET_VALUE
# execution, as dht_b already had the key/value pair stored locally after the
# PUT_VALUE execution.
assert record_a_get_value.seq == record_a_put_value.seq
assert record_b_get_value.seq == record_b_put_value.seq
# Verify that the retrieved value matches the original
assert retrieved_value == value, "Retrieved value does not match the stored value"
@ -142,11 +242,44 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]):
# Store content on the first node
dht_a.value_store.put(content_id, content)
# An extra FIND_NODE req is sent between the 2 nodes while dht creation,
# so both the nodes will have records of each other before PUT_VALUE req is sent
envelope_a = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id())
envelope_b = dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id())
assert isinstance(envelope_a, Envelope)
assert isinstance(envelope_b, Envelope)
record_a = envelope_a.record()
record_b = envelope_b.record()
# Advertise the first node as a provider
with trio.fail_after(TEST_TIMEOUT):
success = await dht_a.provide(content_id)
assert success, "Failed to advertise as provider"
# These are the records that were sent between the peers during
# the ADD_PROVIDER req
envelope_a_add_prov = dht_a.host.get_peerstore().get_peer_record(
dht_b.host.get_id()
)
envelope_b_add_prov = dht_b.host.get_peerstore().get_peer_record(
dht_a.host.get_id()
)
assert isinstance(envelope_a_add_prov, Envelope)
assert isinstance(envelope_b_add_prov, Envelope)
record_a_add_prov = envelope_a_add_prov.record()
record_b_add_prov = envelope_b_add_prov.record()
# This proves that both the records are same, the latest cached signed record
# was passed between the peers during ADD_PROVIDER execution, which proves the
# signed-record transfer/re-issuing of the latest record works correctly in
# ADD_PROVIDER executions.
assert record_a.seq == record_a_add_prov.seq
assert record_b.seq == record_b_add_prov.seq
# Allow time for the provider record to propagate
await trio.sleep(0.1)
@ -154,6 +287,26 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]):
with trio.fail_after(TEST_TIMEOUT):
providers = await dht_b.find_providers(content_id)
# These are the records in each peer after the find_provider execution
envelope_a_find_prov = dht_a.host.get_peerstore().get_peer_record(
dht_b.host.get_id()
)
envelope_b_find_prov = dht_b.host.get_peerstore().get_peer_record(
dht_a.host.get_id()
)
assert isinstance(envelope_a_find_prov, Envelope)
assert isinstance(envelope_b_find_prov, Envelope)
record_a_find_prov = envelope_a_find_prov.record()
record_b_find_prov = envelope_b_find_prov.record()
# This proves that both the records are same, as the dht_b already
# has the provider record for the content_id, after the ADD_PROVIDER
# advertisement by dht_a
assert record_a_find_prov.seq == record_a_add_prov.seq
assert record_b_find_prov.seq == record_b_add_prov.seq
# Verify that we found the first node as a provider
assert providers, "No providers found"
assert any(p.peer_id == dht_a.local_peer_id for p in providers), (
@ -166,3 +319,143 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]):
assert retrieved_value == content, (
"Retrieved content does not match the original"
)
# These are the record state of each peer aftet the GET_VALUE execution
envelope_a_get_value = dht_a.host.get_peerstore().get_peer_record(
dht_b.host.get_id()
)
envelope_b_get_value = dht_b.host.get_peerstore().get_peer_record(
dht_a.host.get_id()
)
assert isinstance(envelope_a_get_value, Envelope)
assert isinstance(envelope_b_get_value, Envelope)
record_a_get_value = envelope_a_get_value.record()
record_b_get_value = envelope_b_get_value.record()
# This proves that both the records are same, meaning that the latest cached
# signed-record tranfer happened during the GET_VALUE execution by dht_b,
# which means the signed-record transfer/re-issuing works correctly
# in GET_VALUE executions.
assert record_a_find_prov.seq == record_a_get_value.seq
assert record_b_find_prov.seq == record_b_get_value.seq
# Create a new provider record in dht_a
provider_key_pair = create_new_key_pair()
provider_peer_id = ID.from_pubkey(provider_key_pair.public_key)
provider_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/123")
provider_peer_info = PeerInfo(peer_id=provider_peer_id, addrs=[provider_addr])
# Generate a random content ID
content_2 = f"random-content-{uuid.uuid4()}".encode()
content_id_2 = hashlib.sha256(content_2).digest()
provider_signed_envelope = create_signed_peer_record(
provider_peer_id, [provider_addr], provider_key_pair.private_key
)
assert (
dht_a.host.get_peerstore().consume_peer_record(provider_signed_envelope, 7200)
is True
)
# Store this provider record in dht_a
dht_a.provider_store.add_provider(content_id_2, provider_peer_info)
# Fetch the provider-record via peer-discovery at dht_b's end
peerinfo = await dht_b.provider_store.find_providers(content_id_2)
assert len(peerinfo) == 1
assert peerinfo[0].peer_id == provider_peer_id
provider_envelope = dht_b.host.get_peerstore().get_peer_record(provider_peer_id)
# This proves that the signed-envelope of provider is consumed on dht_b's end
assert provider_envelope is not None
assert (
provider_signed_envelope.marshal_envelope()
== provider_envelope.marshal_envelope()
)
@pytest.mark.trio
async def test_reissue_when_listen_addrs_change(dht_pair: tuple[KadDHT, KadDHT]):
dht_a, dht_b = dht_pair
# Warm-up: A stores B's current record
with trio.fail_after(10):
await dht_a.find_peer(dht_b.host.get_id())
env0 = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id())
assert isinstance(env0, Envelope)
seq0 = env0.record().seq
# Simulate B's listen addrs changing (different port)
new_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/123")
# Patch just for the duration we force B to respond:
with patch.object(dht_b.host, "get_addrs", return_value=[new_addr]):
# Force B to send a response (which should include a fresh SPR)
with trio.fail_after(10):
await dht_a.peer_routing._query_peer_for_closest(
dht_b.host.get_id(), os.urandom(32)
)
# A should now hold B's new record with a bumped seq
env1 = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id())
assert isinstance(env1, Envelope)
seq1 = env1.record().seq
# This proves that upon the change in listen_addrs, we issue new records
assert seq1 > seq0, f"Expected seq to bump after addr change, got {seq0} -> {seq1}"
@pytest.mark.trio
async def test_dht_req_fail_with_invalid_record_transfer(
dht_pair: tuple[KadDHT, KadDHT],
):
"""
Testing showing failure of storing and retrieving values in the DHT,
if invalid signed-records are sent.
"""
dht_a, dht_b = dht_pair
peer_b_info = PeerInfo(dht_b.host.get_id(), dht_b.host.get_addrs())
# Generate a random key and value
key = create_key_from_binary(b"test-key")
value = b"test-value"
# First add the value directly to node A's store to verify storage works
dht_a.value_store.put(key, value)
local_value = dht_a.value_store.get(key)
assert local_value == value, "Local value storage failed"
await dht_a.routing_table.add_peer(peer_b_info)
# Corrupt dht_a's local peer_record
envelope = dht_a.host.get_peerstore().get_local_record()
if envelope is not None:
true_record = envelope.record()
key_pair = create_new_key_pair()
if envelope is not None:
envelope.public_key = key_pair.public_key
dht_a.host.get_peerstore().set_local_record(envelope)
await dht_a.put_value(key, value)
retrieved_value = dht_b.value_store.get(key)
# This proves that DHT_B rejected DHT_A PUT_RECORD req upon receiving
# the corrupted invalid record
assert retrieved_value is None
# Create a corrupt envelope with correct signature but false peer_id
false_record = PeerRecord(ID.from_pubkey(key_pair.public_key), true_record.addrs)
false_envelope = seal_record(false_record, dht_a.host.get_private_key())
dht_a.host.get_peerstore().set_local_record(false_envelope)
await dht_a.put_value(key, value)
retrieved_value = dht_b.value_store.get(key)
# This proves that DHT_B rejected DHT_A PUT_RECORD req upon receving
# the record with a different peer_id regardless of a valid signature
assert retrieved_value is None

View File

@ -57,7 +57,10 @@ class TestPeerRouting:
def mock_host(self):
"""Create a mock host for testing."""
host = Mock()
host.get_id.return_value = create_valid_peer_id("local")
key_pair = create_new_key_pair()
host.get_id.return_value = ID.from_pubkey(key_pair.public_key)
host.get_public_key.return_value = key_pair.public_key
host.get_private_key.return_value = key_pair.private_key
host.get_addrs.return_value = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
host.get_peerstore.return_value = Mock()
host.new_stream = AsyncMock()

View File

@ -0,0 +1,13 @@
from libp2p.security.noise.pb import noise_pb2 as noise_pb
def test_noise_extensions_serialization():
# Test NoiseExtensions
ext = noise_pb.NoiseExtensions()
ext.stream_muxers.append("/mplex/6.7.0")
ext.stream_muxers.append("/yamux/1.0.0")
# Serialize and deserialize
data = ext.SerializeToString()
ext2 = noise_pb.NoiseExtensions.FromString(data)
assert list(ext2.stream_muxers) == ["/mplex/6.7.0", "/yamux/1.0.0"]

View File

@ -173,8 +173,7 @@ def noise_transport_factory(key_pair: KeyPair) -> ISecureTransport:
return NoiseTransport(
libp2p_keypair=key_pair,
noise_privkey=noise_static_key_factory(),
early_data=None,
with_noise_pipes=False,
# TODO: add early data
)