mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-12 16:10:57 +00:00
feat: Implement Random walk in py-libp2p (#822)
* Implementing random walk in py libp2p * Add documentation for Random Walk module implementation in py-libp2p * Add Random Walk example for py-libp2p Kademlia DHT * refactor: peer eviction from routing table stopped * refactored location of random walk * add nodesin routing table from peerstore * random walk working as expected * removed extra functions * Removed all manual triggers * added newsfragments * fix linting issues * refacored logs and cleaned example file * refactor: update RandomWalk and RTRefreshManager to use query function for peer discovery * docs: added Random Walk example docs * added optional argument to use random walk in kademlia DHT * enabled random walk in example file * Added tests for RandomWalk module * fixed lint issues * Update refresh interval and some more tests are added. * Removed Random Walk module documentation file * Extra parentheses have been removed from the random walk logs. Co-authored-by: Paul Robinson <5199899+pacrob@users.noreply.github.com> --------- Co-authored-by: Manu Sheel Gupta <manusheel.edu@gmail.com> Co-authored-by: Paul Robinson <5199899+pacrob@users.noreply.github.com>
This commit is contained in:
208
libp2p/discovery/random_walk/rt_refresh_manager.py
Normal file
208
libp2p/discovery/random_walk/rt_refresh_manager.py
Normal file
@ -0,0 +1,208 @@
|
||||
from collections.abc import Awaitable, Callable
|
||||
import logging
|
||||
import time
|
||||
from typing import Protocol
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.abc import IHost
|
||||
from libp2p.discovery.random_walk.config import (
|
||||
MIN_RT_REFRESH_THRESHOLD,
|
||||
RANDOM_WALK_CONCURRENCY,
|
||||
RANDOM_WALK_ENABLED,
|
||||
REFRESH_INTERVAL,
|
||||
)
|
||||
from libp2p.discovery.random_walk.exceptions import RoutingTableRefreshError
|
||||
from libp2p.discovery.random_walk.random_walk import RandomWalk
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import PeerInfo
|
||||
|
||||
|
||||
class RoutingTableProtocol(Protocol):
|
||||
"""Protocol for routing table operations needed by RT refresh manager."""
|
||||
|
||||
def size(self) -> int:
|
||||
"""Return the current size of the routing table."""
|
||||
...
|
||||
|
||||
async def add_peer(self, peer_obj: PeerInfo) -> bool:
|
||||
"""Add a peer to the routing table."""
|
||||
...
|
||||
|
||||
|
||||
logger = logging.getLogger("libp2p.discovery.random_walk.rt_refresh_manager")
|
||||
|
||||
|
||||
class RTRefreshManager:
|
||||
"""
|
||||
Routing Table Refresh Manager for py-libp2p.
|
||||
|
||||
Manages periodic routing table refreshes and random walk operations
|
||||
to maintain routing table health and discover new peers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: IHost,
|
||||
routing_table: RoutingTableProtocol,
|
||||
local_peer_id: ID,
|
||||
query_function: Callable[[bytes], Awaitable[list[ID]]],
|
||||
enable_auto_refresh: bool = RANDOM_WALK_ENABLED,
|
||||
refresh_interval: float = REFRESH_INTERVAL,
|
||||
min_refresh_threshold: int = MIN_RT_REFRESH_THRESHOLD,
|
||||
):
|
||||
"""
|
||||
Initialize RT Refresh Manager.
|
||||
|
||||
Args:
|
||||
host: The libp2p host instance
|
||||
routing_table: Routing table of host
|
||||
local_peer_id: Local peer ID
|
||||
query_function: Function to query for closest peers given target key bytes
|
||||
enable_auto_refresh: Whether to enable automatic refresh
|
||||
refresh_interval: Interval between refreshes in seconds
|
||||
min_refresh_threshold: Minimum RT size before triggering refresh
|
||||
|
||||
"""
|
||||
self.host = host
|
||||
self.routing_table = routing_table
|
||||
self.local_peer_id = local_peer_id
|
||||
self.query_function = query_function
|
||||
|
||||
self.enable_auto_refresh = enable_auto_refresh
|
||||
self.refresh_interval = refresh_interval
|
||||
self.min_refresh_threshold = min_refresh_threshold
|
||||
|
||||
# Initialize random walk module
|
||||
self.random_walk = RandomWalk(
|
||||
host=host,
|
||||
local_peer_id=self.local_peer_id,
|
||||
query_function=query_function,
|
||||
)
|
||||
|
||||
# Control variables
|
||||
self._running = False
|
||||
self._nursery: trio.Nursery | None = None
|
||||
|
||||
# Tracking
|
||||
self._last_refresh_time = 0.0
|
||||
self._refresh_done_callbacks: list[Callable[[], None]] = []
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the RT Refresh Manager."""
|
||||
if self._running:
|
||||
logger.warning("RT Refresh Manager is already running")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
logger.info("Starting RT Refresh Manager")
|
||||
|
||||
# Start the main loop
|
||||
async with trio.open_nursery() as nursery:
|
||||
self._nursery = nursery
|
||||
nursery.start_soon(self._main_loop)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the RT Refresh Manager."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
logger.info("Stopping RT Refresh Manager")
|
||||
self._running = False
|
||||
|
||||
async def _main_loop(self) -> None:
|
||||
"""Main loop for the RT Refresh Manager."""
|
||||
logger.info("RT Refresh Manager main loop started")
|
||||
|
||||
# Initial refresh if auto-refresh is enabled
|
||||
if self.enable_auto_refresh:
|
||||
await self._do_refresh(force=True)
|
||||
|
||||
try:
|
||||
while self._running:
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Schedule periodic refresh if enabled
|
||||
if self.enable_auto_refresh:
|
||||
nursery.start_soon(self._periodic_refresh_task)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"RT Refresh Manager main loop error: {e}")
|
||||
finally:
|
||||
logger.info("RT Refresh Manager main loop stopped")
|
||||
|
||||
async def _periodic_refresh_task(self) -> None:
|
||||
"""Task for periodic refreshes."""
|
||||
while self._running:
|
||||
await trio.sleep(self.refresh_interval)
|
||||
if self._running:
|
||||
await self._do_refresh()
|
||||
|
||||
async def _do_refresh(self, force: bool = False) -> None:
|
||||
"""
|
||||
Perform routing table refresh operation.
|
||||
|
||||
Args:
|
||||
force: Whether to force refresh regardless of timing
|
||||
|
||||
"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
|
||||
# Check if refresh is needed
|
||||
if not force:
|
||||
if current_time - self._last_refresh_time < self.refresh_interval:
|
||||
logger.debug("Skipping refresh: interval not elapsed")
|
||||
return
|
||||
|
||||
if self.routing_table.size() >= self.min_refresh_threshold:
|
||||
logger.debug("Skipping refresh: routing table size above threshold")
|
||||
return
|
||||
|
||||
logger.info(f"Starting routing table refresh (force={force})")
|
||||
start_time = current_time
|
||||
|
||||
# Perform random walks to discover new peers
|
||||
logger.info("Running concurrent random walks to discover new peers")
|
||||
current_rt_size = self.routing_table.size()
|
||||
discovered_peers = await self.random_walk.run_concurrent_random_walks(
|
||||
count=RANDOM_WALK_CONCURRENCY,
|
||||
current_routing_table_size=current_rt_size,
|
||||
)
|
||||
|
||||
# Add discovered peers to routing table
|
||||
added_count = 0
|
||||
for peer_info in discovered_peers:
|
||||
result = await self.routing_table.add_peer(peer_info)
|
||||
if result:
|
||||
added_count += 1
|
||||
|
||||
self._last_refresh_time = current_time
|
||||
|
||||
duration = time.time() - start_time
|
||||
logger.info(
|
||||
f"Routing table refresh completed: "
|
||||
f"{added_count}/{len(discovered_peers)} peers added, "
|
||||
f"RT size: {self.routing_table.size()}, "
|
||||
f"duration: {duration:.2f}s"
|
||||
)
|
||||
|
||||
# Notify refresh completion
|
||||
for callback in self._refresh_done_callbacks:
|
||||
try:
|
||||
callback()
|
||||
except Exception as e:
|
||||
logger.warning(f"Refresh callback error: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Routing table refresh failed: {e}")
|
||||
raise RoutingTableRefreshError(f"Refresh operation failed: {e}") from e
|
||||
|
||||
def add_refresh_done_callback(self, callback: Callable[[], None]) -> None:
|
||||
"""Add a callback to be called when refresh completes."""
|
||||
self._refresh_done_callbacks.append(callback)
|
||||
|
||||
def remove_refresh_done_callback(self, callback: Callable[[], None]) -> None:
|
||||
"""Remove a refresh completion callback."""
|
||||
if callback in self._refresh_done_callbacks:
|
||||
self._refresh_done_callbacks.remove(callback)
|
||||
Reference in New Issue
Block a user