Merge branch 'main' into feature/bootstrap

This commit is contained in:
Manu Sheel Gupta
2025-07-07 08:51:53 -07:00
committed by GitHub
12 changed files with 1617 additions and 190 deletions

View File

@ -385,6 +385,18 @@ class IPeerMetadata(ABC):
:raises Exception: If the operation is unsuccessful.
"""
@abstractmethod
def clear_metadata(self, peer_id: ID) -> None:
"""
Remove all stored metadata for the specified peer.
Parameters
----------
peer_id : ID
The peer identifier whose metadata are to be removed.
"""
# -------------------------- addrbook interface.py --------------------------
@ -476,10 +488,272 @@ class IAddrBook(ABC):
"""
# -------------------------- keybook interface.py --------------------------
class IKeyBook(ABC):
"""
Interface for an key book.
Provides methods for managing cryptographic keys.
"""
@abstractmethod
def pubkey(self, peer_id: ID) -> PublicKey:
"""
Returns the public key of the specified peer
Parameters
----------
peer_id : ID
The peer identifier whose public key is to be returned.
"""
@abstractmethod
def privkey(self, peer_id: ID) -> PrivateKey:
"""
Returns the private key of the specified peer
Parameters
----------
peer_id : ID
The peer identifier whose private key is to be returned.
"""
@abstractmethod
def add_pubkey(self, peer_id: ID, pubkey: PublicKey) -> None:
"""
Adds the public key for a specified peer
Parameters
----------
peer_id : ID
The peer identifier whose public key is to be added
pubkey: PublicKey
The public key of the peer
"""
@abstractmethod
def add_privkey(self, peer_id: ID, privkey: PrivateKey) -> None:
"""
Adds the private key for a specified peer
Parameters
----------
peer_id : ID
The peer identifier whose private key is to be added
privkey: PrivateKey
The private key of the peer
"""
@abstractmethod
def add_key_pair(self, peer_id: ID, key_pair: KeyPair) -> None:
"""
Adds the key pair for a specified peer
Parameters
----------
peer_id : ID
The peer identifier whose key pair is to be added
key_pair: KeyPair
The key pair of the peer
"""
@abstractmethod
def peer_with_keys(self) -> list[ID]:
"""Returns all the peer IDs stored in the AddrBook"""
@abstractmethod
def clear_keydata(self, peer_id: ID) -> None:
"""
Remove all stored keydata for the specified peer.
Parameters
----------
peer_id : ID
The peer identifier whose keys are to be removed.
"""
# -------------------------- metrics interface.py --------------------------
class IMetrics(ABC):
"""
Interface for metrics of peer interaction.
Provides methods for managing the metrics.
"""
@abstractmethod
def record_latency(self, peer_id: ID, RTT: float) -> None:
"""
Records a new round-trip time (RTT) latency value for the specified peer
using Exponentially Weighted Moving Average (EWMA).
Parameters
----------
peer_id : ID
The identifier of the peer for which latency is being recorded.
RTT : float
The round-trip time latency value to record.
"""
@abstractmethod
def latency_EWMA(self, peer_id: ID) -> float:
"""
Returns the current latency value for the specified peer using
Exponentially Weighted Moving Average (EWMA).
Parameters
----------
peer_id : ID
The identifier of the peer whose latency EWMA is to be returned.
"""
@abstractmethod
def clear_metrics(self, peer_id: ID) -> None:
"""
Clears the stored latency metrics for the specified peer.
Parameters
----------
peer_id : ID
The identifier of the peer whose latency metrics are to be cleared.
"""
# -------------------------- protobook interface.py --------------------------
class IProtoBook(ABC):
"""
Interface for a protocol book.
Provides methods for managing the list of supported protocols.
"""
@abstractmethod
def get_protocols(self, peer_id: ID) -> list[str]:
"""
Returns the list of protocols associated with the specified peer.
Parameters
----------
peer_id : ID
The identifier of the peer whose supported protocols are to be returned.
"""
@abstractmethod
def add_protocols(self, peer_id: ID, protocols: Sequence[str]) -> None:
"""
Adds the given protocols to the specified peer's protocol list.
Parameters
----------
peer_id : ID
The identifier of the peer to which protocols will be added.
protocols : Sequence[str]
A sequence of protocol strings to add.
"""
@abstractmethod
def set_protocols(self, peer_id: ID, protocols: Sequence[str]) -> None:
"""
Replaces the existing protocols of the specified peer with the given list.
Parameters
----------
peer_id : ID
The identifier of the peer whose protocols are to be set.
protocols : Sequence[str]
A sequence of protocol strings to assign.
"""
@abstractmethod
def remove_protocols(self, peer_id: ID, protocols: Sequence[str]) -> None:
"""
Removes the specified protocols from the peer's protocol list.
Parameters
----------
peer_id : ID
The identifier of the peer from which protocols will be removed.
protocols : Sequence[str]
A sequence of protocol strings to remove.
"""
@abstractmethod
def supports_protocols(self, peer_id: ID, protocols: Sequence[str]) -> list[str]:
"""
Returns the list of protocols from the input sequence that the peer supports.
Parameters
----------
peer_id : ID
The identifier of the peer to check for protocol support.
protocols : Sequence[str]
A sequence of protocol strings to check against the peer's
supported protocols.
"""
@abstractmethod
def first_supported_protocol(self, peer_id: ID, protocols: Sequence[str]) -> str:
"""
Returns the first protocol from the input list that the peer supports.
Parameters
----------
peer_id : ID
The identifier of the peer to check for supported protocols.
protocols : Sequence[str]
A sequence of protocol strings to check.
Returns
-------
str
The first matching protocol string, or an empty string
if none are supported.
"""
@abstractmethod
def clear_protocol_data(self, peer_id: ID) -> None:
"""
Clears all protocol data associated with the specified peer.
Parameters
----------
peer_id : ID
The identifier of the peer whose protocol data will be cleared.
"""
# -------------------------- peerstore interface.py --------------------------
class IPeerStore(IAddrBook, IPeerMetadata):
class IPeerStore(IPeerMetadata, IAddrBook, IKeyBook, IMetrics, IProtoBook):
"""
Interface for a peer store.
@ -487,85 +761,7 @@ class IPeerStore(IAddrBook, IPeerMetadata):
management, protocol handling, and key storage.
"""
@abstractmethod
def peer_info(self, peer_id: ID) -> PeerInfo:
"""
Retrieve the peer information for the specified peer.
Parameters
----------
peer_id : ID
The identifier of the peer.
Returns
-------
PeerInfo
The peer information object for the given peer.
"""
@abstractmethod
def get_protocols(self, peer_id: ID) -> list[str]:
"""
Retrieve the protocols associated with the specified peer.
Parameters
----------
peer_id : ID
The identifier of the peer.
Returns
-------
list[str]
A list of protocol identifiers.
Raises
------
PeerStoreError
If the peer ID is not found.
"""
@abstractmethod
def add_protocols(self, peer_id: ID, protocols: Sequence[str]) -> None:
"""
Add additional protocols for the specified peer.
Parameters
----------
peer_id : ID
The identifier of the peer.
protocols : Sequence[str]
The protocols to add.
"""
@abstractmethod
def set_protocols(self, peer_id: ID, protocols: Sequence[str]) -> None:
"""
Set the protocols for the specified peer.
Parameters
----------
peer_id : ID
The identifier of the peer.
protocols : Sequence[str]
The protocols to set.
"""
@abstractmethod
def peer_ids(self) -> list[ID]:
"""
Retrieve all peer identifiers stored in the peer store.
Returns
-------
list[ID]
A list of all peer IDs in the store.
"""
# -------METADATA---------
@abstractmethod
def get(self, peer_id: ID, key: str) -> Any:
"""
@ -606,6 +802,19 @@ class IPeerStore(IAddrBook, IPeerMetadata):
"""
@abstractmethod
def clear_metadata(self, peer_id: ID) -> None:
"""
Clears the stored latency metrics for the specified peer.
Parameters
----------
peer_id : ID
The identifier of the peer whose latency metrics are to be cleared.
"""
# --------ADDR-BOOK---------
@abstractmethod
def add_addr(self, peer_id: ID, addr: Multiaddr, ttl: int) -> None:
"""
@ -679,25 +888,7 @@ class IPeerStore(IAddrBook, IPeerMetadata):
"""
@abstractmethod
def add_pubkey(self, peer_id: ID, pubkey: PublicKey) -> None:
"""
Add a public key for the specified peer.
Parameters
----------
peer_id : ID
The identifier of the peer.
pubkey : PublicKey
The public key to add.
Raises
------
PeerStoreError
If the peer already has a public key set.
"""
# --------KEY-BOOK----------
@abstractmethod
def pubkey(self, peer_id: ID) -> PublicKey:
"""
@ -720,25 +911,6 @@ class IPeerStore(IAddrBook, IPeerMetadata):
"""
@abstractmethod
def add_privkey(self, peer_id: ID, privkey: PrivateKey) -> None:
"""
Add a private key for the specified peer.
Parameters
----------
peer_id : ID
The identifier of the peer.
privkey : PrivateKey
The private key to add.
Raises
------
PeerStoreError
If the peer already has a private key set.
"""
@abstractmethod
def privkey(self, peer_id: ID) -> PrivateKey:
"""
@ -761,6 +933,44 @@ class IPeerStore(IAddrBook, IPeerMetadata):
"""
@abstractmethod
def add_pubkey(self, peer_id: ID, pubkey: PublicKey) -> None:
"""
Add a public key for the specified peer.
Parameters
----------
peer_id : ID
The identifier of the peer.
pubkey : PublicKey
The public key to add.
Raises
------
PeerStoreError
If the peer already has a public key set.
"""
@abstractmethod
def add_privkey(self, peer_id: ID, privkey: PrivateKey) -> None:
"""
Add a private key for the specified peer.
Parameters
----------
peer_id : ID
The identifier of the peer.
privkey : PrivateKey
The private key to add.
Raises
------
PeerStoreError
If the peer already has a private key set.
"""
@abstractmethod
def add_key_pair(self, peer_id: ID, key_pair: KeyPair) -> None:
"""
@ -780,6 +990,213 @@ class IPeerStore(IAddrBook, IPeerMetadata):
"""
@abstractmethod
def peer_with_keys(self) -> list[ID]:
"""Returns all the peer IDs stored in the AddrBook"""
@abstractmethod
def clear_keydata(self, peer_id: ID) -> None:
"""
Remove all stored keydata for the specified peer.
Parameters
----------
peer_id : ID
The peer identifier whose keys are to be removed.
"""
# -------METRICS---------
@abstractmethod
def record_latency(self, peer_id: ID, RTT: float) -> None:
"""
Records a new round-trip time (RTT) latency value for the specified peer
using Exponentially Weighted Moving Average (EWMA).
Parameters
----------
peer_id : ID
The identifier of the peer for which latency is being recorded.
RTT : float
The round-trip time latency value to record.
"""
@abstractmethod
def latency_EWMA(self, peer_id: ID) -> float:
"""
Returns the current latency value for the specified peer using
Exponentially Weighted Moving Average (EWMA).
Parameters
----------
peer_id : ID
The identifier of the peer whose latency EWMA is to be returned.
"""
@abstractmethod
def clear_metrics(self, peer_id: ID) -> None:
"""
Clears the stored latency metrics for the specified peer.
Parameters
----------
peer_id : ID
The identifier of the peer whose latency metrics are to be cleared.
"""
# --------PROTO-BOOK----------
@abstractmethod
def get_protocols(self, peer_id: ID) -> list[str]:
"""
Retrieve the protocols associated with the specified peer.
Parameters
----------
peer_id : ID
The identifier of the peer.
Returns
-------
list[str]
A list of protocol identifiers.
Raises
------
PeerStoreError
If the peer ID is not found.
"""
@abstractmethod
def add_protocols(self, peer_id: ID, protocols: Sequence[str]) -> None:
"""
Add additional protocols for the specified peer.
Parameters
----------
peer_id : ID
The identifier of the peer.
protocols : Sequence[str]
The protocols to add.
"""
@abstractmethod
def set_protocols(self, peer_id: ID, protocols: Sequence[str]) -> None:
"""
Set the protocols for the specified peer.
Parameters
----------
peer_id : ID
The identifier of the peer.
protocols : Sequence[str]
The protocols to set.
"""
@abstractmethod
def remove_protocols(self, peer_id: ID, protocols: Sequence[str]) -> None:
"""
Removes the specified protocols from the peer's protocol list.
Parameters
----------
peer_id : ID
The identifier of the peer from which protocols will be removed.
protocols : Sequence[str]
A sequence of protocol strings to remove.
"""
@abstractmethod
def supports_protocols(self, peer_id: ID, protocols: Sequence[str]) -> list[str]:
"""
Returns the list of protocols from the input sequence that the peer supports.
Parameters
----------
peer_id : ID
The identifier of the peer to check for protocol support.
protocols : Sequence[str]
A sequence of protocol strings to check against the peer's
supported protocols.
"""
@abstractmethod
def first_supported_protocol(self, peer_id: ID, protocols: Sequence[str]) -> str:
"""
Returns the first protocol from the input list that the peer supports.
Parameters
----------
peer_id : ID
The identifier of the peer to check for supported protocols.
protocols : Sequence[str]
A sequence of protocol strings to check.
Returns
-------
str
The first matching protocol string, or an empty string
if none are supported.
"""
@abstractmethod
def clear_protocol_data(self, peer_id: ID) -> None:
"""
Clears all protocol data associated with the specified peer.
Parameters
----------
peer_id : ID
The identifier of the peer whose protocol data will be cleared.
"""
# --------PEER-STORE--------
@abstractmethod
def peer_info(self, peer_id: ID) -> PeerInfo:
"""
Retrieve the peer information for the specified peer.
Parameters
----------
peer_id : ID
The identifier of the peer.
Returns
-------
PeerInfo
The peer information object for the given peer.
"""
@abstractmethod
def peer_ids(self) -> list[ID]:
"""
Retrieve all peer identifiers stored in the peer store.
Returns
-------
list[ID]
A list of all peer IDs in the store.
"""
@abstractmethod
def clear_peerdata(self, peer_id: ID) -> None:
"""clear_peerdata"""
# -------------------------- listener interface.py --------------------------
@ -1315,6 +1732,60 @@ class IPeerData(ABC):
"""
@abstractmethod
def remove_protocols(self, protocols: Sequence[str]) -> None:
"""
Removes the specified protocols from this peer's list of supported protocols.
Parameters
----------
protocols : Sequence[str]
A sequence of protocol strings to be removed.
"""
@abstractmethod
def supports_protocols(self, protocols: Sequence[str]) -> list[str]:
"""
Returns the list of protocols from the input sequence that are supported
by this peer.
Parameters
----------
protocols : Sequence[str]
A sequence of protocol strings to check against this peer's supported
protocols.
Returns
-------
list[str]
A list of protocol strings that are supported.
"""
@abstractmethod
def first_supported_protocol(self, protocols: Sequence[str]) -> str:
"""
Returns the first protocol from the input list that this peer supports.
Parameters
----------
protocols : Sequence[str]
A sequence of protocol strings to check for support.
Returns
-------
str
The first matching protocol, or an empty string if none are supported.
"""
@abstractmethod
def clear_protocol_data(self) -> None:
"""
Clears all protocol data associated with this peer.
"""
@abstractmethod
def add_addrs(self, addrs: Sequence[Multiaddr]) -> None:
"""
@ -1324,6 +1795,8 @@ class IPeerData(ABC):
----------
addrs : Sequence[Multiaddr]
A sequence of multiaddresses to add.
ttl: inr
Time to live for the peer record
"""
@ -1382,6 +1855,12 @@ class IPeerData(ABC):
"""
@abstractmethod
def clear_metadata(self) -> None:
"""
Clears all metadata entries associated with this peer.
"""
@abstractmethod
def add_pubkey(self, pubkey: PublicKey) -> None:
"""
@ -1440,6 +1919,45 @@ class IPeerData(ABC):
"""
@abstractmethod
def clear_keydata(self) -> None:
"""
Clears all cryptographic key data associated with this peer,
including both public and private keys.
"""
@abstractmethod
def record_latency(self, new_latency: float) -> None:
"""
Records a new latency measurement using
Exponentially Weighted Moving Average (EWMA).
Parameters
----------
new_latency : float
The new round-trip time (RTT) latency value to incorporate
into the EWMA calculation.
"""
@abstractmethod
def latency_EWMA(self) -> float:
"""
Returns the current EWMA value of the recorded latency.
Returns
-------
float
The current latency estimate based on EWMA.
"""
@abstractmethod
def clear_metrics(self) -> None:
"""
Clears all latency-related metrics and resets the internal state.
"""
@abstractmethod
def update_last_identified(self) -> None:
"""

View File

@ -18,6 +18,13 @@ from libp2p.crypto.keys import (
PublicKey,
)
"""
Latency EWMA Smoothing governs the deacy of the EWMA (the speed at which
is changes). This must be a normalized (0-1) value.
1 is 100% change, 0 is no change.
"""
LATENCY_EWMA_SMOOTHING = 0.1
class PeerData(IPeerData):
pubkey: PublicKey | None
@ -27,6 +34,7 @@ class PeerData(IPeerData):
addrs: list[Multiaddr]
last_identified: int
ttl: int # Keep ttl=0 by default for always valid
latmap: float
def __init__(self) -> None:
self.pubkey = None
@ -36,6 +44,9 @@ class PeerData(IPeerData):
self.addrs = []
self.last_identified = int(time.time())
self.ttl = 0
self.latmap = 0
# --------PROTO-BOOK--------
def get_protocols(self) -> list[str]:
"""
@ -55,6 +66,37 @@ class PeerData(IPeerData):
"""
self.protocols = list(protocols)
def remove_protocols(self, protocols: Sequence[str]) -> None:
"""
:param protocols: protocols to remove
"""
for protocol in protocols:
if protocol in self.protocols:
self.protocols.remove(protocol)
def supports_protocols(self, protocols: Sequence[str]) -> list[str]:
"""
:param protocols: protocols to check from
:return: all supported protocols in the given list
"""
return [proto for proto in protocols if proto in self.protocols]
def first_supported_protocol(self, protocols: Sequence[str]) -> str:
"""
:param protocols: protocols to check from
:return: first supported protocol in the given list
"""
for protocol in protocols:
if protocol in self.protocols:
return protocol
return "None supported"
def clear_protocol_data(self) -> None:
"""Clear all protocols"""
self.protocols = []
# -------ADDR-BOOK---------
def add_addrs(self, addrs: Sequence[Multiaddr]) -> None:
"""
:param addrs: multiaddresses to add
@ -73,6 +115,7 @@ class PeerData(IPeerData):
"""Clear all addresses."""
self.addrs = []
# -------METADATA-----------
def put_metadata(self, key: str, val: Any) -> None:
"""
:param key: key in KV pair
@ -90,6 +133,11 @@ class PeerData(IPeerData):
return self.metadata[key]
raise PeerDataError("key not found")
def clear_metadata(self) -> None:
"""Clears metadata."""
self.metadata = {}
# -------KEY-BOOK---------------
def add_pubkey(self, pubkey: PublicKey) -> None:
"""
:param pubkey:
@ -120,9 +168,41 @@ class PeerData(IPeerData):
raise PeerDataError("private key not found")
return self.privkey
def clear_keydata(self) -> None:
"""Clears keydata"""
self.pubkey = None
self.privkey = None
# ----------METRICS--------------
def record_latency(self, new_latency: float) -> None:
"""
Records a new latency measurement for the given peer
using Exponentially Weighted Moving Average (EWMA)
:param new_latency: the new latency value
"""
s = LATENCY_EWMA_SMOOTHING
if s > 1 or s < 0:
s = 0.1
if self.latmap == 0:
self.latmap = new_latency
else:
prev = self.latmap
updated = ((1.0 - s) * prev) + (s * new_latency)
self.latmap = updated
def latency_EWMA(self) -> float:
"""Returns the latency EWMA value"""
return self.latmap
def clear_metrics(self) -> None:
"""Clear the latency metrics"""
self.latmap = 0
def update_last_identified(self) -> None:
self.last_identified = int(time.time())
# ----------TTL------------------
def get_last_identified(self) -> int:
"""
:return: last identified timestamp

View File

@ -2,6 +2,7 @@ from collections import (
defaultdict,
)
from collections.abc import (
AsyncIterable,
Sequence,
)
from typing import (
@ -11,6 +12,8 @@ from typing import (
from multiaddr import (
Multiaddr,
)
import trio
from trio import MemoryReceiveChannel, MemorySendChannel
from libp2p.abc import (
IPeerStore,
@ -40,6 +43,7 @@ class PeerStore(IPeerStore):
def __init__(self) -> None:
self.peer_data_map = defaultdict(PeerData)
self.addr_update_channels: dict[ID, MemorySendChannel[Multiaddr]] = {}
def peer_info(self, peer_id: ID) -> PeerInfo:
"""
@ -53,6 +57,29 @@ class PeerStore(IPeerStore):
return PeerInfo(peer_id, peer_data.get_addrs())
raise PeerStoreError("peer ID not found")
def peer_ids(self) -> list[ID]:
"""
:return: all of the peer IDs stored in peer store
"""
return list(self.peer_data_map.keys())
def clear_peerdata(self, peer_id: ID) -> None:
"""Clears the peer data of the peer"""
def valid_peer_ids(self) -> list[ID]:
"""
:return: all of the valid peer IDs stored in peer store
"""
valid_peer_ids: list[ID] = []
for peer_id, peer_data in self.peer_data_map.items():
if not peer_data.is_expired():
valid_peer_ids.append(peer_id)
else:
peer_data.clear_addrs()
return valid_peer_ids
# --------PROTO-BOOK--------
def get_protocols(self, peer_id: ID) -> list[str]:
"""
:param peer_id: peer ID to get protocols for
@ -79,23 +106,31 @@ class PeerStore(IPeerStore):
peer_data = self.peer_data_map[peer_id]
peer_data.set_protocols(list(protocols))
def peer_ids(self) -> list[ID]:
def remove_protocols(self, peer_id: ID, protocols: Sequence[str]) -> None:
"""
:param peer_id: peer ID to get info for
:param protocols: unsupported protocols to remove
"""
peer_data = self.peer_data_map[peer_id]
peer_data.remove_protocols(protocols)
def supports_protocols(self, peer_id: ID, protocols: Sequence[str]) -> list[str]:
"""
:return: all of the peer IDs stored in peer store
"""
return list(self.peer_data_map.keys())
peer_data = self.peer_data_map[peer_id]
return peer_data.supports_protocols(protocols)
def valid_peer_ids(self) -> list[ID]:
"""
:return: all of the valid peer IDs stored in peer store
"""
valid_peer_ids: list[ID] = []
for peer_id, peer_data in self.peer_data_map.items():
if not peer_data.is_expired():
valid_peer_ids.append(peer_id)
else:
peer_data.clear_addrs()
return valid_peer_ids
def first_supported_protocol(self, peer_id: ID, protocols: Sequence[str]) -> str:
peer_data = self.peer_data_map[peer_id]
return peer_data.first_supported_protocol(protocols)
def clear_protocol_data(self, peer_id: ID) -> None:
"""Clears prtocoldata"""
peer_data = self.peer_data_map[peer_id]
peer_data.clear_protocol_data()
# ------METADATA---------
def get(self, peer_id: ID, key: str) -> Any:
"""
@ -121,6 +156,13 @@ class PeerStore(IPeerStore):
peer_data = self.peer_data_map[peer_id]
peer_data.put_metadata(key, val)
def clear_metadata(self, peer_id: ID) -> None:
"""Clears metadata"""
peer_data = self.peer_data_map[peer_id]
peer_data.clear_metadata()
# -------ADDR-BOOK--------
def add_addr(self, peer_id: ID, addr: Multiaddr, ttl: int = 0) -> None:
"""
:param peer_id: peer ID to add address for
@ -140,6 +182,13 @@ class PeerStore(IPeerStore):
peer_data.set_ttl(ttl)
peer_data.update_last_identified()
if peer_id in self.addr_update_channels:
for addr in addrs:
try:
self.addr_update_channels[peer_id].send_nowait(addr)
except trio.WouldBlock:
pass # Or consider logging / dropping / replacing stream
def addrs(self, peer_id: ID) -> list[Multiaddr]:
"""
:param peer_id: peer ID to get addrs for
@ -165,7 +214,7 @@ class PeerStore(IPeerStore):
def peers_with_addrs(self) -> list[ID]:
"""
:return: all of the peer IDs which has addrs stored in peer store
:return: all of the peer IDs which has addrsfloat stored in peer store
"""
# Add all peers with addrs at least 1 to output
output: list[ID] = []
@ -179,6 +228,27 @@ class PeerStore(IPeerStore):
peer_data.clear_addrs()
return output
async def addr_stream(self, peer_id: ID) -> AsyncIterable[Multiaddr]:
"""
Returns an async stream of newly added addresses for the given peer.
This function allows consumers to subscribe to address updates for a peer
and receive each new address as it is added via `add_addr` or `add_addrs`.
:param peer_id: The ID of the peer to monitor address updates for.
:return: An async iterator yielding Multiaddr instances as they are added.
"""
send: MemorySendChannel[Multiaddr]
receive: MemoryReceiveChannel[Multiaddr]
send, receive = trio.open_memory_channel(0)
self.addr_update_channels[peer_id] = send
async for addr in receive:
yield addr
# -------KEY-BOOK---------
def add_pubkey(self, peer_id: ID, pubkey: PublicKey) -> None:
"""
:param peer_id: peer ID to add public key for
@ -239,6 +309,45 @@ class PeerStore(IPeerStore):
self.add_pubkey(peer_id, key_pair.public_key)
self.add_privkey(peer_id, key_pair.private_key)
def peer_with_keys(self) -> list[ID]:
"""Returns the peer_ids for which keys are stored"""
return [
peer_id
for peer_id, pdata in self.peer_data_map.items()
if pdata.pubkey is not None
]
def clear_keydata(self, peer_id: ID) -> None:
"""Clears the keys of the peer"""
peer_data = self.peer_data_map[peer_id]
peer_data.clear_keydata()
# --------METRICS--------
def record_latency(self, peer_id: ID, RTT: float) -> None:
"""
Records a new latency measurement for the given peer
using Exponentially Weighted Moving Average (EWMA)
:param peer_id: peer ID to get private key for
:param RTT: the new latency value (round trip time)
"""
peer_data = self.peer_data_map[peer_id]
peer_data.record_latency(RTT)
def latency_EWMA(self, peer_id: ID) -> float:
"""
:param peer_id: peer ID to get private key for
:return: The latency EWMA value for that peer
"""
peer_data = self.peer_data_map[peer_id]
return peer_data.latency_EWMA()
def clear_metrics(self, peer_id: ID) -> None:
"""Clear the latency metrics"""
peer_data = self.peer_data_map[peer_id]
peer_data.clear_metrics()
class PeerStoreError(KeyError):
"""Raised when peer ID is not found in peer store."""

View File

@ -98,16 +98,32 @@ class YamuxStream(IMuxedStream):
# Flow control: Check if we have enough send window
total_len = len(data)
sent = 0
logging.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
while sent < total_len:
# Wait for available window with timeout
timeout = False
async with self.window_lock:
# Wait for available window
while self.send_window == 0 and not self.closed:
# Release lock while waiting
if self.send_window == 0:
logging.debug(
f"Stream {self.stream_id}: Window is zero, waiting for update"
)
# Release lock and wait with timeout
self.window_lock.release()
await trio.sleep(0.01)
# To avoid re-acquiring the lock immediately,
with trio.move_on_after(5.0) as cancel_scope:
while self.send_window == 0 and not self.closed:
await trio.sleep(0.01)
# If we timed out, cancel the scope
timeout = cancel_scope.cancelled_caught
# Re-acquire lock
await self.window_lock.acquire()
# If we timed out waiting for window update, raise an error
if timeout:
raise MuxedStreamError(
"Timed out waiting for window update after 5 seconds."
)
if self.closed:
raise MuxedStreamError("Stream is closed")
@ -123,25 +139,45 @@ class YamuxStream(IMuxedStream):
await self.conn.secured_conn.write(header + chunk)
sent += to_send
# If window is getting low, consider updating
if self.send_window < DEFAULT_WINDOW_SIZE // 2:
await self.send_window_update()
async def send_window_update(self, increment: int | None = None) -> None:
"""Send a window update to peer."""
if increment is None:
increment = DEFAULT_WINDOW_SIZE - self.recv_window
async def send_window_update(self, increment: int, skip_lock: bool = False) -> None:
"""
Send a window update to peer.
param:increment: The amount to increment the window size by.
If None, uses the difference between DEFAULT_WINDOW_SIZE
and current receive window.
param:skip_lock (bool): If True, skips acquiring window_lock.
This should only be used when calling from a context
that already holds the lock.
"""
if increment <= 0:
# If increment is zero or negative, skip sending update
logging.debug(
f"Stream {self.stream_id}: Skipping window update"
f"(increment={increment})"
)
return
logging.debug(
f"Stream {self.stream_id}: Sending window update with increment={increment}"
)
async with self.window_lock:
self.recv_window += increment
async def _do_window_update() -> None:
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_WINDOW_UPDATE, 0, self.stream_id, increment
YAMUX_HEADER_FORMAT,
0,
TYPE_WINDOW_UPDATE,
0,
self.stream_id,
increment,
)
await self.conn.secured_conn.write(header)
if skip_lock:
await _do_window_update()
else:
async with self.window_lock:
await _do_window_update()
async def read(self, n: int | None = -1) -> bytes:
# Handle None value for n by converting it to -1
if n is None:
@ -154,55 +190,68 @@ class YamuxStream(IMuxedStream):
)
raise MuxedStreamEOF("Stream is closed for receiving")
# If reading until EOF (n == -1), block until stream is closed
if n == -1:
while not self.recv_closed and not self.conn.event_shutting_down.is_set():
data = b""
while not self.conn.event_shutting_down.is_set():
# Check if there's data in the buffer
buffer = self.conn.stream_buffers.get(self.stream_id)
if buffer and len(buffer) > 0:
# Wait for closure even if data is available
logging.debug(
f"Stream {self.stream_id}:Waiting for FIN before returning data"
)
await self.conn.stream_events[self.stream_id].wait()
self.conn.stream_events[self.stream_id] = trio.Event()
else:
# No data, wait for data or closure
logging.debug(f"Stream {self.stream_id}: Waiting for data or FIN")
await self.conn.stream_events[self.stream_id].wait()
self.conn.stream_events[self.stream_id] = trio.Event()
# After loop, check if stream is closed or shutting down
async with self.conn.streams_lock:
if self.conn.event_shutting_down.is_set():
logging.debug(f"Stream {self.stream_id}: Connection shutting down")
raise MuxedStreamEOF("Connection shut down")
if self.closed:
if self.reset_received:
logging.debug(f"Stream {self.stream_id}: Stream was reset")
raise MuxedStreamReset("Stream was reset")
else:
logging.debug(
f"Stream {self.stream_id}: Stream closed cleanly (EOF)"
)
raise MuxedStreamEOF("Stream closed cleanly (EOF)")
buffer = self.conn.stream_buffers.get(self.stream_id)
# If buffer is not available, check if stream is closed
if buffer is None:
logging.debug(
f"Stream {self.stream_id}: Buffer gone, assuming closed"
)
logging.debug(f"Stream {self.stream_id}: No buffer available")
raise MuxedStreamEOF("Stream buffer closed")
# If we have data in buffer, process it
if len(buffer) > 0:
chunk = bytes(buffer)
buffer.clear()
data += chunk
# Send window update for the chunk we just read
async with self.window_lock:
self.recv_window += len(chunk)
logging.debug(f"Stream {self.stream_id}: Update {len(chunk)}")
await self.send_window_update(len(chunk), skip_lock=True)
# If stream is closed (FIN received) and buffer is empty, break
if self.recv_closed and len(buffer) == 0:
logging.debug(f"Stream {self.stream_id}: EOF reached")
raise MuxedStreamEOF("Stream is closed for receiving")
# Return all buffered data
data = bytes(buffer)
buffer.clear()
logging.debug(f"Stream {self.stream_id}: Returning {len(data)} bytes")
logging.debug(f"Stream {self.stream_id}: Closed with empty buffer")
break
# If stream was reset, raise reset error
if self.reset_received:
logging.debug(f"Stream {self.stream_id}: Stream was reset")
raise MuxedStreamReset("Stream was reset")
# Wait for more data or stream closure
logging.debug(f"Stream {self.stream_id}: Waiting for data or FIN")
await self.conn.stream_events[self.stream_id].wait()
self.conn.stream_events[self.stream_id] = trio.Event()
# After loop exit, first check if we have data to return
if data:
logging.debug(
f"Stream {self.stream_id}: Returning {len(data)} bytes after loop"
)
return data
# For specific size read (n > 0), return available data immediately
return await self.conn.read_stream(self.stream_id, n)
# No data accumulated, now check why we exited the loop
if self.conn.event_shutting_down.is_set():
logging.debug(f"Stream {self.stream_id}: Connection shutting down")
raise MuxedStreamEOF("Connection shut down")
# Return empty data
return b""
else:
data = await self.conn.read_stream(self.stream_id, n)
async with self.window_lock:
self.recv_window += len(data)
logging.debug(
f"Stream {self.stream_id}: Sending window update after read, "
f"increment={len(data)}"
)
await self.send_window_update(len(data), skip_lock=True)
return data
async def close(self) -> None:
if not self.send_closed:

View File

@ -0,0 +1,6 @@
Fixed several flow-control and concurrency issues in the `YamuxStream` class. Previously, stress-testing revealed that transferring data over `DEFAULT_WINDOW_SIZE` would break the stream due to inconsistent window update handling and lock management. The fixes include:
- Removed sending of window updates during writes to maintain correct flow-control.
- Added proper timeout handling when releasing and acquiring locks to prevent concurrency errors.
- Corrected the `read` function to properly handle window updates for both `read_until_EOF` and `read_n_bytes`.
- Added event logging at `send_window_updates` and `waiting_for_window_updates` for better observability.

View File

@ -0,0 +1 @@
Added extra tests for identify push concurrency cap under high peer load

View File

@ -35,6 +35,8 @@ from tests.utils.factories import (
)
from tests.utils.utils import (
create_mock_connections,
run_host_forever,
wait_until_listening,
)
logger = logging.getLogger("libp2p.identity.identify-push-test")
@ -503,3 +505,91 @@ async def test_push_identify_to_peers_respects_concurrency_limit():
assert state["max_observed"] <= CONCURRENCY_LIMIT, (
f"Max concurrency observed: {state['max_observed']}"
)
@pytest.mark.trio
async def test_all_peers_receive_identify_push_with_semaphore(security_protocol):
dummy_peers = []
async with host_pair_factory(security_protocol=security_protocol) as (host_a, _):
# Create dummy peers
for _ in range(50):
key_pair = create_new_key_pair()
dummy_host = new_host(key_pair=key_pair)
dummy_host.set_stream_handler(
ID_PUSH, identify_push_handler_for(dummy_host)
)
listen_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")
dummy_peers.append((dummy_host, listen_addr))
async with trio.open_nursery() as nursery:
# Start all dummy hosts
for host, listen_addr in dummy_peers:
nursery.start_soon(run_host_forever, host, listen_addr)
# Wait for all hosts to finish setting up listeners
for host, _ in dummy_peers:
await wait_until_listening(host)
# Now connect host_a → dummy peers
for host, _ in dummy_peers:
await host_a.connect(info_from_p2p_addr(host.get_addrs()[0]))
await push_identify_to_peers(
host_a,
)
await trio.sleep(0.5)
peer_id_a = host_a.get_id()
for host, _ in dummy_peers:
dummy_peerstore = host.get_peerstore()
assert peer_id_a in dummy_peerstore.peer_ids()
nursery.cancel_scope.cancel()
@pytest.mark.trio
async def test_all_peers_receive_identify_push_with_semaphore_under_high_peer_load(
security_protocol,
):
dummy_peers = []
async with host_pair_factory(security_protocol=security_protocol) as (host_a, _):
# Create dummy peers
# Breaking with more than 500 peers
# Trio have a async tasks limit of 1000
for _ in range(499):
key_pair = create_new_key_pair()
dummy_host = new_host(key_pair=key_pair)
dummy_host.set_stream_handler(
ID_PUSH, identify_push_handler_for(dummy_host)
)
listen_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")
dummy_peers.append((dummy_host, listen_addr))
async with trio.open_nursery() as nursery:
# Start all dummy hosts
for host, listen_addr in dummy_peers:
nursery.start_soon(run_host_forever, host, listen_addr)
# Wait for all hosts to finish setting up listeners
for host, _ in dummy_peers:
await wait_until_listening(host)
# Now connect host_a → dummy peers
for host, _ in dummy_peers:
await host_a.connect(info_from_p2p_addr(host.get_addrs()[0]))
await push_identify_to_peers(
host_a,
)
await trio.sleep(0.5)
peer_id_a = host_a.get_id()
for host, _ in dummy_peers:
dummy_peerstore = host.get_peerstore()
assert peer_id_a in dummy_peerstore.peer_ids()
nursery.cancel_scope.cancel()

View File

@ -6,10 +6,12 @@ from multiaddr import Multiaddr
from libp2p.crypto.secp256k1 import (
create_new_key_pair,
)
from libp2p.peer.id import ID
from libp2p.peer.peerdata import (
PeerData,
PeerDataError,
)
from libp2p.peer.peerstore import PeerStore
MOCK_ADDR = Multiaddr("/ip4/127.0.0.1/tcp/4001")
MOCK_KEYPAIR = create_new_key_pair()
@ -39,6 +41,59 @@ def test_set_protocols():
assert peer_data.get_protocols() == protocols
# Test case when removing protocols:
def test_remove_protocols():
peer_data = PeerData()
protocols: Sequence[str] = ["protocol1", "protocol2"]
peer_data.set_protocols(protocols)
peer_data.remove_protocols(["protocol1"])
assert peer_data.get_protocols() == ["protocol2"]
# Test case when clearing the protocol list:
def test_clear_protocol_data():
peer_data = PeerData()
protocols: Sequence[str] = ["protocol1", "protocol2"]
peer_data.set_protocols(protocols)
peer_data.clear_protocol_data()
assert peer_data.get_protocols() == []
# Test case when supports protocols:
def test_supports_protocols():
peer_data = PeerData()
peer_data.set_protocols(["protocol1", "protocol2", "protocol3"])
input_protocols = ["protocol1", "protocol4", "protocol2"]
supported = peer_data.supports_protocols(input_protocols)
assert supported == ["protocol1", "protocol2"]
# Test case for first supported protocol is found
def test_first_supported_protocol_found():
peer_data = PeerData()
peer_data.set_protocols(["protocolA", "protocolB"])
input_protocols = ["protocolC", "protocolB", "protocolA"]
first = peer_data.first_supported_protocol(input_protocols)
assert first == "protocolB"
# Test case for first supported protocol not found
def test_first_supported_protocol_none():
peer_data = PeerData()
peer_data.set_protocols(["protocolX", "protocolY"])
input_protocols = ["protocolA", "protocolB"]
first = peer_data.first_supported_protocol(input_protocols)
assert first == "None supported"
# Test case when adding addresses
def test_add_addrs():
peer_data = PeerData()
@ -81,6 +136,15 @@ def test_get_metadata_key_not_found():
peer_data.get_metadata("nonexistent_key")
# Test case for clearing metadata
def test_clear_metadata():
peer_data = PeerData()
peer_data.metadata = {"key1": "value1", "key2": "value2"}
peer_data.clear_metadata()
assert peer_data.metadata == {}
# Test case for adding public key
def test_add_pubkey():
peer_data = PeerData()
@ -107,3 +171,71 @@ def test_get_privkey_not_found():
peer_data = PeerData()
with pytest.raises(PeerDataError):
peer_data.get_privkey()
# Test case for returning all the peers with stored keys
def test_peer_with_keys():
peer_store = PeerStore()
peer_id_1 = ID(b"peer1")
peer_id_2 = ID(b"peer2")
peer_data_1 = PeerData()
peer_data_2 = PeerData()
peer_data_1.pubkey = MOCK_PUBKEY
peer_data_2.pubkey = None
peer_store.peer_data_map = {
peer_id_1: peer_data_1,
peer_id_2: peer_data_2,
}
assert peer_store.peer_with_keys() == [peer_id_1]
# Test case for clearing the key book
def test_clear_keydata():
peer_store = PeerStore()
peer_id = ID(b"peer123")
peer_data = PeerData()
peer_data.pubkey = MOCK_PUBKEY
peer_data.privkey = MOCK_PRIVKEY
peer_store.peer_data_map = {peer_id: peer_data}
peer_store.clear_keydata(peer_id)
assert peer_data.pubkey is None
assert peer_data.privkey is None
# Test case for recording latency for the first time
def test_record_latency_initial():
peer_data = PeerData()
assert peer_data.latency_EWMA() == 0
peer_data.record_latency(100.0)
assert peer_data.latency_EWMA() == 100.0
# Test case for updating latency
def test_record_latency_updates_ewma():
peer_data = PeerData()
peer_data.record_latency(100.0) # first measurement
first = peer_data.latency_EWMA()
peer_data.record_latency(50.0) # second measurement
second = peer_data.latency_EWMA()
assert second < first # EWMA should have smoothed downward
assert second > 50.0 # Not as low as the new latency
assert second != first
def test_clear_metrics():
peer_data = PeerData()
peer_data.record_latency(200.0)
assert peer_data.latency_EWMA() == 200.0
peer_data.clear_metrics()
assert peer_data.latency_EWMA() == 0

View File

@ -2,6 +2,7 @@ import time
import pytest
from multiaddr import Multiaddr
import trio
from libp2p.peer.id import ID
from libp2p.peer.peerstore import (
@ -89,3 +90,33 @@ def test_peers():
store.add_addr(ID(b"peer3"), Multiaddr("/ip4/127.0.0.1/tcp/4001"), 10)
assert set(store.peer_ids()) == {ID(b"peer1"), ID(b"peer2"), ID(b"peer3")}
@pytest.mark.trio
async def test_addr_stream_yields_new_addrs():
store = PeerStore()
peer_id = ID(b"peer1")
addr1 = Multiaddr("/ip4/127.0.0.1/tcp/4001")
addr2 = Multiaddr("/ip4/127.0.0.1/tcp/4002")
collected = []
async def consume_addrs():
async for addr in store.addr_stream(peer_id):
collected.append(addr)
if len(collected) == 2:
break
async with trio.open_nursery() as nursery:
nursery.start_soon(consume_addrs)
await trio.sleep(2) # Give time for the stream to start
store.add_addr(peer_id, addr1, ttl=10)
await trio.sleep(0.2)
store.add_addr(peer_id, addr2, ttl=10)
await trio.sleep(0.2)
# After collecting expected addresses, cancel the stream
nursery.cancel_scope.cancel()
assert collected == [addr1, addr2]

View File

@ -0,0 +1,199 @@
import logging
import pytest
import trio
from trio.testing import (
memory_stream_pair,
)
from libp2p.abc import IRawConnection
from libp2p.crypto.ed25519 import (
create_new_key_pair,
)
from libp2p.peer.id import (
ID,
)
from libp2p.security.insecure.transport import (
InsecureTransport,
)
from libp2p.stream_muxer.yamux.yamux import (
Yamux,
YamuxStream,
)
class TrioStreamAdapter(IRawConnection):
"""Adapter to make trio memory streams work with libp2p."""
def __init__(self, send_stream, receive_stream, is_initiator=False):
self.send_stream = send_stream
self.receive_stream = receive_stream
self.is_initiator = is_initiator
async def write(self, data: bytes) -> None:
logging.debug(f"Attempting to write {len(data)} bytes")
with trio.move_on_after(2):
await self.send_stream.send_all(data)
async def read(self, n: int | None = None) -> bytes:
if n is None or n <= 0:
raise ValueError("Reading unbounded or zero bytes not supported")
logging.debug(f"Attempting to read {n} bytes")
with trio.move_on_after(2):
data = await self.receive_stream.receive_some(n)
logging.debug(f"Read {len(data)} bytes")
return data
async def close(self) -> None:
logging.debug("Closing stream")
await self.send_stream.aclose()
await self.receive_stream.aclose()
def get_remote_address(self) -> tuple[str, int] | None:
"""Return None since this is a test adapter without real network info."""
return None
@pytest.fixture
def key_pair():
return create_new_key_pair()
@pytest.fixture
def peer_id(key_pair):
return ID.from_pubkey(key_pair.public_key)
@pytest.fixture
async def secure_conn_pair(key_pair, peer_id):
"""Create a pair of secure connections for testing."""
logging.debug("Setting up secure_conn_pair")
client_send, server_receive = memory_stream_pair()
server_send, client_receive = memory_stream_pair()
client_rw = TrioStreamAdapter(client_send, client_receive)
server_rw = TrioStreamAdapter(server_send, server_receive)
insecure_transport = InsecureTransport(key_pair)
async def run_outbound(nursery_results):
with trio.move_on_after(5):
client_conn = await insecure_transport.secure_outbound(client_rw, peer_id)
logging.debug("Outbound handshake complete")
nursery_results["client"] = client_conn
async def run_inbound(nursery_results):
with trio.move_on_after(5):
server_conn = await insecure_transport.secure_inbound(server_rw)
logging.debug("Inbound handshake complete")
nursery_results["server"] = server_conn
nursery_results = {}
async with trio.open_nursery() as nursery:
nursery.start_soon(run_outbound, nursery_results)
nursery.start_soon(run_inbound, nursery_results)
await trio.sleep(0.1) # Give tasks a chance to finish
client_conn = nursery_results.get("client")
server_conn = nursery_results.get("server")
if client_conn is None or server_conn is None:
raise RuntimeError("Handshake failed: client_conn or server_conn is None")
logging.debug("secure_conn_pair setup complete")
return client_conn, server_conn
@pytest.fixture
async def yamux_pair(secure_conn_pair, peer_id):
"""Create a pair of Yamux multiplexers for testing."""
logging.debug("Setting up yamux_pair")
client_conn, server_conn = secure_conn_pair
client_yamux = Yamux(client_conn, peer_id, is_initiator=True)
server_yamux = Yamux(server_conn, peer_id, is_initiator=False)
async with trio.open_nursery() as nursery:
with trio.move_on_after(5):
nursery.start_soon(client_yamux.start)
nursery.start_soon(server_yamux.start)
await trio.sleep(0.1)
logging.debug("yamux_pair started")
yield client_yamux, server_yamux
logging.debug("yamux_pair cleanup")
@pytest.mark.trio
async def test_yamux_race_condition_without_locks(yamux_pair):
"""
Test for race-around/interleaving in Yamux streams,when reading in
segments of data.
This launches concurrent writers/readers on both sides of a stream.
If there is no proper locking, the received data may be interleaved
or corrupted.
The test creates structured messages and verifies they are received
intact and in order.
Without proper locking, concurrent read/write operations could cause
data corruption
or message interleaving, which this test will catch.
"""
client_yamux, server_yamux = yamux_pair
client_stream: YamuxStream = await client_yamux.open_stream()
server_stream: YamuxStream = await server_yamux.accept_stream()
MSG_COUNT = 10
MSG_SIZE = 256 * 1024 # At max,only DEFAULT_WINDOW_SIZE bytes can be read
client_msgs = [
f"CLIENT-MSG-{i:03d}-".encode().ljust(MSG_SIZE, b"C") for i in range(MSG_COUNT)
]
server_msgs = [
f"SERVER-MSG-{i:03d}-".encode().ljust(MSG_SIZE, b"S") for i in range(MSG_COUNT)
]
client_received = []
server_received = []
async def writer(stream, msgs, name):
"""Write messages with minimal delays to encourage race conditions."""
for i, msg in enumerate(msgs):
await stream.write(msg)
# Yield control frequently to encourage interleaving
if i % 5 == 0:
await trio.sleep(0.005)
async def reader(stream, received, name):
"""Read messages and store them for verification."""
for i in range(MSG_COUNT):
data = await stream.read(MSG_SIZE)
received.append(data)
if i % 3 == 0:
await trio.sleep(0.001)
# Running all operations concurrently
async with trio.open_nursery() as nursery:
nursery.start_soon(writer, client_stream, client_msgs, "client")
nursery.start_soon(writer, server_stream, server_msgs, "server")
nursery.start_soon(reader, client_stream, client_received, "client")
nursery.start_soon(reader, server_stream, server_received, "server")
assert len(client_received) == MSG_COUNT, (
f"Client received {len(client_received)} messages, expected {MSG_COUNT}"
)
assert len(server_received) == MSG_COUNT, (
f"Server received {len(server_received)} messages, expected {MSG_COUNT}"
)
assert client_received == server_msgs, (
"Client did not receive server messages in order or intact!"
)
assert server_received == client_msgs, (
"Server did not receive client messages in order or intact!"
)
for i, msg in enumerate(client_received):
assert len(msg) == MSG_SIZE, (
f"Client message {i} has wrong size: {len(msg)} != {MSG_SIZE}"
)
for i, msg in enumerate(server_received):
assert len(msg) == MSG_SIZE, (
f"Server message {i} has wrong size: {len(msg)} != {MSG_SIZE}"
)
await client_stream.close()
await server_stream.close()

View File

@ -0,0 +1,195 @@
import logging
import pytest
import trio
from trio.testing import (
memory_stream_pair,
)
from libp2p.abc import IRawConnection
from libp2p.crypto.ed25519 import (
create_new_key_pair,
)
from libp2p.peer.id import (
ID,
)
from libp2p.security.insecure.transport import (
InsecureTransport,
)
from libp2p.stream_muxer.exceptions import MuxedStreamEOF
from libp2p.stream_muxer.yamux.yamux import (
Yamux,
YamuxStream,
)
class TrioStreamAdapter(IRawConnection):
"""Adapter to make trio memory streams work with libp2p."""
def __init__(self, send_stream, receive_stream, is_initiator=False):
self.send_stream = send_stream
self.receive_stream = receive_stream
self.is_initiator = is_initiator
async def write(self, data: bytes) -> None:
logging.debug(f"Attempting to write {len(data)} bytes")
with trio.move_on_after(2):
await self.send_stream.send_all(data)
async def read(self, n: int | None = None) -> bytes:
if n is None or n <= 0:
raise ValueError("Reading unbounded or zero bytes not supported")
logging.debug(f"Attempting to read {n} bytes")
with trio.move_on_after(2):
data = await self.receive_stream.receive_some(n)
logging.debug(f"Read {len(data)} bytes")
return data
async def close(self) -> None:
logging.debug("Closing stream")
await self.send_stream.aclose()
await self.receive_stream.aclose()
def get_remote_address(self) -> tuple[str, int] | None:
"""Return None since this is a test adapter without real network info."""
return None
@pytest.fixture
def key_pair():
return create_new_key_pair()
@pytest.fixture
def peer_id(key_pair):
return ID.from_pubkey(key_pair.public_key)
@pytest.fixture
async def secure_conn_pair(key_pair, peer_id):
"""Create a pair of secure connections for testing."""
logging.debug("Setting up secure_conn_pair")
client_send, server_receive = memory_stream_pair()
server_send, client_receive = memory_stream_pair()
client_rw = TrioStreamAdapter(client_send, client_receive)
server_rw = TrioStreamAdapter(server_send, server_receive)
insecure_transport = InsecureTransport(key_pair)
async def run_outbound(nursery_results):
with trio.move_on_after(5):
client_conn = await insecure_transport.secure_outbound(client_rw, peer_id)
logging.debug("Outbound handshake complete")
nursery_results["client"] = client_conn
async def run_inbound(nursery_results):
with trio.move_on_after(5):
server_conn = await insecure_transport.secure_inbound(server_rw)
logging.debug("Inbound handshake complete")
nursery_results["server"] = server_conn
nursery_results = {}
async with trio.open_nursery() as nursery:
nursery.start_soon(run_outbound, nursery_results)
nursery.start_soon(run_inbound, nursery_results)
await trio.sleep(0.1) # Give tasks a chance to finish
client_conn = nursery_results.get("client")
server_conn = nursery_results.get("server")
if client_conn is None or server_conn is None:
raise RuntimeError("Handshake failed: client_conn or server_conn is None")
logging.debug("secure_conn_pair setup complete")
return client_conn, server_conn
@pytest.fixture
async def yamux_pair(secure_conn_pair, peer_id):
"""Create a pair of Yamux multiplexers for testing."""
logging.debug("Setting up yamux_pair")
client_conn, server_conn = secure_conn_pair
client_yamux = Yamux(client_conn, peer_id, is_initiator=True)
server_yamux = Yamux(server_conn, peer_id, is_initiator=False)
async with trio.open_nursery() as nursery:
with trio.move_on_after(5):
nursery.start_soon(client_yamux.start)
nursery.start_soon(server_yamux.start)
await trio.sleep(0.1)
logging.debug("yamux_pair started")
yield client_yamux, server_yamux
logging.debug("yamux_pair cleanup")
@pytest.mark.trio
async def test_yamux_race_condition_without_locks(yamux_pair):
"""
Test for race-around/interleaving in Yamux streams,when reading till
EOF is being used.
This launches concurrent writers/readers on both sides of a stream.
If there is no proper locking, the received data may be interleaved
or corrupted.
The test creates structured messages and verifies they are received
intact and in order.
Without proper locking, concurrent read/write operations could cause
data corruption
or message interleaving, which this test will catch.
"""
client_yamux, server_yamux = yamux_pair
client_stream: YamuxStream = await client_yamux.open_stream()
server_stream: YamuxStream = await server_yamux.accept_stream()
MSG_COUNT = 1
MSG_SIZE = 512 * 1024
client_msgs = [
f"CLIENT-MSG-{i:03d}-".encode().ljust(MSG_SIZE, b"C") for i in range(MSG_COUNT)
]
server_msgs = [
f"SERVER-MSG-{i:03d}-".encode().ljust(MSG_SIZE, b"S") for i in range(MSG_COUNT)
]
client_received = []
server_received = []
async def writer(stream, msgs, name):
"""Write messages with minimal delays to encourage race conditions."""
for i, msg in enumerate(msgs):
await stream.write(msg)
# Yield control frequently to encourage interleaving
if i % 5 == 0:
await trio.sleep(0.005)
async def reader(stream, received, name):
"""Read messages and store them for verification."""
try:
data = await stream.read()
if data:
received.append(data)
except MuxedStreamEOF:
pass
# Running all operations concurrently
async with trio.open_nursery() as nursery:
nursery.start_soon(writer, client_stream, client_msgs, "client")
nursery.start_soon(writer, server_stream, server_msgs, "server")
nursery.start_soon(reader, client_stream, client_received, "client")
nursery.start_soon(reader, server_stream, server_received, "server")
assert client_received == server_msgs, (
"Client did not receive server messages in order or intact!"
)
assert server_received == client_msgs, (
"Server did not receive client messages in order or intact!"
)
for i, msg in enumerate(client_received):
assert len(msg) == MSG_SIZE, (
f"Client message {i} has wrong size: {len(msg)} != {MSG_SIZE}"
)
for i, msg in enumerate(server_received):
assert len(msg) == MSG_SIZE, (
f"Server message {i} has wrong size: {len(msg)} != {MSG_SIZE}"
)
await client_stream.close()
await server_stream.close()

View File

@ -2,13 +2,30 @@ from unittest.mock import (
MagicMock,
)
import trio
def create_mock_connections() -> dict:
from libp2p.abc import IHost
def create_mock_connections(count: int = 50) -> dict:
connections = {}
for i in range(1, 31):
for i in range(1, count):
peer_id = f"peer-{i}"
mock_conn = MagicMock(name=f"INetConn-{i}")
connections[peer_id] = mock_conn
return connections
async def run_host_forever(host: IHost, addr):
async with host.run([addr]):
await trio.sleep_forever()
async def wait_until_listening(host, timeout=3):
with trio.move_on_after(timeout):
while not host.get_addrs():
await trio.sleep(0.05)
return
raise RuntimeError("Timed out waiting for host to get an address")