Merge branch 'main' into add-ws-transport

This commit is contained in:
Manu Sheel Gupta
2025-08-18 22:00:20 +05:30
committed by GitHub
9 changed files with 421 additions and 36 deletions

View File

@ -8,6 +8,7 @@ from collections import (
import logging
import time
import multihash
import trio
from libp2p.abc import (
@ -40,6 +41,22 @@ PEER_REFRESH_INTERVAL = 60 # Interval to refresh peers in seconds
STALE_PEER_THRESHOLD = 3600 # Time in seconds after which a peer is considered stale
def peer_id_to_key(peer_id: ID) -> bytes:
"""
Convert a peer ID to a 256-bit key for routing table operations.
This normalizes all peer IDs to exactly 256 bits by hashing them with SHA-256.
:param peer_id: The peer ID to convert
:return: 32-byte (256-bit) key for routing table operations
"""
return multihash.digest(peer_id.to_bytes(), "sha2-256").digest
def key_to_int(key: bytes) -> int:
"""Convert a 256-bit key to an integer for range calculations."""
return int.from_bytes(key, byteorder="big")
class KBucket:
"""
A k-bucket implementation for the Kademlia DHT.
@ -357,9 +374,24 @@ class KBucket:
True if the key is in range, False otherwise
"""
key_int = int.from_bytes(key, byteorder="big")
key_int = key_to_int(key)
return self.min_range <= key_int < self.max_range
def peer_id_in_range(self, peer_id: ID) -> bool:
"""
Check if a peer ID is in the range of this bucket.
params: peer_id: The peer ID to check
Returns
-------
bool
True if the peer ID is in range, False otherwise
"""
key = peer_id_to_key(peer_id)
return self.key_in_range(key)
def split(self) -> tuple["KBucket", "KBucket"]:
"""
Split the bucket into two buckets.
@ -376,8 +408,9 @@ class KBucket:
# Redistribute peers
for peer_id, (peer_info, timestamp) in self.peers.items():
peer_key = int.from_bytes(peer_id.to_bytes(), byteorder="big")
if peer_key < midpoint:
peer_key = peer_id_to_key(peer_id)
peer_key_int = key_to_int(peer_key)
if peer_key_int < midpoint:
lower_bucket.peers[peer_id] = (peer_info, timestamp)
else:
upper_bucket.peers[peer_id] = (peer_info, timestamp)
@ -458,7 +491,38 @@ class RoutingTable:
success = await bucket.add_peer(peer_info)
if success:
logger.debug(f"Successfully added peer {peer_id} to routing table")
return success
return True
# If bucket is full and couldn't add peer, try splitting the bucket
# Only split if the bucket contains our Peer ID
if self._should_split_bucket(bucket):
logger.debug(
f"Bucket is full, attempting to split bucket for peer {peer_id}"
)
split_success = self._split_bucket(bucket)
if split_success:
# After splitting,
# find the appropriate bucket for the peer and try to add it
target_bucket = self.find_bucket(peer_info.peer_id)
success = await target_bucket.add_peer(peer_info)
if success:
logger.debug(
f"Successfully added peer {peer_id} after bucket split"
)
return True
else:
logger.debug(
f"Failed to add peer {peer_id} even after bucket split"
)
return False
else:
logger.debug(f"Failed to split bucket for peer {peer_id}")
return False
else:
logger.debug(
f"Bucket is full and cannot be split, peer {peer_id} not added"
)
return False
except Exception as e:
logger.debug(f"Error adding peer {peer_obj} to routing table: {e}")
@ -480,9 +544,9 @@ class RoutingTable:
def find_bucket(self, peer_id: ID) -> KBucket:
"""
Find the bucket that would contain the given peer ID or PeerInfo.
Find the bucket that would contain the given peer ID.
:param peer_obj: Either a peer ID or a PeerInfo object
:param peer_id: The peer ID to find a bucket for
Returns
-------
@ -490,7 +554,7 @@ class RoutingTable:
"""
for bucket in self.buckets:
if bucket.key_in_range(peer_id.to_bytes()):
if bucket.peer_id_in_range(peer_id):
return bucket
return self.buckets[0]
@ -513,7 +577,11 @@ class RoutingTable:
all_peers.extend(bucket.peer_ids())
# Sort by XOR distance to the key
all_peers.sort(key=lambda p: xor_distance(p.to_bytes(), key))
def distance_to_key(peer_id: ID) -> int:
peer_key = peer_id_to_key(peer_id)
return xor_distance(peer_key, key)
all_peers.sort(key=distance_to_key)
return all_peers[:count]
@ -591,6 +659,20 @@ class RoutingTable:
stale_peers.extend(bucket.get_stale_peers(stale_threshold_seconds))
return stale_peers
def get_peer_infos(self) -> list[PeerInfo]:
"""
Get all PeerInfo objects in the routing table.
Returns
-------
List[PeerInfo]: List of all PeerInfo objects
"""
peer_infos = []
for bucket in self.buckets:
peer_infos.extend(bucket.peer_infos())
return peer_infos
def cleanup_routing_table(self) -> None:
"""
Cleanup the routing table by removing all data.
@ -598,3 +680,66 @@ class RoutingTable:
"""
self.buckets = [KBucket(self.host, BUCKET_SIZE)]
logger.info("Routing table cleaned up, all data removed.")
def _should_split_bucket(self, bucket: KBucket) -> bool:
"""
Check if a bucket should be split according to Kademlia rules.
:param bucket: The bucket to check
:return: True if the bucket should be split
"""
# Check if we've exceeded maximum buckets
if len(self.buckets) >= MAXIMUM_BUCKETS:
logger.debug("Maximum number of buckets reached, cannot split")
return False
# Check if the bucket contains our local ID
local_key = peer_id_to_key(self.local_id)
local_key_int = key_to_int(local_key)
contains_local_id = bucket.min_range <= local_key_int < bucket.max_range
logger.debug(
f"Bucket range: {bucket.min_range} - {bucket.max_range}, "
f"local_key_int: {local_key_int}, contains_local: {contains_local_id}"
)
return contains_local_id
def _split_bucket(self, bucket: KBucket) -> bool:
"""
Split a bucket into two buckets.
:param bucket: The bucket to split
:return: True if the bucket was successfully split
"""
try:
# Find the bucket index
bucket_index = self.buckets.index(bucket)
logger.debug(f"Splitting bucket at index {bucket_index}")
# Split the bucket
lower_bucket, upper_bucket = bucket.split()
# Replace the original bucket with the two new buckets
self.buckets[bucket_index] = lower_bucket
self.buckets.insert(bucket_index + 1, upper_bucket)
logger.debug(
f"Bucket split successful. New bucket count: {len(self.buckets)}"
)
logger.debug(
f"Lower bucket range: "
f"{lower_bucket.min_range} - {lower_bucket.max_range}, "
f"peers: {lower_bucket.size()}"
)
logger.debug(
f"Upper bucket range: "
f"{upper_bucket.min_range} - {upper_bucket.max_range}, "
f"peers: {upper_bucket.size()}"
)
return True
except Exception as e:
logger.error(f"Error splitting bucket: {e}")
return False

View File

@ -23,7 +23,8 @@ if TYPE_CHECKING:
"""
Reference: https://github.com/libp2p/go-libp2p-swarm/blob/04c86bbdafd390651cb2ee14e334f7caeedad722/swarm_conn.go
Reference: https://github.com/libp2p/go-libp2p-swarm/blob/
04c86bbdafd390651cb2ee14e334f7caeedad722/swarm_conn.go
"""
@ -43,6 +44,21 @@ class SwarmConn(INetConn):
self.streams = set()
self.event_closed = trio.Event()
self.event_started = trio.Event()
# Provide back-references/hooks expected by NetStream
try:
setattr(self.muxed_conn, "swarm", self.swarm)
# NetStream expects an awaitable remove_stream hook
async def _remove_stream_hook(stream: NetStream) -> None:
self.remove_stream(stream)
setattr(self.muxed_conn, "remove_stream", _remove_stream_hook)
except Exception as e:
logging.warning(
f"Failed to set optional conveniences on muxed_conn "
f"for peer {muxed_conn.peer_id}: {e}"
)
# optional conveniences
if hasattr(muxed_conn, "on_close"):
logging.debug(f"Setting on_close for peer {muxed_conn.peer_id}")
setattr(muxed_conn, "on_close", self._on_muxed_conn_closed)

View File

@ -1,3 +1,7 @@
from collections.abc import (
Awaitable,
Callable,
)
import logging
from multiaddr import (
@ -333,8 +337,16 @@ class Swarm(Service, INetworkService):
# Close all listeners
if hasattr(self, "listeners"):
for listener in self.listeners.values():
for maddr_str, listener in self.listeners.items():
await listener.close()
# Notify about listener closure
try:
multiaddr = Multiaddr(maddr_str)
await self.notify_listen_close(multiaddr)
except Exception as e:
logger.warning(
f"Failed to notify listen_close for {maddr_str}: {e}"
)
self.listeners.clear()
# Close the transport if it exists and has a close method
@ -418,7 +430,17 @@ class Swarm(Service, INetworkService):
nursery.start_soon(notifee.listen, self, multiaddr)
async def notify_closed_stream(self, stream: INetStream) -> None:
raise NotImplementedError
async with trio.open_nursery() as nursery:
for notifee in self.notifees:
nursery.start_soon(notifee.closed_stream, self, stream)
async def notify_listen_close(self, multiaddr: Multiaddr) -> None:
raise NotImplementedError
async with trio.open_nursery() as nursery:
for notifee in self.notifees:
nursery.start_soon(notifee.listen_close, self, multiaddr)
# Generic notifier used by NetStream._notify_closed
async def notify_all(self, notifier: Callable[[INotifee], Awaitable[None]]) -> None:
async with trio.open_nursery() as nursery:
for notifee in self.notifees:
nursery.start_soon(notifier, notifee)