From 22d93b39ae86600e17a26566b9db59c0073f39c6 Mon Sep 17 00:00:00 2001 From: Sukhman Singh <63765293+sukhman-sukh@users.noreply.github.com> Date: Tue, 10 Jun 2025 00:12:59 +0530 Subject: [PATCH] Add ttl for peer data expiration (#655) * Add ttl and last_identified to peerdata * Add test for ttl Signed-off-by: sukhman * Fix lint and add newsfragments Signed-off-by: sukhman * Fix failing ci Signed-off-by: sukhman * fix ttl time from 600 to 120 Signed-off-by: sukhman * fix test ttl timeout and lint errors Signed-off-by: sukhman * Fix docstrings Signed-off-by: sukhman * rebase main * remove print statement --------- Signed-off-by: sukhman Co-authored-by: pacrob <5199899+pacrob@users.noreply.github.com> --- libp2p/abc.py | 54 ++++++++++++++++++++++++++++++ libp2p/host/basic_host.py | 2 +- libp2p/host/routed_host.py | 4 +-- libp2p/peer/peerdata.py | 39 ++++++++++++++++++++- libp2p/peer/peerstore.py | 39 ++++++++++++++++----- newsfragments/650.feature.rst | 1 + tests/core/peer/test_peerstore.py | 31 +++++++++++++++-- tests/core/pubsub/test_floodsub.py | 4 +-- 8 files changed, 158 insertions(+), 16 deletions(-) create mode 100644 newsfragments/650.feature.rst diff --git a/libp2p/abc.py b/libp2p/abc.py index 06570eaa..a50a364d 100644 --- a/libp2p/abc.py +++ b/libp2p/abc.py @@ -1440,6 +1440,60 @@ class IPeerData(ABC): """ + @abstractmethod + def update_last_identified(self) -> None: + """ + Updates timestamp to current time. + """ + + @abstractmethod + def get_last_identified(self) -> int: + """ + Fetch the last identified timestamp + + Returns + ------- + last_identified_timestamp + The lastIdentified time of peer. + + """ + + @abstractmethod + def get_ttl(self) -> int: + """ + Get ttl value for the peer for validity check + + Returns + ------- + int + The ttl of the peer. + + """ + + @abstractmethod + def set_ttl(self, ttl: int) -> None: + """ + Set ttl value for the peer for validity check + + Parameters + ---------- + ttl : int + The ttl for the peer. + + """ + + @abstractmethod + def is_expired(self) -> bool: + """ + Check if the peer is expired based on last_identified and ttl + + Returns + ------- + bool + True, if last_identified + ttl > current_time + + """ + # ------------------ multiselect_communicator interface.py ------------------ diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index 6d844bee..1dea876d 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -234,7 +234,7 @@ class BasicHost(IHost): :param peer_info: peer_info of the peer we want to connect to :type peer_info: peer.peerinfo.PeerInfo """ - self.peerstore.add_addrs(peer_info.peer_id, peer_info.addrs, 10) + self.peerstore.add_addrs(peer_info.peer_id, peer_info.addrs, 120) # there is already a connection to this peer if peer_info.peer_id in self._network.connections: diff --git a/libp2p/host/routed_host.py b/libp2p/host/routed_host.py index 7cbe81d9..b637e1eb 100644 --- a/libp2p/host/routed_host.py +++ b/libp2p/host/routed_host.py @@ -40,8 +40,8 @@ class RoutedHost(BasicHost): found_peer_info = await self._router.find_peer(peer_info.peer_id) if not found_peer_info: raise ConnectionFailure("Unable to find Peer address") - self.peerstore.add_addrs(peer_info.peer_id, found_peer_info.addrs, 10) - self.peerstore.add_addrs(peer_info.peer_id, peer_info.addrs, 10) + self.peerstore.add_addrs(peer_info.peer_id, found_peer_info.addrs, 120) + self.peerstore.add_addrs(peer_info.peer_id, peer_info.addrs, 120) # there is already a connection to this peer if peer_info.peer_id in self._network.connections: diff --git a/libp2p/peer/peerdata.py b/libp2p/peer/peerdata.py index fa9f4f54..386e31ef 100644 --- a/libp2p/peer/peerdata.py +++ b/libp2p/peer/peerdata.py @@ -1,7 +1,10 @@ from collections.abc import ( Sequence, ) -from typing import Any +import time +from typing import ( + Any, +) from multiaddr import ( Multiaddr, @@ -22,6 +25,8 @@ class PeerData(IPeerData): metadata: dict[Any, Any] protocols: list[str] addrs: list[Multiaddr] + last_identified: int + ttl: int # Keep ttl=0 by default for always valid def __init__(self) -> None: self.pubkey = None @@ -29,6 +34,8 @@ class PeerData(IPeerData): self.metadata = {} self.protocols = [] self.addrs = [] + self.last_identified = int(time.time()) + self.ttl = 0 def get_protocols(self) -> list[str]: """ @@ -113,6 +120,36 @@ class PeerData(IPeerData): raise PeerDataError("private key not found") return self.privkey + def update_last_identified(self) -> None: + self.last_identified = int(time.time()) + + def get_last_identified(self) -> int: + """ + :return: last identified timestamp + """ + return self.last_identified + + def get_ttl(self) -> int: + """ + :return: ttl for current peer + """ + return self.ttl + + def set_ttl(self, ttl: int) -> None: + """ + :param ttl: ttl to set + """ + self.ttl = ttl + + def is_expired(self) -> bool: + """ + :return: true, if last_identified+ttl > current_time + """ + # for ttl = 0; peer_data is always valid + if self.ttl > 0 and self.last_identified + self.ttl < int(time.time()): + return True + return False + class PeerDataError(KeyError): """Raised when a key is not found in peer metadata.""" diff --git a/libp2p/peer/peerstore.py b/libp2p/peer/peerstore.py index efee6059..3bb729d2 100644 --- a/libp2p/peer/peerstore.py +++ b/libp2p/peer/peerstore.py @@ -4,7 +4,6 @@ from collections import ( from collections.abc import ( Sequence, ) -import sys from typing import ( Any, ) @@ -33,7 +32,7 @@ from .peerinfo import ( PeerInfo, ) -PERMANENT_ADDR_TTL = sys.maxsize +PERMANENT_ADDR_TTL = 0 class PeerStore(IPeerStore): @@ -49,6 +48,8 @@ class PeerStore(IPeerStore): """ if peer_id in self.peer_data_map: peer_data = self.peer_data_map[peer_id] + if peer_data.is_expired(): + peer_data.clear_addrs() return PeerInfo(peer_id, peer_data.get_addrs()) raise PeerStoreError("peer ID not found") @@ -84,6 +85,18 @@ class PeerStore(IPeerStore): """ return list(self.peer_data_map.keys()) + def valid_peer_ids(self) -> list[ID]: + """ + :return: all of the valid peer IDs stored in peer store + """ + valid_peer_ids: list[ID] = [] + for peer_id, peer_data in self.peer_data_map.items(): + if not peer_data.is_expired(): + valid_peer_ids.append(peer_id) + else: + peer_data.clear_addrs() + return valid_peer_ids + def get(self, peer_id: ID, key: str) -> Any: """ :param peer_id: peer ID to get peer data for @@ -108,7 +121,7 @@ class PeerStore(IPeerStore): peer_data = self.peer_data_map[peer_id] peer_data.put_metadata(key, val) - def add_addr(self, peer_id: ID, addr: Multiaddr, ttl: int) -> None: + def add_addr(self, peer_id: ID, addr: Multiaddr, ttl: int = 0) -> None: """ :param peer_id: peer ID to add address for :param addr: @@ -116,24 +129,30 @@ class PeerStore(IPeerStore): """ self.add_addrs(peer_id, [addr], ttl) - def add_addrs(self, peer_id: ID, addrs: Sequence[Multiaddr], ttl: int) -> None: + def add_addrs(self, peer_id: ID, addrs: Sequence[Multiaddr], ttl: int = 0) -> None: """ :param peer_id: peer ID to add address for :param addrs: :param ttl: time-to-live for the this record """ - # Ignore ttl for now peer_data = self.peer_data_map[peer_id] peer_data.add_addrs(list(addrs)) + peer_data.set_ttl(ttl) + peer_data.update_last_identified() def addrs(self, peer_id: ID) -> list[Multiaddr]: """ :param peer_id: peer ID to get addrs for - :return: list of addrs + :return: list of addrs of a valid peer. :raise PeerStoreError: if peer ID not found """ if peer_id in self.peer_data_map: - return self.peer_data_map[peer_id].get_addrs() + peer_data = self.peer_data_map[peer_id] + if not peer_data.is_expired(): + return peer_data.get_addrs() + else: + peer_data.clear_addrs() + raise PeerStoreError("peer ID is expired") raise PeerStoreError("peer ID not found") def clear_addrs(self, peer_id: ID) -> None: @@ -153,7 +172,11 @@ class PeerStore(IPeerStore): for peer_id in self.peer_data_map: if len(self.peer_data_map[peer_id].get_addrs()) >= 1: - output.append(peer_id) + peer_data = self.peer_data_map[peer_id] + if not peer_data.is_expired(): + output.append(peer_id) + else: + peer_data.clear_addrs() return output def add_pubkey(self, peer_id: ID, pubkey: PublicKey) -> None: diff --git a/newsfragments/650.feature.rst b/newsfragments/650.feature.rst new file mode 100644 index 00000000..80a84675 --- /dev/null +++ b/newsfragments/650.feature.rst @@ -0,0 +1 @@ +fix: remove expired peers from peerstore based on TTL diff --git a/tests/core/peer/test_peerstore.py b/tests/core/peer/test_peerstore.py index fcfc83a2..b0d8ed81 100644 --- a/tests/core/peer/test_peerstore.py +++ b/tests/core/peer/test_peerstore.py @@ -1,3 +1,5 @@ +import time + import pytest from multiaddr import Multiaddr @@ -18,9 +20,34 @@ def test_peer_info_empty(): def test_peer_info_basic(): store = PeerStore() - store.add_addr(ID(b"peer"), Multiaddr("/ip4/127.0.0.1/tcp/4001"), 10) - info = store.peer_info(ID(b"peer")) + store.add_addr(ID(b"peer"), Multiaddr("/ip4/127.0.0.1/tcp/4001"), 1) + # update ttl to new value + store.add_addr(ID(b"peer"), Multiaddr("/ip4/127.0.0.1/tcp/4002"), 2) + + time.sleep(1) + info = store.peer_info(ID(b"peer")) + assert info.peer_id == ID(b"peer") + assert info.addrs == [ + Multiaddr("/ip4/127.0.0.1/tcp/4001"), + Multiaddr("/ip4/127.0.0.1/tcp/4002"), + ] + + # Check that addresses are cleared after ttl + time.sleep(2) + info = store.peer_info(ID(b"peer")) + assert info.peer_id == ID(b"peer") + assert info.addrs == [] + assert store.peer_ids() == [ID(b"peer")] + assert store.valid_peer_ids() == [] + + +# Check if all the data remains valid if ttl is set to default(0) +def test_peer_permanent_ttl(): + store = PeerStore() + store.add_addr(ID(b"peer"), Multiaddr("/ip4/127.0.0.1/tcp/4001")) + time.sleep(1) + info = store.peer_info(ID(b"peer")) assert info.peer_id == ID(b"peer") assert info.addrs == [Multiaddr("/ip4/127.0.0.1/tcp/4001")] diff --git a/tests/core/pubsub/test_floodsub.py b/tests/core/pubsub/test_floodsub.py index 135cbbec..f6ab8996 100644 --- a/tests/core/pubsub/test_floodsub.py +++ b/tests/core/pubsub/test_floodsub.py @@ -44,12 +44,12 @@ async def test_simple_two_nodes(): @pytest.mark.trio async def test_timed_cache_two_nodes(): - # Two nodes using LastSeenCache with a TTL of 120 seconds + # Two nodes using LastSeenCache with a TTL of 10 seconds def get_msg_id(msg): return msg.data + msg.from_id async with PubsubFactory.create_batch_with_floodsub( - 2, seen_ttl=120, msg_id_constructor=get_msg_id + 2, seen_ttl=10, msg_id_constructor=get_msg_id ) as pubsubs_fsub: message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1] expected_received_indices = [1, 2, 3, 4, 5]