diff --git a/libp2p/kad_dht/kad_dht.py b/libp2p/kad_dht/kad_dht.py index f93aa75e..adfd7400 100644 --- a/libp2p/kad_dht/kad_dht.py +++ b/libp2p/kad_dht/kad_dht.py @@ -280,7 +280,12 @@ class KadDHT(Service): logger.debug(f"Found {len(closest_peers)} peers close to target") # Consume the source signed_peer_record if sent - success = maybe_consume_signed_record(message, self.host) + if not maybe_consume_signed_record(message, self.host): + logger.error( + "Received an invalid-signed-record, dropping the stream" + ) + await stream.close() + return # Build response message with protobuf response = Message() @@ -336,7 +341,12 @@ class KadDHT(Service): logger.debug(f"Received ADD_PROVIDER for key {key.hex()}") # Consume the source signed-peer-record if sent - success = maybe_consume_signed_record(message, self.host) + if not maybe_consume_signed_record(message, self.host): + logger.error( + "Received an invalid-signed-record, dropping the stream" + ) + await stream.close() + return # Extract provider information for provider_proto in message.providerPeers: @@ -366,9 +376,13 @@ class KadDHT(Service): ) # Process the signed-records of provider if sent - success = maybe_consume_signed_record( - provider_proto, self.host - ) + if not maybe_consume_signed_record(message, self.host): + logger.error( + "Received an invalid-signed-record," + "dropping the stream" + ) + await stream.close() + return except Exception as e: logger.warning(f"Failed to process provider info: {e}") @@ -393,7 +407,12 @@ class KadDHT(Service): logger.debug(f"Received GET_PROVIDERS request for key {key.hex()}") # Consume the source signed_peer_record if sent - success = maybe_consume_signed_record(message, self.host) + if not maybe_consume_signed_record(message, self.host): + logger.error( + "Received an invalid-signed-record, dropping the stream" + ) + await stream.close() + return # Find providers for the key providers = self.provider_store.get_providers(key) @@ -482,7 +501,12 @@ class KadDHT(Service): logger.debug(f"Received GET_VALUE request for key {key.hex()}") # Consume the sender_signed_peer_record - success = maybe_consume_signed_record(message, self.host) + if not maybe_consume_signed_record(message, self.host): + logger.error( + "Received an invalid-signed-record, dropping the stream" + ) + await stream.close() + return value = self.value_store.get(key) if value: @@ -571,7 +595,12 @@ class KadDHT(Service): success = False # Consume the source signed_peer_record if sent - success = maybe_consume_signed_record(message, self.host) + if not maybe_consume_signed_record(message, self.host): + logger.error( + "Received an invalid-signed-record, dropping the stream" + ) + await stream.close() + return try: if not (key and value): diff --git a/libp2p/kad_dht/peer_routing.py b/libp2p/kad_dht/peer_routing.py index 4362ffea..cd1611ed 100644 --- a/libp2p/kad_dht/peer_routing.py +++ b/libp2p/kad_dht/peer_routing.py @@ -307,9 +307,20 @@ class PeerRouting(IPeerRouting): # Process closest peers from response if response_msg.type == Message.MessageType.FIND_NODE: # Consume the sender_signed_peer_record - _ = maybe_consume_signed_record(response_msg, self.host) + if not maybe_consume_signed_record(response_msg, self.host): + logger.error( + "Received an invalid-signed-record,ignoring the response" + ) + return [] for peer_data in response_msg.closerPeers: + # Consume the received closer_peers signed-records + if not maybe_consume_signed_record(peer_data, self.host): + logger.error( + "Received an invalid-signed-record,ignoring the response" + ) + return [] + new_peer_id = ID(peer_data.id) if new_peer_id not in results: results.append(new_peer_id) @@ -321,9 +332,6 @@ class PeerRouting(IPeerRouting): addrs = [Multiaddr(addr) for addr in peer_data.addrs] self.host.get_peerstore().add_addrs(new_peer_id, addrs, 3600) - # Consume the received closer_peers signed-records - _ = maybe_consume_signed_record(peer_data, self.host) - except Exception as e: logger.debug(f"Error querying peer {peer} for closest: {e}") @@ -364,7 +372,11 @@ class PeerRouting(IPeerRouting): if kad_message.type == Message.MessageType.FIND_NODE: # Consume the sender's signed-peer-record if sent - _ = maybe_consume_signed_record(kad_message, self.host) + if not maybe_consume_signed_record(kad_message, self.host): + logger.error( + "Receivedf an invalid-signed-record, dropping the stream" + ) + return # Get target key directly from protobuf message target_key = kad_message.key diff --git a/libp2p/kad_dht/provider_store.py b/libp2p/kad_dht/provider_store.py index 4c6a8e06..ee7adfe8 100644 --- a/libp2p/kad_dht/provider_store.py +++ b/libp2p/kad_dht/provider_store.py @@ -286,8 +286,13 @@ class ProviderStore: if response.type == Message.MessageType.ADD_PROVIDER: # Consume the sender's signed-peer-record if sent - _ = maybe_consume_signed_record(response, self.host) - result = True + if not maybe_consume_signed_record(response, self.host): + logger.error( + "Received an invalid-signed-record, ignoring the response" + ) + result = False + else: + result = True except Exception as e: logger.warning(f"Error sending ADD_PROVIDER to {peer_id}: {e}") @@ -427,12 +432,24 @@ class ProviderStore: return [] # Consume the sender's signed-peer-record if sent - _ = maybe_consume_signed_record(response, self.host) + if not maybe_consume_signed_record(response, self.host): + logger.error( + "Recieved an invalid-signed-record, ignoring the response" + ) + return [] # Extract provider information providers = [] for provider_proto in response.providerPeers: try: + # Consume the provider's signed-peer-record if sent + if not maybe_consume_signed_record(provider_proto, self.host): + logger.error( + "Recieved an invalid-signed-record, " + "ignoring the response" + ) + return [] + # Create peer ID from bytes provider_id = ID(provider_proto.id) @@ -447,9 +464,6 @@ class ProviderStore: # Create PeerInfo and add to result providers.append(PeerInfo(provider_id, addrs)) - # Consume the provider's signed-peer-record if sent - _ = maybe_consume_signed_record(provider_proto, self.host) - except Exception as e: logger.warning(f"Failed to parse provider info: {e}") diff --git a/libp2p/kad_dht/utils.py b/libp2p/kad_dht/utils.py index 3cf79efd..6d65d1af 100644 --- a/libp2p/kad_dht/utils.py +++ b/libp2p/kad_dht/utils.py @@ -27,7 +27,10 @@ def maybe_consume_signed_record(msg: Message | Message.Peer, host: IHost) -> boo try: # Convert the signed-peer-record(Envelope) from # protobuf bytes - envelope, _ = consume_envelope(msg.senderRecord, "libp2p-peer-record") + envelope, _ = consume_envelope( + msg.senderRecord, + "libp2p-peer-record", + ) # Use the default TTL of 2 hours (7200 seconds) if not host.get_peerstore().consume_peer_record(envelope, 7200): logger.error("Updating the certified-addr-book was unsuccessful") @@ -51,7 +54,7 @@ def maybe_consume_signed_record(msg: Message | Message.Peer, host: IHost) -> boo "Error updating the certified-addr-book: %s", e, ) - + return False return True diff --git a/libp2p/kad_dht/value_store.py b/libp2p/kad_dht/value_store.py index bb143dcd..aa545797 100644 --- a/libp2p/kad_dht/value_store.py +++ b/libp2p/kad_dht/value_store.py @@ -161,8 +161,11 @@ class ValueStore: # Check if response is valid if response.type == Message.MessageType.PUT_VALUE: # Consume the sender's signed-peer-record if sent - _ = maybe_consume_signed_record(response, self.host) - + if not maybe_consume_signed_record(response, self.host): + logger.error( + "Received an invalid-signed-record, ignoring the response" + ) + return False if response.key == key: result = True return result @@ -288,7 +291,11 @@ class ValueStore: and response.record.value ): # Consume the sender's signed-peer-record - _ = maybe_consume_signed_record(response, self.host) + if not maybe_consume_signed_record(response, self.host): + logger.error( + "Received an invalid-signed-record, ignoring the response" + ) + return None logger.debug( f"Received value for key {key.hex()} from peer {peer_id}"