mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Merge branch 'main' into feature/bootstrap
This commit is contained in:
754
libp2p/abc.py
754
libp2p/abc.py
@ -385,6 +385,18 @@ class IPeerMetadata(ABC):
|
||||
:raises Exception: If the operation is unsuccessful.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def clear_metadata(self, peer_id: ID) -> None:
|
||||
"""
|
||||
Remove all stored metadata for the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer identifier whose metadata are to be removed.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# -------------------------- addrbook interface.py --------------------------
|
||||
|
||||
@ -476,10 +488,272 @@ class IAddrBook(ABC):
|
||||
"""
|
||||
|
||||
|
||||
# -------------------------- keybook interface.py --------------------------
|
||||
|
||||
|
||||
class IKeyBook(ABC):
|
||||
"""
|
||||
Interface for an key book.
|
||||
|
||||
Provides methods for managing cryptographic keys.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def pubkey(self, peer_id: ID) -> PublicKey:
|
||||
"""
|
||||
Returns the public key of the specified peer
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer identifier whose public key is to be returned.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def privkey(self, peer_id: ID) -> PrivateKey:
|
||||
"""
|
||||
Returns the private key of the specified peer
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer identifier whose private key is to be returned.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_pubkey(self, peer_id: ID, pubkey: PublicKey) -> None:
|
||||
"""
|
||||
Adds the public key for a specified peer
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer identifier whose public key is to be added
|
||||
pubkey: PublicKey
|
||||
The public key of the peer
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_privkey(self, peer_id: ID, privkey: PrivateKey) -> None:
|
||||
"""
|
||||
Adds the private key for a specified peer
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer identifier whose private key is to be added
|
||||
privkey: PrivateKey
|
||||
The private key of the peer
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_key_pair(self, peer_id: ID, key_pair: KeyPair) -> None:
|
||||
"""
|
||||
Adds the key pair for a specified peer
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer identifier whose key pair is to be added
|
||||
key_pair: KeyPair
|
||||
The key pair of the peer
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def peer_with_keys(self) -> list[ID]:
|
||||
"""Returns all the peer IDs stored in the AddrBook"""
|
||||
|
||||
@abstractmethod
|
||||
def clear_keydata(self, peer_id: ID) -> None:
|
||||
"""
|
||||
Remove all stored keydata for the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer identifier whose keys are to be removed.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# -------------------------- metrics interface.py --------------------------
|
||||
|
||||
|
||||
class IMetrics(ABC):
|
||||
"""
|
||||
Interface for metrics of peer interaction.
|
||||
|
||||
Provides methods for managing the metrics.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def record_latency(self, peer_id: ID, RTT: float) -> None:
|
||||
"""
|
||||
Records a new round-trip time (RTT) latency value for the specified peer
|
||||
using Exponentially Weighted Moving Average (EWMA).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer for which latency is being recorded.
|
||||
|
||||
RTT : float
|
||||
The round-trip time latency value to record.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def latency_EWMA(self, peer_id: ID) -> float:
|
||||
"""
|
||||
Returns the current latency value for the specified peer using
|
||||
Exponentially Weighted Moving Average (EWMA).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer whose latency EWMA is to be returned.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def clear_metrics(self, peer_id: ID) -> None:
|
||||
"""
|
||||
Clears the stored latency metrics for the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer whose latency metrics are to be cleared.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# -------------------------- protobook interface.py --------------------------
|
||||
|
||||
|
||||
class IProtoBook(ABC):
|
||||
"""
|
||||
Interface for a protocol book.
|
||||
|
||||
Provides methods for managing the list of supported protocols.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_protocols(self, peer_id: ID) -> list[str]:
|
||||
"""
|
||||
Returns the list of protocols associated with the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer whose supported protocols are to be returned.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_protocols(self, peer_id: ID, protocols: Sequence[str]) -> None:
|
||||
"""
|
||||
Adds the given protocols to the specified peer's protocol list.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer to which protocols will be added.
|
||||
|
||||
protocols : Sequence[str]
|
||||
A sequence of protocol strings to add.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_protocols(self, peer_id: ID, protocols: Sequence[str]) -> None:
|
||||
"""
|
||||
Replaces the existing protocols of the specified peer with the given list.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer whose protocols are to be set.
|
||||
|
||||
protocols : Sequence[str]
|
||||
A sequence of protocol strings to assign.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def remove_protocols(self, peer_id: ID, protocols: Sequence[str]) -> None:
|
||||
"""
|
||||
Removes the specified protocols from the peer's protocol list.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer from which protocols will be removed.
|
||||
|
||||
protocols : Sequence[str]
|
||||
A sequence of protocol strings to remove.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def supports_protocols(self, peer_id: ID, protocols: Sequence[str]) -> list[str]:
|
||||
"""
|
||||
Returns the list of protocols from the input sequence that the peer supports.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer to check for protocol support.
|
||||
|
||||
protocols : Sequence[str]
|
||||
A sequence of protocol strings to check against the peer's
|
||||
supported protocols.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def first_supported_protocol(self, peer_id: ID, protocols: Sequence[str]) -> str:
|
||||
"""
|
||||
Returns the first protocol from the input list that the peer supports.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer to check for supported protocols.
|
||||
|
||||
protocols : Sequence[str]
|
||||
A sequence of protocol strings to check.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The first matching protocol string, or an empty string
|
||||
if none are supported.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def clear_protocol_data(self, peer_id: ID) -> None:
|
||||
"""
|
||||
Clears all protocol data associated with the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer whose protocol data will be cleared.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# -------------------------- peerstore interface.py --------------------------
|
||||
|
||||
|
||||
class IPeerStore(IAddrBook, IPeerMetadata):
|
||||
class IPeerStore(IPeerMetadata, IAddrBook, IKeyBook, IMetrics, IProtoBook):
|
||||
"""
|
||||
Interface for a peer store.
|
||||
|
||||
@ -487,85 +761,7 @@ class IPeerStore(IAddrBook, IPeerMetadata):
|
||||
management, protocol handling, and key storage.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def peer_info(self, peer_id: ID) -> PeerInfo:
|
||||
"""
|
||||
Retrieve the peer information for the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer.
|
||||
|
||||
Returns
|
||||
-------
|
||||
PeerInfo
|
||||
The peer information object for the given peer.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_protocols(self, peer_id: ID) -> list[str]:
|
||||
"""
|
||||
Retrieve the protocols associated with the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[str]
|
||||
A list of protocol identifiers.
|
||||
|
||||
Raises
|
||||
------
|
||||
PeerStoreError
|
||||
If the peer ID is not found.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_protocols(self, peer_id: ID, protocols: Sequence[str]) -> None:
|
||||
"""
|
||||
Add additional protocols for the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer.
|
||||
protocols : Sequence[str]
|
||||
The protocols to add.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_protocols(self, peer_id: ID, protocols: Sequence[str]) -> None:
|
||||
"""
|
||||
Set the protocols for the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer.
|
||||
protocols : Sequence[str]
|
||||
The protocols to set.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def peer_ids(self) -> list[ID]:
|
||||
"""
|
||||
Retrieve all peer identifiers stored in the peer store.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[ID]
|
||||
A list of all peer IDs in the store.
|
||||
|
||||
"""
|
||||
|
||||
# -------METADATA---------
|
||||
@abstractmethod
|
||||
def get(self, peer_id: ID, key: str) -> Any:
|
||||
"""
|
||||
@ -606,6 +802,19 @@ class IPeerStore(IAddrBook, IPeerMetadata):
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def clear_metadata(self, peer_id: ID) -> None:
|
||||
"""
|
||||
Clears the stored latency metrics for the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer whose latency metrics are to be cleared.
|
||||
|
||||
"""
|
||||
|
||||
# --------ADDR-BOOK---------
|
||||
@abstractmethod
|
||||
def add_addr(self, peer_id: ID, addr: Multiaddr, ttl: int) -> None:
|
||||
"""
|
||||
@ -679,25 +888,7 @@ class IPeerStore(IAddrBook, IPeerMetadata):
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_pubkey(self, peer_id: ID, pubkey: PublicKey) -> None:
|
||||
"""
|
||||
Add a public key for the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer.
|
||||
pubkey : PublicKey
|
||||
The public key to add.
|
||||
|
||||
Raises
|
||||
------
|
||||
PeerStoreError
|
||||
If the peer already has a public key set.
|
||||
|
||||
"""
|
||||
|
||||
# --------KEY-BOOK----------
|
||||
@abstractmethod
|
||||
def pubkey(self, peer_id: ID) -> PublicKey:
|
||||
"""
|
||||
@ -720,25 +911,6 @@ class IPeerStore(IAddrBook, IPeerMetadata):
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_privkey(self, peer_id: ID, privkey: PrivateKey) -> None:
|
||||
"""
|
||||
Add a private key for the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer.
|
||||
privkey : PrivateKey
|
||||
The private key to add.
|
||||
|
||||
Raises
|
||||
------
|
||||
PeerStoreError
|
||||
If the peer already has a private key set.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def privkey(self, peer_id: ID) -> PrivateKey:
|
||||
"""
|
||||
@ -761,6 +933,44 @@ class IPeerStore(IAddrBook, IPeerMetadata):
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_pubkey(self, peer_id: ID, pubkey: PublicKey) -> None:
|
||||
"""
|
||||
Add a public key for the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer.
|
||||
pubkey : PublicKey
|
||||
The public key to add.
|
||||
|
||||
Raises
|
||||
------
|
||||
PeerStoreError
|
||||
If the peer already has a public key set.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_privkey(self, peer_id: ID, privkey: PrivateKey) -> None:
|
||||
"""
|
||||
Add a private key for the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer.
|
||||
privkey : PrivateKey
|
||||
The private key to add.
|
||||
|
||||
Raises
|
||||
------
|
||||
PeerStoreError
|
||||
If the peer already has a private key set.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_key_pair(self, peer_id: ID, key_pair: KeyPair) -> None:
|
||||
"""
|
||||
@ -780,6 +990,213 @@ class IPeerStore(IAddrBook, IPeerMetadata):
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def peer_with_keys(self) -> list[ID]:
|
||||
"""Returns all the peer IDs stored in the AddrBook"""
|
||||
|
||||
@abstractmethod
|
||||
def clear_keydata(self, peer_id: ID) -> None:
|
||||
"""
|
||||
Remove all stored keydata for the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer identifier whose keys are to be removed.
|
||||
|
||||
"""
|
||||
|
||||
# -------METRICS---------
|
||||
@abstractmethod
|
||||
def record_latency(self, peer_id: ID, RTT: float) -> None:
|
||||
"""
|
||||
Records a new round-trip time (RTT) latency value for the specified peer
|
||||
using Exponentially Weighted Moving Average (EWMA).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer for which latency is being recorded.
|
||||
|
||||
RTT : float
|
||||
The round-trip time latency value to record.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def latency_EWMA(self, peer_id: ID) -> float:
|
||||
"""
|
||||
Returns the current latency value for the specified peer using
|
||||
Exponentially Weighted Moving Average (EWMA).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer whose latency EWMA is to be returned.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def clear_metrics(self, peer_id: ID) -> None:
|
||||
"""
|
||||
Clears the stored latency metrics for the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer whose latency metrics are to be cleared.
|
||||
|
||||
"""
|
||||
|
||||
# --------PROTO-BOOK----------
|
||||
@abstractmethod
|
||||
def get_protocols(self, peer_id: ID) -> list[str]:
|
||||
"""
|
||||
Retrieve the protocols associated with the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[str]
|
||||
A list of protocol identifiers.
|
||||
|
||||
Raises
|
||||
------
|
||||
PeerStoreError
|
||||
If the peer ID is not found.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_protocols(self, peer_id: ID, protocols: Sequence[str]) -> None:
|
||||
"""
|
||||
Add additional protocols for the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer.
|
||||
protocols : Sequence[str]
|
||||
The protocols to add.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_protocols(self, peer_id: ID, protocols: Sequence[str]) -> None:
|
||||
"""
|
||||
Set the protocols for the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer.
|
||||
protocols : Sequence[str]
|
||||
The protocols to set.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def remove_protocols(self, peer_id: ID, protocols: Sequence[str]) -> None:
|
||||
"""
|
||||
Removes the specified protocols from the peer's protocol list.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer from which protocols will be removed.
|
||||
|
||||
protocols : Sequence[str]
|
||||
A sequence of protocol strings to remove.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def supports_protocols(self, peer_id: ID, protocols: Sequence[str]) -> list[str]:
|
||||
"""
|
||||
Returns the list of protocols from the input sequence that the peer supports.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer to check for protocol support.
|
||||
|
||||
protocols : Sequence[str]
|
||||
A sequence of protocol strings to check against the peer's
|
||||
supported protocols.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def first_supported_protocol(self, peer_id: ID, protocols: Sequence[str]) -> str:
|
||||
"""
|
||||
Returns the first protocol from the input list that the peer supports.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer to check for supported protocols.
|
||||
|
||||
protocols : Sequence[str]
|
||||
A sequence of protocol strings to check.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The first matching protocol string, or an empty string
|
||||
if none are supported.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def clear_protocol_data(self, peer_id: ID) -> None:
|
||||
"""
|
||||
Clears all protocol data associated with the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer whose protocol data will be cleared.
|
||||
|
||||
"""
|
||||
|
||||
# --------PEER-STORE--------
|
||||
@abstractmethod
|
||||
def peer_info(self, peer_id: ID) -> PeerInfo:
|
||||
"""
|
||||
Retrieve the peer information for the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer.
|
||||
|
||||
Returns
|
||||
-------
|
||||
PeerInfo
|
||||
The peer information object for the given peer.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def peer_ids(self) -> list[ID]:
|
||||
"""
|
||||
Retrieve all peer identifiers stored in the peer store.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[ID]
|
||||
A list of all peer IDs in the store.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def clear_peerdata(self, peer_id: ID) -> None:
|
||||
"""clear_peerdata"""
|
||||
|
||||
|
||||
# -------------------------- listener interface.py --------------------------
|
||||
|
||||
@ -1315,6 +1732,60 @@ class IPeerData(ABC):
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def remove_protocols(self, protocols: Sequence[str]) -> None:
|
||||
"""
|
||||
Removes the specified protocols from this peer's list of supported protocols.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
protocols : Sequence[str]
|
||||
A sequence of protocol strings to be removed.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def supports_protocols(self, protocols: Sequence[str]) -> list[str]:
|
||||
"""
|
||||
Returns the list of protocols from the input sequence that are supported
|
||||
by this peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
protocols : Sequence[str]
|
||||
A sequence of protocol strings to check against this peer's supported
|
||||
protocols.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[str]
|
||||
A list of protocol strings that are supported.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def first_supported_protocol(self, protocols: Sequence[str]) -> str:
|
||||
"""
|
||||
Returns the first protocol from the input list that this peer supports.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
protocols : Sequence[str]
|
||||
A sequence of protocol strings to check for support.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The first matching protocol, or an empty string if none are supported.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def clear_protocol_data(self) -> None:
|
||||
"""
|
||||
Clears all protocol data associated with this peer.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_addrs(self, addrs: Sequence[Multiaddr]) -> None:
|
||||
"""
|
||||
@ -1324,6 +1795,8 @@ class IPeerData(ABC):
|
||||
----------
|
||||
addrs : Sequence[Multiaddr]
|
||||
A sequence of multiaddresses to add.
|
||||
ttl: inr
|
||||
Time to live for the peer record
|
||||
|
||||
"""
|
||||
|
||||
@ -1382,6 +1855,12 @@ class IPeerData(ABC):
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def clear_metadata(self) -> None:
|
||||
"""
|
||||
Clears all metadata entries associated with this peer.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_pubkey(self, pubkey: PublicKey) -> None:
|
||||
"""
|
||||
@ -1440,6 +1919,45 @@ class IPeerData(ABC):
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def clear_keydata(self) -> None:
|
||||
"""
|
||||
Clears all cryptographic key data associated with this peer,
|
||||
including both public and private keys.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def record_latency(self, new_latency: float) -> None:
|
||||
"""
|
||||
Records a new latency measurement using
|
||||
Exponentially Weighted Moving Average (EWMA).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
new_latency : float
|
||||
The new round-trip time (RTT) latency value to incorporate
|
||||
into the EWMA calculation.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def latency_EWMA(self) -> float:
|
||||
"""
|
||||
Returns the current EWMA value of the recorded latency.
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
The current latency estimate based on EWMA.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def clear_metrics(self) -> None:
|
||||
"""
|
||||
Clears all latency-related metrics and resets the internal state.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def update_last_identified(self) -> None:
|
||||
"""
|
||||
|
||||
@ -18,6 +18,13 @@ from libp2p.crypto.keys import (
|
||||
PublicKey,
|
||||
)
|
||||
|
||||
"""
|
||||
Latency EWMA Smoothing governs the deacy of the EWMA (the speed at which
|
||||
is changes). This must be a normalized (0-1) value.
|
||||
1 is 100% change, 0 is no change.
|
||||
"""
|
||||
LATENCY_EWMA_SMOOTHING = 0.1
|
||||
|
||||
|
||||
class PeerData(IPeerData):
|
||||
pubkey: PublicKey | None
|
||||
@ -27,6 +34,7 @@ class PeerData(IPeerData):
|
||||
addrs: list[Multiaddr]
|
||||
last_identified: int
|
||||
ttl: int # Keep ttl=0 by default for always valid
|
||||
latmap: float
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.pubkey = None
|
||||
@ -36,6 +44,9 @@ class PeerData(IPeerData):
|
||||
self.addrs = []
|
||||
self.last_identified = int(time.time())
|
||||
self.ttl = 0
|
||||
self.latmap = 0
|
||||
|
||||
# --------PROTO-BOOK--------
|
||||
|
||||
def get_protocols(self) -> list[str]:
|
||||
"""
|
||||
@ -55,6 +66,37 @@ class PeerData(IPeerData):
|
||||
"""
|
||||
self.protocols = list(protocols)
|
||||
|
||||
def remove_protocols(self, protocols: Sequence[str]) -> None:
|
||||
"""
|
||||
:param protocols: protocols to remove
|
||||
"""
|
||||
for protocol in protocols:
|
||||
if protocol in self.protocols:
|
||||
self.protocols.remove(protocol)
|
||||
|
||||
def supports_protocols(self, protocols: Sequence[str]) -> list[str]:
|
||||
"""
|
||||
:param protocols: protocols to check from
|
||||
:return: all supported protocols in the given list
|
||||
"""
|
||||
return [proto for proto in protocols if proto in self.protocols]
|
||||
|
||||
def first_supported_protocol(self, protocols: Sequence[str]) -> str:
|
||||
"""
|
||||
:param protocols: protocols to check from
|
||||
:return: first supported protocol in the given list
|
||||
"""
|
||||
for protocol in protocols:
|
||||
if protocol in self.protocols:
|
||||
return protocol
|
||||
|
||||
return "None supported"
|
||||
|
||||
def clear_protocol_data(self) -> None:
|
||||
"""Clear all protocols"""
|
||||
self.protocols = []
|
||||
|
||||
# -------ADDR-BOOK---------
|
||||
def add_addrs(self, addrs: Sequence[Multiaddr]) -> None:
|
||||
"""
|
||||
:param addrs: multiaddresses to add
|
||||
@ -73,6 +115,7 @@ class PeerData(IPeerData):
|
||||
"""Clear all addresses."""
|
||||
self.addrs = []
|
||||
|
||||
# -------METADATA-----------
|
||||
def put_metadata(self, key: str, val: Any) -> None:
|
||||
"""
|
||||
:param key: key in KV pair
|
||||
@ -90,6 +133,11 @@ class PeerData(IPeerData):
|
||||
return self.metadata[key]
|
||||
raise PeerDataError("key not found")
|
||||
|
||||
def clear_metadata(self) -> None:
|
||||
"""Clears metadata."""
|
||||
self.metadata = {}
|
||||
|
||||
# -------KEY-BOOK---------------
|
||||
def add_pubkey(self, pubkey: PublicKey) -> None:
|
||||
"""
|
||||
:param pubkey:
|
||||
@ -120,9 +168,41 @@ class PeerData(IPeerData):
|
||||
raise PeerDataError("private key not found")
|
||||
return self.privkey
|
||||
|
||||
def clear_keydata(self) -> None:
|
||||
"""Clears keydata"""
|
||||
self.pubkey = None
|
||||
self.privkey = None
|
||||
|
||||
# ----------METRICS--------------
|
||||
def record_latency(self, new_latency: float) -> None:
|
||||
"""
|
||||
Records a new latency measurement for the given peer
|
||||
using Exponentially Weighted Moving Average (EWMA)
|
||||
:param new_latency: the new latency value
|
||||
"""
|
||||
s = LATENCY_EWMA_SMOOTHING
|
||||
if s > 1 or s < 0:
|
||||
s = 0.1
|
||||
|
||||
if self.latmap == 0:
|
||||
self.latmap = new_latency
|
||||
else:
|
||||
prev = self.latmap
|
||||
updated = ((1.0 - s) * prev) + (s * new_latency)
|
||||
self.latmap = updated
|
||||
|
||||
def latency_EWMA(self) -> float:
|
||||
"""Returns the latency EWMA value"""
|
||||
return self.latmap
|
||||
|
||||
def clear_metrics(self) -> None:
|
||||
"""Clear the latency metrics"""
|
||||
self.latmap = 0
|
||||
|
||||
def update_last_identified(self) -> None:
|
||||
self.last_identified = int(time.time())
|
||||
|
||||
# ----------TTL------------------
|
||||
def get_last_identified(self) -> int:
|
||||
"""
|
||||
:return: last identified timestamp
|
||||
|
||||
@ -2,6 +2,7 @@ from collections import (
|
||||
defaultdict,
|
||||
)
|
||||
from collections.abc import (
|
||||
AsyncIterable,
|
||||
Sequence,
|
||||
)
|
||||
from typing import (
|
||||
@ -11,6 +12,8 @@ from typing import (
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
)
|
||||
import trio
|
||||
from trio import MemoryReceiveChannel, MemorySendChannel
|
||||
|
||||
from libp2p.abc import (
|
||||
IPeerStore,
|
||||
@ -40,6 +43,7 @@ class PeerStore(IPeerStore):
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.peer_data_map = defaultdict(PeerData)
|
||||
self.addr_update_channels: dict[ID, MemorySendChannel[Multiaddr]] = {}
|
||||
|
||||
def peer_info(self, peer_id: ID) -> PeerInfo:
|
||||
"""
|
||||
@ -53,6 +57,29 @@ class PeerStore(IPeerStore):
|
||||
return PeerInfo(peer_id, peer_data.get_addrs())
|
||||
raise PeerStoreError("peer ID not found")
|
||||
|
||||
def peer_ids(self) -> list[ID]:
|
||||
"""
|
||||
:return: all of the peer IDs stored in peer store
|
||||
"""
|
||||
return list(self.peer_data_map.keys())
|
||||
|
||||
def clear_peerdata(self, peer_id: ID) -> None:
|
||||
"""Clears the peer data of the peer"""
|
||||
|
||||
def valid_peer_ids(self) -> list[ID]:
|
||||
"""
|
||||
:return: all of the valid peer IDs stored in peer store
|
||||
"""
|
||||
valid_peer_ids: list[ID] = []
|
||||
for peer_id, peer_data in self.peer_data_map.items():
|
||||
if not peer_data.is_expired():
|
||||
valid_peer_ids.append(peer_id)
|
||||
else:
|
||||
peer_data.clear_addrs()
|
||||
return valid_peer_ids
|
||||
|
||||
# --------PROTO-BOOK--------
|
||||
|
||||
def get_protocols(self, peer_id: ID) -> list[str]:
|
||||
"""
|
||||
:param peer_id: peer ID to get protocols for
|
||||
@ -79,23 +106,31 @@ class PeerStore(IPeerStore):
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
peer_data.set_protocols(list(protocols))
|
||||
|
||||
def peer_ids(self) -> list[ID]:
|
||||
def remove_protocols(self, peer_id: ID, protocols: Sequence[str]) -> None:
|
||||
"""
|
||||
:param peer_id: peer ID to get info for
|
||||
:param protocols: unsupported protocols to remove
|
||||
"""
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
peer_data.remove_protocols(protocols)
|
||||
|
||||
def supports_protocols(self, peer_id: ID, protocols: Sequence[str]) -> list[str]:
|
||||
"""
|
||||
:return: all of the peer IDs stored in peer store
|
||||
"""
|
||||
return list(self.peer_data_map.keys())
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
return peer_data.supports_protocols(protocols)
|
||||
|
||||
def valid_peer_ids(self) -> list[ID]:
|
||||
"""
|
||||
:return: all of the valid peer IDs stored in peer store
|
||||
"""
|
||||
valid_peer_ids: list[ID] = []
|
||||
for peer_id, peer_data in self.peer_data_map.items():
|
||||
if not peer_data.is_expired():
|
||||
valid_peer_ids.append(peer_id)
|
||||
else:
|
||||
peer_data.clear_addrs()
|
||||
return valid_peer_ids
|
||||
def first_supported_protocol(self, peer_id: ID, protocols: Sequence[str]) -> str:
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
return peer_data.first_supported_protocol(protocols)
|
||||
|
||||
def clear_protocol_data(self, peer_id: ID) -> None:
|
||||
"""Clears prtocoldata"""
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
peer_data.clear_protocol_data()
|
||||
|
||||
# ------METADATA---------
|
||||
|
||||
def get(self, peer_id: ID, key: str) -> Any:
|
||||
"""
|
||||
@ -121,6 +156,13 @@ class PeerStore(IPeerStore):
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
peer_data.put_metadata(key, val)
|
||||
|
||||
def clear_metadata(self, peer_id: ID) -> None:
|
||||
"""Clears metadata"""
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
peer_data.clear_metadata()
|
||||
|
||||
# -------ADDR-BOOK--------
|
||||
|
||||
def add_addr(self, peer_id: ID, addr: Multiaddr, ttl: int = 0) -> None:
|
||||
"""
|
||||
:param peer_id: peer ID to add address for
|
||||
@ -140,6 +182,13 @@ class PeerStore(IPeerStore):
|
||||
peer_data.set_ttl(ttl)
|
||||
peer_data.update_last_identified()
|
||||
|
||||
if peer_id in self.addr_update_channels:
|
||||
for addr in addrs:
|
||||
try:
|
||||
self.addr_update_channels[peer_id].send_nowait(addr)
|
||||
except trio.WouldBlock:
|
||||
pass # Or consider logging / dropping / replacing stream
|
||||
|
||||
def addrs(self, peer_id: ID) -> list[Multiaddr]:
|
||||
"""
|
||||
:param peer_id: peer ID to get addrs for
|
||||
@ -165,7 +214,7 @@ class PeerStore(IPeerStore):
|
||||
|
||||
def peers_with_addrs(self) -> list[ID]:
|
||||
"""
|
||||
:return: all of the peer IDs which has addrs stored in peer store
|
||||
:return: all of the peer IDs which has addrsfloat stored in peer store
|
||||
"""
|
||||
# Add all peers with addrs at least 1 to output
|
||||
output: list[ID] = []
|
||||
@ -179,6 +228,27 @@ class PeerStore(IPeerStore):
|
||||
peer_data.clear_addrs()
|
||||
return output
|
||||
|
||||
async def addr_stream(self, peer_id: ID) -> AsyncIterable[Multiaddr]:
|
||||
"""
|
||||
Returns an async stream of newly added addresses for the given peer.
|
||||
|
||||
This function allows consumers to subscribe to address updates for a peer
|
||||
and receive each new address as it is added via `add_addr` or `add_addrs`.
|
||||
|
||||
:param peer_id: The ID of the peer to monitor address updates for.
|
||||
:return: An async iterator yielding Multiaddr instances as they are added.
|
||||
"""
|
||||
send: MemorySendChannel[Multiaddr]
|
||||
receive: MemoryReceiveChannel[Multiaddr]
|
||||
|
||||
send, receive = trio.open_memory_channel(0)
|
||||
self.addr_update_channels[peer_id] = send
|
||||
|
||||
async for addr in receive:
|
||||
yield addr
|
||||
|
||||
# -------KEY-BOOK---------
|
||||
|
||||
def add_pubkey(self, peer_id: ID, pubkey: PublicKey) -> None:
|
||||
"""
|
||||
:param peer_id: peer ID to add public key for
|
||||
@ -239,6 +309,45 @@ class PeerStore(IPeerStore):
|
||||
self.add_pubkey(peer_id, key_pair.public_key)
|
||||
self.add_privkey(peer_id, key_pair.private_key)
|
||||
|
||||
def peer_with_keys(self) -> list[ID]:
|
||||
"""Returns the peer_ids for which keys are stored"""
|
||||
return [
|
||||
peer_id
|
||||
for peer_id, pdata in self.peer_data_map.items()
|
||||
if pdata.pubkey is not None
|
||||
]
|
||||
|
||||
def clear_keydata(self, peer_id: ID) -> None:
|
||||
"""Clears the keys of the peer"""
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
peer_data.clear_keydata()
|
||||
|
||||
# --------METRICS--------
|
||||
|
||||
def record_latency(self, peer_id: ID, RTT: float) -> None:
|
||||
"""
|
||||
Records a new latency measurement for the given peer
|
||||
using Exponentially Weighted Moving Average (EWMA)
|
||||
|
||||
:param peer_id: peer ID to get private key for
|
||||
:param RTT: the new latency value (round trip time)
|
||||
"""
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
peer_data.record_latency(RTT)
|
||||
|
||||
def latency_EWMA(self, peer_id: ID) -> float:
|
||||
"""
|
||||
:param peer_id: peer ID to get private key for
|
||||
:return: The latency EWMA value for that peer
|
||||
"""
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
return peer_data.latency_EWMA()
|
||||
|
||||
def clear_metrics(self, peer_id: ID) -> None:
|
||||
"""Clear the latency metrics"""
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
peer_data.clear_metrics()
|
||||
|
||||
|
||||
class PeerStoreError(KeyError):
|
||||
"""Raised when peer ID is not found in peer store."""
|
||||
|
||||
@ -98,16 +98,32 @@ class YamuxStream(IMuxedStream):
|
||||
# Flow control: Check if we have enough send window
|
||||
total_len = len(data)
|
||||
sent = 0
|
||||
|
||||
logging.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
|
||||
while sent < total_len:
|
||||
# Wait for available window with timeout
|
||||
timeout = False
|
||||
async with self.window_lock:
|
||||
# Wait for available window
|
||||
while self.send_window == 0 and not self.closed:
|
||||
# Release lock while waiting
|
||||
if self.send_window == 0:
|
||||
logging.debug(
|
||||
f"Stream {self.stream_id}: Window is zero, waiting for update"
|
||||
)
|
||||
# Release lock and wait with timeout
|
||||
self.window_lock.release()
|
||||
await trio.sleep(0.01)
|
||||
# To avoid re-acquiring the lock immediately,
|
||||
with trio.move_on_after(5.0) as cancel_scope:
|
||||
while self.send_window == 0 and not self.closed:
|
||||
await trio.sleep(0.01)
|
||||
# If we timed out, cancel the scope
|
||||
timeout = cancel_scope.cancelled_caught
|
||||
# Re-acquire lock
|
||||
await self.window_lock.acquire()
|
||||
|
||||
# If we timed out waiting for window update, raise an error
|
||||
if timeout:
|
||||
raise MuxedStreamError(
|
||||
"Timed out waiting for window update after 5 seconds."
|
||||
)
|
||||
|
||||
if self.closed:
|
||||
raise MuxedStreamError("Stream is closed")
|
||||
|
||||
@ -123,25 +139,45 @@ class YamuxStream(IMuxedStream):
|
||||
await self.conn.secured_conn.write(header + chunk)
|
||||
sent += to_send
|
||||
|
||||
# If window is getting low, consider updating
|
||||
if self.send_window < DEFAULT_WINDOW_SIZE // 2:
|
||||
await self.send_window_update()
|
||||
|
||||
async def send_window_update(self, increment: int | None = None) -> None:
|
||||
"""Send a window update to peer."""
|
||||
if increment is None:
|
||||
increment = DEFAULT_WINDOW_SIZE - self.recv_window
|
||||
async def send_window_update(self, increment: int, skip_lock: bool = False) -> None:
|
||||
"""
|
||||
Send a window update to peer.
|
||||
|
||||
param:increment: The amount to increment the window size by.
|
||||
If None, uses the difference between DEFAULT_WINDOW_SIZE
|
||||
and current receive window.
|
||||
param:skip_lock (bool): If True, skips acquiring window_lock.
|
||||
This should only be used when calling from a context
|
||||
that already holds the lock.
|
||||
"""
|
||||
if increment <= 0:
|
||||
# If increment is zero or negative, skip sending update
|
||||
logging.debug(
|
||||
f"Stream {self.stream_id}: Skipping window update"
|
||||
f"(increment={increment})"
|
||||
)
|
||||
return
|
||||
logging.debug(
|
||||
f"Stream {self.stream_id}: Sending window update with increment={increment}"
|
||||
)
|
||||
|
||||
async with self.window_lock:
|
||||
self.recv_window += increment
|
||||
async def _do_window_update() -> None:
|
||||
header = struct.pack(
|
||||
YAMUX_HEADER_FORMAT, 0, TYPE_WINDOW_UPDATE, 0, self.stream_id, increment
|
||||
YAMUX_HEADER_FORMAT,
|
||||
0,
|
||||
TYPE_WINDOW_UPDATE,
|
||||
0,
|
||||
self.stream_id,
|
||||
increment,
|
||||
)
|
||||
await self.conn.secured_conn.write(header)
|
||||
|
||||
if skip_lock:
|
||||
await _do_window_update()
|
||||
else:
|
||||
async with self.window_lock:
|
||||
await _do_window_update()
|
||||
|
||||
async def read(self, n: int | None = -1) -> bytes:
|
||||
# Handle None value for n by converting it to -1
|
||||
if n is None:
|
||||
@ -154,55 +190,68 @@ class YamuxStream(IMuxedStream):
|
||||
)
|
||||
raise MuxedStreamEOF("Stream is closed for receiving")
|
||||
|
||||
# If reading until EOF (n == -1), block until stream is closed
|
||||
if n == -1:
|
||||
while not self.recv_closed and not self.conn.event_shutting_down.is_set():
|
||||
data = b""
|
||||
while not self.conn.event_shutting_down.is_set():
|
||||
# Check if there's data in the buffer
|
||||
buffer = self.conn.stream_buffers.get(self.stream_id)
|
||||
if buffer and len(buffer) > 0:
|
||||
# Wait for closure even if data is available
|
||||
logging.debug(
|
||||
f"Stream {self.stream_id}:Waiting for FIN before returning data"
|
||||
)
|
||||
await self.conn.stream_events[self.stream_id].wait()
|
||||
self.conn.stream_events[self.stream_id] = trio.Event()
|
||||
else:
|
||||
# No data, wait for data or closure
|
||||
logging.debug(f"Stream {self.stream_id}: Waiting for data or FIN")
|
||||
await self.conn.stream_events[self.stream_id].wait()
|
||||
self.conn.stream_events[self.stream_id] = trio.Event()
|
||||
|
||||
# After loop, check if stream is closed or shutting down
|
||||
async with self.conn.streams_lock:
|
||||
if self.conn.event_shutting_down.is_set():
|
||||
logging.debug(f"Stream {self.stream_id}: Connection shutting down")
|
||||
raise MuxedStreamEOF("Connection shut down")
|
||||
if self.closed:
|
||||
if self.reset_received:
|
||||
logging.debug(f"Stream {self.stream_id}: Stream was reset")
|
||||
raise MuxedStreamReset("Stream was reset")
|
||||
else:
|
||||
logging.debug(
|
||||
f"Stream {self.stream_id}: Stream closed cleanly (EOF)"
|
||||
)
|
||||
raise MuxedStreamEOF("Stream closed cleanly (EOF)")
|
||||
buffer = self.conn.stream_buffers.get(self.stream_id)
|
||||
# If buffer is not available, check if stream is closed
|
||||
if buffer is None:
|
||||
logging.debug(
|
||||
f"Stream {self.stream_id}: Buffer gone, assuming closed"
|
||||
)
|
||||
logging.debug(f"Stream {self.stream_id}: No buffer available")
|
||||
raise MuxedStreamEOF("Stream buffer closed")
|
||||
|
||||
# If we have data in buffer, process it
|
||||
if len(buffer) > 0:
|
||||
chunk = bytes(buffer)
|
||||
buffer.clear()
|
||||
data += chunk
|
||||
|
||||
# Send window update for the chunk we just read
|
||||
async with self.window_lock:
|
||||
self.recv_window += len(chunk)
|
||||
logging.debug(f"Stream {self.stream_id}: Update {len(chunk)}")
|
||||
await self.send_window_update(len(chunk), skip_lock=True)
|
||||
|
||||
# If stream is closed (FIN received) and buffer is empty, break
|
||||
if self.recv_closed and len(buffer) == 0:
|
||||
logging.debug(f"Stream {self.stream_id}: EOF reached")
|
||||
raise MuxedStreamEOF("Stream is closed for receiving")
|
||||
# Return all buffered data
|
||||
data = bytes(buffer)
|
||||
buffer.clear()
|
||||
logging.debug(f"Stream {self.stream_id}: Returning {len(data)} bytes")
|
||||
logging.debug(f"Stream {self.stream_id}: Closed with empty buffer")
|
||||
break
|
||||
|
||||
# If stream was reset, raise reset error
|
||||
if self.reset_received:
|
||||
logging.debug(f"Stream {self.stream_id}: Stream was reset")
|
||||
raise MuxedStreamReset("Stream was reset")
|
||||
|
||||
# Wait for more data or stream closure
|
||||
logging.debug(f"Stream {self.stream_id}: Waiting for data or FIN")
|
||||
await self.conn.stream_events[self.stream_id].wait()
|
||||
self.conn.stream_events[self.stream_id] = trio.Event()
|
||||
|
||||
# After loop exit, first check if we have data to return
|
||||
if data:
|
||||
logging.debug(
|
||||
f"Stream {self.stream_id}: Returning {len(data)} bytes after loop"
|
||||
)
|
||||
return data
|
||||
|
||||
# For specific size read (n > 0), return available data immediately
|
||||
return await self.conn.read_stream(self.stream_id, n)
|
||||
# No data accumulated, now check why we exited the loop
|
||||
if self.conn.event_shutting_down.is_set():
|
||||
logging.debug(f"Stream {self.stream_id}: Connection shutting down")
|
||||
raise MuxedStreamEOF("Connection shut down")
|
||||
|
||||
# Return empty data
|
||||
return b""
|
||||
else:
|
||||
data = await self.conn.read_stream(self.stream_id, n)
|
||||
async with self.window_lock:
|
||||
self.recv_window += len(data)
|
||||
logging.debug(
|
||||
f"Stream {self.stream_id}: Sending window update after read, "
|
||||
f"increment={len(data)}"
|
||||
)
|
||||
await self.send_window_update(len(data), skip_lock=True)
|
||||
return data
|
||||
|
||||
async def close(self) -> None:
|
||||
if not self.send_closed:
|
||||
|
||||
6
newsfragments/639.feature.rst
Normal file
6
newsfragments/639.feature.rst
Normal file
@ -0,0 +1,6 @@
|
||||
Fixed several flow-control and concurrency issues in the `YamuxStream` class. Previously, stress-testing revealed that transferring data over `DEFAULT_WINDOW_SIZE` would break the stream due to inconsistent window update handling and lock management. The fixes include:
|
||||
|
||||
- Removed sending of window updates during writes to maintain correct flow-control.
|
||||
- Added proper timeout handling when releasing and acquiring locks to prevent concurrency errors.
|
||||
- Corrected the `read` function to properly handle window updates for both `read_until_EOF` and `read_n_bytes`.
|
||||
- Added event logging at `send_window_updates` and `waiting_for_window_updates` for better observability.
|
||||
1
newsfragments/708.performance.rst
Normal file
1
newsfragments/708.performance.rst
Normal file
@ -0,0 +1 @@
|
||||
Added extra tests for identify push concurrency cap under high peer load
|
||||
@ -35,6 +35,8 @@ from tests.utils.factories import (
|
||||
)
|
||||
from tests.utils.utils import (
|
||||
create_mock_connections,
|
||||
run_host_forever,
|
||||
wait_until_listening,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("libp2p.identity.identify-push-test")
|
||||
@ -503,3 +505,91 @@ async def test_push_identify_to_peers_respects_concurrency_limit():
|
||||
assert state["max_observed"] <= CONCURRENCY_LIMIT, (
|
||||
f"Max concurrency observed: {state['max_observed']}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_all_peers_receive_identify_push_with_semaphore(security_protocol):
|
||||
dummy_peers = []
|
||||
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (host_a, _):
|
||||
# Create dummy peers
|
||||
for _ in range(50):
|
||||
key_pair = create_new_key_pair()
|
||||
dummy_host = new_host(key_pair=key_pair)
|
||||
dummy_host.set_stream_handler(
|
||||
ID_PUSH, identify_push_handler_for(dummy_host)
|
||||
)
|
||||
listen_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")
|
||||
dummy_peers.append((dummy_host, listen_addr))
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Start all dummy hosts
|
||||
for host, listen_addr in dummy_peers:
|
||||
nursery.start_soon(run_host_forever, host, listen_addr)
|
||||
|
||||
# Wait for all hosts to finish setting up listeners
|
||||
for host, _ in dummy_peers:
|
||||
await wait_until_listening(host)
|
||||
|
||||
# Now connect host_a → dummy peers
|
||||
for host, _ in dummy_peers:
|
||||
await host_a.connect(info_from_p2p_addr(host.get_addrs()[0]))
|
||||
|
||||
await push_identify_to_peers(
|
||||
host_a,
|
||||
)
|
||||
|
||||
await trio.sleep(0.5)
|
||||
|
||||
peer_id_a = host_a.get_id()
|
||||
for host, _ in dummy_peers:
|
||||
dummy_peerstore = host.get_peerstore()
|
||||
assert peer_id_a in dummy_peerstore.peer_ids()
|
||||
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_all_peers_receive_identify_push_with_semaphore_under_high_peer_load(
|
||||
security_protocol,
|
||||
):
|
||||
dummy_peers = []
|
||||
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (host_a, _):
|
||||
# Create dummy peers
|
||||
# Breaking with more than 500 peers
|
||||
# Trio have a async tasks limit of 1000
|
||||
for _ in range(499):
|
||||
key_pair = create_new_key_pair()
|
||||
dummy_host = new_host(key_pair=key_pair)
|
||||
dummy_host.set_stream_handler(
|
||||
ID_PUSH, identify_push_handler_for(dummy_host)
|
||||
)
|
||||
listen_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")
|
||||
dummy_peers.append((dummy_host, listen_addr))
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Start all dummy hosts
|
||||
for host, listen_addr in dummy_peers:
|
||||
nursery.start_soon(run_host_forever, host, listen_addr)
|
||||
|
||||
# Wait for all hosts to finish setting up listeners
|
||||
for host, _ in dummy_peers:
|
||||
await wait_until_listening(host)
|
||||
|
||||
# Now connect host_a → dummy peers
|
||||
for host, _ in dummy_peers:
|
||||
await host_a.connect(info_from_p2p_addr(host.get_addrs()[0]))
|
||||
|
||||
await push_identify_to_peers(
|
||||
host_a,
|
||||
)
|
||||
|
||||
await trio.sleep(0.5)
|
||||
|
||||
peer_id_a = host_a.get_id()
|
||||
for host, _ in dummy_peers:
|
||||
dummy_peerstore = host.get_peerstore()
|
||||
assert peer_id_a in dummy_peerstore.peer_ids()
|
||||
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
@ -6,10 +6,12 @@ from multiaddr import Multiaddr
|
||||
from libp2p.crypto.secp256k1 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerdata import (
|
||||
PeerData,
|
||||
PeerDataError,
|
||||
)
|
||||
from libp2p.peer.peerstore import PeerStore
|
||||
|
||||
MOCK_ADDR = Multiaddr("/ip4/127.0.0.1/tcp/4001")
|
||||
MOCK_KEYPAIR = create_new_key_pair()
|
||||
@ -39,6 +41,59 @@ def test_set_protocols():
|
||||
assert peer_data.get_protocols() == protocols
|
||||
|
||||
|
||||
# Test case when removing protocols:
|
||||
def test_remove_protocols():
|
||||
peer_data = PeerData()
|
||||
protocols: Sequence[str] = ["protocol1", "protocol2"]
|
||||
peer_data.set_protocols(protocols)
|
||||
|
||||
peer_data.remove_protocols(["protocol1"])
|
||||
assert peer_data.get_protocols() == ["protocol2"]
|
||||
|
||||
|
||||
# Test case when clearing the protocol list:
|
||||
def test_clear_protocol_data():
|
||||
peer_data = PeerData()
|
||||
protocols: Sequence[str] = ["protocol1", "protocol2"]
|
||||
peer_data.set_protocols(protocols)
|
||||
|
||||
peer_data.clear_protocol_data()
|
||||
assert peer_data.get_protocols() == []
|
||||
|
||||
|
||||
# Test case when supports protocols:
|
||||
def test_supports_protocols():
|
||||
peer_data = PeerData()
|
||||
peer_data.set_protocols(["protocol1", "protocol2", "protocol3"])
|
||||
|
||||
input_protocols = ["protocol1", "protocol4", "protocol2"]
|
||||
supported = peer_data.supports_protocols(input_protocols)
|
||||
|
||||
assert supported == ["protocol1", "protocol2"]
|
||||
|
||||
|
||||
# Test case for first supported protocol is found
|
||||
def test_first_supported_protocol_found():
|
||||
peer_data = PeerData()
|
||||
peer_data.set_protocols(["protocolA", "protocolB"])
|
||||
|
||||
input_protocols = ["protocolC", "protocolB", "protocolA"]
|
||||
first = peer_data.first_supported_protocol(input_protocols)
|
||||
|
||||
assert first == "protocolB"
|
||||
|
||||
|
||||
# Test case for first supported protocol not found
|
||||
def test_first_supported_protocol_none():
|
||||
peer_data = PeerData()
|
||||
peer_data.set_protocols(["protocolX", "protocolY"])
|
||||
|
||||
input_protocols = ["protocolA", "protocolB"]
|
||||
first = peer_data.first_supported_protocol(input_protocols)
|
||||
|
||||
assert first == "None supported"
|
||||
|
||||
|
||||
# Test case when adding addresses
|
||||
def test_add_addrs():
|
||||
peer_data = PeerData()
|
||||
@ -81,6 +136,15 @@ def test_get_metadata_key_not_found():
|
||||
peer_data.get_metadata("nonexistent_key")
|
||||
|
||||
|
||||
# Test case for clearing metadata
|
||||
def test_clear_metadata():
|
||||
peer_data = PeerData()
|
||||
peer_data.metadata = {"key1": "value1", "key2": "value2"}
|
||||
|
||||
peer_data.clear_metadata()
|
||||
assert peer_data.metadata == {}
|
||||
|
||||
|
||||
# Test case for adding public key
|
||||
def test_add_pubkey():
|
||||
peer_data = PeerData()
|
||||
@ -107,3 +171,71 @@ def test_get_privkey_not_found():
|
||||
peer_data = PeerData()
|
||||
with pytest.raises(PeerDataError):
|
||||
peer_data.get_privkey()
|
||||
|
||||
|
||||
# Test case for returning all the peers with stored keys
|
||||
def test_peer_with_keys():
|
||||
peer_store = PeerStore()
|
||||
peer_id_1 = ID(b"peer1")
|
||||
peer_id_2 = ID(b"peer2")
|
||||
|
||||
peer_data_1 = PeerData()
|
||||
peer_data_2 = PeerData()
|
||||
|
||||
peer_data_1.pubkey = MOCK_PUBKEY
|
||||
peer_data_2.pubkey = None
|
||||
|
||||
peer_store.peer_data_map = {
|
||||
peer_id_1: peer_data_1,
|
||||
peer_id_2: peer_data_2,
|
||||
}
|
||||
|
||||
assert peer_store.peer_with_keys() == [peer_id_1]
|
||||
|
||||
|
||||
# Test case for clearing the key book
|
||||
def test_clear_keydata():
|
||||
peer_store = PeerStore()
|
||||
peer_id = ID(b"peer123")
|
||||
peer_data = PeerData()
|
||||
|
||||
peer_data.pubkey = MOCK_PUBKEY
|
||||
peer_data.privkey = MOCK_PRIVKEY
|
||||
peer_store.peer_data_map = {peer_id: peer_data}
|
||||
|
||||
peer_store.clear_keydata(peer_id)
|
||||
|
||||
assert peer_data.pubkey is None
|
||||
assert peer_data.privkey is None
|
||||
|
||||
|
||||
# Test case for recording latency for the first time
|
||||
def test_record_latency_initial():
|
||||
peer_data = PeerData()
|
||||
assert peer_data.latency_EWMA() == 0
|
||||
|
||||
peer_data.record_latency(100.0)
|
||||
assert peer_data.latency_EWMA() == 100.0
|
||||
|
||||
|
||||
# Test case for updating latency
|
||||
def test_record_latency_updates_ewma():
|
||||
peer_data = PeerData()
|
||||
peer_data.record_latency(100.0) # first measurement
|
||||
first = peer_data.latency_EWMA()
|
||||
|
||||
peer_data.record_latency(50.0) # second measurement
|
||||
second = peer_data.latency_EWMA()
|
||||
|
||||
assert second < first # EWMA should have smoothed downward
|
||||
assert second > 50.0 # Not as low as the new latency
|
||||
assert second != first
|
||||
|
||||
|
||||
def test_clear_metrics():
|
||||
peer_data = PeerData()
|
||||
peer_data.record_latency(200.0)
|
||||
assert peer_data.latency_EWMA() == 200.0
|
||||
|
||||
peer_data.clear_metrics()
|
||||
assert peer_data.latency_EWMA() == 0
|
||||
|
||||
@ -2,6 +2,7 @@ import time
|
||||
|
||||
import pytest
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerstore import (
|
||||
@ -89,3 +90,33 @@ def test_peers():
|
||||
store.add_addr(ID(b"peer3"), Multiaddr("/ip4/127.0.0.1/tcp/4001"), 10)
|
||||
|
||||
assert set(store.peer_ids()) == {ID(b"peer1"), ID(b"peer2"), ID(b"peer3")}
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_addr_stream_yields_new_addrs():
|
||||
store = PeerStore()
|
||||
peer_id = ID(b"peer1")
|
||||
addr1 = Multiaddr("/ip4/127.0.0.1/tcp/4001")
|
||||
addr2 = Multiaddr("/ip4/127.0.0.1/tcp/4002")
|
||||
|
||||
collected = []
|
||||
|
||||
async def consume_addrs():
|
||||
async for addr in store.addr_stream(peer_id):
|
||||
collected.append(addr)
|
||||
if len(collected) == 2:
|
||||
break
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(consume_addrs)
|
||||
await trio.sleep(2) # Give time for the stream to start
|
||||
|
||||
store.add_addr(peer_id, addr1, ttl=10)
|
||||
await trio.sleep(0.2)
|
||||
store.add_addr(peer_id, addr2, ttl=10)
|
||||
await trio.sleep(0.2)
|
||||
|
||||
# After collecting expected addresses, cancel the stream
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
assert collected == [addr1, addr2]
|
||||
|
||||
199
tests/core/stream_muxer/test_yamux_interleaving.py
Normal file
199
tests/core/stream_muxer/test_yamux_interleaving.py
Normal file
@ -0,0 +1,199 @@
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
from trio.testing import (
|
||||
memory_stream_pair,
|
||||
)
|
||||
|
||||
from libp2p.abc import IRawConnection
|
||||
from libp2p.crypto.ed25519 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.security.insecure.transport import (
|
||||
InsecureTransport,
|
||||
)
|
||||
from libp2p.stream_muxer.yamux.yamux import (
|
||||
Yamux,
|
||||
YamuxStream,
|
||||
)
|
||||
|
||||
|
||||
class TrioStreamAdapter(IRawConnection):
|
||||
"""Adapter to make trio memory streams work with libp2p."""
|
||||
|
||||
def __init__(self, send_stream, receive_stream, is_initiator=False):
|
||||
self.send_stream = send_stream
|
||||
self.receive_stream = receive_stream
|
||||
self.is_initiator = is_initiator
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
logging.debug(f"Attempting to write {len(data)} bytes")
|
||||
with trio.move_on_after(2):
|
||||
await self.send_stream.send_all(data)
|
||||
|
||||
async def read(self, n: int | None = None) -> bytes:
|
||||
if n is None or n <= 0:
|
||||
raise ValueError("Reading unbounded or zero bytes not supported")
|
||||
logging.debug(f"Attempting to read {n} bytes")
|
||||
with trio.move_on_after(2):
|
||||
data = await self.receive_stream.receive_some(n)
|
||||
logging.debug(f"Read {len(data)} bytes")
|
||||
return data
|
||||
|
||||
async def close(self) -> None:
|
||||
logging.debug("Closing stream")
|
||||
await self.send_stream.aclose()
|
||||
await self.receive_stream.aclose()
|
||||
|
||||
def get_remote_address(self) -> tuple[str, int] | None:
|
||||
"""Return None since this is a test adapter without real network info."""
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def key_pair():
|
||||
return create_new_key_pair()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def peer_id(key_pair):
|
||||
return ID.from_pubkey(key_pair.public_key)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def secure_conn_pair(key_pair, peer_id):
|
||||
"""Create a pair of secure connections for testing."""
|
||||
logging.debug("Setting up secure_conn_pair")
|
||||
client_send, server_receive = memory_stream_pair()
|
||||
server_send, client_receive = memory_stream_pair()
|
||||
|
||||
client_rw = TrioStreamAdapter(client_send, client_receive)
|
||||
server_rw = TrioStreamAdapter(server_send, server_receive)
|
||||
|
||||
insecure_transport = InsecureTransport(key_pair)
|
||||
|
||||
async def run_outbound(nursery_results):
|
||||
with trio.move_on_after(5):
|
||||
client_conn = await insecure_transport.secure_outbound(client_rw, peer_id)
|
||||
logging.debug("Outbound handshake complete")
|
||||
nursery_results["client"] = client_conn
|
||||
|
||||
async def run_inbound(nursery_results):
|
||||
with trio.move_on_after(5):
|
||||
server_conn = await insecure_transport.secure_inbound(server_rw)
|
||||
logging.debug("Inbound handshake complete")
|
||||
nursery_results["server"] = server_conn
|
||||
|
||||
nursery_results = {}
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(run_outbound, nursery_results)
|
||||
nursery.start_soon(run_inbound, nursery_results)
|
||||
await trio.sleep(0.1) # Give tasks a chance to finish
|
||||
|
||||
client_conn = nursery_results.get("client")
|
||||
server_conn = nursery_results.get("server")
|
||||
|
||||
if client_conn is None or server_conn is None:
|
||||
raise RuntimeError("Handshake failed: client_conn or server_conn is None")
|
||||
|
||||
logging.debug("secure_conn_pair setup complete")
|
||||
return client_conn, server_conn
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def yamux_pair(secure_conn_pair, peer_id):
|
||||
"""Create a pair of Yamux multiplexers for testing."""
|
||||
logging.debug("Setting up yamux_pair")
|
||||
client_conn, server_conn = secure_conn_pair
|
||||
client_yamux = Yamux(client_conn, peer_id, is_initiator=True)
|
||||
server_yamux = Yamux(server_conn, peer_id, is_initiator=False)
|
||||
async with trio.open_nursery() as nursery:
|
||||
with trio.move_on_after(5):
|
||||
nursery.start_soon(client_yamux.start)
|
||||
nursery.start_soon(server_yamux.start)
|
||||
await trio.sleep(0.1)
|
||||
logging.debug("yamux_pair started")
|
||||
yield client_yamux, server_yamux
|
||||
logging.debug("yamux_pair cleanup")
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_yamux_race_condition_without_locks(yamux_pair):
|
||||
"""
|
||||
Test for race-around/interleaving in Yamux streams,when reading in
|
||||
segments of data.
|
||||
This launches concurrent writers/readers on both sides of a stream.
|
||||
If there is no proper locking, the received data may be interleaved
|
||||
or corrupted.
|
||||
|
||||
The test creates structured messages and verifies they are received
|
||||
intact and in order.
|
||||
Without proper locking, concurrent read/write operations could cause
|
||||
data corruption
|
||||
or message interleaving, which this test will catch.
|
||||
"""
|
||||
client_yamux, server_yamux = yamux_pair
|
||||
client_stream: YamuxStream = await client_yamux.open_stream()
|
||||
server_stream: YamuxStream = await server_yamux.accept_stream()
|
||||
MSG_COUNT = 10
|
||||
MSG_SIZE = 256 * 1024 # At max,only DEFAULT_WINDOW_SIZE bytes can be read
|
||||
client_msgs = [
|
||||
f"CLIENT-MSG-{i:03d}-".encode().ljust(MSG_SIZE, b"C") for i in range(MSG_COUNT)
|
||||
]
|
||||
server_msgs = [
|
||||
f"SERVER-MSG-{i:03d}-".encode().ljust(MSG_SIZE, b"S") for i in range(MSG_COUNT)
|
||||
]
|
||||
client_received = []
|
||||
server_received = []
|
||||
|
||||
async def writer(stream, msgs, name):
|
||||
"""Write messages with minimal delays to encourage race conditions."""
|
||||
for i, msg in enumerate(msgs):
|
||||
await stream.write(msg)
|
||||
# Yield control frequently to encourage interleaving
|
||||
if i % 5 == 0:
|
||||
await trio.sleep(0.005)
|
||||
|
||||
async def reader(stream, received, name):
|
||||
"""Read messages and store them for verification."""
|
||||
for i in range(MSG_COUNT):
|
||||
data = await stream.read(MSG_SIZE)
|
||||
received.append(data)
|
||||
if i % 3 == 0:
|
||||
await trio.sleep(0.001)
|
||||
|
||||
# Running all operations concurrently
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(writer, client_stream, client_msgs, "client")
|
||||
nursery.start_soon(writer, server_stream, server_msgs, "server")
|
||||
nursery.start_soon(reader, client_stream, client_received, "client")
|
||||
nursery.start_soon(reader, server_stream, server_received, "server")
|
||||
|
||||
assert len(client_received) == MSG_COUNT, (
|
||||
f"Client received {len(client_received)} messages, expected {MSG_COUNT}"
|
||||
)
|
||||
assert len(server_received) == MSG_COUNT, (
|
||||
f"Server received {len(server_received)} messages, expected {MSG_COUNT}"
|
||||
)
|
||||
assert client_received == server_msgs, (
|
||||
"Client did not receive server messages in order or intact!"
|
||||
)
|
||||
assert server_received == client_msgs, (
|
||||
"Server did not receive client messages in order or intact!"
|
||||
)
|
||||
for i, msg in enumerate(client_received):
|
||||
assert len(msg) == MSG_SIZE, (
|
||||
f"Client message {i} has wrong size: {len(msg)} != {MSG_SIZE}"
|
||||
)
|
||||
|
||||
for i, msg in enumerate(server_received):
|
||||
assert len(msg) == MSG_SIZE, (
|
||||
f"Server message {i} has wrong size: {len(msg)} != {MSG_SIZE}"
|
||||
)
|
||||
|
||||
await client_stream.close()
|
||||
await server_stream.close()
|
||||
195
tests/core/stream_muxer/test_yamux_interleaving_EOF.py
Normal file
195
tests/core/stream_muxer/test_yamux_interleaving_EOF.py
Normal file
@ -0,0 +1,195 @@
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
from trio.testing import (
|
||||
memory_stream_pair,
|
||||
)
|
||||
|
||||
from libp2p.abc import IRawConnection
|
||||
from libp2p.crypto.ed25519 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.security.insecure.transport import (
|
||||
InsecureTransport,
|
||||
)
|
||||
from libp2p.stream_muxer.exceptions import MuxedStreamEOF
|
||||
from libp2p.stream_muxer.yamux.yamux import (
|
||||
Yamux,
|
||||
YamuxStream,
|
||||
)
|
||||
|
||||
|
||||
class TrioStreamAdapter(IRawConnection):
|
||||
"""Adapter to make trio memory streams work with libp2p."""
|
||||
|
||||
def __init__(self, send_stream, receive_stream, is_initiator=False):
|
||||
self.send_stream = send_stream
|
||||
self.receive_stream = receive_stream
|
||||
self.is_initiator = is_initiator
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
logging.debug(f"Attempting to write {len(data)} bytes")
|
||||
with trio.move_on_after(2):
|
||||
await self.send_stream.send_all(data)
|
||||
|
||||
async def read(self, n: int | None = None) -> bytes:
|
||||
if n is None or n <= 0:
|
||||
raise ValueError("Reading unbounded or zero bytes not supported")
|
||||
logging.debug(f"Attempting to read {n} bytes")
|
||||
with trio.move_on_after(2):
|
||||
data = await self.receive_stream.receive_some(n)
|
||||
logging.debug(f"Read {len(data)} bytes")
|
||||
return data
|
||||
|
||||
async def close(self) -> None:
|
||||
logging.debug("Closing stream")
|
||||
await self.send_stream.aclose()
|
||||
await self.receive_stream.aclose()
|
||||
|
||||
def get_remote_address(self) -> tuple[str, int] | None:
|
||||
"""Return None since this is a test adapter without real network info."""
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def key_pair():
|
||||
return create_new_key_pair()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def peer_id(key_pair):
|
||||
return ID.from_pubkey(key_pair.public_key)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def secure_conn_pair(key_pair, peer_id):
|
||||
"""Create a pair of secure connections for testing."""
|
||||
logging.debug("Setting up secure_conn_pair")
|
||||
client_send, server_receive = memory_stream_pair()
|
||||
server_send, client_receive = memory_stream_pair()
|
||||
|
||||
client_rw = TrioStreamAdapter(client_send, client_receive)
|
||||
server_rw = TrioStreamAdapter(server_send, server_receive)
|
||||
|
||||
insecure_transport = InsecureTransport(key_pair)
|
||||
|
||||
async def run_outbound(nursery_results):
|
||||
with trio.move_on_after(5):
|
||||
client_conn = await insecure_transport.secure_outbound(client_rw, peer_id)
|
||||
logging.debug("Outbound handshake complete")
|
||||
nursery_results["client"] = client_conn
|
||||
|
||||
async def run_inbound(nursery_results):
|
||||
with trio.move_on_after(5):
|
||||
server_conn = await insecure_transport.secure_inbound(server_rw)
|
||||
logging.debug("Inbound handshake complete")
|
||||
nursery_results["server"] = server_conn
|
||||
|
||||
nursery_results = {}
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(run_outbound, nursery_results)
|
||||
nursery.start_soon(run_inbound, nursery_results)
|
||||
await trio.sleep(0.1) # Give tasks a chance to finish
|
||||
|
||||
client_conn = nursery_results.get("client")
|
||||
server_conn = nursery_results.get("server")
|
||||
|
||||
if client_conn is None or server_conn is None:
|
||||
raise RuntimeError("Handshake failed: client_conn or server_conn is None")
|
||||
|
||||
logging.debug("secure_conn_pair setup complete")
|
||||
return client_conn, server_conn
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def yamux_pair(secure_conn_pair, peer_id):
|
||||
"""Create a pair of Yamux multiplexers for testing."""
|
||||
logging.debug("Setting up yamux_pair")
|
||||
client_conn, server_conn = secure_conn_pair
|
||||
client_yamux = Yamux(client_conn, peer_id, is_initiator=True)
|
||||
server_yamux = Yamux(server_conn, peer_id, is_initiator=False)
|
||||
async with trio.open_nursery() as nursery:
|
||||
with trio.move_on_after(5):
|
||||
nursery.start_soon(client_yamux.start)
|
||||
nursery.start_soon(server_yamux.start)
|
||||
await trio.sleep(0.1)
|
||||
logging.debug("yamux_pair started")
|
||||
yield client_yamux, server_yamux
|
||||
logging.debug("yamux_pair cleanup")
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_yamux_race_condition_without_locks(yamux_pair):
|
||||
"""
|
||||
Test for race-around/interleaving in Yamux streams,when reading till
|
||||
EOF is being used.
|
||||
This launches concurrent writers/readers on both sides of a stream.
|
||||
If there is no proper locking, the received data may be interleaved
|
||||
or corrupted.
|
||||
|
||||
The test creates structured messages and verifies they are received
|
||||
intact and in order.
|
||||
Without proper locking, concurrent read/write operations could cause
|
||||
data corruption
|
||||
or message interleaving, which this test will catch.
|
||||
"""
|
||||
client_yamux, server_yamux = yamux_pair
|
||||
client_stream: YamuxStream = await client_yamux.open_stream()
|
||||
server_stream: YamuxStream = await server_yamux.accept_stream()
|
||||
MSG_COUNT = 1
|
||||
MSG_SIZE = 512 * 1024
|
||||
client_msgs = [
|
||||
f"CLIENT-MSG-{i:03d}-".encode().ljust(MSG_SIZE, b"C") for i in range(MSG_COUNT)
|
||||
]
|
||||
server_msgs = [
|
||||
f"SERVER-MSG-{i:03d}-".encode().ljust(MSG_SIZE, b"S") for i in range(MSG_COUNT)
|
||||
]
|
||||
client_received = []
|
||||
server_received = []
|
||||
|
||||
async def writer(stream, msgs, name):
|
||||
"""Write messages with minimal delays to encourage race conditions."""
|
||||
for i, msg in enumerate(msgs):
|
||||
await stream.write(msg)
|
||||
# Yield control frequently to encourage interleaving
|
||||
if i % 5 == 0:
|
||||
await trio.sleep(0.005)
|
||||
|
||||
async def reader(stream, received, name):
|
||||
"""Read messages and store them for verification."""
|
||||
try:
|
||||
data = await stream.read()
|
||||
if data:
|
||||
received.append(data)
|
||||
except MuxedStreamEOF:
|
||||
pass
|
||||
|
||||
# Running all operations concurrently
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(writer, client_stream, client_msgs, "client")
|
||||
nursery.start_soon(writer, server_stream, server_msgs, "server")
|
||||
nursery.start_soon(reader, client_stream, client_received, "client")
|
||||
nursery.start_soon(reader, server_stream, server_received, "server")
|
||||
|
||||
assert client_received == server_msgs, (
|
||||
"Client did not receive server messages in order or intact!"
|
||||
)
|
||||
assert server_received == client_msgs, (
|
||||
"Server did not receive client messages in order or intact!"
|
||||
)
|
||||
for i, msg in enumerate(client_received):
|
||||
assert len(msg) == MSG_SIZE, (
|
||||
f"Client message {i} has wrong size: {len(msg)} != {MSG_SIZE}"
|
||||
)
|
||||
|
||||
for i, msg in enumerate(server_received):
|
||||
assert len(msg) == MSG_SIZE, (
|
||||
f"Server message {i} has wrong size: {len(msg)} != {MSG_SIZE}"
|
||||
)
|
||||
|
||||
await client_stream.close()
|
||||
await server_stream.close()
|
||||
@ -2,13 +2,30 @@ from unittest.mock import (
|
||||
MagicMock,
|
||||
)
|
||||
|
||||
import trio
|
||||
|
||||
def create_mock_connections() -> dict:
|
||||
from libp2p.abc import IHost
|
||||
|
||||
|
||||
def create_mock_connections(count: int = 50) -> dict:
|
||||
connections = {}
|
||||
|
||||
for i in range(1, 31):
|
||||
for i in range(1, count):
|
||||
peer_id = f"peer-{i}"
|
||||
mock_conn = MagicMock(name=f"INetConn-{i}")
|
||||
connections[peer_id] = mock_conn
|
||||
|
||||
return connections
|
||||
|
||||
|
||||
async def run_host_forever(host: IHost, addr):
|
||||
async with host.run([addr]):
|
||||
await trio.sleep_forever()
|
||||
|
||||
|
||||
async def wait_until_listening(host, timeout=3):
|
||||
with trio.move_on_after(timeout):
|
||||
while not host.get_addrs():
|
||||
await trio.sleep(0.05)
|
||||
return
|
||||
raise RuntimeError("Timed out waiting for host to get an address")
|
||||
|
||||
Reference in New Issue
Block a user