163 Commits

Author SHA1 Message Date
80b58a2ae0 Merge branch 'main' into noise-arch-change 2025-09-05 02:55:55 +05:30
9370101a84 Merge pull request #843 from unniznd/fix_pubsub_msg_id_type_inconsistency
fix: message id type inconsistency in handle ihave and message id parsing improvement in handle iwant
2025-09-04 23:39:14 +05:30
56732a1506 Merge branch 'main' into fix_pubsub_msg_id_type_inconsistency 2025-09-04 16:26:01 +05:30
2a249b1792 Merge pull request #849 from ankur12-1610/issue-798
Enhance Bootstrap module to dial peers after address resolution.
2025-09-04 16:12:56 +05:30
5ec1671608 Merge branch 'main' into issue-798 2025-09-04 14:14:31 +05:30
431a4807fb Merge pull request #886 from yashksaini-coder/fix/cross_platform_path_tests
Fix: Cross-Platform Path Handling Standardization
2025-09-04 14:12:04 +05:30
f54a14b713 Merge branch 'main' into issue-798 2025-09-04 13:41:45 +05:30
37a4d96f90 add rst 2025-09-02 22:23:11 +05:30
b8217bb8a8 Merge branch 'main' into fix_pubsub_msg_id_type_inconsistency 2025-09-02 10:16:17 +05:30
333d56dc00 Merge branch 'main' into noise-arch-change 2025-09-02 03:40:54 +05:30
d385cb45cf Merge branch 'libp2p:main' into fix/cross_platform_path_tests 2025-09-02 03:22:56 +05:30
2535305123 Merge pull request #838 from unniznd/fix_multiselect_negotiate_type
fix: Added multiselect type consistency in negotiate method
2025-09-02 03:03:56 +05:30
9df542f97f Merge branch 'main' into fix_multiselect_negotiate_type 2025-09-02 02:38:33 +05:30
93fe070cfb Merge pull request #884 from acul71/fix/issue-883-transport-issues-todos
fix: remove unused upgrade_listener function (Issue 2 from #726)
2025-09-02 02:38:17 +05:30
7a4c955c98 Merge branch 'main' into fix/issue-883-transport-issues-todos 2025-09-02 01:50:14 +05:30
14a74fdbd1 Merge branch 'main' into fix/cross_platform_path_tests 2025-09-02 01:42:11 +05:30
934f49af83 Merge branch 'main' into fix_multiselect_negotiate_type 2025-09-02 01:40:40 +05:30
970b535b25 Merge pull request #889 from lla-dane/pubsub-record
Signed-Peer-Record support in Pubsub/Gossipsub message transfer
2025-09-02 01:39:44 +05:30
145727a9ba Refactor logging code: Remove unnecessary blank lines in logging setup and cleanup functions for improved readability. Update tests to reflect formatting changes. 2025-09-02 01:39:24 +05:30
84c1a7031a Enhance logging cleanup: Introduce global handler management for proper resource cleanup on exit and during logging setup. Update tests to ensure file handlers are closed correctly across platforms. 2025-09-02 01:23:12 +05:30
fc6b290c56 Merge branch 'main' into fix_multiselect_negotiate_type 2025-09-02 01:08:21 +05:30
20edc3830a Merge branch 'main' into fix_pubsub_msg_id_type_inconsistency 2025-09-02 01:07:16 +05:30
ef6557518c Merge branch 'main' into pubsub-record 2025-09-02 01:01:08 +05:30
6742dd38f7 Merge branch 'main' into fix/cross_platform_path_tests 2025-09-02 01:00:23 +05:30
1783a6b0b9 Merge pull request #876 from bomanaps/feat/swarm-multi-connection-support
feat(swarm): enhance swarm with retry backoff
2025-09-02 00:35:48 +05:30
1077516196 update newsfragment 2025-09-01 18:11:22 +05:30
aad87f983f Adress documentation comment 2025-09-01 11:58:42 +01:00
69680e9c1f Added negative testcases 2025-09-01 10:30:25 +05:30
7d6eb28d7c message inconsistency fixed 2025-09-01 09:48:08 +05:30
fcb35084b3 fix(docs): Update tomllib import handling and streamline pyproject path resolution 2025-09-01 03:14:09 +05:30
42c8937a8d build(app): Add fallback to os.path.join + newsfragment 886 2025-09-01 02:53:53 +05:30
64ccce17eb fix(app): 882 Comprehensive cross-platform path handling utilities 2025-09-01 02:03:51 +05:30
6a24b138dd feat: Add cross-platform path utilities module 2025-09-01 01:35:32 +05:30
9a06ee429f Fix documentation build issues and add _build/ to .gitignore 2025-08-31 02:01:39 +01:00
526b65e1d5 style: apply ruff formatting fixes 2025-08-31 01:43:27 +01:00
59e1d9ae39 address architectural refactoring discussed 2025-08-31 01:38:29 +01:00
d620270eaf docs: add newsfragment for issue 883 - remove unused upgrade_listener function 2025-08-31 00:10:15 +02:00
31040931ea fix: remove unused upgrade_listener function (Issue 2 from #726) 2025-08-30 23:44:49 +02:00
96e2149f4d added newsfragment 2025-08-29 18:06:27 +05:30
cb5bfeda39 Use the same comment in maybe_consume_peer_record function 2025-08-29 18:06:27 +05:30
b26e8333bd updated as per the suggestions in #815 2025-08-29 18:06:25 +05:30
d99b67eafa now ignoring pubsub messages upon receving invalid-signed-records 2025-08-29 18:06:09 +05:30
cdfb083c06 added tests to see if transfer works correctly 2025-08-29 18:06:09 +05:30
d4c387f923 add reissuing mechanism of records if addrs dont change as done in #815 2025-08-29 18:06:09 +05:30
56526b4870 signed-peer-record transfer integrated with pubsub rpc message trasfer 2025-08-29 18:05:46 +05:30
8f5dd3bd11 remove excessive use of trio nursery 2025-08-29 17:34:47 +05:30
997094e5b7 resolve linting errors 2025-08-29 12:55:18 +05:30
3c52b859ba improved the error message 2025-08-29 11:30:17 +05:30
426aae7efb Merge branch 'main' into fix_multiselect_negotiate_type 2025-08-29 03:25:12 +05:30
40dad64949 Merge branch 'main' into fix_pubsub_msg_id_type_inconsistency 2025-08-29 03:24:53 +05:30
999315a74a Merge branch 'main' into noise-arch-change 2025-08-29 03:23:05 +05:30
df39e240e7 Merge branch 'main' into feat/swarm-multi-connection-support 2025-08-29 03:11:08 +05:30
5c11ac20e7 Merge pull request #815 from lla-dane/kad-record
Signed-Peer-Record support in KAD-DHT message transfer mechanism.
2025-08-29 03:09:07 +05:30
9fa3afbb04 fix: format code to pass CI lint 2025-08-28 22:18:33 +01:00
3d1c36419c remove checkpoints, resolve logs, ttl and fix minor issues 2025-08-29 02:05:34 +05:30
c577fd2f71 feat(swarm): enhance swarm with retry backoff 2025-08-28 20:59:36 +01:00
9f80dbae12 added the testcase for StreamFailure 2025-08-27 22:05:19 +05:30
c08007feda improve error message in basic host 2025-08-27 21:54:05 +05:30
c2c4228591 added test for ADD_PROVIDER record processing 2025-08-27 13:02:32 +05:30
943bcc4d36 fix the logic error in add_provider handling 2025-08-27 10:17:40 +05:30
8100a5cd20 removed redudant check in seen seqnos and peers and added test cases of handle iwant and handle ihave 2025-08-26 21:49:12 +05:30
2006b2c92c added newsfragment 2025-08-26 12:59:18 +05:30
fe3f7adc1b fix typos 2025-08-26 12:49:51 +05:30
7b2d637382 Now using env_to_send_in_RPC for issuing records in Identify rpc messages 2025-08-26 12:49:51 +05:30
91bee9df89 Moved env_to_send_in_RPC function to libp2p/peer/peerstore.py 2025-08-26 12:49:51 +05:30
5bf9c7b537 Fix spinx error 2025-08-26 12:49:51 +05:30
8958c0fac3 Moved env_to_send_in_RPC function to libp2p/init.py 2025-08-26 12:49:51 +05:30
091ac082b9 Commented out the bool variable from env_to_send_in_RPC() at places 2025-08-26 12:49:51 +05:30
15f4a399ec Added and docstrings and removed typos 2025-08-26 12:49:51 +05:30
3917d7b596 verify peer_id in signed-record matches authenticated sender 2025-08-26 12:49:51 +05:30
3aacb3a391 remove the timeout bound from the kad-dht test 2025-08-26 12:49:51 +05:30
ba39e91a2e added test for req rejection upon invalid record transfer 2025-08-26 12:49:51 +05:30
57d1c9d807 reject dht-msgs upon receiving invalid records 2025-08-26 12:49:51 +05:30
efc899e872 fix abc.py file 2025-08-26 12:49:51 +05:30
cea1985c5c add reissuing mechanism of records if addrs dont change 2025-08-26 12:49:51 +05:30
702ad4876e remove too much repeatitive code 2025-08-26 12:49:51 +05:30
a21d9e878b recompile protobuf schema and remove typos 2025-08-26 12:49:51 +05:30
5ab68026d6 removed redundant logs 2025-08-26 12:49:51 +05:30
d1792588f9 added tests for signed-peee-record transfer in kad-dht 2025-08-26 12:49:51 +05:30
53db128f69 fix typos 2025-08-26 12:49:51 +05:30
cacb3c8aca feat: add webtransport certhashes field to NoiseExtensions and implement serialization test
Signed-off-by: varun-r-mallya <varunrmallya@gmail.com>
2025-08-26 12:49:21 +05:30
c940dac1e6 simplify bootstrap discovery with optimized timeouts 2025-08-26 01:42:25 +05:30
6214697349 removed redundant imports 2025-08-25 23:01:35 +05:30
fb544d6db2 fixed the merge conflict gossipsub module. 2025-08-25 21:12:45 +05:30
b40d84fc26 Merge remote-tracking branch 'origin/main' into fix_pubsub_msg_id_type_inconsistency 2025-08-25 21:11:55 +05:30
cda50e0ead Merge remote-tracking branch 'origin/main' into fix_multiselect_negotiate_type 2025-08-25 21:07:49 +05:30
3b27b02a8b Merge branch 'main' into issue-798 2025-08-25 16:30:40 +05:30
05fde3ad40 Merge branch 'main' into noise-arch-change 2025-08-25 16:21:43 +05:30
292bd1a942 Merge pull request #811 from yashksaini-coder/feat/804-add-thin-waist-address
 Feat: add Thin Waist address validation utilities and integrate into echo example
2025-08-25 15:52:36 +05:30
c9795e3138 Merge branch 'main' into feat/804-add-thin-waist-address 2025-08-25 15:52:14 +05:30
b80817b5ae Merge pull request #855 from bomanaps/tests/notifee-coverage
Add listener lifecycle tests
2025-08-25 15:29:22 +05:30
6c6adf7459 chore(app): 804 Suggested changes - Remove the comment 2025-08-25 12:43:18 +05:30
79f3a173f4 renamed newsfragments to internal 2025-08-25 06:09:40 +01:00
7fb3c2da9f Add newsfragment for PR #855 (PubsubNotifee integration tests) 2025-08-24 23:31:39 +01:00
6b7f50be3d Merge branch 'libp2p:main' into tests/notifee-coverage 2025-08-24 23:03:42 +01:00
6a0a7c21e8 chore(app): Add newsfragment for 811.feature.rst 2025-08-25 01:31:30 +05:30
fde8c8f127 Merge branch 'main' into feat/804-add-thin-waist-address 2025-08-24 23:46:17 +05:30
bc1b1ed6ae fix_gossipsub_mid_type (#859)
* fix_gossipsub_mid_type

* Fix lint and add newsfragments

* Fix mid formation

* Revert "Fix mid formation"

This reverts commit 835d4ca7af58f0716db51a00a8a7aa6cc15ac0a6.
2025-08-24 12:11:36 -06:00
63a8458d45 add import to __init__ 2025-08-24 23:40:05 +05:30
ed91ee0c31 refactor(app): 804 refactored find_free_port() in address_validation.py 2025-08-24 23:28:02 +05:30
75ffb791ac fix: Ensure newline at end of file in address_validation.py and update news fragment formatting 2025-08-24 22:06:07 +05:30
cf48d2e9a4 chore(app): Add 811.internal.rst 2025-08-24 22:03:31 +05:30
88a1f0a390 cherry pick 7a1198c8c6/libp2p/utils/address_validation.py 2025-08-24 21:17:29 +05:30
b38d504fc1 Merge pull request #1 from acul71/fix/multi-address-listening-bug
Fix/multi address listening bug
2025-08-24 13:02:56 +05:30
3bd6d1f579 doc: add newsfragment 2025-08-24 02:29:23 +02:00
b6cbd78943 Fix multi-address listening bug in swarm.listen()
- Fix early return in swarm.listen() that prevented listening on all addresses
- Add comprehensive tests for multi-address listening functionality
- Ensure all available interfaces are properly bound and connectable
2025-08-24 01:49:42 +02:00
ed2716c1bf feat: Enhance echo example to dynamically find free ports and improve address handling
- Added a function to find a free port on localhost.
- Updated the run function to use the new port finding logic when a non-positive port is provided.
- Modified address printing to handle multiple listen addresses correctly.
- Improved the get_available_interfaces function to ensure the IPv4 loopback address is included.
2025-08-22 11:48:37 +05:30
5a2fca32a0 Add ip4 and tcp address resolution and fallback connection attempts 2025-08-22 02:12:42 +05:30
9efc5a1bd1 Merge branch 'libp2p:main' into tests/notifee-coverage 2025-08-21 08:07:53 +01:00
8d9b7f413d Add trio nursery address resolution and connection attempts 2025-08-21 11:20:21 +05:30
5b9bec8e28 fix: Enhance error handling in echo stream handler to manage stream closure and exceptions 2025-08-20 18:29:35 +05:30
c2c91b8c58 refactor: Improve comment formatting in test_echo_thin_waist.py for clarity 2025-08-20 18:05:20 +05:30
8a2d1f7045 Merge branch 'main' into feat/804-add-thin-waist-address 2025-08-20 18:04:45 +05:30
94d695c6bc feat: Implement Random walk in py-libp2p (#822)
* Implementing random walk in py libp2p

* Add documentation for Random Walk module implementation in py-libp2p

* Add Random Walk example for py-libp2p Kademlia DHT

* refactor: peer eviction from routing table stopped

* refactored location of random walk

* add nodesin routing table  from peerstore

* random walk working as expected

* removed extra functions

* Removed all manual triggers

* added newsfragments

* fix linting issues

* refacored logs and cleaned example file

* refactor: update RandomWalk and RTRefreshManager to use query function for peer discovery

* docs: added Random Walk example docs

* added optional argument to use random walk in kademlia DHT

* enabled random walk in example file

* Added tests for RandomWalk module

* fixed lint issues

* Update refresh interval and some more tests are added.

* Removed Random Walk module documentation file

* Extra parentheses have been removed from the random walk logs.

Co-authored-by: Paul Robinson <5199899+pacrob@users.noreply.github.com>

---------

Co-authored-by: Manu Sheel Gupta <manusheel.edu@gmail.com>
Co-authored-by: Paul Robinson <5199899+pacrob@users.noreply.github.com>
2025-08-20 05:10:06 -06:00
905f3a5708 Merge branch 'main' into feat/804-add-thin-waist-address 2025-08-20 09:59:48 +05:30
dabb3a0962 FIXME: Make TProtocol Optional[TProtocol] to keep types consistent (#770)
* FIXME: Make TProtocol Optional[TProtocol] to keep types consistent

* correct test case of test_protocol_muxer

* add newsfragment

* unit test added

---------

Co-authored-by: Manu Sheel Gupta <manusheel.edu@gmail.com>
2025-08-19 19:20:37 -06:00
69d5274891 fix: update listening address parameter in echo example to accept a list 2025-08-19 22:32:26 +05:30
3ff5728209 Merge branch 'feat/804-add-thin-waist-address' of
https://github.com/yashksaini-coder/py-libp2p into feat/804-add-thin-waist-address
2025-08-19 20:47:44 +05:30
a1b16248d3 fix: correct listening address variable in echo example and streamline address printing 2025-08-19 20:47:18 +05:30
55dd8835a7 Merge branch 'main' into feat/804-add-thin-waist-address 2025-08-19 20:02:09 +05:30
e20a9a3814 update maintainer addresses (#856) 2025-08-19 08:29:55 -06:00
7f6469d5d4 Merge remote-tracking branch 'acul71/feat/804-add-thin-waist-address' into feat/804-add-thin-waist-address 2025-08-19 19:56:20 +05:30
ee66958e7f style: fix trailing blank lines in test files 2025-08-19 11:34:40 +01:00
c306400bd9 Add initial listener lifecycle tests; pubsub integration + perf scenarios not yet implemented 2025-08-19 10:49:05 +01:00
05b372b1eb Fix linting and type checking issues for Thin Waist feature 2025-08-19 01:11:48 +02:00
e4ab3cb2c5 Add early data support to Noise protocol
Signed-off-by: varun-r-mallya <varunrmallya@gmail.com>
2025-08-19 04:41:14 +05:30
a9f184be6a Merge branch 'main' into issue-798 2025-08-18 22:02:34 +05:30
a9a6ed6767 Merge branch 'main' into fix_pubsub_msg_id_type_inconsistency 2025-08-18 22:02:20 +05:30
fe71c479dc Merge branch 'main' into feat/804-add-thin-waist-address 2025-08-18 21:55:02 +05:30
95e1f62870 Merge pull request #845 from bomanaps/feat/implement-listen_close-mynotifee
Implement listen_close notification and tests
2025-08-18 21:03:12 +05:30
9378490dcb fix: ensure loopback addresses are included in available interfaces 2025-08-18 12:40:38 +05:30
a2fcf33bc1 refactor: migrate echo example test to use Trio for process handling 2025-08-18 12:38:10 +05:30
b363d1d6d0 fix: update listening address handling to use all available interfaces 2025-08-18 12:38:04 +05:30
9a0f224a1c Merge branch 'libp2p:main' into feat/804-add-thin-waist-address 2025-08-18 11:31:28 +05:30
13379e38d8 Merge branch 'libp2p:main' into feat/implement-listen_close-mynotifee 2025-08-17 20:54:59 +01:00
09d2110d65 Remove redundant local import of Multiaddr in close() method 2025-08-17 20:29:35 +01:00
6931092eea Merge branch 'main' into issue-798 2025-08-17 17:02:20 +05:30
cff0bfc17d Merge pull request #846 from sumanjeet0012/bugfix/kbucket_split_fix
Fix: kbucket splitting in routing table.
2025-08-17 17:00:47 +05:30
163cc35cb0 Enhance Bootstrap module to dial peers after address resolution. 2025-08-17 02:12:09 +05:30
a2ad10b1e4 added newsfragments 2025-08-16 18:31:07 +05:30
7c2014087f Merge branch 'libp2p:main' into bugfix/kbucket_split_fix 2025-08-16 13:05:26 +05:30
37df8d679d fix: fixed kbucket splitting behavior in RoutingTable 2025-08-16 11:51:37 +05:30
5c78a41552 Implement closed_stream notification and tests 2025-08-15 16:02:58 +01:00
388302baa7 Added newsfragment 2025-08-15 13:57:21 +05:30
dc04270c19 fix: message id type inonsistency in handle ihave and message id parsing improvement in handle iwant 2025-08-15 13:53:24 +05:30
90f143cd88 update pyproject.toml with current maintainers (#799)
* replace ethereum author with current maintainers

* use my github handle instead of full name

* replace no-prod warning with current status message

* update maintainers blurb

---------

Co-authored-by: Manu Sheel Gupta <manusheel.edu@gmail.com>
2025-08-14 13:13:34 -06:00
1ecff5437c fixed newsfragment filename issue. 2025-08-14 07:29:06 +05:30
aa7276c863 Implement closed_stream event handling and enable related tests (#834)
* Implement closed_stream event handling and enable related tests

* Fix linting issues and ensure all tests pass

* Add logging for exception in SwarmConn and create newsfragment for closed_stream feature
2025-08-13 16:19:53 -06:00
b838a0e3b6 added none type to return value of negotiate and changed caller handles to handle none. Added newsfragment. 2025-08-12 21:50:10 +05:30
b01596ad92 Revert "Compile release notes for v0.2.10"
This reverts commit 2730db4285.
2025-08-12 07:34:33 -06:00
1565d409e8 Revert "Bump version: 0.2.9 → 0.2.10"
This reverts commit 400ee9b896.
2025-08-12 07:25:21 -06:00
400ee9b896 Bump version: 0.2.9 → 0.2.10 2025-08-12 07:22:43 -06:00
2730db4285 Compile release notes for v0.2.10 2025-08-12 07:21:56 -06:00
bb896dac2c Merge pull request #818 from varun-r-mallya/varun-r-mallya/protobuf-update
Update protobufs
2025-08-12 00:22:28 +05:30
a14c42ef73 Merge branch 'main' into feat/804-add-thin-waist-address 2025-08-10 21:35:10 +05:30
af61523c87 Merge branch 'main' into varun-r-mallya/protobuf-update 2025-08-10 16:56:10 +05:30
d2fdf70692 Merge pull request #819 from lla-dane/remove-todos
Remove completed TODO task comments in Peerstore
2025-08-10 16:54:41 +05:30
09cd8b37ed Merge branch 'main' into feat/804-add-thin-waist-address 2025-08-10 13:13:44 +05:30
f4247faa51 added newsfragment 2025-08-10 11:39:34 +05:30
eb3121b818 remove completed TODO task comments 2025-08-10 11:28:11 +05:30
59a898c8ce Add tests for echo example and address validation utilities
- Introduced `test_echo_thin_waist.py` to validate the echo example's output for Thin Waist lines.
- Added `test_address_validation.py` to cover functions for available interfaces, optimal binding addresses, and wildcard address expansion.
- Included parameterized tests and environment checks for IPv6 support.
2025-08-09 01:24:14 +05:30
fa174230ba Refactor echo example to use optimal binding address
- Replaced hardcoded listen address with `get_optimal_binding_address` for improved flexibility.
- Imported address validation utilities in `echo.py` and updated `__init__.py` to include new functions.
2025-08-09 01:22:17 +05:30
b840eaa7e1 Implement advanced network discovery example and address validation utilities
- Added `network_discover.py` to demonstrate Thin Waist address handling.
- Introduced `address_validation.py` with functions for discovering available network interfaces, expanding wildcard addresses, and determining optimal binding addresses.
- Included fallback mechanisms for environments lacking Thin Waist support.
2025-08-09 01:22:03 +05:30
109 changed files with 7020 additions and 794 deletions

3
.gitignore vendored
View File

@ -178,3 +178,6 @@ env.bak/
#lockfiles
uv.lock
poetry.lock
# Sphinx documentation build
_build/

View File

@ -12,13 +12,13 @@
[![Build Status](https://img.shields.io/github/actions/workflow/status/libp2p/py-libp2p/tox.yml?branch=main&label=build%20status)](https://github.com/libp2p/py-libp2p/actions/workflows/tox.yml)
[![Docs build](https://readthedocs.org/projects/py-libp2p/badge/?version=latest)](http://py-libp2p.readthedocs.io/en/latest/?badge=latest)
> ⚠️ **Warning:** py-libp2p is an experimental and work-in-progress repo under development. We do not yet recommend using py-libp2p in production environments.
> py-libp2p has moved beyond its experimental roots and is steadily progressing toward production readiness. The core features are stable, and were focused on refining performance, expanding protocol support, and ensuring smooth interop with other libp2p implementations. We welcome contributions and real-world usage feedback to help us reach full production maturity.
Read more in the [documentation on ReadTheDocs](https://py-libp2p.readthedocs.io/). [View the release notes](https://py-libp2p.readthedocs.io/en/latest/release_notes.html).
## Maintainers
Currently maintained by [@pacrob](https://github.com/pacrob), [@seetadev](https://github.com/seetadev) and [@dhuseby](https://github.com/dhuseby), looking for assistance!
Currently maintained by [@pacrob](https://github.com/pacrob), [@seetadev](https://github.com/seetadev) and [@dhuseby](https://github.com/dhuseby). Please reach out to us for collaboration or active feedback. If you have questions, feel free to open a new [discussion](https://github.com/libp2p/py-libp2p/discussions). We are also available on the libp2p Discord — join us at #py-libp2p [sub-channel](https://discord.gg/d92MEugb).
## Feature Breakdown

View File

@ -0,0 +1,194 @@
Multiple Connections Per Peer
=============================
This example demonstrates how to use the multiple connections per peer feature in py-libp2p.
Overview
--------
The multiple connections per peer feature allows a libp2p node to maintain multiple network connections to the same peer. This provides several benefits:
- **Improved reliability**: If one connection fails, others remain available
- **Better performance**: Load can be distributed across multiple connections
- **Enhanced throughput**: Multiple streams can be created in parallel
- **Fault tolerance**: Redundant connections provide backup paths
Configuration
-------------
The feature is configured through the `ConnectionConfig` class:
.. code-block:: python
from libp2p.network.swarm import ConnectionConfig
# Default configuration
config = ConnectionConfig()
print(f"Max connections per peer: {config.max_connections_per_peer}")
print(f"Load balancing strategy: {config.load_balancing_strategy}")
# Custom configuration
custom_config = ConnectionConfig(
max_connections_per_peer=5,
connection_timeout=60.0,
load_balancing_strategy="least_loaded"
)
Load Balancing Strategies
-------------------------
Two load balancing strategies are available:
**Round Robin** (default)
Cycles through connections in order, distributing load evenly.
**Least Loaded**
Selects the connection with the fewest active streams.
API Usage
---------
The new API provides direct access to multiple connections:
.. code-block:: python
from libp2p import new_swarm
# Create swarm with multiple connections support
swarm = new_swarm()
# Dial a peer - returns list of connections
connections = await swarm.dial_peer(peer_id)
print(f"Established {len(connections)} connections")
# Get all connections to a peer
peer_connections = swarm.get_connections(peer_id)
# Get all connections (across all peers)
all_connections = swarm.get_connections()
# Get the complete connections map
connections_map = swarm.get_connections_map()
# Backward compatibility - get single connection
single_conn = swarm.get_connection(peer_id)
Backward Compatibility
----------------------
Existing code continues to work through backward compatibility features:
.. code-block:: python
# Legacy 1:1 mapping (returns first connection for each peer)
legacy_connections = swarm.connections_legacy
# Single connection access (returns first available connection)
conn = swarm.get_connection(peer_id)
Example
-------
A complete working example is available in the `examples/doc-examples/multiple_connections_example.py` file.
Production Configuration
-------------------------
For production use, consider these settings:
**RetryConfig Parameters**
The `RetryConfig` class controls connection retry behavior with exponential backoff:
- **max_retries**: Maximum number of retry attempts before giving up (default: 3)
- **initial_delay**: Initial delay in seconds before the first retry (default: 0.1s)
- **max_delay**: Maximum delay cap to prevent excessive wait times (default: 30.0s)
- **backoff_multiplier**: Exponential backoff multiplier - each retry multiplies delay by this factor (default: 2.0)
- **jitter_factor**: Random jitter (0.0-1.0) to prevent synchronized retries (default: 0.1)
**ConnectionConfig Parameters**
The `ConnectionConfig` class manages multi-connection behavior:
- **max_connections_per_peer**: Maximum connections allowed to a single peer (default: 3)
- **connection_timeout**: Timeout for establishing new connections in seconds (default: 30.0s)
- **load_balancing_strategy**: Strategy for distributing streams ("round_robin" or "least_loaded")
**Load Balancing Strategies Explained**
- **round_robin**: Cycles through connections in order, distributing load evenly. Simple and predictable.
- **least_loaded**: Selects the connection with the fewest active streams. Better for performance but more complex.
.. code-block:: python
from libp2p.network.swarm import ConnectionConfig, RetryConfig
# Production-ready configuration
retry_config = RetryConfig(
max_retries=3, # Maximum retry attempts before giving up
initial_delay=0.1, # Start with 100ms delay
max_delay=30.0, # Cap exponential backoff at 30 seconds
backoff_multiplier=2.0, # Double delay each retry (100ms -> 200ms -> 400ms)
jitter_factor=0.1 # Add 10% random jitter to prevent thundering herd
)
connection_config = ConnectionConfig(
max_connections_per_peer=3, # Allow up to 3 connections per peer
connection_timeout=30.0, # 30 second timeout for new connections
load_balancing_strategy="round_robin" # Simple, predictable load distribution
)
swarm = new_swarm(
retry_config=retry_config,
connection_config=connection_config
)
**How RetryConfig Works in Practice**
With the configuration above, connection retries follow this pattern:
1. **Attempt 1**: Immediate connection attempt
2. **Attempt 2**: Wait 100ms ± 10ms jitter, then retry
3. **Attempt 3**: Wait 200ms ± 20ms jitter, then retry
4. **Attempt 4**: Wait 400ms ± 40ms jitter, then retry
5. **Attempt 5**: Wait 800ms ± 80ms jitter, then retry
6. **Attempt 6**: Wait 1.6s ± 160ms jitter, then retry
7. **Attempt 7**: Wait 3.2s ± 320ms jitter, then retry
8. **Attempt 8**: Wait 6.4s ± 640ms jitter, then retry
9. **Attempt 9**: Wait 12.8s ± 1.28s jitter, then retry
10. **Attempt 10**: Wait 25.6s ± 2.56s jitter, then retry
11. **Attempt 11**: Wait 30.0s (capped) ± 3.0s jitter, then retry
12. **Attempt 12**: Wait 30.0s (capped) ± 3.0s jitter, then retry
13. **Give up**: After 12 retries (3 initial + 9 retries), connection fails
The jitter prevents multiple clients from retrying simultaneously, reducing server load.
**Parameter Tuning Guidelines**
**For Development/Testing:**
- Use lower `max_retries` (1-2) and shorter delays for faster feedback
- Example: `RetryConfig(max_retries=2, initial_delay=0.01, max_delay=0.1)`
**For Production:**
- Use moderate `max_retries` (3-5) with reasonable delays for reliability
- Example: `RetryConfig(max_retries=5, initial_delay=0.1, max_delay=60.0)`
**For High-Latency Networks:**
- Use higher `max_retries` (5-10) with longer delays
- Example: `RetryConfig(max_retries=8, initial_delay=0.5, max_delay=120.0)`
**For Load Balancing:**
- Use `round_robin` for simple, predictable behavior
- Use `least_loaded` when you need optimal performance and can handle complexity
Architecture
------------
The implementation follows the same architectural patterns as the Go and JavaScript reference implementations:
- **Core data structure**: `dict[ID, list[INetConn]]` for 1:many mapping
- **API consistency**: Methods like `get_connections()` match reference implementations
- **Load balancing**: Integrated at the API level for optimal performance
- **Backward compatibility**: Maintains existing interfaces for gradual migration
This design ensures consistency across libp2p implementations while providing the benefits of multiple connections per peer.

View File

@ -0,0 +1,131 @@
Random Walk Example
===================
This example demonstrates the Random Walk module's peer discovery capabilities using real libp2p hosts and Kademlia DHT.
It shows how the Random Walk module automatically discovers new peers and maintains routing table health.
The Random Walk implementation performs the following key operations:
* **Automatic Peer Discovery**: Generates random peer IDs and queries the DHT network to discover new peers
* **Routing Table Maintenance**: Periodically refreshes the routing table to maintain network connectivity
* **Connection Management**: Maintains optimal connections to healthy peers in the network
* **Real-time Statistics**: Displays routing table size, connected peers, and peerstore statistics
.. code-block:: console
$ python -m pip install libp2p
Collecting libp2p
...
Successfully installed libp2p-x.x.x
$ cd examples/random_walk
$ python random_walk.py --mode server
2025-08-12 19:51:25,424 - random-walk-example - INFO - === Random Walk Example for py-libp2p ===
2025-08-12 19:51:25,424 - random-walk-example - INFO - Mode: server, Port: 0 Demo interval: 30s
2025-08-12 19:51:25,426 - random-walk-example - INFO - Starting server node on port 45123
2025-08-12 19:51:25,426 - random-walk-example - INFO - Node peer ID: 16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef
2025-08-12 19:51:25,426 - random-walk-example - INFO - Node address: /ip4/0.0.0.0/tcp/45123/p2p/16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef
2025-08-12 19:51:25,427 - random-walk-example - INFO - Initial routing table size: 0
2025-08-12 19:51:25,427 - random-walk-example - INFO - DHT service started in SERVER mode
2025-08-12 19:51:25,430 - libp2p.discovery.random_walk.rt_refresh_manager - INFO - RT Refresh Manager started
2025-08-12 19:51:55,432 - random-walk-example - INFO - --- Iteration 1 ---
2025-08-12 19:51:55,432 - random-walk-example - INFO - Routing table size: 15
2025-08-12 19:51:55,432 - random-walk-example - INFO - Connected peers: 8
2025-08-12 19:51:55,432 - random-walk-example - INFO - Peerstore size: 42
You can also run the example in client mode:
.. code-block:: console
$ python random_walk.py --mode client
2025-08-12 19:52:15,424 - random-walk-example - INFO - === Random Walk Example for py-libp2p ===
2025-08-12 19:52:15,424 - random-walk-example - INFO - Mode: client, Port: 0 Demo interval: 30s
2025-08-12 19:52:15,426 - random-walk-example - INFO - Starting client node on port 51234
2025-08-12 19:52:15,426 - random-walk-example - INFO - Node peer ID: 16Uiu2HAmAbc123xyz...
2025-08-12 19:52:15,427 - random-walk-example - INFO - DHT service started in CLIENT mode
2025-08-12 19:52:45,432 - random-walk-example - INFO - --- Iteration 1 ---
2025-08-12 19:52:45,432 - random-walk-example - INFO - Routing table size: 8
2025-08-12 19:52:45,432 - random-walk-example - INFO - Connected peers: 5
2025-08-12 19:52:45,432 - random-walk-example - INFO - Peerstore size: 25
Command Line Options
--------------------
The example supports several command-line options:
.. code-block:: console
$ python random_walk.py --help
usage: random_walk.py [-h] [--mode {server,client}] [--port PORT]
[--demo-interval DEMO_INTERVAL] [--verbose]
Random Walk Example for py-libp2p Kademlia DHT
optional arguments:
-h, --help show this help message and exit
--mode {server,client}
Node mode: server (DHT server), or client (DHT client)
--port PORT Port to listen on (0 for random)
--demo-interval DEMO_INTERVAL
Interval between random walk demonstrations in seconds
--verbose Enable verbose logging
Key Features Demonstrated
-------------------------
**Automatic Random Walk Discovery**
The example shows how the Random Walk module automatically:
* Generates random 256-bit peer IDs for discovery queries
* Performs concurrent random walks to maximize peer discovery
* Validates discovered peers and adds them to the routing table
* Maintains routing table health through periodic refreshes
**Real-time Network Statistics**
The example displays live statistics every 30 seconds (configurable):
* **Routing Table Size**: Number of peers in the Kademlia routing table
* **Connected Peers**: Number of actively connected peers
* **Peerstore Size**: Total number of known peers with addresses
**Connection Management**
The example includes sophisticated connection management:
* Automatically maintains connections to healthy peers
* Filters for compatible peers (TCP + IPv4 addresses)
* Reconnects to maintain optimal network connectivity
* Handles connection failures gracefully
**DHT Integration**
Shows seamless integration between Random Walk and Kademlia DHT:
* RT Refresh Manager coordinates with the DHT routing table
* Peer discovery feeds directly into DHT operations
* Both SERVER and CLIENT modes supported
* Bootstrap connectivity to public IPFS nodes
Understanding the Output
------------------------
When you run the example, you'll see periodic statistics that show how the Random Walk module is working:
* **Initial Phase**: Routing table starts empty and quickly discovers peers
* **Growth Phase**: Routing table size increases as more peers are discovered
* **Maintenance Phase**: Routing table size stabilizes as the system maintains optimal peer connections
The Random Walk module runs automatically in the background, performing peer discovery queries every few minutes to ensure the routing table remains populated with fresh, reachable peers.
Configuration
-------------
The Random Walk module can be configured through the following parameters in ``libp2p.discovery.random_walk.config``:
* ``RANDOM_WALK_ENABLED``: Enable/disable automatic random walks (default: True)
* ``REFRESH_INTERVAL``: Time between automatic refreshes in seconds (default: 300)
* ``RANDOM_WALK_CONCURRENCY``: Number of concurrent random walks (default: 3)
* ``MIN_RT_REFRESH_THRESHOLD``: Minimum routing table size before triggering refresh (default: 4)
See Also
--------
* :doc:`examples.kademlia` - Kademlia DHT value storage and content routing
* :doc:`libp2p.discovery.random_walk` - Random Walk module API documentation

View File

@ -14,3 +14,5 @@ Examples
examples.circuit_relay
examples.kademlia
examples.mDNS
examples.random_walk
examples.multiple_connections

View File

@ -0,0 +1,48 @@
libp2p.discovery.random_walk package
====================================
The Random Walk module implements a peer discovery mechanism.
It performs random walks through the DHT network to discover new peers and maintain routing table health through periodic refreshes.
Submodules
----------
libp2p.discovery.random_walk.config module
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: libp2p.discovery.random_walk.config
:members:
:undoc-members:
:show-inheritance:
libp2p.discovery.random_walk.exceptions module
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: libp2p.discovery.random_walk.exceptions
:members:
:undoc-members:
:show-inheritance:
libp2p.discovery.random_walk.random_walk module
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: libp2p.discovery.random_walk.random_walk
:members:
:undoc-members:
:show-inheritance:
libp2p.discovery.random_walk.rt_refresh_manager module
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: libp2p.discovery.random_walk.rt_refresh_manager
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: libp2p.discovery.random_walk
:members:
:undoc-members:
:show-inheritance:

View File

@ -10,6 +10,7 @@ Subpackages
libp2p.discovery.bootstrap
libp2p.discovery.events
libp2p.discovery.mdns
libp2p.discovery.random_walk
Submodules
----------

View File

@ -0,0 +1,63 @@
"""
Advanced demonstration of Thin Waist address handling.
Run:
python -m examples.advanced.network_discovery
"""
from __future__ import annotations
from multiaddr import Multiaddr
try:
from libp2p.utils.address_validation import (
expand_wildcard_address,
get_available_interfaces,
get_optimal_binding_address,
)
except ImportError:
# Fallbacks if utilities are missing
def get_available_interfaces(port: int, protocol: str = "tcp"):
return [Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")]
def expand_wildcard_address(addr: Multiaddr, port: int | None = None):
if port is None:
return [addr]
addr_str = str(addr).rsplit("/", 1)[0]
return [Multiaddr(addr_str + f"/{port}")]
def get_optimal_binding_address(port: int, protocol: str = "tcp"):
return Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")
def main() -> None:
port = 8080
interfaces = get_available_interfaces(port)
print(f"Discovered interfaces for port {port}:")
for a in interfaces:
print(f" - {a}")
wildcard_v4 = Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
expanded_v4 = expand_wildcard_address(wildcard_v4)
print("\nExpanded IPv4 wildcard:")
for a in expanded_v4:
print(f" - {a}")
wildcard_v6 = Multiaddr(f"/ip6/::/tcp/{port}")
expanded_v6 = expand_wildcard_address(wildcard_v6)
print("\nExpanded IPv6 wildcard:")
for a in expanded_v6:
print(f" - {a}")
print("\nOptimal binding address heuristic result:")
print(f" -> {get_optimal_binding_address(port)}")
override_port = 9000
overridden = expand_wildcard_address(wildcard_v4, port=override_port)
print(f"\nPort override expansion to {override_port}:")
for a in overridden:
print(f" - {a}")
if __name__ == "__main__":
main()

View File

@ -24,13 +24,8 @@ async def main():
noise_transport = NoiseTransport(
# local_key_pair: The key pair used for libp2p identity and authentication
libp2p_keypair=key_pair,
# noise_privkey: The private key used for Noise protocol encryption
noise_privkey=key_pair.private_key,
# early_data: Optional data to send during the handshake
# (None means no early data)
early_data=None,
# with_noise_pipes: Whether to use Noise pipes for additional security features
with_noise_pipes=False,
# TODO: add early data
)
# Create a security options dictionary mapping protocol ID to transport

View File

@ -28,9 +28,7 @@ async def main():
noise_privkey=key_pair.private_key,
# early_data: Optional data to send during the handshake
# (None means no early data)
early_data=None,
# with_noise_pipes: Whether to use Noise pipes for additional security features
with_noise_pipes=False,
# TODO: add early data
)
# Create a security options dictionary mapping protocol ID to transport

View File

@ -31,9 +31,7 @@ async def main():
noise_privkey=key_pair.private_key,
# early_data: Optional data to send during the handshake
# (None means no early data)
early_data=None,
# with_noise_pipes: Whether to use Noise pipes for additional security features
with_noise_pipes=False,
# TODO: add early data
)
# Create a security options dictionary mapping protocol ID to transport

View File

@ -28,9 +28,7 @@ async def main():
noise_privkey=key_pair.private_key,
# early_data: Optional data to send during the handshake
# (None means no early data)
early_data=None,
# with_noise_pipes: Whether to use Noise pipes for additional security features
with_noise_pipes=False,
# TODO: add early data
)
# Create a security options dictionary mapping protocol ID to transport

View File

@ -0,0 +1,170 @@
#!/usr/bin/env python3
"""
Example demonstrating multiple connections per peer support in libp2p.
This example shows how to:
1. Configure multiple connections per peer
2. Use different load balancing strategies
3. Access multiple connections through the new API
4. Maintain backward compatibility
"""
import logging
import trio
from libp2p import new_swarm
from libp2p.network.swarm import ConnectionConfig, RetryConfig
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def example_basic_multiple_connections() -> None:
"""Example of basic multiple connections per peer usage."""
logger.info("Creating swarm with multiple connections support...")
# Create swarm with default configuration
swarm = new_swarm()
default_connection = ConnectionConfig()
logger.info(f"Swarm created with peer ID: {swarm.get_peer_id()}")
logger.info(
f"Connection config: max_connections_per_peer="
f"{default_connection.max_connections_per_peer}"
)
await swarm.close()
logger.info("Basic multiple connections example completed")
async def example_custom_connection_config() -> None:
"""Example of custom connection configuration."""
logger.info("Creating swarm with custom connection configuration...")
# Custom connection configuration for high-performance scenarios
connection_config = ConnectionConfig(
max_connections_per_peer=5, # More connections per peer
connection_timeout=60.0, # Longer timeout
load_balancing_strategy="least_loaded", # Use least loaded strategy
)
# Create swarm with custom connection config
swarm = new_swarm(connection_config=connection_config)
logger.info("Custom connection config applied:")
logger.info(
f" Max connections per peer: {connection_config.max_connections_per_peer}"
)
logger.info(f" Connection timeout: {connection_config.connection_timeout}s")
logger.info(
f" Load balancing strategy: {connection_config.load_balancing_strategy}"
)
await swarm.close()
logger.info("Custom connection config example completed")
async def example_multiple_connections_api() -> None:
"""Example of using the new multiple connections API."""
logger.info("Demonstrating multiple connections API...")
connection_config = ConnectionConfig(
max_connections_per_peer=3, load_balancing_strategy="round_robin"
)
swarm = new_swarm(connection_config=connection_config)
logger.info("Multiple connections API features:")
logger.info(" - dial_peer() returns list[INetConn]")
logger.info(" - get_connections(peer_id) returns list[INetConn]")
logger.info(" - get_connections_map() returns dict[ID, list[INetConn]]")
logger.info(
" - get_connection(peer_id) returns INetConn | None (backward compatibility)"
)
await swarm.close()
logger.info("Multiple connections API example completed")
async def example_backward_compatibility() -> None:
"""Example of backward compatibility features."""
logger.info("Demonstrating backward compatibility...")
swarm = new_swarm()
logger.info("Backward compatibility features:")
logger.info(" - connections_legacy property provides 1:1 mapping")
logger.info(" - get_connection() method for single connection access")
logger.info(" - Existing code continues to work")
await swarm.close()
logger.info("Backward compatibility example completed")
async def example_production_ready_config() -> None:
"""Example of production-ready configuration."""
logger.info("Creating swarm with production-ready configuration...")
# Production-ready retry configuration
retry_config = RetryConfig(
max_retries=3, # Reasonable retry limit
initial_delay=0.1, # Quick initial retry
max_delay=30.0, # Cap exponential backoff
backoff_multiplier=2.0, # Standard exponential backoff
jitter_factor=0.1, # Small jitter to prevent thundering herd
)
# Production-ready connection configuration
connection_config = ConnectionConfig(
max_connections_per_peer=3, # Balance between performance and resource usage
connection_timeout=30.0, # Reasonable timeout
load_balancing_strategy="round_robin", # Simple, predictable strategy
)
# Create swarm with production config
swarm = new_swarm(retry_config=retry_config, connection_config=connection_config)
logger.info("Production-ready configuration applied:")
logger.info(
f" Retry: {retry_config.max_retries} retries, "
f"{retry_config.max_delay}s max delay"
)
logger.info(f" Connections: {connection_config.max_connections_per_peer} per peer")
logger.info(f" Load balancing: {connection_config.load_balancing_strategy}")
await swarm.close()
logger.info("Production-ready configuration example completed")
async def main() -> None:
"""Run all examples."""
logger.info("Multiple Connections Per Peer Examples")
logger.info("=" * 50)
try:
await example_basic_multiple_connections()
logger.info("-" * 30)
await example_custom_connection_config()
logger.info("-" * 30)
await example_multiple_connections_api()
logger.info("-" * 30)
await example_backward_compatibility()
logger.info("-" * 30)
await example_production_ready_config()
logger.info("-" * 30)
logger.info("All examples completed successfully!")
except Exception as e:
logger.error(f"Example failed: {e}")
raise
if __name__ == "__main__":
trio.run(main)

View File

@ -1,4 +1,6 @@
import argparse
import random
import secrets
import multiaddr
import trio
@ -12,40 +14,54 @@ from libp2p.crypto.secp256k1 import (
from libp2p.custom_types import (
TProtocol,
)
from libp2p.network.stream.exceptions import (
StreamEOF,
)
from libp2p.network.stream.net_stream import (
INetStream,
)
from libp2p.peer.peerinfo import (
info_from_p2p_addr,
)
from libp2p.utils.address_validation import (
find_free_port,
get_available_interfaces,
)
PROTOCOL_ID = TProtocol("/echo/1.0.0")
MAX_READ_LEN = 2**32 - 1
async def _echo_stream_handler(stream: INetStream) -> None:
# Wait until EOF
msg = await stream.read(MAX_READ_LEN)
await stream.write(msg)
await stream.close()
try:
peer_id = stream.muxed_conn.peer_id
print(f"Received connection from {peer_id}")
# Wait until EOF
msg = await stream.read(MAX_READ_LEN)
print(f"Echoing message: {msg.decode('utf-8')}")
await stream.write(msg)
except StreamEOF:
print("Stream closed by remote peer.")
except Exception as e:
print(f"Error in echo handler: {e}")
finally:
await stream.close()
async def run(port: int, destination: str, seed: int | None = None) -> None:
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
if port <= 0:
port = find_free_port()
listen_addr = get_available_interfaces(port)
if seed:
import random
random.seed(seed)
secret_number = random.getrandbits(32 * 8)
secret = secret_number.to_bytes(length=32, byteorder="big")
else:
import secrets
secret = secrets.token_bytes(32)
host = new_host(key_pair=create_new_key_pair(secret))
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
async with host.run(listen_addrs=listen_addr), trio.open_nursery() as nursery:
# Start the peer-store cleanup task
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
@ -54,10 +70,15 @@ async def run(port: int, destination: str, seed: int | None = None) -> None:
if not destination: # its the server
host.set_stream_handler(PROTOCOL_ID, _echo_stream_handler)
# Print all listen addresses with peer ID (JS parity)
print("Listener ready, listening on:\n")
peer_id = host.get_id().to_string()
for addr in listen_addr:
print(f"{addr}/p2p/{peer_id}")
print(
"Run this from the same folder in another console:\n\n"
f"echo-demo "
f"-d {host.get_addrs()[0]}\n"
"\nRun this from the same folder in another console:\n\n"
f"echo-demo -d {host.get_addrs()[0]}\n"
)
print("Waiting for incoming connections...")
await trio.sleep_forever()

View File

@ -41,6 +41,7 @@ from libp2p.tools.async_service import (
from libp2p.tools.utils import (
info_from_p2p_addr,
)
from libp2p.utils.paths import get_script_dir, join_paths
# Configure logging
logging.basicConfig(
@ -53,8 +54,8 @@ logger = logging.getLogger("kademlia-example")
# Configure DHT module loggers to inherit from the parent logger
# This ensures all kademlia-example.* loggers use the same configuration
# Get the directory where this script is located
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
SERVER_ADDR_LOG = os.path.join(SCRIPT_DIR, "server_node_addr.txt")
SCRIPT_DIR = get_script_dir(__file__)
SERVER_ADDR_LOG = join_paths(SCRIPT_DIR, "server_node_addr.txt")
# Set the level for all child loggers
for module in [
@ -227,7 +228,7 @@ async def run_node(
# Keep the node running
while True:
logger.debug(
logger.info(
"Status - Connected peers: %d,"
"Peers in store: %d, Values in store: %d",
len(dht.host.get_connected_peers()),

View File

@ -1,6 +1,5 @@
import argparse
import logging
import socket
import base58
import multiaddr
@ -31,6 +30,9 @@ from libp2p.stream_muxer.mplex.mplex import (
from libp2p.tools.async_service.trio_service import (
background_trio_service,
)
from libp2p.utils.address_validation import (
find_free_port,
)
# Configure logging
logging.basicConfig(
@ -77,13 +79,6 @@ async def publish_loop(pubsub, topic, termination_event):
await trio.sleep(1) # Avoid tight loop on error
def find_free_port():
"""Find a free port on localhost."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0)) # Bind to a free port provided by the OS
return s.getsockname()[1]
async def monitor_peer_topics(pubsub, nursery, termination_event):
"""
Monitor for new topics that peers are subscribed to and

View File

@ -0,0 +1,221 @@
"""
Random Walk Example for py-libp2p Kademlia DHT
This example demonstrates the Random Walk module's peer discovery capabilities
using real libp2p hosts and Kademlia DHT. It shows how the Random Walk module
automatically discovers new peers and maintains routing table health.
Usage:
# Start server nodes (they will discover peers via random walk)
python3 random_walk.py --mode server
"""
import argparse
import logging
import random
import secrets
import sys
from multiaddr import Multiaddr
import trio
from libp2p import new_host
from libp2p.abc import IHost
from libp2p.crypto.secp256k1 import create_new_key_pair
from libp2p.kad_dht.kad_dht import DHTMode, KadDHT
from libp2p.tools.async_service import background_trio_service
# Simple logging configuration
def setup_logging(verbose: bool = False):
"""Setup unified logging configuration."""
level = logging.DEBUG if verbose else logging.INFO
logging.basicConfig(
level=level,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler()],
)
# Configure key module loggers
for module in ["libp2p.discovery.random_walk", "libp2p.kad_dht"]:
logging.getLogger(module).setLevel(level)
# Suppress noisy logs
logging.getLogger("multiaddr").setLevel(logging.WARNING)
logger = logging.getLogger("random-walk-example")
# Default bootstrap nodes
DEFAULT_BOOTSTRAP_NODES = [
"/ip4/104.131.131.82/tcp/4001/p2p/QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ"
]
def filter_compatible_peer_info(peer_info) -> bool:
"""Filter peer info to check if it has compatible addresses (TCP + IPv4)."""
if not hasattr(peer_info, "addrs") or not peer_info.addrs:
return False
for addr in peer_info.addrs:
addr_str = str(addr)
if "/tcp/" in addr_str and "/ip4/" in addr_str and "/quic" not in addr_str:
return True
return False
async def maintain_connections(host: IHost) -> None:
"""Maintain connections to ensure the host remains connected to healthy peers."""
while True:
try:
connected_peers = host.get_connected_peers()
list_peers = host.get_peerstore().peers_with_addrs()
if len(connected_peers) < 20:
logger.debug("Reconnecting to maintain peer connections...")
# Find compatible peers
compatible_peers = []
for peer_id in list_peers:
try:
peer_info = host.get_peerstore().peer_info(peer_id)
if filter_compatible_peer_info(peer_info):
compatible_peers.append(peer_id)
except Exception:
continue
# Connect to random subset of compatible peers
if compatible_peers:
random_peers = random.sample(
compatible_peers, min(50, len(compatible_peers))
)
for peer_id in random_peers:
if peer_id not in connected_peers:
try:
with trio.move_on_after(5):
peer_info = host.get_peerstore().peer_info(peer_id)
await host.connect(peer_info)
logger.debug(f"Connected to peer: {peer_id}")
except Exception as e:
logger.debug(f"Failed to connect to {peer_id}: {e}")
await trio.sleep(15)
except Exception as e:
logger.error(f"Error maintaining connections: {e}")
async def demonstrate_random_walk_discovery(dht: KadDHT, interval: int = 30) -> None:
"""Demonstrate Random Walk peer discovery with periodic statistics."""
iteration = 0
while True:
iteration += 1
logger.info(f"--- Iteration {iteration} ---")
logger.info(f"Routing table size: {dht.get_routing_table_size()}")
logger.info(f"Connected peers: {len(dht.host.get_connected_peers())}")
logger.info(f"Peerstore size: {len(dht.host.get_peerstore().peer_ids())}")
await trio.sleep(interval)
async def run_node(port: int, mode: str, demo_interval: int = 30) -> None:
"""Run a node that demonstrates Random Walk peer discovery."""
try:
if port <= 0:
port = random.randint(10000, 60000)
logger.info(f"Starting {mode} node on port {port}")
# Determine DHT mode
dht_mode = DHTMode.SERVER if mode == "server" else DHTMode.CLIENT
# Create host and DHT
key_pair = create_new_key_pair(secrets.token_bytes(32))
host = new_host(key_pair=key_pair, bootstrap=DEFAULT_BOOTSTRAP_NODES)
listen_addr = Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
# Start maintenance tasks
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
nursery.start_soon(maintain_connections, host)
peer_id = host.get_id().pretty()
logger.info(f"Node peer ID: {peer_id}")
logger.info(f"Node address: /ip4/0.0.0.0/tcp/{port}/p2p/{peer_id}")
# Create and start DHT with Random Walk enabled
dht = KadDHT(host, dht_mode, enable_random_walk=True)
logger.info(f"Initial routing table size: {dht.get_routing_table_size()}")
async with background_trio_service(dht):
logger.info(f"DHT service started in {dht_mode.value} mode")
logger.info(f"Random Walk enabled: {dht.is_random_walk_enabled()}")
async with trio.open_nursery() as task_nursery:
# Start demonstration and status reporting
task_nursery.start_soon(
demonstrate_random_walk_discovery, dht, demo_interval
)
# Periodic status updates
async def status_reporter():
while True:
await trio.sleep(30)
logger.debug(
f"Connected: {len(dht.host.get_connected_peers())}, "
f"Routing table: {dht.get_routing_table_size()}, "
f"Peerstore: {len(dht.host.get_peerstore().peer_ids())}"
)
task_nursery.start_soon(status_reporter)
await trio.sleep_forever()
except Exception as e:
logger.error(f"Node error: {e}", exc_info=True)
sys.exit(1)
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="Random Walk Example for py-libp2p Kademlia DHT",
)
parser.add_argument(
"--mode",
choices=["server", "client"],
default="server",
help="Node mode: server (DHT server), or client (DHT client)",
)
parser.add_argument(
"--port", type=int, default=0, help="Port to listen on (0 for random)"
)
parser.add_argument(
"--demo-interval",
type=int,
default=30,
help="Interval between random walk demonstrations in seconds",
)
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
return parser.parse_args()
def main():
"""Main entry point for the random walk example."""
try:
args = parse_args()
setup_logging(args.verbose)
logger.info("=== Random Walk Example for py-libp2p ===")
logger.info(
f"Mode: {args.mode}, Port: {args.port} Demo interval: {args.demo_interval}s"
)
trio.run(run_node, args.port, args.mode, args.demo_interval)
except KeyboardInterrupt:
logger.info("Received interrupt signal, shutting down...")
except Exception as e:
logger.critical(f"Example failed: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
main()

View File

@ -1,3 +1,5 @@
"""Libp2p Python implementation."""
from collections.abc import (
Mapping,
Sequence,
@ -6,15 +8,12 @@ from importlib.metadata import version as __version
from typing import (
Literal,
Optional,
Type,
cast,
)
import multiaddr
from libp2p.abc import (
IHost,
IMuxedConn,
INetworkService,
IPeerRouting,
IPeerStore,
@ -32,9 +31,6 @@ from libp2p.custom_types import (
TProtocol,
TSecurityOptions,
)
from libp2p.discovery.mdns.mdns import (
MDNSDiscovery,
)
from libp2p.host.basic_host import (
BasicHost,
)
@ -42,6 +38,8 @@ from libp2p.host.routed_host import (
RoutedHost,
)
from libp2p.network.swarm import (
ConnectionConfig,
RetryConfig,
Swarm,
)
from libp2p.peer.id import (
@ -49,22 +47,25 @@ from libp2p.peer.id import (
)
from libp2p.peer.peerstore import (
PeerStore,
create_signed_peer_record,
)
from libp2p.security.insecure.transport import (
PLAINTEXT_PROTOCOL_ID,
InsecureTransport,
)
from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID
from libp2p.security.noise.transport import Transport as NoiseTransport
from libp2p.security.noise.transport import (
PROTOCOL_ID as NOISE_PROTOCOL_ID,
Transport as NoiseTransport,
)
import libp2p.security.secio.transport as secio
from libp2p.stream_muxer.mplex.mplex import (
MPLEX_PROTOCOL_ID,
Mplex,
)
from libp2p.stream_muxer.yamux.yamux import (
PROTOCOL_ID as YAMUX_PROTOCOL_ID,
Yamux,
)
from libp2p.stream_muxer.yamux.yamux import PROTOCOL_ID as YAMUX_PROTOCOL_ID
from libp2p.transport.tcp.tcp import (
TCP,
)
@ -87,7 +88,6 @@ MUXER_MPLEX = "MPLEX"
DEFAULT_NEGOTIATE_TIMEOUT = 5
def set_default_muxer(muxer_name: Literal["YAMUX", "MPLEX"]) -> None:
"""
Set the default multiplexer protocol to use.
@ -155,7 +155,6 @@ def get_default_muxer_options() -> TMuxerOptions:
else: # YAMUX is default
return create_yamux_muxer_option()
def new_swarm(
key_pair: KeyPair | None = None,
muxer_opt: TMuxerOptions | None = None,
@ -163,6 +162,8 @@ def new_swarm(
peerstore_opt: IPeerStore | None = None,
muxer_preference: Literal["YAMUX", "MPLEX"] | None = None,
listen_addrs: Sequence[multiaddr.Multiaddr] | None = None,
retry_config: Optional["RetryConfig"] = None,
connection_config: Optional["ConnectionConfig"] = None,
) -> INetworkService:
"""
Create a swarm instance based on the parameters.
@ -239,7 +240,14 @@ def new_swarm(
# Store our key pair in peerstore
peerstore.add_key_pair(id_opt, key_pair)
return Swarm(id_opt, peerstore, upgrader, transport)
return Swarm(
id_opt,
peerstore,
upgrader,
transport,
retry_config=retry_config,
connection_config=connection_config
)
def new_host(
@ -279,6 +287,12 @@ def new_host(
if disc_opt is not None:
return RoutedHost(swarm, disc_opt, enable_mDNS, bootstrap)
return BasicHost(network=swarm,enable_mDNS=enable_mDNS , bootstrap=bootstrap, negotitate_timeout=negotiate_timeout)
return BasicHost(
network=swarm,
enable_mDNS=enable_mDNS,
bootstrap=bootstrap,
negotitate_timeout=negotiate_timeout
)
__version__ = __version("libp2p")

View File

@ -970,6 +970,14 @@ class IPeerStore(
# --------CERTIFIED-ADDR-BOOK----------
@abstractmethod
def get_local_record(self) -> Optional["Envelope"]:
"""Get the local-peer-record wrapped in Envelope"""
@abstractmethod
def set_local_record(self, envelope: "Envelope") -> None:
"""Set the local-peer-record wrapped in Envelope"""
@abstractmethod
def consume_peer_record(self, envelope: "Envelope", ttl: int) -> bool:
"""
@ -1404,15 +1412,16 @@ class INetwork(ABC):
----------
peerstore : IPeerStore
The peer store for managing peer information.
connections : dict[ID, INetConn]
A mapping of peer IDs to network connections.
connections : dict[ID, list[INetConn]]
A mapping of peer IDs to lists of network connections
(multiple connections per peer).
listeners : dict[str, IListener]
A mapping of listener identifiers to listener instances.
"""
peerstore: IPeerStore
connections: dict[ID, INetConn]
connections: dict[ID, list[INetConn]]
listeners: dict[str, IListener]
@abstractmethod
@ -1428,9 +1437,56 @@ class INetwork(ABC):
"""
@abstractmethod
async def dial_peer(self, peer_id: ID) -> INetConn:
def get_connections(self, peer_id: ID | None = None) -> list[INetConn]:
"""
Create a connection to the specified peer.
Get connections for peer (like JS getConnections, Go ConnsToPeer).
Parameters
----------
peer_id : ID | None
The peer ID to get connections for. If None, returns all connections.
Returns
-------
list[INetConn]
List of connections to the specified peer, or all connections
if peer_id is None.
"""
@abstractmethod
def get_connections_map(self) -> dict[ID, list[INetConn]]:
"""
Get all connections map (like JS getConnectionsMap).
Returns
-------
dict[ID, list[INetConn]]
The complete mapping of peer IDs to their connection lists.
"""
@abstractmethod
def get_connection(self, peer_id: ID) -> INetConn | None:
"""
Get single connection for backward compatibility.
Parameters
----------
peer_id : ID
The peer ID to get a connection for.
Returns
-------
INetConn | None
The first available connection, or None if no connections exist.
"""
@abstractmethod
async def dial_peer(self, peer_id: ID) -> list[INetConn]:
"""
Create connections to the specified peer with load balancing.
Parameters
----------
@ -1439,8 +1495,8 @@ class INetwork(ABC):
Returns
-------
INetConn
The network connection instance to the specified peer.
list[INetConn]
List of established connections to the peer.
Raises
------

View File

@ -37,3 +37,4 @@ SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool]
AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]]
ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn]
UnsubscribeFn = Callable[[], Awaitable[None]]
MessageID = NewType("MessageID", str)

View File

@ -2,15 +2,20 @@ import logging
from multiaddr import Multiaddr
from multiaddr.resolvers import DNSResolver
import trio
from libp2p.abc import ID, INetworkService, PeerInfo
from libp2p.discovery.bootstrap.utils import validate_bootstrap_addresses
from libp2p.discovery.events.peerDiscovery import peerDiscovery
from libp2p.network.exceptions import SwarmException
from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.peer.peerstore import PERMANENT_ADDR_TTL
logger = logging.getLogger("libp2p.discovery.bootstrap")
resolver = DNSResolver()
DEFAULT_CONNECTION_TIMEOUT = 10
class BootstrapDiscovery:
"""
@ -19,68 +24,147 @@ class BootstrapDiscovery:
"""
def __init__(self, swarm: INetworkService, bootstrap_addrs: list[str]):
"""
Initialize BootstrapDiscovery.
Args:
swarm: The network service (swarm) instance
bootstrap_addrs: List of bootstrap peer multiaddresses
"""
self.swarm = swarm
self.peerstore = swarm.peerstore
self.bootstrap_addrs = bootstrap_addrs or []
self.discovered_peers: set[str] = set()
self.connection_timeout: int = DEFAULT_CONNECTION_TIMEOUT
async def start(self) -> None:
"""Process bootstrap addresses and emit peer discovery events."""
logger.debug(
"""Process bootstrap addresses and emit peer discovery events in parallel."""
logger.info(
f"Starting bootstrap discovery with "
f"{len(self.bootstrap_addrs)} bootstrap addresses"
)
# Show all bootstrap addresses being processed
for i, addr in enumerate(self.bootstrap_addrs):
logger.debug(f"{i + 1}. {addr}")
# Validate and filter bootstrap addresses
self.bootstrap_addrs = validate_bootstrap_addresses(self.bootstrap_addrs)
logger.info(f"Valid addresses after validation: {len(self.bootstrap_addrs)}")
for addr_str in self.bootstrap_addrs:
try:
await self._process_bootstrap_addr(addr_str)
except Exception as e:
logger.debug(f"Failed to process bootstrap address {addr_str}: {e}")
# Use Trio nursery for PARALLEL address processing
try:
async with trio.open_nursery() as nursery:
logger.debug(
f"Starting {len(self.bootstrap_addrs)} parallel address "
f"processing tasks"
)
# Start all bootstrap address processing tasks in parallel
for addr_str in self.bootstrap_addrs:
logger.debug(f"Starting parallel task for: {addr_str}")
nursery.start_soon(self._process_bootstrap_addr, addr_str)
# The nursery will wait for all address processing tasks to complete
logger.debug(
"Nursery active - waiting for address processing tasks to complete"
)
except trio.Cancelled:
logger.debug("Bootstrap address processing cancelled - cleaning up tasks")
raise
except Exception as e:
logger.error(f"Bootstrap address processing failed: {e}")
raise
logger.info("Bootstrap discovery startup complete - all tasks finished")
def stop(self) -> None:
"""Clean up bootstrap discovery resources."""
logger.debug("Stopping bootstrap discovery")
logger.info("Stopping bootstrap discovery and cleaning up tasks")
# Clear discovered peers
self.discovered_peers.clear()
logger.debug("Bootstrap discovery cleanup completed")
async def _process_bootstrap_addr(self, addr_str: str) -> None:
"""Convert string address to PeerInfo and add to peerstore."""
try:
multiaddr = Multiaddr(addr_str)
try:
multiaddr = Multiaddr(addr_str)
except Exception as e:
logger.debug(f"Invalid multiaddr format '{addr_str}': {e}")
return
if self.is_dns_addr(multiaddr):
resolved_addrs = await resolver.resolve(multiaddr)
if resolved_addrs is None:
logger.warning(f"DNS resolution returned None for: {addr_str}")
return
peer_id_str = multiaddr.get_peer_id()
if peer_id_str is None:
logger.warning(f"Missing peer ID in DNS address: {addr_str}")
return
peer_id = ID.from_base58(peer_id_str)
addrs = [addr for addr in resolved_addrs]
if not addrs:
logger.warning(f"No addresses resolved for DNS address: {addr_str}")
return
peer_info = PeerInfo(peer_id, addrs)
await self.add_addr(peer_info)
else:
peer_info = info_from_p2p_addr(multiaddr)
await self.add_addr(peer_info)
except Exception as e:
logger.debug(f"Invalid multiaddr format '{addr_str}': {e}")
return
if self.is_dns_addr(multiaddr):
resolved_addrs = await resolver.resolve(multiaddr)
peer_id_str = multiaddr.get_peer_id()
if peer_id_str is None:
logger.warning(f"Missing peer ID in DNS address: {addr_str}")
return
peer_id = ID.from_base58(peer_id_str)
addrs = [addr for addr in resolved_addrs]
if not addrs:
logger.warning(f"No addresses resolved for DNS address: {addr_str}")
return
peer_info = PeerInfo(peer_id, addrs)
self.add_addr(peer_info)
else:
self.add_addr(info_from_p2p_addr(multiaddr))
logger.warning(f"Failed to process bootstrap address {addr_str}: {e}")
def is_dns_addr(self, addr: Multiaddr) -> bool:
"""Check if the address is a DNS address."""
return any(protocol.name == "dnsaddr" for protocol in addr.protocols())
def add_addr(self, peer_info: PeerInfo) -> None:
"""Add a peer to the peerstore and emit discovery event."""
async def add_addr(self, peer_info: PeerInfo) -> None:
"""
Add a peer to the peerstore, emit discovery event,
and attempt connection in parallel.
"""
logger.debug(
f"Adding peer {peer_info.peer_id} with {len(peer_info.addrs)} addresses"
)
# Skip if it's our own peer
if peer_info.peer_id == self.swarm.get_peer_id():
logger.debug(f"Skipping own peer ID: {peer_info.peer_id}")
return
# Always add addresses to peerstore (allows multiple addresses for same peer)
self.peerstore.add_addrs(peer_info.peer_id, peer_info.addrs, 10)
# Filter addresses to only include IPv4+TCP (only supported protocol)
ipv4_tcp_addrs = []
filtered_out_addrs = []
for addr in peer_info.addrs:
if self._is_ipv4_tcp_addr(addr):
ipv4_tcp_addrs.append(addr)
else:
filtered_out_addrs.append(addr)
# Log filtering results
logger.debug(
f"Address filtering for {peer_info.peer_id}: "
f"{len(ipv4_tcp_addrs)} IPv4+TCP, {len(filtered_out_addrs)} filtered"
)
# Skip peer if no IPv4+TCP addresses available
if not ipv4_tcp_addrs:
logger.warning(
f"❌ No IPv4+TCP addresses for {peer_info.peer_id} - "
f"skipping connection attempts"
)
return
# Add only IPv4+TCP addresses to peerstore
self.peerstore.add_addrs(peer_info.peer_id, ipv4_tcp_addrs, PERMANENT_ADDR_TTL)
# Only emit discovery event if this is the first time we see this peer
peer_id_str = str(peer_info.peer_id)
@ -89,6 +173,140 @@ class BootstrapDiscovery:
self.discovered_peers.add(peer_id_str)
# Emit peer discovery event
peerDiscovery.emit_peer_discovered(peer_info)
logger.debug(f"Peer discovered: {peer_info.peer_id}")
logger.info(f"Peer discovered: {peer_info.peer_id}")
# Connect to peer (parallel across different bootstrap addresses)
logger.debug("Connecting to discovered peer...")
await self._connect_to_peer(peer_info.peer_id)
else:
logger.debug(f"Additional addresses added for peer: {peer_info.peer_id}")
logger.debug(
f"Additional addresses added for existing peer: {peer_info.peer_id}"
)
# Even for existing peers, try to connect if not already connected
if peer_info.peer_id not in self.swarm.connections:
logger.debug("Connecting to existing peer...")
await self._connect_to_peer(peer_info.peer_id)
async def _connect_to_peer(self, peer_id: ID) -> None:
"""
Attempt to establish a connection to a peer with timeout.
Uses swarm.dial_peer to connect using addresses stored in peerstore.
Times out after self.connection_timeout seconds to prevent hanging.
"""
logger.debug(f"Connection attempt for peer: {peer_id}")
# Pre-connection validation: Check if already connected
if peer_id in self.swarm.connections:
logger.debug(
f"Already connected to {peer_id} - skipping connection attempt"
)
return
# Check available addresses before attempting connection
available_addrs = self.peerstore.addrs(peer_id)
logger.debug(f"Connecting to {peer_id} ({len(available_addrs)} addresses)")
if not available_addrs:
logger.error(f"❌ No addresses available for {peer_id} - cannot connect")
return
# Record start time for connection attempt monitoring
connection_start_time = trio.current_time()
try:
with trio.move_on_after(self.connection_timeout):
# Log connection attempt
logger.debug(
f"Attempting connection to {peer_id} using "
f"{len(available_addrs)} addresses"
)
# Use swarm.dial_peer to connect using stored addresses
await self.swarm.dial_peer(peer_id)
# Calculate connection time
connection_time = trio.current_time() - connection_start_time
# Post-connection validation: Verify connection was actually established
if peer_id in self.swarm.connections:
logger.info(
f"✅ Connected to {peer_id} (took {connection_time:.2f}s)"
)
else:
logger.warning(
f"Dial succeeded but connection not found for {peer_id}"
)
except trio.TooSlowError:
logger.warning(
f"❌ Connection to {peer_id} timed out after {self.connection_timeout}s"
)
except SwarmException as e:
# Calculate failed connection time
failed_connection_time = trio.current_time() - connection_start_time
# Enhanced error logging
error_msg = str(e)
if "no addresses established a successful connection" in error_msg:
logger.warning(
f"❌ Failed to connect to {peer_id} after trying all "
f"{len(available_addrs)} addresses "
f"(took {failed_connection_time:.2f}s)"
)
# Log individual address failures if this is a MultiError
if (
e.__cause__ is not None
and hasattr(e.__cause__, "exceptions")
and getattr(e.__cause__, "exceptions", None) is not None
):
exceptions_list = getattr(e.__cause__, "exceptions")
logger.debug("📋 Individual address failure details:")
for i, addr_exception in enumerate(exceptions_list, 1):
logger.debug(f"Address {i}: {addr_exception}")
# Also log the actual address that failed
if i <= len(available_addrs):
logger.debug(f"Failed address: {available_addrs[i - 1]}")
else:
logger.warning("No detailed exception information available")
else:
logger.warning(
f"❌ Failed to connect to {peer_id}: {e} "
f"(took {failed_connection_time:.2f}s)"
)
except Exception as e:
# Handle unexpected errors that aren't swarm-specific
failed_connection_time = trio.current_time() - connection_start_time
logger.error(
f"❌ Unexpected error connecting to {peer_id}: "
f"{e} (took {failed_connection_time:.2f}s)"
)
# Don't re-raise to prevent killing the nursery and other parallel tasks
def _is_ipv4_tcp_addr(self, addr: Multiaddr) -> bool:
"""
Check if address is IPv4 with TCP protocol only.
Filters out IPv6, UDP, QUIC, WebSocket, and other unsupported protocols.
Only IPv4+TCP addresses are supported by the current transport.
"""
try:
protocols = addr.protocols()
# Must have IPv4 protocol
has_ipv4 = any(p.name == "ip4" for p in protocols)
if not has_ipv4:
return False
# Must have TCP protocol
has_tcp = any(p.name == "tcp" for p in protocols)
if not has_tcp:
return False
return True
except Exception:
# If we can't parse the address, don't use it
return False

View File

@ -0,0 +1,17 @@
"""Random walk discovery modules for py-libp2p."""
from .rt_refresh_manager import RTRefreshManager
from .random_walk import RandomWalk
from .exceptions import (
RoutingTableRefreshError,
RandomWalkError,
PeerValidationError,
)
__all__ = [
"RTRefreshManager",
"RandomWalk",
"RoutingTableRefreshError",
"RandomWalkError",
"PeerValidationError",
]

View File

@ -0,0 +1,16 @@
from typing import Final
# Timing constants (matching go-libp2p)
PEER_PING_TIMEOUT: Final[float] = 10.0 # seconds
REFRESH_QUERY_TIMEOUT: Final[float] = 60.0 # seconds
REFRESH_INTERVAL: Final[float] = 300.0 # 5 minutes
SUCCESSFUL_OUTBOUND_QUERY_GRACE_PERIOD: Final[float] = 60.0 # 1 minute
# Routing table thresholds
MIN_RT_REFRESH_THRESHOLD: Final[int] = 4 # Minimum peers before triggering refresh
MAX_N_BOOTSTRAPPERS: Final[int] = 2 # Maximum bootstrap peers to try
# Random walk specific
RANDOM_WALK_CONCURRENCY: Final[int] = 3 # Number of concurrent random walks
RANDOM_WALK_ENABLED: Final[bool] = True # Enable automatic random walks
RANDOM_WALK_RT_THRESHOLD: Final[int] = 20 # RT size threshold for peerstore fallback

View File

@ -0,0 +1,19 @@
from libp2p.exceptions import BaseLibp2pError
class RoutingTableRefreshError(BaseLibp2pError):
"""Base exception for routing table refresh operations."""
pass
class RandomWalkError(RoutingTableRefreshError):
"""Exception raised during random walk operations."""
pass
class PeerValidationError(RoutingTableRefreshError):
"""Exception raised when peer validation fails."""
pass

View File

@ -0,0 +1,218 @@
from collections.abc import Awaitable, Callable
import logging
import secrets
import trio
from libp2p.abc import IHost
from libp2p.discovery.random_walk.config import (
RANDOM_WALK_CONCURRENCY,
RANDOM_WALK_RT_THRESHOLD,
REFRESH_QUERY_TIMEOUT,
)
from libp2p.discovery.random_walk.exceptions import RandomWalkError
from libp2p.peer.id import ID
from libp2p.peer.peerinfo import PeerInfo
logger = logging.getLogger("libp2p.discovery.random_walk")
class RandomWalk:
"""
Random Walk implementation for peer discovery in Kademlia DHT.
Generates random peer IDs and performs FIND_NODE queries to discover
new peers and populate the routing table.
"""
def __init__(
self,
host: IHost,
local_peer_id: ID,
query_function: Callable[[bytes], Awaitable[list[ID]]],
):
"""
Initialize Random Walk module.
Args:
host: The libp2p host instance
local_peer_id: Local peer ID
query_function: Function to query for closest peers given target key bytes
"""
self.host = host
self.local_peer_id = local_peer_id
self.query_function = query_function
def generate_random_peer_id(self) -> str:
"""
Generate a completely random peer ID
for random walk queries.
Returns:
Random peer ID as string
"""
# Generate 32 random bytes (256 bits) - same as go-libp2p
random_bytes = secrets.token_bytes(32)
# Convert to hex string for query
return random_bytes.hex()
async def perform_random_walk(self) -> list[PeerInfo]:
"""
Perform a single random walk operation.
Returns:
List of validated peers discovered during the walk
"""
try:
# Generate random peer ID
random_peer_id = self.generate_random_peer_id()
logger.info(f"Starting random walk for peer ID: {random_peer_id}")
# Perform FIND_NODE query
discovered_peer_ids: list[ID] = []
with trio.move_on_after(REFRESH_QUERY_TIMEOUT):
# Call the query function with target key bytes
target_key = bytes.fromhex(random_peer_id)
discovered_peer_ids = await self.query_function(target_key) or []
if not discovered_peer_ids:
logger.debug(f"No peers discovered in random walk for {random_peer_id}")
return []
logger.info(
f"Discovered {len(discovered_peer_ids)} peers in random walk "
f"for {random_peer_id[:8]}..." # Show only first 8 chars for brevity
)
# Convert peer IDs to PeerInfo objects and validate
validated_peers: list[PeerInfo] = []
for peer_id in discovered_peer_ids:
try:
# Get addresses from peerstore
addrs = self.host.get_peerstore().addrs(peer_id)
if addrs:
peer_info = PeerInfo(peer_id, addrs)
validated_peers.append(peer_info)
except Exception as e:
logger.debug(f"Failed to create PeerInfo for {peer_id}: {e}")
continue
return validated_peers
except Exception as e:
logger.error(f"Random walk failed: {e}")
raise RandomWalkError(f"Random walk operation failed: {e}") from e
async def run_concurrent_random_walks(
self, count: int = RANDOM_WALK_CONCURRENCY, current_routing_table_size: int = 0
) -> list[PeerInfo]:
"""
Run multiple random walks concurrently.
Args:
count: Number of concurrent random walks to perform
current_routing_table_size: Current size of routing table (for optimization)
Returns:
Combined list of all validated peers discovered
"""
all_validated_peers: list[PeerInfo] = []
logger.info(f"Starting {count} concurrent random walks")
# First, try to add peers from peerstore if routing table is small
if current_routing_table_size < RANDOM_WALK_RT_THRESHOLD:
try:
peerstore_peers = self._get_peerstore_peers()
if peerstore_peers:
logger.debug(
f"RT size ({current_routing_table_size}) below threshold, "
f"adding {len(peerstore_peers)} peerstore peers"
)
all_validated_peers.extend(peerstore_peers)
except Exception as e:
logger.warning(f"Error processing peerstore peers: {e}")
async def single_walk() -> None:
try:
peers = await self.perform_random_walk()
all_validated_peers.extend(peers)
except Exception as e:
logger.warning(f"Concurrent random walk failed: {e}")
return
# Run concurrent random walks
async with trio.open_nursery() as nursery:
for _ in range(count):
nursery.start_soon(single_walk)
# Remove duplicates based on peer ID
unique_peers = {}
for peer in all_validated_peers:
unique_peers[peer.peer_id] = peer
result = list(unique_peers.values())
logger.info(
f"Concurrent random walks completed: {len(result)} unique peers discovered"
)
return result
def _get_peerstore_peers(self) -> list[PeerInfo]:
"""
Get peer info objects from the host's peerstore.
Returns:
List of PeerInfo objects from peerstore
"""
try:
peerstore = self.host.get_peerstore()
peer_ids = peerstore.peers_with_addrs()
peer_infos = []
for peer_id in peer_ids:
try:
# Skip local peer
if peer_id == self.local_peer_id:
continue
peer_info = peerstore.peer_info(peer_id)
if peer_info and peer_info.addrs:
# Filter for compatible addresses (TCP + IPv4)
if self._has_compatible_addresses(peer_info):
peer_infos.append(peer_info)
except Exception as e:
logger.debug(f"Error getting peer info for {peer_id}: {e}")
return peer_infos
except Exception as e:
logger.warning(f"Error accessing peerstore: {e}")
return []
def _has_compatible_addresses(self, peer_info: PeerInfo) -> bool:
"""
Check if a peer has TCP+IPv4 compatible addresses.
Args:
peer_info: PeerInfo to check
Returns:
True if peer has compatible addresses
"""
if not peer_info.addrs:
return False
for addr in peer_info.addrs:
addr_str = str(addr)
# Check for TCP and IPv4 compatibility, avoid QUIC
if "/tcp/" in addr_str and "/ip4/" in addr_str and "/quic" not in addr_str:
return True
return False

View File

@ -0,0 +1,208 @@
from collections.abc import Awaitable, Callable
import logging
import time
from typing import Protocol
import trio
from libp2p.abc import IHost
from libp2p.discovery.random_walk.config import (
MIN_RT_REFRESH_THRESHOLD,
RANDOM_WALK_CONCURRENCY,
RANDOM_WALK_ENABLED,
REFRESH_INTERVAL,
)
from libp2p.discovery.random_walk.exceptions import RoutingTableRefreshError
from libp2p.discovery.random_walk.random_walk import RandomWalk
from libp2p.peer.id import ID
from libp2p.peer.peerinfo import PeerInfo
class RoutingTableProtocol(Protocol):
"""Protocol for routing table operations needed by RT refresh manager."""
def size(self) -> int:
"""Return the current size of the routing table."""
...
async def add_peer(self, peer_obj: PeerInfo) -> bool:
"""Add a peer to the routing table."""
...
logger = logging.getLogger("libp2p.discovery.random_walk.rt_refresh_manager")
class RTRefreshManager:
"""
Routing Table Refresh Manager for py-libp2p.
Manages periodic routing table refreshes and random walk operations
to maintain routing table health and discover new peers.
"""
def __init__(
self,
host: IHost,
routing_table: RoutingTableProtocol,
local_peer_id: ID,
query_function: Callable[[bytes], Awaitable[list[ID]]],
enable_auto_refresh: bool = RANDOM_WALK_ENABLED,
refresh_interval: float = REFRESH_INTERVAL,
min_refresh_threshold: int = MIN_RT_REFRESH_THRESHOLD,
):
"""
Initialize RT Refresh Manager.
Args:
host: The libp2p host instance
routing_table: Routing table of host
local_peer_id: Local peer ID
query_function: Function to query for closest peers given target key bytes
enable_auto_refresh: Whether to enable automatic refresh
refresh_interval: Interval between refreshes in seconds
min_refresh_threshold: Minimum RT size before triggering refresh
"""
self.host = host
self.routing_table = routing_table
self.local_peer_id = local_peer_id
self.query_function = query_function
self.enable_auto_refresh = enable_auto_refresh
self.refresh_interval = refresh_interval
self.min_refresh_threshold = min_refresh_threshold
# Initialize random walk module
self.random_walk = RandomWalk(
host=host,
local_peer_id=self.local_peer_id,
query_function=query_function,
)
# Control variables
self._running = False
self._nursery: trio.Nursery | None = None
# Tracking
self._last_refresh_time = 0.0
self._refresh_done_callbacks: list[Callable[[], None]] = []
async def start(self) -> None:
"""Start the RT Refresh Manager."""
if self._running:
logger.warning("RT Refresh Manager is already running")
return
self._running = True
logger.info("Starting RT Refresh Manager")
# Start the main loop
async with trio.open_nursery() as nursery:
self._nursery = nursery
nursery.start_soon(self._main_loop)
async def stop(self) -> None:
"""Stop the RT Refresh Manager."""
if not self._running:
return
logger.info("Stopping RT Refresh Manager")
self._running = False
async def _main_loop(self) -> None:
"""Main loop for the RT Refresh Manager."""
logger.info("RT Refresh Manager main loop started")
# Initial refresh if auto-refresh is enabled
if self.enable_auto_refresh:
await self._do_refresh(force=True)
try:
while self._running:
async with trio.open_nursery() as nursery:
# Schedule periodic refresh if enabled
if self.enable_auto_refresh:
nursery.start_soon(self._periodic_refresh_task)
except Exception as e:
logger.error(f"RT Refresh Manager main loop error: {e}")
finally:
logger.info("RT Refresh Manager main loop stopped")
async def _periodic_refresh_task(self) -> None:
"""Task for periodic refreshes."""
while self._running:
await trio.sleep(self.refresh_interval)
if self._running:
await self._do_refresh()
async def _do_refresh(self, force: bool = False) -> None:
"""
Perform routing table refresh operation.
Args:
force: Whether to force refresh regardless of timing
"""
try:
current_time = time.time()
# Check if refresh is needed
if not force:
if current_time - self._last_refresh_time < self.refresh_interval:
logger.debug("Skipping refresh: interval not elapsed")
return
if self.routing_table.size() >= self.min_refresh_threshold:
logger.debug("Skipping refresh: routing table size above threshold")
return
logger.info(f"Starting routing table refresh (force={force})")
start_time = current_time
# Perform random walks to discover new peers
logger.info("Running concurrent random walks to discover new peers")
current_rt_size = self.routing_table.size()
discovered_peers = await self.random_walk.run_concurrent_random_walks(
count=RANDOM_WALK_CONCURRENCY,
current_routing_table_size=current_rt_size,
)
# Add discovered peers to routing table
added_count = 0
for peer_info in discovered_peers:
result = await self.routing_table.add_peer(peer_info)
if result:
added_count += 1
self._last_refresh_time = current_time
duration = time.time() - start_time
logger.info(
f"Routing table refresh completed: "
f"{added_count}/{len(discovered_peers)} peers added, "
f"RT size: {self.routing_table.size()}, "
f"duration: {duration:.2f}s"
)
# Notify refresh completion
for callback in self._refresh_done_callbacks:
try:
callback()
except Exception as e:
logger.warning(f"Refresh callback error: {e}")
except Exception as e:
logger.error(f"Routing table refresh failed: {e}")
raise RoutingTableRefreshError(f"Refresh operation failed: {e}") from e
def add_refresh_done_callback(self, callback: Callable[[], None]) -> None:
"""Add a callback to be called when refresh completes."""
self._refresh_done_callbacks.append(callback)
def remove_refresh_done_callback(self, callback: Callable[[], None]) -> None:
"""Remove a refresh completion callback."""
if callback in self._refresh_done_callbacks:
self._refresh_done_callbacks.remove(callback)

View File

@ -43,6 +43,7 @@ from libp2p.peer.id import (
from libp2p.peer.peerinfo import (
PeerInfo,
)
from libp2p.peer.peerstore import create_signed_peer_record
from libp2p.protocol_muxer.exceptions import (
MultiselectClientError,
MultiselectError,
@ -110,6 +111,14 @@ class BasicHost(IHost):
if bootstrap:
self.bootstrap = BootstrapDiscovery(network, bootstrap)
# Cache a signed-record if the local-node in the PeerStore
envelope = create_signed_peer_record(
self.get_id(),
self.get_addrs(),
self.get_private_key(),
)
self.get_peerstore().set_local_record(envelope)
def get_id(self) -> ID:
"""
:return: peer_id of host
@ -288,6 +297,11 @@ class BasicHost(IHost):
protocol, handler = await self.multiselect.negotiate(
MultiselectCommunicator(net_stream), self.negotiate_timeout
)
if protocol is None:
await net_stream.reset()
raise StreamFailure(
"Failed to negotiate protocol: no protocol selected"
)
except MultiselectError as error:
peer_id = net_stream.muxed_conn.peer_id
logger.debug(
@ -295,6 +309,13 @@ class BasicHost(IHost):
)
await net_stream.reset()
return
if protocol is None:
logger.debug(
"no protocol negotiated, closing stream from peer %s",
net_stream.muxed_conn.peer_id,
)
await net_stream.reset()
return
net_stream.set_protocol(protocol)
if handler is None:
logger.debug(
@ -322,7 +343,7 @@ class BasicHost(IHost):
:param peer_id: ID of the peer to check
:return: True if peer has an active connection, False otherwise
"""
return peer_id in self._network.connections
return len(self._network.get_connections(peer_id)) > 0
def get_peer_connection_info(self, peer_id: ID) -> INetConn | None:
"""
@ -331,4 +352,4 @@ class BasicHost(IHost):
:param peer_id: ID of the peer to get info for
:return: Connection object if peer is connected, None otherwise
"""
return self._network.connections.get(peer_id)
return self._network.get_connection(peer_id)

View File

@ -15,8 +15,7 @@ from libp2p.custom_types import (
from libp2p.network.stream.exceptions import (
StreamClosed,
)
from libp2p.peer.envelope import seal_record
from libp2p.peer.peer_record import PeerRecord
from libp2p.peer.peerstore import env_to_send_in_RPC
from libp2p.utils import (
decode_varint_with_size,
get_agent_version,
@ -66,9 +65,7 @@ def _mk_identify_protobuf(
protocols = tuple(str(p) for p in host.get_mux().get_protocols() if p is not None)
# Create a signed peer-record for the remote peer
record = PeerRecord(host.get_id(), host.get_addrs())
envelope = seal_record(record, host.get_private_key())
protobuf = envelope.marshal_envelope()
envelope_bytes, _ = env_to_send_in_RPC(host)
observed_addr = observed_multiaddr.to_bytes() if observed_multiaddr else b""
return Identify(
@ -78,7 +75,7 @@ def _mk_identify_protobuf(
listen_addrs=map(_multiaddr_to_bytes, laddrs),
observed_addr=observed_addr,
protocols=protocols,
signedPeerRecord=protobuf,
signedPeerRecord=envelope_bytes,
)

View File

@ -5,6 +5,7 @@ This module provides a complete Distributed Hash Table (DHT)
implementation based on the Kademlia algorithm and protocol.
"""
from collections.abc import Awaitable, Callable
from enum import (
Enum,
)
@ -20,15 +21,19 @@ import varint
from libp2p.abc import (
IHost,
)
from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager
from libp2p.kad_dht.utils import maybe_consume_signed_record
from libp2p.network.stream.net_stream import (
INetStream,
)
from libp2p.peer.envelope import Envelope
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
from libp2p.peer.peerstore import env_to_send_in_RPC
from libp2p.tools.async_service import (
Service,
)
@ -73,14 +78,27 @@ class KadDHT(Service):
This class provides a DHT implementation that combines routing table management,
peer discovery, content routing, and value storage.
Optional Random Walk feature enhances peer discovery by automatically
performing periodic random queries to discover new peers and maintain
routing table health.
Example:
# Basic DHT without random walk (default)
dht = KadDHT(host, DHTMode.SERVER)
# DHT with random walk enabled for enhanced peer discovery
dht = KadDHT(host, DHTMode.SERVER, enable_random_walk=True)
"""
def __init__(self, host: IHost, mode: DHTMode):
def __init__(self, host: IHost, mode: DHTMode, enable_random_walk: bool = False):
"""
Initialize a new Kademlia DHT node.
:param host: The libp2p host.
:param mode: The mode of host (Client or Server) - must be DHTMode enum
:param enable_random_walk: Whether to enable automatic random walk
"""
super().__init__()
@ -92,6 +110,7 @@ class KadDHT(Service):
raise TypeError(f"mode must be DHTMode enum, got {type(mode)}")
self.mode = mode
self.enable_random_walk = enable_random_walk
# Initialize the routing table
self.routing_table = RoutingTable(self.local_peer_id, self.host)
@ -108,13 +127,56 @@ class KadDHT(Service):
# Last time we republished provider records
self._last_provider_republish = time.time()
# Initialize RT Refresh Manager (only if random walk is enabled)
self.rt_refresh_manager: RTRefreshManager | None = None
if self.enable_random_walk:
self.rt_refresh_manager = RTRefreshManager(
host=self.host,
routing_table=self.routing_table,
local_peer_id=self.local_peer_id,
query_function=self._create_query_function(),
enable_auto_refresh=True,
)
# Set protocol handlers
host.set_stream_handler(PROTOCOL_ID, self.handle_stream)
def _create_query_function(self) -> Callable[[bytes], Awaitable[list[ID]]]:
"""
Create a query function that wraps peer_routing.find_closest_peers_network.
This function is used by the RandomWalk module to query for peers without
directly importing PeerRouting, avoiding circular import issues.
Returns:
Callable that takes target_key bytes and returns list of peer IDs
"""
async def query_function(target_key: bytes) -> list[ID]:
"""Query for closest peers to target key."""
return await self.peer_routing.find_closest_peers_network(target_key)
return query_function
async def run(self) -> None:
"""Run the DHT service."""
logger.info(f"Starting Kademlia DHT with peer ID {self.local_peer_id}")
# Start the RT Refresh Manager in parallel with the main DHT service
async with trio.open_nursery() as nursery:
# Start the RT Refresh Manager only if random walk is enabled
if self.rt_refresh_manager is not None:
nursery.start_soon(self.rt_refresh_manager.start)
logger.info("RT Refresh Manager started - Random Walk is now active")
else:
logger.info("Random Walk is disabled - RT Refresh Manager not started")
# Start the main DHT service loop
nursery.start_soon(self._run_main_loop)
async def _run_main_loop(self) -> None:
"""Run the main DHT service loop."""
# Main service loop
while self.manager.is_running:
# Periodically refresh the routing table
@ -135,6 +197,17 @@ class KadDHT(Service):
# Wait before next maintenance cycle
await trio.sleep(ROUTING_TABLE_REFRESH_INTERVAL)
async def stop(self) -> None:
"""Stop the DHT service and cleanup resources."""
logger.info("Stopping Kademlia DHT")
# Stop the RT Refresh Manager only if it was started
if self.rt_refresh_manager is not None:
await self.rt_refresh_manager.stop()
logger.info("RT Refresh Manager stopped")
else:
logger.info("RT Refresh Manager was not running (Random Walk disabled)")
async def switch_mode(self, new_mode: DHTMode) -> DHTMode:
"""
Switch the DHT mode.
@ -164,6 +237,9 @@ class KadDHT(Service):
await self.add_peer(peer_id)
logger.debug(f"Added peer {peer_id} to routing table")
closer_peer_envelope: Envelope | None = None
provider_peer_envelope: Envelope | None = None
try:
# Read varint-prefixed length for the message
length_prefix = b""
@ -204,6 +280,14 @@ class KadDHT(Service):
)
logger.debug(f"Found {len(closest_peers)} peers close to target")
# Consume the source signed_peer_record if sent
if not maybe_consume_signed_record(message, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, dropping the stream"
)
await stream.close()
return
# Build response message with protobuf
response = Message()
response.type = Message.MessageType.FIND_NODE
@ -228,6 +312,21 @@ class KadDHT(Service):
except Exception:
pass
# Add the signed-peer-record for each peer in the peer-proto
# if cached in the peerstore
closer_peer_envelope = (
self.host.get_peerstore().get_peer_record(peer)
)
if closer_peer_envelope is not None:
peer_proto.signedRecord = (
closer_peer_envelope.marshal_envelope()
)
# Create sender_signed_peer_record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
# Serialize and send response
response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes)))
@ -242,6 +341,14 @@ class KadDHT(Service):
key = message.key
logger.debug(f"Received ADD_PROVIDER for key {key.hex()}")
# Consume the source signed-peer-record if sent
if not maybe_consume_signed_record(message, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, dropping the stream"
)
await stream.close()
return
# Extract provider information
for provider_proto in message.providerPeers:
try:
@ -268,6 +375,17 @@ class KadDHT(Service):
logger.debug(
f"Added provider {provider_id} for key {key.hex()}"
)
# Process the signed-records of provider if sent
if not maybe_consume_signed_record(
provider_proto, self.host
):
logger.error(
"Received an invalid-signed-record,"
"dropping the stream"
)
await stream.close()
return
except Exception as e:
logger.warning(f"Failed to process provider info: {e}")
@ -276,6 +394,10 @@ class KadDHT(Service):
response.type = Message.MessageType.ADD_PROVIDER
response.key = key
# Add sender's signed-peer-record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes)))
await stream.write(response_bytes)
@ -287,6 +409,14 @@ class KadDHT(Service):
key = message.key
logger.debug(f"Received GET_PROVIDERS request for key {key.hex()}")
# Consume the source signed_peer_record if sent
if not maybe_consume_signed_record(message, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, dropping the stream"
)
await stream.close()
return
# Find providers for the key
providers = self.provider_store.get_providers(key)
logger.debug(
@ -298,12 +428,28 @@ class KadDHT(Service):
response.type = Message.MessageType.GET_PROVIDERS
response.key = key
# Create sender_signed_peer_record for the response
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
# Add provider information to response
for provider_info in providers:
provider_proto = response.providerPeers.add()
provider_proto.id = provider_info.peer_id.to_bytes()
provider_proto.connection = Message.ConnectionType.CAN_CONNECT
# Add provider signed-records if cached
provider_peer_envelope = (
self.host.get_peerstore().get_peer_record(
provider_info.peer_id
)
)
if provider_peer_envelope is not None:
provider_proto.signedRecord = (
provider_peer_envelope.marshal_envelope()
)
# Add addresses if available
for addr in provider_info.addrs:
provider_proto.addrs.append(addr.to_bytes())
@ -327,6 +473,16 @@ class KadDHT(Service):
peer_proto.id = peer.to_bytes()
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
# Add the signed-records of closest_peers if cached
closer_peer_envelope = (
self.host.get_peerstore().get_peer_record(peer)
)
if closer_peer_envelope is not None:
peer_proto.signedRecord = (
closer_peer_envelope.marshal_envelope()
)
# Add addresses if available
try:
addrs = self.host.get_peerstore().addrs(peer)
@ -347,6 +503,14 @@ class KadDHT(Service):
key = message.key
logger.debug(f"Received GET_VALUE request for key {key.hex()}")
# Consume the sender_signed_peer_record
if not maybe_consume_signed_record(message, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, dropping the stream"
)
await stream.close()
return
value = self.value_store.get(key)
if value:
logger.debug(f"Found value for key {key.hex()}")
@ -361,6 +525,10 @@ class KadDHT(Service):
response.record.value = value
response.record.timeReceived = str(time.time())
# Create sender_signed_peer_record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
# Serialize and send response
response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes)))
@ -374,6 +542,10 @@ class KadDHT(Service):
response.type = Message.MessageType.GET_VALUE
response.key = key
# Create sender_signed_peer_record for the response
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
# Add closest peers to key
closest_peers = self.routing_table.find_local_closest_peers(
key, 20
@ -392,6 +564,16 @@ class KadDHT(Service):
peer_proto.id = peer.to_bytes()
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
# Add signed-records of closer-peers if cached
closer_peer_envelope = (
self.host.get_peerstore().get_peer_record(peer)
)
if closer_peer_envelope is not None:
peer_proto.signedRecord = (
closer_peer_envelope.marshal_envelope()
)
# Add addresses if available
try:
addrs = self.host.get_peerstore().addrs(peer)
@ -414,6 +596,15 @@ class KadDHT(Service):
key = message.record.key
value = message.record.value
success = False
# Consume the source signed_peer_record if sent
if not maybe_consume_signed_record(message, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, dropping the stream"
)
await stream.close()
return
try:
if not (key and value):
raise ValueError(
@ -434,6 +625,12 @@ class KadDHT(Service):
response.type = Message.MessageType.PUT_VALUE
if success:
response.key = key
# Create sender_signed_peer_record for the response
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
# Serialize and send response
response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes)))
await stream.write(response_bytes)
@ -614,3 +811,15 @@ class KadDHT(Service):
"""
return self.value_store.size()
def is_random_walk_enabled(self) -> bool:
"""
Check if random walk peer discovery is enabled.
Returns
-------
bool
True if random walk is enabled, False otherwise.
"""
return self.enable_random_walk

View File

@ -27,6 +27,7 @@ message Message {
bytes id = 1;
repeated bytes addrs = 2;
ConnectionType connection = 3;
optional bytes signedRecord = 4; // Envelope(PeerRecord) encoded
}
MessageType type = 1;
@ -35,4 +36,6 @@ message Message {
Record record = 3;
repeated Peer closerPeers = 8;
repeated Peer providerPeers = 9;
optional bytes senderRecord = 11; // Envelope(PeerRecord) encoded
}

View File

@ -1,11 +1,12 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: libp2p/kad_dht/pb/kademlia.proto
# Protobuf Python Version: 4.25.3
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
@ -13,21 +14,21 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/kad_dht/pb/kademlia.proto\":\n\x06Record\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x14\n\x0ctimeReceived\x18\x05 \x01(\t\"\xca\x03\n\x07Message\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.Message.MessageType\x12\x17\n\x0f\x63lusterLevelRaw\x18\n \x01(\x05\x12\x0b\n\x03key\x18\x02 \x01(\x0c\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record\x12\"\n\x0b\x63loserPeers\x18\x08 \x03(\x0b\x32\r.Message.Peer\x12$\n\rproviderPeers\x18\t \x03(\x0b\x32\r.Message.Peer\x1aN\n\x04Peer\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12+\n\nconnection\x18\x03 \x01(\x0e\x32\x17.Message.ConnectionType\"i\n\x0bMessageType\x12\r\n\tPUT_VALUE\x10\x00\x12\r\n\tGET_VALUE\x10\x01\x12\x10\n\x0c\x41\x44\x44_PROVIDER\x10\x02\x12\x11\n\rGET_PROVIDERS\x10\x03\x12\r\n\tFIND_NODE\x10\x04\x12\x08\n\x04PING\x10\x05\"W\n\x0e\x43onnectionType\x12\x11\n\rNOT_CONNECTED\x10\x00\x12\r\n\tCONNECTED\x10\x01\x12\x0f\n\x0b\x43\x41N_CONNECT\x10\x02\x12\x12\n\x0e\x43\x41NNOT_CONNECT\x10\x03\x62\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/kad_dht/pb/kademlia.proto\":\n\x06Record\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x14\n\x0ctimeReceived\x18\x05 \x01(\t\"\xa2\x04\n\x07Message\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.Message.MessageType\x12\x17\n\x0f\x63lusterLevelRaw\x18\n \x01(\x05\x12\x0b\n\x03key\x18\x02 \x01(\x0c\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record\x12\"\n\x0b\x63loserPeers\x18\x08 \x03(\x0b\x32\r.Message.Peer\x12$\n\rproviderPeers\x18\t \x03(\x0b\x32\r.Message.Peer\x12\x19\n\x0csenderRecord\x18\x0b \x01(\x0cH\x00\x88\x01\x01\x1az\n\x04Peer\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12+\n\nconnection\x18\x03 \x01(\x0e\x32\x17.Message.ConnectionType\x12\x19\n\x0csignedRecord\x18\x04 \x01(\x0cH\x00\x88\x01\x01\x42\x0f\n\r_signedRecord\"i\n\x0bMessageType\x12\r\n\tPUT_VALUE\x10\x00\x12\r\n\tGET_VALUE\x10\x01\x12\x10\n\x0c\x41\x44\x44_PROVIDER\x10\x02\x12\x11\n\rGET_PROVIDERS\x10\x03\x12\r\n\tFIND_NODE\x10\x04\x12\x08\n\x04PING\x10\x05\"W\n\x0e\x43onnectionType\x12\x11\n\rNOT_CONNECTED\x10\x00\x12\r\n\tCONNECTED\x10\x01\x12\x0f\n\x0b\x43\x41N_CONNECT\x10\x02\x12\x12\n\x0e\x43\x41NNOT_CONNECT\x10\x03\x42\x0f\n\r_senderRecordb\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.kad_dht.pb.kademlia_pb2', globals())
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.kad_dht.pb.kademlia_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_RECORD._serialized_start=36
_RECORD._serialized_end=94
_MESSAGE._serialized_start=97
_MESSAGE._serialized_end=555
_MESSAGE_PEER._serialized_start=281
_MESSAGE_PEER._serialized_end=359
_MESSAGE_MESSAGETYPE._serialized_start=361
_MESSAGE_MESSAGETYPE._serialized_end=466
_MESSAGE_CONNECTIONTYPE._serialized_start=468
_MESSAGE_CONNECTIONTYPE._serialized_end=555
_globals['_RECORD']._serialized_start=36
_globals['_RECORD']._serialized_end=94
_globals['_MESSAGE']._serialized_start=97
_globals['_MESSAGE']._serialized_end=643
_globals['_MESSAGE_PEER']._serialized_start=308
_globals['_MESSAGE_PEER']._serialized_end=430
_globals['_MESSAGE_MESSAGETYPE']._serialized_start=432
_globals['_MESSAGE_MESSAGETYPE']._serialized_end=537
_globals['_MESSAGE_CONNECTIONTYPE']._serialized_start=539
_globals['_MESSAGE_CONNECTIONTYPE']._serialized_end=626
# @@protoc_insertion_point(module_scope)

View File

@ -1,133 +1,70 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
"""
from google.protobuf.internal import containers as _containers
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union
import builtins
import collections.abc
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.internal.enum_type_wrapper
import google.protobuf.message
import sys
import typing
DESCRIPTOR: _descriptor.FileDescriptor
if sys.version_info >= (3, 10):
import typing as typing_extensions
else:
import typing_extensions
class Record(_message.Message):
__slots__ = ("key", "value", "timeReceived")
KEY_FIELD_NUMBER: _ClassVar[int]
VALUE_FIELD_NUMBER: _ClassVar[int]
TIMERECEIVED_FIELD_NUMBER: _ClassVar[int]
key: bytes
value: bytes
timeReceived: str
def __init__(self, key: _Optional[bytes] = ..., value: _Optional[bytes] = ..., timeReceived: _Optional[str] = ...) -> None: ...
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@typing.final
class Record(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEY_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
TIMERECEIVED_FIELD_NUMBER: builtins.int
key: builtins.bytes
value: builtins.bytes
timeReceived: builtins.str
def __init__(
self,
*,
key: builtins.bytes = ...,
value: builtins.bytes = ...,
timeReceived: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["key", b"key", "timeReceived", b"timeReceived", "value", b"value"]) -> None: ...
global___Record = Record
@typing.final
class Message(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
class _MessageType:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _MessageTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._MessageType.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
PUT_VALUE: Message._MessageType.ValueType # 0
GET_VALUE: Message._MessageType.ValueType # 1
ADD_PROVIDER: Message._MessageType.ValueType # 2
GET_PROVIDERS: Message._MessageType.ValueType # 3
FIND_NODE: Message._MessageType.ValueType # 4
PING: Message._MessageType.ValueType # 5
class MessageType(_MessageType, metaclass=_MessageTypeEnumTypeWrapper): ...
PUT_VALUE: Message.MessageType.ValueType # 0
GET_VALUE: Message.MessageType.ValueType # 1
ADD_PROVIDER: Message.MessageType.ValueType # 2
GET_PROVIDERS: Message.MessageType.ValueType # 3
FIND_NODE: Message.MessageType.ValueType # 4
PING: Message.MessageType.ValueType # 5
class _ConnectionType:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _ConnectionTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._ConnectionType.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
NOT_CONNECTED: Message._ConnectionType.ValueType # 0
CONNECTED: Message._ConnectionType.ValueType # 1
CAN_CONNECT: Message._ConnectionType.ValueType # 2
CANNOT_CONNECT: Message._ConnectionType.ValueType # 3
class ConnectionType(_ConnectionType, metaclass=_ConnectionTypeEnumTypeWrapper): ...
NOT_CONNECTED: Message.ConnectionType.ValueType # 0
CONNECTED: Message.ConnectionType.ValueType # 1
CAN_CONNECT: Message.ConnectionType.ValueType # 2
CANNOT_CONNECT: Message.ConnectionType.ValueType # 3
@typing.final
class Peer(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
ID_FIELD_NUMBER: builtins.int
ADDRS_FIELD_NUMBER: builtins.int
CONNECTION_FIELD_NUMBER: builtins.int
id: builtins.bytes
connection: global___Message.ConnectionType.ValueType
@property
def addrs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ...
def __init__(
self,
*,
id: builtins.bytes = ...,
addrs: collections.abc.Iterable[builtins.bytes] | None = ...,
connection: global___Message.ConnectionType.ValueType = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["addrs", b"addrs", "connection", b"connection", "id", b"id"]) -> None: ...
TYPE_FIELD_NUMBER: builtins.int
CLUSTERLEVELRAW_FIELD_NUMBER: builtins.int
KEY_FIELD_NUMBER: builtins.int
RECORD_FIELD_NUMBER: builtins.int
CLOSERPEERS_FIELD_NUMBER: builtins.int
PROVIDERPEERS_FIELD_NUMBER: builtins.int
type: global___Message.MessageType.ValueType
clusterLevelRaw: builtins.int
key: builtins.bytes
@property
def record(self) -> global___Record: ...
@property
def closerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ...
@property
def providerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ...
def __init__(
self,
*,
type: global___Message.MessageType.ValueType = ...,
clusterLevelRaw: builtins.int = ...,
key: builtins.bytes = ...,
record: global___Record | None = ...,
closerPeers: collections.abc.Iterable[global___Message.Peer] | None = ...,
providerPeers: collections.abc.Iterable[global___Message.Peer] | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["record", b"record"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["closerPeers", b"closerPeers", "clusterLevelRaw", b"clusterLevelRaw", "key", b"key", "providerPeers", b"providerPeers", "record", b"record", "type", b"type"]) -> None: ...
global___Message = Message
class Message(_message.Message):
__slots__ = ("type", "clusterLevelRaw", "key", "record", "closerPeers", "providerPeers", "senderRecord")
class MessageType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
PUT_VALUE: _ClassVar[Message.MessageType]
GET_VALUE: _ClassVar[Message.MessageType]
ADD_PROVIDER: _ClassVar[Message.MessageType]
GET_PROVIDERS: _ClassVar[Message.MessageType]
FIND_NODE: _ClassVar[Message.MessageType]
PING: _ClassVar[Message.MessageType]
PUT_VALUE: Message.MessageType
GET_VALUE: Message.MessageType
ADD_PROVIDER: Message.MessageType
GET_PROVIDERS: Message.MessageType
FIND_NODE: Message.MessageType
PING: Message.MessageType
class ConnectionType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
NOT_CONNECTED: _ClassVar[Message.ConnectionType]
CONNECTED: _ClassVar[Message.ConnectionType]
CAN_CONNECT: _ClassVar[Message.ConnectionType]
CANNOT_CONNECT: _ClassVar[Message.ConnectionType]
NOT_CONNECTED: Message.ConnectionType
CONNECTED: Message.ConnectionType
CAN_CONNECT: Message.ConnectionType
CANNOT_CONNECT: Message.ConnectionType
class Peer(_message.Message):
__slots__ = ("id", "addrs", "connection", "signedRecord")
ID_FIELD_NUMBER: _ClassVar[int]
ADDRS_FIELD_NUMBER: _ClassVar[int]
CONNECTION_FIELD_NUMBER: _ClassVar[int]
SIGNEDRECORD_FIELD_NUMBER: _ClassVar[int]
id: bytes
addrs: _containers.RepeatedScalarFieldContainer[bytes]
connection: Message.ConnectionType
signedRecord: bytes
def __init__(self, id: _Optional[bytes] = ..., addrs: _Optional[_Iterable[bytes]] = ..., connection: _Optional[_Union[Message.ConnectionType, str]] = ..., signedRecord: _Optional[bytes] = ...) -> None: ...
TYPE_FIELD_NUMBER: _ClassVar[int]
CLUSTERLEVELRAW_FIELD_NUMBER: _ClassVar[int]
KEY_FIELD_NUMBER: _ClassVar[int]
RECORD_FIELD_NUMBER: _ClassVar[int]
CLOSERPEERS_FIELD_NUMBER: _ClassVar[int]
PROVIDERPEERS_FIELD_NUMBER: _ClassVar[int]
SENDERRECORD_FIELD_NUMBER: _ClassVar[int]
type: Message.MessageType
clusterLevelRaw: int
key: bytes
record: Record
closerPeers: _containers.RepeatedCompositeFieldContainer[Message.Peer]
providerPeers: _containers.RepeatedCompositeFieldContainer[Message.Peer]
senderRecord: bytes
def __init__(self, type: _Optional[_Union[Message.MessageType, str]] = ..., clusterLevelRaw: _Optional[int] = ..., key: _Optional[bytes] = ..., record: _Optional[_Union[Record, _Mapping]] = ..., closerPeers: _Optional[_Iterable[_Union[Message.Peer, _Mapping]]] = ..., providerPeers: _Optional[_Iterable[_Union[Message.Peer, _Mapping]]] = ..., senderRecord: _Optional[bytes] = ...) -> None: ... # type: ignore

View File

@ -15,12 +15,14 @@ from libp2p.abc import (
INetStream,
IPeerRouting,
)
from libp2p.peer.envelope import Envelope
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
from libp2p.peer.peerstore import env_to_send_in_RPC
from .common import (
ALPHA,
@ -33,6 +35,7 @@ from .routing_table import (
RoutingTable,
)
from .utils import (
maybe_consume_signed_record,
sort_peer_ids_by_distance,
)
@ -170,7 +173,7 @@ class PeerRouting(IPeerRouting):
# Return early if we have no peers to start with
if not closest_peers:
logger.warning("No local peers available for network lookup")
logger.debug("No local peers available for network lookup")
return []
# Iterative lookup until convergence
@ -255,6 +258,10 @@ class PeerRouting(IPeerRouting):
find_node_msg.type = Message.MessageType.FIND_NODE
find_node_msg.key = target_key # Set target key directly as bytes
# Create sender_signed_peer_record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
find_node_msg.senderRecord = envelope_bytes
# Serialize and send the protobuf message with varint length prefix
proto_bytes = find_node_msg.SerializeToString()
logger.debug(
@ -299,7 +306,22 @@ class PeerRouting(IPeerRouting):
# Process closest peers from response
if response_msg.type == Message.MessageType.FIND_NODE:
# Consume the sender_signed_peer_record
if not maybe_consume_signed_record(response_msg, self.host, peer):
logger.error(
"Received an invalid-signed-record,ignoring the response"
)
return []
for peer_data in response_msg.closerPeers:
# Consume the received closer_peers signed-records, peer-id is
# sent with the peer-data
if not maybe_consume_signed_record(peer_data, self.host):
logger.error(
"Received an invalid-signed-record,ignoring the response"
)
return []
new_peer_id = ID(peer_data.id)
if new_peer_id not in results:
results.append(new_peer_id)
@ -332,6 +354,7 @@ class PeerRouting(IPeerRouting):
"""
try:
# Read message length
peer_id = stream.muxed_conn.peer_id
length_bytes = await stream.read(4)
if not length_bytes:
return
@ -345,10 +368,18 @@ class PeerRouting(IPeerRouting):
# Parse protobuf message
kad_message = Message()
closer_peer_envelope: Envelope | None = None
try:
kad_message.ParseFromString(message_bytes)
if kad_message.type == Message.MessageType.FIND_NODE:
# Consume the sender's signed-peer-record if sent
if not maybe_consume_signed_record(kad_message, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, dropping the stream"
)
return
# Get target key directly from protobuf message
target_key = kad_message.key
@ -361,12 +392,26 @@ class PeerRouting(IPeerRouting):
response = Message()
response.type = Message.MessageType.FIND_NODE
# Create sender_signed_peer_record for the response
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
# Add peer information to response
for peer_id in closest_peers:
peer_proto = response.closerPeers.add()
peer_proto.id = peer_id.to_bytes()
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
# Add the signed-records of closest_peers if cached
closer_peer_envelope = (
self.host.get_peerstore().get_peer_record(peer_id)
)
if isinstance(closer_peer_envelope, Envelope):
peer_proto.signedRecord = (
closer_peer_envelope.marshal_envelope()
)
# Add addresses if available
try:
addrs = self.host.get_peerstore().addrs(peer_id)

View File

@ -22,12 +22,14 @@ from libp2p.abc import (
from libp2p.custom_types import (
TProtocol,
)
from libp2p.kad_dht.utils import maybe_consume_signed_record
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
from libp2p.peer.peerstore import env_to_send_in_RPC
from .common import (
ALPHA,
@ -240,11 +242,18 @@ class ProviderStore:
message.type = Message.MessageType.ADD_PROVIDER
message.key = key
# Create sender's signed-peer-record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
message.senderRecord = envelope_bytes
# Add our provider info
provider = message.providerPeers.add()
provider.id = self.local_peer_id.to_bytes()
provider.addrs.extend(addrs)
# Add the provider's signed-peer-record
provider.signedRecord = envelope_bytes
# Serialize and send the message
proto_bytes = message.SerializeToString()
await stream.write(varint.encode(len(proto_bytes)))
@ -276,10 +285,15 @@ class ProviderStore:
response = Message()
response.ParseFromString(response_bytes)
# Check response type
response.type == Message.MessageType.ADD_PROVIDER
if response.type:
result = True
if response.type == Message.MessageType.ADD_PROVIDER:
# Consume the sender's signed-peer-record if sent
if not maybe_consume_signed_record(response, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, ignoring the response"
)
result = False
else:
result = True
except Exception as e:
logger.warning(f"Error sending ADD_PROVIDER to {peer_id}: {e}")
@ -380,6 +394,10 @@ class ProviderStore:
message.type = Message.MessageType.GET_PROVIDERS
message.key = key
# Create sender's signed-peer-record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
message.senderRecord = envelope_bytes
# Serialize and send the message
proto_bytes = message.SerializeToString()
await stream.write(varint.encode(len(proto_bytes)))
@ -414,10 +432,26 @@ class ProviderStore:
if response.type != Message.MessageType.GET_PROVIDERS:
return []
# Consume the sender's signed-peer-record if sent
if not maybe_consume_signed_record(response, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, ignoring the response"
)
return []
# Extract provider information
providers = []
for provider_proto in response.providerPeers:
try:
# Consume the provider's signed-peer-record if sent, peer-id
# already sent with the provider-proto
if not maybe_consume_signed_record(provider_proto, self.host):
logger.error(
"Received an invalid-signed-record, "
"ignoring the response"
)
return []
# Create peer ID from bytes
provider_id = ID(provider_proto.id)
@ -431,6 +465,7 @@ class ProviderStore:
# Create PeerInfo and add to result
providers.append(PeerInfo(provider_id, addrs))
except Exception as e:
logger.warning(f"Failed to parse provider info: {e}")

View File

@ -8,6 +8,7 @@ from collections import (
import logging
import time
import multihash
import trio
from libp2p.abc import (
@ -40,6 +41,22 @@ PEER_REFRESH_INTERVAL = 60 # Interval to refresh peers in seconds
STALE_PEER_THRESHOLD = 3600 # Time in seconds after which a peer is considered stale
def peer_id_to_key(peer_id: ID) -> bytes:
"""
Convert a peer ID to a 256-bit key for routing table operations.
This normalizes all peer IDs to exactly 256 bits by hashing them with SHA-256.
:param peer_id: The peer ID to convert
:return: 32-byte (256-bit) key for routing table operations
"""
return multihash.digest(peer_id.to_bytes(), "sha2-256").digest
def key_to_int(key: bytes) -> int:
"""Convert a 256-bit key to an integer for range calculations."""
return int.from_bytes(key, byteorder="big")
class KBucket:
"""
A k-bucket implementation for the Kademlia DHT.
@ -357,9 +374,24 @@ class KBucket:
True if the key is in range, False otherwise
"""
key_int = int.from_bytes(key, byteorder="big")
key_int = key_to_int(key)
return self.min_range <= key_int < self.max_range
def peer_id_in_range(self, peer_id: ID) -> bool:
"""
Check if a peer ID is in the range of this bucket.
params: peer_id: The peer ID to check
Returns
-------
bool
True if the peer ID is in range, False otherwise
"""
key = peer_id_to_key(peer_id)
return self.key_in_range(key)
def split(self) -> tuple["KBucket", "KBucket"]:
"""
Split the bucket into two buckets.
@ -376,8 +408,9 @@ class KBucket:
# Redistribute peers
for peer_id, (peer_info, timestamp) in self.peers.items():
peer_key = int.from_bytes(peer_id.to_bytes(), byteorder="big")
if peer_key < midpoint:
peer_key = peer_id_to_key(peer_id)
peer_key_int = key_to_int(peer_key)
if peer_key_int < midpoint:
lower_bucket.peers[peer_id] = (peer_info, timestamp)
else:
upper_bucket.peers[peer_id] = (peer_info, timestamp)
@ -458,7 +491,38 @@ class RoutingTable:
success = await bucket.add_peer(peer_info)
if success:
logger.debug(f"Successfully added peer {peer_id} to routing table")
return success
return True
# If bucket is full and couldn't add peer, try splitting the bucket
# Only split if the bucket contains our Peer ID
if self._should_split_bucket(bucket):
logger.debug(
f"Bucket is full, attempting to split bucket for peer {peer_id}"
)
split_success = self._split_bucket(bucket)
if split_success:
# After splitting,
# find the appropriate bucket for the peer and try to add it
target_bucket = self.find_bucket(peer_info.peer_id)
success = await target_bucket.add_peer(peer_info)
if success:
logger.debug(
f"Successfully added peer {peer_id} after bucket split"
)
return True
else:
logger.debug(
f"Failed to add peer {peer_id} even after bucket split"
)
return False
else:
logger.debug(f"Failed to split bucket for peer {peer_id}")
return False
else:
logger.debug(
f"Bucket is full and cannot be split, peer {peer_id} not added"
)
return False
except Exception as e:
logger.debug(f"Error adding peer {peer_obj} to routing table: {e}")
@ -480,9 +544,9 @@ class RoutingTable:
def find_bucket(self, peer_id: ID) -> KBucket:
"""
Find the bucket that would contain the given peer ID or PeerInfo.
Find the bucket that would contain the given peer ID.
:param peer_obj: Either a peer ID or a PeerInfo object
:param peer_id: The peer ID to find a bucket for
Returns
-------
@ -490,7 +554,7 @@ class RoutingTable:
"""
for bucket in self.buckets:
if bucket.key_in_range(peer_id.to_bytes()):
if bucket.peer_id_in_range(peer_id):
return bucket
return self.buckets[0]
@ -513,7 +577,11 @@ class RoutingTable:
all_peers.extend(bucket.peer_ids())
# Sort by XOR distance to the key
all_peers.sort(key=lambda p: xor_distance(p.to_bytes(), key))
def distance_to_key(peer_id: ID) -> int:
peer_key = peer_id_to_key(peer_id)
return xor_distance(peer_key, key)
all_peers.sort(key=distance_to_key)
return all_peers[:count]
@ -591,6 +659,20 @@ class RoutingTable:
stale_peers.extend(bucket.get_stale_peers(stale_threshold_seconds))
return stale_peers
def get_peer_infos(self) -> list[PeerInfo]:
"""
Get all PeerInfo objects in the routing table.
Returns
-------
List[PeerInfo]: List of all PeerInfo objects
"""
peer_infos = []
for bucket in self.buckets:
peer_infos.extend(bucket.peer_infos())
return peer_infos
def cleanup_routing_table(self) -> None:
"""
Cleanup the routing table by removing all data.
@ -598,3 +680,66 @@ class RoutingTable:
"""
self.buckets = [KBucket(self.host, BUCKET_SIZE)]
logger.info("Routing table cleaned up, all data removed.")
def _should_split_bucket(self, bucket: KBucket) -> bool:
"""
Check if a bucket should be split according to Kademlia rules.
:param bucket: The bucket to check
:return: True if the bucket should be split
"""
# Check if we've exceeded maximum buckets
if len(self.buckets) >= MAXIMUM_BUCKETS:
logger.debug("Maximum number of buckets reached, cannot split")
return False
# Check if the bucket contains our local ID
local_key = peer_id_to_key(self.local_id)
local_key_int = key_to_int(local_key)
contains_local_id = bucket.min_range <= local_key_int < bucket.max_range
logger.debug(
f"Bucket range: {bucket.min_range} - {bucket.max_range}, "
f"local_key_int: {local_key_int}, contains_local: {contains_local_id}"
)
return contains_local_id
def _split_bucket(self, bucket: KBucket) -> bool:
"""
Split a bucket into two buckets.
:param bucket: The bucket to split
:return: True if the bucket was successfully split
"""
try:
# Find the bucket index
bucket_index = self.buckets.index(bucket)
logger.debug(f"Splitting bucket at index {bucket_index}")
# Split the bucket
lower_bucket, upper_bucket = bucket.split()
# Replace the original bucket with the two new buckets
self.buckets[bucket_index] = lower_bucket
self.buckets.insert(bucket_index + 1, upper_bucket)
logger.debug(
f"Bucket split successful. New bucket count: {len(self.buckets)}"
)
logger.debug(
f"Lower bucket range: "
f"{lower_bucket.min_range} - {lower_bucket.max_range}, "
f"peers: {lower_bucket.size()}"
)
logger.debug(
f"Upper bucket range: "
f"{upper_bucket.min_range} - {upper_bucket.max_range}, "
f"peers: {upper_bucket.size()}"
)
return True
except Exception as e:
logger.error(f"Error splitting bucket: {e}")
return False

View File

@ -2,13 +2,93 @@
Utility functions for Kademlia DHT implementation.
"""
import logging
import base58
import multihash
from libp2p.abc import IHost
from libp2p.peer.envelope import consume_envelope
from libp2p.peer.id import (
ID,
)
from .pb.kademlia_pb2 import (
Message,
)
logger = logging.getLogger("kademlia-example.utils")
def maybe_consume_signed_record(
msg: Message | Message.Peer, host: IHost, peer_id: ID | None = None
) -> bool:
"""
Attempt to parse and store a signed-peer-record (Envelope) received during
DHT communication. If the record is invalid, the peer-id does not match, or
updating the peerstore fails, the function logs an error and returns False.
Parameters
----------
msg : Message | Message.Peer
The protobuf message received during DHT communication. Can either be a
top-level `Message` containing `senderRecord` or a `Message.Peer`
containing `signedRecord`.
host : IHost
The local host instance, providing access to the peerstore for storing
verified peer records.
peer_id : ID | None, optional
The expected peer ID for record validation. If provided, the peer ID
inside the record must match this value.
Returns
-------
bool
True if a valid signed peer record was successfully consumed and stored,
False otherwise.
"""
if isinstance(msg, Message):
if msg.HasField("senderRecord"):
try:
# Convert the signed-peer-record(Envelope) from
# protobuf bytes
envelope, record = consume_envelope(
msg.senderRecord,
"libp2p-peer-record",
)
if not (isinstance(peer_id, ID) and record.peer_id == peer_id):
return False
# Use the default TTL of 2 hours (7200 seconds)
if not host.get_peerstore().consume_peer_record(envelope, 7200):
logger.error("Failed to update the Certified-Addr-Book")
return False
except Exception as e:
logger.error("Failed to update the Certified-Addr-Book: %s", e)
return False
else:
if msg.HasField("signedRecord"):
try:
# Convert the signed-peer-record(Envelope) from
# protobuf bytes
envelope, record = consume_envelope(
msg.signedRecord,
"libp2p-peer-record",
)
if not record.peer_id.to_bytes() == msg.id:
return False
# Use the default TTL of 2 hours (7200 seconds)
if not host.get_peerstore().consume_peer_record(envelope, 7200):
logger.error("Failed to update the Certified-Addr-Book")
return False
except Exception as e:
logger.error(
"Failed to update the Certified-Addr-Book: %s",
e,
)
return False
return True
def create_key_from_binary(binary_data: bytes) -> bytes:
"""

View File

@ -15,9 +15,11 @@ from libp2p.abc import (
from libp2p.custom_types import (
TProtocol,
)
from libp2p.kad_dht.utils import maybe_consume_signed_record
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerstore import env_to_send_in_RPC
from .common import (
DEFAULT_TTL,
@ -110,6 +112,10 @@ class ValueStore:
message = Message()
message.type = Message.MessageType.PUT_VALUE
# Create sender's signed-peer-record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
message.senderRecord = envelope_bytes
# Set message fields
message.key = key
message.record.key = key
@ -155,7 +161,13 @@ class ValueStore:
# Check if response is valid
if response.type == Message.MessageType.PUT_VALUE:
if response.key:
# Consume the sender's signed-peer-record if sent
if not maybe_consume_signed_record(response, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, ignoring the response"
)
return False
if response.key == key:
result = True
return result
@ -231,6 +243,10 @@ class ValueStore:
message.type = Message.MessageType.GET_VALUE
message.key = key
# Create sender's signed-peer-record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
message.senderRecord = envelope_bytes
# Serialize and send the protobuf message
proto_bytes = message.SerializeToString()
await stream.write(varint.encode(len(proto_bytes)))
@ -275,6 +291,13 @@ class ValueStore:
and response.HasField("record")
and response.record.value
):
# Consume the sender's signed-peer-record
if not maybe_consume_signed_record(response, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, ignoring the response"
)
return None
logger.debug(
f"Received value for key {key.hex()} from peer {peer_id}"
)

View File

@ -23,7 +23,8 @@ if TYPE_CHECKING:
"""
Reference: https://github.com/libp2p/go-libp2p-swarm/blob/04c86bbdafd390651cb2ee14e334f7caeedad722/swarm_conn.go
Reference: https://github.com/libp2p/go-libp2p-swarm/blob/
04c86bbdafd390651cb2ee14e334f7caeedad722/swarm_conn.go
"""
@ -43,6 +44,21 @@ class SwarmConn(INetConn):
self.streams = set()
self.event_closed = trio.Event()
self.event_started = trio.Event()
# Provide back-references/hooks expected by NetStream
try:
setattr(self.muxed_conn, "swarm", self.swarm)
# NetStream expects an awaitable remove_stream hook
async def _remove_stream_hook(stream: NetStream) -> None:
self.remove_stream(stream)
setattr(self.muxed_conn, "remove_stream", _remove_stream_hook)
except Exception as e:
logging.warning(
f"Failed to set optional conveniences on muxed_conn "
f"for peer {muxed_conn.peer_id}: {e}"
)
# optional conveniences
if hasattr(muxed_conn, "on_close"):
logging.debug(f"Setting on_close for peer {muxed_conn.peer_id}")
setattr(muxed_conn, "on_close", self._on_muxed_conn_closed)

View File

@ -1,4 +1,10 @@
from collections.abc import (
Awaitable,
Callable,
)
from dataclasses import dataclass
import logging
import random
from multiaddr import (
Multiaddr,
@ -55,6 +61,59 @@ from .exceptions import (
logger = logging.getLogger("libp2p.network.swarm")
@dataclass
class RetryConfig:
"""
Configuration for retry logic with exponential backoff.
This configuration controls how connection attempts are retried when they fail.
The retry mechanism uses exponential backoff with jitter to prevent thundering
herd problems in distributed systems.
Attributes:
max_retries: Maximum number of retry attempts before giving up.
Default: 3 attempts
initial_delay: Initial delay in seconds before the first retry.
Default: 0.1 seconds (100ms)
max_delay: Maximum delay cap in seconds to prevent excessive wait times.
Default: 30.0 seconds
backoff_multiplier: Multiplier for exponential backoff (each retry multiplies
the delay by this factor). Default: 2.0 (doubles each time)
jitter_factor: Random jitter factor (0.0-1.0) to add randomness to delays
and prevent synchronized retries. Default: 0.1 (10% jitter)
"""
max_retries: int = 3
initial_delay: float = 0.1
max_delay: float = 30.0
backoff_multiplier: float = 2.0
jitter_factor: float = 0.1
@dataclass
class ConnectionConfig:
"""
Configuration for multi-connection support.
This configuration controls how multiple connections per peer are managed,
including connection limits, timeouts, and load balancing strategies.
Attributes:
max_connections_per_peer: Maximum number of connections allowed to a single
peer. Default: 3 connections
connection_timeout: Timeout in seconds for establishing new connections.
Default: 30.0 seconds
load_balancing_strategy: Strategy for distributing streams across connections.
Options: "round_robin" (default) or "least_loaded"
"""
max_connections_per_peer: int = 3
connection_timeout: float = 30.0
load_balancing_strategy: str = "round_robin" # or "least_loaded"
def create_default_stream_handler(network: INetworkService) -> StreamHandlerFn:
async def stream_handler(stream: INetStream) -> None:
await network.get_manager().wait_finished()
@ -67,9 +126,8 @@ class Swarm(Service, INetworkService):
peerstore: IPeerStore
upgrader: TransportUpgrader
transport: ITransport
# TODO: Connection and `peer_id` are 1-1 mapping in our implementation,
# whereas in Go one `peer_id` may point to multiple connections.
connections: dict[ID, INetConn]
# Enhanced: Support for multiple connections per peer
connections: dict[ID, list[INetConn]] # Multiple connections per peer
listeners: dict[str, IListener]
common_stream_handler: StreamHandlerFn
listener_nursery: trio.Nursery | None
@ -77,18 +135,31 @@ class Swarm(Service, INetworkService):
notifees: list[INotifee]
# Enhanced: New configuration
retry_config: RetryConfig
connection_config: ConnectionConfig
_round_robin_index: dict[ID, int]
def __init__(
self,
peer_id: ID,
peerstore: IPeerStore,
upgrader: TransportUpgrader,
transport: ITransport,
retry_config: RetryConfig | None = None,
connection_config: ConnectionConfig | None = None,
):
self.self_id = peer_id
self.peerstore = peerstore
self.upgrader = upgrader
self.transport = transport
self.connections = dict()
# Enhanced: Initialize retry and connection configuration
self.retry_config = retry_config or RetryConfig()
self.connection_config = connection_config or ConnectionConfig()
# Enhanced: Initialize connections as 1:many mapping
self.connections = {}
self.listeners = dict()
# Create Notifee array
@ -99,6 +170,9 @@ class Swarm(Service, INetworkService):
self.listener_nursery = None
self.event_listener_nursery_created = trio.Event()
# Load balancing state
self._round_robin_index = {}
async def run(self) -> None:
async with trio.open_nursery() as nursery:
# Create a nursery for listener tasks.
@ -118,18 +192,74 @@ class Swarm(Service, INetworkService):
def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None:
self.common_stream_handler = stream_handler
async def dial_peer(self, peer_id: ID) -> INetConn:
def get_connections(self, peer_id: ID | None = None) -> list[INetConn]:
"""
Try to create a connection to peer_id.
Get connections for peer (like JS getConnections, Go ConnsToPeer).
Parameters
----------
peer_id : ID | None
The peer ID to get connections for. If None, returns all connections.
Returns
-------
list[INetConn]
List of connections to the specified peer, or all connections
if peer_id is None.
"""
if peer_id is not None:
return self.connections.get(peer_id, [])
# Return all connections from all peers
all_conns = []
for conns in self.connections.values():
all_conns.extend(conns)
return all_conns
def get_connections_map(self) -> dict[ID, list[INetConn]]:
"""
Get all connections map (like JS getConnectionsMap).
Returns
-------
dict[ID, list[INetConn]]
The complete mapping of peer IDs to their connection lists.
"""
return self.connections.copy()
def get_connection(self, peer_id: ID) -> INetConn | None:
"""
Get single connection for backward compatibility.
Parameters
----------
peer_id : ID
The peer ID to get a connection for.
Returns
-------
INetConn | None
The first available connection, or None if no connections exist.
"""
conns = self.get_connections(peer_id)
return conns[0] if conns else None
async def dial_peer(self, peer_id: ID) -> list[INetConn]:
"""
Try to create connections to peer_id with enhanced retry logic.
:param peer_id: peer if we want to dial
:raises SwarmException: raised when an error occurs
:return: muxed connection
:return: list of muxed connections
"""
if peer_id in self.connections:
# If muxed connection already exists for peer_id,
# set muxed connection equal to existing muxed connection
return self.connections[peer_id]
# Check if we already have connections
existing_connections = self.get_connections(peer_id)
if existing_connections:
logger.debug(f"Reusing existing connections to peer {peer_id}")
return existing_connections
logger.debug("attempting to dial peer %s", peer_id)
@ -142,12 +272,19 @@ class Swarm(Service, INetworkService):
if not addrs:
raise SwarmException(f"No known addresses to peer {peer_id}")
connections = []
exceptions: list[SwarmException] = []
# Try all known addresses
# Enhanced: Try all known addresses with retry logic
for multiaddr in addrs:
try:
return await self.dial_addr(multiaddr, peer_id)
connection = await self._dial_with_retry(multiaddr, peer_id)
connections.append(connection)
# Limit number of connections per peer
if len(connections) >= self.connection_config.max_connections_per_peer:
break
except SwarmException as e:
exceptions.append(e)
logger.debug(
@ -157,15 +294,73 @@ class Swarm(Service, INetworkService):
exc_info=e,
)
# Tried all addresses, raising exception.
raise SwarmException(
f"unable to connect to {peer_id}, no addresses established a successful "
"connection (with exceptions)"
) from MultiError(exceptions)
if not connections:
# Tried all addresses, raising exception.
raise SwarmException(
f"unable to connect to {peer_id}, no addresses established a "
"successful connection (with exceptions)"
) from MultiError(exceptions)
async def dial_addr(self, addr: Multiaddr, peer_id: ID) -> INetConn:
return connections
async def _dial_with_retry(self, addr: Multiaddr, peer_id: ID) -> INetConn:
"""
Try to create a connection to peer_id with addr.
Enhanced: Dial with retry logic and exponential backoff.
:param addr: the address to dial
:param peer_id: the peer we want to connect to
:raises SwarmException: raised when all retry attempts fail
:return: network connection
"""
last_exception = None
for attempt in range(self.retry_config.max_retries + 1):
try:
return await self._dial_addr_single_attempt(addr, peer_id)
except Exception as e:
last_exception = e
if attempt < self.retry_config.max_retries:
delay = self._calculate_backoff_delay(attempt)
logger.debug(
f"Connection attempt {attempt + 1} failed, "
f"retrying in {delay:.2f}s: {e}"
)
await trio.sleep(delay)
else:
logger.debug(f"All {self.retry_config.max_retries} attempts failed")
# Convert the last exception to SwarmException for consistency
if last_exception is not None:
if isinstance(last_exception, SwarmException):
raise last_exception
else:
raise SwarmException(
f"Failed to connect after {self.retry_config.max_retries} attempts"
) from last_exception
# This should never be reached, but mypy requires it
raise SwarmException("Unexpected error in retry logic")
def _calculate_backoff_delay(self, attempt: int) -> float:
"""
Enhanced: Calculate backoff delay with jitter to prevent thundering herd.
:param attempt: the current attempt number (0-based)
:return: delay in seconds
"""
delay = min(
self.retry_config.initial_delay
* (self.retry_config.backoff_multiplier**attempt),
self.retry_config.max_delay,
)
# Add jitter to prevent synchronized retries
jitter = delay * self.retry_config.jitter_factor
return delay + random.uniform(-jitter, jitter)
async def _dial_addr_single_attempt(self, addr: Multiaddr, peer_id: ID) -> INetConn:
"""
Enhanced: Single attempt to dial an address (extracted from original dial_addr).
:param addr: the address we want to connect with
:param peer_id: the peer we want to connect to
@ -212,19 +407,97 @@ class Swarm(Service, INetworkService):
return swarm_conn
async def dial_addr(self, addr: Multiaddr, peer_id: ID) -> INetConn:
"""
Enhanced: Try to create a connection to peer_id with addr using retry logic.
:param addr: the address we want to connect with
:param peer_id: the peer we want to connect to
:raises SwarmException: raised when an error occurs
:return: network connection
"""
return await self._dial_with_retry(addr, peer_id)
async def new_stream(self, peer_id: ID) -> INetStream:
"""
Enhanced: Create a new stream with load balancing across multiple connections.
:param peer_id: peer_id of destination
:raises SwarmException: raised when an error occurs
:return: net stream instance
"""
logger.debug("attempting to open a stream to peer %s", peer_id)
swarm_conn = await self.dial_peer(peer_id)
# Get existing connections or dial new ones
connections = self.get_connections(peer_id)
if not connections:
connections = await self.dial_peer(peer_id)
net_stream = await swarm_conn.new_stream()
logger.debug("successfully opened a stream to peer %s", peer_id)
return net_stream
# Load balancing strategy at interface level
connection = self._select_connection(connections, peer_id)
try:
net_stream = await connection.new_stream()
logger.debug("successfully opened a stream to peer %s", peer_id)
return net_stream
except Exception as e:
logger.debug(f"Failed to create stream on connection: {e}")
# Try other connections if available
for other_conn in connections:
if other_conn != connection:
try:
net_stream = await other_conn.new_stream()
logger.debug(
f"Successfully opened a stream to peer {peer_id} "
"using alternative connection"
)
return net_stream
except Exception:
continue
# All connections failed, raise exception
raise SwarmException(f"Failed to create stream to peer {peer_id}") from e
def _select_connection(self, connections: list[INetConn], peer_id: ID) -> INetConn:
"""
Select connection based on load balancing strategy.
Parameters
----------
connections : list[INetConn]
List of available connections.
peer_id : ID
The peer ID for round-robin tracking.
strategy : str
Load balancing strategy ("round_robin", "least_loaded", etc.).
Returns
-------
INetConn
Selected connection.
"""
if not connections:
raise ValueError("No connections available")
strategy = self.connection_config.load_balancing_strategy
if strategy == "round_robin":
# Simple round-robin selection
if peer_id not in self._round_robin_index:
self._round_robin_index[peer_id] = 0
index = self._round_robin_index[peer_id] % len(connections)
self._round_robin_index[peer_id] += 1
return connections[index]
elif strategy == "least_loaded":
# Find connection with least streams
return min(connections, key=lambda c: len(c.get_streams()))
else:
# Default to first connection
return connections[0]
async def listen(self, *multiaddrs: Multiaddr) -> bool:
"""
@ -245,9 +518,11 @@ class Swarm(Service, INetworkService):
# We need to wait until `self.listener_nursery` is created.
await self.event_listener_nursery_created.wait()
success_count = 0
for maddr in multiaddrs:
if str(maddr) in self.listeners:
return True
success_count += 1
continue
async def conn_handler(
read_write_closer: ReadWriteCloser, maddr: Multiaddr = maddr
@ -298,13 +573,14 @@ class Swarm(Service, INetworkService):
# Call notifiers since event occurred
await self.notify_listen(maddr)
return True
success_count += 1
logger.debug("successfully started listening on: %s", maddr)
except OSError:
# Failed. Continue looping.
logger.debug("fail to listen on: %s", maddr)
# No maddr succeeded
return False
# Return true if at least one address succeeded
return success_count > 0
async def close(self) -> None:
"""
@ -317,17 +593,25 @@ class Swarm(Service, INetworkService):
# Perform alternative cleanup if the manager isn't initialized
# Close all connections manually
if hasattr(self, "connections"):
for conn_id in list(self.connections.keys()):
conn = self.connections[conn_id]
await conn.close()
for peer_id, conns in list(self.connections.items()):
for conn in conns:
await conn.close()
# Clear connection tracking dictionary
self.connections.clear()
# Close all listeners
if hasattr(self, "listeners"):
for listener in self.listeners.values():
for maddr_str, listener in self.listeners.items():
await listener.close()
# Notify about listener closure
try:
multiaddr = Multiaddr(maddr_str)
await self.notify_listen_close(multiaddr)
except Exception as e:
logger.warning(
f"Failed to notify listen_close for {maddr_str}: {e}"
)
self.listeners.clear()
# Close the transport if it exists and has a close method
@ -341,12 +625,28 @@ class Swarm(Service, INetworkService):
logger.debug("swarm successfully closed")
async def close_peer(self, peer_id: ID) -> None:
if peer_id not in self.connections:
"""
Close all connections to the specified peer.
Parameters
----------
peer_id : ID
The peer ID to close connections for.
"""
connections = self.get_connections(peer_id)
if not connections:
return
connection = self.connections[peer_id]
# NOTE: `connection.close` will delete `peer_id` from `self.connections`
# and `notify_disconnected` for us.
await connection.close()
# Close all connections
for connection in connections:
try:
await connection.close()
except Exception as e:
logger.warning(f"Error closing connection to {peer_id}: {e}")
# Remove from connections dict
self.connections.pop(peer_id, None)
logger.debug("successfully close the connection to peer %s", peer_id)
@ -365,21 +665,71 @@ class Swarm(Service, INetworkService):
await muxed_conn.event_started.wait()
self.manager.run_task(swarm_conn.start)
await swarm_conn.event_started.wait()
# Store muxed_conn with peer id
self.connections[muxed_conn.peer_id] = swarm_conn
# Add to connections dict with deduplication
peer_id = muxed_conn.peer_id
if peer_id not in self.connections:
self.connections[peer_id] = []
# Check for duplicate connections by comparing the underlying muxed connection
for existing_conn in self.connections[peer_id]:
if existing_conn.muxed_conn == muxed_conn:
logger.debug(f"Connection already exists for peer {peer_id}")
# existing_conn is a SwarmConn since it's stored in the connections list
return existing_conn # type: ignore[return-value]
self.connections[peer_id].append(swarm_conn)
# Trim if we exceed max connections
max_conns = self.connection_config.max_connections_per_peer
if len(self.connections[peer_id]) > max_conns:
self._trim_connections(peer_id)
# Call notifiers since event occurred
await self.notify_connected(swarm_conn)
return swarm_conn
def _trim_connections(self, peer_id: ID) -> None:
"""
Remove oldest connections when limit is exceeded.
"""
connections = self.connections[peer_id]
if len(connections) <= self.connection_config.max_connections_per_peer:
return
# Sort by creation time and remove oldest
# For now, just keep the most recent connections
max_conns = self.connection_config.max_connections_per_peer
connections_to_remove = connections[:-max_conns]
for conn in connections_to_remove:
logger.debug(f"Trimming old connection for peer {peer_id}")
trio.lowlevel.spawn_system_task(self._close_connection_async, conn)
# Keep only the most recent connections
max_conns = self.connection_config.max_connections_per_peer
self.connections[peer_id] = connections[-max_conns:]
async def _close_connection_async(self, connection: INetConn) -> None:
"""Close a connection asynchronously."""
try:
await connection.close()
except Exception as e:
logger.warning(f"Error closing connection: {e}")
def remove_conn(self, swarm_conn: SwarmConn) -> None:
"""
Simply remove the connection from Swarm's records, without closing
the connection.
"""
peer_id = swarm_conn.muxed_conn.peer_id
if peer_id not in self.connections:
return
del self.connections[peer_id]
if peer_id in self.connections:
self.connections[peer_id] = [
conn for conn in self.connections[peer_id] if conn != swarm_conn
]
if not self.connections[peer_id]:
del self.connections[peer_id]
# Notifee
@ -411,7 +761,35 @@ class Swarm(Service, INetworkService):
nursery.start_soon(notifee.listen, self, multiaddr)
async def notify_closed_stream(self, stream: INetStream) -> None:
raise NotImplementedError
async with trio.open_nursery() as nursery:
for notifee in self.notifees:
nursery.start_soon(notifee.closed_stream, self, stream)
async def notify_listen_close(self, multiaddr: Multiaddr) -> None:
raise NotImplementedError
async with trio.open_nursery() as nursery:
for notifee in self.notifees:
nursery.start_soon(notifee.listen_close, self, multiaddr)
# Generic notifier used by NetStream._notify_closed
async def notify_all(self, notifier: Callable[[INotifee], Awaitable[None]]) -> None:
async with trio.open_nursery() as nursery:
for notifee in self.notifees:
nursery.start_soon(notifier, notifee)
# Backward compatibility properties
@property
def connections_legacy(self) -> dict[ID, INetConn]:
"""
Legacy 1:1 mapping for backward compatibility.
Returns
-------
dict[ID, INetConn]
Legacy mapping with only the first connection per peer.
"""
legacy_conns = {}
for peer_id, conns in self.connections.items():
if conns:
legacy_conns[peer_id] = conns[0]
return legacy_conns

View File

@ -1,5 +1,7 @@
from typing import Any, cast
import multiaddr
from libp2p.crypto.ed25519 import Ed25519PublicKey
from libp2p.crypto.keys import PrivateKey, PublicKey
from libp2p.crypto.rsa import RSAPublicKey
@ -131,6 +133,9 @@ class Envelope:
)
return False
def _env_addrs_set(self) -> set[multiaddr.Multiaddr]:
return {b for b in self.record().addrs}
def pub_key_to_protobuf(pub_key: PublicKey) -> cryto_pb.PublicKey:
"""

View File

@ -16,6 +16,7 @@ import trio
from trio import MemoryReceiveChannel, MemorySendChannel
from libp2p.abc import (
IHost,
IPeerStore,
)
from libp2p.crypto.keys import (
@ -23,7 +24,8 @@ from libp2p.crypto.keys import (
PrivateKey,
PublicKey,
)
from libp2p.peer.envelope import Envelope
from libp2p.peer.envelope import Envelope, seal_record
from libp2p.peer.peer_record import PeerRecord
from .id import (
ID,
@ -39,8 +41,86 @@ from .peerinfo import (
PERMANENT_ADDR_TTL = 0
# TODO: Set up an async task for periodic peer-store cleanup
# for expired addresses and records.
def create_signed_peer_record(
peer_id: ID, addrs: list[Multiaddr], pvt_key: PrivateKey
) -> Envelope:
"""Creates a signed_peer_record wrapped in an Envelope"""
record = PeerRecord(peer_id, addrs)
envelope = seal_record(record, pvt_key)
return envelope
def env_to_send_in_RPC(host: IHost) -> tuple[bytes, bool]:
"""
Return the signed peer record (Envelope) to be sent in an RPC.
This function checks whether the host already has a cached signed peer record
(SPR). If one exists and its addresses match the host's current listen
addresses, the cached envelope is reused. Otherwise, a new signed peer record
is created, cached, and returned.
Parameters
----------
host : IHost
The local host instance, providing access to peer ID, listen addresses,
private key, and the peerstore.
Returns
-------
tuple[bytes, bool]
A 2-tuple where the first element is the serialized envelope (bytes)
for the signed peer record, and the second element is a boolean flag
indicating whether a new record was created (True) or an existing cached
one was reused (False).
"""
listen_addrs_set = {addr for addr in host.get_addrs()}
local_env = host.get_peerstore().get_local_record()
if local_env is None:
# No cached SPR yet -> create one
return issue_and_cache_local_record(host), True
else:
record_addrs_set = local_env._env_addrs_set()
if record_addrs_set == listen_addrs_set:
# Perfect match -> reuse cached envelope
return local_env.marshal_envelope(), False
else:
# Addresses changed -> issue a new SPR and cache it
return issue_and_cache_local_record(host), True
def issue_and_cache_local_record(host: IHost) -> bytes:
"""
Create and cache a new signed peer record (Envelope) for the host.
This function generates a new signed peer record from the hosts peer ID,
listen addresses, and private key. The resulting envelope is stored in
the peerstore as the local record for future reuse.
Parameters
----------
host : IHost
The local host instance, providing access to peer ID, listen addresses,
private key, and the peerstore.
Returns
-------
bytes
The serialized envelope (bytes) representing the newly created signed
peer record.
"""
env = create_signed_peer_record(
host.get_id(),
host.get_addrs(),
host.get_private_key(),
)
# Cache it for next time use
host.get_peerstore().set_local_record(env)
return env.marshal_envelope()
class PeerRecordState:
envelope: Envelope
seq: int
@ -57,8 +137,17 @@ class PeerStore(IPeerStore):
self.peer_data_map = defaultdict(PeerData)
self.addr_update_channels: dict[ID, MemorySendChannel[Multiaddr]] = {}
self.peer_record_map: dict[ID, PeerRecordState] = {}
self.local_peer_record: Envelope | None = None
self.max_records = max_records
def get_local_record(self) -> Envelope | None:
"""Get the local-signed-record wrapped in Envelope"""
return self.local_peer_record
def set_local_record(self, envelope: Envelope) -> None:
"""Set the local-signed-record wrapped in Envelope"""
self.local_peer_record = envelope
def peer_info(self, peer_id: ID) -> PeerInfo:
"""
:param peer_id: peer ID to get info for
@ -217,7 +306,6 @@ class PeerStore(IPeerStore):
# -----CERT-ADDR-BOOK-----
# TODO: Make proper use of this function
def maybe_delete_peer_record(self, peer_id: ID) -> None:
"""
Delete the signed peer record for a peer if it has no know

View File

@ -48,12 +48,11 @@ class Multiselect(IMultiselectMuxer):
"""
self.handlers[protocol] = handler
# FIXME: Make TProtocol Optional[TProtocol] to keep types consistent
async def negotiate(
self,
communicator: IMultiselectCommunicator,
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
) -> tuple[TProtocol, StreamHandlerFn | None]:
) -> tuple[TProtocol | None, StreamHandlerFn | None]:
"""
Negotiate performs protocol selection.
@ -84,14 +83,14 @@ class Multiselect(IMultiselectMuxer):
raise MultiselectError() from error
else:
protocol = TProtocol(command)
if protocol in self.handlers:
protocol_to_check = None if not command else TProtocol(command)
if protocol_to_check in self.handlers:
try:
await communicator.write(protocol)
await communicator.write(command)
except MultiselectCommunicatorError as error:
raise MultiselectError() from error
return protocol, self.handlers[protocol]
return protocol_to_check, self.handlers[protocol_to_check]
try:
await communicator.write(PROTOCOL_NOT_FOUND_MSG)
except MultiselectCommunicatorError as error:

View File

@ -134,8 +134,10 @@ class MultiselectClient(IMultiselectClient):
:raise MultiselectClientError: raised when protocol negotiation failed
:return: selected protocol
"""
# Represent `None` protocol as an empty string.
protocol_str = protocol if protocol is not None else ""
try:
await communicator.write(protocol)
await communicator.write(protocol_str)
except MultiselectCommunicatorError as error:
raise MultiselectClientError() from error
@ -145,7 +147,7 @@ class MultiselectClient(IMultiselectClient):
except MultiselectCommunicatorError as error:
raise MultiselectClientError() from error
if response == protocol:
if response == protocol_str:
return protocol
if response == PROTOCOL_NOT_FOUND_MSG:
raise MultiselectClientError("protocol not supported")

View File

@ -30,7 +30,10 @@ class MultiselectCommunicator(IMultiselectCommunicator):
"""
:raise MultiselectCommunicatorError: raised when failed to write to underlying reader
""" # noqa: E501
msg_bytes = encode_delim(msg_str.encode())
if msg_str is None:
msg_bytes = encode_delim(b"")
else:
msg_bytes = encode_delim(msg_str.encode())
try:
await self.read_writer.write(msg_bytes)
except IOException as error:

View File

@ -15,6 +15,7 @@ from libp2p.custom_types import (
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerstore import env_to_send_in_RPC
from .exceptions import (
PubsubRouterError,
@ -103,6 +104,11 @@ class FloodSub(IPubsubRouter):
)
rpc_msg = rpc_pb2.RPC(publish=[pubsub_msg])
# Add the senderRecord of the peer in the RPC msg
if isinstance(self.pubsub, Pubsub):
envelope_bytes, _ = env_to_send_in_RPC(self.pubsub.host)
rpc_msg.senderRecord = envelope_bytes
logger.debug("publishing message %s", pubsub_msg)
if self.pubsub is None:

View File

@ -1,6 +1,3 @@
from ast import (
literal_eval,
)
from collections import (
defaultdict,
)
@ -22,6 +19,7 @@ from libp2p.abc import (
IPubsubRouter,
)
from libp2p.custom_types import (
MessageID,
TProtocol,
)
from libp2p.peer.id import (
@ -34,10 +32,12 @@ from libp2p.peer.peerinfo import (
)
from libp2p.peer.peerstore import (
PERMANENT_ADDR_TTL,
env_to_send_in_RPC,
)
from libp2p.pubsub import (
floodsub,
)
from libp2p.pubsub.utils import maybe_consume_signed_record
from libp2p.tools.async_service import (
Service,
)
@ -54,6 +54,10 @@ from .pb import (
from .pubsub import (
Pubsub,
)
from .utils import (
parse_message_id_safe,
safe_parse_message_id,
)
PROTOCOL_ID = TProtocol("/meshsub/1.0.0")
PROTOCOL_ID_V11 = TProtocol("/meshsub/1.1.0")
@ -226,6 +230,12 @@ class GossipSub(IPubsubRouter, Service):
:param rpc: RPC message
:param sender_peer_id: id of the peer who sent the message
"""
# Process the senderRecord if sent
if isinstance(self.pubsub, Pubsub):
if not maybe_consume_signed_record(rpc, self.pubsub.host, sender_peer_id):
logger.error("Received an invalid-signed-record, ignoring the message")
return
control_message = rpc.control
# Relay each rpc control message to the appropriate handler
@ -253,6 +263,11 @@ class GossipSub(IPubsubRouter, Service):
)
rpc_msg = rpc_pb2.RPC(publish=[pubsub_msg])
# Add the senderRecord of the peer in the RPC msg
if isinstance(self.pubsub, Pubsub):
envelope_bytes, _ = env_to_send_in_RPC(self.pubsub.host)
rpc_msg.senderRecord = envelope_bytes
logger.debug("publishing message %s", pubsub_msg)
for peer_id in peers_gen:
@ -775,16 +790,16 @@ class GossipSub(IPubsubRouter, Service):
# Get list of all seen (seqnos, from) from the (seqno, from) tuples in
# seen_messages cache
seen_seqnos_and_peers = [
seqno_and_from for seqno_and_from in self.pubsub.seen_messages.cache.keys()
str(seqno_and_from)
for seqno_and_from in self.pubsub.seen_messages.cache.keys()
]
# Add all unknown message ids (ids that appear in ihave_msg but not in
# seen_seqnos) to list of messages we want to request
# FIXME: Update type of message ID
msg_ids_wanted: list[Any] = [
msg_id
msg_ids_wanted: list[MessageID] = [
parse_message_id_safe(msg_id)
for msg_id in ihave_msg.messageIDs
if literal_eval(msg_id) not in seen_seqnos_and_peers
if msg_id not in seen_seqnos_and_peers
]
# Request messages with IWANT message
@ -798,9 +813,9 @@ class GossipSub(IPubsubRouter, Service):
Forwards all request messages that are present in mcache to the
requesting peer.
"""
# FIXME: Update type of message ID
# FIXME: Find a better way to parse the msg ids
msg_ids: list[Any] = [literal_eval(msg) for msg in iwant_msg.messageIDs]
msg_ids: list[tuple[bytes, bytes]] = [
safe_parse_message_id(msg) for msg in iwant_msg.messageIDs
]
msgs_to_forward: list[rpc_pb2.Message] = []
for msg_id_iwant in msg_ids:
# Check if the wanted message ID is present in mcache
@ -818,6 +833,13 @@ class GossipSub(IPubsubRouter, Service):
# 1) Package these messages into a single packet
packet: rpc_pb2.RPC = rpc_pb2.RPC()
# Here the an RPC message is being created and published in response
# to the iwant control msg, so we will send a freshly created senderRecord
# with the RPC msg
if isinstance(self.pubsub, Pubsub):
envelope_bytes, _ = env_to_send_in_RPC(self.pubsub.host)
packet.senderRecord = envelope_bytes
packet.publish.extend(msgs_to_forward)
if self.pubsub is None:
@ -973,6 +995,12 @@ class GossipSub(IPubsubRouter, Service):
raise NoPubsubAttached
# Add control message to packet
packet: rpc_pb2.RPC = rpc_pb2.RPC()
# Add the sender's peer-record in the RPC msg
if isinstance(self.pubsub, Pubsub):
envelope_bytes, _ = env_to_send_in_RPC(self.pubsub.host)
packet.senderRecord = envelope_bytes
packet.control.CopyFrom(control_msg)
# Get stream for peer from pubsub

View File

@ -14,6 +14,7 @@ message RPC {
}
optional ControlMessage control = 3;
optional bytes senderRecord = 4;
}
message Message {

View File

@ -1,11 +1,12 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: libp2p/pubsub/pb/rpc.proto
# Protobuf Python Version: 4.25.3
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
@ -13,39 +14,39 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1alibp2p/pubsub/pb/rpc.proto\x12\tpubsub.pb\"\xb4\x01\n\x03RPC\x12-\n\rsubscriptions\x18\x01 \x03(\x0b\x32\x16.pubsub.pb.RPC.SubOpts\x12#\n\x07publish\x18\x02 \x03(\x0b\x32\x12.pubsub.pb.Message\x12*\n\x07\x63ontrol\x18\x03 \x01(\x0b\x32\x19.pubsub.pb.ControlMessage\x1a-\n\x07SubOpts\x12\x11\n\tsubscribe\x18\x01 \x01(\x08\x12\x0f\n\x07topicid\x18\x02 \x01(\t\"i\n\x07Message\x12\x0f\n\x07\x66rom_id\x18\x01 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\x12\r\n\x05seqno\x18\x03 \x01(\x0c\x12\x10\n\x08topicIDs\x18\x04 \x03(\t\x12\x11\n\tsignature\x18\x05 \x01(\x0c\x12\x0b\n\x03key\x18\x06 \x01(\x0c\"\xb0\x01\n\x0e\x43ontrolMessage\x12&\n\x05ihave\x18\x01 \x03(\x0b\x32\x17.pubsub.pb.ControlIHave\x12&\n\x05iwant\x18\x02 \x03(\x0b\x32\x17.pubsub.pb.ControlIWant\x12&\n\x05graft\x18\x03 \x03(\x0b\x32\x17.pubsub.pb.ControlGraft\x12&\n\x05prune\x18\x04 \x03(\x0b\x32\x17.pubsub.pb.ControlPrune\"3\n\x0c\x43ontrolIHave\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\x12\n\nmessageIDs\x18\x02 \x03(\t\"\"\n\x0c\x43ontrolIWant\x12\x12\n\nmessageIDs\x18\x01 \x03(\t\"\x1f\n\x0c\x43ontrolGraft\x12\x0f\n\x07topicID\x18\x01 \x01(\t\"T\n\x0c\x43ontrolPrune\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\"\n\x05peers\x18\x02 \x03(\x0b\x32\x13.pubsub.pb.PeerInfo\x12\x0f\n\x07\x62\x61\x63koff\x18\x03 \x01(\x04\"4\n\x08PeerInfo\x12\x0e\n\x06peerID\x18\x01 \x01(\x0c\x12\x18\n\x10signedPeerRecord\x18\x02 \x01(\x0c\"\x87\x03\n\x0fTopicDescriptor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x31\n\x04\x61uth\x18\x02 \x01(\x0b\x32#.pubsub.pb.TopicDescriptor.AuthOpts\x12/\n\x03\x65nc\x18\x03 \x01(\x0b\x32\".pubsub.pb.TopicDescriptor.EncOpts\x1a|\n\x08\x41uthOpts\x12:\n\x04mode\x18\x01 \x01(\x0e\x32,.pubsub.pb.TopicDescriptor.AuthOpts.AuthMode\x12\x0c\n\x04keys\x18\x02 \x03(\x0c\"&\n\x08\x41uthMode\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03KEY\x10\x01\x12\x07\n\x03WOT\x10\x02\x1a\x83\x01\n\x07\x45ncOpts\x12\x38\n\x04mode\x18\x01 \x01(\x0e\x32*.pubsub.pb.TopicDescriptor.EncOpts.EncMode\x12\x11\n\tkeyHashes\x18\x02 \x03(\x0c\"+\n\x07\x45ncMode\x12\x08\n\x04NONE\x10\x00\x12\r\n\tSHAREDKEY\x10\x01\x12\x07\n\x03WOT\x10\x02')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1alibp2p/pubsub/pb/rpc.proto\x12\tpubsub.pb\"\xca\x01\n\x03RPC\x12-\n\rsubscriptions\x18\x01 \x03(\x0b\x32\x16.pubsub.pb.RPC.SubOpts\x12#\n\x07publish\x18\x02 \x03(\x0b\x32\x12.pubsub.pb.Message\x12*\n\x07\x63ontrol\x18\x03 \x01(\x0b\x32\x19.pubsub.pb.ControlMessage\x12\x14\n\x0csenderRecord\x18\x04 \x01(\x0c\x1a-\n\x07SubOpts\x12\x11\n\tsubscribe\x18\x01 \x01(\x08\x12\x0f\n\x07topicid\x18\x02 \x01(\t\"i\n\x07Message\x12\x0f\n\x07\x66rom_id\x18\x01 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\x12\r\n\x05seqno\x18\x03 \x01(\x0c\x12\x10\n\x08topicIDs\x18\x04 \x03(\t\x12\x11\n\tsignature\x18\x05 \x01(\x0c\x12\x0b\n\x03key\x18\x06 \x01(\x0c\"\xb0\x01\n\x0e\x43ontrolMessage\x12&\n\x05ihave\x18\x01 \x03(\x0b\x32\x17.pubsub.pb.ControlIHave\x12&\n\x05iwant\x18\x02 \x03(\x0b\x32\x17.pubsub.pb.ControlIWant\x12&\n\x05graft\x18\x03 \x03(\x0b\x32\x17.pubsub.pb.ControlGraft\x12&\n\x05prune\x18\x04 \x03(\x0b\x32\x17.pubsub.pb.ControlPrune\"3\n\x0c\x43ontrolIHave\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\x12\n\nmessageIDs\x18\x02 \x03(\t\"\"\n\x0c\x43ontrolIWant\x12\x12\n\nmessageIDs\x18\x01 \x03(\t\"\x1f\n\x0c\x43ontrolGraft\x12\x0f\n\x07topicID\x18\x01 \x01(\t\"T\n\x0c\x43ontrolPrune\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\"\n\x05peers\x18\x02 \x03(\x0b\x32\x13.pubsub.pb.PeerInfo\x12\x0f\n\x07\x62\x61\x63koff\x18\x03 \x01(\x04\"4\n\x08PeerInfo\x12\x0e\n\x06peerID\x18\x01 \x01(\x0c\x12\x18\n\x10signedPeerRecord\x18\x02 \x01(\x0c\"\x87\x03\n\x0fTopicDescriptor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x31\n\x04\x61uth\x18\x02 \x01(\x0b\x32#.pubsub.pb.TopicDescriptor.AuthOpts\x12/\n\x03\x65nc\x18\x03 \x01(\x0b\x32\".pubsub.pb.TopicDescriptor.EncOpts\x1a|\n\x08\x41uthOpts\x12:\n\x04mode\x18\x01 \x01(\x0e\x32,.pubsub.pb.TopicDescriptor.AuthOpts.AuthMode\x12\x0c\n\x04keys\x18\x02 \x03(\x0c\"&\n\x08\x41uthMode\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03KEY\x10\x01\x12\x07\n\x03WOT\x10\x02\x1a\x83\x01\n\x07\x45ncOpts\x12\x38\n\x04mode\x18\x01 \x01(\x0e\x32*.pubsub.pb.TopicDescriptor.EncOpts.EncMode\x12\x11\n\tkeyHashes\x18\x02 \x03(\x0c\"+\n\x07\x45ncMode\x12\x08\n\x04NONE\x10\x00\x12\r\n\tSHAREDKEY\x10\x01\x12\x07\n\x03WOT\x10\x02')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.pubsub.pb.rpc_pb2', globals())
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.pubsub.pb.rpc_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_RPC._serialized_start=42
_RPC._serialized_end=222
_RPC_SUBOPTS._serialized_start=177
_RPC_SUBOPTS._serialized_end=222
_MESSAGE._serialized_start=224
_MESSAGE._serialized_end=329
_CONTROLMESSAGE._serialized_start=332
_CONTROLMESSAGE._serialized_end=508
_CONTROLIHAVE._serialized_start=510
_CONTROLIHAVE._serialized_end=561
_CONTROLIWANT._serialized_start=563
_CONTROLIWANT._serialized_end=597
_CONTROLGRAFT._serialized_start=599
_CONTROLGRAFT._serialized_end=630
_CONTROLPRUNE._serialized_start=632
_CONTROLPRUNE._serialized_end=716
_PEERINFO._serialized_start=718
_PEERINFO._serialized_end=770
_TOPICDESCRIPTOR._serialized_start=773
_TOPICDESCRIPTOR._serialized_end=1164
_TOPICDESCRIPTOR_AUTHOPTS._serialized_start=906
_TOPICDESCRIPTOR_AUTHOPTS._serialized_end=1030
_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE._serialized_start=992
_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE._serialized_end=1030
_TOPICDESCRIPTOR_ENCOPTS._serialized_start=1033
_TOPICDESCRIPTOR_ENCOPTS._serialized_end=1164
_TOPICDESCRIPTOR_ENCOPTS_ENCMODE._serialized_start=1121
_TOPICDESCRIPTOR_ENCOPTS_ENCMODE._serialized_end=1164
_globals['_RPC']._serialized_start=42
_globals['_RPC']._serialized_end=244
_globals['_RPC_SUBOPTS']._serialized_start=199
_globals['_RPC_SUBOPTS']._serialized_end=244
_globals['_MESSAGE']._serialized_start=246
_globals['_MESSAGE']._serialized_end=351
_globals['_CONTROLMESSAGE']._serialized_start=354
_globals['_CONTROLMESSAGE']._serialized_end=530
_globals['_CONTROLIHAVE']._serialized_start=532
_globals['_CONTROLIHAVE']._serialized_end=583
_globals['_CONTROLIWANT']._serialized_start=585
_globals['_CONTROLIWANT']._serialized_end=619
_globals['_CONTROLGRAFT']._serialized_start=621
_globals['_CONTROLGRAFT']._serialized_end=652
_globals['_CONTROLPRUNE']._serialized_start=654
_globals['_CONTROLPRUNE']._serialized_end=738
_globals['_PEERINFO']._serialized_start=740
_globals['_PEERINFO']._serialized_end=792
_globals['_TOPICDESCRIPTOR']._serialized_start=795
_globals['_TOPICDESCRIPTOR']._serialized_end=1186
_globals['_TOPICDESCRIPTOR_AUTHOPTS']._serialized_start=928
_globals['_TOPICDESCRIPTOR_AUTHOPTS']._serialized_end=1052
_globals['_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE']._serialized_start=1014
_globals['_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE']._serialized_end=1052
_globals['_TOPICDESCRIPTOR_ENCOPTS']._serialized_start=1055
_globals['_TOPICDESCRIPTOR_ENCOPTS']._serialized_end=1186
_globals['_TOPICDESCRIPTOR_ENCOPTS_ENCMODE']._serialized_start=1143
_globals['_TOPICDESCRIPTOR_ENCOPTS_ENCMODE']._serialized_end=1186
# @@protoc_insertion_point(module_scope)

View File

@ -1,323 +1,132 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
Modified from https://github.com/libp2p/go-libp2p-pubsub/blob/master/pb/rpc.proto"""
from google.protobuf.internal import containers as _containers
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union
import builtins
import collections.abc
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.internal.enum_type_wrapper
import google.protobuf.message
import sys
import typing
DESCRIPTOR: _descriptor.FileDescriptor
if sys.version_info >= (3, 10):
import typing as typing_extensions
else:
import typing_extensions
class RPC(_message.Message):
__slots__ = ("subscriptions", "publish", "control", "senderRecord")
class SubOpts(_message.Message):
__slots__ = ("subscribe", "topicid")
SUBSCRIBE_FIELD_NUMBER: _ClassVar[int]
TOPICID_FIELD_NUMBER: _ClassVar[int]
subscribe: bool
topicid: str
def __init__(self, subscribe: bool = ..., topicid: _Optional[str] = ...) -> None: ...
SUBSCRIPTIONS_FIELD_NUMBER: _ClassVar[int]
PUBLISH_FIELD_NUMBER: _ClassVar[int]
CONTROL_FIELD_NUMBER: _ClassVar[int]
SENDERRECORD_FIELD_NUMBER: _ClassVar[int]
subscriptions: _containers.RepeatedCompositeFieldContainer[RPC.SubOpts]
publish: _containers.RepeatedCompositeFieldContainer[Message]
control: ControlMessage
senderRecord: bytes
def __init__(self, subscriptions: _Optional[_Iterable[_Union[RPC.SubOpts, _Mapping]]] = ..., publish: _Optional[_Iterable[_Union[Message, _Mapping]]] = ..., control: _Optional[_Union[ControlMessage, _Mapping]] = ..., senderRecord: _Optional[bytes] = ...) -> None: ... # type: ignore
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
class Message(_message.Message):
__slots__ = ("from_id", "data", "seqno", "topicIDs", "signature", "key")
FROM_ID_FIELD_NUMBER: _ClassVar[int]
DATA_FIELD_NUMBER: _ClassVar[int]
SEQNO_FIELD_NUMBER: _ClassVar[int]
TOPICIDS_FIELD_NUMBER: _ClassVar[int]
SIGNATURE_FIELD_NUMBER: _ClassVar[int]
KEY_FIELD_NUMBER: _ClassVar[int]
from_id: bytes
data: bytes
seqno: bytes
topicIDs: _containers.RepeatedScalarFieldContainer[str]
signature: bytes
key: bytes
def __init__(self, from_id: _Optional[bytes] = ..., data: _Optional[bytes] = ..., seqno: _Optional[bytes] = ..., topicIDs: _Optional[_Iterable[str]] = ..., signature: _Optional[bytes] = ..., key: _Optional[bytes] = ...) -> None: ...
@typing.final
class RPC(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
class ControlMessage(_message.Message):
__slots__ = ("ihave", "iwant", "graft", "prune")
IHAVE_FIELD_NUMBER: _ClassVar[int]
IWANT_FIELD_NUMBER: _ClassVar[int]
GRAFT_FIELD_NUMBER: _ClassVar[int]
PRUNE_FIELD_NUMBER: _ClassVar[int]
ihave: _containers.RepeatedCompositeFieldContainer[ControlIHave]
iwant: _containers.RepeatedCompositeFieldContainer[ControlIWant]
graft: _containers.RepeatedCompositeFieldContainer[ControlGraft]
prune: _containers.RepeatedCompositeFieldContainer[ControlPrune]
def __init__(self, ihave: _Optional[_Iterable[_Union[ControlIHave, _Mapping]]] = ..., iwant: _Optional[_Iterable[_Union[ControlIWant, _Mapping]]] = ..., graft: _Optional[_Iterable[_Union[ControlGraft, _Mapping]]] = ..., prune: _Optional[_Iterable[_Union[ControlPrune, _Mapping]]] = ...) -> None: ... # type: ignore
@typing.final
class SubOpts(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
class ControlIHave(_message.Message):
__slots__ = ("topicID", "messageIDs")
TOPICID_FIELD_NUMBER: _ClassVar[int]
MESSAGEIDS_FIELD_NUMBER: _ClassVar[int]
topicID: str
messageIDs: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, topicID: _Optional[str] = ..., messageIDs: _Optional[_Iterable[str]] = ...) -> None: ...
SUBSCRIBE_FIELD_NUMBER: builtins.int
TOPICID_FIELD_NUMBER: builtins.int
subscribe: builtins.bool
"""subscribe or unsubscribe"""
topicid: builtins.str
def __init__(
self,
*,
subscribe: builtins.bool | None = ...,
topicid: builtins.str | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["subscribe", b"subscribe", "topicid", b"topicid"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["subscribe", b"subscribe", "topicid", b"topicid"]) -> None: ...
class ControlIWant(_message.Message):
__slots__ = ("messageIDs",)
MESSAGEIDS_FIELD_NUMBER: _ClassVar[int]
messageIDs: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, messageIDs: _Optional[_Iterable[str]] = ...) -> None: ...
SUBSCRIPTIONS_FIELD_NUMBER: builtins.int
PUBLISH_FIELD_NUMBER: builtins.int
CONTROL_FIELD_NUMBER: builtins.int
@property
def subscriptions(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___RPC.SubOpts]: ...
@property
def publish(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message]: ...
@property
def control(self) -> global___ControlMessage: ...
def __init__(
self,
*,
subscriptions: collections.abc.Iterable[global___RPC.SubOpts] | None = ...,
publish: collections.abc.Iterable[global___Message] | None = ...,
control: global___ControlMessage | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["control", b"control"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["control", b"control", "publish", b"publish", "subscriptions", b"subscriptions"]) -> None: ...
class ControlGraft(_message.Message):
__slots__ = ("topicID",)
TOPICID_FIELD_NUMBER: _ClassVar[int]
topicID: str
def __init__(self, topicID: _Optional[str] = ...) -> None: ...
global___RPC = RPC
class ControlPrune(_message.Message):
__slots__ = ("topicID", "peers", "backoff")
TOPICID_FIELD_NUMBER: _ClassVar[int]
PEERS_FIELD_NUMBER: _ClassVar[int]
BACKOFF_FIELD_NUMBER: _ClassVar[int]
topicID: str
peers: _containers.RepeatedCompositeFieldContainer[PeerInfo]
backoff: int
def __init__(self, topicID: _Optional[str] = ..., peers: _Optional[_Iterable[_Union[PeerInfo, _Mapping]]] = ..., backoff: _Optional[int] = ...) -> None: ... # type: ignore
@typing.final
class Message(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
class PeerInfo(_message.Message):
__slots__ = ("peerID", "signedPeerRecord")
PEERID_FIELD_NUMBER: _ClassVar[int]
SIGNEDPEERRECORD_FIELD_NUMBER: _ClassVar[int]
peerID: bytes
signedPeerRecord: bytes
def __init__(self, peerID: _Optional[bytes] = ..., signedPeerRecord: _Optional[bytes] = ...) -> None: ...
FROM_ID_FIELD_NUMBER: builtins.int
DATA_FIELD_NUMBER: builtins.int
SEQNO_FIELD_NUMBER: builtins.int
TOPICIDS_FIELD_NUMBER: builtins.int
SIGNATURE_FIELD_NUMBER: builtins.int
KEY_FIELD_NUMBER: builtins.int
from_id: builtins.bytes
data: builtins.bytes
seqno: builtins.bytes
signature: builtins.bytes
key: builtins.bytes
@property
def topicIDs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ...
def __init__(
self,
*,
from_id: builtins.bytes | None = ...,
data: builtins.bytes | None = ...,
seqno: builtins.bytes | None = ...,
topicIDs: collections.abc.Iterable[builtins.str] | None = ...,
signature: builtins.bytes | None = ...,
key: builtins.bytes | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["data", b"data", "from_id", b"from_id", "key", b"key", "seqno", b"seqno", "signature", b"signature"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["data", b"data", "from_id", b"from_id", "key", b"key", "seqno", b"seqno", "signature", b"signature", "topicIDs", b"topicIDs"]) -> None: ...
global___Message = Message
@typing.final
class ControlMessage(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
IHAVE_FIELD_NUMBER: builtins.int
IWANT_FIELD_NUMBER: builtins.int
GRAFT_FIELD_NUMBER: builtins.int
PRUNE_FIELD_NUMBER: builtins.int
@property
def ihave(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ControlIHave]: ...
@property
def iwant(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ControlIWant]: ...
@property
def graft(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ControlGraft]: ...
@property
def prune(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ControlPrune]: ...
def __init__(
self,
*,
ihave: collections.abc.Iterable[global___ControlIHave] | None = ...,
iwant: collections.abc.Iterable[global___ControlIWant] | None = ...,
graft: collections.abc.Iterable[global___ControlGraft] | None = ...,
prune: collections.abc.Iterable[global___ControlPrune] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["graft", b"graft", "ihave", b"ihave", "iwant", b"iwant", "prune", b"prune"]) -> None: ...
global___ControlMessage = ControlMessage
@typing.final
class ControlIHave(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
TOPICID_FIELD_NUMBER: builtins.int
MESSAGEIDS_FIELD_NUMBER: builtins.int
topicID: builtins.str
@property
def messageIDs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ...
def __init__(
self,
*,
topicID: builtins.str | None = ...,
messageIDs: collections.abc.Iterable[builtins.str] | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["topicID", b"topicID"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["messageIDs", b"messageIDs", "topicID", b"topicID"]) -> None: ...
global___ControlIHave = ControlIHave
@typing.final
class ControlIWant(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
MESSAGEIDS_FIELD_NUMBER: builtins.int
@property
def messageIDs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ...
def __init__(
self,
*,
messageIDs: collections.abc.Iterable[builtins.str] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["messageIDs", b"messageIDs"]) -> None: ...
global___ControlIWant = ControlIWant
@typing.final
class ControlGraft(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
TOPICID_FIELD_NUMBER: builtins.int
topicID: builtins.str
def __init__(
self,
*,
topicID: builtins.str | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["topicID", b"topicID"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["topicID", b"topicID"]) -> None: ...
global___ControlGraft = ControlGraft
@typing.final
class ControlPrune(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
TOPICID_FIELD_NUMBER: builtins.int
PEERS_FIELD_NUMBER: builtins.int
BACKOFF_FIELD_NUMBER: builtins.int
topicID: builtins.str
backoff: builtins.int
@property
def peers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___PeerInfo]: ...
def __init__(
self,
*,
topicID: builtins.str | None = ...,
peers: collections.abc.Iterable[global___PeerInfo] | None = ...,
backoff: builtins.int | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["backoff", b"backoff", "topicID", b"topicID"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["backoff", b"backoff", "peers", b"peers", "topicID", b"topicID"]) -> None: ...
global___ControlPrune = ControlPrune
@typing.final
class PeerInfo(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
PEERID_FIELD_NUMBER: builtins.int
SIGNEDPEERRECORD_FIELD_NUMBER: builtins.int
peerID: builtins.bytes
signedPeerRecord: builtins.bytes
def __init__(
self,
*,
peerID: builtins.bytes | None = ...,
signedPeerRecord: builtins.bytes | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["peerID", b"peerID", "signedPeerRecord", b"signedPeerRecord"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["peerID", b"peerID", "signedPeerRecord", b"signedPeerRecord"]) -> None: ...
global___PeerInfo = PeerInfo
@typing.final
class TopicDescriptor(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@typing.final
class AuthOpts(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
class _AuthMode:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _AuthModeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[TopicDescriptor.AuthOpts._AuthMode.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
NONE: TopicDescriptor.AuthOpts._AuthMode.ValueType # 0
"""no authentication, anyone can publish"""
KEY: TopicDescriptor.AuthOpts._AuthMode.ValueType # 1
"""only messages signed by keys in the topic descriptor are accepted"""
WOT: TopicDescriptor.AuthOpts._AuthMode.ValueType # 2
"""web of trust, certificates can allow publisher set to grow"""
class AuthMode(_AuthMode, metaclass=_AuthModeEnumTypeWrapper): ...
NONE: TopicDescriptor.AuthOpts.AuthMode.ValueType # 0
"""no authentication, anyone can publish"""
KEY: TopicDescriptor.AuthOpts.AuthMode.ValueType # 1
"""only messages signed by keys in the topic descriptor are accepted"""
WOT: TopicDescriptor.AuthOpts.AuthMode.ValueType # 2
"""web of trust, certificates can allow publisher set to grow"""
MODE_FIELD_NUMBER: builtins.int
KEYS_FIELD_NUMBER: builtins.int
mode: global___TopicDescriptor.AuthOpts.AuthMode.ValueType
@property
def keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]:
"""root keys to trust"""
def __init__(
self,
*,
mode: global___TopicDescriptor.AuthOpts.AuthMode.ValueType | None = ...,
keys: collections.abc.Iterable[builtins.bytes] | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["mode", b"mode"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["keys", b"keys", "mode", b"mode"]) -> None: ...
@typing.final
class EncOpts(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
class _EncMode:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _EncModeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[TopicDescriptor.EncOpts._EncMode.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
NONE: TopicDescriptor.EncOpts._EncMode.ValueType # 0
"""no encryption, anyone can read"""
SHAREDKEY: TopicDescriptor.EncOpts._EncMode.ValueType # 1
"""messages are encrypted with shared key"""
WOT: TopicDescriptor.EncOpts._EncMode.ValueType # 2
"""web of trust, certificates can allow publisher set to grow"""
class EncMode(_EncMode, metaclass=_EncModeEnumTypeWrapper): ...
NONE: TopicDescriptor.EncOpts.EncMode.ValueType # 0
"""no encryption, anyone can read"""
SHAREDKEY: TopicDescriptor.EncOpts.EncMode.ValueType # 1
"""messages are encrypted with shared key"""
WOT: TopicDescriptor.EncOpts.EncMode.ValueType # 2
"""web of trust, certificates can allow publisher set to grow"""
MODE_FIELD_NUMBER: builtins.int
KEYHASHES_FIELD_NUMBER: builtins.int
mode: global___TopicDescriptor.EncOpts.EncMode.ValueType
@property
def keyHashes(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]:
"""the hashes of the shared keys used (salted)"""
def __init__(
self,
*,
mode: global___TopicDescriptor.EncOpts.EncMode.ValueType | None = ...,
keyHashes: collections.abc.Iterable[builtins.bytes] | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["mode", b"mode"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["keyHashes", b"keyHashes", "mode", b"mode"]) -> None: ...
NAME_FIELD_NUMBER: builtins.int
AUTH_FIELD_NUMBER: builtins.int
ENC_FIELD_NUMBER: builtins.int
name: builtins.str
@property
def auth(self) -> global___TopicDescriptor.AuthOpts: ...
@property
def enc(self) -> global___TopicDescriptor.EncOpts: ...
def __init__(
self,
*,
name: builtins.str | None = ...,
auth: global___TopicDescriptor.AuthOpts | None = ...,
enc: global___TopicDescriptor.EncOpts | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["auth", b"auth", "enc", b"enc", "name", b"name"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["auth", b"auth", "enc", b"enc", "name", b"name"]) -> None: ...
global___TopicDescriptor = TopicDescriptor
class TopicDescriptor(_message.Message):
__slots__ = ("name", "auth", "enc")
class AuthOpts(_message.Message):
__slots__ = ("mode", "keys")
class AuthMode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
NONE: _ClassVar[TopicDescriptor.AuthOpts.AuthMode]
KEY: _ClassVar[TopicDescriptor.AuthOpts.AuthMode]
WOT: _ClassVar[TopicDescriptor.AuthOpts.AuthMode]
NONE: TopicDescriptor.AuthOpts.AuthMode
KEY: TopicDescriptor.AuthOpts.AuthMode
WOT: TopicDescriptor.AuthOpts.AuthMode
MODE_FIELD_NUMBER: _ClassVar[int]
KEYS_FIELD_NUMBER: _ClassVar[int]
mode: TopicDescriptor.AuthOpts.AuthMode
keys: _containers.RepeatedScalarFieldContainer[bytes]
def __init__(self, mode: _Optional[_Union[TopicDescriptor.AuthOpts.AuthMode, str]] = ..., keys: _Optional[_Iterable[bytes]] = ...) -> None: ...
class EncOpts(_message.Message):
__slots__ = ("mode", "keyHashes")
class EncMode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
NONE: _ClassVar[TopicDescriptor.EncOpts.EncMode]
SHAREDKEY: _ClassVar[TopicDescriptor.EncOpts.EncMode]
WOT: _ClassVar[TopicDescriptor.EncOpts.EncMode]
NONE: TopicDescriptor.EncOpts.EncMode
SHAREDKEY: TopicDescriptor.EncOpts.EncMode
WOT: TopicDescriptor.EncOpts.EncMode
MODE_FIELD_NUMBER: _ClassVar[int]
KEYHASHES_FIELD_NUMBER: _ClassVar[int]
mode: TopicDescriptor.EncOpts.EncMode
keyHashes: _containers.RepeatedScalarFieldContainer[bytes]
def __init__(self, mode: _Optional[_Union[TopicDescriptor.EncOpts.EncMode, str]] = ..., keyHashes: _Optional[_Iterable[bytes]] = ...) -> None: ...
NAME_FIELD_NUMBER: _ClassVar[int]
AUTH_FIELD_NUMBER: _ClassVar[int]
ENC_FIELD_NUMBER: _ClassVar[int]
name: str
auth: TopicDescriptor.AuthOpts
enc: TopicDescriptor.EncOpts
def __init__(self, name: _Optional[str] = ..., auth: _Optional[_Union[TopicDescriptor.AuthOpts, _Mapping]] = ..., enc: _Optional[_Union[TopicDescriptor.EncOpts, _Mapping]] = ...) -> None: ... # type: ignore

View File

@ -56,6 +56,8 @@ from libp2p.peer.id import (
from libp2p.peer.peerdata import (
PeerDataError,
)
from libp2p.peer.peerstore import env_to_send_in_RPC
from libp2p.pubsub.utils import maybe_consume_signed_record
from libp2p.tools.async_service import (
Service,
)
@ -247,6 +249,10 @@ class Pubsub(Service, IPubsub):
packet.subscriptions.extend(
[rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id)]
)
# Add the sender's signedRecord in the RPC message
envelope_bytes, _ = env_to_send_in_RPC(self.host)
packet.senderRecord = envelope_bytes
return packet
async def continuously_read_stream(self, stream: INetStream) -> None:
@ -263,6 +269,14 @@ class Pubsub(Service, IPubsub):
incoming: bytes = await read_varint_prefixed_bytes(stream)
rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC()
rpc_incoming.ParseFromString(incoming)
# Process the sender's signed-record if sent
if not maybe_consume_signed_record(rpc_incoming, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, ignoring the incoming msg"
)
continue
if rpc_incoming.publish:
# deal with RPC.publish
for msg in rpc_incoming.publish:
@ -572,6 +586,9 @@ class Pubsub(Service, IPubsub):
[rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id)]
)
# Add the senderRecord of the peer in the RPC msg
envelope_bytes, _ = env_to_send_in_RPC(self.host)
packet.senderRecord = envelope_bytes
# Send out subscribe message to all peers
await self.message_all_peers(packet.SerializeToString())
@ -604,6 +621,9 @@ class Pubsub(Service, IPubsub):
packet.subscriptions.extend(
[rpc_pb2.RPC.SubOpts(subscribe=False, topicid=topic_id)]
)
# Add the senderRecord of the peer in the RPC msg
envelope_bytes, _ = env_to_send_in_RPC(self.host)
packet.senderRecord = envelope_bytes
# Send out unsubscribe message to all peers
await self.message_all_peers(packet.SerializeToString())

80
libp2p/pubsub/utils.py Normal file
View File

@ -0,0 +1,80 @@
import ast
import logging
from libp2p.abc import IHost
from libp2p.custom_types import (
MessageID,
)
from libp2p.peer.envelope import consume_envelope
from libp2p.peer.id import ID
from libp2p.pubsub.pb.rpc_pb2 import RPC
logger = logging.getLogger("pubsub-example.utils")
def maybe_consume_signed_record(msg: RPC, host: IHost, peer_id: ID) -> bool:
"""
Attempt to parse and store a signed-peer-record (Envelope) received during
PubSub communication. If the record is invalid, the peer-id does not match, or
updating the peerstore fails, the function logs an error and returns False.
Parameters
----------
msg : RPC
The protobuf message received during PubSub communication.
host : IHost
The local host instance, providing access to the peerstore for storing
verified peer records.
peer_id : ID | None, optional
The expected peer ID for record validation. If provided, the peer ID
inside the record must match this value.
Returns
-------
bool
True if a valid signed peer record was successfully consumed and stored,
False otherwise.
"""
if msg.HasField("senderRecord"):
try:
# Convert the signed-peer-record(Envelope) from
# protobuf bytes
envelope, record = consume_envelope(msg.senderRecord, "libp2p-peer-record")
if not record.peer_id == peer_id:
return False
# Use the default TTL of 2 hours (7200 seconds)
if not host.get_peerstore().consume_peer_record(envelope, 7200):
logger.error("Failed to update the Certified-Addr-Book")
return False
except Exception as e:
logger.error("Failed to update the Certified-Addr-Book: %s", e)
return False
return True
def parse_message_id_safe(msg_id_str: str) -> MessageID:
"""Safely handle message ID as string."""
return MessageID(msg_id_str)
def safe_parse_message_id(msg_id_str: str) -> tuple[bytes, bytes]:
"""
Safely parse message ID using ast.literal_eval with validation.
:param msg_id_str: String representation of message ID
:return: Tuple of (seqno, from_id) as bytes
:raises ValueError: If parsing fails
"""
try:
parsed = ast.literal_eval(msg_id_str)
if not isinstance(parsed, tuple) or len(parsed) != 2:
raise ValueError("Invalid message ID format")
seqno, from_id = parsed
if not isinstance(seqno, bytes) or not isinstance(from_id, bytes):
raise ValueError("Message ID components must be bytes")
return (seqno, from_id)
except (ValueError, SyntaxError) as e:
raise ValueError(f"Invalid message ID format: {e}")

View File

@ -0,0 +1,68 @@
from abc import ABC, abstractmethod
from libp2p.abc import IRawConnection
from libp2p.custom_types import TProtocol
from libp2p.peer.id import ID
from .pb import noise_pb2 as noise_pb
class EarlyDataHandler(ABC):
"""Interface for handling early data during Noise handshake"""
@abstractmethod
async def send(
self, conn: IRawConnection, peer_id: ID
) -> noise_pb.NoiseExtensions | None:
"""Called to generate early data to send during handshake"""
pass
@abstractmethod
async def received(
self, conn: IRawConnection, extensions: noise_pb.NoiseExtensions | None
) -> None:
"""Called when early data is received during handshake"""
pass
class TransportEarlyDataHandler(EarlyDataHandler):
"""Default early data handler for muxer negotiation"""
def __init__(self, supported_muxers: list[TProtocol]):
self.supported_muxers = supported_muxers
self.received_muxers: list[TProtocol] = []
async def send(
self, conn: IRawConnection, peer_id: ID
) -> noise_pb.NoiseExtensions | None:
"""Send our supported muxers list"""
if not self.supported_muxers:
return None
extensions = noise_pb.NoiseExtensions()
# Convert TProtocol to string for serialization
extensions.stream_muxers[:] = [str(muxer) for muxer in self.supported_muxers]
return extensions
async def received(
self, conn: IRawConnection, extensions: noise_pb.NoiseExtensions | None
) -> None:
"""Store received muxers list"""
if extensions and extensions.stream_muxers:
self.received_muxers = [
TProtocol(muxer) for muxer in extensions.stream_muxers
]
def match_muxers(self, is_initiator: bool) -> TProtocol | None:
"""Find first common muxer between local and remote"""
if is_initiator:
# Initiator: find first local muxer that remote supports
for local_muxer in self.supported_muxers:
if local_muxer in self.received_muxers:
return local_muxer
else:
# Responder: find first remote muxer that we support
for remote_muxer in self.received_muxers:
if remote_muxer in self.supported_muxers:
return remote_muxer
return None

View File

@ -30,6 +30,9 @@ from libp2p.security.secure_session import (
SecureSession,
)
from .early_data import (
EarlyDataHandler,
)
from .exceptions import (
HandshakeHasNotFinished,
InvalidSignature,
@ -45,6 +48,7 @@ from .messages import (
make_handshake_payload_sig,
verify_handshake_payload_sig,
)
from .pb import noise_pb2 as noise_pb
class IPattern(ABC):
@ -62,7 +66,8 @@ class BasePattern(IPattern):
noise_static_key: PrivateKey
local_peer: ID
libp2p_privkey: PrivateKey
early_data: bytes | None
initiator_early_data_handler: EarlyDataHandler | None
responder_early_data_handler: EarlyDataHandler | None
def create_noise_state(self) -> NoiseState:
noise_state = NoiseState.from_name(self.protocol_name)
@ -73,11 +78,50 @@ class BasePattern(IPattern):
raise NoiseStateError("noise_protocol is not initialized")
return noise_state
def make_handshake_payload(self) -> NoiseHandshakePayload:
async def make_handshake_payload(
self, conn: IRawConnection, peer_id: ID, is_initiator: bool
) -> NoiseHandshakePayload:
signature = make_handshake_payload_sig(
self.libp2p_privkey, self.noise_static_key.get_public_key()
)
return NoiseHandshakePayload(self.libp2p_privkey.get_public_key(), signature)
# NEW: Get early data from appropriate handler
extensions = None
if is_initiator and self.initiator_early_data_handler:
extensions = await self.initiator_early_data_handler.send(conn, peer_id)
elif not is_initiator and self.responder_early_data_handler:
extensions = await self.responder_early_data_handler.send(conn, peer_id)
# NEW: Serialize extensions into early_data field
early_data = None
if extensions:
early_data = extensions.SerializeToString()
return NoiseHandshakePayload(
self.libp2p_privkey.get_public_key(),
signature,
early_data, # ← This is the key addition
)
async def handle_received_payload(
self, conn: IRawConnection, payload: NoiseHandshakePayload, is_initiator: bool
) -> None:
"""Process early data from received payload"""
if not payload.early_data:
return
# Deserialize the NoiseExtensions from early_data field
try:
extensions = noise_pb.NoiseExtensions.FromString(payload.early_data)
except Exception:
# Invalid extensions, ignore silently
return
# Pass to appropriate handler
if is_initiator and self.initiator_early_data_handler:
await self.initiator_early_data_handler.received(conn, extensions)
elif not is_initiator and self.responder_early_data_handler:
await self.responder_early_data_handler.received(conn, extensions)
class PatternXX(BasePattern):
@ -86,13 +130,15 @@ class PatternXX(BasePattern):
local_peer: ID,
libp2p_privkey: PrivateKey,
noise_static_key: PrivateKey,
early_data: bytes | None = None,
initiator_early_data_handler: EarlyDataHandler | None,
responder_early_data_handler: EarlyDataHandler | None,
) -> None:
self.protocol_name = b"Noise_XX_25519_ChaChaPoly_SHA256"
self.local_peer = local_peer
self.libp2p_privkey = libp2p_privkey
self.noise_static_key = noise_static_key
self.early_data = early_data
self.initiator_early_data_handler = initiator_early_data_handler
self.responder_early_data_handler = responder_early_data_handler
async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn:
noise_state = self.create_noise_state()
@ -106,18 +152,23 @@ class PatternXX(BasePattern):
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
# Consume msg#1.
# 1. Consume msg#1 (just empty bytes)
await read_writer.read_msg()
# Send msg#2, which should include our handshake payload.
our_payload = self.make_handshake_payload()
# 2. Send msg#2 with our payload INCLUDING EARLY DATA
our_payload = await self.make_handshake_payload(
conn,
self.local_peer, # We send our own peer ID in responder role
is_initiator=False,
)
msg_2 = our_payload.serialize()
await read_writer.write_msg(msg_2)
# Receive and consume msg#3.
# 3. Receive msg#3
msg_3 = await read_writer.read_msg()
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_3)
# Extract remote pubkey from noise handshake state
if handshake_state.rs is None:
raise NoiseStateError(
"something is wrong in the underlying noise `handshake_state`: "
@ -126,14 +177,31 @@ class PatternXX(BasePattern):
)
remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs)
# 4. Verify signature (unchanged)
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
raise InvalidSignature
# NEW: Process early data from msg#3 AFTER signature verification
await self.handle_received_payload(
conn, peer_handshake_payload, is_initiator=False
)
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
if not noise_state.handshake_finished:
raise HandshakeHasNotFinished(
"handshake is done but it is not marked as finished in `noise_state`"
)
# NEW: Get negotiated muxer for connection state
# negotiated_muxer = None
if self.responder_early_data_handler and hasattr(
self.responder_early_data_handler, "match_muxers"
):
# negotiated_muxer =
# self.responder_early_data_handler.match_muxers(is_initiator=False)
pass
transport_read_writer = NoiseTransportReadWriter(conn, noise_state)
return SecureSession(
local_peer=self.local_peer,
@ -142,6 +210,8 @@ class PatternXX(BasePattern):
remote_permanent_pubkey=remote_pubkey,
is_initiator=False,
conn=transport_read_writer,
# NOTE: negotiated_muxer would need to be added to SecureSession constructor
# For now, store it in connection metadata or similar
)
async def handshake_outbound(
@ -158,24 +228,27 @@ class PatternXX(BasePattern):
if handshake_state is None:
raise NoiseStateError("Handshake state is not initialized")
# Send msg#1, which is *not* encrypted.
# 1. Send msg#1 (empty) - no early data possible in XX pattern
msg_1 = b""
await read_writer.write_msg(msg_1)
# Read msg#2 from the remote, which contains the public key of the peer.
# 2. Read msg#2 from responder
msg_2 = await read_writer.read_msg()
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_2)
# Extract remote pubkey from noise handshake state
if handshake_state.rs is None:
raise NoiseStateError(
"something is wrong in the underlying noise `handshake_state`: "
"we received and consumed msg#3, which should have included the "
"we received and consumed msg#2, which should have included the "
"remote static public key, but it is not present in the handshake_state"
)
remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs)
# Verify signature BEFORE processing early data (security)
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
raise InvalidSignature
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
if remote_peer_id_from_pubkey != remote_peer:
raise PeerIDMismatchesPubkey(
@ -184,8 +257,15 @@ class PatternXX(BasePattern):
f"remote_peer_id_from_pubkey={remote_peer_id_from_pubkey}"
)
# Send msg#3, which includes our encrypted payload and our noise static key.
our_payload = self.make_handshake_payload()
# NEW: Process early data from msg#2 AFTER verification
await self.handle_received_payload(
conn, peer_handshake_payload, is_initiator=True
)
# 3. Send msg#3 with our payload INCLUDING EARLY DATA
our_payload = await self.make_handshake_payload(
conn, remote_peer, is_initiator=True
)
msg_3 = our_payload.serialize()
await read_writer.write_msg(msg_3)
@ -193,6 +273,16 @@ class PatternXX(BasePattern):
raise HandshakeHasNotFinished(
"handshake is done but it is not marked as finished in `noise_state`"
)
# NEW: Get negotiated muxer
# negotiated_muxer = None
if self.initiator_early_data_handler and hasattr(
self.initiator_early_data_handler, "match_muxers"
):
pass
# negotiated_muxer =
# self.initiator_early_data_handler.match_muxers(is_initiator=True)
transport_read_writer = NoiseTransportReadWriter(conn, noise_state)
return SecureSession(
local_peer=self.local_peer,
@ -201,6 +291,8 @@ class PatternXX(BasePattern):
remote_permanent_pubkey=remote_pubkey,
is_initiator=True,
conn=transport_read_writer,
# NOTE: negotiated_muxer would need to be added to SecureSession constructor
# For now, store it in connection metadata or similar
)
@staticmethod

View File

@ -1,8 +1,13 @@
syntax = "proto3";
syntax = "proto2";
package pb;
message NoiseHandshakePayload {
bytes identity_key = 1;
bytes identity_sig = 2;
bytes data = 3;
message NoiseExtensions {
repeated bytes webtransport_certhashes = 1;
repeated string stream_muxers = 2;
}
message NoiseHandshakePayload {
optional bytes identity_key = 1;
optional bytes identity_sig = 2;
optional bytes data = 3;
}

View File

@ -13,13 +13,15 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$libp2p/security/noise/pb/noise.proto\x12\x02pb\"Q\n\x15NoiseHandshakePayload\x12\x14\n\x0cidentity_key\x18\x01 \x01(\x0c\x12\x14\n\x0cidentity_sig\x18\x02 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\x62\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$libp2p/security/noise/pb/noise.proto\x12\x02pb\"I\n\x0fNoiseExtensions\x12\x1f\n\x17webtransport_certhashes\x18\x01 \x03(\x0c\x12\x15\n\rstream_muxers\x18\x02 \x03(\t\"Q\n\x15NoiseHandshakePayload\x12\x14\n\x0cidentity_key\x18\x01 \x01(\x0c\x12\x14\n\x0cidentity_sig\x18\x02 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.security.noise.pb.noise_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_NOISEHANDSHAKEPAYLOAD._serialized_start=44
_NOISEHANDSHAKEPAYLOAD._serialized_end=125
_NOISEEXTENSIONS._serialized_start=44
_NOISEEXTENSIONS._serialized_end=117
_NOISEHANDSHAKEPAYLOAD._serialized_start=119
_NOISEHANDSHAKEPAYLOAD._serialized_end=200
# @@protoc_insertion_point(module_scope)

View File

@ -4,12 +4,34 @@ isort:skip_file
"""
import builtins
import collections.abc
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.message
import typing
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@typing.final
class NoiseExtensions(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
WEBTRANSPORT_CERTHASHES_FIELD_NUMBER: builtins.int
STREAM_MUXERS_FIELD_NUMBER: builtins.int
@property
def webtransport_certhashes(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ...
@property
def stream_muxers(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ...
def __init__(
self,
*,
webtransport_certhashes: collections.abc.Iterable[builtins.bytes] | None = ...,
stream_muxers: collections.abc.Iterable[builtins.str] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["stream_muxers", b"stream_muxers", "webtransport_certhashes", b"webtransport_certhashes"]) -> None: ...
global___NoiseExtensions = NoiseExtensions
@typing.final
class NoiseHandshakePayload(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@ -23,10 +45,11 @@ class NoiseHandshakePayload(google.protobuf.message.Message):
def __init__(
self,
*,
identity_key: builtins.bytes = ...,
identity_sig: builtins.bytes = ...,
data: builtins.bytes = ...,
identity_key: builtins.bytes | None = ...,
identity_sig: builtins.bytes | None = ...,
data: builtins.bytes | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["data", b"data", "identity_key", b"identity_key", "identity_sig", b"identity_sig"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["data", b"data", "identity_key", b"identity_key", "identity_sig", b"identity_sig"]) -> None: ...
global___NoiseHandshakePayload = NoiseHandshakePayload

View File

@ -14,6 +14,7 @@ from libp2p.peer.id import (
ID,
)
from .early_data import EarlyDataHandler, TransportEarlyDataHandler
from .patterns import (
IPattern,
PatternXX,
@ -26,35 +27,40 @@ class Transport(ISecureTransport):
libp2p_privkey: PrivateKey
noise_privkey: PrivateKey
local_peer: ID
early_data: bytes | None
with_noise_pipes: bool
supported_muxers: list[TProtocol]
initiator_early_data_handler: EarlyDataHandler | None
responder_early_data_handler: EarlyDataHandler | None
def __init__(
self,
libp2p_keypair: KeyPair,
noise_privkey: PrivateKey,
early_data: bytes | None = None,
with_noise_pipes: bool = False,
supported_muxers: list[TProtocol] | None = None,
initiator_handler: EarlyDataHandler | None = None,
responder_handler: EarlyDataHandler | None = None,
) -> None:
self.libp2p_privkey = libp2p_keypair.private_key
self.noise_privkey = noise_privkey
self.local_peer = ID.from_pubkey(libp2p_keypair.public_key)
self.early_data = early_data
self.with_noise_pipes = with_noise_pipes
self.supported_muxers = supported_muxers or []
if self.with_noise_pipes:
raise NotImplementedError
# Create default handlers for muxer negotiation if none provided
if initiator_handler is None and self.supported_muxers:
initiator_handler = TransportEarlyDataHandler(self.supported_muxers)
if responder_handler is None and self.supported_muxers:
responder_handler = TransportEarlyDataHandler(self.supported_muxers)
self.initiator_early_data_handler = initiator_handler
self.responder_early_data_handler = responder_handler
def get_pattern(self) -> IPattern:
if self.with_noise_pipes:
raise NotImplementedError
else:
return PatternXX(
self.local_peer,
self.libp2p_privkey,
self.noise_privkey,
self.early_data,
)
return PatternXX(
self.local_peer,
self.libp2p_privkey,
self.noise_privkey,
self.initiator_early_data_handler,
self.responder_early_data_handler,
)
async def secure_inbound(self, conn: IRawConnection) -> ISecureConn:
pattern = self.get_pattern()

View File

@ -17,6 +17,9 @@ from libp2p.custom_types import (
from libp2p.peer.id import (
ID,
)
from libp2p.protocol_muxer.exceptions import (
MultiselectError,
)
from libp2p.protocol_muxer.multiselect import (
Multiselect,
)
@ -104,7 +107,7 @@ class SecurityMultistream(ABC):
:param is_initiator: true if we are the initiator, false otherwise
:return: selected secure transport
"""
protocol: TProtocol
protocol: TProtocol | None
communicator = MultiselectCommunicator(conn)
if is_initiator:
# Select protocol if initiator
@ -114,5 +117,9 @@ class SecurityMultistream(ABC):
else:
# Select protocol if non-initiator
protocol, _ = await self.multiselect.negotiate(communicator)
if protocol is None:
raise MultiselectError(
"Failed to negotiate a security protocol: no protocol selected"
)
# Return transport from protocol
return self.transports[protocol]

View File

@ -17,6 +17,9 @@ from libp2p.custom_types import (
from libp2p.peer.id import (
ID,
)
from libp2p.protocol_muxer.exceptions import (
MultiselectError,
)
from libp2p.protocol_muxer.multiselect import (
Multiselect,
)
@ -73,7 +76,7 @@ class MuxerMultistream:
:param conn: conn to choose a transport over
:return: selected muxer transport
"""
protocol: TProtocol
protocol: TProtocol | None
communicator = MultiselectCommunicator(conn)
if conn.is_initiator:
protocol = await self.multiselect_client.select_one_of(
@ -81,6 +84,10 @@ class MuxerMultistream:
)
else:
protocol, _ = await self.multiselect.negotiate(communicator)
if protocol is None:
raise MultiselectError(
"Fail to negotiate a stream muxer protocol: no protocol selected"
)
return self.transports[protocol]
async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn:

View File

@ -1,9 +1,7 @@
from libp2p.abc import (
IListener,
IMuxedConn,
IRawConnection,
ISecureConn,
ITransport,
)
from libp2p.custom_types import (
TMuxerOptions,
@ -43,10 +41,6 @@ class TransportUpgrader:
self.security_multistream = SecurityMultistream(secure_transports_by_protocol)
self.muxer_multistream = MuxerMultistream(muxer_transports_by_protocol)
def upgrade_listener(self, transport: ITransport, listeners: IListener) -> None:
"""Upgrade multiaddr listeners to libp2p-transport listeners."""
# TODO: Figure out what to do with this function.
async def upgrade_security(
self,
raw_conn: IRawConnection,

View File

@ -15,6 +15,13 @@ from libp2p.utils.version import (
get_agent_version,
)
from libp2p.utils.address_validation import (
get_available_interfaces,
get_optimal_binding_address,
expand_wildcard_address,
find_free_port,
)
__all__ = [
"decode_uvarint_from_stream",
"encode_delim",
@ -26,4 +33,8 @@ __all__ = [
"decode_varint_from_bytes",
"decode_varint_with_size",
"read_length_prefixed_protobuf",
"get_available_interfaces",
"get_optimal_binding_address",
"expand_wildcard_address",
"find_free_port",
]

View File

@ -0,0 +1,160 @@
from __future__ import annotations
import socket
from multiaddr import Multiaddr
try:
from multiaddr.utils import ( # type: ignore
get_network_addrs,
get_thin_waist_addresses,
)
_HAS_THIN_WAIST = True
except ImportError: # pragma: no cover - only executed in older environments
_HAS_THIN_WAIST = False
get_thin_waist_addresses = None # type: ignore
get_network_addrs = None # type: ignore
def _safe_get_network_addrs(ip_version: int) -> list[str]:
"""
Internal safe wrapper. Returns a list of IP addresses for the requested IP version.
Falls back to minimal defaults when Thin Waist helpers are missing.
:param ip_version: 4 or 6
"""
if _HAS_THIN_WAIST and get_network_addrs:
try:
return get_network_addrs(ip_version) or []
except Exception: # pragma: no cover - defensive
return []
# Fallback behavior (very conservative)
if ip_version == 4:
return ["127.0.0.1"]
if ip_version == 6:
return ["::1"]
return []
def find_free_port() -> int:
"""Find a free port on localhost."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0)) # Bind to a free port provided by the OS
return s.getsockname()[1]
def _safe_expand(addr: Multiaddr, port: int | None = None) -> list[Multiaddr]:
"""
Internal safe expansion wrapper. Returns a list of Multiaddr objects.
If Thin Waist isn't available, returns [addr] (identity).
"""
if _HAS_THIN_WAIST and get_thin_waist_addresses:
try:
if port is not None:
return get_thin_waist_addresses(addr, port=port) or []
return get_thin_waist_addresses(addr) or []
except Exception: # pragma: no cover - defensive
return [addr]
return [addr]
def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr]:
"""
Discover available network interfaces (IPv4 + IPv6 if supported) for binding.
:param port: Port number to bind to.
:param protocol: Transport protocol (e.g., "tcp" or "udp").
:return: List of Multiaddr objects representing candidate interface addresses.
"""
addrs: list[Multiaddr] = []
# IPv4 enumeration
seen_v4: set[str] = set()
for ip in _safe_get_network_addrs(4):
seen_v4.add(ip)
addrs.append(Multiaddr(f"/ip4/{ip}/{protocol}/{port}"))
# Ensure IPv4 loopback is always included when IPv4 interfaces are discovered
if seen_v4 and "127.0.0.1" not in seen_v4:
addrs.append(Multiaddr(f"/ip4/127.0.0.1/{protocol}/{port}"))
# TODO: IPv6 support temporarily disabled due to libp2p handshake issues
# IPv6 connections fail during protocol negotiation (SecurityUpgradeFailure)
# Re-enable IPv6 support once the following issues are resolved:
# - libp2p security handshake over IPv6
# - multiselect protocol over IPv6
# - connection establishment over IPv6
#
# seen_v6: set[str] = set()
# for ip in _safe_get_network_addrs(6):
# seen_v6.add(ip)
# addrs.append(Multiaddr(f"/ip6/{ip}/{protocol}/{port}"))
#
# # Always include IPv6 loopback for testing purposes when IPv6 is available
# # This ensures IPv6 functionality can be tested even without global IPv6 addresses
# if "::1" not in seen_v6:
# addrs.append(Multiaddr(f"/ip6/::1/{protocol}/{port}"))
# Fallback if nothing discovered
if not addrs:
addrs.append(Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}"))
return addrs
def expand_wildcard_address(
addr: Multiaddr, port: int | None = None
) -> list[Multiaddr]:
"""
Expand a wildcard (e.g. /ip4/0.0.0.0/tcp/0) into all concrete interfaces.
:param addr: Multiaddr to expand.
:param port: Optional override for port selection.
:return: List of concrete Multiaddr instances.
"""
expanded = _safe_expand(addr, port=port)
if not expanded: # Safety fallback
return [addr]
return expanded
def get_optimal_binding_address(port: int, protocol: str = "tcp") -> Multiaddr:
"""
Choose an optimal address for an example to bind to:
- Prefer non-loopback IPv4
- Then non-loopback IPv6
- Fallback to loopback
- Fallback to wildcard
:param port: Port number.
:param protocol: Transport protocol.
:return: A single Multiaddr chosen heuristically.
"""
candidates = get_available_interfaces(port, protocol)
def is_non_loopback(ma: Multiaddr) -> bool:
s = str(ma)
return not ("/ip4/127." in s or "/ip6/::1" in s)
for c in candidates:
if "/ip4/" in str(c) and is_non_loopback(c):
return c
for c in candidates:
if "/ip6/" in str(c) and is_non_loopback(c):
return c
for c in candidates:
if "/ip4/127." in str(c) or "/ip6/::1" in str(c):
return c
# As a final fallback, produce a wildcard
return Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")
__all__ = [
"get_available_interfaces",
"get_optimal_binding_address",
"expand_wildcard_address",
"find_free_port",
]

View File

@ -1,7 +1,4 @@
import atexit
from datetime import (
datetime,
)
import logging
import logging.handlers
import os
@ -21,6 +18,9 @@ log_queue: "queue.Queue[Any]" = queue.Queue()
# Store the current listener to stop it on exit
_current_listener: logging.handlers.QueueListener | None = None
# Store the handlers for proper cleanup
_current_handlers: list[logging.Handler] = []
# Event to track when the listener is ready
_listener_ready = threading.Event()
@ -95,7 +95,7 @@ def setup_logging() -> None:
- Child loggers inherit their parent's level unless explicitly set
- The root libp2p logger controls the default level
"""
global _current_listener, _listener_ready
global _current_listener, _listener_ready, _current_handlers
# Reset the event
_listener_ready.clear()
@ -105,6 +105,12 @@ def setup_logging() -> None:
_current_listener.stop()
_current_listener = None
# Close and clear existing handlers
for handler in _current_handlers:
if isinstance(handler, logging.FileHandler):
handler.close()
_current_handlers.clear()
# Get the log level from environment variable
debug_str = os.environ.get("LIBP2P_DEBUG", "")
@ -148,13 +154,10 @@ def setup_logging() -> None:
log_path = Path(log_file)
log_path.parent.mkdir(parents=True, exist_ok=True)
else:
# Default log file with timestamp and unique identifier
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
unique_id = os.urandom(4).hex() # Add a unique identifier to prevent collisions
if os.name == "nt": # Windows
log_file = f"C:\\Windows\\Temp\\py-libp2p_{timestamp}_{unique_id}.log"
else: # Unix-like
log_file = f"/tmp/py-libp2p_{timestamp}_{unique_id}.log"
# Use cross-platform temp file creation
from libp2p.utils.paths import create_temp_file
log_file = str(create_temp_file(prefix="py-libp2p_", suffix=".log"))
# Print the log file path so users know where to find it
print(f"Logging to: {log_file}", file=sys.stderr)
@ -195,6 +198,9 @@ def setup_logging() -> None:
logger.setLevel(level)
logger.propagate = False # Prevent message duplication
# Store handlers globally for cleanup
_current_handlers.extend(handlers)
# Start the listener AFTER configuring all loggers
_current_listener = logging.handlers.QueueListener(
log_queue, *handlers, respect_handler_level=True
@ -209,7 +215,13 @@ def setup_logging() -> None:
@atexit.register
def cleanup_logging() -> None:
"""Clean up logging resources on exit."""
global _current_listener
global _current_listener, _current_handlers
if _current_listener is not None:
_current_listener.stop()
_current_listener = None
# Close all file handlers to ensure proper cleanup on Windows
for handler in _current_handlers:
if isinstance(handler, logging.FileHandler):
handler.close()
_current_handlers.clear()

267
libp2p/utils/paths.py Normal file
View File

@ -0,0 +1,267 @@
"""
Cross-platform path utilities for py-libp2p.
This module provides standardized path operations to ensure consistent
behavior across Windows, macOS, and Linux platforms.
"""
import os
from pathlib import Path
import sys
import tempfile
from typing import Union
PathLike = Union[str, Path]
def get_temp_dir() -> Path:
"""
Get cross-platform temporary directory.
Returns:
Path: Platform-specific temporary directory path
"""
return Path(tempfile.gettempdir())
def get_project_root() -> Path:
"""
Get the project root directory.
Returns:
Path: Path to the py-libp2p project root
"""
# Navigate from libp2p/utils/paths.py to project root
return Path(__file__).parent.parent.parent
def join_paths(*parts: PathLike) -> Path:
"""
Cross-platform path joining.
Args:
*parts: Path components to join
Returns:
Path: Joined path using platform-appropriate separator
"""
return Path(*parts)
def ensure_dir_exists(path: PathLike) -> Path:
"""
Ensure directory exists, create if needed.
Args:
path: Directory path to ensure exists
Returns:
Path: Path object for the directory
"""
path_obj = Path(path)
path_obj.mkdir(parents=True, exist_ok=True)
return path_obj
def get_config_dir() -> Path:
"""
Get user config directory (cross-platform).
Returns:
Path: Platform-specific config directory
"""
if os.name == "nt": # Windows
appdata = os.environ.get("APPDATA", "")
if appdata:
return Path(appdata) / "py-libp2p"
else:
# Fallback to user home directory
return Path.home() / "AppData" / "Roaming" / "py-libp2p"
else: # Unix-like (Linux, macOS)
return Path.home() / ".config" / "py-libp2p"
def get_script_dir(script_path: PathLike | None = None) -> Path:
"""
Get the directory containing a script file.
Args:
script_path: Path to the script file. If None, uses __file__
Returns:
Path: Directory containing the script
Raises:
RuntimeError: If script path cannot be determined
"""
if script_path is None:
# This will be the directory of the calling script
import inspect
frame = inspect.currentframe()
if frame and frame.f_back:
script_path = frame.f_back.f_globals.get("__file__")
else:
raise RuntimeError("Could not determine script path")
if script_path is None:
raise RuntimeError("Script path is None")
return Path(script_path).parent.absolute()
def create_temp_file(prefix: str = "py-libp2p_", suffix: str = ".log") -> Path:
"""
Create a temporary file with a unique name.
Args:
prefix: File name prefix
suffix: File name suffix
Returns:
Path: Path to the created temporary file
"""
temp_dir = get_temp_dir()
# Create a unique filename using timestamp and random bytes
import secrets
import time
timestamp = time.strftime("%Y%m%d_%H%M%S")
microseconds = f"{time.time() % 1:.6f}"[2:] # Get microseconds as string
unique_id = secrets.token_hex(4)
filename = f"{prefix}{timestamp}_{microseconds}_{unique_id}{suffix}"
temp_file = temp_dir / filename
# Create the file by touching it
temp_file.touch()
return temp_file
def resolve_relative_path(base_path: PathLike, relative_path: PathLike) -> Path:
"""
Resolve a relative path from a base path.
Args:
base_path: Base directory path
relative_path: Relative path to resolve
Returns:
Path: Resolved absolute path
"""
base = Path(base_path).resolve()
relative = Path(relative_path)
if relative.is_absolute():
return relative
else:
return (base / relative).resolve()
def normalize_path(path: PathLike) -> Path:
"""
Normalize a path, resolving any symbolic links and relative components.
Args:
path: Path to normalize
Returns:
Path: Normalized absolute path
"""
return Path(path).resolve()
def get_venv_path() -> Path | None:
"""
Get virtual environment path if active.
Returns:
Path: Virtual environment path if active, None otherwise
"""
venv_path = os.environ.get("VIRTUAL_ENV")
if venv_path:
return Path(venv_path)
return None
def get_python_executable() -> Path:
"""
Get current Python executable path.
Returns:
Path: Path to the current Python executable
"""
return Path(sys.executable)
def find_executable(name: str) -> Path | None:
"""
Find executable in system PATH.
Args:
name: Name of the executable to find
Returns:
Path: Path to executable if found, None otherwise
"""
# Check if name already contains path
if os.path.dirname(name):
path = Path(name)
if path.exists() and os.access(path, os.X_OK):
return path
return None
# Search in PATH
for path_dir in os.environ.get("PATH", "").split(os.pathsep):
if not path_dir:
continue
path = Path(path_dir) / name
if path.exists() and os.access(path, os.X_OK):
return path
return None
def get_script_binary_path() -> Path:
"""
Get path to script's binary directory.
Returns:
Path: Directory containing the script's binary
"""
return get_python_executable().parent
def get_binary_path(binary_name: str) -> Path | None:
"""
Find binary in PATH or virtual environment.
Args:
binary_name: Name of the binary to find
Returns:
Path: Path to binary if found, None otherwise
"""
# First check in virtual environment if active
venv_path = get_venv_path()
if venv_path:
venv_bin = venv_path / "bin" if os.name != "nt" else venv_path / "Scripts"
binary_path = venv_bin / binary_name
if binary_path.exists() and os.access(binary_path, os.X_OK):
return binary_path
# Fall back to system PATH
return find_executable(binary_name)

View File

@ -0,0 +1 @@
Make TProtocol as Optional[TProtocol] to keep types consistent in py-libp2p/libp2p/protocol_muxer/multiselect.py

View File

@ -0,0 +1 @@
Added Thin Waist address validation utilities (with support for interface enumeration, optimal binding, and wildcard expansion).

View File

@ -0,0 +1,7 @@
Add Thin Waist address validation utilities and integrate into echo example
- Add ``libp2p/utils/address_validation.py`` with dynamic interface discovery
- Implement ``get_available_interfaces()``, ``get_optimal_binding_address()``, and ``expand_wildcard_address()``
- Update echo example to use dynamic address discovery instead of hardcoded wildcard
- Add safe fallbacks for environments lacking Thin Waist support
- Temporarily disable IPv6 support due to libp2p handshake issues (TODO: re-enable when resolved)

View File

@ -0,0 +1 @@
KAD-DHT now include signed-peer-records in its protobuf message schema, for more secure peer-discovery.

View File

@ -0,0 +1,3 @@
Remove the already completed TODO tasks in Peerstore:
TODO: Set up an async task for periodic peer-store cleanup for expired addresses and records.
TODO: Make proper use of this function

View File

@ -0,0 +1 @@
Added `Random Walk` peer discovery module that enables random peer exploration for improved peer discovery.

View File

@ -0,0 +1,6 @@
Implement closed_stream notification in MyNotifee
- Add notify_closed_stream method to swarm notification system for proper stream lifecycle management
- Integrate remove_stream hook in SwarmConn to enable stream closure notifications
- Add comprehensive tests for closed_stream functionality in test_notify.py
- Enable stream lifecycle integration for proper cleanup and resource management

View File

@ -0,0 +1 @@
Added multiselect type consistency in negotiate method. Updates all the usages of the method.

View File

@ -0,0 +1 @@
Fixed message id type inconsistency in handle ihave and message id parsing improvement in handle iwant in pubsub module.

View File

@ -0,0 +1 @@
Fix kbucket splitting in routing table when full. Routing table now maintains multiple kbuckets and properly distributes peers as specified by the Kademlia DHT protocol.

View File

@ -0,0 +1 @@
Add automatic peer dialing in bootstrap module using trio.Nursery.

View File

@ -0,0 +1 @@
Improved PubsubNotifee integration tests and added failure scenario coverage.

View File

@ -0,0 +1 @@
Fix type for gossipsub_message_id for consistency and security

View File

@ -0,0 +1,5 @@
Fix multi-address listening bug in swarm.listen()
- Fix early return in swarm.listen() that prevented listening on all addresses
- Add comprehensive tests for multi-address listening functionality
- Ensure all available interfaces are properly bound and connectable

View File

@ -0,0 +1 @@
Enhanced Swarm networking with retry logic, exponential backoff, and multi-connection support. Added configurable retry mechanisms that automatically recover from transient connection failures using exponential backoff with jitter to prevent thundering herd problems. Introduced connection pooling that allows multiple concurrent connections per peer for improved performance and fault tolerance. Added load balancing across connections and automatic connection health management. All enhancements are fully backward compatible and can be configured through new RetryConfig and ConnectionConfig classes.

View File

@ -0,0 +1,5 @@
Remove unused upgrade_listener function from transport upgrader
- Remove unused `upgrade_listener` function from `libp2p/transport/upgrader.py` (Issue 2 from #726)
- Clean up unused imports related to the removed function
- Improve code maintainability by removing dead code

View File

@ -0,0 +1,2 @@
Fixed cross-platform path handling by replacing hardcoded OS-specific
paths with standardized utilities in core modules and examples.

View File

@ -0,0 +1 @@
PubSub routers now include signed-peer-records in RPC messages for secure peer-info exchange.

View File

@ -10,8 +10,10 @@ readme = "README.md"
requires-python = ">=3.10, <4.0"
license = { text = "MIT AND Apache-2.0" }
keywords = ["libp2p", "p2p"]
authors = [
{ name = "The Ethereum Foundation", email = "snakecharmers@ethereum.org" },
maintainers = [
{ name = "pacrob", email = "pacrob-py-libp2p@proton.me" },
{ name = "Manu Sheel Gupta", email = "manu@seeta.in" },
{ name = "Dave Grantham", email = "dwg@linuxprogrammer.org" },
]
dependencies = [
"base58>=1.0.3",

255
scripts/audit_paths.py Normal file
View File

@ -0,0 +1,255 @@
#!/usr/bin/env python3
"""
Audit script to identify path handling issues in the py-libp2p codebase.
This script scans for patterns that should be migrated to use the new
cross-platform path utilities.
"""
import argparse
from pathlib import Path
import re
from typing import Any
def scan_for_path_issues(directory: Path) -> dict[str, list[dict[str, Any]]]:
"""
Scan for path handling issues in the codebase.
Args:
directory: Root directory to scan
Returns:
Dictionary mapping issue types to lists of found issues
"""
issues = {
"hard_coded_slash": [],
"os_path_join": [],
"temp_hardcode": [],
"os_path_dirname": [],
"os_path_abspath": [],
"direct_path_concat": [],
}
# Patterns to search for
patterns = {
"hard_coded_slash": r'["\'][^"\']*\/[^"\']*["\']',
"os_path_join": r"os\.path\.join\(",
"temp_hardcode": r'["\']\/tmp\/|["\']C:\\\\',
"os_path_dirname": r"os\.path\.dirname\(",
"os_path_abspath": r"os\.path\.abspath\(",
"direct_path_concat": r'["\'][^"\']*["\']\s*\+\s*["\'][^"\']*["\']',
}
# Files to exclude
exclude_patterns = [
r"__pycache__",
r"\.git",
r"\.pytest_cache",
r"\.mypy_cache",
r"\.ruff_cache",
r"env/",
r"venv/",
r"\.venv/",
]
for py_file in directory.rglob("*.py"):
# Skip excluded files
if any(re.search(pattern, str(py_file)) for pattern in exclude_patterns):
continue
try:
content = py_file.read_text(encoding="utf-8")
except UnicodeDecodeError:
print(f"Warning: Could not read {py_file} (encoding issue)")
continue
for issue_type, pattern in patterns.items():
matches = re.finditer(pattern, content, re.MULTILINE)
for match in matches:
line_num = content[: match.start()].count("\n") + 1
line_content = content.split("\n")[line_num - 1].strip()
issues[issue_type].append(
{
"file": py_file,
"line": line_num,
"content": match.group(),
"full_line": line_content,
"relative_path": py_file.relative_to(directory),
}
)
return issues
def generate_migration_suggestions(issues: dict[str, list[dict[str, Any]]]) -> str:
"""
Generate migration suggestions for found issues.
Args:
issues: Dictionary of found issues
Returns:
Formatted string with migration suggestions
"""
suggestions = []
for issue_type, issue_list in issues.items():
if not issue_list:
continue
suggestions.append(f"\n## {issue_type.replace('_', ' ').title()}")
suggestions.append(f"Found {len(issue_list)} instances:")
for issue in issue_list[:10]: # Show first 10 examples
suggestions.append(f"\n### {issue['relative_path']}:{issue['line']}")
suggestions.append("```python")
suggestions.append("# Current code:")
suggestions.append(f"{issue['full_line']}")
suggestions.append("```")
# Add migration suggestion based on issue type
if issue_type == "os_path_join":
suggestions.append("```python")
suggestions.append("# Suggested fix:")
suggestions.append("from libp2p.utils.paths import join_paths")
suggestions.append(
"# Replace os.path.join(a, b, c) with join_paths(a, b, c)"
)
suggestions.append("```")
elif issue_type == "temp_hardcode":
suggestions.append("```python")
suggestions.append("# Suggested fix:")
suggestions.append(
"from libp2p.utils.paths import get_temp_dir, create_temp_file"
)
temp_fix_msg = (
"# Replace hard-coded temp paths with get_temp_dir() or "
"create_temp_file()"
)
suggestions.append(temp_fix_msg)
suggestions.append("```")
elif issue_type == "os_path_dirname":
suggestions.append("```python")
suggestions.append("# Suggested fix:")
suggestions.append("from libp2p.utils.paths import get_script_dir")
script_dir_fix_msg = (
"# Replace os.path.dirname(os.path.abspath(__file__)) with "
"get_script_dir(__file__)"
)
suggestions.append(script_dir_fix_msg)
suggestions.append("```")
if len(issue_list) > 10:
suggestions.append(f"\n... and {len(issue_list) - 10} more instances")
return "\n".join(suggestions)
def generate_summary_report(issues: dict[str, list[dict[str, Any]]]) -> str:
"""
Generate a summary report of all found issues.
Args:
issues: Dictionary of found issues
Returns:
Formatted summary report
"""
total_issues = sum(len(issue_list) for issue_list in issues.values())
report = [
"# Cross-Platform Path Handling Audit Report",
"",
"## Summary",
f"Total issues found: {total_issues}",
"",
"## Issue Breakdown:",
]
for issue_type, issue_list in issues.items():
if issue_list:
issue_title = issue_type.replace("_", " ").title()
instances_count = len(issue_list)
report.append(f"- **{issue_title}**: {instances_count} instances")
report.append("")
report.append("## Priority Matrix:")
report.append("")
report.append("| Priority | Issue Type | Risk Level | Impact |")
report.append("|----------|------------|------------|---------|")
priority_map = {
"temp_hardcode": (
"🔴 P0",
"HIGH",
"Core functionality fails on different platforms",
),
"os_path_join": ("🟡 P1", "MEDIUM", "Examples and utilities may break"),
"os_path_dirname": ("🟡 P1", "MEDIUM", "Script location detection issues"),
"hard_coded_slash": ("🟢 P2", "LOW", "Future-proofing and consistency"),
"os_path_abspath": ("🟢 P2", "LOW", "Path resolution consistency"),
"direct_path_concat": ("🟢 P2", "LOW", "String concatenation issues"),
}
for issue_type, issue_list in issues.items():
if issue_list:
priority, risk, impact = priority_map.get(
issue_type, ("🟢 P2", "LOW", "General improvement")
)
issue_title = issue_type.replace("_", " ").title()
report.append(f"| {priority} | {issue_title} | {risk} | {impact} |")
return "\n".join(report)
def main():
"""Main function to run the audit."""
parser = argparse.ArgumentParser(
description="Audit py-libp2p codebase for path handling issues"
)
parser.add_argument(
"--directory",
default=".",
help="Directory to scan (default: current directory)",
)
parser.add_argument("--output", help="Output file for detailed report")
parser.add_argument(
"--summary-only", action="store_true", help="Only show summary report"
)
args = parser.parse_args()
directory = Path(args.directory)
if not directory.exists():
print(f"Error: Directory {directory} does not exist")
return 1
print("🔍 Scanning for path handling issues...")
issues = scan_for_path_issues(directory)
# Generate and display summary
summary = generate_summary_report(issues)
print(summary)
if not args.summary_only:
# Generate detailed suggestions
suggestions = generate_migration_suggestions(issues)
if args.output:
with open(args.output, "w", encoding="utf-8") as f:
f.write(summary)
f.write(suggestions)
print(f"\n📄 Detailed report saved to {args.output}")
else:
print(suggestions)
return 0
if __name__ == "__main__":
exit(main())

View File

@ -1,3 +1,10 @@
from unittest.mock import (
AsyncMock,
MagicMock,
)
import pytest
from libp2p import (
new_swarm,
)
@ -10,6 +17,9 @@ from libp2p.host.basic_host import (
from libp2p.host.defaults import (
get_default_protocols,
)
from libp2p.host.exceptions import (
StreamFailure,
)
def test_default_protocols():
@ -22,3 +32,30 @@ def test_default_protocols():
# NOTE: comparing keys for equality as handlers may be closures that do not compare
# in the way this test is concerned with
assert handlers.keys() == get_default_protocols(host).keys()
@pytest.mark.trio
async def test_swarm_stream_handler_no_protocol_selected(monkeypatch):
key_pair = create_new_key_pair()
swarm = new_swarm(key_pair)
host = BasicHost(swarm)
# Create a mock net_stream
net_stream = MagicMock()
net_stream.reset = AsyncMock()
net_stream.muxed_conn.peer_id = "peer-test"
# Monkeypatch negotiate to simulate "no protocol selected"
async def fake_negotiate(comm, timeout):
return None, None
monkeypatch.setattr(host.multiselect, "negotiate", fake_negotiate)
# Now run the handler and expect StreamFailure
with pytest.raises(
StreamFailure, match="Failed to negotiate protocol: no protocol selected"
):
await host._swarm_stream_handler(net_stream)
# Ensure reset was called since negotiation failed
net_stream.reset.assert_awaited()

View File

@ -164,8 +164,8 @@ async def test_live_peers_unexpected_drop(security_protocol):
assert peer_a_id in host_b.get_live_peers()
# Simulate unexpected connection drop by directly closing the connection
conn = host_a.get_network().connections[peer_b_id]
await conn.muxed_conn.close()
conns = host_a.get_network().connections[peer_b_id]
await conns[0].muxed_conn.close()
# Allow for connection cleanup
await trio.sleep(0.1)

View File

@ -9,11 +9,15 @@ This module tests core functionality of the Kademlia DHT including:
import hashlib
import logging
import os
from unittest.mock import patch
import uuid
import pytest
import multiaddr
import trio
from libp2p.crypto.rsa import create_new_key_pair
from libp2p.kad_dht.kad_dht import (
DHTMode,
KadDHT,
@ -21,9 +25,13 @@ from libp2p.kad_dht.kad_dht import (
from libp2p.kad_dht.utils import (
create_key_from_binary,
)
from libp2p.peer.envelope import Envelope, seal_record
from libp2p.peer.id import ID
from libp2p.peer.peer_record import PeerRecord
from libp2p.peer.peerinfo import (
PeerInfo,
)
from libp2p.peer.peerstore import create_signed_peer_record
from libp2p.tools.async_service import (
background_trio_service,
)
@ -76,10 +84,52 @@ async def test_find_node(dht_pair: tuple[KadDHT, KadDHT]):
"""Test that nodes can find each other in the DHT."""
dht_a, dht_b = dht_pair
# An extra FIND_NODE req is sent between the 2 nodes while dht creation,
# so both the nodes will have records of each other before the next FIND_NODE
# req is sent
envelope_a = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id())
envelope_b = dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id())
assert isinstance(envelope_a, Envelope)
assert isinstance(envelope_b, Envelope)
record_a = envelope_a.record()
record_b = envelope_b.record()
# Node A should be able to find Node B
with trio.fail_after(TEST_TIMEOUT):
found_info = await dht_a.find_peer(dht_b.host.get_id())
# Verifies if the senderRecord in the FIND_NODE request is correctly processed
assert isinstance(
dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id()), Envelope
)
# Verifies if the senderRecord in the FIND_NODE response is correctly processed
assert isinstance(
dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id()), Envelope
)
# These are the records that were sent between the peers during the FIND_NODE req
envelope_a_find_peer = dht_a.host.get_peerstore().get_peer_record(
dht_b.host.get_id()
)
envelope_b_find_peer = dht_b.host.get_peerstore().get_peer_record(
dht_a.host.get_id()
)
assert isinstance(envelope_a_find_peer, Envelope)
assert isinstance(envelope_b_find_peer, Envelope)
record_a_find_peer = envelope_a_find_peer.record()
record_b_find_peer = envelope_b_find_peer.record()
# This proves that both the records are same, and a latest cached signed record
# was passed between the peers during FIND_NODE execution, which proves the
# signed-record transfer/re-issuing works correctly in FIND_NODE executions.
assert record_a.seq == record_a_find_peer.seq
assert record_b.seq == record_b_find_peer.seq
# Verify that the found peer has the correct peer ID
assert found_info is not None, "Failed to find the target peer"
assert found_info.peer_id == dht_b.host.get_id(), "Found incorrect peer ID"
@ -104,14 +154,44 @@ async def test_put_and_get_value(dht_pair: tuple[KadDHT, KadDHT]):
await dht_a.routing_table.add_peer(peer_b_info)
print("Routing table of a has ", dht_a.routing_table.get_peer_ids())
# An extra FIND_NODE req is sent between the 2 nodes while dht creation,
# so both the nodes will have records of each other before PUT_VALUE req is sent
envelope_a = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id())
envelope_b = dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id())
assert isinstance(envelope_a, Envelope)
assert isinstance(envelope_b, Envelope)
record_a = envelope_a.record()
record_b = envelope_b.record()
# Store the value using the first node (this will also store locally)
with trio.fail_after(TEST_TIMEOUT):
await dht_a.put_value(key, value)
# These are the records that were sent between the peers during the PUT_VALUE req
envelope_a_put_value = dht_a.host.get_peerstore().get_peer_record(
dht_b.host.get_id()
)
envelope_b_put_value = dht_b.host.get_peerstore().get_peer_record(
dht_a.host.get_id()
)
assert isinstance(envelope_a_put_value, Envelope)
assert isinstance(envelope_b_put_value, Envelope)
record_a_put_value = envelope_a_put_value.record()
record_b_put_value = envelope_b_put_value.record()
# This proves that both the records are same, and a latest cached signed record
# was passed between the peers during PUT_VALUE execution, which proves the
# signed-record transfer/re-issuing works correctly in PUT_VALUE executions.
assert record_a.seq == record_a_put_value.seq
assert record_b.seq == record_b_put_value.seq
# # Log debugging information
logger.debug("Put value with key %s...", key.hex()[:10])
logger.debug("Node A value store: %s", dht_a.value_store.store)
print("hello test")
# # Allow more time for the value to propagate
await trio.sleep(0.5)
@ -126,6 +206,26 @@ async def test_put_and_get_value(dht_pair: tuple[KadDHT, KadDHT]):
print("the value stored in node b is", dht_b.get_value_store_size())
logger.debug("Retrieved value: %s", retrieved_value)
# These are the records that were sent between the peers during the PUT_VALUE req
envelope_a_get_value = dht_a.host.get_peerstore().get_peer_record(
dht_b.host.get_id()
)
envelope_b_get_value = dht_b.host.get_peerstore().get_peer_record(
dht_a.host.get_id()
)
assert isinstance(envelope_a_get_value, Envelope)
assert isinstance(envelope_b_get_value, Envelope)
record_a_get_value = envelope_a_get_value.record()
record_b_get_value = envelope_b_get_value.record()
# This proves that there was no record exchange between the nodes during GET_VALUE
# execution, as dht_b already had the key/value pair stored locally after the
# PUT_VALUE execution.
assert record_a_get_value.seq == record_a_put_value.seq
assert record_b_get_value.seq == record_b_put_value.seq
# Verify that the retrieved value matches the original
assert retrieved_value == value, "Retrieved value does not match the stored value"
@ -142,11 +242,44 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]):
# Store content on the first node
dht_a.value_store.put(content_id, content)
# An extra FIND_NODE req is sent between the 2 nodes while dht creation,
# so both the nodes will have records of each other before PUT_VALUE req is sent
envelope_a = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id())
envelope_b = dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id())
assert isinstance(envelope_a, Envelope)
assert isinstance(envelope_b, Envelope)
record_a = envelope_a.record()
record_b = envelope_b.record()
# Advertise the first node as a provider
with trio.fail_after(TEST_TIMEOUT):
success = await dht_a.provide(content_id)
assert success, "Failed to advertise as provider"
# These are the records that were sent between the peers during
# the ADD_PROVIDER req
envelope_a_add_prov = dht_a.host.get_peerstore().get_peer_record(
dht_b.host.get_id()
)
envelope_b_add_prov = dht_b.host.get_peerstore().get_peer_record(
dht_a.host.get_id()
)
assert isinstance(envelope_a_add_prov, Envelope)
assert isinstance(envelope_b_add_prov, Envelope)
record_a_add_prov = envelope_a_add_prov.record()
record_b_add_prov = envelope_b_add_prov.record()
# This proves that both the records are same, the latest cached signed record
# was passed between the peers during ADD_PROVIDER execution, which proves the
# signed-record transfer/re-issuing of the latest record works correctly in
# ADD_PROVIDER executions.
assert record_a.seq == record_a_add_prov.seq
assert record_b.seq == record_b_add_prov.seq
# Allow time for the provider record to propagate
await trio.sleep(0.1)
@ -154,6 +287,26 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]):
with trio.fail_after(TEST_TIMEOUT):
providers = await dht_b.find_providers(content_id)
# These are the records in each peer after the find_provider execution
envelope_a_find_prov = dht_a.host.get_peerstore().get_peer_record(
dht_b.host.get_id()
)
envelope_b_find_prov = dht_b.host.get_peerstore().get_peer_record(
dht_a.host.get_id()
)
assert isinstance(envelope_a_find_prov, Envelope)
assert isinstance(envelope_b_find_prov, Envelope)
record_a_find_prov = envelope_a_find_prov.record()
record_b_find_prov = envelope_b_find_prov.record()
# This proves that both the records are same, as the dht_b already
# has the provider record for the content_id, after the ADD_PROVIDER
# advertisement by dht_a
assert record_a_find_prov.seq == record_a_add_prov.seq
assert record_b_find_prov.seq == record_b_add_prov.seq
# Verify that we found the first node as a provider
assert providers, "No providers found"
assert any(p.peer_id == dht_a.local_peer_id for p in providers), (
@ -166,3 +319,143 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]):
assert retrieved_value == content, (
"Retrieved content does not match the original"
)
# These are the record state of each peer aftet the GET_VALUE execution
envelope_a_get_value = dht_a.host.get_peerstore().get_peer_record(
dht_b.host.get_id()
)
envelope_b_get_value = dht_b.host.get_peerstore().get_peer_record(
dht_a.host.get_id()
)
assert isinstance(envelope_a_get_value, Envelope)
assert isinstance(envelope_b_get_value, Envelope)
record_a_get_value = envelope_a_get_value.record()
record_b_get_value = envelope_b_get_value.record()
# This proves that both the records are same, meaning that the latest cached
# signed-record tranfer happened during the GET_VALUE execution by dht_b,
# which means the signed-record transfer/re-issuing works correctly
# in GET_VALUE executions.
assert record_a_find_prov.seq == record_a_get_value.seq
assert record_b_find_prov.seq == record_b_get_value.seq
# Create a new provider record in dht_a
provider_key_pair = create_new_key_pair()
provider_peer_id = ID.from_pubkey(provider_key_pair.public_key)
provider_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/123")
provider_peer_info = PeerInfo(peer_id=provider_peer_id, addrs=[provider_addr])
# Generate a random content ID
content_2 = f"random-content-{uuid.uuid4()}".encode()
content_id_2 = hashlib.sha256(content_2).digest()
provider_signed_envelope = create_signed_peer_record(
provider_peer_id, [provider_addr], provider_key_pair.private_key
)
assert (
dht_a.host.get_peerstore().consume_peer_record(provider_signed_envelope, 7200)
is True
)
# Store this provider record in dht_a
dht_a.provider_store.add_provider(content_id_2, provider_peer_info)
# Fetch the provider-record via peer-discovery at dht_b's end
peerinfo = await dht_b.provider_store.find_providers(content_id_2)
assert len(peerinfo) == 1
assert peerinfo[0].peer_id == provider_peer_id
provider_envelope = dht_b.host.get_peerstore().get_peer_record(provider_peer_id)
# This proves that the signed-envelope of provider is consumed on dht_b's end
assert provider_envelope is not None
assert (
provider_signed_envelope.marshal_envelope()
== provider_envelope.marshal_envelope()
)
@pytest.mark.trio
async def test_reissue_when_listen_addrs_change(dht_pair: tuple[KadDHT, KadDHT]):
dht_a, dht_b = dht_pair
# Warm-up: A stores B's current record
with trio.fail_after(10):
await dht_a.find_peer(dht_b.host.get_id())
env0 = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id())
assert isinstance(env0, Envelope)
seq0 = env0.record().seq
# Simulate B's listen addrs changing (different port)
new_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/123")
# Patch just for the duration we force B to respond:
with patch.object(dht_b.host, "get_addrs", return_value=[new_addr]):
# Force B to send a response (which should include a fresh SPR)
with trio.fail_after(10):
await dht_a.peer_routing._query_peer_for_closest(
dht_b.host.get_id(), os.urandom(32)
)
# A should now hold B's new record with a bumped seq
env1 = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id())
assert isinstance(env1, Envelope)
seq1 = env1.record().seq
# This proves that upon the change in listen_addrs, we issue new records
assert seq1 > seq0, f"Expected seq to bump after addr change, got {seq0} -> {seq1}"
@pytest.mark.trio
async def test_dht_req_fail_with_invalid_record_transfer(
dht_pair: tuple[KadDHT, KadDHT],
):
"""
Testing showing failure of storing and retrieving values in the DHT,
if invalid signed-records are sent.
"""
dht_a, dht_b = dht_pair
peer_b_info = PeerInfo(dht_b.host.get_id(), dht_b.host.get_addrs())
# Generate a random key and value
key = create_key_from_binary(b"test-key")
value = b"test-value"
# First add the value directly to node A's store to verify storage works
dht_a.value_store.put(key, value)
local_value = dht_a.value_store.get(key)
assert local_value == value, "Local value storage failed"
await dht_a.routing_table.add_peer(peer_b_info)
# Corrupt dht_a's local peer_record
envelope = dht_a.host.get_peerstore().get_local_record()
if envelope is not None:
true_record = envelope.record()
key_pair = create_new_key_pair()
if envelope is not None:
envelope.public_key = key_pair.public_key
dht_a.host.get_peerstore().set_local_record(envelope)
await dht_a.put_value(key, value)
retrieved_value = dht_b.value_store.get(key)
# This proves that DHT_B rejected DHT_A PUT_RECORD req upon receiving
# the corrupted invalid record
assert retrieved_value is None
# Create a corrupt envelope with correct signature but false peer_id
false_record = PeerRecord(ID.from_pubkey(key_pair.public_key), true_record.addrs)
false_envelope = seal_record(false_record, dht_a.host.get_private_key())
dht_a.host.get_peerstore().set_local_record(false_envelope)
await dht_a.put_value(key, value)
retrieved_value = dht_b.value_store.get(key)
# This proves that DHT_B rejected DHT_A PUT_RECORD req upon receving
# the record with a different peer_id regardless of a valid signature
assert retrieved_value is None

View File

@ -57,7 +57,10 @@ class TestPeerRouting:
def mock_host(self):
"""Create a mock host for testing."""
host = Mock()
host.get_id.return_value = create_valid_peer_id("local")
key_pair = create_new_key_pair()
host.get_id.return_value = ID.from_pubkey(key_pair.public_key)
host.get_public_key.return_value = key_pair.public_key
host.get_private_key.return_value = key_pair.private_key
host.get_addrs.return_value = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
host.get_peerstore.return_value = Mock()
host.new_stream = AsyncMock()

View File

@ -226,6 +226,32 @@ class TestKBucket:
class TestRoutingTable:
"""Test suite for RoutingTable class."""
@pytest.mark.trio
async def test_kbucket_split_behavior(self, mock_host, local_peer_id):
"""
Test that adding more than BUCKET_SIZE peers to the routing table
triggers kbucket splitting and all peers are added.
"""
routing_table = RoutingTable(local_peer_id, mock_host)
num_peers = BUCKET_SIZE + 5
peer_ids = []
for i in range(num_peers):
key_pair = create_new_key_pair()
peer_id = ID.from_pubkey(key_pair.public_key)
peer_info = PeerInfo(peer_id, [Multiaddr(f"/ip4/127.0.0.1/tcp/{9000 + i}")])
peer_ids.append(peer_id)
added = await routing_table.add_peer(peer_info)
assert added, f"Peer {peer_id} should be added"
assert len(routing_table.buckets) > 1, "KBucket splitting did not occur"
for pid in peer_ids:
assert routing_table.peer_in_table(pid), f"Peer {pid} not found after split"
all_peer_ids = routing_table.get_peer_ids()
assert set(peer_ids).issubset(set(all_peer_ids)), (
"Not all peers present after split"
)
@pytest.fixture
def mock_host(self):
"""Create a mock host for testing."""

View File

@ -0,0 +1,325 @@
import time
from typing import cast
from unittest.mock import Mock
import pytest
from multiaddr import Multiaddr
import trio
from libp2p.abc import INetConn, INetStream
from libp2p.network.exceptions import SwarmException
from libp2p.network.swarm import (
ConnectionConfig,
RetryConfig,
Swarm,
)
from libp2p.peer.id import ID
class MockConnection(INetConn):
"""Mock connection for testing."""
def __init__(self, peer_id: ID, is_closed: bool = False):
self.peer_id = peer_id
self._is_closed = is_closed
self.streams = set() # Track streams properly
# Mock the muxed_conn attribute that Swarm expects
self.muxed_conn = Mock()
self.muxed_conn.peer_id = peer_id
# Required by INetConn interface
self.event_started = trio.Event()
async def close(self):
self._is_closed = True
@property
def is_closed(self) -> bool:
return self._is_closed
async def new_stream(self) -> INetStream:
# Create a mock stream and add it to the connection's stream set
mock_stream = Mock(spec=INetStream)
self.streams.add(mock_stream)
return mock_stream
def get_streams(self) -> tuple[INetStream, ...]:
"""Return all streams associated with this connection."""
return tuple(self.streams)
def get_transport_addresses(self) -> list[Multiaddr]:
"""Mock implementation of get_transport_addresses."""
return []
class MockNetStream(INetStream):
"""Mock network stream for testing."""
def __init__(self, peer_id: ID):
self.peer_id = peer_id
@pytest.mark.trio
async def test_retry_config_defaults():
"""Test RetryConfig default values."""
config = RetryConfig()
assert config.max_retries == 3
assert config.initial_delay == 0.1
assert config.max_delay == 30.0
assert config.backoff_multiplier == 2.0
assert config.jitter_factor == 0.1
@pytest.mark.trio
async def test_connection_config_defaults():
"""Test ConnectionConfig default values."""
config = ConnectionConfig()
assert config.max_connections_per_peer == 3
assert config.connection_timeout == 30.0
assert config.load_balancing_strategy == "round_robin"
@pytest.mark.trio
async def test_enhanced_swarm_constructor():
"""Test enhanced Swarm constructor with new configuration."""
# Create mock dependencies
peer_id = ID(b"QmTest")
peerstore = Mock()
upgrader = Mock()
transport = Mock()
# Test with default config
swarm = Swarm(peer_id, peerstore, upgrader, transport)
assert swarm.retry_config.max_retries == 3
assert swarm.connection_config.max_connections_per_peer == 3
assert isinstance(swarm.connections, dict)
# Test with custom config
custom_retry = RetryConfig(max_retries=5, initial_delay=0.5)
custom_conn = ConnectionConfig(max_connections_per_peer=5)
swarm = Swarm(peer_id, peerstore, upgrader, transport, custom_retry, custom_conn)
assert swarm.retry_config.max_retries == 5
assert swarm.retry_config.initial_delay == 0.5
assert swarm.connection_config.max_connections_per_peer == 5
@pytest.mark.trio
async def test_swarm_backoff_calculation():
"""Test exponential backoff calculation with jitter."""
peer_id = ID(b"QmTest")
peerstore = Mock()
upgrader = Mock()
transport = Mock()
retry_config = RetryConfig(
initial_delay=0.1, max_delay=1.0, backoff_multiplier=2.0, jitter_factor=0.1
)
swarm = Swarm(peer_id, peerstore, upgrader, transport, retry_config)
# Test backoff calculation
delay1 = swarm._calculate_backoff_delay(0)
delay2 = swarm._calculate_backoff_delay(1)
delay3 = swarm._calculate_backoff_delay(2)
# Should increase exponentially
assert delay2 > delay1
assert delay3 > delay2
# Should respect max delay
assert delay1 <= 1.0
assert delay2 <= 1.0
assert delay3 <= 1.0
# Should have jitter
assert delay1 != 0.1 # Should have jitter added
@pytest.mark.trio
async def test_swarm_retry_logic():
"""Test retry logic in dial operations."""
peer_id = ID(b"QmTest")
peerstore = Mock()
upgrader = Mock()
transport = Mock()
# Configure for fast testing
retry_config = RetryConfig(
max_retries=2,
initial_delay=0.01, # Very short for testing
max_delay=0.1,
)
swarm = Swarm(peer_id, peerstore, upgrader, transport, retry_config)
# Mock the single attempt method to fail twice then succeed
attempt_count = [0]
async def mock_single_attempt(addr, peer_id):
attempt_count[0] += 1
if attempt_count[0] < 3:
raise SwarmException(f"Attempt {attempt_count[0]} failed")
return MockConnection(peer_id)
swarm._dial_addr_single_attempt = mock_single_attempt
# Test retry logic
start_time = time.time()
result = await swarm._dial_with_retry(Mock(spec=Multiaddr), peer_id)
end_time = time.time()
# Should have succeeded after 3 attempts
assert attempt_count[0] == 3
assert isinstance(result, MockConnection)
assert end_time - start_time > 0.01 # Should have some delay
@pytest.mark.trio
async def test_swarm_load_balancing_strategies():
"""Test load balancing strategies."""
peer_id = ID(b"QmTest")
peerstore = Mock()
upgrader = Mock()
transport = Mock()
swarm = Swarm(peer_id, peerstore, upgrader, transport)
# Create mock connections with different stream counts
conn1 = MockConnection(peer_id)
conn2 = MockConnection(peer_id)
conn3 = MockConnection(peer_id)
# Add some streams to simulate load
await conn1.new_stream()
await conn1.new_stream()
await conn2.new_stream()
connections = [conn1, conn2, conn3]
# Test round-robin strategy
swarm.connection_config.load_balancing_strategy = "round_robin"
# Cast to satisfy type checker
connections_cast = cast("list[INetConn]", connections)
selected1 = swarm._select_connection(connections_cast, peer_id)
selected2 = swarm._select_connection(connections_cast, peer_id)
selected3 = swarm._select_connection(connections_cast, peer_id)
# Should cycle through connections
assert selected1 in connections
assert selected2 in connections
assert selected3 in connections
# Test least loaded strategy
swarm.connection_config.load_balancing_strategy = "least_loaded"
least_loaded = swarm._select_connection(connections_cast, peer_id)
# conn3 has 0 streams, conn2 has 1 stream, conn1 has 2 streams
# So conn3 should be selected as least loaded
assert least_loaded == conn3
# Test default strategy (first connection)
swarm.connection_config.load_balancing_strategy = "unknown"
default_selected = swarm._select_connection(connections_cast, peer_id)
assert default_selected == conn1
@pytest.mark.trio
async def test_swarm_multiple_connections_api():
"""Test the new multiple connections API methods."""
peer_id = ID(b"QmTest")
peerstore = Mock()
upgrader = Mock()
transport = Mock()
swarm = Swarm(peer_id, peerstore, upgrader, transport)
# Test empty connections
assert swarm.get_connections() == []
assert swarm.get_connections(peer_id) == []
assert swarm.get_connection(peer_id) is None
assert swarm.get_connections_map() == {}
# Add some connections
conn1 = MockConnection(peer_id)
conn2 = MockConnection(peer_id)
swarm.connections[peer_id] = [conn1, conn2]
# Test get_connections with peer_id
peer_connections = swarm.get_connections(peer_id)
assert len(peer_connections) == 2
assert conn1 in peer_connections
assert conn2 in peer_connections
# Test get_connections without peer_id (all connections)
all_connections = swarm.get_connections()
assert len(all_connections) == 2
assert conn1 in all_connections
assert conn2 in all_connections
# Test get_connection (backward compatibility)
single_conn = swarm.get_connection(peer_id)
assert single_conn in [conn1, conn2]
# Test get_connections_map
connections_map = swarm.get_connections_map()
assert peer_id in connections_map
assert connections_map[peer_id] == [conn1, conn2]
@pytest.mark.trio
async def test_swarm_connection_trimming():
"""Test connection trimming when limit is exceeded."""
peer_id = ID(b"QmTest")
peerstore = Mock()
upgrader = Mock()
transport = Mock()
# Set max connections to 2
connection_config = ConnectionConfig(max_connections_per_peer=2)
swarm = Swarm(
peer_id, peerstore, upgrader, transport, connection_config=connection_config
)
# Add 3 connections
conn1 = MockConnection(peer_id)
conn2 = MockConnection(peer_id)
conn3 = MockConnection(peer_id)
swarm.connections[peer_id] = [conn1, conn2, conn3]
# Trigger trimming
swarm._trim_connections(peer_id)
# Should have only 2 connections
assert len(swarm.connections[peer_id]) == 2
# The most recent connections should remain
remaining = swarm.connections[peer_id]
assert conn2 in remaining
assert conn3 in remaining
@pytest.mark.trio
async def test_swarm_backward_compatibility():
"""Test backward compatibility features."""
peer_id = ID(b"QmTest")
peerstore = Mock()
upgrader = Mock()
transport = Mock()
swarm = Swarm(peer_id, peerstore, upgrader, transport)
# Add connections
conn1 = MockConnection(peer_id)
conn2 = MockConnection(peer_id)
swarm.connections[peer_id] = [conn1, conn2]
# Test connections_legacy property
legacy_connections = swarm.connections_legacy
assert peer_id in legacy_connections
# Should return first connection
assert legacy_connections[peer_id] in [conn1, conn2]
if __name__ == "__main__":
pytest.main([__file__])

View File

@ -0,0 +1,82 @@
import pytest
from multiaddr import Multiaddr
import trio
from libp2p.abc import (
INetConn,
INetStream,
INetwork,
INotifee,
)
from libp2p.tools.utils import connect_swarm
from tests.utils.factories import SwarmFactory
class CountingNotifee(INotifee):
def __init__(self, event: trio.Event) -> None:
self._event = event
async def opened_stream(self, network: INetwork, stream: INetStream) -> None:
pass
async def closed_stream(self, network: INetwork, stream: INetStream) -> None:
pass
async def connected(self, network: INetwork, conn: INetConn) -> None:
self._event.set()
async def disconnected(self, network: INetwork, conn: INetConn) -> None:
pass
async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None:
pass
async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None:
pass
class SlowNotifee(INotifee):
async def opened_stream(self, network: INetwork, stream: INetStream) -> None:
pass
async def closed_stream(self, network: INetwork, stream: INetStream) -> None:
pass
async def connected(self, network: INetwork, conn: INetConn) -> None:
await trio.sleep(0.5)
async def disconnected(self, network: INetwork, conn: INetConn) -> None:
pass
async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None:
pass
async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None:
pass
@pytest.mark.trio
async def test_many_notifees_receive_connected_quickly() -> None:
async with SwarmFactory.create_batch_and_listen(2) as swarms:
count = 200
events = [trio.Event() for _ in range(count)]
for ev in events:
swarms[0].register_notifee(CountingNotifee(ev))
await connect_swarm(swarms[0], swarms[1])
with trio.fail_after(1.5):
for ev in events:
await ev.wait()
@pytest.mark.trio
async def test_slow_notifee_does_not_block_others() -> None:
async with SwarmFactory.create_batch_and_listen(2) as swarms:
fast_events = [trio.Event() for _ in range(20)]
for ev in fast_events:
swarms[0].register_notifee(CountingNotifee(ev))
swarms[0].register_notifee(SlowNotifee())
await connect_swarm(swarms[0], swarms[1])
# Fast notifees should complete quickly despite one slow notifee
with trio.fail_after(0.3):
for ev in fast_events:
await ev.wait()

View File

@ -5,11 +5,12 @@ the stream passed into opened_stream is correct.
Note: Listen event does not get hit because MyNotifee is passed
into network after network has already started listening
TODO: Add tests for closed_stream, listen_close when those
features are implemented in swarm
Note: ClosedStream events are processed asynchronously and may not be
immediately available due to the rapid nature of operations
"""
import enum
from unittest.mock import Mock
import pytest
from multiaddr import Multiaddr
@ -29,11 +30,11 @@ from tests.utils.factories import (
class Event(enum.Enum):
OpenedStream = 0
ClosedStream = 1 # Not implemented
ClosedStream = 1
Connected = 2
Disconnected = 3
Listen = 4
ListenClose = 5 # Not implemented
ListenClose = 5
class MyNotifee(INotifee):
@ -44,8 +45,11 @@ class MyNotifee(INotifee):
self.events.append(Event.OpenedStream)
async def closed_stream(self, network: INetwork, stream: INetStream) -> None:
# TODO: It is not implemented yet.
pass
if network is None:
raise ValueError("network parameter cannot be None")
if stream is None:
raise ValueError("stream parameter cannot be None")
self.events.append(Event.ClosedStream)
async def connected(self, network: INetwork, conn: INetConn) -> None:
self.events.append(Event.Connected)
@ -57,8 +61,11 @@ class MyNotifee(INotifee):
self.events.append(Event.Listen)
async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None:
# TODO: It is not implemented yet.
pass
if network is None:
raise ValueError("network parameter cannot be None")
if multiaddr is None:
raise ValueError("multiaddr parameter cannot be None")
self.events.append(Event.ListenClose)
@pytest.mark.trio
@ -103,28 +110,188 @@ async def test_notify(security_protocol):
# Wait for events
assert await wait_for_event(events_0_0, Event.Connected, 1.0)
assert await wait_for_event(events_0_0, Event.OpenedStream, 1.0)
# assert await wait_for_event(
# events_0_0, Event.ClosedStream, 1.0
# ) # Not implemented
assert await wait_for_event(events_0_0, Event.ClosedStream, 1.0)
assert await wait_for_event(events_0_0, Event.Disconnected, 1.0)
assert await wait_for_event(events_0_1, Event.Connected, 1.0)
assert await wait_for_event(events_0_1, Event.OpenedStream, 1.0)
# assert await wait_for_event(
# events_0_1, Event.ClosedStream, 1.0
# ) # Not implemented
assert await wait_for_event(events_0_1, Event.ClosedStream, 1.0)
assert await wait_for_event(events_0_1, Event.Disconnected, 1.0)
assert await wait_for_event(events_1_0, Event.Connected, 1.0)
assert await wait_for_event(events_1_0, Event.OpenedStream, 1.0)
# assert await wait_for_event(
# events_1_0, Event.ClosedStream, 1.0
# ) # Not implemented
assert await wait_for_event(events_1_0, Event.ClosedStream, 1.0)
assert await wait_for_event(events_1_0, Event.Disconnected, 1.0)
assert await wait_for_event(events_1_1, Event.Connected, 1.0)
assert await wait_for_event(events_1_1, Event.OpenedStream, 1.0)
# assert await wait_for_event(
# events_1_1, Event.ClosedStream, 1.0
# ) # Not implemented
assert await wait_for_event(events_1_1, Event.ClosedStream, 1.0)
assert await wait_for_event(events_1_1, Event.Disconnected, 1.0)
# Note: ListenClose events are triggered when swarm closes during cleanup
# The test framework automatically closes listeners, triggering ListenClose
# notifications
async def wait_for_event(events_list, event, timeout=1.0):
"""Helper to wait for a specific event to appear in the events list."""
with trio.move_on_after(timeout):
while event not in events_list:
await trio.sleep(0.01)
return True
return False
@pytest.mark.trio
async def test_notify_with_closed_stream_and_listen_close():
"""Test that closed_stream and listen_close events are properly triggered."""
# Event lists for notifees
events_0 = []
events_1 = []
# Create two swarms
async with SwarmFactory.create_batch_and_listen(2) as swarms:
# Register notifees
notifee_0 = MyNotifee(events_0)
notifee_1 = MyNotifee(events_1)
swarms[0].register_notifee(notifee_0)
swarms[1].register_notifee(notifee_1)
# Connect swarms
await connect_swarm(swarms[0], swarms[1])
# Create and close a stream to trigger closed_stream event
stream = await swarms[0].new_stream(swarms[1].get_peer_id())
await stream.close()
# Note: Events are processed asynchronously and may not be immediately available
# due to the rapid nature of operations
@pytest.mark.trio
async def test_notify_edge_cases():
"""Test edge cases for notify system."""
events = []
async with SwarmFactory.create_batch_and_listen(2) as swarms:
notifee = MyNotifee(events)
swarms[0].register_notifee(notifee)
# Connect swarms first
await connect_swarm(swarms[0], swarms[1])
# Test 1: Multiple rapid stream operations
streams = []
for _ in range(5):
stream = await swarms[0].new_stream(swarms[1].get_peer_id())
streams.append(stream)
# Close all streams rapidly
for stream in streams:
await stream.close()
@pytest.mark.trio
async def test_my_notifee_error_handling():
"""Test error handling for invalid parameters in MyNotifee methods."""
events = []
notifee = MyNotifee(events)
# Mock objects for testing
mock_network = Mock(spec=INetwork)
mock_stream = Mock(spec=INetStream)
mock_multiaddr = Mock(spec=Multiaddr)
# Test closed_stream with None parameters
with pytest.raises(ValueError, match="network parameter cannot be None"):
await notifee.closed_stream(None, mock_stream) # type: ignore
with pytest.raises(ValueError, match="stream parameter cannot be None"):
await notifee.closed_stream(mock_network, None) # type: ignore
# Test listen_close with None parameters
with pytest.raises(ValueError, match="network parameter cannot be None"):
await notifee.listen_close(None, mock_multiaddr) # type: ignore
with pytest.raises(ValueError, match="multiaddr parameter cannot be None"):
await notifee.listen_close(mock_network, None) # type: ignore
# Verify no events were recorded due to errors
assert len(events) == 0
@pytest.mark.trio
async def test_rapid_stream_operations():
"""Test rapid stream open/close operations."""
events_0 = []
events_1 = []
async with SwarmFactory.create_batch_and_listen(2) as swarms:
notifee_0 = MyNotifee(events_0)
notifee_1 = MyNotifee(events_1)
swarms[0].register_notifee(notifee_0)
swarms[1].register_notifee(notifee_1)
# Connect swarms
await connect_swarm(swarms[0], swarms[1])
# Rapidly create and close multiple streams
streams = []
for _ in range(3):
stream = await swarms[0].new_stream(swarms[1].get_peer_id())
streams.append(stream)
# Close all streams immediately
for stream in streams:
await stream.close()
# Verify OpenedStream events are recorded
assert events_0.count(Event.OpenedStream) == 3
assert events_1.count(Event.OpenedStream) == 3
# Close peer to trigger disconnection events
await swarms[0].close_peer(swarms[1].get_peer_id())
@pytest.mark.trio
async def test_concurrent_stream_operations():
"""Test concurrent stream operations using trio nursery."""
events_0 = []
events_1 = []
async with SwarmFactory.create_batch_and_listen(2) as swarms:
notifee_0 = MyNotifee(events_0)
notifee_1 = MyNotifee(events_1)
swarms[0].register_notifee(notifee_0)
swarms[1].register_notifee(notifee_1)
# Connect swarms
await connect_swarm(swarms[0], swarms[1])
async def create_and_close_stream():
"""Create and immediately close a stream."""
stream = await swarms[0].new_stream(swarms[1].get_peer_id())
await stream.close()
# Run multiple stream operations concurrently
async with trio.open_nursery() as nursery:
for _ in range(4):
nursery.start_soon(create_and_close_stream)
# Verify some OpenedStream events are recorded
# (concurrent operations may not all succeed)
opened_count_0 = events_0.count(Event.OpenedStream)
opened_count_1 = events_1.count(Event.OpenedStream)
assert opened_count_0 > 0, (
f"Expected some OpenedStream events, got {opened_count_0}"
)
assert opened_count_1 > 0, (
f"Expected some OpenedStream events, got {opened_count_1}"
)
# Close peer to trigger disconnection events
await swarms[0].close_peer(swarms[1].get_peer_id())

View File

@ -0,0 +1,76 @@
import enum
import pytest
from multiaddr import Multiaddr
import trio
from libp2p.abc import (
INetConn,
INetStream,
INetwork,
INotifee,
)
from libp2p.tools.async_service import background_trio_service
from libp2p.tools.constants import LISTEN_MADDR
from tests.utils.factories import SwarmFactory
class Event(enum.Enum):
Listen = 0
ListenClose = 1
class MyNotifee(INotifee):
def __init__(self, events: list[Event]):
self.events = events
async def opened_stream(self, network: INetwork, stream: INetStream) -> None:
pass
async def closed_stream(self, network: INetwork, stream: INetStream) -> None:
pass
async def connected(self, network: INetwork, conn: INetConn) -> None:
pass
async def disconnected(self, network: INetwork, conn: INetConn) -> None:
pass
async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None:
self.events.append(Event.Listen)
async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None:
self.events.append(Event.ListenClose)
async def wait_for_event(
events_list: list[Event], event: Event, timeout: float = 1.0
) -> bool:
with trio.move_on_after(timeout):
while event not in events_list:
await trio.sleep(0.01)
return True
return False
@pytest.mark.trio
async def test_listen_emitted_when_registered_before_listen():
events: list[Event] = []
swarm = SwarmFactory.build()
swarm.register_notifee(MyNotifee(events))
async with background_trio_service(swarm):
# Start listening now; notifee was registered beforehand
assert await swarm.listen(LISTEN_MADDR)
assert await wait_for_event(events, Event.Listen)
@pytest.mark.trio
async def test_single_listener_close_emits_listen_close():
events: list[Event] = []
swarm = SwarmFactory.build()
swarm.register_notifee(MyNotifee(events))
async with background_trio_service(swarm):
assert await swarm.listen(LISTEN_MADDR)
# Explicitly notify listen_close (close path via manager doesn't emit it)
await swarm.notify_listen_close(LISTEN_MADDR)
assert await wait_for_event(events, Event.ListenClose)

View File

@ -16,6 +16,9 @@ from libp2p.network.exceptions import (
from libp2p.network.swarm import (
Swarm,
)
from libp2p.tools.async_service import (
background_trio_service,
)
from libp2p.tools.utils import (
connect_swarm,
)
@ -48,14 +51,19 @@ async def test_swarm_dial_peer(security_protocol):
for addr in transport.get_addrs()
)
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
await swarms[0].dial_peer(swarms[1].get_peer_id())
# New: dial_peer now returns list of connections
connections = await swarms[0].dial_peer(swarms[1].get_peer_id())
assert len(connections) > 0
# Verify connections are established in both directions
assert swarms[0].get_peer_id() in swarms[1].connections
assert swarms[1].get_peer_id() in swarms[0].connections
# Test: Reuse connections when we already have ones with a peer.
conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()]
conn = await swarms[0].dial_peer(swarms[1].get_peer_id())
assert conn is conn_to_1
existing_connections = swarms[0].get_connections(swarms[1].get_peer_id())
new_connections = await swarms[0].dial_peer(swarms[1].get_peer_id())
assert new_connections == existing_connections
@pytest.mark.trio
@ -104,7 +112,8 @@ async def test_swarm_close_peer(security_protocol):
@pytest.mark.trio
async def test_swarm_remove_conn(swarm_pair):
swarm_0, swarm_1 = swarm_pair
conn_0 = swarm_0.connections[swarm_1.get_peer_id()]
# Get the first connection from the list
conn_0 = swarm_0.connections[swarm_1.get_peer_id()][0]
swarm_0.remove_conn(conn_0)
assert swarm_1.get_peer_id() not in swarm_0.connections
# Test: Remove twice. There should not be errors.
@ -112,6 +121,67 @@ async def test_swarm_remove_conn(swarm_pair):
assert swarm_1.get_peer_id() not in swarm_0.connections
@pytest.mark.trio
async def test_swarm_multiple_connections(security_protocol):
"""Test multiple connections per peer functionality."""
async with SwarmFactory.create_batch_and_listen(
2, security_protocol=security_protocol
) as swarms:
# Setup multiple addresses for peer
addrs = tuple(
addr
for transport in swarms[1].listeners.values()
for addr in transport.get_addrs()
)
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
# Dial peer - should return list of connections
connections = await swarms[0].dial_peer(swarms[1].get_peer_id())
assert len(connections) > 0
# Test get_connections method
peer_connections = swarms[0].get_connections(swarms[1].get_peer_id())
assert len(peer_connections) == len(connections)
# Test get_connections_map method
connections_map = swarms[0].get_connections_map()
assert swarms[1].get_peer_id() in connections_map
assert len(connections_map[swarms[1].get_peer_id()]) == len(connections)
# Test get_connection method (backward compatibility)
single_conn = swarms[0].get_connection(swarms[1].get_peer_id())
assert single_conn is not None
assert single_conn in connections
@pytest.mark.trio
async def test_swarm_load_balancing(security_protocol):
"""Test load balancing across multiple connections."""
async with SwarmFactory.create_batch_and_listen(
2, security_protocol=security_protocol
) as swarms:
# Setup connection
addrs = tuple(
addr
for transport in swarms[1].listeners.values()
for addr in transport.get_addrs()
)
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
# Create multiple streams - should use load balancing
streams = []
for _ in range(5):
stream = await swarms[0].new_stream(swarms[1].get_peer_id())
streams.append(stream)
# Verify streams were created successfully
assert len(streams) == 5
# Clean up
for stream in streams:
await stream.close()
@pytest.mark.trio
async def test_swarm_multiaddr(security_protocol):
async with SwarmFactory.create_batch_and_listen(
@ -184,3 +254,116 @@ def test_new_swarm_quic_multiaddr_raises():
addr = Multiaddr("/ip4/127.0.0.1/udp/9999/quic")
with pytest.raises(ValueError, match="QUIC not yet supported"):
new_swarm(listen_addrs=[addr])
@pytest.mark.trio
async def test_swarm_listen_multiple_addresses(security_protocol):
"""Test that swarm can listen on multiple addresses simultaneously."""
from libp2p.utils.address_validation import get_available_interfaces
# Get multiple addresses to listen on
listen_addrs = get_available_interfaces(0) # Let OS choose ports
# Create a swarm and listen on multiple addresses
swarm = SwarmFactory.build(security_protocol=security_protocol)
async with background_trio_service(swarm):
# Listen on all addresses
success = await swarm.listen(*listen_addrs)
assert success, "Should successfully listen on at least one address"
# Check that we have listeners for the addresses
actual_listeners = list(swarm.listeners.keys())
assert len(actual_listeners) > 0, "Should have at least one listener"
# Verify that all successful listeners are in the listeners dict
successful_count = 0
for addr in listen_addrs:
addr_str = str(addr)
if addr_str in actual_listeners:
successful_count += 1
# This address successfully started listening
listener = swarm.listeners[addr_str]
listener_addrs = listener.get_addrs()
assert len(listener_addrs) > 0, (
f"Listener for {addr} should have addresses"
)
# Check that the listener address matches the expected address
# (port might be different if we used port 0)
expected_ip = addr.value_for_protocol("ip4")
expected_protocol = addr.value_for_protocol("tcp")
if expected_ip and expected_protocol:
found_matching = False
for listener_addr in listener_addrs:
if (
listener_addr.value_for_protocol("ip4") == expected_ip
and listener_addr.value_for_protocol("tcp") is not None
):
found_matching = True
break
assert found_matching, (
f"Listener for {addr} should have matching IP"
)
assert successful_count == len(listen_addrs), (
f"All {len(listen_addrs)} addresses should be listening, "
f"but only {successful_count} succeeded"
)
@pytest.mark.trio
async def test_swarm_listen_multiple_addresses_connectivity(security_protocol):
"""Test that real libp2p connections can be established to all listening addresses.""" # noqa: E501
from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.utils.address_validation import get_available_interfaces
# Get multiple addresses to listen on
listen_addrs = get_available_interfaces(0) # Let OS choose ports
# Create a swarm and listen on multiple addresses
swarm1 = SwarmFactory.build(security_protocol=security_protocol)
async with background_trio_service(swarm1):
# Listen on all addresses
success = await swarm1.listen(*listen_addrs)
assert success, "Should successfully listen on at least one address"
# Verify all available interfaces are listening
assert len(swarm1.listeners) == len(listen_addrs), (
f"All {len(listen_addrs)} interfaces should be listening, "
f"but only {len(swarm1.listeners)} are"
)
# Create a second swarm to test connections
swarm2 = SwarmFactory.build(security_protocol=security_protocol)
async with background_trio_service(swarm2):
# Test connectivity to each listening address using real libp2p connections
for addr_str, listener in swarm1.listeners.items():
listener_addrs = listener.get_addrs()
for listener_addr in listener_addrs:
# Create a full multiaddr with peer ID for libp2p connection
peer_id = swarm1.get_peer_id()
full_addr = listener_addr.encapsulate(f"/p2p/{peer_id}")
# Test real libp2p connection
try:
peer_info = info_from_p2p_addr(full_addr)
# Add the peer info to swarm2's peerstore so it knows where to connect # noqa: E501
swarm2.peerstore.add_addrs(
peer_info.peer_id, [listener_addr], 10000
)
await swarm2.dial_peer(peer_info.peer_id)
# Verify connection was established
assert peer_info.peer_id in swarm2.connections, (
f"Connection to {full_addr} should be established"
)
assert swarm2.get_peer_id() in swarm1.connections, (
f"Connection from {full_addr} should be established"
)
except Exception as e:
pytest.fail(
f"Failed to establish libp2p connection to {full_addr}: {e}"
)

View File

@ -1,9 +1,9 @@
from collections import deque
import pytest
import trio
from libp2p.abc import (
IMultiselectCommunicator,
)
from libp2p.abc import IMultiselectCommunicator, INetStream
from libp2p.custom_types import TProtocol
from libp2p.protocol_muxer.exceptions import (
MultiselectClientError,
@ -13,6 +13,10 @@ from libp2p.protocol_muxer.multiselect import Multiselect
from libp2p.protocol_muxer.multiselect_client import MultiselectClient
async def dummy_handler(stream: INetStream) -> None:
pass
class DummyMultiselectCommunicator(IMultiselectCommunicator):
"""
Dummy MultiSelectCommunicator to test out negotiate timmeout.
@ -31,7 +35,7 @@ class DummyMultiselectCommunicator(IMultiselectCommunicator):
@pytest.mark.trio
async def test_select_one_of_timeout():
async def test_select_one_of_timeout() -> None:
ECHO = TProtocol("/echo/1.0.0")
communicator = DummyMultiselectCommunicator()
@ -42,7 +46,7 @@ async def test_select_one_of_timeout():
@pytest.mark.trio
async def test_query_multistream_command_timeout():
async def test_query_multistream_command_timeout() -> None:
communicator = DummyMultiselectCommunicator()
client = MultiselectClient()
@ -51,9 +55,95 @@ async def test_query_multistream_command_timeout():
@pytest.mark.trio
async def test_negotiate_timeout():
async def test_negotiate_timeout() -> None:
communicator = DummyMultiselectCommunicator()
server = Multiselect()
with pytest.raises(MultiselectError, match="handshake read timeout"):
await server.negotiate(communicator, 2)
class HandshakeThenHangCommunicator(IMultiselectCommunicator):
handshaked: bool
def __init__(self) -> None:
self.handshaked = False
async def write(self, msg_str: str) -> None:
if msg_str == "/multistream/1.0.0":
self.handshaked = True
return
async def read(self) -> str:
if not self.handshaked:
return "/multistream/1.0.0"
# After handshake, hang on read.
await trio.sleep_forever()
# Should not be reached.
return ""
@pytest.mark.trio
async def test_negotiate_timeout_post_handshake() -> None:
communicator = HandshakeThenHangCommunicator()
server = Multiselect()
with pytest.raises(MultiselectError, match="handshake read timeout"):
await server.negotiate(communicator, 1)
class MockCommunicator(IMultiselectCommunicator):
def __init__(self, commands_to_read: list[str]):
self.read_queue = deque(commands_to_read)
self.written_data: list[str] = []
async def write(self, msg_str: str) -> None:
self.written_data.append(msg_str)
async def read(self) -> str:
if not self.read_queue:
raise EOFError
return self.read_queue.popleft()
@pytest.mark.trio
async def test_negotiate_empty_string_command() -> None:
# server receives an empty string, which means client wants `None` protocol.
server = Multiselect({None: dummy_handler})
# Handshake, then empty command
communicator = MockCommunicator(["/multistream/1.0.0", ""])
protocol, handler = await server.negotiate(communicator)
assert protocol is None
assert handler == dummy_handler
# Check that server sent back handshake and the protocol confirmation (empty string)
assert communicator.written_data == ["/multistream/1.0.0", ""]
@pytest.mark.trio
async def test_negotiate_with_none_handler() -> None:
# server has None handler, client sends "" to select it.
server = Multiselect({None: dummy_handler, TProtocol("/proto1"): dummy_handler})
# Handshake, then empty command
communicator = MockCommunicator(["/multistream/1.0.0", ""])
protocol, handler = await server.negotiate(communicator)
assert protocol is None
assert handler == dummy_handler
# Check written data: handshake, protocol confirmation
assert communicator.written_data == ["/multistream/1.0.0", ""]
@pytest.mark.trio
async def test_negotiate_with_none_handler_ls() -> None:
# server has None handler, client sends "ls" then empty string.
server = Multiselect({None: dummy_handler, TProtocol("/proto1"): dummy_handler})
# Handshake, ls, empty command
communicator = MockCommunicator(["/multistream/1.0.0", "ls", ""])
protocol, handler = await server.negotiate(communicator)
assert protocol is None
assert handler == dummy_handler
# Check written data: handshake, ls response, protocol confirmation
assert communicator.written_data[0] == "/multistream/1.0.0"
assert "/proto1" in communicator.written_data[1]
# Note: `ls` should not list the `None` protocol.
assert "None" not in communicator.written_data[1]
assert "\n\n" not in communicator.written_data[1]
assert communicator.written_data[2] == ""

View File

@ -159,3 +159,41 @@ async def test_get_protocols_returns_all_registered_protocols():
protocols = ms.get_protocols()
assert set(protocols) == {p1, p2, p3}
@pytest.mark.trio
async def test_negotiate_optional_tprotocol(security_protocol):
with pytest.raises(Exception):
await perform_simple_test(
None,
[None],
[None],
security_protocol,
)
@pytest.mark.trio
async def test_negotiate_optional_tprotocol_client_none_server_no_none(
security_protocol,
):
with pytest.raises(Exception):
await perform_simple_test(None, [None], [PROTOCOL_ECHO], security_protocol)
@pytest.mark.trio
async def test_negotiate_optional_tprotocol_client_none_in_list(security_protocol):
expected_selected_protocol = PROTOCOL_ECHO
await perform_simple_test(
expected_selected_protocol,
[None, PROTOCOL_ECHO],
[PROTOCOL_ECHO],
security_protocol,
)
@pytest.mark.trio
async def test_negotiate_optional_tprotocol_server_none_client_other(
security_protocol,
):
with pytest.raises(Exception):
await perform_simple_test(None, [PROTOCOL_ECHO], [None], security_protocol)

View File

@ -1,4 +1,8 @@
import random
from unittest.mock import (
AsyncMock,
MagicMock,
)
import pytest
import trio
@ -7,6 +11,9 @@ from libp2p.pubsub.gossipsub import (
PROTOCOL_ID,
GossipSub,
)
from libp2p.pubsub.pb import (
rpc_pb2,
)
from libp2p.tools.utils import (
connect,
)
@ -754,3 +761,173 @@ async def test_single_host():
assert connected_peers == 0, (
f"Single host has {connected_peers} connections, expected 0"
)
@pytest.mark.trio
async def test_handle_ihave(monkeypatch):
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
gossipsub_routers = []
for pubsub in pubsubs_gsub:
if isinstance(pubsub.router, GossipSub):
gossipsub_routers.append(pubsub.router)
gossipsubs = tuple(gossipsub_routers)
index_alice = 0
index_bob = 1
id_bob = pubsubs_gsub[index_bob].my_id
# Connect Alice and Bob
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
await trio.sleep(0.1) # Allow connections to establish
# Mock emit_iwant to capture calls
mock_emit_iwant = AsyncMock()
monkeypatch.setattr(gossipsubs[index_alice], "emit_iwant", mock_emit_iwant)
# Create a test message ID as a string representation of a (seqno, from) tuple
test_seqno = b"1234"
test_from = id_bob.to_bytes()
test_msg_id = f"(b'{test_seqno.hex()}', b'{test_from.hex()}')"
ihave_msg = rpc_pb2.ControlIHave(messageIDs=[test_msg_id])
# Mock seen_messages.cache to avoid false positives
monkeypatch.setattr(pubsubs_gsub[index_alice].seen_messages, "cache", {})
# Simulate Bob sending IHAVE to Alice
await gossipsubs[index_alice].handle_ihave(ihave_msg, id_bob)
# Check if emit_iwant was called with the correct message ID
mock_emit_iwant.assert_called_once()
called_args = mock_emit_iwant.call_args[0]
assert called_args[0] == [test_msg_id] # Expected message IDs
assert called_args[1] == id_bob # Sender peer ID
@pytest.mark.trio
async def test_handle_iwant(monkeypatch):
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
gossipsub_routers = []
for pubsub in pubsubs_gsub:
if isinstance(pubsub.router, GossipSub):
gossipsub_routers.append(pubsub.router)
gossipsubs = tuple(gossipsub_routers)
index_alice = 0
index_bob = 1
id_alice = pubsubs_gsub[index_alice].my_id
# Connect Alice and Bob
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
await trio.sleep(0.1) # Allow connections to establish
# Mock mcache.get to return a message
test_message = rpc_pb2.Message(data=b"test_data")
test_seqno = b"1234"
test_from = id_alice.to_bytes()
# ✅ Correct: use raw tuple and str() to serialize, no hex()
test_msg_id = str((test_seqno, test_from))
mock_mcache_get = MagicMock(return_value=test_message)
monkeypatch.setattr(gossipsubs[index_bob].mcache, "get", mock_mcache_get)
# Mock write_msg to capture the sent packet
mock_write_msg = AsyncMock()
monkeypatch.setattr(gossipsubs[index_bob].pubsub, "write_msg", mock_write_msg)
# Simulate Alice sending IWANT to Bob
iwant_msg = rpc_pb2.ControlIWant(messageIDs=[test_msg_id])
await gossipsubs[index_bob].handle_iwant(iwant_msg, id_alice)
# Check if write_msg was called with the correct packet
mock_write_msg.assert_called_once()
packet = mock_write_msg.call_args[0][1]
assert isinstance(packet, rpc_pb2.RPC)
assert len(packet.publish) == 1
assert packet.publish[0] == test_message
# Verify that mcache.get was called with the correct parsed message ID
mock_mcache_get.assert_called_once()
called_msg_id = mock_mcache_get.call_args[0][0]
assert isinstance(called_msg_id, tuple)
assert called_msg_id == (test_seqno, test_from)
@pytest.mark.trio
async def test_handle_iwant_invalid_msg_id(monkeypatch):
"""
Test that handle_iwant raises ValueError for malformed message IDs.
"""
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
gossipsub_routers = []
for pubsub in pubsubs_gsub:
if isinstance(pubsub.router, GossipSub):
gossipsub_routers.append(pubsub.router)
gossipsubs = tuple(gossipsub_routers)
index_alice = 0
index_bob = 1
id_alice = pubsubs_gsub[index_alice].my_id
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
await trio.sleep(0.1)
# Malformed message ID (not a tuple string)
malformed_msg_id = "not_a_valid_msg_id"
iwant_msg = rpc_pb2.ControlIWant(messageIDs=[malformed_msg_id])
# Mock mcache.get and write_msg to ensure they are not called
mock_mcache_get = MagicMock()
monkeypatch.setattr(gossipsubs[index_bob].mcache, "get", mock_mcache_get)
mock_write_msg = AsyncMock()
monkeypatch.setattr(gossipsubs[index_bob].pubsub, "write_msg", mock_write_msg)
with pytest.raises(ValueError):
await gossipsubs[index_bob].handle_iwant(iwant_msg, id_alice)
mock_mcache_get.assert_not_called()
mock_write_msg.assert_not_called()
# Message ID that's a tuple string but not (bytes, bytes)
invalid_tuple_msg_id = "('abc', 123)"
iwant_msg = rpc_pb2.ControlIWant(messageIDs=[invalid_tuple_msg_id])
with pytest.raises(ValueError):
await gossipsubs[index_bob].handle_iwant(iwant_msg, id_alice)
mock_mcache_get.assert_not_called()
mock_write_msg.assert_not_called()
@pytest.mark.trio
async def test_handle_ihave_empty_message_ids(monkeypatch):
"""
Test that handle_ihave with an empty messageIDs list does not call emit_iwant.
"""
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
gossipsub_routers = []
for pubsub in pubsubs_gsub:
if isinstance(pubsub.router, GossipSub):
gossipsub_routers.append(pubsub.router)
gossipsubs = tuple(gossipsub_routers)
index_alice = 0
index_bob = 1
id_bob = pubsubs_gsub[index_bob].my_id
# Connect Alice and Bob
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
await trio.sleep(0.1) # Allow connections to establish
# Mock emit_iwant to capture calls
mock_emit_iwant = AsyncMock()
monkeypatch.setattr(gossipsubs[index_alice], "emit_iwant", mock_emit_iwant)
# Empty messageIDs list
ihave_msg = rpc_pb2.ControlIHave(messageIDs=[])
# Mock seen_messages.cache to avoid false positives
monkeypatch.setattr(pubsubs_gsub[index_alice].seen_messages, "cache", {})
# Simulate Bob sending IHAVE to Alice
await gossipsubs[index_alice].handle_ihave(ihave_msg, id_bob)
# emit_iwant should not be called since there are no message IDs
mock_emit_iwant.assert_not_called()

View File

@ -8,8 +8,10 @@ from typing import (
from unittest.mock import patch
import pytest
import multiaddr
import trio
from libp2p.crypto.rsa import create_new_key_pair
from libp2p.custom_types import AsyncValidatorFn
from libp2p.exceptions import (
ValidationError,
@ -17,9 +19,11 @@ from libp2p.exceptions import (
from libp2p.network.stream.exceptions import (
StreamEOF,
)
from libp2p.peer.envelope import Envelope, seal_record
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peer_record import PeerRecord
from libp2p.pubsub.pb import (
rpc_pb2,
)
@ -87,6 +91,45 @@ async def test_re_unsubscribe():
assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids
@pytest.mark.trio
async def test_reissue_when_listen_addrs_change():
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
# Yield to let 0 notify 1
await trio.sleep(1)
assert pubsubs_fsub[0].my_id in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
# Check whether signed-records were transfered properly in the subscribe call
envelope_b_sub = (
pubsubs_fsub[1]
.host.get_peerstore()
.get_peer_record(pubsubs_fsub[0].host.get_id())
)
assert isinstance(envelope_b_sub, Envelope)
# Simulate pubsubs_fsub[1].host listen addrs changing (different port)
new_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/123")
# Patch just for the duration we force A to unsubscribe
with patch.object(pubsubs_fsub[0].host, "get_addrs", return_value=[new_addr]):
# Unsubscribe from A's side so that a new_record is issued
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
await trio.sleep(1)
# B should be holding A's new record with bumped seq
envelope_b_unsub = (
pubsubs_fsub[1]
.host.get_peerstore()
.get_peer_record(pubsubs_fsub[0].host.get_id())
)
assert isinstance(envelope_b_unsub, Envelope)
# This proves that a freshly signed record was issued rather than
# the latest-cached-one creating one.
assert envelope_b_sub.record().seq < envelope_b_unsub.record().seq
@pytest.mark.trio
async def test_peers_subscribe():
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
@ -95,11 +138,71 @@ async def test_peers_subscribe():
# Yield to let 0 notify 1
await trio.sleep(1)
assert pubsubs_fsub[0].my_id in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
# Check whether signed-records were transfered properly in the subscribe call
envelope_b_sub = (
pubsubs_fsub[1]
.host.get_peerstore()
.get_peer_record(pubsubs_fsub[0].host.get_id())
)
assert isinstance(envelope_b_sub, Envelope)
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
# Yield to let 0 notify 1
await trio.sleep(1)
assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
envelope_b_unsub = (
pubsubs_fsub[1]
.host.get_peerstore()
.get_peer_record(pubsubs_fsub[0].host.get_id())
)
assert isinstance(envelope_b_unsub, Envelope)
# This proves that the latest-cached-record was re-issued rather than
# freshly creating one.
assert envelope_b_sub.record().seq == envelope_b_unsub.record().seq
@pytest.mark.trio
async def test_peer_subscribe_fail_upon_invald_record_transfer():
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
# Corrupt host_a's local peer record
envelope = pubsubs_fsub[0].host.get_peerstore().get_local_record()
if envelope is not None:
true_record = envelope.record()
key_pair = create_new_key_pair()
if envelope is not None:
envelope.public_key = key_pair.public_key
pubsubs_fsub[0].host.get_peerstore().set_local_record(envelope)
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
# Yeild to let 0 notify 1
await trio.sleep(1)
assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics.get(
TESTING_TOPIC, set()
)
# Create a corrupt envelope with correct signature but false peer-id
false_record = PeerRecord(
ID.from_pubkey(key_pair.public_key), true_record.addrs
)
false_envelope = seal_record(
false_record, pubsubs_fsub[0].host.get_private_key()
)
pubsubs_fsub[0].host.get_peerstore().set_local_record(false_envelope)
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
# Yeild to let 0 notify 1
await trio.sleep(1)
assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics.get(
TESTING_TOPIC, set()
)
@pytest.mark.trio
async def test_get_hello_packet():

View File

@ -0,0 +1,90 @@
from typing import cast
import pytest
import trio
from libp2p.tools.utils import connect
from tests.utils.factories import PubsubFactory
@pytest.mark.trio
async def test_connected_enqueues_and_adds_peer():
async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1):
await connect(p0.host, p1.host)
await p0.wait_until_ready()
# Wait until peer is added via queue processing
with trio.fail_after(1.0):
while p1.my_id not in p0.peers:
await trio.sleep(0.01)
assert p1.my_id in p0.peers
@pytest.mark.trio
async def test_disconnected_enqueues_and_removes_peer():
async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1):
await connect(p0.host, p1.host)
await p0.wait_until_ready()
# Ensure present first
with trio.fail_after(1.0):
while p1.my_id not in p0.peers:
await trio.sleep(0.01)
# Now disconnect and expect removal via dead peer queue
await p0.host.get_network().close_peer(p1.host.get_id())
with trio.fail_after(1.0):
while p1.my_id in p0.peers:
await trio.sleep(0.01)
assert p1.my_id not in p0.peers
@pytest.mark.trio
async def test_channel_closed_is_swallowed_in_notifee(monkeypatch) -> None:
# Ensure PubsubNotifee catches BrokenResourceError from its send channel
async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1):
# Find the PubsubNotifee registered on the network
from libp2p.pubsub.pubsub_notifee import PubsubNotifee
network = p0.host.get_network()
notifees = getattr(network, "notifees", [])
target = None
for nf in notifees:
if isinstance(nf, cast(type, PubsubNotifee)):
target = nf
break
assert target is not None, "PubsubNotifee not found on network"
async def failing_send(_peer_id): # type: ignore[no-redef]
raise trio.BrokenResourceError
# Make initiator queue send fail; PubsubNotifee should swallow
monkeypatch.setattr(target.initiator_peers_queue, "send", failing_send)
# Connect peers; if exceptions are swallowed, service stays running
await connect(p0.host, p1.host)
await p0.wait_until_ready()
assert True
@pytest.mark.trio
async def test_duplicate_connection_does_not_duplicate_peer_state():
async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1):
await connect(p0.host, p1.host)
await p0.wait_until_ready()
with trio.fail_after(1.0):
while p1.my_id not in p0.peers:
await trio.sleep(0.01)
# Connect again should not add duplicates
await connect(p0.host, p1.host)
await trio.sleep(0.1)
assert list(p0.peers.keys()).count(p1.my_id) == 1
@pytest.mark.trio
async def test_blacklist_blocks_peer_added_by_notifee():
async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1):
# Blacklist before connecting
p0.add_to_blacklist(p1.my_id)
await connect(p0.host, p1.host)
await p0.wait_until_ready()
# Give handler a chance to run
await trio.sleep(0.1)
assert p1.my_id not in p0.peers

View File

@ -0,0 +1,13 @@
from libp2p.security.noise.pb import noise_pb2 as noise_pb
def test_noise_extensions_serialization():
# Test NoiseExtensions
ext = noise_pb.NoiseExtensions()
ext.stream_muxers.append("/mplex/6.7.0")
ext.stream_muxers.append("/yamux/1.0.0")
# Serialize and deserialize
data = ext.SerializeToString()
ext2 = noise_pb.NoiseExtensions.FromString(data)
assert list(ext2.stream_muxers) == ["/mplex/6.7.0", "/yamux/1.0.0"]

Some files were not shown because too many files have changed in this diff Show More