mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-11 23:51:07 +00:00
Compare commits
118 Commits
concurrenc
...
356192d793
| Author | SHA1 | Date | |
|---|---|---|---|
| 356192d793 | |||
| 5fcfc677f3 | |||
| dd14aad47c | |||
| 505d3b2a8f | |||
| f4eb0158fe | |||
| b716d64184 | |||
| 58b33ba2e8 | |||
| 0679efb299 | |||
| b21591f8d5 | |||
| d1c31483bd | |||
| 51c08de1bc | |||
| faeacf686a | |||
| 9943697054 | |||
| ff966bbfa0 | |||
| 1b025e552c | |||
| 4e53327079 | |||
| 3d369bc142 | |||
| 5de458482c | |||
| f3d8cbf968 | |||
| e6f96d32e2 | |||
| 51313a5909 | |||
| 8d2b889605 | |||
| bfe3dee781 | |||
| a7d122a0f9 | |||
| 8bfd4bde94 | |||
| 383d7cb722 | |||
| a89ba8ef81 | |||
| 31b6a6f237 | |||
| 5ac4fc1aba | |||
| 00ba846f7b | |||
| f96fe0c1b6 | |||
| 3403689495 | |||
| 2ee23fdec1 | |||
| 96c41773ea | |||
| 83b42b7850 | |||
| 5a95212697 | |||
| 007527ef75 | |||
| 975ea1bd8e | |||
| 88db4ceb21 | |||
| ad0b5505ba | |||
| 572c6915f6 | |||
| 98438916ad | |||
| 966cef58de | |||
| 01319638cd | |||
| a7eb9b5fbd | |||
| fee4208d89 | |||
| 4df454ebdc | |||
| 715e528a56 | |||
| d0e73f5438 | |||
| 8753024add | |||
| 621ea321ab | |||
| 6d8f695778 | |||
| ba231dab79 | |||
| 83d11db852 | |||
| cbe2b4bd99 | |||
| 0038ef99d4 | |||
| 4c739f6259 | |||
| 5306bdc8cb | |||
| 647034221a | |||
| ee2c979e71 | |||
| ad87e50eb7 | |||
| 92c9ba7e46 | |||
| ac51d87046 | |||
| 520e555b82 | |||
| 6be05639e4 | |||
| e8e0cf74d1 | |||
| 211e951678 | |||
| ef16f3c993 | |||
| 2201d9e8d2 | |||
| 7ed33e9a55 | |||
| 4eff928a6d | |||
| 644fc77b6c | |||
| 4947578139 | |||
| 460f502bb9 | |||
| 983a4a001c | |||
| 1fb3f9c72b | |||
| 8afb99c5b1 | |||
| ae16909f79 | |||
| b7d62c0f85 | |||
| c914818f48 | |||
| 5262566f6a | |||
| f274d20715 | |||
| f12ca4e9c1 | |||
| 4d8afa6448 | |||
| e50f9fc8e5 | |||
| 724375e1fa | |||
| 28d0e5759a | |||
| 9adf9aa499 | |||
| dcc8bbb619 | |||
| b258ff3ea2 | |||
| 31b694aa29 | |||
| 293087bd06 | |||
| 35248f8167 | |||
| e018af09ae | |||
| 7135e6cd4d | |||
| 77a9788a69 | |||
| 555e389109 | |||
| 8f0762f95c | |||
| 67bcad1674 | |||
| 3b53120092 | |||
| 89ed86d903 | |||
| 387f4879d1 | |||
| e2f95f4df3 | |||
| f43e7e367a | |||
| 3262749db7 | |||
| cd7eaba4a4 | |||
| 6add1cb685 | |||
| 742bc7bca3 | |||
| cbd4f9b502 | |||
| d7cdae8a0f | |||
| df17788ec3 | |||
| 209deffc8a | |||
| 0a7e13f0ed | |||
| 8bddbfb9bb | |||
| c33ab32c33 | |||
| 01b9e89e83 | |||
| d733b78dba | |||
| e397ce25a6 |
64
docs/examples.mDNS.rst
Normal file
64
docs/examples.mDNS.rst
Normal file
@ -0,0 +1,64 @@
|
||||
mDNS Peer Discovery Example
|
||||
===========================
|
||||
|
||||
This example demonstrates how to use mDNS (Multicast DNS) for peer discovery in py-libp2p.
|
||||
|
||||
Prerequisites
|
||||
-------------
|
||||
|
||||
First, ensure you have py-libp2p installed and your environment is activated:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python -m pip install libp2p
|
||||
|
||||
Running the Example
|
||||
-------------------
|
||||
|
||||
The mDNS demo script allows you to discover peers on your local network using mDNS. To start a peer, run:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ mdns-demo
|
||||
|
||||
You should see output similar to:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
Run this from another console to start another peer on a different port:
|
||||
|
||||
python mdns-demo -p <ANOTHER_PORT>
|
||||
|
||||
Waiting for mDNS peer discovery events...
|
||||
|
||||
2025-06-20 23:28:12,052 - libp2p.example.discovery.mdns - INFO - Starting peer Discovery
|
||||
|
||||
To discover peers, open another terminal and run the same command with a different port:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python mdns-demo -p 9001
|
||||
|
||||
You should see output indicating that a new peer has been discovered:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
Run this from the same folder in another console to start another peer on a different port:
|
||||
|
||||
python mdns-demo -p <ANOTHER_PORT>
|
||||
|
||||
Waiting for mDNS peer discovery events...
|
||||
|
||||
2025-06-20 23:43:43,786 - libp2p.example.discovery.mdns - INFO - Starting peer Discovery
|
||||
2025-06-20 23:43:43,790 - libp2p.example.discovery.mdns - INFO - Discovered: 16Uiu2HAmGxy5NdQEjZWtrYUMrzdp3Syvg7MB2E5Lx8weA9DanYxj
|
||||
|
||||
When a new peer is discovered, its peer ID will be printed in the console output.
|
||||
|
||||
How it Works
|
||||
------------
|
||||
|
||||
- Each node advertises itself on the local network using mDNS.
|
||||
- When a new peer is discovered, the handler prints its peer ID.
|
||||
- This is useful for local peer discovery without requiring a DHT or bootstrap nodes.
|
||||
|
||||
You can modify the script to perform additional actions when peers are discovered, such as opening streams or exchanging messages.
|
||||
@ -13,3 +13,4 @@ Examples
|
||||
examples.pubsub
|
||||
examples.circuit_relay
|
||||
examples.kademlia
|
||||
examples.mDNS
|
||||
|
||||
21
docs/libp2p.discovery.events.rst
Normal file
21
docs/libp2p.discovery.events.rst
Normal file
@ -0,0 +1,21 @@
|
||||
libp2p.discovery.events package
|
||||
===============================
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
libp2p.discovery.events.peerDiscovery module
|
||||
--------------------------------------------
|
||||
|
||||
.. automodule:: libp2p.discovery.events.peerDiscovery
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Module contents
|
||||
---------------
|
||||
|
||||
.. automodule:: libp2p.discovery.events
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
45
docs/libp2p.discovery.mdns.rst
Normal file
45
docs/libp2p.discovery.mdns.rst
Normal file
@ -0,0 +1,45 @@
|
||||
libp2p.discovery.mdns package
|
||||
=============================
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
libp2p.discovery.mdns.broadcaster module
|
||||
----------------------------------------
|
||||
|
||||
.. automodule:: libp2p.discovery.mdns.broadcaster
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.discovery.mdns.listener module
|
||||
-------------------------------------
|
||||
|
||||
.. automodule:: libp2p.discovery.mdns.listener
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.discovery.mdns.mdns module
|
||||
---------------------------------
|
||||
|
||||
.. automodule:: libp2p.discovery.mdns.mdns
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.discovery.mdns.utils module
|
||||
----------------------------------
|
||||
|
||||
.. automodule:: libp2p.discovery.mdns.utils
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Module contents
|
||||
---------------
|
||||
|
||||
.. automodule:: libp2p.discovery.mdns
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
22
docs/libp2p.discovery.rst
Normal file
22
docs/libp2p.discovery.rst
Normal file
@ -0,0 +1,22 @@
|
||||
libp2p.discovery package
|
||||
========================
|
||||
|
||||
Subpackages
|
||||
-----------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 4
|
||||
|
||||
libp2p.discovery.events
|
||||
libp2p.discovery.mdns
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
Module contents
|
||||
---------------
|
||||
|
||||
.. automodule:: libp2p.discovery
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
@ -8,6 +8,7 @@ Subpackages
|
||||
:maxdepth: 4
|
||||
|
||||
libp2p.crypto
|
||||
libp2p.discovery
|
||||
libp2p.host
|
||||
libp2p.identity
|
||||
libp2p.io
|
||||
|
||||
@ -3,6 +3,65 @@ Release Notes
|
||||
|
||||
.. towncrier release notes start
|
||||
|
||||
py-libp2p v0.2.9 (2025-07-09)
|
||||
-----------------------------
|
||||
|
||||
Breaking Changes
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
- Reordered the arguments to ``upgrade_security`` to place ``is_initiator`` before ``peer_id``, and made ``peer_id`` optional.
|
||||
This allows the method to reflect the fact that peer identity is not required for inbound connections. (`#681 <https://github.com/libp2p/py-libp2p/issues/681>`__)
|
||||
|
||||
|
||||
Bugfixes
|
||||
~~~~~~~~
|
||||
|
||||
- Add timeout wrappers in:
|
||||
1. ``multiselect.py``: ``negotiate`` function
|
||||
2. ``multiselect_client.py``: ``select_one_of`` , ``query_multistream_command`` functions
|
||||
to prevent indefinite hangs when a remote peer does not respond. (`#696 <https://github.com/libp2p/py-libp2p/issues/696>`__)
|
||||
- Align stream creation logic with yamux specification (`#701 <https://github.com/libp2p/py-libp2p/issues/701>`__)
|
||||
- Fixed an issue in ``Pubsub`` where async validators were not handled reliably under concurrency. Now uses a safe aggregator list for consistent behavior. (`#702 <https://github.com/libp2p/py-libp2p/issues/702>`__)
|
||||
|
||||
|
||||
Features
|
||||
~~~~~~~~
|
||||
|
||||
- Added support for ``Kademlia DHT`` in py-libp2p. (`#579 <https://github.com/libp2p/py-libp2p/issues/579>`__)
|
||||
- Limit concurrency in ``push_identify_to_peers`` to prevent resource congestion under high peer counts. (`#621 <https://github.com/libp2p/py-libp2p/issues/621>`__)
|
||||
- Store public key and peer ID in peerstore during handshake
|
||||
|
||||
Modified the InsecureTransport class to accept an optional peerstore parameter and updated the handshake process to store the received public key and peer ID in the peerstore when available.
|
||||
|
||||
Added test cases to verify:
|
||||
1. The peerstore remains unchanged when handshake fails due to peer ID mismatch
|
||||
2. The handshake correctly adds a public key to a peer ID that already exists in the peerstore but doesn't have a public key yet (`#631 <https://github.com/libp2p/py-libp2p/issues/631>`__)
|
||||
- Fixed several flow-control and concurrency issues in the ``YamuxStream`` class. Previously, stress-testing revealed that transferring data over ``DEFAULT_WINDOW_SIZE`` would break the stream due to inconsistent window update handling and lock management. The fixes include:
|
||||
|
||||
- Removed sending of window updates during writes to maintain correct flow-control.
|
||||
- Added proper timeout handling when releasing and acquiring locks to prevent concurrency errors.
|
||||
- Corrected the ``read`` function to properly handle window updates for both ``read_until_EOF`` and ``read_n_bytes``.
|
||||
- Added event logging at ``send_window_updates`` and ``waiting_for_window_updates`` for better observability. (`#639 <https://github.com/libp2p/py-libp2p/issues/639>`__)
|
||||
- Added support for ``Multicast DNS`` in py-libp2p (`#649 <https://github.com/libp2p/py-libp2p/issues/649>`__)
|
||||
- Optimized pubsub publishing to send multiple topics in a single message instead of separate messages per topic. (`#685 <https://github.com/libp2p/py-libp2p/issues/685>`__)
|
||||
- Optimized pubsub message writing by implementing a write_msg() method that uses pre-allocated buffers and single write operations, improving performance by eliminating separate varint prefix encoding and write operations in FloodSub and GossipSub. (`#687 <https://github.com/libp2p/py-libp2p/issues/687>`__)
|
||||
- Added peer exchange and backoff logic as part of Gossipsub v1.1 upgrade (`#690 <https://github.com/libp2p/py-libp2p/issues/690>`__)
|
||||
|
||||
|
||||
Internal Changes - for py-libp2p Contributors
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
- Added sparse connect utility function to pubsub test utilities for creating test networks with configurable connectivity. (`#679 <https://github.com/libp2p/py-libp2p/issues/679>`__)
|
||||
- Added comprehensive tests for pubsub connection utility functions to verify degree limits are enforced, excess peers are handled correctly, and edge cases (degree=0, negative values, empty lists) are managed gracefully. (`#707 <https://github.com/libp2p/py-libp2p/issues/707>`__)
|
||||
- Added extra tests for identify push concurrency cap under high peer load (`#708 <https://github.com/libp2p/py-libp2p/issues/708>`__)
|
||||
|
||||
|
||||
Miscellaneous Changes
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
- `#678 <https://github.com/libp2p/py-libp2p/issues/678>`__, `#684 <https://github.com/libp2p/py-libp2p/issues/684>`__
|
||||
|
||||
|
||||
py-libp2p v0.2.8 (2025-06-10)
|
||||
-----------------------------
|
||||
|
||||
|
||||
74
examples/mDNS/mDNS.py
Normal file
74
examples/mDNS/mDNS.py
Normal file
@ -0,0 +1,74 @@
|
||||
import argparse
|
||||
import logging
|
||||
import secrets
|
||||
|
||||
import multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p import (
|
||||
new_host,
|
||||
)
|
||||
from libp2p.abc import PeerInfo
|
||||
from libp2p.crypto.secp256k1 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.discovery.events.peerDiscovery import peerDiscovery
|
||||
|
||||
logger = logging.getLogger("libp2p.discovery.mdns")
|
||||
logger.setLevel(logging.INFO)
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(
|
||||
logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
)
|
||||
logger.addHandler(handler)
|
||||
|
||||
# Set root logger to DEBUG to capture all logs from dependencies
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
def onPeerDiscovery(peerinfo: PeerInfo):
|
||||
logger.info(f"Discovered: {peerinfo.peer_id}")
|
||||
|
||||
|
||||
async def run(port: int) -> None:
|
||||
secret = secrets.token_bytes(32)
|
||||
key_pair = create_new_key_pair(secret)
|
||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
||||
|
||||
peerDiscovery.register_peer_discovered_handler(onPeerDiscovery)
|
||||
|
||||
print(
|
||||
"Run this from the same folder in another console to "
|
||||
"start another peer on a different port:\n\n"
|
||||
"mdns-demo -p <ANOTHER_PORT>\n"
|
||||
)
|
||||
print("Waiting for mDNS peer discovery events...\n")
|
||||
|
||||
logger.info("Starting peer Discovery")
|
||||
host = new_host(key_pair=key_pair, enable_mDNS=True)
|
||||
async with host.run(listen_addrs=[listen_addr]):
|
||||
await trio.sleep_forever()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
description = """
|
||||
This program demonstrates mDNS peer discovery using libp2p.
|
||||
To use it, run 'mdns-demo -p <PORT>', where <PORT> is the port number.
|
||||
Start multiple peers on different ports to see discovery in action.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
|
||||
parser.add_argument(
|
||||
"-v", "--verbose", action="store_true", help="Enable verbose output"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.verbose:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
try:
|
||||
trio.run(run, args.port)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Exiting...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -32,6 +32,9 @@ from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
TSecurityOptions,
|
||||
)
|
||||
from libp2p.discovery.mdns.mdns import (
|
||||
MDNSDiscovery,
|
||||
)
|
||||
from libp2p.host.basic_host import (
|
||||
BasicHost,
|
||||
)
|
||||
@ -81,6 +84,8 @@ DEFAULT_MUXER = "YAMUX"
|
||||
# Multiplexer options
|
||||
MUXER_YAMUX = "YAMUX"
|
||||
MUXER_MPLEX = "MPLEX"
|
||||
DEFAULT_NEGOTIATE_TIMEOUT = 5
|
||||
|
||||
|
||||
|
||||
def set_default_muxer(muxer_name: Literal["YAMUX", "MPLEX"]) -> None:
|
||||
@ -245,6 +250,8 @@ def new_host(
|
||||
disc_opt: IPeerRouting | None = None,
|
||||
muxer_preference: Literal["YAMUX", "MPLEX"] | None = None,
|
||||
listen_addrs: Sequence[multiaddr.Multiaddr] | None = None,
|
||||
enable_mDNS: bool = False,
|
||||
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
) -> IHost:
|
||||
"""
|
||||
Create a new libp2p host based on the given parameters.
|
||||
@ -256,6 +263,7 @@ def new_host(
|
||||
:param disc_opt: optional discovery
|
||||
:param muxer_preference: optional explicit muxer preference
|
||||
:param listen_addrs: optional list of multiaddrs to listen on
|
||||
:param enable_mDNS: whether to enable mDNS discovery
|
||||
:return: return a host instance
|
||||
"""
|
||||
swarm = new_swarm(
|
||||
@ -268,8 +276,7 @@ def new_host(
|
||||
)
|
||||
|
||||
if disc_opt is not None:
|
||||
return RoutedHost(swarm, disc_opt)
|
||||
return BasicHost(swarm)
|
||||
|
||||
return RoutedHost(swarm, disc_opt, enable_mDNS)
|
||||
return BasicHost(network=swarm,enable_mDNS=enable_mDNS , negotitate_timeout=negotiate_timeout)
|
||||
|
||||
__version__ = __version("libp2p")
|
||||
|
||||
764
libp2p/abc.py
764
libp2p/abc.py
@ -50,6 +50,11 @@ if TYPE_CHECKING:
|
||||
Pubsub,
|
||||
)
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from libp2p.protocol_muxer.multiselect import Multiselect
|
||||
|
||||
from libp2p.pubsub.pb import (
|
||||
rpc_pb2,
|
||||
)
|
||||
@ -385,6 +390,18 @@ class IPeerMetadata(ABC):
|
||||
:raises Exception: If the operation is unsuccessful.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def clear_metadata(self, peer_id: ID) -> None:
|
||||
"""
|
||||
Remove all stored metadata for the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer identifier whose metadata are to be removed.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# -------------------------- addrbook interface.py --------------------------
|
||||
|
||||
@ -476,10 +493,272 @@ class IAddrBook(ABC):
|
||||
"""
|
||||
|
||||
|
||||
# -------------------------- keybook interface.py --------------------------
|
||||
|
||||
|
||||
class IKeyBook(ABC):
|
||||
"""
|
||||
Interface for an key book.
|
||||
|
||||
Provides methods for managing cryptographic keys.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def pubkey(self, peer_id: ID) -> PublicKey:
|
||||
"""
|
||||
Returns the public key of the specified peer
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer identifier whose public key is to be returned.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def privkey(self, peer_id: ID) -> PrivateKey:
|
||||
"""
|
||||
Returns the private key of the specified peer
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer identifier whose private key is to be returned.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_pubkey(self, peer_id: ID, pubkey: PublicKey) -> None:
|
||||
"""
|
||||
Adds the public key for a specified peer
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer identifier whose public key is to be added
|
||||
pubkey: PublicKey
|
||||
The public key of the peer
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_privkey(self, peer_id: ID, privkey: PrivateKey) -> None:
|
||||
"""
|
||||
Adds the private key for a specified peer
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer identifier whose private key is to be added
|
||||
privkey: PrivateKey
|
||||
The private key of the peer
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_key_pair(self, peer_id: ID, key_pair: KeyPair) -> None:
|
||||
"""
|
||||
Adds the key pair for a specified peer
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer identifier whose key pair is to be added
|
||||
key_pair: KeyPair
|
||||
The key pair of the peer
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def peer_with_keys(self) -> list[ID]:
|
||||
"""Returns all the peer IDs stored in the AddrBook"""
|
||||
|
||||
@abstractmethod
|
||||
def clear_keydata(self, peer_id: ID) -> None:
|
||||
"""
|
||||
Remove all stored keydata for the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer identifier whose keys are to be removed.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# -------------------------- metrics interface.py --------------------------
|
||||
|
||||
|
||||
class IMetrics(ABC):
|
||||
"""
|
||||
Interface for metrics of peer interaction.
|
||||
|
||||
Provides methods for managing the metrics.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def record_latency(self, peer_id: ID, RTT: float) -> None:
|
||||
"""
|
||||
Records a new round-trip time (RTT) latency value for the specified peer
|
||||
using Exponentially Weighted Moving Average (EWMA).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer for which latency is being recorded.
|
||||
|
||||
RTT : float
|
||||
The round-trip time latency value to record.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def latency_EWMA(self, peer_id: ID) -> float:
|
||||
"""
|
||||
Returns the current latency value for the specified peer using
|
||||
Exponentially Weighted Moving Average (EWMA).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer whose latency EWMA is to be returned.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def clear_metrics(self, peer_id: ID) -> None:
|
||||
"""
|
||||
Clears the stored latency metrics for the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer whose latency metrics are to be cleared.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# -------------------------- protobook interface.py --------------------------
|
||||
|
||||
|
||||
class IProtoBook(ABC):
|
||||
"""
|
||||
Interface for a protocol book.
|
||||
|
||||
Provides methods for managing the list of supported protocols.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_protocols(self, peer_id: ID) -> list[str]:
|
||||
"""
|
||||
Returns the list of protocols associated with the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer whose supported protocols are to be returned.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_protocols(self, peer_id: ID, protocols: Sequence[str]) -> None:
|
||||
"""
|
||||
Adds the given protocols to the specified peer's protocol list.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer to which protocols will be added.
|
||||
|
||||
protocols : Sequence[str]
|
||||
A sequence of protocol strings to add.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_protocols(self, peer_id: ID, protocols: Sequence[str]) -> None:
|
||||
"""
|
||||
Replaces the existing protocols of the specified peer with the given list.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer whose protocols are to be set.
|
||||
|
||||
protocols : Sequence[str]
|
||||
A sequence of protocol strings to assign.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def remove_protocols(self, peer_id: ID, protocols: Sequence[str]) -> None:
|
||||
"""
|
||||
Removes the specified protocols from the peer's protocol list.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer from which protocols will be removed.
|
||||
|
||||
protocols : Sequence[str]
|
||||
A sequence of protocol strings to remove.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def supports_protocols(self, peer_id: ID, protocols: Sequence[str]) -> list[str]:
|
||||
"""
|
||||
Returns the list of protocols from the input sequence that the peer supports.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer to check for protocol support.
|
||||
|
||||
protocols : Sequence[str]
|
||||
A sequence of protocol strings to check against the peer's
|
||||
supported protocols.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def first_supported_protocol(self, peer_id: ID, protocols: Sequence[str]) -> str:
|
||||
"""
|
||||
Returns the first protocol from the input list that the peer supports.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer to check for supported protocols.
|
||||
|
||||
protocols : Sequence[str]
|
||||
A sequence of protocol strings to check.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The first matching protocol string, or an empty string
|
||||
if none are supported.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def clear_protocol_data(self, peer_id: ID) -> None:
|
||||
"""
|
||||
Clears all protocol data associated with the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer whose protocol data will be cleared.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# -------------------------- peerstore interface.py --------------------------
|
||||
|
||||
|
||||
class IPeerStore(IAddrBook, IPeerMetadata):
|
||||
class IPeerStore(IPeerMetadata, IAddrBook, IKeyBook, IMetrics, IProtoBook):
|
||||
"""
|
||||
Interface for a peer store.
|
||||
|
||||
@ -487,85 +766,7 @@ class IPeerStore(IAddrBook, IPeerMetadata):
|
||||
management, protocol handling, and key storage.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def peer_info(self, peer_id: ID) -> PeerInfo:
|
||||
"""
|
||||
Retrieve the peer information for the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer.
|
||||
|
||||
Returns
|
||||
-------
|
||||
PeerInfo
|
||||
The peer information object for the given peer.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_protocols(self, peer_id: ID) -> list[str]:
|
||||
"""
|
||||
Retrieve the protocols associated with the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[str]
|
||||
A list of protocol identifiers.
|
||||
|
||||
Raises
|
||||
------
|
||||
PeerStoreError
|
||||
If the peer ID is not found.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_protocols(self, peer_id: ID, protocols: Sequence[str]) -> None:
|
||||
"""
|
||||
Add additional protocols for the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer.
|
||||
protocols : Sequence[str]
|
||||
The protocols to add.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_protocols(self, peer_id: ID, protocols: Sequence[str]) -> None:
|
||||
"""
|
||||
Set the protocols for the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer.
|
||||
protocols : Sequence[str]
|
||||
The protocols to set.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def peer_ids(self) -> list[ID]:
|
||||
"""
|
||||
Retrieve all peer identifiers stored in the peer store.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[ID]
|
||||
A list of all peer IDs in the store.
|
||||
|
||||
"""
|
||||
|
||||
# -------METADATA---------
|
||||
@abstractmethod
|
||||
def get(self, peer_id: ID, key: str) -> Any:
|
||||
"""
|
||||
@ -606,6 +807,19 @@ class IPeerStore(IAddrBook, IPeerMetadata):
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def clear_metadata(self, peer_id: ID) -> None:
|
||||
"""
|
||||
Clears the stored latency metrics for the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer whose latency metrics are to be cleared.
|
||||
|
||||
"""
|
||||
|
||||
# --------ADDR-BOOK---------
|
||||
@abstractmethod
|
||||
def add_addr(self, peer_id: ID, addr: Multiaddr, ttl: int) -> None:
|
||||
"""
|
||||
@ -679,25 +893,7 @@ class IPeerStore(IAddrBook, IPeerMetadata):
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_pubkey(self, peer_id: ID, pubkey: PublicKey) -> None:
|
||||
"""
|
||||
Add a public key for the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer.
|
||||
pubkey : PublicKey
|
||||
The public key to add.
|
||||
|
||||
Raises
|
||||
------
|
||||
PeerStoreError
|
||||
If the peer already has a public key set.
|
||||
|
||||
"""
|
||||
|
||||
# --------KEY-BOOK----------
|
||||
@abstractmethod
|
||||
def pubkey(self, peer_id: ID) -> PublicKey:
|
||||
"""
|
||||
@ -720,25 +916,6 @@ class IPeerStore(IAddrBook, IPeerMetadata):
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_privkey(self, peer_id: ID, privkey: PrivateKey) -> None:
|
||||
"""
|
||||
Add a private key for the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer.
|
||||
privkey : PrivateKey
|
||||
The private key to add.
|
||||
|
||||
Raises
|
||||
------
|
||||
PeerStoreError
|
||||
If the peer already has a private key set.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def privkey(self, peer_id: ID) -> PrivateKey:
|
||||
"""
|
||||
@ -761,6 +938,44 @@ class IPeerStore(IAddrBook, IPeerMetadata):
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_pubkey(self, peer_id: ID, pubkey: PublicKey) -> None:
|
||||
"""
|
||||
Add a public key for the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer.
|
||||
pubkey : PublicKey
|
||||
The public key to add.
|
||||
|
||||
Raises
|
||||
------
|
||||
PeerStoreError
|
||||
If the peer already has a public key set.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_privkey(self, peer_id: ID, privkey: PrivateKey) -> None:
|
||||
"""
|
||||
Add a private key for the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer.
|
||||
privkey : PrivateKey
|
||||
The private key to add.
|
||||
|
||||
Raises
|
||||
------
|
||||
PeerStoreError
|
||||
If the peer already has a private key set.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_key_pair(self, peer_id: ID, key_pair: KeyPair) -> None:
|
||||
"""
|
||||
@ -780,6 +995,213 @@ class IPeerStore(IAddrBook, IPeerMetadata):
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def peer_with_keys(self) -> list[ID]:
|
||||
"""Returns all the peer IDs stored in the AddrBook"""
|
||||
|
||||
@abstractmethod
|
||||
def clear_keydata(self, peer_id: ID) -> None:
|
||||
"""
|
||||
Remove all stored keydata for the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer identifier whose keys are to be removed.
|
||||
|
||||
"""
|
||||
|
||||
# -------METRICS---------
|
||||
@abstractmethod
|
||||
def record_latency(self, peer_id: ID, RTT: float) -> None:
|
||||
"""
|
||||
Records a new round-trip time (RTT) latency value for the specified peer
|
||||
using Exponentially Weighted Moving Average (EWMA).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer for which latency is being recorded.
|
||||
|
||||
RTT : float
|
||||
The round-trip time latency value to record.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def latency_EWMA(self, peer_id: ID) -> float:
|
||||
"""
|
||||
Returns the current latency value for the specified peer using
|
||||
Exponentially Weighted Moving Average (EWMA).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer whose latency EWMA is to be returned.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def clear_metrics(self, peer_id: ID) -> None:
|
||||
"""
|
||||
Clears the stored latency metrics for the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer whose latency metrics are to be cleared.
|
||||
|
||||
"""
|
||||
|
||||
# --------PROTO-BOOK----------
|
||||
@abstractmethod
|
||||
def get_protocols(self, peer_id: ID) -> list[str]:
|
||||
"""
|
||||
Retrieve the protocols associated with the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[str]
|
||||
A list of protocol identifiers.
|
||||
|
||||
Raises
|
||||
------
|
||||
PeerStoreError
|
||||
If the peer ID is not found.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_protocols(self, peer_id: ID, protocols: Sequence[str]) -> None:
|
||||
"""
|
||||
Add additional protocols for the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer.
|
||||
protocols : Sequence[str]
|
||||
The protocols to add.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_protocols(self, peer_id: ID, protocols: Sequence[str]) -> None:
|
||||
"""
|
||||
Set the protocols for the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer.
|
||||
protocols : Sequence[str]
|
||||
The protocols to set.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def remove_protocols(self, peer_id: ID, protocols: Sequence[str]) -> None:
|
||||
"""
|
||||
Removes the specified protocols from the peer's protocol list.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer from which protocols will be removed.
|
||||
|
||||
protocols : Sequence[str]
|
||||
A sequence of protocol strings to remove.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def supports_protocols(self, peer_id: ID, protocols: Sequence[str]) -> list[str]:
|
||||
"""
|
||||
Returns the list of protocols from the input sequence that the peer supports.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer to check for protocol support.
|
||||
|
||||
protocols : Sequence[str]
|
||||
A sequence of protocol strings to check against the peer's
|
||||
supported protocols.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def first_supported_protocol(self, peer_id: ID, protocols: Sequence[str]) -> str:
|
||||
"""
|
||||
Returns the first protocol from the input list that the peer supports.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer to check for supported protocols.
|
||||
|
||||
protocols : Sequence[str]
|
||||
A sequence of protocol strings to check.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The first matching protocol string, or an empty string
|
||||
if none are supported.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def clear_protocol_data(self, peer_id: ID) -> None:
|
||||
"""
|
||||
Clears all protocol data associated with the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer whose protocol data will be cleared.
|
||||
|
||||
"""
|
||||
|
||||
# --------PEER-STORE--------
|
||||
@abstractmethod
|
||||
def peer_info(self, peer_id: ID) -> PeerInfo:
|
||||
"""
|
||||
Retrieve the peer information for the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The identifier of the peer.
|
||||
|
||||
Returns
|
||||
-------
|
||||
PeerInfo
|
||||
The peer information object for the given peer.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def peer_ids(self) -> list[ID]:
|
||||
"""
|
||||
Retrieve all peer identifiers stored in the peer store.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[ID]
|
||||
A list of all peer IDs in the store.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def clear_peerdata(self, peer_id: ID) -> None:
|
||||
"""clear_peerdata"""
|
||||
|
||||
|
||||
# -------------------------- listener interface.py --------------------------
|
||||
|
||||
@ -1128,9 +1550,8 @@ class IHost(ABC):
|
||||
|
||||
"""
|
||||
|
||||
# FIXME: Replace with correct return type
|
||||
@abstractmethod
|
||||
def get_mux(self) -> Any:
|
||||
def get_mux(self) -> "Multiselect":
|
||||
"""
|
||||
Retrieve the muxer instance for the host.
|
||||
|
||||
@ -1315,6 +1736,60 @@ class IPeerData(ABC):
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def remove_protocols(self, protocols: Sequence[str]) -> None:
|
||||
"""
|
||||
Removes the specified protocols from this peer's list of supported protocols.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
protocols : Sequence[str]
|
||||
A sequence of protocol strings to be removed.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def supports_protocols(self, protocols: Sequence[str]) -> list[str]:
|
||||
"""
|
||||
Returns the list of protocols from the input sequence that are supported
|
||||
by this peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
protocols : Sequence[str]
|
||||
A sequence of protocol strings to check against this peer's supported
|
||||
protocols.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[str]
|
||||
A list of protocol strings that are supported.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def first_supported_protocol(self, protocols: Sequence[str]) -> str:
|
||||
"""
|
||||
Returns the first protocol from the input list that this peer supports.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
protocols : Sequence[str]
|
||||
A sequence of protocol strings to check for support.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The first matching protocol, or an empty string if none are supported.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def clear_protocol_data(self) -> None:
|
||||
"""
|
||||
Clears all protocol data associated with this peer.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_addrs(self, addrs: Sequence[Multiaddr]) -> None:
|
||||
"""
|
||||
@ -1324,6 +1799,8 @@ class IPeerData(ABC):
|
||||
----------
|
||||
addrs : Sequence[Multiaddr]
|
||||
A sequence of multiaddresses to add.
|
||||
ttl: inr
|
||||
Time to live for the peer record
|
||||
|
||||
"""
|
||||
|
||||
@ -1382,6 +1859,12 @@ class IPeerData(ABC):
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def clear_metadata(self) -> None:
|
||||
"""
|
||||
Clears all metadata entries associated with this peer.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_pubkey(self, pubkey: PublicKey) -> None:
|
||||
"""
|
||||
@ -1440,6 +1923,45 @@ class IPeerData(ABC):
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def clear_keydata(self) -> None:
|
||||
"""
|
||||
Clears all cryptographic key data associated with this peer,
|
||||
including both public and private keys.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def record_latency(self, new_latency: float) -> None:
|
||||
"""
|
||||
Records a new latency measurement using
|
||||
Exponentially Weighted Moving Average (EWMA).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
new_latency : float
|
||||
The new round-trip time (RTT) latency value to incorporate
|
||||
into the EWMA calculation.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def latency_EWMA(self) -> float:
|
||||
"""
|
||||
Returns the current EWMA value of the recorded latency.
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
The current latency estimate based on EWMA.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def clear_metrics(self) -> None:
|
||||
"""
|
||||
Clears all latency-related metrics and resets the internal state.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def update_last_identified(self) -> None:
|
||||
"""
|
||||
@ -1640,6 +2162,7 @@ class IMultiselectMuxer(ABC):
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_protocols(self) -> tuple[TProtocol | None, ...]:
|
||||
"""
|
||||
Retrieve the protocols for which handlers have been registered.
|
||||
@ -1650,7 +2173,6 @@ class IMultiselectMuxer(ABC):
|
||||
A tuple of registered protocol names.
|
||||
|
||||
"""
|
||||
return tuple(self.handlers.keys())
|
||||
|
||||
@abstractmethod
|
||||
async def negotiate(
|
||||
|
||||
0
libp2p/discovery/__init__.py
Normal file
0
libp2p/discovery/__init__.py
Normal file
0
libp2p/discovery/events/__init__.py
Normal file
0
libp2p/discovery/events/__init__.py
Normal file
26
libp2p/discovery/events/peerDiscovery.py
Normal file
26
libp2p/discovery/events/peerDiscovery.py
Normal file
@ -0,0 +1,26 @@
|
||||
from collections.abc import (
|
||||
Callable,
|
||||
)
|
||||
|
||||
from libp2p.abc import (
|
||||
PeerInfo,
|
||||
)
|
||||
|
||||
TTL: int = 60 * 60 # Time-to-live for discovered peers in seconds
|
||||
|
||||
|
||||
class PeerDiscovery:
|
||||
def __init__(self) -> None:
|
||||
self._peer_discovered_handlers: list[Callable[[PeerInfo], None]] = []
|
||||
|
||||
def register_peer_discovered_handler(
|
||||
self, handler: Callable[[PeerInfo], None]
|
||||
) -> None:
|
||||
self._peer_discovered_handlers.append(handler)
|
||||
|
||||
def emit_peer_discovered(self, peer_info: PeerInfo) -> None:
|
||||
for handler in self._peer_discovered_handlers:
|
||||
handler(peer_info)
|
||||
|
||||
|
||||
peerDiscovery = PeerDiscovery()
|
||||
0
libp2p/discovery/mdns/__init__.py
Normal file
0
libp2p/discovery/mdns/__init__.py
Normal file
91
libp2p/discovery/mdns/broadcaster.py
Normal file
91
libp2p/discovery/mdns/broadcaster.py
Normal file
@ -0,0 +1,91 @@
|
||||
import logging
|
||||
import socket
|
||||
|
||||
from zeroconf import (
|
||||
EventLoopBlocked,
|
||||
ServiceInfo,
|
||||
Zeroconf,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("libp2p.discovery.mdns.broadcaster")
|
||||
|
||||
|
||||
class PeerBroadcaster:
|
||||
"""
|
||||
Broadcasts this peer's presence on the local network using mDNS/zeroconf.
|
||||
Registers a service with the peer's ID in the TXT record as per libp2p spec.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
zeroconf: Zeroconf,
|
||||
service_type: str,
|
||||
service_name: str,
|
||||
peer_id: str,
|
||||
port: int,
|
||||
):
|
||||
self.zeroconf = zeroconf
|
||||
self.service_type = service_type
|
||||
self.peer_id = peer_id
|
||||
self.port = port
|
||||
self.service_name = service_name
|
||||
|
||||
# Get the local IP address
|
||||
local_ip = self._get_local_ip()
|
||||
hostname = socket.gethostname()
|
||||
|
||||
self.service_info = ServiceInfo(
|
||||
type_=self.service_type,
|
||||
name=self.service_name,
|
||||
port=self.port,
|
||||
properties={b"id": self.peer_id.encode()},
|
||||
server=f"{hostname}.local.",
|
||||
addresses=[socket.inet_aton(local_ip)],
|
||||
)
|
||||
|
||||
def _get_local_ip(self) -> str:
|
||||
"""Get the local IP address of this machine"""
|
||||
try:
|
||||
# Connect to a remote address to determine the local IP
|
||||
# This doesn't actually send data
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
||||
s.connect(("8.8.8.8", 80))
|
||||
local_ip = s.getsockname()[0]
|
||||
return local_ip
|
||||
except Exception:
|
||||
# Fallback to localhost if we can't determine the IP
|
||||
return "127.0.0.1"
|
||||
|
||||
def register(self) -> None:
|
||||
"""Register the peer's mDNS service on the network."""
|
||||
try:
|
||||
self.zeroconf.register_service(self.service_info)
|
||||
logger.debug(f"mDNS service registered: {self.service_name}")
|
||||
except EventLoopBlocked as e:
|
||||
logger.warning(
|
||||
"EventLoopBlocked while registering mDNS '%s': %s", self.service_name, e
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Unexpected error during mDNS registration for '%s': %r",
|
||||
self.service_name,
|
||||
e,
|
||||
)
|
||||
|
||||
def unregister(self) -> None:
|
||||
"""Unregister the peer's mDNS service from the network."""
|
||||
try:
|
||||
self.zeroconf.unregister_service(self.service_info)
|
||||
logger.debug(f"mDNS service unregistered: {self.service_name}")
|
||||
except EventLoopBlocked as e:
|
||||
logger.warning(
|
||||
"EventLoopBlocked while unregistering mDNS '%s': %s",
|
||||
self.service_name,
|
||||
e,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Unexpected error during mDNS unregistration for '%s': %r",
|
||||
self.service_name,
|
||||
e,
|
||||
)
|
||||
83
libp2p/discovery/mdns/listener.py
Normal file
83
libp2p/discovery/mdns/listener.py
Normal file
@ -0,0 +1,83 @@
|
||||
import logging
|
||||
import socket
|
||||
|
||||
from zeroconf import (
|
||||
ServiceBrowser,
|
||||
ServiceInfo,
|
||||
ServiceListener,
|
||||
Zeroconf,
|
||||
)
|
||||
|
||||
from libp2p.abc import IPeerStore, Multiaddr
|
||||
from libp2p.discovery.events.peerDiscovery import peerDiscovery
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import PeerInfo
|
||||
|
||||
logger = logging.getLogger("libp2p.discovery.mdns.listener")
|
||||
|
||||
|
||||
class PeerListener(ServiceListener):
|
||||
"""mDNS listener — now a true ServiceListener subclass."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
peerstore: IPeerStore,
|
||||
zeroconf: Zeroconf,
|
||||
service_type: str,
|
||||
service_name: str,
|
||||
) -> None:
|
||||
self.peerstore = peerstore
|
||||
self.zeroconf = zeroconf
|
||||
self.service_type = service_type
|
||||
self.service_name = service_name
|
||||
self.discovered_services: dict[str, ID] = {}
|
||||
self.browser = ServiceBrowser(self.zeroconf, self.service_type, listener=self)
|
||||
|
||||
def add_service(self, zc: Zeroconf, type_: str, name: str) -> None:
|
||||
if name == self.service_name:
|
||||
return
|
||||
logger.debug(f"Adding service: {name}")
|
||||
info = zc.get_service_info(type_, name, timeout=5000)
|
||||
if not info:
|
||||
return
|
||||
peer_info = self._extract_peer_info(info)
|
||||
if peer_info:
|
||||
self.discovered_services[name] = peer_info.peer_id
|
||||
self.peerstore.add_addrs(peer_info.peer_id, peer_info.addrs, 10)
|
||||
peerDiscovery.emit_peer_discovered(peer_info)
|
||||
logger.debug(f"Discovered Peer: {peer_info.peer_id}")
|
||||
|
||||
def remove_service(self, zc: Zeroconf, type_: str, name: str) -> None:
|
||||
if name == self.service_name:
|
||||
return
|
||||
logger.debug(f"Removing service: {name}")
|
||||
peer_id = self.discovered_services.pop(name)
|
||||
self.peerstore.clear_addrs(peer_id)
|
||||
logger.debug(f"Removed Peer: {peer_id}")
|
||||
|
||||
def update_service(self, zc: Zeroconf, type_: str, name: str) -> None:
|
||||
info = zc.get_service_info(type_, name, timeout=5000)
|
||||
if not info:
|
||||
return
|
||||
peer_info = self._extract_peer_info(info)
|
||||
if peer_info:
|
||||
self.peerstore.clear_addrs(peer_info.peer_id)
|
||||
self.peerstore.add_addrs(peer_info.peer_id, peer_info.addrs, 10)
|
||||
logger.debug(f"Updated Peer {peer_info.peer_id}")
|
||||
|
||||
def _extract_peer_info(self, info: ServiceInfo) -> PeerInfo | None:
|
||||
try:
|
||||
addrs = [
|
||||
Multiaddr(f"/ip4/{socket.inet_ntoa(addr)}/tcp/{info.port}")
|
||||
for addr in info.addresses
|
||||
]
|
||||
pid_bytes = info.properties.get(b"id")
|
||||
if not pid_bytes:
|
||||
return None
|
||||
pid = ID.from_base58(pid_bytes.decode())
|
||||
return PeerInfo(peer_id=pid, addrs=addrs)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def stop(self) -> None:
|
||||
self.browser.cancel()
|
||||
73
libp2p/discovery/mdns/mdns.py
Normal file
73
libp2p/discovery/mdns/mdns.py
Normal file
@ -0,0 +1,73 @@
|
||||
"""
|
||||
mDNS-based peer discovery for py-libp2p.
|
||||
Conforms to https://github.com/libp2p/specs/blob/master/discovery/mdns.md
|
||||
Uses zeroconf for mDNS broadcast/listen. Async operations use trio.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from zeroconf import (
|
||||
Zeroconf,
|
||||
)
|
||||
|
||||
from libp2p.abc import (
|
||||
INetworkService,
|
||||
)
|
||||
|
||||
from .broadcaster import (
|
||||
PeerBroadcaster,
|
||||
)
|
||||
from .listener import (
|
||||
PeerListener,
|
||||
)
|
||||
from .utils import (
|
||||
stringGen,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("libp2p.discovery.mdns")
|
||||
|
||||
SERVICE_TYPE = "_p2p._udp.local."
|
||||
MCAST_PORT = 5353
|
||||
MCAST_ADDR = "224.0.0.251"
|
||||
|
||||
|
||||
class MDNSDiscovery:
|
||||
"""
|
||||
mDNS-based peer discovery for py-libp2p, using zeroconf.
|
||||
Conforms to the libp2p mDNS discovery spec.
|
||||
"""
|
||||
|
||||
def __init__(self, swarm: INetworkService, port: int = 8000):
|
||||
self.peer_id = str(swarm.get_peer_id())
|
||||
self.port = port
|
||||
self.zeroconf = Zeroconf()
|
||||
self.serviceName = f"{stringGen()}.{SERVICE_TYPE}"
|
||||
self.peerstore = swarm.peerstore
|
||||
self.swarm = swarm
|
||||
self.broadcaster = PeerBroadcaster(
|
||||
zeroconf=self.zeroconf,
|
||||
service_type=SERVICE_TYPE,
|
||||
service_name=self.serviceName,
|
||||
peer_id=self.peer_id,
|
||||
port=self.port,
|
||||
)
|
||||
self.listener = PeerListener(
|
||||
zeroconf=self.zeroconf,
|
||||
peerstore=self.peerstore,
|
||||
service_type=SERVICE_TYPE,
|
||||
service_name=self.serviceName,
|
||||
)
|
||||
|
||||
def start(self) -> None:
|
||||
"""Register this peer and start listening for others."""
|
||||
logger.debug(
|
||||
f"Starting mDNS discovery for peer {self.peer_id} on port {self.port}"
|
||||
)
|
||||
self.broadcaster.register()
|
||||
# Listener is started in constructor
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Unregister this peer and clean up zeroconf resources."""
|
||||
logger.debug("Stopping mDNS discovery")
|
||||
self.broadcaster.unregister()
|
||||
self.zeroconf.close()
|
||||
11
libp2p/discovery/mdns/utils.py
Normal file
11
libp2p/discovery/mdns/utils.py
Normal file
@ -0,0 +1,11 @@
|
||||
import random
|
||||
import string
|
||||
|
||||
|
||||
def stringGen(len: int = 63) -> str:
|
||||
"""Generate a random string of lowercase letters and digits."""
|
||||
charset = string.ascii_lowercase + string.digits
|
||||
result = []
|
||||
for _ in range(len):
|
||||
result.append(random.choice(charset))
|
||||
return "".join(result)
|
||||
@ -29,6 +29,7 @@ from libp2p.custom_types import (
|
||||
StreamHandlerFn,
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.discovery.mdns.mdns import MDNSDiscovery
|
||||
from libp2p.host.defaults import (
|
||||
get_default_protocols,
|
||||
)
|
||||
@ -70,6 +71,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
logger = logging.getLogger("libp2p.network.basic_host")
|
||||
DEFAULT_NEGOTIATE_TIMEOUT = 5
|
||||
|
||||
|
||||
class BasicHost(IHost):
|
||||
@ -89,15 +91,20 @@ class BasicHost(IHost):
|
||||
def __init__(
|
||||
self,
|
||||
network: INetworkService,
|
||||
enable_mDNS: bool = False,
|
||||
default_protocols: Optional["OrderedDict[TProtocol, StreamHandlerFn]"] = None,
|
||||
negotitate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
) -> None:
|
||||
self._network = network
|
||||
self._network.set_stream_handler(self._swarm_stream_handler)
|
||||
self.peerstore = self._network.peerstore
|
||||
self.negotiate_timeout = negotitate_timeout
|
||||
# Protocol muxing
|
||||
default_protocols = default_protocols or get_default_protocols(self)
|
||||
self.multiselect = Multiselect(dict(default_protocols.items()))
|
||||
self.multiselect_client = MultiselectClient()
|
||||
if enable_mDNS:
|
||||
self.mDNS = MDNSDiscovery(network)
|
||||
|
||||
def get_id(self) -> ID:
|
||||
"""
|
||||
@ -162,7 +169,14 @@ class BasicHost(IHost):
|
||||
network = self.get_network()
|
||||
async with background_trio_service(network):
|
||||
await network.listen(*listen_addrs)
|
||||
yield
|
||||
if hasattr(self, "mDNS") and self.mDNS is not None:
|
||||
logger.debug("Starting mDNS Discovery")
|
||||
self.mDNS.start()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if hasattr(self, "mDNS") and self.mDNS is not None:
|
||||
self.mDNS.stop()
|
||||
|
||||
return _run()
|
||||
|
||||
@ -178,7 +192,10 @@ class BasicHost(IHost):
|
||||
self.multiselect.add_handler(protocol_id, stream_handler)
|
||||
|
||||
async def new_stream(
|
||||
self, peer_id: ID, protocol_ids: Sequence[TProtocol]
|
||||
self,
|
||||
peer_id: ID,
|
||||
protocol_ids: Sequence[TProtocol],
|
||||
negotitate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
) -> INetStream:
|
||||
"""
|
||||
:param peer_id: peer_id that host is connecting
|
||||
@ -190,7 +207,9 @@ class BasicHost(IHost):
|
||||
# Perform protocol muxing to determine protocol to use
|
||||
try:
|
||||
selected_protocol = await self.multiselect_client.select_one_of(
|
||||
list(protocol_ids), MultiselectCommunicator(net_stream)
|
||||
list(protocol_ids),
|
||||
MultiselectCommunicator(net_stream),
|
||||
negotitate_timeout,
|
||||
)
|
||||
except MultiselectClientError as error:
|
||||
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)
|
||||
@ -200,7 +219,12 @@ class BasicHost(IHost):
|
||||
net_stream.set_protocol(selected_protocol)
|
||||
return net_stream
|
||||
|
||||
async def send_command(self, peer_id: ID, command: str) -> list[str]:
|
||||
async def send_command(
|
||||
self,
|
||||
peer_id: ID,
|
||||
command: str,
|
||||
response_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Send a multistream-select command to the specified peer and return
|
||||
the response.
|
||||
@ -214,7 +238,7 @@ class BasicHost(IHost):
|
||||
|
||||
try:
|
||||
response = await self.multiselect_client.query_multistream_command(
|
||||
MultiselectCommunicator(new_stream), command
|
||||
MultiselectCommunicator(new_stream), command, response_timeout
|
||||
)
|
||||
except MultiselectClientError as error:
|
||||
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)
|
||||
@ -253,7 +277,7 @@ class BasicHost(IHost):
|
||||
# Perform protocol muxing to determine protocol to use
|
||||
try:
|
||||
protocol, handler = await self.multiselect.negotiate(
|
||||
MultiselectCommunicator(net_stream)
|
||||
MultiselectCommunicator(net_stream), self.negotiate_timeout
|
||||
)
|
||||
except MultiselectError as error:
|
||||
peer_id = net_stream.muxed_conn.peer_id
|
||||
|
||||
@ -18,8 +18,10 @@ from libp2p.peer.peerinfo import (
|
||||
class RoutedHost(BasicHost):
|
||||
_router: IPeerRouting
|
||||
|
||||
def __init__(self, network: INetworkService, router: IPeerRouting):
|
||||
super().__init__(network)
|
||||
def __init__(
|
||||
self, network: INetworkService, router: IPeerRouting, enable_mDNS: bool = False
|
||||
):
|
||||
super().__init__(network, enable_mDNS)
|
||||
self._router = router
|
||||
|
||||
async def connect(self, peer_info: PeerInfo) -> None:
|
||||
|
||||
@ -59,7 +59,7 @@ def _mk_identify_protobuf(
|
||||
) -> Identify:
|
||||
public_key = host.get_public_key()
|
||||
laddrs = host.get_addrs()
|
||||
protocols = host.get_mux().get_protocols()
|
||||
protocols = tuple(str(p) for p in host.get_mux().get_protocols() if p is not None)
|
||||
|
||||
observed_addr = observed_multiaddr.to_bytes() if observed_multiaddr else b""
|
||||
return Identify(
|
||||
|
||||
@ -40,6 +40,7 @@ logger = logging.getLogger(__name__)
|
||||
ID_PUSH = TProtocol("/ipfs/id/push/1.0.0")
|
||||
PROTOCOL_VERSION = "ipfs/0.1.0"
|
||||
AGENT_VERSION = get_agent_version()
|
||||
CONCURRENCY_LIMIT = 10
|
||||
|
||||
|
||||
def identify_push_handler_for(host: IHost) -> StreamHandlerFn:
|
||||
@ -132,7 +133,10 @@ async def _update_peerstore_from_identify(
|
||||
|
||||
|
||||
async def push_identify_to_peer(
|
||||
host: IHost, peer_id: ID, observed_multiaddr: Multiaddr | None = None
|
||||
host: IHost,
|
||||
peer_id: ID,
|
||||
observed_multiaddr: Multiaddr | None = None,
|
||||
limit: trio.Semaphore = trio.Semaphore(CONCURRENCY_LIMIT),
|
||||
) -> bool:
|
||||
"""
|
||||
Push an identify message to a specific peer.
|
||||
@ -146,25 +150,26 @@ async def push_identify_to_peer(
|
||||
True if the push was successful, False otherwise.
|
||||
|
||||
"""
|
||||
try:
|
||||
# Create a new stream to the peer using the identify/push protocol
|
||||
stream = await host.new_stream(peer_id, [ID_PUSH])
|
||||
async with limit:
|
||||
try:
|
||||
# Create a new stream to the peer using the identify/push protocol
|
||||
stream = await host.new_stream(peer_id, [ID_PUSH])
|
||||
|
||||
# Create the identify message
|
||||
identify_msg = _mk_identify_protobuf(host, observed_multiaddr)
|
||||
response = identify_msg.SerializeToString()
|
||||
# Create the identify message
|
||||
identify_msg = _mk_identify_protobuf(host, observed_multiaddr)
|
||||
response = identify_msg.SerializeToString()
|
||||
|
||||
# Send the identify message
|
||||
await stream.write(response)
|
||||
# Send the identify message
|
||||
await stream.write(response)
|
||||
|
||||
# Close the stream
|
||||
await stream.close()
|
||||
# Close the stream
|
||||
await stream.close()
|
||||
|
||||
logger.debug("Successfully pushed identify to peer %s", peer_id)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Error pushing identify to peer %s: %s", peer_id, e)
|
||||
return False
|
||||
logger.debug("Successfully pushed identify to peer %s", peer_id)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Error pushing identify to peer %s: %s", peer_id, e)
|
||||
return False
|
||||
|
||||
|
||||
async def push_identify_to_peers(
|
||||
@ -179,13 +184,10 @@ async def push_identify_to_peers(
|
||||
"""
|
||||
if peer_ids is None:
|
||||
# Get all connected peers
|
||||
peer_ids = set(host.get_peerstore().peer_ids())
|
||||
peer_ids = set(host.get_connected_peers())
|
||||
|
||||
# Push to each peer in parallel using a trio.Nursery
|
||||
# TODO: Consider using a bounded nursery to limit concurrency
|
||||
# and avoid overwhelming the network. This can be done by using
|
||||
# trio.open_nursery(max_concurrent=10) or similar.
|
||||
# For now, we will use an unbounded nursery for simplicity.
|
||||
# limiting concurrent connections to 10
|
||||
async with trio.open_nursery() as nursery:
|
||||
for peer_id in peer_ids:
|
||||
nursery.start_soon(push_identify_to_peer, host, peer_id, observed_multiaddr)
|
||||
|
||||
@ -18,6 +18,13 @@ from libp2p.crypto.keys import (
|
||||
PublicKey,
|
||||
)
|
||||
|
||||
"""
|
||||
Latency EWMA Smoothing governs the deacy of the EWMA (the speed at which
|
||||
is changes). This must be a normalized (0-1) value.
|
||||
1 is 100% change, 0 is no change.
|
||||
"""
|
||||
LATENCY_EWMA_SMOOTHING = 0.1
|
||||
|
||||
|
||||
class PeerData(IPeerData):
|
||||
pubkey: PublicKey | None
|
||||
@ -27,6 +34,7 @@ class PeerData(IPeerData):
|
||||
addrs: list[Multiaddr]
|
||||
last_identified: int
|
||||
ttl: int # Keep ttl=0 by default for always valid
|
||||
latmap: float
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.pubkey = None
|
||||
@ -36,6 +44,9 @@ class PeerData(IPeerData):
|
||||
self.addrs = []
|
||||
self.last_identified = int(time.time())
|
||||
self.ttl = 0
|
||||
self.latmap = 0
|
||||
|
||||
# --------PROTO-BOOK--------
|
||||
|
||||
def get_protocols(self) -> list[str]:
|
||||
"""
|
||||
@ -55,6 +66,37 @@ class PeerData(IPeerData):
|
||||
"""
|
||||
self.protocols = list(protocols)
|
||||
|
||||
def remove_protocols(self, protocols: Sequence[str]) -> None:
|
||||
"""
|
||||
:param protocols: protocols to remove
|
||||
"""
|
||||
for protocol in protocols:
|
||||
if protocol in self.protocols:
|
||||
self.protocols.remove(protocol)
|
||||
|
||||
def supports_protocols(self, protocols: Sequence[str]) -> list[str]:
|
||||
"""
|
||||
:param protocols: protocols to check from
|
||||
:return: all supported protocols in the given list
|
||||
"""
|
||||
return [proto for proto in protocols if proto in self.protocols]
|
||||
|
||||
def first_supported_protocol(self, protocols: Sequence[str]) -> str:
|
||||
"""
|
||||
:param protocols: protocols to check from
|
||||
:return: first supported protocol in the given list
|
||||
"""
|
||||
for protocol in protocols:
|
||||
if protocol in self.protocols:
|
||||
return protocol
|
||||
|
||||
return "None supported"
|
||||
|
||||
def clear_protocol_data(self) -> None:
|
||||
"""Clear all protocols"""
|
||||
self.protocols = []
|
||||
|
||||
# -------ADDR-BOOK---------
|
||||
def add_addrs(self, addrs: Sequence[Multiaddr]) -> None:
|
||||
"""
|
||||
:param addrs: multiaddresses to add
|
||||
@ -73,6 +115,7 @@ class PeerData(IPeerData):
|
||||
"""Clear all addresses."""
|
||||
self.addrs = []
|
||||
|
||||
# -------METADATA-----------
|
||||
def put_metadata(self, key: str, val: Any) -> None:
|
||||
"""
|
||||
:param key: key in KV pair
|
||||
@ -90,6 +133,11 @@ class PeerData(IPeerData):
|
||||
return self.metadata[key]
|
||||
raise PeerDataError("key not found")
|
||||
|
||||
def clear_metadata(self) -> None:
|
||||
"""Clears metadata."""
|
||||
self.metadata = {}
|
||||
|
||||
# -------KEY-BOOK---------------
|
||||
def add_pubkey(self, pubkey: PublicKey) -> None:
|
||||
"""
|
||||
:param pubkey:
|
||||
@ -120,9 +168,41 @@ class PeerData(IPeerData):
|
||||
raise PeerDataError("private key not found")
|
||||
return self.privkey
|
||||
|
||||
def clear_keydata(self) -> None:
|
||||
"""Clears keydata"""
|
||||
self.pubkey = None
|
||||
self.privkey = None
|
||||
|
||||
# ----------METRICS--------------
|
||||
def record_latency(self, new_latency: float) -> None:
|
||||
"""
|
||||
Records a new latency measurement for the given peer
|
||||
using Exponentially Weighted Moving Average (EWMA)
|
||||
:param new_latency: the new latency value
|
||||
"""
|
||||
s = LATENCY_EWMA_SMOOTHING
|
||||
if s > 1 or s < 0:
|
||||
s = 0.1
|
||||
|
||||
if self.latmap == 0:
|
||||
self.latmap = new_latency
|
||||
else:
|
||||
prev = self.latmap
|
||||
updated = ((1.0 - s) * prev) + (s * new_latency)
|
||||
self.latmap = updated
|
||||
|
||||
def latency_EWMA(self) -> float:
|
||||
"""Returns the latency EWMA value"""
|
||||
return self.latmap
|
||||
|
||||
def clear_metrics(self) -> None:
|
||||
"""Clear the latency metrics"""
|
||||
self.latmap = 0
|
||||
|
||||
def update_last_identified(self) -> None:
|
||||
self.last_identified = int(time.time())
|
||||
|
||||
# ----------TTL------------------
|
||||
def get_last_identified(self) -> int:
|
||||
"""
|
||||
:return: last identified timestamp
|
||||
|
||||
@ -2,6 +2,7 @@ from collections import (
|
||||
defaultdict,
|
||||
)
|
||||
from collections.abc import (
|
||||
AsyncIterable,
|
||||
Sequence,
|
||||
)
|
||||
from typing import (
|
||||
@ -11,6 +12,8 @@ from typing import (
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
)
|
||||
import trio
|
||||
from trio import MemoryReceiveChannel, MemorySendChannel
|
||||
|
||||
from libp2p.abc import (
|
||||
IPeerStore,
|
||||
@ -40,6 +43,7 @@ class PeerStore(IPeerStore):
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.peer_data_map = defaultdict(PeerData)
|
||||
self.addr_update_channels: dict[ID, MemorySendChannel[Multiaddr]] = {}
|
||||
|
||||
def peer_info(self, peer_id: ID) -> PeerInfo:
|
||||
"""
|
||||
@ -53,6 +57,33 @@ class PeerStore(IPeerStore):
|
||||
return PeerInfo(peer_id, peer_data.get_addrs())
|
||||
raise PeerStoreError("peer ID not found")
|
||||
|
||||
def peer_ids(self) -> list[ID]:
|
||||
"""
|
||||
:return: all of the peer IDs stored in peer store
|
||||
"""
|
||||
return list(self.peer_data_map.keys())
|
||||
|
||||
def clear_peerdata(self, peer_id: ID) -> None:
|
||||
"""Clears all data associated with the given peer_id."""
|
||||
if peer_id in self.peer_data_map:
|
||||
del self.peer_data_map[peer_id]
|
||||
else:
|
||||
raise PeerStoreError("peer ID not found")
|
||||
|
||||
def valid_peer_ids(self) -> list[ID]:
|
||||
"""
|
||||
:return: all of the valid peer IDs stored in peer store
|
||||
"""
|
||||
valid_peer_ids: list[ID] = []
|
||||
for peer_id, peer_data in self.peer_data_map.items():
|
||||
if not peer_data.is_expired():
|
||||
valid_peer_ids.append(peer_id)
|
||||
else:
|
||||
peer_data.clear_addrs()
|
||||
return valid_peer_ids
|
||||
|
||||
# --------PROTO-BOOK--------
|
||||
|
||||
def get_protocols(self, peer_id: ID) -> list[str]:
|
||||
"""
|
||||
:param peer_id: peer ID to get protocols for
|
||||
@ -79,23 +110,31 @@ class PeerStore(IPeerStore):
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
peer_data.set_protocols(list(protocols))
|
||||
|
||||
def peer_ids(self) -> list[ID]:
|
||||
def remove_protocols(self, peer_id: ID, protocols: Sequence[str]) -> None:
|
||||
"""
|
||||
:param peer_id: peer ID to get info for
|
||||
:param protocols: unsupported protocols to remove
|
||||
"""
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
peer_data.remove_protocols(protocols)
|
||||
|
||||
def supports_protocols(self, peer_id: ID, protocols: Sequence[str]) -> list[str]:
|
||||
"""
|
||||
:return: all of the peer IDs stored in peer store
|
||||
"""
|
||||
return list(self.peer_data_map.keys())
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
return peer_data.supports_protocols(protocols)
|
||||
|
||||
def valid_peer_ids(self) -> list[ID]:
|
||||
"""
|
||||
:return: all of the valid peer IDs stored in peer store
|
||||
"""
|
||||
valid_peer_ids: list[ID] = []
|
||||
for peer_id, peer_data in self.peer_data_map.items():
|
||||
if not peer_data.is_expired():
|
||||
valid_peer_ids.append(peer_id)
|
||||
else:
|
||||
peer_data.clear_addrs()
|
||||
return valid_peer_ids
|
||||
def first_supported_protocol(self, peer_id: ID, protocols: Sequence[str]) -> str:
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
return peer_data.first_supported_protocol(protocols)
|
||||
|
||||
def clear_protocol_data(self, peer_id: ID) -> None:
|
||||
"""Clears prtocoldata"""
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
peer_data.clear_protocol_data()
|
||||
|
||||
# ------METADATA---------
|
||||
|
||||
def get(self, peer_id: ID, key: str) -> Any:
|
||||
"""
|
||||
@ -121,6 +160,13 @@ class PeerStore(IPeerStore):
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
peer_data.put_metadata(key, val)
|
||||
|
||||
def clear_metadata(self, peer_id: ID) -> None:
|
||||
"""Clears metadata"""
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
peer_data.clear_metadata()
|
||||
|
||||
# -------ADDR-BOOK--------
|
||||
|
||||
def add_addr(self, peer_id: ID, addr: Multiaddr, ttl: int = 0) -> None:
|
||||
"""
|
||||
:param peer_id: peer ID to add address for
|
||||
@ -140,6 +186,13 @@ class PeerStore(IPeerStore):
|
||||
peer_data.set_ttl(ttl)
|
||||
peer_data.update_last_identified()
|
||||
|
||||
if peer_id in self.addr_update_channels:
|
||||
for addr in addrs:
|
||||
try:
|
||||
self.addr_update_channels[peer_id].send_nowait(addr)
|
||||
except trio.WouldBlock:
|
||||
pass # Or consider logging / dropping / replacing stream
|
||||
|
||||
def addrs(self, peer_id: ID) -> list[Multiaddr]:
|
||||
"""
|
||||
:param peer_id: peer ID to get addrs for
|
||||
@ -165,7 +218,7 @@ class PeerStore(IPeerStore):
|
||||
|
||||
def peers_with_addrs(self) -> list[ID]:
|
||||
"""
|
||||
:return: all of the peer IDs which has addrs stored in peer store
|
||||
:return: all of the peer IDs which has addrsfloat stored in peer store
|
||||
"""
|
||||
# Add all peers with addrs at least 1 to output
|
||||
output: list[ID] = []
|
||||
@ -179,6 +232,27 @@ class PeerStore(IPeerStore):
|
||||
peer_data.clear_addrs()
|
||||
return output
|
||||
|
||||
async def addr_stream(self, peer_id: ID) -> AsyncIterable[Multiaddr]:
|
||||
"""
|
||||
Returns an async stream of newly added addresses for the given peer.
|
||||
|
||||
This function allows consumers to subscribe to address updates for a peer
|
||||
and receive each new address as it is added via `add_addr` or `add_addrs`.
|
||||
|
||||
:param peer_id: The ID of the peer to monitor address updates for.
|
||||
:return: An async iterator yielding Multiaddr instances as they are added.
|
||||
"""
|
||||
send: MemorySendChannel[Multiaddr]
|
||||
receive: MemoryReceiveChannel[Multiaddr]
|
||||
|
||||
send, receive = trio.open_memory_channel(0)
|
||||
self.addr_update_channels[peer_id] = send
|
||||
|
||||
async for addr in receive:
|
||||
yield addr
|
||||
|
||||
# -------KEY-BOOK---------
|
||||
|
||||
def add_pubkey(self, peer_id: ID, pubkey: PublicKey) -> None:
|
||||
"""
|
||||
:param peer_id: peer ID to add public key for
|
||||
@ -239,6 +313,45 @@ class PeerStore(IPeerStore):
|
||||
self.add_pubkey(peer_id, key_pair.public_key)
|
||||
self.add_privkey(peer_id, key_pair.private_key)
|
||||
|
||||
def peer_with_keys(self) -> list[ID]:
|
||||
"""Returns the peer_ids for which keys are stored"""
|
||||
return [
|
||||
peer_id
|
||||
for peer_id, pdata in self.peer_data_map.items()
|
||||
if pdata.pubkey is not None
|
||||
]
|
||||
|
||||
def clear_keydata(self, peer_id: ID) -> None:
|
||||
"""Clears the keys of the peer"""
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
peer_data.clear_keydata()
|
||||
|
||||
# --------METRICS--------
|
||||
|
||||
def record_latency(self, peer_id: ID, RTT: float) -> None:
|
||||
"""
|
||||
Records a new latency measurement for the given peer
|
||||
using Exponentially Weighted Moving Average (EWMA)
|
||||
|
||||
:param peer_id: peer ID to get private key for
|
||||
:param RTT: the new latency value (round trip time)
|
||||
"""
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
peer_data.record_latency(RTT)
|
||||
|
||||
def latency_EWMA(self, peer_id: ID) -> float:
|
||||
"""
|
||||
:param peer_id: peer ID to get private key for
|
||||
:return: The latency EWMA value for that peer
|
||||
"""
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
return peer_data.latency_EWMA()
|
||||
|
||||
def clear_metrics(self, peer_id: ID) -> None:
|
||||
"""Clear the latency metrics"""
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
peer_data.clear_metrics()
|
||||
|
||||
|
||||
class PeerStoreError(KeyError):
|
||||
"""Raised when peer ID is not found in peer store."""
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
IMultiselectCommunicator,
|
||||
IMultiselectMuxer,
|
||||
@ -14,6 +16,7 @@ from .exceptions import (
|
||||
|
||||
MULTISELECT_PROTOCOL_ID = "/multistream/1.0.0"
|
||||
PROTOCOL_NOT_FOUND_MSG = "na"
|
||||
DEFAULT_NEGOTIATE_TIMEOUT = 5
|
||||
|
||||
|
||||
class Multiselect(IMultiselectMuxer):
|
||||
@ -47,47 +50,68 @@ class Multiselect(IMultiselectMuxer):
|
||||
|
||||
# FIXME: Make TProtocol Optional[TProtocol] to keep types consistent
|
||||
async def negotiate(
|
||||
self, communicator: IMultiselectCommunicator
|
||||
self,
|
||||
communicator: IMultiselectCommunicator,
|
||||
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
) -> tuple[TProtocol, StreamHandlerFn | None]:
|
||||
"""
|
||||
Negotiate performs protocol selection.
|
||||
|
||||
:param stream: stream to negotiate on
|
||||
:param negotiate_timeout: timeout for negotiation
|
||||
:return: selected protocol name, handler function
|
||||
:raise MultiselectError: raised when negotiation failed
|
||||
"""
|
||||
await self.handshake(communicator)
|
||||
try:
|
||||
with trio.fail_after(negotiate_timeout):
|
||||
await self.handshake(communicator)
|
||||
|
||||
while True:
|
||||
try:
|
||||
command = await communicator.read()
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectError() from error
|
||||
|
||||
if command == "ls":
|
||||
supported_protocols = [p for p in self.handlers.keys() if p is not None]
|
||||
response = "\n".join(supported_protocols) + "\n"
|
||||
|
||||
try:
|
||||
await communicator.write(response)
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectError() from error
|
||||
|
||||
else:
|
||||
protocol = TProtocol(command)
|
||||
if protocol in self.handlers:
|
||||
while True:
|
||||
try:
|
||||
await communicator.write(protocol)
|
||||
command = await communicator.read()
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectError() from error
|
||||
|
||||
return protocol, self.handlers[protocol]
|
||||
try:
|
||||
await communicator.write(PROTOCOL_NOT_FOUND_MSG)
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectError() from error
|
||||
if command == "ls":
|
||||
supported_protocols = [
|
||||
p for p in self.handlers.keys() if p is not None
|
||||
]
|
||||
response = "\n".join(supported_protocols) + "\n"
|
||||
|
||||
raise MultiselectError("Negotiation failed: no matching protocol")
|
||||
try:
|
||||
await communicator.write(response)
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectError() from error
|
||||
|
||||
else:
|
||||
protocol = TProtocol(command)
|
||||
if protocol in self.handlers:
|
||||
try:
|
||||
await communicator.write(protocol)
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectError() from error
|
||||
|
||||
return protocol, self.handlers[protocol]
|
||||
try:
|
||||
await communicator.write(PROTOCOL_NOT_FOUND_MSG)
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectError() from error
|
||||
|
||||
raise MultiselectError("Negotiation failed: no matching protocol")
|
||||
except trio.TooSlowError:
|
||||
raise MultiselectError("handshake read timeout")
|
||||
|
||||
def get_protocols(self) -> tuple[TProtocol | None, ...]:
|
||||
"""
|
||||
Retrieve the protocols for which handlers have been registered.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[TProtocol, ...]
|
||||
A tuple of registered protocol names.
|
||||
|
||||
"""
|
||||
return tuple(self.handlers.keys())
|
||||
|
||||
async def handshake(self, communicator: IMultiselectCommunicator) -> None:
|
||||
"""
|
||||
|
||||
@ -2,6 +2,8 @@ from collections.abc import (
|
||||
Sequence,
|
||||
)
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
IMultiselectClient,
|
||||
IMultiselectCommunicator,
|
||||
@ -17,6 +19,7 @@ from .exceptions import (
|
||||
|
||||
MULTISELECT_PROTOCOL_ID = "/multistream/1.0.0"
|
||||
PROTOCOL_NOT_FOUND_MSG = "na"
|
||||
DEFAULT_NEGOTIATE_TIMEOUT = 5
|
||||
|
||||
|
||||
class MultiselectClient(IMultiselectClient):
|
||||
@ -40,6 +43,7 @@ class MultiselectClient(IMultiselectClient):
|
||||
|
||||
try:
|
||||
handshake_contents = await communicator.read()
|
||||
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectClientError() from error
|
||||
|
||||
@ -47,7 +51,10 @@ class MultiselectClient(IMultiselectClient):
|
||||
raise MultiselectClientError("multiselect protocol ID mismatch")
|
||||
|
||||
async def select_one_of(
|
||||
self, protocols: Sequence[TProtocol], communicator: IMultiselectCommunicator
|
||||
self,
|
||||
protocols: Sequence[TProtocol],
|
||||
communicator: IMultiselectCommunicator,
|
||||
negotitate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
) -> TProtocol:
|
||||
"""
|
||||
For each protocol, send message to multiselect selecting protocol and
|
||||
@ -56,22 +63,32 @@ class MultiselectClient(IMultiselectClient):
|
||||
|
||||
:param protocol: protocol to select
|
||||
:param communicator: communicator to use to communicate with counterparty
|
||||
:param negotiate_timeout: timeout for negotiation
|
||||
:return: selected protocol
|
||||
:raise MultiselectClientError: raised when protocol negotiation failed
|
||||
"""
|
||||
await self.handshake(communicator)
|
||||
try:
|
||||
with trio.fail_after(negotitate_timeout):
|
||||
await self.handshake(communicator)
|
||||
|
||||
for protocol in protocols:
|
||||
try:
|
||||
selected_protocol = await self.try_select(communicator, protocol)
|
||||
return selected_protocol
|
||||
except MultiselectClientError:
|
||||
pass
|
||||
for protocol in protocols:
|
||||
try:
|
||||
selected_protocol = await self.try_select(
|
||||
communicator, protocol
|
||||
)
|
||||
return selected_protocol
|
||||
except MultiselectClientError:
|
||||
pass
|
||||
|
||||
raise MultiselectClientError("protocols not supported")
|
||||
raise MultiselectClientError("protocols not supported")
|
||||
except trio.TooSlowError:
|
||||
raise MultiselectClientError("response timed out")
|
||||
|
||||
async def query_multistream_command(
|
||||
self, communicator: IMultiselectCommunicator, command: str
|
||||
self,
|
||||
communicator: IMultiselectCommunicator,
|
||||
command: str,
|
||||
response_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Send a multistream-select command over the given communicator and return
|
||||
@ -79,26 +96,32 @@ class MultiselectClient(IMultiselectClient):
|
||||
|
||||
:param communicator: communicator to use to communicate with counterparty
|
||||
:param command: supported multistream-select command(e.g., ls)
|
||||
:param negotiate_timeout: timeout for negotiation
|
||||
:raise MultiselectClientError: If the communicator fails to process data.
|
||||
:return: list of strings representing the response from peer.
|
||||
"""
|
||||
await self.handshake(communicator)
|
||||
|
||||
if command == "ls":
|
||||
try:
|
||||
await communicator.write("ls")
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectClientError() from error
|
||||
else:
|
||||
raise ValueError("Command not supported")
|
||||
|
||||
try:
|
||||
response = await communicator.read()
|
||||
response_list = response.strip().splitlines()
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectClientError() from error
|
||||
with trio.fail_after(response_timeout):
|
||||
await self.handshake(communicator)
|
||||
|
||||
return response_list
|
||||
if command == "ls":
|
||||
try:
|
||||
await communicator.write("ls")
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectClientError() from error
|
||||
else:
|
||||
raise ValueError("Command not supported")
|
||||
|
||||
try:
|
||||
response = await communicator.read()
|
||||
response_list = response.strip().splitlines()
|
||||
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectClientError() from error
|
||||
|
||||
return response_list
|
||||
except trio.TooSlowError:
|
||||
raise MultiselectClientError("command response timed out")
|
||||
|
||||
async def try_select(
|
||||
self, communicator: IMultiselectCommunicator, protocol: TProtocol
|
||||
@ -118,6 +141,7 @@ class MultiselectClient(IMultiselectClient):
|
||||
|
||||
try:
|
||||
response = await communicator.read()
|
||||
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectClientError() from error
|
||||
|
||||
|
||||
@ -12,15 +12,9 @@ from libp2p.abc import (
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.network.stream.exceptions import (
|
||||
StreamClosed,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.utils import (
|
||||
encode_varint_prefixed,
|
||||
)
|
||||
|
||||
from .exceptions import (
|
||||
PubsubRouterError,
|
||||
@ -120,13 +114,7 @@ class FloodSub(IPubsubRouter):
|
||||
if peer_id not in pubsub.peers:
|
||||
continue
|
||||
stream = pubsub.peers[peer_id]
|
||||
# FIXME: We should add a `WriteMsg` similar to write delimited messages.
|
||||
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107
|
||||
try:
|
||||
await stream.write(encode_varint_prefixed(rpc_msg.SerializeToString()))
|
||||
except StreamClosed:
|
||||
logger.debug("Fail to publish message to %s: stream closed", peer_id)
|
||||
pubsub._handle_dead_peer(peer_id)
|
||||
await pubsub.write_msg(stream, rpc_msg)
|
||||
|
||||
async def join(self, topic: str) -> None:
|
||||
"""
|
||||
|
||||
@ -24,9 +24,6 @@ from libp2p.abc import (
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.network.stream.exceptions import (
|
||||
StreamClosed,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
@ -44,9 +41,6 @@ from libp2p.pubsub import (
|
||||
from libp2p.tools.async_service import (
|
||||
Service,
|
||||
)
|
||||
from libp2p.utils import (
|
||||
encode_varint_prefixed,
|
||||
)
|
||||
|
||||
from .exceptions import (
|
||||
NoPubsubAttached,
|
||||
@ -267,13 +261,10 @@ class GossipSub(IPubsubRouter, Service):
|
||||
if peer_id not in self.pubsub.peers:
|
||||
continue
|
||||
stream = self.pubsub.peers[peer_id]
|
||||
# FIXME: We should add a `WriteMsg` similar to write delimited messages.
|
||||
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107
|
||||
try:
|
||||
await stream.write(encode_varint_prefixed(rpc_msg.SerializeToString()))
|
||||
except StreamClosed:
|
||||
logger.debug("Fail to publish message to %s: stream closed", peer_id)
|
||||
self.pubsub._handle_dead_peer(peer_id)
|
||||
|
||||
# TODO: Go use `sendRPC`, which possibly piggybacks gossip/control messages.
|
||||
await self.pubsub.write_msg(stream, rpc_msg)
|
||||
|
||||
for topic in pubsub_msg.topicIDs:
|
||||
self.time_since_last_publish[topic] = int(time.time())
|
||||
|
||||
@ -829,8 +820,6 @@ class GossipSub(IPubsubRouter, Service):
|
||||
|
||||
packet.publish.extend(msgs_to_forward)
|
||||
|
||||
# 2) Serialize that packet
|
||||
rpc_msg: bytes = packet.SerializeToString()
|
||||
if self.pubsub is None:
|
||||
raise NoPubsubAttached
|
||||
|
||||
@ -844,14 +833,7 @@ class GossipSub(IPubsubRouter, Service):
|
||||
peer_stream = self.pubsub.peers[sender_peer_id]
|
||||
|
||||
# 4) And write the packet to the stream
|
||||
try:
|
||||
await peer_stream.write(encode_varint_prefixed(rpc_msg))
|
||||
except StreamClosed:
|
||||
logger.debug(
|
||||
"Fail to responed to iwant request from %s: stream closed",
|
||||
sender_peer_id,
|
||||
)
|
||||
self.pubsub._handle_dead_peer(sender_peer_id)
|
||||
await self.pubsub.write_msg(peer_stream, packet)
|
||||
|
||||
async def handle_graft(
|
||||
self, graft_msg: rpc_pb2.ControlGraft, sender_peer_id: ID
|
||||
@ -993,8 +975,6 @@ class GossipSub(IPubsubRouter, Service):
|
||||
packet: rpc_pb2.RPC = rpc_pb2.RPC()
|
||||
packet.control.CopyFrom(control_msg)
|
||||
|
||||
rpc_msg: bytes = packet.SerializeToString()
|
||||
|
||||
# Get stream for peer from pubsub
|
||||
if to_peer not in self.pubsub.peers:
|
||||
logger.debug(
|
||||
@ -1004,8 +984,4 @@ class GossipSub(IPubsubRouter, Service):
|
||||
peer_stream = self.pubsub.peers[to_peer]
|
||||
|
||||
# Write rpc to stream
|
||||
try:
|
||||
await peer_stream.write(encode_varint_prefixed(rpc_msg))
|
||||
except StreamClosed:
|
||||
logger.debug("Fail to emit control message to %s: stream closed", to_peer)
|
||||
self.pubsub._handle_dead_peer(to_peer)
|
||||
await self.pubsub.write_msg(peer_stream, packet)
|
||||
|
||||
@ -66,6 +66,7 @@ from libp2p.utils import (
|
||||
encode_varint_prefixed,
|
||||
read_varint_prefixed_bytes,
|
||||
)
|
||||
from libp2p.utils.varint import encode_uvarint
|
||||
|
||||
from .pb import (
|
||||
rpc_pb2,
|
||||
@ -778,3 +779,43 @@ class Pubsub(Service, IPubsub):
|
||||
|
||||
def _is_subscribed_to_msg(self, msg: rpc_pb2.Message) -> bool:
|
||||
return any(topic in self.topic_ids for topic in msg.topicIDs)
|
||||
|
||||
async def write_msg(self, stream: INetStream, rpc_msg: rpc_pb2.RPC) -> bool:
|
||||
"""
|
||||
Write an RPC message to a stream with proper error handling.
|
||||
|
||||
Implements WriteMsg similar to go-msgio which is used in go-libp2p
|
||||
Ref: https://github.com/libp2p/go-msgio/blob/master/protoio/uvarint_writer.go#L56
|
||||
|
||||
|
||||
:param stream: stream to write the message to
|
||||
:param rpc_msg: RPC message to write
|
||||
:return: True if successful, False if stream was closed
|
||||
"""
|
||||
try:
|
||||
# Calculate message size first
|
||||
msg_bytes = rpc_msg.SerializeToString()
|
||||
msg_size = len(msg_bytes)
|
||||
|
||||
# Calculate varint size and allocate exact buffer size needed
|
||||
|
||||
varint_bytes = encode_uvarint(msg_size)
|
||||
varint_size = len(varint_bytes)
|
||||
|
||||
# Allocate buffer with exact size (like Go's pool.Get())
|
||||
buf = bytearray(varint_size + msg_size)
|
||||
|
||||
# Write varint length prefix to buffer (like Go's binary.PutUvarint())
|
||||
buf[:varint_size] = varint_bytes
|
||||
|
||||
# Write serialized message after varint (like Go's rpc.MarshalTo())
|
||||
buf[varint_size:] = msg_bytes
|
||||
|
||||
# Single write operation (like Go's s.Write(buf))
|
||||
await stream.write(bytes(buf))
|
||||
return True
|
||||
except StreamClosed:
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
logger.debug("Fail to write message to %s: stream closed", peer_id)
|
||||
self._handle_dead_peer(peer_id)
|
||||
return False
|
||||
|
||||
@ -234,7 +234,8 @@ class RelayDiscovery(Service):
|
||||
|
||||
if not callable(proto_getter):
|
||||
return None
|
||||
|
||||
if peer_id not in peerstore.peer_ids():
|
||||
return None
|
||||
try:
|
||||
# Try to get protocols
|
||||
proto_result = proto_getter(peer_id)
|
||||
@ -283,8 +284,6 @@ class RelayDiscovery(Service):
|
||||
return None
|
||||
|
||||
mux = self.host.get_mux()
|
||||
if not hasattr(mux, "protocols"):
|
||||
return None
|
||||
|
||||
peer_protocols = set()
|
||||
# Get protocols from mux with proper type safety
|
||||
@ -293,7 +292,9 @@ class RelayDiscovery(Service):
|
||||
# Get protocols with proper typing
|
||||
mux_protocols = mux.get_protocols()
|
||||
if isinstance(mux_protocols, (list, tuple)):
|
||||
available_protocols = list(mux_protocols)
|
||||
available_protocols = [
|
||||
p for p in mux.get_protocols() if p is not None
|
||||
]
|
||||
|
||||
for protocol in available_protocols:
|
||||
try:
|
||||
@ -313,7 +314,7 @@ class RelayDiscovery(Service):
|
||||
|
||||
self._protocol_cache[peer_id] = peer_protocols
|
||||
protocol_str = str(PROTOCOL_ID)
|
||||
for protocol in peer_protocols:
|
||||
for protocol in map(TProtocol, peer_protocols):
|
||||
if protocol == protocol_str:
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -31,9 +31,6 @@ from libp2p.stream_muxer.yamux.yamux import (
|
||||
Yamux,
|
||||
)
|
||||
|
||||
# FIXME: add negotiate timeout to `MuxerMultistream`
|
||||
DEFAULT_NEGOTIATE_TIMEOUT = 60
|
||||
|
||||
|
||||
class MuxerMultistream:
|
||||
"""
|
||||
|
||||
@ -98,16 +98,32 @@ class YamuxStream(IMuxedStream):
|
||||
# Flow control: Check if we have enough send window
|
||||
total_len = len(data)
|
||||
sent = 0
|
||||
|
||||
logging.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
|
||||
while sent < total_len:
|
||||
# Wait for available window with timeout
|
||||
timeout = False
|
||||
async with self.window_lock:
|
||||
# Wait for available window
|
||||
while self.send_window == 0 and not self.closed:
|
||||
# Release lock while waiting
|
||||
if self.send_window == 0:
|
||||
logging.debug(
|
||||
f"Stream {self.stream_id}: Window is zero, waiting for update"
|
||||
)
|
||||
# Release lock and wait with timeout
|
||||
self.window_lock.release()
|
||||
await trio.sleep(0.01)
|
||||
# To avoid re-acquiring the lock immediately,
|
||||
with trio.move_on_after(5.0) as cancel_scope:
|
||||
while self.send_window == 0 and not self.closed:
|
||||
await trio.sleep(0.01)
|
||||
# If we timed out, cancel the scope
|
||||
timeout = cancel_scope.cancelled_caught
|
||||
# Re-acquire lock
|
||||
await self.window_lock.acquire()
|
||||
|
||||
# If we timed out waiting for window update, raise an error
|
||||
if timeout:
|
||||
raise MuxedStreamError(
|
||||
"Timed out waiting for window update after 5 seconds."
|
||||
)
|
||||
|
||||
if self.closed:
|
||||
raise MuxedStreamError("Stream is closed")
|
||||
|
||||
@ -123,25 +139,45 @@ class YamuxStream(IMuxedStream):
|
||||
await self.conn.secured_conn.write(header + chunk)
|
||||
sent += to_send
|
||||
|
||||
# If window is getting low, consider updating
|
||||
if self.send_window < DEFAULT_WINDOW_SIZE // 2:
|
||||
await self.send_window_update()
|
||||
|
||||
async def send_window_update(self, increment: int | None = None) -> None:
|
||||
"""Send a window update to peer."""
|
||||
if increment is None:
|
||||
increment = DEFAULT_WINDOW_SIZE - self.recv_window
|
||||
async def send_window_update(self, increment: int, skip_lock: bool = False) -> None:
|
||||
"""
|
||||
Send a window update to peer.
|
||||
|
||||
param:increment: The amount to increment the window size by.
|
||||
If None, uses the difference between DEFAULT_WINDOW_SIZE
|
||||
and current receive window.
|
||||
param:skip_lock (bool): If True, skips acquiring window_lock.
|
||||
This should only be used when calling from a context
|
||||
that already holds the lock.
|
||||
"""
|
||||
if increment <= 0:
|
||||
# If increment is zero or negative, skip sending update
|
||||
logging.debug(
|
||||
f"Stream {self.stream_id}: Skipping window update"
|
||||
f"(increment={increment})"
|
||||
)
|
||||
return
|
||||
logging.debug(
|
||||
f"Stream {self.stream_id}: Sending window update with increment={increment}"
|
||||
)
|
||||
|
||||
async with self.window_lock:
|
||||
self.recv_window += increment
|
||||
async def _do_window_update() -> None:
|
||||
header = struct.pack(
|
||||
YAMUX_HEADER_FORMAT, 0, TYPE_WINDOW_UPDATE, 0, self.stream_id, increment
|
||||
YAMUX_HEADER_FORMAT,
|
||||
0,
|
||||
TYPE_WINDOW_UPDATE,
|
||||
0,
|
||||
self.stream_id,
|
||||
increment,
|
||||
)
|
||||
await self.conn.secured_conn.write(header)
|
||||
|
||||
if skip_lock:
|
||||
await _do_window_update()
|
||||
else:
|
||||
async with self.window_lock:
|
||||
await _do_window_update()
|
||||
|
||||
async def read(self, n: int | None = -1) -> bytes:
|
||||
# Handle None value for n by converting it to -1
|
||||
if n is None:
|
||||
@ -154,55 +190,68 @@ class YamuxStream(IMuxedStream):
|
||||
)
|
||||
raise MuxedStreamEOF("Stream is closed for receiving")
|
||||
|
||||
# If reading until EOF (n == -1), block until stream is closed
|
||||
if n == -1:
|
||||
while not self.recv_closed and not self.conn.event_shutting_down.is_set():
|
||||
data = b""
|
||||
while not self.conn.event_shutting_down.is_set():
|
||||
# Check if there's data in the buffer
|
||||
buffer = self.conn.stream_buffers.get(self.stream_id)
|
||||
if buffer and len(buffer) > 0:
|
||||
# Wait for closure even if data is available
|
||||
logging.debug(
|
||||
f"Stream {self.stream_id}:Waiting for FIN before returning data"
|
||||
)
|
||||
await self.conn.stream_events[self.stream_id].wait()
|
||||
self.conn.stream_events[self.stream_id] = trio.Event()
|
||||
else:
|
||||
# No data, wait for data or closure
|
||||
logging.debug(f"Stream {self.stream_id}: Waiting for data or FIN")
|
||||
await self.conn.stream_events[self.stream_id].wait()
|
||||
self.conn.stream_events[self.stream_id] = trio.Event()
|
||||
|
||||
# After loop, check if stream is closed or shutting down
|
||||
async with self.conn.streams_lock:
|
||||
if self.conn.event_shutting_down.is_set():
|
||||
logging.debug(f"Stream {self.stream_id}: Connection shutting down")
|
||||
raise MuxedStreamEOF("Connection shut down")
|
||||
if self.closed:
|
||||
if self.reset_received:
|
||||
logging.debug(f"Stream {self.stream_id}: Stream was reset")
|
||||
raise MuxedStreamReset("Stream was reset")
|
||||
else:
|
||||
logging.debug(
|
||||
f"Stream {self.stream_id}: Stream closed cleanly (EOF)"
|
||||
)
|
||||
raise MuxedStreamEOF("Stream closed cleanly (EOF)")
|
||||
buffer = self.conn.stream_buffers.get(self.stream_id)
|
||||
# If buffer is not available, check if stream is closed
|
||||
if buffer is None:
|
||||
logging.debug(
|
||||
f"Stream {self.stream_id}: Buffer gone, assuming closed"
|
||||
)
|
||||
logging.debug(f"Stream {self.stream_id}: No buffer available")
|
||||
raise MuxedStreamEOF("Stream buffer closed")
|
||||
|
||||
# If we have data in buffer, process it
|
||||
if len(buffer) > 0:
|
||||
chunk = bytes(buffer)
|
||||
buffer.clear()
|
||||
data += chunk
|
||||
|
||||
# Send window update for the chunk we just read
|
||||
async with self.window_lock:
|
||||
self.recv_window += len(chunk)
|
||||
logging.debug(f"Stream {self.stream_id}: Update {len(chunk)}")
|
||||
await self.send_window_update(len(chunk), skip_lock=True)
|
||||
|
||||
# If stream is closed (FIN received) and buffer is empty, break
|
||||
if self.recv_closed and len(buffer) == 0:
|
||||
logging.debug(f"Stream {self.stream_id}: EOF reached")
|
||||
raise MuxedStreamEOF("Stream is closed for receiving")
|
||||
# Return all buffered data
|
||||
data = bytes(buffer)
|
||||
buffer.clear()
|
||||
logging.debug(f"Stream {self.stream_id}: Returning {len(data)} bytes")
|
||||
logging.debug(f"Stream {self.stream_id}: Closed with empty buffer")
|
||||
break
|
||||
|
||||
# If stream was reset, raise reset error
|
||||
if self.reset_received:
|
||||
logging.debug(f"Stream {self.stream_id}: Stream was reset")
|
||||
raise MuxedStreamReset("Stream was reset")
|
||||
|
||||
# Wait for more data or stream closure
|
||||
logging.debug(f"Stream {self.stream_id}: Waiting for data or FIN")
|
||||
await self.conn.stream_events[self.stream_id].wait()
|
||||
self.conn.stream_events[self.stream_id] = trio.Event()
|
||||
|
||||
# After loop exit, first check if we have data to return
|
||||
if data:
|
||||
logging.debug(
|
||||
f"Stream {self.stream_id}: Returning {len(data)} bytes after loop"
|
||||
)
|
||||
return data
|
||||
|
||||
# For specific size read (n > 0), return available data immediately
|
||||
return await self.conn.read_stream(self.stream_id, n)
|
||||
# No data accumulated, now check why we exited the loop
|
||||
if self.conn.event_shutting_down.is_set():
|
||||
logging.debug(f"Stream {self.stream_id}: Connection shutting down")
|
||||
raise MuxedStreamEOF("Connection shut down")
|
||||
|
||||
# Return empty data
|
||||
return b""
|
||||
else:
|
||||
data = await self.conn.read_stream(self.stream_id, n)
|
||||
async with self.window_lock:
|
||||
self.recv_window += len(data)
|
||||
logging.debug(
|
||||
f"Stream {self.stream_id}: Sending window update after read, "
|
||||
f"increment={len(data)}"
|
||||
)
|
||||
await self.send_window_update(len(data), skip_lock=True)
|
||||
return data
|
||||
|
||||
async def close(self) -> None:
|
||||
if not self.send_closed:
|
||||
@ -493,7 +542,7 @@ class Yamux(IMuxedConn):
|
||||
f"type={typ}, flags={flags}, stream_id={stream_id},"
|
||||
f"length={length}"
|
||||
)
|
||||
if typ == TYPE_DATA and flags & FLAG_SYN:
|
||||
if (typ == TYPE_DATA or typ == TYPE_WINDOW_UPDATE) and flags & FLAG_SYN:
|
||||
async with self.streams_lock:
|
||||
if stream_id not in self.streams:
|
||||
stream = YamuxStream(stream_id, self, False)
|
||||
|
||||
@ -1 +0,0 @@
|
||||
Added support for ``Kademlia DHT`` in py-libp2p.
|
||||
@ -1,7 +0,0 @@
|
||||
Store public key and peer ID in peerstore during handshake
|
||||
|
||||
Modified the InsecureTransport class to accept an optional peerstore parameter and updated the handshake process to store the received public key and peer ID in the peerstore when available.
|
||||
|
||||
Added test cases to verify:
|
||||
1. The peerstore remains unchanged when handshake fails due to peer ID mismatch
|
||||
2. The handshake correctly adds a public key to a peer ID that already exists in the peerstore but doesn't have a public key yet
|
||||
@ -1 +0,0 @@
|
||||
Refactored gossipsub heartbeat logic to use a single helper method `_handle_topic_heartbeat` that handles both fanout and gossip heartbeats.
|
||||
@ -1 +0,0 @@
|
||||
Added sparse connect utility function to pubsub test utilities for creating test networks with configurable connectivity.
|
||||
@ -1,2 +0,0 @@
|
||||
Reordered the arguments to `upgrade_security` to place `is_initiator` before `peer_id`, and made `peer_id` optional.
|
||||
This allows the method to reflect the fact that peer identity is not required for inbound connections.
|
||||
@ -1 +0,0 @@
|
||||
Uses the `decapsulate` method of the `Multiaddr` class to clean up the observed address.
|
||||
@ -1 +0,0 @@
|
||||
Optimized pubsub publishing to send multiple topics in a single message instead of separate messages per topic.
|
||||
@ -1 +0,0 @@
|
||||
added peer exchange and backoff logic as part of Gossipsub v1.1 upgrade
|
||||
@ -1 +0,0 @@
|
||||
Fixed an issue in `Pubsub` where async validators were not handled reliably under concurrency. Now uses a safe aggregator list for consistent behavior.
|
||||
1
newsfragments/732.deprecation.rst
Normal file
1
newsfragments/732.deprecation.rst
Normal file
@ -0,0 +1 @@
|
||||
update cryptographic dependencies: pycryptodome to ≥3.19.1, pynacl to ≥1.5.0, coincurve to ≥21.0.0
|
||||
3
newsfragments/746.bugfix.rst
Normal file
3
newsfragments/746.bugfix.rst
Normal file
@ -0,0 +1,3 @@
|
||||
Improved type safety in `get_mux()` and `get_protocols()` by returning properly typed values instead
|
||||
of `Any`. Also updated `identify.py` and `discovery.py` to handle `None` values safely and
|
||||
compare protocols correctly.
|
||||
1
newsfragments/749.internal.rst
Normal file
1
newsfragments/749.internal.rst
Normal file
@ -0,0 +1 @@
|
||||
Add comprehensive tests for relay_discovery method in circuit_relay_v2
|
||||
1
newsfragments/750.feature.rst
Normal file
1
newsfragments/750.feature.rst
Normal file
@ -0,0 +1 @@
|
||||
Add logic to clear_peerdata method in peerstore
|
||||
@ -1,11 +1,10 @@
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=42", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "libp2p"
|
||||
version = "0.2.8"
|
||||
version = "0.2.9"
|
||||
description = "libp2p: The Python implementation of the libp2p networking stack"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10, <4.0"
|
||||
@ -16,21 +15,22 @@ authors = [
|
||||
]
|
||||
dependencies = [
|
||||
"base58>=1.0.3",
|
||||
"coincurve>=10.0.0",
|
||||
"coincurve>=21.0.0",
|
||||
"exceptiongroup>=1.2.0; python_version < '3.11'",
|
||||
"grpcio>=1.41.0",
|
||||
"lru-dict>=1.1.6",
|
||||
"multiaddr>=0.0.9",
|
||||
"mypy-protobuf>=3.0.0",
|
||||
"noiseprotocol>=0.3.0",
|
||||
"protobuf>=3.20.1,<4.0.0",
|
||||
"pycryptodome>=3.9.2",
|
||||
"pycryptodome>=3.19.1",
|
||||
"protobuf>=4.21.0,<5.0.0",
|
||||
"pymultihash>=0.8.2",
|
||||
"pynacl>=1.3.0",
|
||||
"pynacl>=1.5.0",
|
||||
"rpcudp>=3.0.0",
|
||||
"trio-typing>=0.0.4",
|
||||
"trio>=0.26.0",
|
||||
"fastecdsa==2.3.2; sys_platform != 'win32'",
|
||||
"zeroconf (>=0.147.0,<0.148.0)",
|
||||
]
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
@ -54,6 +54,7 @@ identify-demo = "examples.identify.identify:main"
|
||||
identify-push-demo = "examples.identify_push.identify_push_demo:run_main"
|
||||
identify-push-listener-dialer-demo = "examples.identify_push.identify_push_listener_dialer:main"
|
||||
pubsub-demo = "examples.pubsub.pubsub:main"
|
||||
mdns-demo = "examples.mDNS.mDNS:main"
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
@ -187,7 +188,7 @@ name = "Removals"
|
||||
showcontent = true
|
||||
|
||||
[tool.bumpversion]
|
||||
current_version = "0.2.8"
|
||||
current_version = "0.2.9"
|
||||
parse = """
|
||||
(?P<major>\\d+)
|
||||
\\.(?P<minor>\\d+)
|
||||
|
||||
@ -1,4 +1,7 @@
|
||||
import logging
|
||||
from unittest.mock import (
|
||||
patch,
|
||||
)
|
||||
|
||||
import pytest
|
||||
import multiaddr
|
||||
@ -17,6 +20,7 @@ from libp2p.identity.identify.pb.identify_pb2 import (
|
||||
Identify,
|
||||
)
|
||||
from libp2p.identity.identify_push.identify_push import (
|
||||
CONCURRENCY_LIMIT,
|
||||
ID_PUSH,
|
||||
_update_peerstore_from_identify,
|
||||
identify_push_handler_for,
|
||||
@ -29,6 +33,11 @@ from libp2p.peer.peerinfo import (
|
||||
from tests.utils.factories import (
|
||||
host_pair_factory,
|
||||
)
|
||||
from tests.utils.utils import (
|
||||
create_mock_connections,
|
||||
run_host_forever,
|
||||
wait_until_listening,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("libp2p.identity.identify-push-test")
|
||||
|
||||
@ -175,6 +184,7 @@ async def test_identify_push_to_peers(security_protocol):
|
||||
host_c = new_host(key_pair=key_pair_c)
|
||||
|
||||
# Set up the identify/push handlers
|
||||
host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a))
|
||||
host_b.set_stream_handler(ID_PUSH, identify_push_handler_for(host_b))
|
||||
host_c.set_stream_handler(ID_PUSH, identify_push_handler_for(host_c))
|
||||
|
||||
@ -204,6 +214,20 @@ async def test_identify_push_to_peers(security_protocol):
|
||||
# Check that the peer is in the peerstore
|
||||
assert peer_id_a in peerstore_c.peer_ids()
|
||||
|
||||
# Test for push_identify to only connected peers and not all peers
|
||||
# Disconnect a from c.
|
||||
await host_c.disconnect(host_a.get_id())
|
||||
|
||||
await push_identify_to_peers(host_c)
|
||||
|
||||
# Wait a bit for the push to complete
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Check that host_a's peerstore has not been updated with host_c's info
|
||||
assert host_c.get_id() not in host_a.get_peerstore().peer_ids()
|
||||
# Check that host_b's peerstore has been updated with host_c's info
|
||||
assert host_c.get_id() in host_b.get_peerstore().peer_ids()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_push_identify_to_peers_with_explicit_params(security_protocol):
|
||||
@ -412,3 +436,160 @@ async def test_partial_update_peerstore_from_identify(security_protocol):
|
||||
host_a_public_key = host_a.get_public_key().serialize()
|
||||
peerstore_public_key = peerstore.pubkey(peer_id).serialize()
|
||||
assert host_a_public_key == peerstore_public_key
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_push_identify_to_peers_respects_concurrency_limit():
|
||||
"""
|
||||
Test bounded concurrency for the identify/push protocol to prevent
|
||||
network congestion.
|
||||
|
||||
This test verifies:
|
||||
1. The number of concurrent tasks executing the identify push is always
|
||||
less than or equal to CONCURRENCY_LIMIT.
|
||||
2. An error is raised if concurrency exceeds the defined limit.
|
||||
|
||||
It mocks `push_identify_to_peer` to simulate delay using sleep,
|
||||
allowing the test to measure and assert actual concurrency behavior.
|
||||
"""
|
||||
state = {
|
||||
"concurrency_counter": 0,
|
||||
"max_observed": 0,
|
||||
}
|
||||
lock = trio.Lock()
|
||||
|
||||
async def mock_push_identify_to_peer(
|
||||
host, peer_id, observed_multiaddr=None, limit=trio.Semaphore(CONCURRENCY_LIMIT)
|
||||
) -> bool:
|
||||
"""
|
||||
Mock function to test concurrency by simulating an identify message.
|
||||
|
||||
This function patches push_identify_to_peer for testing purpose
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the push was successful, False otherwise.
|
||||
|
||||
"""
|
||||
async with limit:
|
||||
async with lock:
|
||||
state["concurrency_counter"] += 1
|
||||
if state["concurrency_counter"] > CONCURRENCY_LIMIT:
|
||||
raise RuntimeError(
|
||||
f"Concurrency limit exceeded: {state['concurrency_counter']}"
|
||||
)
|
||||
state["max_observed"] = max(
|
||||
state["max_observed"], state["concurrency_counter"]
|
||||
)
|
||||
|
||||
logger.debug("Successfully pushed identify to peer %s", peer_id)
|
||||
await trio.sleep(0.05)
|
||||
|
||||
async with lock:
|
||||
state["concurrency_counter"] -= 1
|
||||
|
||||
return True
|
||||
|
||||
# Create a mock host.
|
||||
key_pair_host = create_new_key_pair()
|
||||
host = new_host(key_pair=key_pair_host)
|
||||
|
||||
# Create a mock network and add mock connections to the host
|
||||
host.get_network().connections = create_mock_connections()
|
||||
with patch(
|
||||
"libp2p.identity.identify_push.identify_push.push_identify_to_peer",
|
||||
new=mock_push_identify_to_peer,
|
||||
):
|
||||
await push_identify_to_peers(host)
|
||||
assert state["max_observed"] <= CONCURRENCY_LIMIT, (
|
||||
f"Max concurrency observed: {state['max_observed']}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_all_peers_receive_identify_push_with_semaphore(security_protocol):
|
||||
dummy_peers = []
|
||||
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (host_a, _):
|
||||
# Create dummy peers
|
||||
for _ in range(50):
|
||||
key_pair = create_new_key_pair()
|
||||
dummy_host = new_host(key_pair=key_pair)
|
||||
dummy_host.set_stream_handler(
|
||||
ID_PUSH, identify_push_handler_for(dummy_host)
|
||||
)
|
||||
listen_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")
|
||||
dummy_peers.append((dummy_host, listen_addr))
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Start all dummy hosts
|
||||
for host, listen_addr in dummy_peers:
|
||||
nursery.start_soon(run_host_forever, host, listen_addr)
|
||||
|
||||
# Wait for all hosts to finish setting up listeners
|
||||
for host, _ in dummy_peers:
|
||||
await wait_until_listening(host)
|
||||
|
||||
# Now connect host_a → dummy peers
|
||||
for host, _ in dummy_peers:
|
||||
await host_a.connect(info_from_p2p_addr(host.get_addrs()[0]))
|
||||
|
||||
await push_identify_to_peers(
|
||||
host_a,
|
||||
)
|
||||
|
||||
await trio.sleep(0.5)
|
||||
|
||||
peer_id_a = host_a.get_id()
|
||||
for host, _ in dummy_peers:
|
||||
dummy_peerstore = host.get_peerstore()
|
||||
assert peer_id_a in dummy_peerstore.peer_ids()
|
||||
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_all_peers_receive_identify_push_with_semaphore_under_high_peer_load(
|
||||
security_protocol,
|
||||
):
|
||||
dummy_peers = []
|
||||
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (host_a, _):
|
||||
# Create dummy peers
|
||||
# Breaking with more than 500 peers
|
||||
# Trio have a async tasks limit of 1000
|
||||
for _ in range(499):
|
||||
key_pair = create_new_key_pair()
|
||||
dummy_host = new_host(key_pair=key_pair)
|
||||
dummy_host.set_stream_handler(
|
||||
ID_PUSH, identify_push_handler_for(dummy_host)
|
||||
)
|
||||
listen_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")
|
||||
dummy_peers.append((dummy_host, listen_addr))
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Start all dummy hosts
|
||||
for host, listen_addr in dummy_peers:
|
||||
nursery.start_soon(run_host_forever, host, listen_addr)
|
||||
|
||||
# Wait for all hosts to finish setting up listeners
|
||||
for host, _ in dummy_peers:
|
||||
await wait_until_listening(host)
|
||||
|
||||
# Now connect host_a → dummy peers
|
||||
for host, _ in dummy_peers:
|
||||
await host_a.connect(info_from_p2p_addr(host.get_addrs()[0]))
|
||||
|
||||
await push_identify_to_peers(
|
||||
host_a,
|
||||
)
|
||||
|
||||
await trio.sleep(0.5)
|
||||
|
||||
peer_id_a = host_a.get_id()
|
||||
for host, _ in dummy_peers:
|
||||
dummy_peerstore = host.get_peerstore()
|
||||
assert peer_id_a in dummy_peerstore.peer_ids()
|
||||
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
@ -6,10 +6,12 @@ from multiaddr import Multiaddr
|
||||
from libp2p.crypto.secp256k1 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerdata import (
|
||||
PeerData,
|
||||
PeerDataError,
|
||||
)
|
||||
from libp2p.peer.peerstore import PeerStore
|
||||
|
||||
MOCK_ADDR = Multiaddr("/ip4/127.0.0.1/tcp/4001")
|
||||
MOCK_KEYPAIR = create_new_key_pair()
|
||||
@ -39,6 +41,59 @@ def test_set_protocols():
|
||||
assert peer_data.get_protocols() == protocols
|
||||
|
||||
|
||||
# Test case when removing protocols:
|
||||
def test_remove_protocols():
|
||||
peer_data = PeerData()
|
||||
protocols: Sequence[str] = ["protocol1", "protocol2"]
|
||||
peer_data.set_protocols(protocols)
|
||||
|
||||
peer_data.remove_protocols(["protocol1"])
|
||||
assert peer_data.get_protocols() == ["protocol2"]
|
||||
|
||||
|
||||
# Test case when clearing the protocol list:
|
||||
def test_clear_protocol_data():
|
||||
peer_data = PeerData()
|
||||
protocols: Sequence[str] = ["protocol1", "protocol2"]
|
||||
peer_data.set_protocols(protocols)
|
||||
|
||||
peer_data.clear_protocol_data()
|
||||
assert peer_data.get_protocols() == []
|
||||
|
||||
|
||||
# Test case when supports protocols:
|
||||
def test_supports_protocols():
|
||||
peer_data = PeerData()
|
||||
peer_data.set_protocols(["protocol1", "protocol2", "protocol3"])
|
||||
|
||||
input_protocols = ["protocol1", "protocol4", "protocol2"]
|
||||
supported = peer_data.supports_protocols(input_protocols)
|
||||
|
||||
assert supported == ["protocol1", "protocol2"]
|
||||
|
||||
|
||||
# Test case for first supported protocol is found
|
||||
def test_first_supported_protocol_found():
|
||||
peer_data = PeerData()
|
||||
peer_data.set_protocols(["protocolA", "protocolB"])
|
||||
|
||||
input_protocols = ["protocolC", "protocolB", "protocolA"]
|
||||
first = peer_data.first_supported_protocol(input_protocols)
|
||||
|
||||
assert first == "protocolB"
|
||||
|
||||
|
||||
# Test case for first supported protocol not found
|
||||
def test_first_supported_protocol_none():
|
||||
peer_data = PeerData()
|
||||
peer_data.set_protocols(["protocolX", "protocolY"])
|
||||
|
||||
input_protocols = ["protocolA", "protocolB"]
|
||||
first = peer_data.first_supported_protocol(input_protocols)
|
||||
|
||||
assert first == "None supported"
|
||||
|
||||
|
||||
# Test case when adding addresses
|
||||
def test_add_addrs():
|
||||
peer_data = PeerData()
|
||||
@ -81,6 +136,15 @@ def test_get_metadata_key_not_found():
|
||||
peer_data.get_metadata("nonexistent_key")
|
||||
|
||||
|
||||
# Test case for clearing metadata
|
||||
def test_clear_metadata():
|
||||
peer_data = PeerData()
|
||||
peer_data.metadata = {"key1": "value1", "key2": "value2"}
|
||||
|
||||
peer_data.clear_metadata()
|
||||
assert peer_data.metadata == {}
|
||||
|
||||
|
||||
# Test case for adding public key
|
||||
def test_add_pubkey():
|
||||
peer_data = PeerData()
|
||||
@ -107,3 +171,71 @@ def test_get_privkey_not_found():
|
||||
peer_data = PeerData()
|
||||
with pytest.raises(PeerDataError):
|
||||
peer_data.get_privkey()
|
||||
|
||||
|
||||
# Test case for returning all the peers with stored keys
|
||||
def test_peer_with_keys():
|
||||
peer_store = PeerStore()
|
||||
peer_id_1 = ID(b"peer1")
|
||||
peer_id_2 = ID(b"peer2")
|
||||
|
||||
peer_data_1 = PeerData()
|
||||
peer_data_2 = PeerData()
|
||||
|
||||
peer_data_1.pubkey = MOCK_PUBKEY
|
||||
peer_data_2.pubkey = None
|
||||
|
||||
peer_store.peer_data_map = {
|
||||
peer_id_1: peer_data_1,
|
||||
peer_id_2: peer_data_2,
|
||||
}
|
||||
|
||||
assert peer_store.peer_with_keys() == [peer_id_1]
|
||||
|
||||
|
||||
# Test case for clearing the key book
|
||||
def test_clear_keydata():
|
||||
peer_store = PeerStore()
|
||||
peer_id = ID(b"peer123")
|
||||
peer_data = PeerData()
|
||||
|
||||
peer_data.pubkey = MOCK_PUBKEY
|
||||
peer_data.privkey = MOCK_PRIVKEY
|
||||
peer_store.peer_data_map = {peer_id: peer_data}
|
||||
|
||||
peer_store.clear_keydata(peer_id)
|
||||
|
||||
assert peer_data.pubkey is None
|
||||
assert peer_data.privkey is None
|
||||
|
||||
|
||||
# Test case for recording latency for the first time
|
||||
def test_record_latency_initial():
|
||||
peer_data = PeerData()
|
||||
assert peer_data.latency_EWMA() == 0
|
||||
|
||||
peer_data.record_latency(100.0)
|
||||
assert peer_data.latency_EWMA() == 100.0
|
||||
|
||||
|
||||
# Test case for updating latency
|
||||
def test_record_latency_updates_ewma():
|
||||
peer_data = PeerData()
|
||||
peer_data.record_latency(100.0) # first measurement
|
||||
first = peer_data.latency_EWMA()
|
||||
|
||||
peer_data.record_latency(50.0) # second measurement
|
||||
second = peer_data.latency_EWMA()
|
||||
|
||||
assert second < first # EWMA should have smoothed downward
|
||||
assert second > 50.0 # Not as low as the new latency
|
||||
assert second != first
|
||||
|
||||
|
||||
def test_clear_metrics():
|
||||
peer_data = PeerData()
|
||||
peer_data.record_latency(200.0)
|
||||
assert peer_data.latency_EWMA() == 200.0
|
||||
|
||||
peer_data.clear_metrics()
|
||||
assert peer_data.latency_EWMA() == 0
|
||||
|
||||
@ -2,6 +2,7 @@ import time
|
||||
|
||||
import pytest
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerstore import (
|
||||
@ -89,3 +90,33 @@ def test_peers():
|
||||
store.add_addr(ID(b"peer3"), Multiaddr("/ip4/127.0.0.1/tcp/4001"), 10)
|
||||
|
||||
assert set(store.peer_ids()) == {ID(b"peer1"), ID(b"peer2"), ID(b"peer3")}
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_addr_stream_yields_new_addrs():
|
||||
store = PeerStore()
|
||||
peer_id = ID(b"peer1")
|
||||
addr1 = Multiaddr("/ip4/127.0.0.1/tcp/4001")
|
||||
addr2 = Multiaddr("/ip4/127.0.0.1/tcp/4002")
|
||||
|
||||
collected = []
|
||||
|
||||
async def consume_addrs():
|
||||
async for addr in store.addr_stream(peer_id):
|
||||
collected.append(addr)
|
||||
if len(collected) == 2:
|
||||
break
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(consume_addrs)
|
||||
await trio.sleep(2) # Give time for the stream to start
|
||||
|
||||
store.add_addr(peer_id, addr1, ttl=10)
|
||||
await trio.sleep(0.2)
|
||||
store.add_addr(peer_id, addr2, ttl=10)
|
||||
await trio.sleep(0.2)
|
||||
|
||||
# After collecting expected addresses, cancel the stream
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
assert collected == [addr1, addr2]
|
||||
|
||||
59
tests/core/protocol_muxer/test_negotiate_timeout.py
Normal file
59
tests/core/protocol_muxer/test_negotiate_timeout.py
Normal file
@ -0,0 +1,59 @@
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
IMultiselectCommunicator,
|
||||
)
|
||||
from libp2p.custom_types import TProtocol
|
||||
from libp2p.protocol_muxer.exceptions import (
|
||||
MultiselectClientError,
|
||||
MultiselectError,
|
||||
)
|
||||
from libp2p.protocol_muxer.multiselect import Multiselect
|
||||
from libp2p.protocol_muxer.multiselect_client import MultiselectClient
|
||||
|
||||
|
||||
class DummyMultiselectCommunicator(IMultiselectCommunicator):
|
||||
"""
|
||||
Dummy MultiSelectCommunicator to test out negotiate timmeout.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
return
|
||||
|
||||
async def write(self, msg_str: str) -> None:
|
||||
"""Goes into infinite loop when .write is called"""
|
||||
await trio.sleep_forever()
|
||||
|
||||
async def read(self) -> str:
|
||||
"""Returns a dummy read"""
|
||||
return "dummy_read"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_select_one_of_timeout():
|
||||
ECHO = TProtocol("/echo/1.0.0")
|
||||
communicator = DummyMultiselectCommunicator()
|
||||
|
||||
client = MultiselectClient()
|
||||
|
||||
with pytest.raises(MultiselectClientError, match="response timed out"):
|
||||
await client.select_one_of([ECHO], communicator, 2)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_query_multistream_command_timeout():
|
||||
communicator = DummyMultiselectCommunicator()
|
||||
client = MultiselectClient()
|
||||
|
||||
with pytest.raises(MultiselectClientError, match="response timed out"):
|
||||
await client.query_multistream_command(communicator, "ls", 2)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_timeout():
|
||||
communicator = DummyMultiselectCommunicator()
|
||||
server = Multiselect()
|
||||
|
||||
with pytest.raises(MultiselectError, match="handshake read timeout"):
|
||||
await server.negotiate(communicator, 2)
|
||||
@ -3,6 +3,7 @@ import pytest
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.protocol_muxer.multiselect import Multiselect
|
||||
from libp2p.tools.utils import (
|
||||
create_echo_stream_handler,
|
||||
)
|
||||
@ -138,3 +139,23 @@ async def test_multistream_command(security_protocol):
|
||||
# Dialer asks for unspoorted command
|
||||
with pytest.raises(ValueError, match="Command not supported"):
|
||||
await dialer.send_command(listener.get_id(), "random")
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_get_protocols_returns_all_registered_protocols():
|
||||
ms = Multiselect()
|
||||
|
||||
async def dummy_handler(stream):
|
||||
pass
|
||||
|
||||
p1 = TProtocol("/echo/1.0.0")
|
||||
p2 = TProtocol("/foo/1.0.0")
|
||||
p3 = TProtocol("/bar/1.0.0")
|
||||
|
||||
ms.add_handler(p1, dummy_handler)
|
||||
ms.add_handler(p2, dummy_handler)
|
||||
ms.add_handler(p3, dummy_handler)
|
||||
|
||||
protocols = ms.get_protocols()
|
||||
|
||||
assert set(protocols) == {p1, p2, p3}
|
||||
|
||||
@ -15,6 +15,7 @@ from tests.utils.factories import (
|
||||
PubsubFactory,
|
||||
)
|
||||
from tests.utils.pubsub.utils import (
|
||||
connect_some,
|
||||
dense_connect,
|
||||
one_to_all_connect,
|
||||
sparse_connect,
|
||||
@ -590,3 +591,166 @@ async def test_sparse_connect():
|
||||
f"received the message. Ideally all nodes should receive it, but at "
|
||||
f"minimum {min_required} required for sparse network scalability."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connect_some_with_fewer_hosts_than_degree():
|
||||
"""Test connect_some when there are fewer hosts than degree."""
|
||||
# Create 3 hosts with degree=5
|
||||
async with PubsubFactory.create_batch_with_floodsub(3) as pubsubs_fsub:
|
||||
hosts = [pubsub.host for pubsub in pubsubs_fsub]
|
||||
degree = 5
|
||||
|
||||
await connect_some(hosts, degree)
|
||||
await trio.sleep(0.1) # Allow connections to establish
|
||||
|
||||
# Each host should connect to all other hosts (since there are only 2 others)
|
||||
for i, pubsub in enumerate(pubsubs_fsub):
|
||||
connected_peers = len(pubsub.peers)
|
||||
expected_max_connections = len(hosts) - 1 # All others
|
||||
assert connected_peers <= expected_max_connections, (
|
||||
f"Host {i} has {connected_peers} connections, "
|
||||
f"but can only connect to {expected_max_connections} others"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connect_some_degree_limit_enforced():
|
||||
"""Test that connect_some enforces degree limits and creates expected topology."""
|
||||
# Test with small network where we can verify exact behavior
|
||||
async with PubsubFactory.create_batch_with_floodsub(6) as pubsubs_fsub:
|
||||
hosts = [pubsub.host for pubsub in pubsubs_fsub]
|
||||
degree = 2
|
||||
|
||||
await connect_some(hosts, degree)
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# With 6 hosts and degree=2, expected connections:
|
||||
# Host 0 → connects to hosts 1,2 (2 peers total)
|
||||
# Host 1 → connects to hosts 2,3 (3 peers: 0,2,3)
|
||||
# Host 2 → connects to hosts 3,4 (4 peers: 0,1,3,4)
|
||||
# Host 3 → connects to hosts 4,5 (3 peers: 1,2,4,5) - wait, that's 4!
|
||||
# Host 4 → connects to host 5 (3 peers: 2,3,5)
|
||||
# Host 5 → (2 peers: 3,4)
|
||||
|
||||
peer_counts = [len(pubsub.peers) for pubsub in pubsubs_fsub]
|
||||
|
||||
# First and last hosts should have exactly degree connections
|
||||
assert peer_counts[0] == degree, (
|
||||
f"Host 0 should have {degree} peers, got {peer_counts[0]}"
|
||||
)
|
||||
assert peer_counts[-1] <= degree, (
|
||||
f"Last host should have ≤ {degree} peers, got {peer_counts[-1]}"
|
||||
)
|
||||
|
||||
# Middle hosts may have more due to bidirectional connections
|
||||
# but the pattern should be consistent with degree limit
|
||||
total_connections = sum(peer_counts)
|
||||
|
||||
# Should be less than full mesh (each host connected to all others)
|
||||
full_mesh_connections = len(hosts) * (len(hosts) - 1)
|
||||
assert total_connections < full_mesh_connections, (
|
||||
f"Got {total_connections} total connections, "
|
||||
f"but full mesh would be {full_mesh_connections}"
|
||||
)
|
||||
|
||||
# Should be more than just a chain (each host connected to next only)
|
||||
chain_connections = 2 * (len(hosts) - 1) # bidirectional chain
|
||||
assert total_connections > chain_connections, (
|
||||
f"Got {total_connections} total connections, which is too few "
|
||||
f"(chain would be {chain_connections})"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connect_some_degree_zero():
|
||||
"""Test edge case: degree=0 should result in no connections."""
|
||||
# Create 5 hosts with degree=0
|
||||
async with PubsubFactory.create_batch_with_floodsub(5) as pubsubs_fsub:
|
||||
hosts = [pubsub.host for pubsub in pubsubs_fsub]
|
||||
degree = 0
|
||||
|
||||
await connect_some(hosts, degree)
|
||||
await trio.sleep(0.1) # Allow any potential connections to establish
|
||||
|
||||
# Verify no connections were made
|
||||
for i, pubsub in enumerate(pubsubs_fsub):
|
||||
connected_peers = len(pubsub.peers)
|
||||
assert connected_peers == 0, (
|
||||
f"Host {i} has {connected_peers} connections, "
|
||||
f"but degree=0 should result in no connections"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connect_some_negative_degree():
|
||||
"""Test edge case: negative degree should be handled gracefully."""
|
||||
# Create 5 hosts with degree=-1
|
||||
async with PubsubFactory.create_batch_with_floodsub(5) as pubsubs_fsub:
|
||||
hosts = [pubsub.host for pubsub in pubsubs_fsub]
|
||||
degree = -1
|
||||
|
||||
await connect_some(hosts, degree)
|
||||
await trio.sleep(0.1) # Allow any potential connections to establish
|
||||
|
||||
# Verify no connections were made (negative degree should behave like 0)
|
||||
for i, pubsub in enumerate(pubsubs_fsub):
|
||||
connected_peers = len(pubsub.peers)
|
||||
assert connected_peers == 0, (
|
||||
f"Host {i} has {connected_peers} connections, "
|
||||
f"but negative degree should result in no connections"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_sparse_connect_degree_zero():
|
||||
"""Test sparse_connect with degree=0."""
|
||||
async with PubsubFactory.create_batch_with_floodsub(8) as pubsubs_fsub:
|
||||
hosts = [pubsub.host for pubsub in pubsubs_fsub]
|
||||
degree = 0
|
||||
|
||||
await sparse_connect(hosts, degree)
|
||||
await trio.sleep(0.1) # Allow connections to establish
|
||||
|
||||
# With degree=0, sparse_connect should still create neighbor connections
|
||||
# for connectivity (this is part of the algorithm design)
|
||||
for i, pubsub in enumerate(pubsubs_fsub):
|
||||
connected_peers = len(pubsub.peers)
|
||||
# Should have some connections due to neighbor connectivity
|
||||
# (each node connects to immediate neighbors)
|
||||
expected_neighbors = 2 # previous and next in ring
|
||||
assert connected_peers >= expected_neighbors, (
|
||||
f"Host {i} has {connected_peers} connections, "
|
||||
f"expected at least {expected_neighbors} neighbor connections"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_empty_host_list():
|
||||
"""Test edge case: empty host list should be handled gracefully."""
|
||||
hosts = []
|
||||
|
||||
# All functions should handle empty lists gracefully
|
||||
await connect_some(hosts, 5)
|
||||
await sparse_connect(hosts, 3)
|
||||
await dense_connect(hosts)
|
||||
|
||||
# If we reach here without exceptions, the test passes
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_single_host():
|
||||
"""Test edge case: single host should be handled gracefully."""
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
hosts = [pubsub.host for pubsub in pubsubs_fsub]
|
||||
|
||||
# All functions should handle single host gracefully
|
||||
await connect_some(hosts, 5)
|
||||
await sparse_connect(hosts, 3)
|
||||
await dense_connect(hosts)
|
||||
|
||||
# Single host should have no connections
|
||||
connected_peers = len(pubsubs_fsub[0].peers)
|
||||
assert connected_peers == 0, (
|
||||
f"Single host has {connected_peers} connections, expected 0"
|
||||
)
|
||||
|
||||
@ -105,11 +105,11 @@ async def test_relay_discovery_initialization():
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_relay_discovery_find_relay():
|
||||
"""Test finding a relay node via discovery."""
|
||||
async def test_relay_discovery_find_relay_peerstore_method():
|
||||
"""Test finding a relay node via discovery using the peerstore method."""
|
||||
async with HostFactory.create_batch_and_listen(2) as hosts:
|
||||
relay_host, client_host = hosts
|
||||
logger.info("Created hosts for test_relay_discovery_find_relay")
|
||||
logger.info("Created host for test_relay_discovery_find_relay_peerstore_method")
|
||||
logger.info("Relay host ID: %s", relay_host.get_id())
|
||||
logger.info("Client host ID: %s", client_host.get_id())
|
||||
|
||||
@ -144,19 +144,19 @@ async def test_relay_discovery_find_relay():
|
||||
# Start discovery service
|
||||
async with background_trio_service(client_discovery):
|
||||
await client_discovery.event_started.wait()
|
||||
logger.info("Client discovery service started")
|
||||
logger.info("Client discovery service started (peerstore method)")
|
||||
|
||||
# Wait for discovery to find the relay
|
||||
logger.info("Waiting for relay discovery...")
|
||||
# Wait for discovery to find the relay using the peerstore method
|
||||
logger.info("Waiting for relay discovery using peerstore...")
|
||||
|
||||
# Manually trigger discovery instead of waiting
|
||||
# Manually trigger discovery which uses peerstore as default
|
||||
await client_discovery.discover_relays()
|
||||
|
||||
# Check if relay was found
|
||||
with trio.fail_after(DISCOVERY_TIMEOUT):
|
||||
for _ in range(20): # Try multiple times
|
||||
if relay_host.get_id() in client_discovery._discovered_relays:
|
||||
logger.info("Relay discovered successfully")
|
||||
logger.info("Relay discovered successfully (peerstore method)")
|
||||
break
|
||||
|
||||
# Wait and try again
|
||||
@ -164,14 +164,194 @@ async def test_relay_discovery_find_relay():
|
||||
# Manually trigger discovery again
|
||||
await client_discovery.discover_relays()
|
||||
else:
|
||||
pytest.fail("Failed to discover relay node within timeout")
|
||||
pytest.fail(
|
||||
"Failed to discover relay node within timeout(peerstore method)"
|
||||
)
|
||||
|
||||
# Verify that relay was found and is valid
|
||||
assert relay_host.get_id() in client_discovery._discovered_relays, (
|
||||
"Relay should be discovered"
|
||||
"Relay should be discovered (peerstore method)"
|
||||
)
|
||||
relay_info = client_discovery._discovered_relays[relay_host.get_id()]
|
||||
assert relay_info.peer_id == relay_host.get_id(), "Peer ID should match"
|
||||
assert relay_info.peer_id == relay_host.get_id(), (
|
||||
"Peer ID should match (peerstore method)"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_relay_discovery_find_relay_direct_connection_method():
|
||||
"""Test finding a relay node via discovery using the direct connection method."""
|
||||
async with HostFactory.create_batch_and_listen(2) as hosts:
|
||||
relay_host, client_host = hosts
|
||||
logger.info("Created hosts for test_relay_discovery_find_relay_direct_method")
|
||||
logger.info("Relay host ID: %s", relay_host.get_id())
|
||||
logger.info("Client host ID: %s", client_host.get_id())
|
||||
|
||||
# Explicitly register the protocol handlers on relay_host
|
||||
relay_host.set_stream_handler(PROTOCOL_ID, simple_stream_handler)
|
||||
relay_host.set_stream_handler(STOP_PROTOCOL_ID, simple_stream_handler)
|
||||
|
||||
# Manually add protocol to peerstore for testing, then remove to force fallback
|
||||
client_host.get_peerstore().add_protocols(
|
||||
relay_host.get_id(), [str(PROTOCOL_ID)]
|
||||
)
|
||||
|
||||
# Set up discovery on the client host
|
||||
client_discovery = RelayDiscovery(
|
||||
client_host, discovery_interval=5
|
||||
) # Use shorter interval for testing
|
||||
|
||||
try:
|
||||
# Connect peers so they can discover each other
|
||||
with trio.fail_after(CONNECT_TIMEOUT):
|
||||
logger.info("Connecting client host to relay host")
|
||||
await connect(client_host, relay_host)
|
||||
assert relay_host.get_network().connections[client_host.get_id()], (
|
||||
"Peers not connected"
|
||||
)
|
||||
logger.info("Connection established between peers")
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect peers: %s", str(e))
|
||||
raise
|
||||
|
||||
# Remove the relay from the peerstore to test fallback to direct connection
|
||||
client_host.get_peerstore().clear_peerdata(relay_host.get_id())
|
||||
# Make sure that peer_id is not present in peerstore
|
||||
assert relay_host.get_id() not in client_host.get_peerstore().peer_ids()
|
||||
|
||||
# Start discovery service
|
||||
async with background_trio_service(client_discovery):
|
||||
await client_discovery.event_started.wait()
|
||||
logger.info("Client discovery service started (direct connection method)")
|
||||
|
||||
# Wait for discovery to find the relay using the direct connection method
|
||||
logger.info(
|
||||
"Waiting for relay discovery using direct connection fallback..."
|
||||
)
|
||||
|
||||
# Manually trigger discovery which should fallback to direct connection
|
||||
await client_discovery.discover_relays()
|
||||
|
||||
# Check if relay was found
|
||||
with trio.fail_after(DISCOVERY_TIMEOUT):
|
||||
for _ in range(20): # Try multiple times
|
||||
if relay_host.get_id() in client_discovery._discovered_relays:
|
||||
logger.info("Relay discovered successfully (direct method)")
|
||||
break
|
||||
|
||||
# Wait and try again
|
||||
await trio.sleep(1)
|
||||
# Manually trigger discovery again
|
||||
await client_discovery.discover_relays()
|
||||
else:
|
||||
pytest.fail(
|
||||
"Failed to discover relay node within timeout (direct method)"
|
||||
)
|
||||
|
||||
# Verify that relay was found and is valid
|
||||
assert relay_host.get_id() in client_discovery._discovered_relays, (
|
||||
"Relay should be discovered (direct method)"
|
||||
)
|
||||
relay_info = client_discovery._discovered_relays[relay_host.get_id()]
|
||||
assert relay_info.peer_id == relay_host.get_id(), (
|
||||
"Peer ID should match (direct method)"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_relay_discovery_find_relay_mux_method():
|
||||
"""
|
||||
Test finding a relay node via discovery using the mux method
|
||||
(fallback after direct connection fails).
|
||||
"""
|
||||
async with HostFactory.create_batch_and_listen(2) as hosts:
|
||||
relay_host, client_host = hosts
|
||||
logger.info("Created hosts for test_relay_discovery_find_relay_mux_method")
|
||||
logger.info("Relay host ID: %s", relay_host.get_id())
|
||||
logger.info("Client host ID: %s", client_host.get_id())
|
||||
|
||||
# Explicitly register the protocol handlers on relay_host
|
||||
relay_host.set_stream_handler(PROTOCOL_ID, simple_stream_handler)
|
||||
relay_host.set_stream_handler(STOP_PROTOCOL_ID, simple_stream_handler)
|
||||
|
||||
client_host.set_stream_handler(PROTOCOL_ID, simple_stream_handler)
|
||||
client_host.set_stream_handler(STOP_PROTOCOL_ID, simple_stream_handler)
|
||||
|
||||
# Set up discovery on the client host
|
||||
client_discovery = RelayDiscovery(
|
||||
client_host, discovery_interval=5
|
||||
) # Use shorter interval for testing
|
||||
|
||||
try:
|
||||
# Connect peers so they can discover each other
|
||||
with trio.fail_after(CONNECT_TIMEOUT):
|
||||
logger.info("Connecting client host to relay host")
|
||||
await connect(client_host, relay_host)
|
||||
assert relay_host.get_network().connections[client_host.get_id()], (
|
||||
"Peers not connected"
|
||||
)
|
||||
logger.info("Connection established between peers")
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect peers: %s", str(e))
|
||||
raise
|
||||
|
||||
# Remove the relay from the peerstore to test fallback
|
||||
client_host.get_peerstore().clear_peerdata(relay_host.get_id())
|
||||
# Make sure that peer_id is not present in peerstore
|
||||
assert relay_host.get_id() not in client_host.get_peerstore().peer_ids()
|
||||
|
||||
# Mock the _check_via_direct_connection method to return None
|
||||
# This forces the discovery to fall back to the mux method
|
||||
async def mock_direct_check_fails(peer_id):
|
||||
"""Mock that always returns None to force mux fallback."""
|
||||
return None
|
||||
|
||||
client_discovery._check_via_direct_connection = mock_direct_check_fails
|
||||
|
||||
# Start discovery service
|
||||
async with background_trio_service(client_discovery):
|
||||
await client_discovery.event_started.wait()
|
||||
logger.info("Client discovery service started (mux method)")
|
||||
|
||||
# Wait for discovery to find the relay using the mux method
|
||||
logger.info("Waiting for relay discovery using mux fallback...")
|
||||
|
||||
# Manually trigger discovery which should fallback to mux method
|
||||
await client_discovery.discover_relays()
|
||||
|
||||
# Check if relay was found
|
||||
with trio.fail_after(DISCOVERY_TIMEOUT):
|
||||
for _ in range(20): # Try multiple times
|
||||
if relay_host.get_id() in client_discovery._discovered_relays:
|
||||
logger.info("Relay discovered successfully (mux method)")
|
||||
break
|
||||
|
||||
# Wait and try again
|
||||
await trio.sleep(1)
|
||||
# Manually trigger discovery again
|
||||
await client_discovery.discover_relays()
|
||||
else:
|
||||
pytest.fail(
|
||||
"Failed to discover relay node within timeout (mux method)"
|
||||
)
|
||||
|
||||
# Verify that relay was found and is valid
|
||||
assert relay_host.get_id() in client_discovery._discovered_relays, (
|
||||
"Relay should be discovered (mux method)"
|
||||
)
|
||||
relay_info = client_discovery._discovered_relays[relay_host.get_id()]
|
||||
assert relay_info.peer_id == relay_host.get_id(), (
|
||||
"Peer ID should match (mux method)"
|
||||
)
|
||||
|
||||
# Verify that the protocol was cached via mux method
|
||||
assert relay_host.get_id() in client_discovery._protocol_cache, (
|
||||
"Protocol should be cached (mux method)"
|
||||
)
|
||||
assert (
|
||||
str(PROTOCOL_ID)
|
||||
in client_discovery._protocol_cache[relay_host.get_id()]
|
||||
), "Relay protocol should be in cache (mux method)"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
|
||||
199
tests/core/stream_muxer/test_yamux_interleaving.py
Normal file
199
tests/core/stream_muxer/test_yamux_interleaving.py
Normal file
@ -0,0 +1,199 @@
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
from trio.testing import (
|
||||
memory_stream_pair,
|
||||
)
|
||||
|
||||
from libp2p.abc import IRawConnection
|
||||
from libp2p.crypto.ed25519 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.security.insecure.transport import (
|
||||
InsecureTransport,
|
||||
)
|
||||
from libp2p.stream_muxer.yamux.yamux import (
|
||||
Yamux,
|
||||
YamuxStream,
|
||||
)
|
||||
|
||||
|
||||
class TrioStreamAdapter(IRawConnection):
|
||||
"""Adapter to make trio memory streams work with libp2p."""
|
||||
|
||||
def __init__(self, send_stream, receive_stream, is_initiator=False):
|
||||
self.send_stream = send_stream
|
||||
self.receive_stream = receive_stream
|
||||
self.is_initiator = is_initiator
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
logging.debug(f"Attempting to write {len(data)} bytes")
|
||||
with trio.move_on_after(2):
|
||||
await self.send_stream.send_all(data)
|
||||
|
||||
async def read(self, n: int | None = None) -> bytes:
|
||||
if n is None or n <= 0:
|
||||
raise ValueError("Reading unbounded or zero bytes not supported")
|
||||
logging.debug(f"Attempting to read {n} bytes")
|
||||
with trio.move_on_after(2):
|
||||
data = await self.receive_stream.receive_some(n)
|
||||
logging.debug(f"Read {len(data)} bytes")
|
||||
return data
|
||||
|
||||
async def close(self) -> None:
|
||||
logging.debug("Closing stream")
|
||||
await self.send_stream.aclose()
|
||||
await self.receive_stream.aclose()
|
||||
|
||||
def get_remote_address(self) -> tuple[str, int] | None:
|
||||
"""Return None since this is a test adapter without real network info."""
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def key_pair():
|
||||
return create_new_key_pair()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def peer_id(key_pair):
|
||||
return ID.from_pubkey(key_pair.public_key)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def secure_conn_pair(key_pair, peer_id):
|
||||
"""Create a pair of secure connections for testing."""
|
||||
logging.debug("Setting up secure_conn_pair")
|
||||
client_send, server_receive = memory_stream_pair()
|
||||
server_send, client_receive = memory_stream_pair()
|
||||
|
||||
client_rw = TrioStreamAdapter(client_send, client_receive)
|
||||
server_rw = TrioStreamAdapter(server_send, server_receive)
|
||||
|
||||
insecure_transport = InsecureTransport(key_pair)
|
||||
|
||||
async def run_outbound(nursery_results):
|
||||
with trio.move_on_after(5):
|
||||
client_conn = await insecure_transport.secure_outbound(client_rw, peer_id)
|
||||
logging.debug("Outbound handshake complete")
|
||||
nursery_results["client"] = client_conn
|
||||
|
||||
async def run_inbound(nursery_results):
|
||||
with trio.move_on_after(5):
|
||||
server_conn = await insecure_transport.secure_inbound(server_rw)
|
||||
logging.debug("Inbound handshake complete")
|
||||
nursery_results["server"] = server_conn
|
||||
|
||||
nursery_results = {}
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(run_outbound, nursery_results)
|
||||
nursery.start_soon(run_inbound, nursery_results)
|
||||
await trio.sleep(0.1) # Give tasks a chance to finish
|
||||
|
||||
client_conn = nursery_results.get("client")
|
||||
server_conn = nursery_results.get("server")
|
||||
|
||||
if client_conn is None or server_conn is None:
|
||||
raise RuntimeError("Handshake failed: client_conn or server_conn is None")
|
||||
|
||||
logging.debug("secure_conn_pair setup complete")
|
||||
return client_conn, server_conn
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def yamux_pair(secure_conn_pair, peer_id):
|
||||
"""Create a pair of Yamux multiplexers for testing."""
|
||||
logging.debug("Setting up yamux_pair")
|
||||
client_conn, server_conn = secure_conn_pair
|
||||
client_yamux = Yamux(client_conn, peer_id, is_initiator=True)
|
||||
server_yamux = Yamux(server_conn, peer_id, is_initiator=False)
|
||||
async with trio.open_nursery() as nursery:
|
||||
with trio.move_on_after(5):
|
||||
nursery.start_soon(client_yamux.start)
|
||||
nursery.start_soon(server_yamux.start)
|
||||
await trio.sleep(0.1)
|
||||
logging.debug("yamux_pair started")
|
||||
yield client_yamux, server_yamux
|
||||
logging.debug("yamux_pair cleanup")
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_yamux_race_condition_without_locks(yamux_pair):
|
||||
"""
|
||||
Test for race-around/interleaving in Yamux streams,when reading in
|
||||
segments of data.
|
||||
This launches concurrent writers/readers on both sides of a stream.
|
||||
If there is no proper locking, the received data may be interleaved
|
||||
or corrupted.
|
||||
|
||||
The test creates structured messages and verifies they are received
|
||||
intact and in order.
|
||||
Without proper locking, concurrent read/write operations could cause
|
||||
data corruption
|
||||
or message interleaving, which this test will catch.
|
||||
"""
|
||||
client_yamux, server_yamux = yamux_pair
|
||||
client_stream: YamuxStream = await client_yamux.open_stream()
|
||||
server_stream: YamuxStream = await server_yamux.accept_stream()
|
||||
MSG_COUNT = 10
|
||||
MSG_SIZE = 256 * 1024 # At max,only DEFAULT_WINDOW_SIZE bytes can be read
|
||||
client_msgs = [
|
||||
f"CLIENT-MSG-{i:03d}-".encode().ljust(MSG_SIZE, b"C") for i in range(MSG_COUNT)
|
||||
]
|
||||
server_msgs = [
|
||||
f"SERVER-MSG-{i:03d}-".encode().ljust(MSG_SIZE, b"S") for i in range(MSG_COUNT)
|
||||
]
|
||||
client_received = []
|
||||
server_received = []
|
||||
|
||||
async def writer(stream, msgs, name):
|
||||
"""Write messages with minimal delays to encourage race conditions."""
|
||||
for i, msg in enumerate(msgs):
|
||||
await stream.write(msg)
|
||||
# Yield control frequently to encourage interleaving
|
||||
if i % 5 == 0:
|
||||
await trio.sleep(0.005)
|
||||
|
||||
async def reader(stream, received, name):
|
||||
"""Read messages and store them for verification."""
|
||||
for i in range(MSG_COUNT):
|
||||
data = await stream.read(MSG_SIZE)
|
||||
received.append(data)
|
||||
if i % 3 == 0:
|
||||
await trio.sleep(0.001)
|
||||
|
||||
# Running all operations concurrently
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(writer, client_stream, client_msgs, "client")
|
||||
nursery.start_soon(writer, server_stream, server_msgs, "server")
|
||||
nursery.start_soon(reader, client_stream, client_received, "client")
|
||||
nursery.start_soon(reader, server_stream, server_received, "server")
|
||||
|
||||
assert len(client_received) == MSG_COUNT, (
|
||||
f"Client received {len(client_received)} messages, expected {MSG_COUNT}"
|
||||
)
|
||||
assert len(server_received) == MSG_COUNT, (
|
||||
f"Server received {len(server_received)} messages, expected {MSG_COUNT}"
|
||||
)
|
||||
assert client_received == server_msgs, (
|
||||
"Client did not receive server messages in order or intact!"
|
||||
)
|
||||
assert server_received == client_msgs, (
|
||||
"Server did not receive client messages in order or intact!"
|
||||
)
|
||||
for i, msg in enumerate(client_received):
|
||||
assert len(msg) == MSG_SIZE, (
|
||||
f"Client message {i} has wrong size: {len(msg)} != {MSG_SIZE}"
|
||||
)
|
||||
|
||||
for i, msg in enumerate(server_received):
|
||||
assert len(msg) == MSG_SIZE, (
|
||||
f"Server message {i} has wrong size: {len(msg)} != {MSG_SIZE}"
|
||||
)
|
||||
|
||||
await client_stream.close()
|
||||
await server_stream.close()
|
||||
195
tests/core/stream_muxer/test_yamux_interleaving_EOF.py
Normal file
195
tests/core/stream_muxer/test_yamux_interleaving_EOF.py
Normal file
@ -0,0 +1,195 @@
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
from trio.testing import (
|
||||
memory_stream_pair,
|
||||
)
|
||||
|
||||
from libp2p.abc import IRawConnection
|
||||
from libp2p.crypto.ed25519 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.security.insecure.transport import (
|
||||
InsecureTransport,
|
||||
)
|
||||
from libp2p.stream_muxer.exceptions import MuxedStreamEOF
|
||||
from libp2p.stream_muxer.yamux.yamux import (
|
||||
Yamux,
|
||||
YamuxStream,
|
||||
)
|
||||
|
||||
|
||||
class TrioStreamAdapter(IRawConnection):
|
||||
"""Adapter to make trio memory streams work with libp2p."""
|
||||
|
||||
def __init__(self, send_stream, receive_stream, is_initiator=False):
|
||||
self.send_stream = send_stream
|
||||
self.receive_stream = receive_stream
|
||||
self.is_initiator = is_initiator
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
logging.debug(f"Attempting to write {len(data)} bytes")
|
||||
with trio.move_on_after(2):
|
||||
await self.send_stream.send_all(data)
|
||||
|
||||
async def read(self, n: int | None = None) -> bytes:
|
||||
if n is None or n <= 0:
|
||||
raise ValueError("Reading unbounded or zero bytes not supported")
|
||||
logging.debug(f"Attempting to read {n} bytes")
|
||||
with trio.move_on_after(2):
|
||||
data = await self.receive_stream.receive_some(n)
|
||||
logging.debug(f"Read {len(data)} bytes")
|
||||
return data
|
||||
|
||||
async def close(self) -> None:
|
||||
logging.debug("Closing stream")
|
||||
await self.send_stream.aclose()
|
||||
await self.receive_stream.aclose()
|
||||
|
||||
def get_remote_address(self) -> tuple[str, int] | None:
|
||||
"""Return None since this is a test adapter without real network info."""
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def key_pair():
|
||||
return create_new_key_pair()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def peer_id(key_pair):
|
||||
return ID.from_pubkey(key_pair.public_key)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def secure_conn_pair(key_pair, peer_id):
|
||||
"""Create a pair of secure connections for testing."""
|
||||
logging.debug("Setting up secure_conn_pair")
|
||||
client_send, server_receive = memory_stream_pair()
|
||||
server_send, client_receive = memory_stream_pair()
|
||||
|
||||
client_rw = TrioStreamAdapter(client_send, client_receive)
|
||||
server_rw = TrioStreamAdapter(server_send, server_receive)
|
||||
|
||||
insecure_transport = InsecureTransport(key_pair)
|
||||
|
||||
async def run_outbound(nursery_results):
|
||||
with trio.move_on_after(5):
|
||||
client_conn = await insecure_transport.secure_outbound(client_rw, peer_id)
|
||||
logging.debug("Outbound handshake complete")
|
||||
nursery_results["client"] = client_conn
|
||||
|
||||
async def run_inbound(nursery_results):
|
||||
with trio.move_on_after(5):
|
||||
server_conn = await insecure_transport.secure_inbound(server_rw)
|
||||
logging.debug("Inbound handshake complete")
|
||||
nursery_results["server"] = server_conn
|
||||
|
||||
nursery_results = {}
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(run_outbound, nursery_results)
|
||||
nursery.start_soon(run_inbound, nursery_results)
|
||||
await trio.sleep(0.1) # Give tasks a chance to finish
|
||||
|
||||
client_conn = nursery_results.get("client")
|
||||
server_conn = nursery_results.get("server")
|
||||
|
||||
if client_conn is None or server_conn is None:
|
||||
raise RuntimeError("Handshake failed: client_conn or server_conn is None")
|
||||
|
||||
logging.debug("secure_conn_pair setup complete")
|
||||
return client_conn, server_conn
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def yamux_pair(secure_conn_pair, peer_id):
|
||||
"""Create a pair of Yamux multiplexers for testing."""
|
||||
logging.debug("Setting up yamux_pair")
|
||||
client_conn, server_conn = secure_conn_pair
|
||||
client_yamux = Yamux(client_conn, peer_id, is_initiator=True)
|
||||
server_yamux = Yamux(server_conn, peer_id, is_initiator=False)
|
||||
async with trio.open_nursery() as nursery:
|
||||
with trio.move_on_after(5):
|
||||
nursery.start_soon(client_yamux.start)
|
||||
nursery.start_soon(server_yamux.start)
|
||||
await trio.sleep(0.1)
|
||||
logging.debug("yamux_pair started")
|
||||
yield client_yamux, server_yamux
|
||||
logging.debug("yamux_pair cleanup")
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_yamux_race_condition_without_locks(yamux_pair):
|
||||
"""
|
||||
Test for race-around/interleaving in Yamux streams,when reading till
|
||||
EOF is being used.
|
||||
This launches concurrent writers/readers on both sides of a stream.
|
||||
If there is no proper locking, the received data may be interleaved
|
||||
or corrupted.
|
||||
|
||||
The test creates structured messages and verifies they are received
|
||||
intact and in order.
|
||||
Without proper locking, concurrent read/write operations could cause
|
||||
data corruption
|
||||
or message interleaving, which this test will catch.
|
||||
"""
|
||||
client_yamux, server_yamux = yamux_pair
|
||||
client_stream: YamuxStream = await client_yamux.open_stream()
|
||||
server_stream: YamuxStream = await server_yamux.accept_stream()
|
||||
MSG_COUNT = 1
|
||||
MSG_SIZE = 512 * 1024
|
||||
client_msgs = [
|
||||
f"CLIENT-MSG-{i:03d}-".encode().ljust(MSG_SIZE, b"C") for i in range(MSG_COUNT)
|
||||
]
|
||||
server_msgs = [
|
||||
f"SERVER-MSG-{i:03d}-".encode().ljust(MSG_SIZE, b"S") for i in range(MSG_COUNT)
|
||||
]
|
||||
client_received = []
|
||||
server_received = []
|
||||
|
||||
async def writer(stream, msgs, name):
|
||||
"""Write messages with minimal delays to encourage race conditions."""
|
||||
for i, msg in enumerate(msgs):
|
||||
await stream.write(msg)
|
||||
# Yield control frequently to encourage interleaving
|
||||
if i % 5 == 0:
|
||||
await trio.sleep(0.005)
|
||||
|
||||
async def reader(stream, received, name):
|
||||
"""Read messages and store them for verification."""
|
||||
try:
|
||||
data = await stream.read()
|
||||
if data:
|
||||
received.append(data)
|
||||
except MuxedStreamEOF:
|
||||
pass
|
||||
|
||||
# Running all operations concurrently
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(writer, client_stream, client_msgs, "client")
|
||||
nursery.start_soon(writer, server_stream, server_msgs, "server")
|
||||
nursery.start_soon(reader, client_stream, client_received, "client")
|
||||
nursery.start_soon(reader, server_stream, server_received, "server")
|
||||
|
||||
assert client_received == server_msgs, (
|
||||
"Client did not receive server messages in order or intact!"
|
||||
)
|
||||
assert server_received == client_msgs, (
|
||||
"Server did not receive client messages in order or intact!"
|
||||
)
|
||||
for i, msg in enumerate(client_received):
|
||||
assert len(msg) == MSG_SIZE, (
|
||||
f"Client message {i} has wrong size: {len(msg)} != {MSG_SIZE}"
|
||||
)
|
||||
|
||||
for i, msg in enumerate(server_received):
|
||||
assert len(msg) == MSG_SIZE, (
|
||||
f"Server message {i} has wrong size: {len(msg)} != {MSG_SIZE}"
|
||||
)
|
||||
|
||||
await client_stream.close()
|
||||
await server_stream.close()
|
||||
0
tests/discovery/__init__.py
Normal file
0
tests/discovery/__init__.py
Normal file
0
tests/discovery/mdns/__init__.py
Normal file
0
tests/discovery/mdns/__init__.py
Normal file
91
tests/discovery/mdns/test_broadcaster.py
Normal file
91
tests/discovery/mdns/test_broadcaster.py
Normal file
@ -0,0 +1,91 @@
|
||||
"""
|
||||
Unit tests for mDNS broadcaster component.
|
||||
"""
|
||||
|
||||
from zeroconf import Zeroconf
|
||||
|
||||
from libp2p.discovery.mdns.broadcaster import PeerBroadcaster
|
||||
from libp2p.peer.id import ID
|
||||
|
||||
|
||||
class TestPeerBroadcaster:
|
||||
"""Unit tests for PeerBroadcaster."""
|
||||
|
||||
def test_broadcaster_initialization(self):
|
||||
"""Test that broadcaster initializes correctly."""
|
||||
zeroconf = Zeroconf()
|
||||
service_type = "_p2p._udp.local."
|
||||
service_name = "test-peer._p2p._udp.local."
|
||||
peer_id = (
|
||||
"QmNnooDu7bfjPFoTZYxMNLWUQJyrVwtbZg5gBMjTezGAJN" # String, not ID object
|
||||
)
|
||||
port = 8000
|
||||
|
||||
broadcaster = PeerBroadcaster(
|
||||
zeroconf=zeroconf,
|
||||
service_type=service_type,
|
||||
service_name=service_name,
|
||||
peer_id=peer_id,
|
||||
port=port,
|
||||
)
|
||||
|
||||
assert broadcaster.zeroconf == zeroconf
|
||||
assert broadcaster.service_type == service_type
|
||||
assert broadcaster.service_name == service_name
|
||||
assert broadcaster.peer_id == peer_id
|
||||
assert broadcaster.port == port
|
||||
|
||||
# Clean up
|
||||
zeroconf.close()
|
||||
|
||||
def test_broadcaster_service_creation(self):
|
||||
"""Test that broadcaster creates valid service info."""
|
||||
zeroconf = Zeroconf()
|
||||
service_type = "_p2p._udp.local."
|
||||
service_name = "test-peer2._p2p._udp.local."
|
||||
peer_id_obj = ID.from_base58("QmNnooDu7bfjPFoTZYxMNLWUQJyrVwtbZg5gBMjTezGAJN")
|
||||
peer_id = str(peer_id_obj) # Convert to string
|
||||
port = 8000
|
||||
|
||||
broadcaster = PeerBroadcaster(
|
||||
zeroconf=zeroconf,
|
||||
service_type=service_type,
|
||||
service_name=service_name,
|
||||
peer_id=peer_id,
|
||||
port=port,
|
||||
)
|
||||
|
||||
# Verify service was created and registered
|
||||
service_info = broadcaster.service_info
|
||||
assert service_info is not None
|
||||
assert service_info.type == service_type
|
||||
assert service_info.name == service_name
|
||||
assert service_info.port == port
|
||||
assert b"id" in service_info.properties
|
||||
assert service_info.properties[b"id"] == peer_id.encode()
|
||||
|
||||
# Clean up
|
||||
zeroconf.close()
|
||||
|
||||
def test_broadcaster_start_stop(self):
|
||||
"""Test that broadcaster can start and stop correctly."""
|
||||
zeroconf = Zeroconf()
|
||||
service_type = "_p2p._udp.local."
|
||||
service_name = "test-start-stop._p2p._udp.local."
|
||||
peer_id_obj = ID.from_base58("QmYyQSo1c1Ym7orWxLYvCrM2EmxFTANf8wXmmE7DWjhx5N")
|
||||
peer_id = str(peer_id_obj) # Convert to string
|
||||
port = 8001
|
||||
|
||||
broadcaster = PeerBroadcaster(
|
||||
zeroconf=zeroconf,
|
||||
service_type=service_type,
|
||||
service_name=service_name,
|
||||
peer_id=peer_id,
|
||||
port=port,
|
||||
)
|
||||
|
||||
# Service should be registered
|
||||
assert broadcaster.service_info is not None
|
||||
|
||||
# Clean up
|
||||
zeroconf.close()
|
||||
114
tests/discovery/mdns/test_listener.py
Normal file
114
tests/discovery/mdns/test_listener.py
Normal file
@ -0,0 +1,114 @@
|
||||
"""
|
||||
Unit tests for mDNS listener component.
|
||||
"""
|
||||
|
||||
import socket
|
||||
|
||||
from zeroconf import ServiceInfo, Zeroconf
|
||||
|
||||
from libp2p.abc import Multiaddr
|
||||
from libp2p.discovery.mdns.listener import PeerListener
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerstore import PeerStore
|
||||
|
||||
|
||||
class TestPeerListener:
|
||||
"""Unit tests for PeerListener."""
|
||||
|
||||
def test_listener_initialization(self):
|
||||
"""Test that listener initializes correctly."""
|
||||
peerstore = PeerStore()
|
||||
zeroconf = Zeroconf()
|
||||
service_type = "_p2p._udp.local."
|
||||
service_name = "local-peer._p2p._udp.local."
|
||||
|
||||
listener = PeerListener(
|
||||
peerstore=peerstore,
|
||||
zeroconf=zeroconf,
|
||||
service_type=service_type,
|
||||
service_name=service_name,
|
||||
)
|
||||
|
||||
assert listener.peerstore == peerstore
|
||||
assert listener.zeroconf == zeroconf
|
||||
assert listener.service_type == service_type
|
||||
assert listener.service_name == service_name
|
||||
assert listener.discovered_services == {}
|
||||
|
||||
# Clean up
|
||||
listener.stop()
|
||||
zeroconf.close()
|
||||
|
||||
def test_listener_extract_peer_info_success(self):
|
||||
"""Test successful PeerInfo extraction from ServiceInfo."""
|
||||
peerstore = PeerStore()
|
||||
zeroconf = Zeroconf()
|
||||
|
||||
listener = PeerListener(
|
||||
peerstore=peerstore,
|
||||
zeroconf=zeroconf,
|
||||
service_type="_p2p._udp.local.",
|
||||
service_name="local._p2p._udp.local.",
|
||||
)
|
||||
|
||||
# Create sample service info
|
||||
sample_peer_id = ID.from_base58(
|
||||
"QmNnooDu7bfjPFoTZYxMNLWUQJyrVwtbZg5gBMjTezGAJN"
|
||||
)
|
||||
hostname = socket.gethostname()
|
||||
local_ip = "192.168.1.100"
|
||||
|
||||
sample_service_info = ServiceInfo(
|
||||
type_="_p2p._udp.local.",
|
||||
name="test-peer._p2p._udp.local.",
|
||||
port=8000,
|
||||
properties={b"id": str(sample_peer_id).encode()},
|
||||
server=f"{hostname}.local.",
|
||||
addresses=[socket.inet_aton(local_ip)],
|
||||
)
|
||||
|
||||
peer_info = listener._extract_peer_info(sample_service_info)
|
||||
|
||||
assert peer_info is not None
|
||||
assert isinstance(peer_info.peer_id, ID)
|
||||
assert len(peer_info.addrs) > 0
|
||||
assert all(isinstance(addr, Multiaddr) for addr in peer_info.addrs)
|
||||
|
||||
# Check that protocol is TCP since we always use TCP
|
||||
assert "/tcp/" in str(peer_info.addrs[0])
|
||||
|
||||
# Clean up
|
||||
listener.stop()
|
||||
zeroconf.close()
|
||||
|
||||
def test_listener_extract_peer_info_invalid_id(self):
|
||||
"""Test PeerInfo extraction fails with invalid peer ID."""
|
||||
peerstore = PeerStore()
|
||||
zeroconf = Zeroconf()
|
||||
|
||||
listener = PeerListener(
|
||||
peerstore=peerstore,
|
||||
zeroconf=zeroconf,
|
||||
service_type="_p2p._udp.local.",
|
||||
service_name="local._p2p._udp.local.",
|
||||
)
|
||||
|
||||
# Create service info with invalid peer ID
|
||||
hostname = socket.gethostname()
|
||||
local_ip = "192.168.1.100"
|
||||
|
||||
service_info = ServiceInfo(
|
||||
type_="_p2p._udp.local.",
|
||||
name="invalid-peer._p2p._udp.local.",
|
||||
port=8000,
|
||||
properties={b"id": b"invalid_peer_id_format"},
|
||||
server=f"{hostname}.local.",
|
||||
addresses=[socket.inet_aton(local_ip)],
|
||||
)
|
||||
|
||||
peer_info = listener._extract_peer_info(service_info)
|
||||
assert peer_info is None
|
||||
|
||||
# Clean up
|
||||
listener.stop()
|
||||
zeroconf.close()
|
||||
121
tests/discovery/mdns/test_mdns.py
Normal file
121
tests/discovery/mdns/test_mdns.py
Normal file
@ -0,0 +1,121 @@
|
||||
"""
|
||||
Comprehensive integration tests for mDNS discovery functionality.
|
||||
"""
|
||||
|
||||
import socket
|
||||
|
||||
from zeroconf import Zeroconf
|
||||
|
||||
from libp2p.discovery.mdns.broadcaster import PeerBroadcaster
|
||||
from libp2p.discovery.mdns.listener import PeerListener
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerstore import PeerStore
|
||||
|
||||
|
||||
class TestMDNSDiscovery:
|
||||
"""Comprehensive integration tests for mDNS peer discovery."""
|
||||
|
||||
def test_one_host_finds_another(self):
|
||||
"""Test that one host can find another host using mDNS."""
|
||||
# Create two separate Zeroconf instances to simulate different hosts
|
||||
host1_zeroconf = Zeroconf()
|
||||
host2_zeroconf = Zeroconf()
|
||||
|
||||
try:
|
||||
# Host 1: Set up as broadcaster (the host to be discovered)
|
||||
host1_peer_id_obj = ID.from_base58(
|
||||
"QmNnooDu7bfjPFoTZYxMNLWUQJyrVwtbZg5gBMjTezGAJN"
|
||||
)
|
||||
host1_peer_id = str(host1_peer_id_obj) # Convert to string
|
||||
host1_broadcaster = PeerBroadcaster(
|
||||
zeroconf=host1_zeroconf,
|
||||
service_type="_p2p._udp.local.",
|
||||
service_name="host1._p2p._udp.local.",
|
||||
peer_id=host1_peer_id,
|
||||
port=8000,
|
||||
)
|
||||
|
||||
# Host 2: Set up as listener (the host that discovers others)
|
||||
host2_peerstore = PeerStore()
|
||||
host2_listener = PeerListener(
|
||||
peerstore=host2_peerstore,
|
||||
zeroconf=host2_zeroconf,
|
||||
service_type="_p2p._udp.local.",
|
||||
service_name="host2._p2p._udp.local.",
|
||||
)
|
||||
|
||||
# Host 1 registers its service for discovery
|
||||
host1_broadcaster.register()
|
||||
|
||||
# Verify that host2 discovered host1
|
||||
assert len(host2_listener.discovered_services) > 0
|
||||
assert "host1._p2p._udp.local." in host2_listener.discovered_services
|
||||
|
||||
# Verify that host1's peer info was added to host2's peerstore
|
||||
discovered_peer_id = host2_listener.discovered_services[
|
||||
"host1._p2p._udp.local."
|
||||
]
|
||||
assert str(discovered_peer_id) == host1_peer_id
|
||||
|
||||
# Verify addresses were added to peerstore
|
||||
try:
|
||||
addrs = host2_peerstore.addrs(discovered_peer_id)
|
||||
assert len(addrs) > 0
|
||||
# Should be TCP since we always use TCP protocol
|
||||
assert "/tcp/8000" in str(addrs[0])
|
||||
except Exception:
|
||||
# If no addresses found, the discovery didn't work properly
|
||||
assert False, "Host1 addresses should be in Host2's peerstore"
|
||||
|
||||
# Clean up
|
||||
host1_broadcaster.unregister()
|
||||
host2_listener.stop()
|
||||
|
||||
finally:
|
||||
host1_zeroconf.close()
|
||||
host2_zeroconf.close()
|
||||
|
||||
def test_service_info_extraction(self):
|
||||
"""Test service info extraction functionality."""
|
||||
peerstore = PeerStore()
|
||||
zeroconf = Zeroconf()
|
||||
|
||||
try:
|
||||
listener = PeerListener(
|
||||
peerstore=peerstore,
|
||||
zeroconf=zeroconf,
|
||||
service_type="_p2p._udp.local.",
|
||||
service_name="test-listener._p2p._udp.local.",
|
||||
)
|
||||
|
||||
# Create a test service info
|
||||
test_peer_id = ID.from_base58(
|
||||
"QmYyQSo1c1Ym7orWxLYvCrM2EmxFTANf8wXmmE7DWjhx5N"
|
||||
)
|
||||
hostname = socket.gethostname()
|
||||
|
||||
from zeroconf import ServiceInfo
|
||||
|
||||
service_info = ServiceInfo(
|
||||
type_="_p2p._udp.local.",
|
||||
name="test-service._p2p._udp.local.",
|
||||
port=8001,
|
||||
properties={b"id": str(test_peer_id).encode()},
|
||||
server=f"{hostname}.local.",
|
||||
addresses=[socket.inet_aton("192.168.1.100")],
|
||||
)
|
||||
|
||||
# Test extraction
|
||||
peer_info = listener._extract_peer_info(service_info)
|
||||
|
||||
assert peer_info is not None
|
||||
assert peer_info.peer_id == test_peer_id
|
||||
assert len(peer_info.addrs) == 1
|
||||
assert "/tcp/8001" in str(peer_info.addrs[0])
|
||||
|
||||
print("✅ Service info extraction test successful!")
|
||||
print(f" Extracted peer ID: {peer_info.peer_id}")
|
||||
print(f" Extracted addresses: {[str(addr) for addr in peer_info.addrs]}")
|
||||
|
||||
finally:
|
||||
zeroconf.close()
|
||||
39
tests/discovery/mdns/test_utils.py
Normal file
39
tests/discovery/mdns/test_utils.py
Normal file
@ -0,0 +1,39 @@
|
||||
"""
|
||||
Basic unit tests for mDNS utils module.
|
||||
"""
|
||||
|
||||
import string
|
||||
|
||||
from libp2p.discovery.mdns.utils import stringGen
|
||||
|
||||
|
||||
class TestStringGen:
|
||||
"""Unit tests for stringGen function."""
|
||||
|
||||
def test_stringgen_default_length(self):
|
||||
"""Test stringGen with default length (63)."""
|
||||
result = stringGen()
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert len(result) == 63
|
||||
|
||||
# Check that all characters are from the expected charset
|
||||
charset = string.ascii_lowercase + string.digits
|
||||
for char in result:
|
||||
assert char in charset
|
||||
|
||||
def test_stringgen_custom_length(self):
|
||||
"""Test stringGen with custom lengths."""
|
||||
# Test various lengths
|
||||
test_lengths = [1, 5, 10, 20, 50, 100]
|
||||
|
||||
for length in test_lengths:
|
||||
result = stringGen(length)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert len(result) == length
|
||||
|
||||
# Check that all characters are from the expected charset
|
||||
charset = string.ascii_lowercase + string.digits
|
||||
for char in result:
|
||||
assert char in charset
|
||||
@ -24,16 +24,22 @@ def make_pubsub_msg(
|
||||
)
|
||||
|
||||
|
||||
# TODO: Implement sparse connect
|
||||
async def dense_connect(hosts: Sequence[IHost]) -> None:
|
||||
await connect_some(hosts, 10)
|
||||
|
||||
|
||||
# FIXME: `degree` is not used at all
|
||||
async def connect_some(hosts: Sequence[IHost], degree: int) -> None:
|
||||
"""
|
||||
Connect each host to up to 'degree' number of other hosts.
|
||||
Creates a sparse network topology where each node has limited connections.
|
||||
"""
|
||||
for i, host in enumerate(hosts):
|
||||
for host2 in hosts[i + 1 :]:
|
||||
await connect(host, host2)
|
||||
connections_made = 0
|
||||
for j in range(i + 1, len(hosts)):
|
||||
if connections_made >= degree:
|
||||
break
|
||||
await connect(host, hosts[j])
|
||||
connections_made += 1
|
||||
|
||||
|
||||
async def one_to_all_connect(hosts: Sequence[IHost], central_host_index: int) -> None:
|
||||
|
||||
31
tests/utils/utils.py
Normal file
31
tests/utils/utils.py
Normal file
@ -0,0 +1,31 @@
|
||||
from unittest.mock import (
|
||||
MagicMock,
|
||||
)
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.abc import IHost
|
||||
|
||||
|
||||
def create_mock_connections(count: int = 50) -> dict:
|
||||
connections = {}
|
||||
|
||||
for i in range(1, count):
|
||||
peer_id = f"peer-{i}"
|
||||
mock_conn = MagicMock(name=f"INetConn-{i}")
|
||||
connections[peer_id] = mock_conn
|
||||
|
||||
return connections
|
||||
|
||||
|
||||
async def run_host_forever(host: IHost, addr):
|
||||
async with host.run([addr]):
|
||||
await trio.sleep_forever()
|
||||
|
||||
|
||||
async def wait_until_listening(host, timeout=3):
|
||||
with trio.move_on_after(timeout):
|
||||
while not host.get_addrs():
|
||||
await trio.sleep(0.05)
|
||||
return
|
||||
raise RuntimeError("Timed out waiting for host to get an address")
|
||||
Reference in New Issue
Block a user