diff --git a/libp2p/__init__.py b/libp2p/__init__.py index d2ce122a..350ae46b 100644 --- a/libp2p/__init__.py +++ b/libp2p/__init__.py @@ -49,6 +49,7 @@ from libp2p.peer.id import ( ) from libp2p.peer.peerstore import ( PeerStore, + create_signed_peer_record, ) from libp2p.security.insecure.transport import ( PLAINTEXT_PROTOCOL_ID, @@ -155,7 +156,6 @@ def get_default_muxer_options() -> TMuxerOptions: else: # YAMUX is default return create_yamux_muxer_option() - def new_swarm( key_pair: KeyPair | None = None, muxer_opt: TMuxerOptions | None = None, diff --git a/libp2p/abc.py b/libp2p/abc.py index 90ad6a45..a9748339 100644 --- a/libp2p/abc.py +++ b/libp2p/abc.py @@ -970,6 +970,14 @@ class IPeerStore( # --------CERTIFIED-ADDR-BOOK---------- + @abstractmethod + def get_local_record(self) -> Optional["Envelope"]: + """Get the local-peer-record wrapped in Envelope""" + + @abstractmethod + def set_local_record(self, envelope: "Envelope") -> None: + """Set the local-peer-record wrapped in Envelope""" + @abstractmethod def consume_peer_record(self, envelope: "Envelope", ttl: int) -> bool: """ diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index b40b0128..a0311bd8 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -43,6 +43,7 @@ from libp2p.peer.id import ( from libp2p.peer.peerinfo import ( PeerInfo, ) +from libp2p.peer.peerstore import create_signed_peer_record from libp2p.protocol_muxer.exceptions import ( MultiselectClientError, MultiselectError, @@ -110,6 +111,14 @@ class BasicHost(IHost): if bootstrap: self.bootstrap = BootstrapDiscovery(network, bootstrap) + # Cache a signed-record if the local-node in the PeerStore + envelope = create_signed_peer_record( + self.get_id(), + self.get_addrs(), + self.get_private_key(), + ) + self.get_peerstore().set_local_record(envelope) + def get_id(self) -> ID: """ :return: peer_id of host diff --git a/libp2p/identity/identify/identify.py b/libp2p/identity/identify/identify.py index b2811ff9..146fbd2d 100644 --- a/libp2p/identity/identify/identify.py +++ b/libp2p/identity/identify/identify.py @@ -15,8 +15,7 @@ from libp2p.custom_types import ( from libp2p.network.stream.exceptions import ( StreamClosed, ) -from libp2p.peer.envelope import seal_record -from libp2p.peer.peer_record import PeerRecord +from libp2p.peer.peerstore import env_to_send_in_RPC from libp2p.utils import ( decode_varint_with_size, get_agent_version, @@ -66,9 +65,7 @@ def _mk_identify_protobuf( protocols = tuple(str(p) for p in host.get_mux().get_protocols() if p is not None) # Create a signed peer-record for the remote peer - record = PeerRecord(host.get_id(), host.get_addrs()) - envelope = seal_record(record, host.get_private_key()) - protobuf = envelope.marshal_envelope() + envelope_bytes, _ = env_to_send_in_RPC(host) observed_addr = observed_multiaddr.to_bytes() if observed_multiaddr else b"" return Identify( @@ -78,7 +75,7 @@ def _mk_identify_protobuf( listen_addrs=map(_multiaddr_to_bytes, laddrs), observed_addr=observed_addr, protocols=protocols, - signedPeerRecord=protobuf, + signedPeerRecord=envelope_bytes, ) diff --git a/libp2p/kad_dht/kad_dht.py b/libp2p/kad_dht/kad_dht.py index 097b6c48..0d05aaf8 100644 --- a/libp2p/kad_dht/kad_dht.py +++ b/libp2p/kad_dht/kad_dht.py @@ -22,15 +22,18 @@ from libp2p.abc import ( IHost, ) from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager +from libp2p.kad_dht.utils import maybe_consume_signed_record from libp2p.network.stream.net_stream import ( INetStream, ) +from libp2p.peer.envelope import Envelope from libp2p.peer.id import ( ID, ) from libp2p.peer.peerinfo import ( PeerInfo, ) +from libp2p.peer.peerstore import env_to_send_in_RPC from libp2p.tools.async_service import ( Service, ) @@ -234,6 +237,9 @@ class KadDHT(Service): await self.add_peer(peer_id) logger.debug(f"Added peer {peer_id} to routing table") + closer_peer_envelope: Envelope | None = None + provider_peer_envelope: Envelope | None = None + try: # Read varint-prefixed length for the message length_prefix = b"" @@ -274,6 +280,14 @@ class KadDHT(Service): ) logger.debug(f"Found {len(closest_peers)} peers close to target") + # Consume the source signed_peer_record if sent + if not maybe_consume_signed_record(message, self.host, peer_id): + logger.error( + "Received an invalid-signed-record, dropping the stream" + ) + await stream.close() + return + # Build response message with protobuf response = Message() response.type = Message.MessageType.FIND_NODE @@ -298,6 +312,21 @@ class KadDHT(Service): except Exception: pass + # Add the signed-peer-record for each peer in the peer-proto + # if cached in the peerstore + closer_peer_envelope = ( + self.host.get_peerstore().get_peer_record(peer) + ) + + if closer_peer_envelope is not None: + peer_proto.signedRecord = ( + closer_peer_envelope.marshal_envelope() + ) + + # Create sender_signed_peer_record + envelope_bytes, _ = env_to_send_in_RPC(self.host) + response.senderRecord = envelope_bytes + # Serialize and send response response_bytes = response.SerializeToString() await stream.write(varint.encode(len(response_bytes))) @@ -312,6 +341,14 @@ class KadDHT(Service): key = message.key logger.debug(f"Received ADD_PROVIDER for key {key.hex()}") + # Consume the source signed-peer-record if sent + if not maybe_consume_signed_record(message, self.host, peer_id): + logger.error( + "Received an invalid-signed-record, dropping the stream" + ) + await stream.close() + return + # Extract provider information for provider_proto in message.providerPeers: try: @@ -338,6 +375,17 @@ class KadDHT(Service): logger.debug( f"Added provider {provider_id} for key {key.hex()}" ) + + # Process the signed-records of provider if sent + if not maybe_consume_signed_record( + provider_proto, self.host + ): + logger.error( + "Received an invalid-signed-record," + "dropping the stream" + ) + await stream.close() + return except Exception as e: logger.warning(f"Failed to process provider info: {e}") @@ -346,6 +394,10 @@ class KadDHT(Service): response.type = Message.MessageType.ADD_PROVIDER response.key = key + # Add sender's signed-peer-record + envelope_bytes, _ = env_to_send_in_RPC(self.host) + response.senderRecord = envelope_bytes + response_bytes = response.SerializeToString() await stream.write(varint.encode(len(response_bytes))) await stream.write(response_bytes) @@ -357,6 +409,14 @@ class KadDHT(Service): key = message.key logger.debug(f"Received GET_PROVIDERS request for key {key.hex()}") + # Consume the source signed_peer_record if sent + if not maybe_consume_signed_record(message, self.host, peer_id): + logger.error( + "Received an invalid-signed-record, dropping the stream" + ) + await stream.close() + return + # Find providers for the key providers = self.provider_store.get_providers(key) logger.debug( @@ -368,12 +428,28 @@ class KadDHT(Service): response.type = Message.MessageType.GET_PROVIDERS response.key = key + # Create sender_signed_peer_record for the response + envelope_bytes, _ = env_to_send_in_RPC(self.host) + response.senderRecord = envelope_bytes + # Add provider information to response for provider_info in providers: provider_proto = response.providerPeers.add() provider_proto.id = provider_info.peer_id.to_bytes() provider_proto.connection = Message.ConnectionType.CAN_CONNECT + # Add provider signed-records if cached + provider_peer_envelope = ( + self.host.get_peerstore().get_peer_record( + provider_info.peer_id + ) + ) + + if provider_peer_envelope is not None: + provider_proto.signedRecord = ( + provider_peer_envelope.marshal_envelope() + ) + # Add addresses if available for addr in provider_info.addrs: provider_proto.addrs.append(addr.to_bytes()) @@ -397,6 +473,16 @@ class KadDHT(Service): peer_proto.id = peer.to_bytes() peer_proto.connection = Message.ConnectionType.CAN_CONNECT + # Add the signed-records of closest_peers if cached + closer_peer_envelope = ( + self.host.get_peerstore().get_peer_record(peer) + ) + + if closer_peer_envelope is not None: + peer_proto.signedRecord = ( + closer_peer_envelope.marshal_envelope() + ) + # Add addresses if available try: addrs = self.host.get_peerstore().addrs(peer) @@ -417,6 +503,14 @@ class KadDHT(Service): key = message.key logger.debug(f"Received GET_VALUE request for key {key.hex()}") + # Consume the sender_signed_peer_record + if not maybe_consume_signed_record(message, self.host, peer_id): + logger.error( + "Received an invalid-signed-record, dropping the stream" + ) + await stream.close() + return + value = self.value_store.get(key) if value: logger.debug(f"Found value for key {key.hex()}") @@ -431,6 +525,10 @@ class KadDHT(Service): response.record.value = value response.record.timeReceived = str(time.time()) + # Create sender_signed_peer_record + envelope_bytes, _ = env_to_send_in_RPC(self.host) + response.senderRecord = envelope_bytes + # Serialize and send response response_bytes = response.SerializeToString() await stream.write(varint.encode(len(response_bytes))) @@ -444,6 +542,10 @@ class KadDHT(Service): response.type = Message.MessageType.GET_VALUE response.key = key + # Create sender_signed_peer_record for the response + envelope_bytes, _ = env_to_send_in_RPC(self.host) + response.senderRecord = envelope_bytes + # Add closest peers to key closest_peers = self.routing_table.find_local_closest_peers( key, 20 @@ -462,6 +564,16 @@ class KadDHT(Service): peer_proto.id = peer.to_bytes() peer_proto.connection = Message.ConnectionType.CAN_CONNECT + # Add signed-records of closer-peers if cached + closer_peer_envelope = ( + self.host.get_peerstore().get_peer_record(peer) + ) + + if closer_peer_envelope is not None: + peer_proto.signedRecord = ( + closer_peer_envelope.marshal_envelope() + ) + # Add addresses if available try: addrs = self.host.get_peerstore().addrs(peer) @@ -484,6 +596,15 @@ class KadDHT(Service): key = message.record.key value = message.record.value success = False + + # Consume the source signed_peer_record if sent + if not maybe_consume_signed_record(message, self.host, peer_id): + logger.error( + "Received an invalid-signed-record, dropping the stream" + ) + await stream.close() + return + try: if not (key and value): raise ValueError( @@ -504,6 +625,12 @@ class KadDHT(Service): response.type = Message.MessageType.PUT_VALUE if success: response.key = key + + # Create sender_signed_peer_record for the response + envelope_bytes, _ = env_to_send_in_RPC(self.host) + response.senderRecord = envelope_bytes + + # Serialize and send response response_bytes = response.SerializeToString() await stream.write(varint.encode(len(response_bytes))) await stream.write(response_bytes) diff --git a/libp2p/kad_dht/pb/kademlia.proto b/libp2p/kad_dht/pb/kademlia.proto index fd198d28..7c3e5bad 100644 --- a/libp2p/kad_dht/pb/kademlia.proto +++ b/libp2p/kad_dht/pb/kademlia.proto @@ -27,6 +27,7 @@ message Message { bytes id = 1; repeated bytes addrs = 2; ConnectionType connection = 3; + optional bytes signedRecord = 4; // Envelope(PeerRecord) encoded } MessageType type = 1; @@ -35,4 +36,6 @@ message Message { Record record = 3; repeated Peer closerPeers = 8; repeated Peer providerPeers = 9; + + optional bytes senderRecord = 11; // Envelope(PeerRecord) encoded } diff --git a/libp2p/kad_dht/pb/kademlia_pb2.py b/libp2p/kad_dht/pb/kademlia_pb2.py index 781333bf..ac23169c 100644 --- a/libp2p/kad_dht/pb/kademlia_pb2.py +++ b/libp2p/kad_dht/pb/kademlia_pb2.py @@ -1,11 +1,12 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: libp2p/kad_dht/pb/kademlia.proto +# Protobuf Python Version: 4.25.3 """Generated protocol buffer code.""" -from google.protobuf.internal import builder as _builder from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -13,21 +14,21 @@ _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/kad_dht/pb/kademlia.proto\":\n\x06Record\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x14\n\x0ctimeReceived\x18\x05 \x01(\t\"\xca\x03\n\x07Message\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.Message.MessageType\x12\x17\n\x0f\x63lusterLevelRaw\x18\n \x01(\x05\x12\x0b\n\x03key\x18\x02 \x01(\x0c\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record\x12\"\n\x0b\x63loserPeers\x18\x08 \x03(\x0b\x32\r.Message.Peer\x12$\n\rproviderPeers\x18\t \x03(\x0b\x32\r.Message.Peer\x1aN\n\x04Peer\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12+\n\nconnection\x18\x03 \x01(\x0e\x32\x17.Message.ConnectionType\"i\n\x0bMessageType\x12\r\n\tPUT_VALUE\x10\x00\x12\r\n\tGET_VALUE\x10\x01\x12\x10\n\x0c\x41\x44\x44_PROVIDER\x10\x02\x12\x11\n\rGET_PROVIDERS\x10\x03\x12\r\n\tFIND_NODE\x10\x04\x12\x08\n\x04PING\x10\x05\"W\n\x0e\x43onnectionType\x12\x11\n\rNOT_CONNECTED\x10\x00\x12\r\n\tCONNECTED\x10\x01\x12\x0f\n\x0b\x43\x41N_CONNECT\x10\x02\x12\x12\n\x0e\x43\x41NNOT_CONNECT\x10\x03\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/kad_dht/pb/kademlia.proto\":\n\x06Record\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x14\n\x0ctimeReceived\x18\x05 \x01(\t\"\xa2\x04\n\x07Message\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.Message.MessageType\x12\x17\n\x0f\x63lusterLevelRaw\x18\n \x01(\x05\x12\x0b\n\x03key\x18\x02 \x01(\x0c\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record\x12\"\n\x0b\x63loserPeers\x18\x08 \x03(\x0b\x32\r.Message.Peer\x12$\n\rproviderPeers\x18\t \x03(\x0b\x32\r.Message.Peer\x12\x19\n\x0csenderRecord\x18\x0b \x01(\x0cH\x00\x88\x01\x01\x1az\n\x04Peer\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12+\n\nconnection\x18\x03 \x01(\x0e\x32\x17.Message.ConnectionType\x12\x19\n\x0csignedRecord\x18\x04 \x01(\x0cH\x00\x88\x01\x01\x42\x0f\n\r_signedRecord\"i\n\x0bMessageType\x12\r\n\tPUT_VALUE\x10\x00\x12\r\n\tGET_VALUE\x10\x01\x12\x10\n\x0c\x41\x44\x44_PROVIDER\x10\x02\x12\x11\n\rGET_PROVIDERS\x10\x03\x12\r\n\tFIND_NODE\x10\x04\x12\x08\n\x04PING\x10\x05\"W\n\x0e\x43onnectionType\x12\x11\n\rNOT_CONNECTED\x10\x00\x12\r\n\tCONNECTED\x10\x01\x12\x0f\n\x0b\x43\x41N_CONNECT\x10\x02\x12\x12\n\x0e\x43\x41NNOT_CONNECT\x10\x03\x42\x0f\n\r_senderRecordb\x06proto3') -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.kad_dht.pb.kademlia_pb2', globals()) +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.kad_dht.pb.kademlia_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _RECORD._serialized_start=36 - _RECORD._serialized_end=94 - _MESSAGE._serialized_start=97 - _MESSAGE._serialized_end=555 - _MESSAGE_PEER._serialized_start=281 - _MESSAGE_PEER._serialized_end=359 - _MESSAGE_MESSAGETYPE._serialized_start=361 - _MESSAGE_MESSAGETYPE._serialized_end=466 - _MESSAGE_CONNECTIONTYPE._serialized_start=468 - _MESSAGE_CONNECTIONTYPE._serialized_end=555 + _globals['_RECORD']._serialized_start=36 + _globals['_RECORD']._serialized_end=94 + _globals['_MESSAGE']._serialized_start=97 + _globals['_MESSAGE']._serialized_end=643 + _globals['_MESSAGE_PEER']._serialized_start=308 + _globals['_MESSAGE_PEER']._serialized_end=430 + _globals['_MESSAGE_MESSAGETYPE']._serialized_start=432 + _globals['_MESSAGE_MESSAGETYPE']._serialized_end=537 + _globals['_MESSAGE_CONNECTIONTYPE']._serialized_start=539 + _globals['_MESSAGE_CONNECTIONTYPE']._serialized_end=626 # @@protoc_insertion_point(module_scope) diff --git a/libp2p/kad_dht/pb/kademlia_pb2.pyi b/libp2p/kad_dht/pb/kademlia_pb2.pyi index c8f16db2..6d80d77d 100644 --- a/libp2p/kad_dht/pb/kademlia_pb2.pyi +++ b/libp2p/kad_dht/pb/kademlia_pb2.pyi @@ -1,133 +1,70 @@ -""" -@generated by mypy-protobuf. Do not edit manually! -isort:skip_file -""" +from google.protobuf.internal import containers as _containers +from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union -import builtins -import collections.abc -import google.protobuf.descriptor -import google.protobuf.internal.containers -import google.protobuf.internal.enum_type_wrapper -import google.protobuf.message -import sys -import typing +DESCRIPTOR: _descriptor.FileDescriptor -if sys.version_info >= (3, 10): - import typing as typing_extensions -else: - import typing_extensions +class Record(_message.Message): + __slots__ = ("key", "value", "timeReceived") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + TIMERECEIVED_FIELD_NUMBER: _ClassVar[int] + key: bytes + value: bytes + timeReceived: str + def __init__(self, key: _Optional[bytes] = ..., value: _Optional[bytes] = ..., timeReceived: _Optional[str] = ...) -> None: ... -DESCRIPTOR: google.protobuf.descriptor.FileDescriptor - -@typing.final -class Record(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - KEY_FIELD_NUMBER: builtins.int - VALUE_FIELD_NUMBER: builtins.int - TIMERECEIVED_FIELD_NUMBER: builtins.int - key: builtins.bytes - value: builtins.bytes - timeReceived: builtins.str - def __init__( - self, - *, - key: builtins.bytes = ..., - value: builtins.bytes = ..., - timeReceived: builtins.str = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["key", b"key", "timeReceived", b"timeReceived", "value", b"value"]) -> None: ... - -global___Record = Record - -@typing.final -class Message(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - class _MessageType: - ValueType = typing.NewType("ValueType", builtins.int) - V: typing_extensions.TypeAlias = ValueType - - class _MessageTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._MessageType.ValueType], builtins.type): - DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor - PUT_VALUE: Message._MessageType.ValueType # 0 - GET_VALUE: Message._MessageType.ValueType # 1 - ADD_PROVIDER: Message._MessageType.ValueType # 2 - GET_PROVIDERS: Message._MessageType.ValueType # 3 - FIND_NODE: Message._MessageType.ValueType # 4 - PING: Message._MessageType.ValueType # 5 - - class MessageType(_MessageType, metaclass=_MessageTypeEnumTypeWrapper): ... - PUT_VALUE: Message.MessageType.ValueType # 0 - GET_VALUE: Message.MessageType.ValueType # 1 - ADD_PROVIDER: Message.MessageType.ValueType # 2 - GET_PROVIDERS: Message.MessageType.ValueType # 3 - FIND_NODE: Message.MessageType.ValueType # 4 - PING: Message.MessageType.ValueType # 5 - - class _ConnectionType: - ValueType = typing.NewType("ValueType", builtins.int) - V: typing_extensions.TypeAlias = ValueType - - class _ConnectionTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._ConnectionType.ValueType], builtins.type): - DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor - NOT_CONNECTED: Message._ConnectionType.ValueType # 0 - CONNECTED: Message._ConnectionType.ValueType # 1 - CAN_CONNECT: Message._ConnectionType.ValueType # 2 - CANNOT_CONNECT: Message._ConnectionType.ValueType # 3 - - class ConnectionType(_ConnectionType, metaclass=_ConnectionTypeEnumTypeWrapper): ... - NOT_CONNECTED: Message.ConnectionType.ValueType # 0 - CONNECTED: Message.ConnectionType.ValueType # 1 - CAN_CONNECT: Message.ConnectionType.ValueType # 2 - CANNOT_CONNECT: Message.ConnectionType.ValueType # 3 - - @typing.final - class Peer(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - ID_FIELD_NUMBER: builtins.int - ADDRS_FIELD_NUMBER: builtins.int - CONNECTION_FIELD_NUMBER: builtins.int - id: builtins.bytes - connection: global___Message.ConnectionType.ValueType - @property - def addrs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ... - def __init__( - self, - *, - id: builtins.bytes = ..., - addrs: collections.abc.Iterable[builtins.bytes] | None = ..., - connection: global___Message.ConnectionType.ValueType = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["addrs", b"addrs", "connection", b"connection", "id", b"id"]) -> None: ... - - TYPE_FIELD_NUMBER: builtins.int - CLUSTERLEVELRAW_FIELD_NUMBER: builtins.int - KEY_FIELD_NUMBER: builtins.int - RECORD_FIELD_NUMBER: builtins.int - CLOSERPEERS_FIELD_NUMBER: builtins.int - PROVIDERPEERS_FIELD_NUMBER: builtins.int - type: global___Message.MessageType.ValueType - clusterLevelRaw: builtins.int - key: builtins.bytes - @property - def record(self) -> global___Record: ... - @property - def closerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ... - @property - def providerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ... - def __init__( - self, - *, - type: global___Message.MessageType.ValueType = ..., - clusterLevelRaw: builtins.int = ..., - key: builtins.bytes = ..., - record: global___Record | None = ..., - closerPeers: collections.abc.Iterable[global___Message.Peer] | None = ..., - providerPeers: collections.abc.Iterable[global___Message.Peer] | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["record", b"record"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["closerPeers", b"closerPeers", "clusterLevelRaw", b"clusterLevelRaw", "key", b"key", "providerPeers", b"providerPeers", "record", b"record", "type", b"type"]) -> None: ... - -global___Message = Message +class Message(_message.Message): + __slots__ = ("type", "clusterLevelRaw", "key", "record", "closerPeers", "providerPeers", "senderRecord") + class MessageType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + PUT_VALUE: _ClassVar[Message.MessageType] + GET_VALUE: _ClassVar[Message.MessageType] + ADD_PROVIDER: _ClassVar[Message.MessageType] + GET_PROVIDERS: _ClassVar[Message.MessageType] + FIND_NODE: _ClassVar[Message.MessageType] + PING: _ClassVar[Message.MessageType] + PUT_VALUE: Message.MessageType + GET_VALUE: Message.MessageType + ADD_PROVIDER: Message.MessageType + GET_PROVIDERS: Message.MessageType + FIND_NODE: Message.MessageType + PING: Message.MessageType + class ConnectionType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + NOT_CONNECTED: _ClassVar[Message.ConnectionType] + CONNECTED: _ClassVar[Message.ConnectionType] + CAN_CONNECT: _ClassVar[Message.ConnectionType] + CANNOT_CONNECT: _ClassVar[Message.ConnectionType] + NOT_CONNECTED: Message.ConnectionType + CONNECTED: Message.ConnectionType + CAN_CONNECT: Message.ConnectionType + CANNOT_CONNECT: Message.ConnectionType + class Peer(_message.Message): + __slots__ = ("id", "addrs", "connection", "signedRecord") + ID_FIELD_NUMBER: _ClassVar[int] + ADDRS_FIELD_NUMBER: _ClassVar[int] + CONNECTION_FIELD_NUMBER: _ClassVar[int] + SIGNEDRECORD_FIELD_NUMBER: _ClassVar[int] + id: bytes + addrs: _containers.RepeatedScalarFieldContainer[bytes] + connection: Message.ConnectionType + signedRecord: bytes + def __init__(self, id: _Optional[bytes] = ..., addrs: _Optional[_Iterable[bytes]] = ..., connection: _Optional[_Union[Message.ConnectionType, str]] = ..., signedRecord: _Optional[bytes] = ...) -> None: ... + TYPE_FIELD_NUMBER: _ClassVar[int] + CLUSTERLEVELRAW_FIELD_NUMBER: _ClassVar[int] + KEY_FIELD_NUMBER: _ClassVar[int] + RECORD_FIELD_NUMBER: _ClassVar[int] + CLOSERPEERS_FIELD_NUMBER: _ClassVar[int] + PROVIDERPEERS_FIELD_NUMBER: _ClassVar[int] + SENDERRECORD_FIELD_NUMBER: _ClassVar[int] + type: Message.MessageType + clusterLevelRaw: int + key: bytes + record: Record + closerPeers: _containers.RepeatedCompositeFieldContainer[Message.Peer] + providerPeers: _containers.RepeatedCompositeFieldContainer[Message.Peer] + senderRecord: bytes + def __init__(self, type: _Optional[_Union[Message.MessageType, str]] = ..., clusterLevelRaw: _Optional[int] = ..., key: _Optional[bytes] = ..., record: _Optional[_Union[Record, _Mapping]] = ..., closerPeers: _Optional[_Iterable[_Union[Message.Peer, _Mapping]]] = ..., providerPeers: _Optional[_Iterable[_Union[Message.Peer, _Mapping]]] = ..., senderRecord: _Optional[bytes] = ...) -> None: ... # type: ignore diff --git a/libp2p/kad_dht/peer_routing.py b/libp2p/kad_dht/peer_routing.py index c4a066f7..f5313cb6 100644 --- a/libp2p/kad_dht/peer_routing.py +++ b/libp2p/kad_dht/peer_routing.py @@ -15,12 +15,14 @@ from libp2p.abc import ( INetStream, IPeerRouting, ) +from libp2p.peer.envelope import Envelope from libp2p.peer.id import ( ID, ) from libp2p.peer.peerinfo import ( PeerInfo, ) +from libp2p.peer.peerstore import env_to_send_in_RPC from .common import ( ALPHA, @@ -33,6 +35,7 @@ from .routing_table import ( RoutingTable, ) from .utils import ( + maybe_consume_signed_record, sort_peer_ids_by_distance, ) @@ -255,6 +258,10 @@ class PeerRouting(IPeerRouting): find_node_msg.type = Message.MessageType.FIND_NODE find_node_msg.key = target_key # Set target key directly as bytes + # Create sender_signed_peer_record + envelope_bytes, _ = env_to_send_in_RPC(self.host) + find_node_msg.senderRecord = envelope_bytes + # Serialize and send the protobuf message with varint length prefix proto_bytes = find_node_msg.SerializeToString() logger.debug( @@ -299,7 +306,22 @@ class PeerRouting(IPeerRouting): # Process closest peers from response if response_msg.type == Message.MessageType.FIND_NODE: + # Consume the sender_signed_peer_record + if not maybe_consume_signed_record(response_msg, self.host, peer): + logger.error( + "Received an invalid-signed-record,ignoring the response" + ) + return [] + for peer_data in response_msg.closerPeers: + # Consume the received closer_peers signed-records, peer-id is + # sent with the peer-data + if not maybe_consume_signed_record(peer_data, self.host): + logger.error( + "Received an invalid-signed-record,ignoring the response" + ) + return [] + new_peer_id = ID(peer_data.id) if new_peer_id not in results: results.append(new_peer_id) @@ -332,6 +354,7 @@ class PeerRouting(IPeerRouting): """ try: # Read message length + peer_id = stream.muxed_conn.peer_id length_bytes = await stream.read(4) if not length_bytes: return @@ -345,10 +368,18 @@ class PeerRouting(IPeerRouting): # Parse protobuf message kad_message = Message() + closer_peer_envelope: Envelope | None = None try: kad_message.ParseFromString(message_bytes) if kad_message.type == Message.MessageType.FIND_NODE: + # Consume the sender's signed-peer-record if sent + if not maybe_consume_signed_record(kad_message, self.host, peer_id): + logger.error( + "Received an invalid-signed-record, dropping the stream" + ) + return + # Get target key directly from protobuf message target_key = kad_message.key @@ -361,12 +392,26 @@ class PeerRouting(IPeerRouting): response = Message() response.type = Message.MessageType.FIND_NODE + # Create sender_signed_peer_record for the response + envelope_bytes, _ = env_to_send_in_RPC(self.host) + response.senderRecord = envelope_bytes + # Add peer information to response for peer_id in closest_peers: peer_proto = response.closerPeers.add() peer_proto.id = peer_id.to_bytes() peer_proto.connection = Message.ConnectionType.CAN_CONNECT + # Add the signed-records of closest_peers if cached + closer_peer_envelope = ( + self.host.get_peerstore().get_peer_record(peer_id) + ) + + if isinstance(closer_peer_envelope, Envelope): + peer_proto.signedRecord = ( + closer_peer_envelope.marshal_envelope() + ) + # Add addresses if available try: addrs = self.host.get_peerstore().addrs(peer_id) diff --git a/libp2p/kad_dht/provider_store.py b/libp2p/kad_dht/provider_store.py index 5c34f0c7..77bb464f 100644 --- a/libp2p/kad_dht/provider_store.py +++ b/libp2p/kad_dht/provider_store.py @@ -22,12 +22,14 @@ from libp2p.abc import ( from libp2p.custom_types import ( TProtocol, ) +from libp2p.kad_dht.utils import maybe_consume_signed_record from libp2p.peer.id import ( ID, ) from libp2p.peer.peerinfo import ( PeerInfo, ) +from libp2p.peer.peerstore import env_to_send_in_RPC from .common import ( ALPHA, @@ -240,11 +242,18 @@ class ProviderStore: message.type = Message.MessageType.ADD_PROVIDER message.key = key + # Create sender's signed-peer-record + envelope_bytes, _ = env_to_send_in_RPC(self.host) + message.senderRecord = envelope_bytes + # Add our provider info provider = message.providerPeers.add() provider.id = self.local_peer_id.to_bytes() provider.addrs.extend(addrs) + # Add the provider's signed-peer-record + provider.signedRecord = envelope_bytes + # Serialize and send the message proto_bytes = message.SerializeToString() await stream.write(varint.encode(len(proto_bytes))) @@ -276,10 +285,15 @@ class ProviderStore: response = Message() response.ParseFromString(response_bytes) - # Check response type - response.type == Message.MessageType.ADD_PROVIDER - if response.type: - result = True + if response.type == Message.MessageType.ADD_PROVIDER: + # Consume the sender's signed-peer-record if sent + if not maybe_consume_signed_record(response, self.host, peer_id): + logger.error( + "Received an invalid-signed-record, ignoring the response" + ) + result = False + else: + result = True except Exception as e: logger.warning(f"Error sending ADD_PROVIDER to {peer_id}: {e}") @@ -380,6 +394,10 @@ class ProviderStore: message.type = Message.MessageType.GET_PROVIDERS message.key = key + # Create sender's signed-peer-record + envelope_bytes, _ = env_to_send_in_RPC(self.host) + message.senderRecord = envelope_bytes + # Serialize and send the message proto_bytes = message.SerializeToString() await stream.write(varint.encode(len(proto_bytes))) @@ -414,10 +432,26 @@ class ProviderStore: if response.type != Message.MessageType.GET_PROVIDERS: return [] + # Consume the sender's signed-peer-record if sent + if not maybe_consume_signed_record(response, self.host, peer_id): + logger.error( + "Received an invalid-signed-record, ignoring the response" + ) + return [] + # Extract provider information providers = [] for provider_proto in response.providerPeers: try: + # Consume the provider's signed-peer-record if sent, peer-id + # already sent with the provider-proto + if not maybe_consume_signed_record(provider_proto, self.host): + logger.error( + "Received an invalid-signed-record, " + "ignoring the response" + ) + return [] + # Create peer ID from bytes provider_id = ID(provider_proto.id) @@ -431,6 +465,7 @@ class ProviderStore: # Create PeerInfo and add to result providers.append(PeerInfo(provider_id, addrs)) + except Exception as e: logger.warning(f"Failed to parse provider info: {e}") diff --git a/libp2p/kad_dht/utils.py b/libp2p/kad_dht/utils.py index 61158320..fe768723 100644 --- a/libp2p/kad_dht/utils.py +++ b/libp2p/kad_dht/utils.py @@ -2,13 +2,93 @@ Utility functions for Kademlia DHT implementation. """ +import logging + import base58 import multihash +from libp2p.abc import IHost +from libp2p.peer.envelope import consume_envelope from libp2p.peer.id import ( ID, ) +from .pb.kademlia_pb2 import ( + Message, +) + +logger = logging.getLogger("kademlia-example.utils") + + +def maybe_consume_signed_record( + msg: Message | Message.Peer, host: IHost, peer_id: ID | None = None +) -> bool: + """ + Attempt to parse and store a signed-peer-record (Envelope) received during + DHT communication. If the record is invalid, the peer-id does not match, or + updating the peerstore fails, the function logs an error and returns False. + + Parameters + ---------- + msg : Message | Message.Peer + The protobuf message received during DHT communication. Can either be a + top-level `Message` containing `senderRecord` or a `Message.Peer` + containing `signedRecord`. + host : IHost + The local host instance, providing access to the peerstore for storing + verified peer records. + peer_id : ID | None, optional + The expected peer ID for record validation. If provided, the peer ID + inside the record must match this value. + + Returns + ------- + bool + True if a valid signed peer record was successfully consumed and stored, + False otherwise. + + """ + if isinstance(msg, Message): + if msg.HasField("senderRecord"): + try: + # Convert the signed-peer-record(Envelope) from + # protobuf bytes + envelope, record = consume_envelope( + msg.senderRecord, + "libp2p-peer-record", + ) + if not (isinstance(peer_id, ID) and record.peer_id == peer_id): + return False + # Use the default TTL of 2 hours (7200 seconds) + if not host.get_peerstore().consume_peer_record(envelope, 7200): + logger.error("Failed to update the Certified-Addr-Book") + return False + except Exception as e: + logger.error("Failed to update the Certified-Addr-Book: %s", e) + return False + else: + if msg.HasField("signedRecord"): + try: + # Convert the signed-peer-record(Envelope) from + # protobuf bytes + envelope, record = consume_envelope( + msg.signedRecord, + "libp2p-peer-record", + ) + if not record.peer_id.to_bytes() == msg.id: + return False + # Use the default TTL of 2 hours (7200 seconds) + if not host.get_peerstore().consume_peer_record(envelope, 7200): + logger.error("Failed to update the Certified-Addr-Book") + return False + except Exception as e: + logger.error( + "Failed to update the Certified-Addr-Book: %s", + e, + ) + return False + return True + def create_key_from_binary(binary_data: bytes) -> bytes: """ diff --git a/libp2p/kad_dht/value_store.py b/libp2p/kad_dht/value_store.py index b79425fd..2002965f 100644 --- a/libp2p/kad_dht/value_store.py +++ b/libp2p/kad_dht/value_store.py @@ -15,9 +15,11 @@ from libp2p.abc import ( from libp2p.custom_types import ( TProtocol, ) +from libp2p.kad_dht.utils import maybe_consume_signed_record from libp2p.peer.id import ( ID, ) +from libp2p.peer.peerstore import env_to_send_in_RPC from .common import ( DEFAULT_TTL, @@ -110,6 +112,10 @@ class ValueStore: message = Message() message.type = Message.MessageType.PUT_VALUE + # Create sender's signed-peer-record + envelope_bytes, _ = env_to_send_in_RPC(self.host) + message.senderRecord = envelope_bytes + # Set message fields message.key = key message.record.key = key @@ -155,7 +161,13 @@ class ValueStore: # Check if response is valid if response.type == Message.MessageType.PUT_VALUE: - if response.key: + # Consume the sender's signed-peer-record if sent + if not maybe_consume_signed_record(response, self.host, peer_id): + logger.error( + "Received an invalid-signed-record, ignoring the response" + ) + return False + if response.key == key: result = True return result @@ -231,6 +243,10 @@ class ValueStore: message.type = Message.MessageType.GET_VALUE message.key = key + # Create sender's signed-peer-record + envelope_bytes, _ = env_to_send_in_RPC(self.host) + message.senderRecord = envelope_bytes + # Serialize and send the protobuf message proto_bytes = message.SerializeToString() await stream.write(varint.encode(len(proto_bytes))) @@ -275,6 +291,13 @@ class ValueStore: and response.HasField("record") and response.record.value ): + # Consume the sender's signed-peer-record + if not maybe_consume_signed_record(response, self.host, peer_id): + logger.error( + "Received an invalid-signed-record, ignoring the response" + ) + return None + logger.debug( f"Received value for key {key.hex()} from peer {peer_id}" ) diff --git a/libp2p/peer/envelope.py b/libp2p/peer/envelope.py index e93a8280..f8bf9f43 100644 --- a/libp2p/peer/envelope.py +++ b/libp2p/peer/envelope.py @@ -1,5 +1,7 @@ from typing import Any, cast +import multiaddr + from libp2p.crypto.ed25519 import Ed25519PublicKey from libp2p.crypto.keys import PrivateKey, PublicKey from libp2p.crypto.rsa import RSAPublicKey @@ -131,6 +133,9 @@ class Envelope: ) return False + def _env_addrs_set(self) -> set[multiaddr.Multiaddr]: + return {b for b in self.record().addrs} + def pub_key_to_protobuf(pub_key: PublicKey) -> cryto_pb.PublicKey: """ diff --git a/libp2p/peer/peerstore.py b/libp2p/peer/peerstore.py index 043aaf0d..ddf1af1f 100644 --- a/libp2p/peer/peerstore.py +++ b/libp2p/peer/peerstore.py @@ -16,6 +16,7 @@ import trio from trio import MemoryReceiveChannel, MemorySendChannel from libp2p.abc import ( + IHost, IPeerStore, ) from libp2p.crypto.keys import ( @@ -23,7 +24,8 @@ from libp2p.crypto.keys import ( PrivateKey, PublicKey, ) -from libp2p.peer.envelope import Envelope +from libp2p.peer.envelope import Envelope, seal_record +from libp2p.peer.peer_record import PeerRecord from .id import ( ID, @@ -39,6 +41,86 @@ from .peerinfo import ( PERMANENT_ADDR_TTL = 0 +def create_signed_peer_record( + peer_id: ID, addrs: list[Multiaddr], pvt_key: PrivateKey +) -> Envelope: + """Creates a signed_peer_record wrapped in an Envelope""" + record = PeerRecord(peer_id, addrs) + envelope = seal_record(record, pvt_key) + return envelope + + +def env_to_send_in_RPC(host: IHost) -> tuple[bytes, bool]: + """ + Return the signed peer record (Envelope) to be sent in an RPC. + + This function checks whether the host already has a cached signed peer record + (SPR). If one exists and its addresses match the host's current listen + addresses, the cached envelope is reused. Otherwise, a new signed peer record + is created, cached, and returned. + + Parameters + ---------- + host : IHost + The local host instance, providing access to peer ID, listen addresses, + private key, and the peerstore. + + Returns + ------- + tuple[bytes, bool] + A 2-tuple where the first element is the serialized envelope (bytes) + for the signed peer record, and the second element is a boolean flag + indicating whether a new record was created (True) or an existing cached + one was reused (False). + + """ + listen_addrs_set = {addr for addr in host.get_addrs()} + local_env = host.get_peerstore().get_local_record() + + if local_env is None: + # No cached SPR yet -> create one + return issue_and_cache_local_record(host), True + else: + record_addrs_set = local_env._env_addrs_set() + if record_addrs_set == listen_addrs_set: + # Perfect match -> reuse cached envelope + return local_env.marshal_envelope(), False + else: + # Addresses changed -> issue a new SPR and cache it + return issue_and_cache_local_record(host), True + + +def issue_and_cache_local_record(host: IHost) -> bytes: + """ + Create and cache a new signed peer record (Envelope) for the host. + + This function generates a new signed peer record from the host’s peer ID, + listen addresses, and private key. The resulting envelope is stored in + the peerstore as the local record for future reuse. + + Parameters + ---------- + host : IHost + The local host instance, providing access to peer ID, listen addresses, + private key, and the peerstore. + + Returns + ------- + bytes + The serialized envelope (bytes) representing the newly created signed + peer record. + + """ + env = create_signed_peer_record( + host.get_id(), + host.get_addrs(), + host.get_private_key(), + ) + # Cache it for next time use + host.get_peerstore().set_local_record(env) + return env.marshal_envelope() + + class PeerRecordState: envelope: Envelope seq: int @@ -55,8 +137,17 @@ class PeerStore(IPeerStore): self.peer_data_map = defaultdict(PeerData) self.addr_update_channels: dict[ID, MemorySendChannel[Multiaddr]] = {} self.peer_record_map: dict[ID, PeerRecordState] = {} + self.local_peer_record: Envelope | None = None self.max_records = max_records + def get_local_record(self) -> Envelope | None: + """Get the local-signed-record wrapped in Envelope""" + return self.local_peer_record + + def set_local_record(self, envelope: Envelope) -> None: + """Set the local-signed-record wrapped in Envelope""" + self.local_peer_record = envelope + def peer_info(self, peer_id: ID) -> PeerInfo: """ :param peer_id: peer ID to get info for diff --git a/newsfragments/815.feature.rst b/newsfragments/815.feature.rst new file mode 100644 index 00000000..8fcf6fea --- /dev/null +++ b/newsfragments/815.feature.rst @@ -0,0 +1 @@ +KAD-DHT now include signed-peer-records in its protobuf message schema, for more secure peer-discovery. diff --git a/tests/core/kad_dht/test_kad_dht.py b/tests/core/kad_dht/test_kad_dht.py index a6f73074..5bf4f3e8 100644 --- a/tests/core/kad_dht/test_kad_dht.py +++ b/tests/core/kad_dht/test_kad_dht.py @@ -9,11 +9,15 @@ This module tests core functionality of the Kademlia DHT including: import hashlib import logging +import os +from unittest.mock import patch import uuid import pytest +import multiaddr import trio +from libp2p.crypto.rsa import create_new_key_pair from libp2p.kad_dht.kad_dht import ( DHTMode, KadDHT, @@ -21,9 +25,13 @@ from libp2p.kad_dht.kad_dht import ( from libp2p.kad_dht.utils import ( create_key_from_binary, ) +from libp2p.peer.envelope import Envelope, seal_record +from libp2p.peer.id import ID +from libp2p.peer.peer_record import PeerRecord from libp2p.peer.peerinfo import ( PeerInfo, ) +from libp2p.peer.peerstore import create_signed_peer_record from libp2p.tools.async_service import ( background_trio_service, ) @@ -76,10 +84,52 @@ async def test_find_node(dht_pair: tuple[KadDHT, KadDHT]): """Test that nodes can find each other in the DHT.""" dht_a, dht_b = dht_pair + # An extra FIND_NODE req is sent between the 2 nodes while dht creation, + # so both the nodes will have records of each other before the next FIND_NODE + # req is sent + envelope_a = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id()) + envelope_b = dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id()) + + assert isinstance(envelope_a, Envelope) + assert isinstance(envelope_b, Envelope) + + record_a = envelope_a.record() + record_b = envelope_b.record() + # Node A should be able to find Node B with trio.fail_after(TEST_TIMEOUT): found_info = await dht_a.find_peer(dht_b.host.get_id()) + # Verifies if the senderRecord in the FIND_NODE request is correctly processed + assert isinstance( + dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id()), Envelope + ) + + # Verifies if the senderRecord in the FIND_NODE response is correctly processed + assert isinstance( + dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id()), Envelope + ) + + # These are the records that were sent between the peers during the FIND_NODE req + envelope_a_find_peer = dht_a.host.get_peerstore().get_peer_record( + dht_b.host.get_id() + ) + envelope_b_find_peer = dht_b.host.get_peerstore().get_peer_record( + dht_a.host.get_id() + ) + + assert isinstance(envelope_a_find_peer, Envelope) + assert isinstance(envelope_b_find_peer, Envelope) + + record_a_find_peer = envelope_a_find_peer.record() + record_b_find_peer = envelope_b_find_peer.record() + + # This proves that both the records are same, and a latest cached signed record + # was passed between the peers during FIND_NODE execution, which proves the + # signed-record transfer/re-issuing works correctly in FIND_NODE executions. + assert record_a.seq == record_a_find_peer.seq + assert record_b.seq == record_b_find_peer.seq + # Verify that the found peer has the correct peer ID assert found_info is not None, "Failed to find the target peer" assert found_info.peer_id == dht_b.host.get_id(), "Found incorrect peer ID" @@ -104,14 +154,44 @@ async def test_put_and_get_value(dht_pair: tuple[KadDHT, KadDHT]): await dht_a.routing_table.add_peer(peer_b_info) print("Routing table of a has ", dht_a.routing_table.get_peer_ids()) + # An extra FIND_NODE req is sent between the 2 nodes while dht creation, + # so both the nodes will have records of each other before PUT_VALUE req is sent + envelope_a = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id()) + envelope_b = dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id()) + + assert isinstance(envelope_a, Envelope) + assert isinstance(envelope_b, Envelope) + + record_a = envelope_a.record() + record_b = envelope_b.record() + # Store the value using the first node (this will also store locally) with trio.fail_after(TEST_TIMEOUT): await dht_a.put_value(key, value) + # These are the records that were sent between the peers during the PUT_VALUE req + envelope_a_put_value = dht_a.host.get_peerstore().get_peer_record( + dht_b.host.get_id() + ) + envelope_b_put_value = dht_b.host.get_peerstore().get_peer_record( + dht_a.host.get_id() + ) + + assert isinstance(envelope_a_put_value, Envelope) + assert isinstance(envelope_b_put_value, Envelope) + + record_a_put_value = envelope_a_put_value.record() + record_b_put_value = envelope_b_put_value.record() + + # This proves that both the records are same, and a latest cached signed record + # was passed between the peers during PUT_VALUE execution, which proves the + # signed-record transfer/re-issuing works correctly in PUT_VALUE executions. + assert record_a.seq == record_a_put_value.seq + assert record_b.seq == record_b_put_value.seq + # # Log debugging information logger.debug("Put value with key %s...", key.hex()[:10]) logger.debug("Node A value store: %s", dht_a.value_store.store) - print("hello test") # # Allow more time for the value to propagate await trio.sleep(0.5) @@ -126,6 +206,26 @@ async def test_put_and_get_value(dht_pair: tuple[KadDHT, KadDHT]): print("the value stored in node b is", dht_b.get_value_store_size()) logger.debug("Retrieved value: %s", retrieved_value) + # These are the records that were sent between the peers during the PUT_VALUE req + envelope_a_get_value = dht_a.host.get_peerstore().get_peer_record( + dht_b.host.get_id() + ) + envelope_b_get_value = dht_b.host.get_peerstore().get_peer_record( + dht_a.host.get_id() + ) + + assert isinstance(envelope_a_get_value, Envelope) + assert isinstance(envelope_b_get_value, Envelope) + + record_a_get_value = envelope_a_get_value.record() + record_b_get_value = envelope_b_get_value.record() + + # This proves that there was no record exchange between the nodes during GET_VALUE + # execution, as dht_b already had the key/value pair stored locally after the + # PUT_VALUE execution. + assert record_a_get_value.seq == record_a_put_value.seq + assert record_b_get_value.seq == record_b_put_value.seq + # Verify that the retrieved value matches the original assert retrieved_value == value, "Retrieved value does not match the stored value" @@ -142,11 +242,44 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]): # Store content on the first node dht_a.value_store.put(content_id, content) + # An extra FIND_NODE req is sent between the 2 nodes while dht creation, + # so both the nodes will have records of each other before PUT_VALUE req is sent + envelope_a = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id()) + envelope_b = dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id()) + + assert isinstance(envelope_a, Envelope) + assert isinstance(envelope_b, Envelope) + + record_a = envelope_a.record() + record_b = envelope_b.record() + # Advertise the first node as a provider with trio.fail_after(TEST_TIMEOUT): success = await dht_a.provide(content_id) assert success, "Failed to advertise as provider" + # These are the records that were sent between the peers during + # the ADD_PROVIDER req + envelope_a_add_prov = dht_a.host.get_peerstore().get_peer_record( + dht_b.host.get_id() + ) + envelope_b_add_prov = dht_b.host.get_peerstore().get_peer_record( + dht_a.host.get_id() + ) + + assert isinstance(envelope_a_add_prov, Envelope) + assert isinstance(envelope_b_add_prov, Envelope) + + record_a_add_prov = envelope_a_add_prov.record() + record_b_add_prov = envelope_b_add_prov.record() + + # This proves that both the records are same, the latest cached signed record + # was passed between the peers during ADD_PROVIDER execution, which proves the + # signed-record transfer/re-issuing of the latest record works correctly in + # ADD_PROVIDER executions. + assert record_a.seq == record_a_add_prov.seq + assert record_b.seq == record_b_add_prov.seq + # Allow time for the provider record to propagate await trio.sleep(0.1) @@ -154,6 +287,26 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]): with trio.fail_after(TEST_TIMEOUT): providers = await dht_b.find_providers(content_id) + # These are the records in each peer after the find_provider execution + envelope_a_find_prov = dht_a.host.get_peerstore().get_peer_record( + dht_b.host.get_id() + ) + envelope_b_find_prov = dht_b.host.get_peerstore().get_peer_record( + dht_a.host.get_id() + ) + + assert isinstance(envelope_a_find_prov, Envelope) + assert isinstance(envelope_b_find_prov, Envelope) + + record_a_find_prov = envelope_a_find_prov.record() + record_b_find_prov = envelope_b_find_prov.record() + + # This proves that both the records are same, as the dht_b already + # has the provider record for the content_id, after the ADD_PROVIDER + # advertisement by dht_a + assert record_a_find_prov.seq == record_a_add_prov.seq + assert record_b_find_prov.seq == record_b_add_prov.seq + # Verify that we found the first node as a provider assert providers, "No providers found" assert any(p.peer_id == dht_a.local_peer_id for p in providers), ( @@ -166,3 +319,143 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]): assert retrieved_value == content, ( "Retrieved content does not match the original" ) + + # These are the record state of each peer aftet the GET_VALUE execution + envelope_a_get_value = dht_a.host.get_peerstore().get_peer_record( + dht_b.host.get_id() + ) + envelope_b_get_value = dht_b.host.get_peerstore().get_peer_record( + dht_a.host.get_id() + ) + + assert isinstance(envelope_a_get_value, Envelope) + assert isinstance(envelope_b_get_value, Envelope) + + record_a_get_value = envelope_a_get_value.record() + record_b_get_value = envelope_b_get_value.record() + + # This proves that both the records are same, meaning that the latest cached + # signed-record tranfer happened during the GET_VALUE execution by dht_b, + # which means the signed-record transfer/re-issuing works correctly + # in GET_VALUE executions. + assert record_a_find_prov.seq == record_a_get_value.seq + assert record_b_find_prov.seq == record_b_get_value.seq + + # Create a new provider record in dht_a + provider_key_pair = create_new_key_pair() + provider_peer_id = ID.from_pubkey(provider_key_pair.public_key) + provider_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/123") + provider_peer_info = PeerInfo(peer_id=provider_peer_id, addrs=[provider_addr]) + + # Generate a random content ID + content_2 = f"random-content-{uuid.uuid4()}".encode() + content_id_2 = hashlib.sha256(content_2).digest() + + provider_signed_envelope = create_signed_peer_record( + provider_peer_id, [provider_addr], provider_key_pair.private_key + ) + assert ( + dht_a.host.get_peerstore().consume_peer_record(provider_signed_envelope, 7200) + is True + ) + + # Store this provider record in dht_a + dht_a.provider_store.add_provider(content_id_2, provider_peer_info) + + # Fetch the provider-record via peer-discovery at dht_b's end + peerinfo = await dht_b.provider_store.find_providers(content_id_2) + + assert len(peerinfo) == 1 + assert peerinfo[0].peer_id == provider_peer_id + provider_envelope = dht_b.host.get_peerstore().get_peer_record(provider_peer_id) + + # This proves that the signed-envelope of provider is consumed on dht_b's end + assert provider_envelope is not None + assert ( + provider_signed_envelope.marshal_envelope() + == provider_envelope.marshal_envelope() + ) + + +@pytest.mark.trio +async def test_reissue_when_listen_addrs_change(dht_pair: tuple[KadDHT, KadDHT]): + dht_a, dht_b = dht_pair + + # Warm-up: A stores B's current record + with trio.fail_after(10): + await dht_a.find_peer(dht_b.host.get_id()) + + env0 = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id()) + assert isinstance(env0, Envelope) + seq0 = env0.record().seq + + # Simulate B's listen addrs changing (different port) + new_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/123") + + # Patch just for the duration we force B to respond: + with patch.object(dht_b.host, "get_addrs", return_value=[new_addr]): + # Force B to send a response (which should include a fresh SPR) + with trio.fail_after(10): + await dht_a.peer_routing._query_peer_for_closest( + dht_b.host.get_id(), os.urandom(32) + ) + + # A should now hold B's new record with a bumped seq + env1 = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id()) + assert isinstance(env1, Envelope) + seq1 = env1.record().seq + + # This proves that upon the change in listen_addrs, we issue new records + assert seq1 > seq0, f"Expected seq to bump after addr change, got {seq0} -> {seq1}" + + +@pytest.mark.trio +async def test_dht_req_fail_with_invalid_record_transfer( + dht_pair: tuple[KadDHT, KadDHT], +): + """ + Testing showing failure of storing and retrieving values in the DHT, + if invalid signed-records are sent. + """ + dht_a, dht_b = dht_pair + peer_b_info = PeerInfo(dht_b.host.get_id(), dht_b.host.get_addrs()) + + # Generate a random key and value + key = create_key_from_binary(b"test-key") + value = b"test-value" + + # First add the value directly to node A's store to verify storage works + dht_a.value_store.put(key, value) + local_value = dht_a.value_store.get(key) + assert local_value == value, "Local value storage failed" + await dht_a.routing_table.add_peer(peer_b_info) + + # Corrupt dht_a's local peer_record + envelope = dht_a.host.get_peerstore().get_local_record() + if envelope is not None: + true_record = envelope.record() + key_pair = create_new_key_pair() + + if envelope is not None: + envelope.public_key = key_pair.public_key + dht_a.host.get_peerstore().set_local_record(envelope) + + await dht_a.put_value(key, value) + retrieved_value = dht_b.value_store.get(key) + + # This proves that DHT_B rejected DHT_A PUT_RECORD req upon receiving + # the corrupted invalid record + assert retrieved_value is None + + # Create a corrupt envelope with correct signature but false peer_id + false_record = PeerRecord(ID.from_pubkey(key_pair.public_key), true_record.addrs) + false_envelope = seal_record(false_record, dht_a.host.get_private_key()) + + dht_a.host.get_peerstore().set_local_record(false_envelope) + + await dht_a.put_value(key, value) + retrieved_value = dht_b.value_store.get(key) + + # This proves that DHT_B rejected DHT_A PUT_RECORD req upon receving + # the record with a different peer_id regardless of a valid signature + assert retrieved_value is None diff --git a/tests/core/kad_dht/test_unit_peer_routing.py b/tests/core/kad_dht/test_unit_peer_routing.py index ffe20655..6e15ce7e 100644 --- a/tests/core/kad_dht/test_unit_peer_routing.py +++ b/tests/core/kad_dht/test_unit_peer_routing.py @@ -57,7 +57,10 @@ class TestPeerRouting: def mock_host(self): """Create a mock host for testing.""" host = Mock() - host.get_id.return_value = create_valid_peer_id("local") + key_pair = create_new_key_pair() + host.get_id.return_value = ID.from_pubkey(key_pair.public_key) + host.get_public_key.return_value = key_pair.public_key + host.get_private_key.return_value = key_pair.private_key host.get_addrs.return_value = [Multiaddr("/ip4/127.0.0.1/tcp/8000")] host.get_peerstore.return_value = Mock() host.new_stream = AsyncMock()