added : timed_cache sub-module

This commit is contained in:
Mystical
2025-03-04 20:27:09 +05:30
committed by Paul Robinson
parent 0fa8711ca7
commit e5f3e88134
12 changed files with 158 additions and 20 deletions

View File

@ -17,6 +17,7 @@ Subpackages
libp2p.pubsub libp2p.pubsub
libp2p.security libp2p.security
libp2p.stream_muxer libp2p.stream_muxer
libp2p.timed_cache
libp2p.tools libp2p.tools
libp2p.transport libp2p.transport

View File

@ -0,0 +1,37 @@
libp2p.timed_cache package
===================
Submodules
----------
libp2p.timed\_cache.basic\_time\_cache module
------------------------------
.. automodule:: libp2p.timed_cache.basic_time_cache
:members:
:undoc-members:
:show-inheritance:
libp2p.timed\_cache.first\_seen\_cache module
------------------------------
.. automodule:: libp2p.timed_cache.first_seen_cache
:members:
:undoc-members:
:show-inheritance:
libp2p.timed\_cache.last\_seen\_cache module
------------------------------
.. automodule:: libp2p.timed_cache.last_seen_cache
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: libp2p.timed_cache
:members:
:undoc-members:
:show-inheritance:

View File

@ -580,7 +580,7 @@ class GossipSub(IPubsubRouter, Service):
# Get list of all seen (seqnos, from) from the (seqno, from) tuples in # Get list of all seen (seqnos, from) from the (seqno, from) tuples in
# seen_messages cache # seen_messages cache
seen_seqnos_and_peers = [ seen_seqnos_and_peers = [
seqno_and_from for seqno_and_from in self.pubsub.seen_messages.keys() seqno_and_from for seqno_and_from in self.pubsub.seen_messages.cache.keys()
] ]
# Add all unknown message ids (ids that appear in ihave_msg but not in # Add all unknown message ids (ids that appear in ihave_msg but not in

View File

@ -18,9 +18,6 @@ from typing import (
) )
import base58 import base58
from lru import (
LRU,
)
import trio import trio
from libp2p.abc import ( from libp2p.abc import (
@ -56,6 +53,9 @@ from libp2p.network.stream.exceptions import (
from libp2p.peer.id import ( from libp2p.peer.id import (
ID, ID,
) )
from libp2p.timed_cache.last_seen_cache import (
LastSeenCache,
)
from libp2p.tools.async_service import ( from libp2p.tools.async_service import (
Service, Service,
) )
@ -112,7 +112,7 @@ class Pubsub(Service, IPubsub):
peer_receive_channel: trio.MemoryReceiveChannel[ID] peer_receive_channel: trio.MemoryReceiveChannel[ID]
dead_peer_receive_channel: trio.MemoryReceiveChannel[ID] dead_peer_receive_channel: trio.MemoryReceiveChannel[ID]
seen_messages: LRU[bytes, bool] seen_messages: LastSeenCache
subscribed_topics_send: dict[str, trio.MemorySendChannel[rpc_pb2.Message]] subscribed_topics_send: dict[str, trio.MemorySendChannel[rpc_pb2.Message]]
subscribed_topics_receive: dict[str, TrioSubscriptionAPI] subscribed_topics_receive: dict[str, TrioSubscriptionAPI]
@ -136,6 +136,7 @@ class Pubsub(Service, IPubsub):
host: IHost, host: IHost,
router: IPubsubRouter, router: IPubsubRouter,
cache_size: int = None, cache_size: int = None,
seen_ttl: int = 120,
strict_signing: bool = True, strict_signing: bool = True,
msg_id_constructor: Callable[ msg_id_constructor: Callable[
[rpc_pb2.Message], bytes [rpc_pb2.Message], bytes
@ -187,7 +188,7 @@ class Pubsub(Service, IPubsub):
else: else:
self.sign_key = None self.sign_key = None
self.seen_messages = LRU(self.cache_size) self.seen_messages = LastSeenCache(seen_ttl)
# Map of topics we are subscribed to blocking queues # Map of topics we are subscribed to blocking queues
# for when the given topic receives a message # for when the given topic receives a message
@ -662,11 +663,11 @@ class Pubsub(Service, IPubsub):
def _is_msg_seen(self, msg: rpc_pb2.Message) -> bool: def _is_msg_seen(self, msg: rpc_pb2.Message) -> bool:
msg_id = self._msg_id_constructor(msg) msg_id = self._msg_id_constructor(msg)
return msg_id in self.seen_messages return self.seen_messages.has(msg_id)
def _mark_msg_seen(self, msg: rpc_pb2.Message) -> None: def _mark_msg_seen(self, msg: rpc_pb2.Message) -> None:
msg_id = self._msg_id_constructor(msg) msg_id = self._msg_id_constructor(msg)
self.seen_messages[msg_id] = True self.seen_messages.add(msg_id)
def _is_subscribed_to_msg(self, msg: rpc_pb2.Message) -> bool: def _is_subscribed_to_msg(self, msg: rpc_pb2.Message) -> bool:
return any(topic in self.topic_ids for topic in msg.topicIDs) return any(topic in self.topic_ids for topic in msg.topicIDs)

View File

View File

@ -0,0 +1,51 @@
import threading
import time
class TimedCache:
"""Base class for Timed Cache with cleanup mechanism."""
cache: dict[bytes, int]
SWEEP_INTERVAL = 60 # 1-minute interval between each sweep
def __init__(self, ttl: int) -> None:
"""
Initialize a new TimedCache with a time-to-live for cache entries
:param ttl: no of seconds as time-to-live for each cache entry
"""
self.ttl = ttl
self.lock = threading.Lock()
self.cache = {}
self._stop_event = threading.Event()
self._thread = threading.Thread(target=self._background_cleanup, daemon=True)
self._thread.start()
def _background_cleanup(self) -> None:
while not self._stop_event.wait(self.SWEEP_INTERVAL):
self._sweep()
def _sweep(self) -> None:
"""Removes expired entries from the cache."""
now = time.time()
with self.lock:
keys_to_remove = [key for key, expiry in self.cache.items() if expiry < now]
for key in keys_to_remove:
del self.cache[key]
def stop(self) -> None:
"""Stops the background cleanup thread."""
self._stop_event.set()
self._thread.join()
def length(self) -> int:
return len(self.cache)
def add(self, key: bytes) -> bool:
"""To be implemented in subclasses."""
raise NotImplementedError
def has(self, key: bytes) -> bool:
"""To be implemented in subclasses."""
raise NotImplementedError

View File

@ -0,0 +1,20 @@
import time
from .basic_time_cache import (
TimedCache,
)
class FirstSeenCache(TimedCache):
"""Cache where expiry is set only when first added."""
def add(self, key: bytes) -> bool:
with self.lock:
if key in self.cache:
return False
self.cache[key] = int(time.time()) + self.ttl
return True
def has(self, key: bytes) -> bool:
with self.lock:
return key in self.cache

View File

@ -0,0 +1,22 @@
import time
from .basic_time_cache import (
TimedCache,
)
class LastSeenCache(TimedCache):
"""Cache where expiry is updated on every access."""
def add(self, key: bytes) -> bool:
with self.lock:
is_new = key not in self.cache
self.cache[key] = int(time.time()) + self.ttl
return is_new
def has(self, key: bytes) -> bool:
with self.lock:
if key in self.cache:
self.cache[key] = int(time.time()) + self.ttl
return True
return False

View File

@ -447,6 +447,7 @@ class PubsubFactory(factory.Factory):
host: IHost, host: IHost,
router: IPubsubRouter, router: IPubsubRouter,
cache_size: int, cache_size: int,
seen_ttl: int,
strict_signing: bool, strict_signing: bool,
msg_id_constructor: Callable[[rpc_pb2.Message], bytes] = None, msg_id_constructor: Callable[[rpc_pb2.Message], bytes] = None,
) -> AsyncIterator[Pubsub]: ) -> AsyncIterator[Pubsub]:
@ -454,6 +455,7 @@ class PubsubFactory(factory.Factory):
host=host, host=host,
router=router, router=router,
cache_size=cache_size, cache_size=cache_size,
seen_ttl=seen_ttl,
strict_signing=strict_signing, strict_signing=strict_signing,
msg_id_constructor=msg_id_constructor, msg_id_constructor=msg_id_constructor,
) )
@ -468,6 +470,7 @@ class PubsubFactory(factory.Factory):
number: int, number: int,
routers: Sequence[IPubsubRouter], routers: Sequence[IPubsubRouter],
cache_size: int = None, cache_size: int = None,
seen_ttl: int = None,
strict_signing: bool = False, strict_signing: bool = False,
security_protocol: TProtocol = None, security_protocol: TProtocol = None,
muxer_opt: TMuxerOptions = None, muxer_opt: TMuxerOptions = None,
@ -481,7 +484,12 @@ class PubsubFactory(factory.Factory):
pubsubs = [ pubsubs = [
await stack.enter_async_context( await stack.enter_async_context(
cls.create_and_start( cls.create_and_start(
host, router, cache_size, strict_signing, msg_id_constructor host,
router,
cache_size,
seen_ttl,
strict_signing,
msg_id_constructor,
) )
) )
for host, router in zip(hosts, routers) for host, router in zip(hosts, routers)
@ -494,6 +502,7 @@ class PubsubFactory(factory.Factory):
cls, cls,
number: int, number: int,
cache_size: int = None, cache_size: int = None,
seen_ttl: int = 120,
strict_signing: bool = False, strict_signing: bool = False,
protocols: Sequence[TProtocol] = None, protocols: Sequence[TProtocol] = None,
security_protocol: TProtocol = None, security_protocol: TProtocol = None,
@ -510,6 +519,7 @@ class PubsubFactory(factory.Factory):
number, number,
floodsubs, floodsubs,
cache_size, cache_size,
seen_ttl,
strict_signing, strict_signing,
security_protocol=security_protocol, security_protocol=security_protocol,
muxer_opt=muxer_opt, muxer_opt=muxer_opt,

View File

@ -0,0 +1 @@
implemented ``timed_cache`` module which will allow to implement ``seen_ttl`` configurable param for all pubsub and it's derived protocol.

View File

@ -43,21 +43,16 @@ async def test_simple_two_nodes():
@pytest.mark.trio @pytest.mark.trio
async def test_lru_cache_two_nodes(): async def test_timed_cache_two_nodes():
# two nodes with cache_size of 4 # Two nodes using LastSeenCache with a TTL of 120 seconds
# Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`.
def get_msg_id(msg): def get_msg_id(msg):
# Originally it is `(msg.seqno, msg.from_id)`
return (msg.data, msg.from_id) return (msg.data, msg.from_id)
async with PubsubFactory.create_batch_with_floodsub( async with PubsubFactory.create_batch_with_floodsub(
2, cache_size=4, msg_id_constructor=get_msg_id 2, seen_ttl=120, msg_id_constructor=get_msg_id
) as pubsubs_fsub: ) as pubsubs_fsub:
# `node_a` send the following messages to node_b
message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1] message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1]
# `node_b` should only receive the following expected_received_indices = [1, 2, 3, 4, 5]
expected_received_indices = [1, 2, 3, 4, 5, 1]
topic = "my_topic" topic = "my_topic"

View File

@ -635,8 +635,8 @@ async def test_strict_signing():
await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA) await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA)
await trio.sleep(1) await trio.sleep(1)
assert len(pubsubs_fsub[0].seen_messages) == 1 assert pubsubs_fsub[0].seen_messages.length() == 1
assert len(pubsubs_fsub[1].seen_messages) == 1 assert pubsubs_fsub[1].seen_messages.length() == 1
@pytest.mark.trio @pytest.mark.trio