mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Merge branch 'main' into add-ws-transport
This commit is contained in:
@ -8,6 +8,7 @@ from collections import (
|
||||
import logging
|
||||
import time
|
||||
|
||||
import multihash
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
@ -40,6 +41,22 @@ PEER_REFRESH_INTERVAL = 60 # Interval to refresh peers in seconds
|
||||
STALE_PEER_THRESHOLD = 3600 # Time in seconds after which a peer is considered stale
|
||||
|
||||
|
||||
def peer_id_to_key(peer_id: ID) -> bytes:
|
||||
"""
|
||||
Convert a peer ID to a 256-bit key for routing table operations.
|
||||
This normalizes all peer IDs to exactly 256 bits by hashing them with SHA-256.
|
||||
|
||||
:param peer_id: The peer ID to convert
|
||||
:return: 32-byte (256-bit) key for routing table operations
|
||||
"""
|
||||
return multihash.digest(peer_id.to_bytes(), "sha2-256").digest
|
||||
|
||||
|
||||
def key_to_int(key: bytes) -> int:
|
||||
"""Convert a 256-bit key to an integer for range calculations."""
|
||||
return int.from_bytes(key, byteorder="big")
|
||||
|
||||
|
||||
class KBucket:
|
||||
"""
|
||||
A k-bucket implementation for the Kademlia DHT.
|
||||
@ -357,9 +374,24 @@ class KBucket:
|
||||
True if the key is in range, False otherwise
|
||||
|
||||
"""
|
||||
key_int = int.from_bytes(key, byteorder="big")
|
||||
key_int = key_to_int(key)
|
||||
return self.min_range <= key_int < self.max_range
|
||||
|
||||
def peer_id_in_range(self, peer_id: ID) -> bool:
|
||||
"""
|
||||
Check if a peer ID is in the range of this bucket.
|
||||
|
||||
params: peer_id: The peer ID to check
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the peer ID is in range, False otherwise
|
||||
|
||||
"""
|
||||
key = peer_id_to_key(peer_id)
|
||||
return self.key_in_range(key)
|
||||
|
||||
def split(self) -> tuple["KBucket", "KBucket"]:
|
||||
"""
|
||||
Split the bucket into two buckets.
|
||||
@ -376,8 +408,9 @@ class KBucket:
|
||||
|
||||
# Redistribute peers
|
||||
for peer_id, (peer_info, timestamp) in self.peers.items():
|
||||
peer_key = int.from_bytes(peer_id.to_bytes(), byteorder="big")
|
||||
if peer_key < midpoint:
|
||||
peer_key = peer_id_to_key(peer_id)
|
||||
peer_key_int = key_to_int(peer_key)
|
||||
if peer_key_int < midpoint:
|
||||
lower_bucket.peers[peer_id] = (peer_info, timestamp)
|
||||
else:
|
||||
upper_bucket.peers[peer_id] = (peer_info, timestamp)
|
||||
@ -458,7 +491,38 @@ class RoutingTable:
|
||||
success = await bucket.add_peer(peer_info)
|
||||
if success:
|
||||
logger.debug(f"Successfully added peer {peer_id} to routing table")
|
||||
return success
|
||||
return True
|
||||
|
||||
# If bucket is full and couldn't add peer, try splitting the bucket
|
||||
# Only split if the bucket contains our Peer ID
|
||||
if self._should_split_bucket(bucket):
|
||||
logger.debug(
|
||||
f"Bucket is full, attempting to split bucket for peer {peer_id}"
|
||||
)
|
||||
split_success = self._split_bucket(bucket)
|
||||
if split_success:
|
||||
# After splitting,
|
||||
# find the appropriate bucket for the peer and try to add it
|
||||
target_bucket = self.find_bucket(peer_info.peer_id)
|
||||
success = await target_bucket.add_peer(peer_info)
|
||||
if success:
|
||||
logger.debug(
|
||||
f"Successfully added peer {peer_id} after bucket split"
|
||||
)
|
||||
return True
|
||||
else:
|
||||
logger.debug(
|
||||
f"Failed to add peer {peer_id} even after bucket split"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
logger.debug(f"Failed to split bucket for peer {peer_id}")
|
||||
return False
|
||||
else:
|
||||
logger.debug(
|
||||
f"Bucket is full and cannot be split, peer {peer_id} not added"
|
||||
)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error adding peer {peer_obj} to routing table: {e}")
|
||||
@ -480,9 +544,9 @@ class RoutingTable:
|
||||
|
||||
def find_bucket(self, peer_id: ID) -> KBucket:
|
||||
"""
|
||||
Find the bucket that would contain the given peer ID or PeerInfo.
|
||||
Find the bucket that would contain the given peer ID.
|
||||
|
||||
:param peer_obj: Either a peer ID or a PeerInfo object
|
||||
:param peer_id: The peer ID to find a bucket for
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -490,7 +554,7 @@ class RoutingTable:
|
||||
|
||||
"""
|
||||
for bucket in self.buckets:
|
||||
if bucket.key_in_range(peer_id.to_bytes()):
|
||||
if bucket.peer_id_in_range(peer_id):
|
||||
return bucket
|
||||
|
||||
return self.buckets[0]
|
||||
@ -513,7 +577,11 @@ class RoutingTable:
|
||||
all_peers.extend(bucket.peer_ids())
|
||||
|
||||
# Sort by XOR distance to the key
|
||||
all_peers.sort(key=lambda p: xor_distance(p.to_bytes(), key))
|
||||
def distance_to_key(peer_id: ID) -> int:
|
||||
peer_key = peer_id_to_key(peer_id)
|
||||
return xor_distance(peer_key, key)
|
||||
|
||||
all_peers.sort(key=distance_to_key)
|
||||
|
||||
return all_peers[:count]
|
||||
|
||||
@ -591,6 +659,20 @@ class RoutingTable:
|
||||
stale_peers.extend(bucket.get_stale_peers(stale_threshold_seconds))
|
||||
return stale_peers
|
||||
|
||||
def get_peer_infos(self) -> list[PeerInfo]:
|
||||
"""
|
||||
Get all PeerInfo objects in the routing table.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[PeerInfo]: List of all PeerInfo objects
|
||||
|
||||
"""
|
||||
peer_infos = []
|
||||
for bucket in self.buckets:
|
||||
peer_infos.extend(bucket.peer_infos())
|
||||
return peer_infos
|
||||
|
||||
def cleanup_routing_table(self) -> None:
|
||||
"""
|
||||
Cleanup the routing table by removing all data.
|
||||
@ -598,3 +680,66 @@ class RoutingTable:
|
||||
"""
|
||||
self.buckets = [KBucket(self.host, BUCKET_SIZE)]
|
||||
logger.info("Routing table cleaned up, all data removed.")
|
||||
|
||||
def _should_split_bucket(self, bucket: KBucket) -> bool:
|
||||
"""
|
||||
Check if a bucket should be split according to Kademlia rules.
|
||||
|
||||
:param bucket: The bucket to check
|
||||
:return: True if the bucket should be split
|
||||
"""
|
||||
# Check if we've exceeded maximum buckets
|
||||
if len(self.buckets) >= MAXIMUM_BUCKETS:
|
||||
logger.debug("Maximum number of buckets reached, cannot split")
|
||||
return False
|
||||
|
||||
# Check if the bucket contains our local ID
|
||||
local_key = peer_id_to_key(self.local_id)
|
||||
local_key_int = key_to_int(local_key)
|
||||
contains_local_id = bucket.min_range <= local_key_int < bucket.max_range
|
||||
|
||||
logger.debug(
|
||||
f"Bucket range: {bucket.min_range} - {bucket.max_range}, "
|
||||
f"local_key_int: {local_key_int}, contains_local: {contains_local_id}"
|
||||
)
|
||||
|
||||
return contains_local_id
|
||||
|
||||
def _split_bucket(self, bucket: KBucket) -> bool:
|
||||
"""
|
||||
Split a bucket into two buckets.
|
||||
|
||||
:param bucket: The bucket to split
|
||||
:return: True if the bucket was successfully split
|
||||
"""
|
||||
try:
|
||||
# Find the bucket index
|
||||
bucket_index = self.buckets.index(bucket)
|
||||
logger.debug(f"Splitting bucket at index {bucket_index}")
|
||||
|
||||
# Split the bucket
|
||||
lower_bucket, upper_bucket = bucket.split()
|
||||
|
||||
# Replace the original bucket with the two new buckets
|
||||
self.buckets[bucket_index] = lower_bucket
|
||||
self.buckets.insert(bucket_index + 1, upper_bucket)
|
||||
|
||||
logger.debug(
|
||||
f"Bucket split successful. New bucket count: {len(self.buckets)}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Lower bucket range: "
|
||||
f"{lower_bucket.min_range} - {lower_bucket.max_range}, "
|
||||
f"peers: {lower_bucket.size()}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Upper bucket range: "
|
||||
f"{upper_bucket.min_range} - {upper_bucket.max_range}, "
|
||||
f"peers: {upper_bucket.size()}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error splitting bucket: {e}")
|
||||
return False
|
||||
|
||||
@ -23,7 +23,8 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
"""
|
||||
Reference: https://github.com/libp2p/go-libp2p-swarm/blob/04c86bbdafd390651cb2ee14e334f7caeedad722/swarm_conn.go
|
||||
Reference: https://github.com/libp2p/go-libp2p-swarm/blob/
|
||||
04c86bbdafd390651cb2ee14e334f7caeedad722/swarm_conn.go
|
||||
"""
|
||||
|
||||
|
||||
@ -43,6 +44,21 @@ class SwarmConn(INetConn):
|
||||
self.streams = set()
|
||||
self.event_closed = trio.Event()
|
||||
self.event_started = trio.Event()
|
||||
# Provide back-references/hooks expected by NetStream
|
||||
try:
|
||||
setattr(self.muxed_conn, "swarm", self.swarm)
|
||||
|
||||
# NetStream expects an awaitable remove_stream hook
|
||||
async def _remove_stream_hook(stream: NetStream) -> None:
|
||||
self.remove_stream(stream)
|
||||
|
||||
setattr(self.muxed_conn, "remove_stream", _remove_stream_hook)
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
f"Failed to set optional conveniences on muxed_conn "
|
||||
f"for peer {muxed_conn.peer_id}: {e}"
|
||||
)
|
||||
# optional conveniences
|
||||
if hasattr(muxed_conn, "on_close"):
|
||||
logging.debug(f"Setting on_close for peer {muxed_conn.peer_id}")
|
||||
setattr(muxed_conn, "on_close", self._on_muxed_conn_closed)
|
||||
|
||||
@ -1,3 +1,7 @@
|
||||
from collections.abc import (
|
||||
Awaitable,
|
||||
Callable,
|
||||
)
|
||||
import logging
|
||||
|
||||
from multiaddr import (
|
||||
@ -333,8 +337,16 @@ class Swarm(Service, INetworkService):
|
||||
|
||||
# Close all listeners
|
||||
if hasattr(self, "listeners"):
|
||||
for listener in self.listeners.values():
|
||||
for maddr_str, listener in self.listeners.items():
|
||||
await listener.close()
|
||||
# Notify about listener closure
|
||||
try:
|
||||
multiaddr = Multiaddr(maddr_str)
|
||||
await self.notify_listen_close(multiaddr)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to notify listen_close for {maddr_str}: {e}"
|
||||
)
|
||||
self.listeners.clear()
|
||||
|
||||
# Close the transport if it exists and has a close method
|
||||
@ -418,7 +430,17 @@ class Swarm(Service, INetworkService):
|
||||
nursery.start_soon(notifee.listen, self, multiaddr)
|
||||
|
||||
async def notify_closed_stream(self, stream: INetStream) -> None:
|
||||
raise NotImplementedError
|
||||
async with trio.open_nursery() as nursery:
|
||||
for notifee in self.notifees:
|
||||
nursery.start_soon(notifee.closed_stream, self, stream)
|
||||
|
||||
async def notify_listen_close(self, multiaddr: Multiaddr) -> None:
|
||||
raise NotImplementedError
|
||||
async with trio.open_nursery() as nursery:
|
||||
for notifee in self.notifees:
|
||||
nursery.start_soon(notifee.listen_close, self, multiaddr)
|
||||
|
||||
# Generic notifier used by NetStream._notify_closed
|
||||
async def notify_all(self, notifier: Callable[[INotifee], Awaitable[None]]) -> None:
|
||||
async with trio.open_nursery() as nursery:
|
||||
for notifee in self.notifees:
|
||||
nursery.start_soon(notifier, notifee)
|
||||
|
||||
Reference in New Issue
Block a user