122 Commits

Author SHA1 Message Date
af61523c87 Merge branch 'main' into varun-r-mallya/protobuf-update 2025-08-10 16:56:10 +05:30
d2fdf70692 Merge pull request #819 from lla-dane/remove-todos
Remove completed TODO task comments in Peerstore
2025-08-10 16:54:41 +05:30
1ea50a3cf3 Add newsfragment to indicate updation
Signed-off-by: varun-r-mallya <varunrmallya@gmail.com>
2025-08-10 11:47:30 +05:30
f4247faa51 added newsfragment 2025-08-10 11:39:34 +05:30
92e79bbb3f Merge branch 'main' into varun-r-mallya/protobuf-update 2025-08-10 11:29:59 +05:30
eb3121b818 remove completed TODO task comments 2025-08-10 11:28:11 +05:30
787648177f feat: enhance Makefile to keep protobufs always out of date
Signed-off-by: varun-r-mallya <varunrmallya@gmail.com>
2025-08-10 11:25:06 +05:30
fc9b28910a Merge pull request #817 from acul71/fix/fix-816-remove-TODO-IK-patterns-in-noise
remove deprecated IK pattern TODO
2025-08-10 11:16:44 +05:30
26d0ed2d81 remove deprecated IK pattern TODO 2025-08-09 23:25:23 +00:00
618aff9368 chore: recompile protobufs
Signed-off-by: varun-r-mallya <varunrmallya@gmail.com>
2025-08-10 01:31:40 +05:30
32e545d9c7 Merge pull request #813 from sumanjeet0012/feature/update-readme-status
docs: update README.md with module status and source links.
2025-08-09 19:45:47 +05:30
e712e6c0c4 docs: update README.md with module status and source links. 2025-08-09 18:41:44 +05:30
b143c96abd Merge pull request #803 from acul71/fix/fix-592-32bytes_prefix_go_fixme_comment
Remove FIXME about 32 bytes prefix in noise
2025-08-08 12:51:44 +05:30
678b920992 Remove FIXME 2025-08-08 02:01:36 +00:00
cb11f076c8 feat/606-enable-nat-traversal-via-hole-punching (#668)
* feat: base implementation of dcutr for hole-punching

* chore: removed circuit-relay imports from __init__

* feat: implemented dcutr protocol

* added test suite with mock setup

* Fix pre-commit hook issues in DCUtR implementation

* usages of CONNECT_TYPE and SYNC_TYPE have been replaced with HolePunch.Type.CONNECT and HolePunch.Type.SYNC

* added unit tests for dcutr and nat module and

* added multiaddr.get_peer_id() with proper DNS address handling and fixed method signature inconsistencies

* added assertions to verify DCUtR hole punch result in integration test

---------

Co-authored-by: Manu Sheel Gupta <manusheel.edu@gmail.com>
2025-08-07 19:00:16 -06:00
9ed44f5fa3 Merge pull request #753 from lla-dane/feat/certified-addrbook
WIP: Certified-Address-Book interface for Peer-Store
2025-07-26 23:41:28 +05:30
8786f06862 added newsfragment 2025-07-26 22:38:28 +05:30
8c96c5a941 Add the periodic peer-store cleanup in all the examples 2025-07-26 22:38:28 +05:30
16445714f7 overwrite old_addr with new_addrs in consume_peer_record 2025-07-26 22:38:28 +05:30
64bc388b33 added peer-store cleanup task in ping example 2025-07-26 22:38:28 +05:30
09e151aafc Added test for peer-store cleanup task 2025-07-26 22:38:28 +05:30
2d335d4394 Integrated Signed-peer-record transfer with identify/identify-push 2025-07-26 22:38:28 +05:30
8b8b051885 batch operations for consume_peer_record 2025-07-26 22:38:28 +05:30
07c8d4cd1f added periodic cleanup task 2025-07-26 22:38:28 +05:30
09e6feea8e merge new addresses with existing ones, in consume_peer_record 2025-07-26 22:38:28 +05:30
601a8a3ef0 enforce_peer_record_limit 2025-07-26 22:38:28 +05:30
9d597012cc fixed the linter <> protobuf issues 2025-07-26 22:38:28 +05:30
8625226be8 fix merge conflicts 2025-07-26 22:38:28 +05:30
c2b1738cd9 fix sphinx/docutils bugs 2025-07-26 22:38:28 +05:30
83acc38281 fix tox bugs 2025-07-26 22:38:28 +05:30
1899dac84c fix tox bugs 2025-07-26 22:38:28 +05:30
aab2a0b603 Completed: CertifiedAddrBook interface with related tests 2025-07-26 22:38:28 +05:30
bab08c0900 added test for Envelope.equal 2025-07-26 22:38:28 +05:30
6431fb8788 Implemented: Envelope wrapper class + linter hacks for protobuf checks 2025-07-26 22:38:28 +05:30
c8053417d5 remove the big google-protobuf file 2025-07-26 22:38:28 +05:30
6eba9d8ca0 downgrade the peer-record protobuf files to v@25.3 2025-07-26 22:38:27 +05:30
0e1b738cbb update protobuf-version 2025-07-26 22:38:27 +05:30
2ff5ae9c90 added hacks for linting errors 2025-07-26 22:38:27 +05:30
ecc443dcfe linter respacing 2025-07-26 22:38:27 +05:30
aa6039bcd3 PeerRecord class with ProtoBuff implemented 2025-07-26 22:38:27 +05:30
8352d19113 Merge pull request #752 from Minimega12121/todo/handletimeout
todo: handle timeout
2025-07-26 21:16:12 +05:30
ceb9f7d3f7 Merge branch 'main' into todo/handletimeout 2025-07-26 20:54:53 +05:30
9b667bd472 Merge pull request #711 from sumanjeet0012/feature/bootstrap
feat: Implemented Bootstrap module in py-libp2p
2025-07-26 01:56:16 +05:30
eca548851b added new fragment and tests 2025-07-25 16:19:29 +05:30
e91f458446 Enhance peer discovery logging and address resolution handling in BootstrapDiscovery 2025-07-24 00:11:05 +05:30
0416572457 Merge branch 'main' into feature/bootstrap 2025-07-21 20:55:51 +05:30
39375fb338 Merge branch 'main' into todo/handletimeout 2025-07-21 08:17:08 -07:00
8bf261ca77 Merge pull request #766 from acul71/py-multiaddr
feat: add py-multiaddr from git
2025-07-21 08:06:04 -07:00
3a927c8419 Merge branch 'main' into feature/bootstrap 2025-07-21 06:43:34 -07:00
ec92af20e7 Merge branch 'main' into py-multiaddr 2025-07-21 06:27:11 -07:00
01db5d5fa0 Merge pull request #785 from acul71/feat/issue-784-identify-push-raw-format
feat: Add identify-push raw format support and yamux logging improvent
2025-07-21 06:27:00 -07:00
21ee417793 Pin py-multiaddr dependency to specific git commit db8124e2321f316d3b7d2733c7df11d6ad9c03e6 2025-07-20 23:48:16 +02:00
37e4fee9f8 feat: Add identify-push raw format support and yamux logging improvements
- Add comprehensive integration tests for identify-push protocol
- Support both raw protobuf and varint message formats
- Improve yamux logging integration with LIBP2P_DEBUG
- Fix RawConnError handling to reduce log noise
- Add Ctrl+C handling to identify examples
- Enhance identify-push listener/dialer demo

Fixes: #784
2025-07-20 20:19:18 +02:00
c277cce2ed Merge branch 'main' into feature/bootstrap 2025-07-20 04:39:56 -07:00
048e6deb96 fix: utf-8 in reading in py-cid 2025-07-19 23:24:39 +00:00
2dc2dd4670 Merge branch 'main' into py-multiaddr 2025-07-19 02:22:39 -07:00
e6a355d395 Merge pull request #748 from Jineshbansal/add-read-write-lock
TODO: add read/write lock
2025-07-19 02:04:52 -07:00
7b181f3ce5 Merge branch 'main' into add-read-write-lock 2025-07-18 23:53:12 -07:00
0606788ab6 Merge pull request #779 from acul71/fix/issue-778-Incorrect_handling_of_raw_format_in_identify
Fix/issue 778 incorrect handling of raw format in identify
2025-07-18 22:11:58 -07:00
7d62a2f558 Merge branch 'main' into fix/issue-778-Incorrect_handling_of_raw_format_in_identify 2025-07-19 04:25:48 +02:00
26fd169ccc doc: newsfragment raw identify message 2025-07-19 04:25:06 +02:00
99db5b309f fix raw format in identify and tests 2025-07-19 04:11:27 +02:00
7cfe5b9dc7 style: add new line within newsfragment 2025-07-19 01:19:23 +02:00
092b9c0c57 chore(newsfragment): add entry to the release notes 2025-07-19 01:19:23 +02:00
fcf0546831 style: enforce consistent import block 2025-07-19 01:19:23 +02:00
85bad2d0ae replace: attributes with cache cached_property 2025-07-19 01:19:23 +02:00
11560f5cc9 TODO: throttle on async validators (#755)
* fixed todo: throttle on async validators

* added test: validate message respects concurrency limit

* added newsfragment

* added configurable validator semaphore in the PubSub constructor

* added the concurrency-checker in the original test-validate-msg test case

* separate out a _run_async_validator function

* remove redundant run_async_validator
2025-07-18 06:01:28 -06:00
3507531344 chore: clarify newline requirement in newsfragments README.md (#775)
* chore: clarify newline requirement in README

Small change in newsfragments README.md, that reduces some possible room for pull-request tox workflow errors.

* style: remove double backticks for single backticks

the linter strikes again XD.

* docs: clarify trailing newline requirement in newsfragments for lint checks

---------

Co-authored-by: Manu Sheel Gupta <manusheel.edu@gmail.com>
2025-07-17 20:43:00 -06:00
c9162beb2b add grave that were removed by mistake 2025-07-17 20:55:49 +05:30
f587e50cab Merge branch 'main' into todo/handletimeout 2025-07-17 02:36:44 -07:00
d1a0f4f767 Merge branch 'main' into py-multiaddr 2025-07-17 02:32:52 -07:00
3ca27c6e93 Merge pull request #772 from LVivona/replace/peerID/attribute
removal: private attribute for cached_property
2025-07-17 01:58:36 -07:00
b4482e1a5e Merge branch 'main' into replace/peerID/attribute 2025-07-17 01:47:06 -07:00
ae82895d86 style: add new line within newsfragment 2025-07-16 22:12:05 -04:00
9f40d97a05 chore(newsfragment): add entry to the release notes 2025-07-16 22:08:25 -04:00
6fe28dcdd3 Merge branch 'main' into py-multiaddr 2025-07-17 02:39:05 +02:00
41b1ecb67c Merge branch 'main' into feature/bootstrap 2025-07-16 15:00:35 -07:00
e3c9b4bd54 Merge branch 'main' into add-read-write-lock 2025-07-16 14:59:44 -07:00
e132b154e3 Merge branch 'main' into todo/handletimeout 2025-07-16 14:59:24 -07:00
430527625b Merge branch 'main' into replace/peerID/attribute 2025-07-16 17:15:10 -04:00
93fc063e70 Merge branch 'main' into todo/handletimeout 2025-07-16 10:11:23 -07:00
5315816521 Merge branch 'main' into add-read-write-lock 2025-07-16 09:38:10 -07:00
42f07ae1ab Merge branch 'main' into add-read-write-lock 2025-07-15 14:54:29 -07:00
773962c070 Merge branch 'main' into todo/handletimeout 2025-07-15 14:53:48 -07:00
ab94e77310 Merge branch 'main' into feature/bootstrap 2025-07-15 14:40:20 -07:00
23622ea1a0 style: enforce consistent import block 2025-07-15 15:28:03 -04:00
6aeb217349 replace: attributes with cache cached_property 2025-07-15 14:59:34 -04:00
003e7bf278 Merge branch 'main' into py-multiaddr 2025-07-15 10:38:01 -07:00
6f33cde9a9 feat: add py-multiaddr from git 2025-07-13 21:56:07 +00:00
9f38d48e26 Fix valid bootstrap address in test case 2025-07-14 01:45:12 +05:30
2c1e50428a Merge branch 'feature/bootstrap' of https://github.com/sumanjeet0012/py-libp2p into feature/bootstrap 2025-07-14 01:38:53 +05:30
9e76940e75 Refactor logging configuration to reduce verbosity and improve peer discovery events 2025-07-14 01:38:15 +05:30
9cd3805542 make readwrite more safe 2025-07-13 18:37:44 +05:30
d03bdd75d6 Merge branch 'main' into feature/bootstrap 2025-07-12 07:48:04 -07:00
8ff7bb1f20 Merge branch 'main' into todo/handletimeout 2025-07-12 07:25:08 -07:00
21db1c3b72 Merge branch 'main' into feature/bootstrap 2025-07-10 20:43:25 +05:30
3592ad308f Merge branch 'main' into add-read-write-lock 2025-07-10 07:52:06 -07:00
9669a92976 Fix formatting and linting issues 2025-07-10 19:25:58 +05:30
2dfee68f20 Refactor bootstrap discovery to use async methods and update bootstrap peers list 2025-07-10 19:24:09 +05:30
198208aef3 validate and filter bootstrap addresses during discovery initialization 2025-07-09 20:23:47 +05:30
cda163fc48 change ReadWriteLock class 2025-07-09 18:18:37 +05:30
26ed99dafd change tests path 2025-07-09 18:09:07 +05:30
a26fd95854 Merge branch 'feature/bootstrap' of https://github.com/sumanjeet0012/py-libp2p into feature/bootstrap 2025-07-09 01:46:31 +05:30
2965b4e364 DNS resolution working 2025-07-09 01:45:15 +05:30
242998ae9d add test for read-write-lock 2025-07-08 20:06:30 +05:30
5f497c7f5d add file in newsfragments folder 2025-07-08 19:17:43 +05:30
e65e38a3f1 fix: linting error related to read 2025-07-08 19:11:56 +05:30
8fb664bfdf Fix: linting errors 2025-07-08 18:34:30 +05:30
3dcd99a2d1 todo: handle timeout 2025-07-08 17:48:57 +05:30
75abc8b863 run ruff format 2025-07-08 07:35:45 +05:30
91dca97d83 TODO: add read/write lock 2025-07-07 21:55:32 +05:30
80c686ddce Merge branch 'main' into feature/bootstrap 2025-07-07 08:51:53 -07:00
dcb199a6b7 Merge branch 'main' into feature/bootstrap 2025-07-03 01:31:43 -07:00
16be6fab85 Merge branch 'main' into feature/bootstrap 2025-07-03 00:21:19 +05:30
cbb1e26a4f refactor fixed some lint issues 2025-06-30 23:19:03 +05:30
69a2cb00ba remove obsolete test script and add comprehensive validation tests for bootstrap addresses 2025-06-30 23:12:48 +05:30
ec20ca81dd remove unnecessary files from .gitignore 2025-06-30 22:46:19 +05:30
b5ec1bd7ee Merge branch 'main' into feature/bootstrap 2025-06-30 07:35:41 -07:00
ddbd190993 docs: added bootstrap docs in doctree 2025-06-30 11:38:06 +05:30
36be4c354b fix: ensure newline at end of file in Bootstrap peer discovery module documentation 2025-06-30 11:38:05 +05:30
befb2d31db added newsfragments 2025-06-30 11:38:05 +05:30
12ad2dcdf4 Added bootstrap module 2025-06-30 11:37:40 +05:30
96 changed files with 6568 additions and 1286 deletions

View File

@ -60,6 +60,7 @@ PB = libp2p/crypto/pb/crypto.proto \
libp2p/identity/identify/pb/identify.proto \
libp2p/host/autonat/pb/autonat.proto \
libp2p/relay/circuit_v2/pb/circuit.proto \
libp2p/relay/circuit_v2/pb/dcutr.proto \
libp2p/kad_dht/pb/kademlia.proto
PY = $(PB:.proto=_pb2.py)
@ -68,6 +69,8 @@ PYI = $(PB:.proto=_pb2.pyi)
## Set default to `protobufs`, otherwise `format` is called when typing only `make`
all: protobufs
.PHONY: protobufs clean-proto
protobufs: $(PY)
%_pb2.py: %.proto
@ -76,6 +79,11 @@ protobufs: $(PY)
clean-proto:
rm -f $(PY) $(PYI)
# Force protobuf regeneration by making them always out of date
$(PY): FORCE
FORCE:
# docs commands
docs: check-docs

View File

@ -34,19 +34,19 @@ ______________________________________________________________________
| -------------------------------------- | :--------: | :---------------------------------------------------------------------------------: |
| **`libp2p-tcp`** | ✅ | [source](https://github.com/libp2p/py-libp2p/blob/main/libp2p/transport/tcp/tcp.py) |
| **`libp2p-quic`** | 🌱 | |
| **`libp2p-websocket`** | | |
| **`libp2p-webrtc-browser-to-server`** | | |
| **`libp2p-webrtc-private-to-private`** | | |
| **`libp2p-websocket`** | 🌱 | |
| **`libp2p-webrtc-browser-to-server`** | 🌱 | |
| **`libp2p-webrtc-private-to-private`** | 🌱 | |
______________________________________________________________________
### NAT Traversal
| **NAT Traversal** | **Status** |
| ----------------------------- | :--------: |
| **`libp2p-circuit-relay-v2`** | |
| **`libp2p-autonat`** | |
| **`libp2p-hole-punching`** | |
| **NAT Traversal** | **Status** | **Source** |
| ----------------------------- | :--------: | :-----------------------------------------------------------------------------: |
| **`libp2p-circuit-relay-v2`** | | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/relay/circuit_v2) |
| **`libp2p-autonat`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/host/autonat) |
| **`libp2p-hole-punching`** | | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/relay/circuit_v2) |
______________________________________________________________________
@ -54,27 +54,27 @@ ______________________________________________________________________
| **Secure Communication** | **Status** | **Source** |
| ------------------------ | :--------: | :---------------------------------------------------------------------------: |
| **`libp2p-noise`** | 🌱 | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/security/noise) |
| **`libp2p-tls`** | | |
| **`libp2p-noise`** | | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/security/noise) |
| **`libp2p-tls`** | 🌱 | |
______________________________________________________________________
### Discovery
| **Discovery** | **Status** |
| -------------------- | :--------: |
| **`bootstrap`** | |
| **`random-walk`** | |
| **`mdns-discovery`** | |
| **`rendezvous`** | |
| **Discovery** | **Status** | **Source** |
| -------------------- | :--------: | :--------------------------------------------------------------------------------: |
| **`bootstrap`** | | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/discovery/bootstrap) |
| **`random-walk`** | 🌱 | |
| **`mdns-discovery`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/discovery/mdns) |
| **`rendezvous`** | 🌱 | |
______________________________________________________________________
### Peer Routing
| **Peer Routing** | **Status** |
| -------------------- | :--------: |
| **`libp2p-kad-dht`** | |
| **Peer Routing** | **Status** | **Source** |
| -------------------- | :--------: | :--------------------------------------------------------------------: |
| **`libp2p-kad-dht`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/kad_dht) |
______________________________________________________________________
@ -89,10 +89,10 @@ ______________________________________________________________________
### Stream Muxers
| **Stream Muxers** | **Status** | **Status** |
| ------------------ | :--------: | :----------------------------------------------------------------------------------------: |
| **`libp2p-yamux`** | 🌱 | |
| **`libp2p-mplex`** | 🛠️ | [source](https://github.com/libp2p/py-libp2p/blob/main/libp2p/stream_muxer/mplex/mplex.py) |
| **Stream Muxers** | **Status** | **Source** |
| ------------------ | :--------: | :-------------------------------------------------------------------------------: |
| **`libp2p-yamux`** | | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/stream_muxer/yamux) |
| **`libp2p-mplex`** | | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/stream_muxer/mplex) |
______________________________________________________________________
@ -100,7 +100,7 @@ ______________________________________________________________________
| **Storage** | **Status** |
| ------------------- | :--------: |
| **`libp2p-record`** | |
| **`libp2p-record`** | 🌱 |
______________________________________________________________________

View File

@ -0,0 +1,13 @@
libp2p.discovery.bootstrap package
==================================
Submodules
----------
Module contents
---------------
.. automodule:: libp2p.discovery.bootstrap
:members:
:undoc-members:
:show-inheritance:

View File

@ -7,6 +7,7 @@ Subpackages
.. toctree::
:maxdepth: 4
libp2p.discovery.bootstrap
libp2p.discovery.events
libp2p.discovery.mdns

View File

@ -0,0 +1,136 @@
import argparse
import logging
import secrets
import multiaddr
import trio
from libp2p import new_host
from libp2p.abc import PeerInfo
from libp2p.crypto.secp256k1 import create_new_key_pair
from libp2p.discovery.events.peerDiscovery import peerDiscovery
# Configure logging
logger = logging.getLogger("libp2p.discovery.bootstrap")
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setFormatter(
logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
)
logger.addHandler(handler)
# Configure root logger to only show warnings and above to reduce noise
# This prevents verbose DEBUG messages from multiaddr, DNS, etc.
logging.getLogger().setLevel(logging.WARNING)
# Specifically silence noisy libraries
logging.getLogger("multiaddr").setLevel(logging.WARNING)
logging.getLogger("root").setLevel(logging.WARNING)
def on_peer_discovery(peer_info: PeerInfo) -> None:
"""Handler for peer discovery events."""
logger.info(f"🔍 Discovered peer: {peer_info.peer_id}")
logger.debug(f" Addresses: {[str(addr) for addr in peer_info.addrs]}")
# Example bootstrap peers
BOOTSTRAP_PEERS = [
"/dnsaddr/github.com/p2p/QmNnooDu7bfjPFoTZYxMNLWUQJyrVwtbZg5gBMjTezGAJN",
"/dnsaddr/cloudflare.com/p2p/QmNnooDu7bfjPFoTZYxMNLWUQJyrVwtbZg5gBMjTezGAJN",
"/dnsaddr/google.com/p2p/QmNnooDu7bfjPFoTZYxMNLWUQJyrVwtbZg5gBMjTezGAJN",
"/dnsaddr/bootstrap.libp2p.io/p2p/QmNnooDu7bfjPFoTZYxMNLWUQJyrVwtbZg5gBMjTezGAJN",
"/dnsaddr/bootstrap.libp2p.io/p2p/QmbLHAnMoJPWSCR5Zhtx6BHJX9KiKNN6tpvbUcqanj75Nb",
"/ip4/104.131.131.82/tcp/4001/p2p/QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ",
"/ip6/2604:a880:1:20::203:d001/tcp/4001/p2p/QmSoLPppuBtQSGwKDZT2M73ULpjvfd3aZ6ha4oFGL1KrGM",
"/ip4/128.199.219.111/tcp/4001/p2p/QmSoLV4Bbm51jM9C4gDYZQ9Cy3U6aXMJDAbzgu2fzaDs64",
"/ip4/104.236.76.40/tcp/4001/p2p/QmSoLV4Bbm51jM9C4gDYZQ9Cy3U6aXMJDAbzgu2fzaDs64",
"/ip4/178.62.158.247/tcp/4001/p2p/QmSoLer265NRgSp2LA3dPaeykiS1J6DifTC88f5uVQKNAd",
"/ip6/2604:a880:1:20::203:d001/tcp/4001/p2p/QmSoLPppuBtQSGwKDZT2M73ULpjvfd3aZ6ha4oFGL1KrGM",
"/ip6/2400:6180:0:d0::151:6001/tcp/4001/p2p/QmSoLSafTMBsPKadTEgaXctDQVcqN88CNLHXMkTNwMKPnu",
"/ip6/2a03:b0c0:0:1010::23:1001/tcp/4001/p2p/QmSoLueR4xBeUbY9WZ9xGUUxunbKWcrNFTDAadQJmocnWm",
]
async def run(port: int, bootstrap_addrs: list[str]) -> None:
"""Run the bootstrap discovery example."""
# Generate key pair
secret = secrets.token_bytes(32)
key_pair = create_new_key_pair(secret)
# Create listen address
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
# Register peer discovery handler
peerDiscovery.register_peer_discovered_handler(on_peer_discovery)
logger.info("🚀 Starting Bootstrap Discovery Example")
logger.info(f"📍 Listening on: {listen_addr}")
logger.info(f"🌐 Bootstrap peers: {len(bootstrap_addrs)}")
print("\n" + "=" * 60)
print("Bootstrap Discovery Example")
print("=" * 60)
print("This example demonstrates connecting to bootstrap peers.")
print("Watch the logs for peer discovery events!")
print("Press Ctrl+C to exit.")
print("=" * 60)
# Create and run host with bootstrap discovery
host = new_host(key_pair=key_pair, bootstrap=bootstrap_addrs)
try:
async with host.run(listen_addrs=[listen_addr]):
# Keep running and log peer discovery events
await trio.sleep_forever()
except KeyboardInterrupt:
logger.info("👋 Shutting down...")
def main() -> None:
"""Main entry point."""
description = """
Bootstrap Discovery Example for py-libp2p
This example demonstrates how to use bootstrap peers for peer discovery.
Bootstrap peers are predefined peers that help new nodes join the network.
Usage:
python bootstrap.py -p 8000
python bootstrap.py -p 8001 --custom-bootstrap \\
"/ip4/127.0.0.1/tcp/8000/p2p/QmYourPeerID"
"""
parser = argparse.ArgumentParser(
description=description, formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument(
"-p", "--port", default=0, type=int, help="Port to listen on (default: random)"
)
parser.add_argument(
"--custom-bootstrap",
nargs="*",
help="Custom bootstrap addresses (space-separated)",
)
parser.add_argument(
"-v", "--verbose", action="store_true", help="Enable verbose output"
)
args = parser.parse_args()
if args.verbose:
logger.setLevel(logging.DEBUG)
# Use custom bootstrap addresses if provided, otherwise use defaults
bootstrap_addrs = (
args.custom_bootstrap if args.custom_bootstrap else BOOTSTRAP_PEERS
)
try:
trio.run(run, args.port, bootstrap_addrs)
except KeyboardInterrupt:
logger.info("Exiting...")
if __name__ == "__main__":
main()

View File

@ -43,6 +43,9 @@ async def run(port: int, destination: str) -> None:
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
host = new_host()
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
# Start the peer-store cleanup task
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
if not destination: # its the server
async def stream_handler(stream: INetStream) -> None:

View File

@ -45,7 +45,10 @@ async def run(port: int, destination: str, seed: int | None = None) -> None:
secret = secrets.token_bytes(32)
host = new_host(key_pair=create_new_key_pair(secret))
async with host.run(listen_addrs=[listen_addr]):
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
# Start the peer-store cleanup task
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
print(f"I am {host.get_id().to_string()}")
if not destination: # its the server

View File

@ -1,6 +1,7 @@
import argparse
import base64
import logging
import sys
import multiaddr
import trio
@ -13,6 +14,8 @@ from libp2p.identity.identify.identify import (
identify_handler_for,
parse_identify_response,
)
from libp2p.identity.identify.pb.identify_pb2 import Identify
from libp2p.peer.envelope import debug_dump_envelope, unmarshal_envelope
from libp2p.peer.peerinfo import (
info_from_p2p_addr,
)
@ -31,10 +34,11 @@ def decode_multiaddrs(raw_addrs):
return decoded_addrs
def print_identify_response(identify_response):
def print_identify_response(identify_response: Identify):
"""Pretty-print Identify response."""
public_key_b64 = base64.b64encode(identify_response.public_key).decode("utf-8")
listen_addrs = decode_multiaddrs(identify_response.listen_addrs)
signed_peer_record = unmarshal_envelope(identify_response.signedPeerRecord)
try:
observed_addr_decoded = decode_multiaddrs([identify_response.observed_addr])
except Exception:
@ -50,6 +54,8 @@ def print_identify_response(identify_response):
f" Agent Version: {identify_response.agent_version}"
)
debug_dump_envelope(signed_peer_record)
async def run(port: int, destination: str, use_varint_format: bool = True) -> None:
localhost_ip = "0.0.0.0"
@ -60,58 +66,158 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No
host_a = new_host()
# Set up identify handler with specified format
# Set use_varint_format = False, if want to checkout the Signed-PeerRecord
identify_handler = identify_handler_for(
host_a, use_varint_format=use_varint_format
)
host_a.set_stream_handler(IDENTIFY_PROTOCOL_ID, identify_handler)
async with host_a.run(listen_addrs=[listen_addr]):
async with (
host_a.run(listen_addrs=[listen_addr]),
trio.open_nursery() as nursery,
):
# Start the peer-store cleanup task
nursery.start_soon(host_a.get_peerstore().start_cleanup_task, 60)
# Get the actual address and replace 0.0.0.0 with 127.0.0.1 for client
# connections
server_addr = str(host_a.get_addrs()[0])
client_addr = server_addr.replace("/ip4/0.0.0.0/", "/ip4/127.0.0.1/")
format_name = "length-prefixed" if use_varint_format else "raw protobuf"
format_flag = "--raw-format" if not use_varint_format else ""
print(
f"First host listening (using {format_name} format). "
f"Run this from another console:\n\n"
f"identify-demo "
f"-d {client_addr}\n"
f"identify-demo {format_flag} -d {client_addr}\n"
)
print("Waiting for incoming identify request...")
await trio.sleep_forever()
# Add a custom handler to show connection events
async def custom_identify_handler(stream):
peer_id = stream.muxed_conn.peer_id
print(f"\n🔗 Received identify request from peer: {peer_id}")
# Show remote address in multiaddr format
try:
from libp2p.identity.identify.identify import (
_remote_address_to_multiaddr,
)
remote_address = stream.get_remote_address()
if remote_address:
observed_multiaddr = _remote_address_to_multiaddr(
remote_address
)
# Add the peer ID to create a complete multiaddr
complete_multiaddr = f"{observed_multiaddr}/p2p/{peer_id}"
print(f" Remote address: {complete_multiaddr}")
else:
print(f" Remote address: {remote_address}")
except Exception:
print(f" Remote address: {stream.get_remote_address()}")
# Call the original handler
await identify_handler(stream)
print(f"✅ Successfully processed identify request from {peer_id}")
# Replace the handler with our custom one
host_a.set_stream_handler(IDENTIFY_PROTOCOL_ID, custom_identify_handler)
try:
await trio.sleep_forever()
except KeyboardInterrupt:
print("\n🛑 Shutting down listener...")
logger.info("Listener interrupted by user")
return
else:
# Create second host (dialer)
listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}")
host_b = new_host()
async with host_b.run(listen_addrs=[listen_addr]):
async with (
host_b.run(listen_addrs=[listen_addr]),
trio.open_nursery() as nursery,
):
# Start the peer-store cleanup task
nursery.start_soon(host_b.get_peerstore().start_cleanup_task, 60)
# Connect to the first host
print(f"dialer (host_b) listening on {host_b.get_addrs()[0]}")
maddr = multiaddr.Multiaddr(destination)
info = info_from_p2p_addr(maddr)
print(f"Second host connecting to peer: {info.peer_id}")
await host_b.connect(info)
try:
await host_b.connect(info)
except Exception as e:
error_msg = str(e)
if "unable to connect" in error_msg or "SwarmException" in error_msg:
print(f"\n❌ Cannot connect to peer: {info.peer_id}")
print(f" Address: {destination}")
print(f" Error: {error_msg}")
print(
"\n💡 Make sure the peer is running and the address is correct."
)
return
else:
# Re-raise other exceptions
raise
stream = await host_b.new_stream(info.peer_id, (IDENTIFY_PROTOCOL_ID,))
try:
print("Starting identify protocol...")
# Read the complete response (could be either format)
# Read a larger chunk to get all the data before stream closes
response = await stream.read(8192) # Read enough data in one go
# Read the response using the utility function
from libp2p.utils.varint import read_length_prefixed_protobuf
response = await read_length_prefixed_protobuf(
stream, use_varint_format
)
full_response = response
await stream.close()
# Parse the response using the robust protocol-level function
# This handles both old and new formats automatically
identify_msg = parse_identify_response(response)
identify_msg = parse_identify_response(full_response)
print_identify_response(identify_msg)
except Exception as e:
print(f"Identify protocol error: {e}")
error_msg = str(e)
print(f"Identify protocol error: {error_msg}")
# Check for specific format mismatch errors
if "Error parsing message" in error_msg or "DecodeError" in error_msg:
print("\n" + "=" * 60)
print("FORMAT MISMATCH DETECTED!")
print("=" * 60)
if use_varint_format:
print(
"You are using length-prefixed format (default) but the "
"listener"
)
print("is using raw protobuf format.")
print(
"\nTo fix this, run the dialer with the --raw-format flag:"
)
print(f"identify-demo --raw-format -d {destination}")
else:
print("You are using raw protobuf format but the listener")
print("is using length-prefixed format (default).")
print(
"\nTo fix this, run the dialer without the --raw-format "
"flag:"
)
print(f"identify-demo -d {destination}")
print("=" * 60)
else:
import traceback
traceback.print_exc()
return
@ -147,16 +253,27 @@ def main() -> None:
"length-prefixed (new format)"
),
)
args = parser.parse_args()
# Determine format: raw format if --raw-format is specified, otherwise
# length-prefixed
use_varint_format = not args.raw_format
# Determine format: use varint (length-prefixed) if --raw-format is specified,
# otherwise use raw protobuf format (old format)
use_varint_format = args.raw_format
try:
trio.run(run, *(args.port, args.destination, use_varint_format))
if args.destination:
# Run in dialer mode
trio.run(run, *(args.port, args.destination, use_varint_format))
else:
# Run in listener mode
trio.run(run, *(args.port, args.destination, use_varint_format))
except KeyboardInterrupt:
pass
print("\n👋 Goodbye!")
logger.info("Application interrupted by user")
except Exception as e:
print(f"\n❌ Error: {str(e)}")
logger.error("Error: %s", str(e))
sys.exit(1)
if __name__ == "__main__":

View File

@ -11,23 +11,26 @@ This example shows how to:
import logging
import multiaddr
import trio
from libp2p import (
new_host,
)
from libp2p.abc import (
INetStream,
)
from libp2p.crypto.secp256k1 import (
create_new_key_pair,
)
from libp2p.custom_types import (
TProtocol,
)
from libp2p.identity.identify import (
identify_handler_for,
from libp2p.identity.identify.pb.identify_pb2 import (
Identify,
)
from libp2p.identity.identify_push import (
ID_PUSH,
identify_push_handler_for,
push_identify_to_peer,
)
from libp2p.peer.peerinfo import (
@ -38,8 +41,145 @@ from libp2p.peer.peerinfo import (
logger = logging.getLogger(__name__)
def create_custom_identify_handler(host, host_name: str):
"""Create a custom identify handler that displays received information."""
async def handle_identify(stream: INetStream) -> None:
peer_id = stream.muxed_conn.peer_id
print(f"\n🔍 {host_name} received identify request from peer: {peer_id}")
# Get the standard identify response using the existing function
from libp2p.identity.identify.identify import (
_mk_identify_protobuf,
_remote_address_to_multiaddr,
)
# Get observed address
observed_multiaddr = None
try:
remote_address = stream.get_remote_address()
if remote_address:
observed_multiaddr = _remote_address_to_multiaddr(remote_address)
except Exception:
pass
# Build the identify protobuf
identify_msg = _mk_identify_protobuf(host, observed_multiaddr)
response_data = identify_msg.SerializeToString()
print(f" 📋 {host_name} identify information:")
if identify_msg.HasField("protocol_version"):
print(f" Protocol Version: {identify_msg.protocol_version}")
if identify_msg.HasField("agent_version"):
print(f" Agent Version: {identify_msg.agent_version}")
if identify_msg.HasField("public_key"):
print(f" Public Key: {identify_msg.public_key.hex()[:16]}...")
if identify_msg.listen_addrs:
print(" Listen Addresses:")
for addr_bytes in identify_msg.listen_addrs:
addr = multiaddr.Multiaddr(addr_bytes)
print(f" - {addr}")
if identify_msg.protocols:
print(" Supported Protocols:")
for protocol in identify_msg.protocols:
print(f" - {protocol}")
# Send the response
await stream.write(response_data)
await stream.close()
return handle_identify
def create_custom_identify_push_handler(host, host_name: str):
"""Create a custom identify/push handler that displays received information."""
async def handle_identify_push(stream: INetStream) -> None:
peer_id = stream.muxed_conn.peer_id
print(f"\n📤 {host_name} received identify/push from peer: {peer_id}")
try:
# Read the identify message using the utility function
from libp2p.utils.varint import read_length_prefixed_protobuf
data = await read_length_prefixed_protobuf(stream, use_varint_format=True)
# Parse the identify message
identify_msg = Identify()
identify_msg.ParseFromString(data)
print(" 📋 Received identify information:")
if identify_msg.HasField("protocol_version"):
print(f" Protocol Version: {identify_msg.protocol_version}")
if identify_msg.HasField("agent_version"):
print(f" Agent Version: {identify_msg.agent_version}")
if identify_msg.HasField("public_key"):
print(f" Public Key: {identify_msg.public_key.hex()[:16]}...")
if identify_msg.HasField("observed_addr") and identify_msg.observed_addr:
observed_addr = multiaddr.Multiaddr(identify_msg.observed_addr)
print(f" Observed Address: {observed_addr}")
if identify_msg.listen_addrs:
print(" Listen Addresses:")
for addr_bytes in identify_msg.listen_addrs:
addr = multiaddr.Multiaddr(addr_bytes)
print(f" - {addr}")
if identify_msg.protocols:
print(" Supported Protocols:")
for protocol in identify_msg.protocols:
print(f" - {protocol}")
# Update the peerstore with the new information
from libp2p.identity.identify_push.identify_push import (
_update_peerstore_from_identify,
)
await _update_peerstore_from_identify(
host.get_peerstore(), peer_id, identify_msg
)
print(f"{host_name} updated peerstore with new information")
except Exception as e:
print(f" ❌ Error processing identify/push: {e}")
finally:
await stream.close()
return handle_identify_push
async def display_peerstore_info(host, host_name: str, peer_id, description: str):
"""Display peerstore information for a specific peer."""
peerstore = host.get_peerstore()
try:
addrs = peerstore.addrs(peer_id)
except Exception:
addrs = []
try:
protocols = peerstore.get_protocols(peer_id)
except Exception:
protocols = []
print(f"\n📚 {host_name} peerstore for {description}:")
print(f" Peer ID: {peer_id}")
if addrs:
print(" Addresses:")
for addr in addrs:
print(f" - {addr}")
else:
print(" Addresses: None")
if protocols:
print(" Protocols:")
for protocol in protocols:
print(f" - {protocol}")
else:
print(" Protocols: None")
async def main() -> None:
print("\n==== Starting Identify-Push Example ====\n")
print("\n==== Starting Enhanced Identify-Push Example ====\n")
# Create key pairs for the two hosts
key_pair_1 = create_new_key_pair()
@ -48,45 +188,57 @@ async def main() -> None:
# Create the first host
host_1 = new_host(key_pair=key_pair_1)
# Set up the identify and identify/push handlers
host_1.set_stream_handler(TProtocol("/ipfs/id/1.0.0"), identify_handler_for(host_1))
host_1.set_stream_handler(ID_PUSH, identify_push_handler_for(host_1))
# Set up custom identify and identify/push handlers
host_1.set_stream_handler(
TProtocol("/ipfs/id/1.0.0"), create_custom_identify_handler(host_1, "Host 1")
)
host_1.set_stream_handler(
ID_PUSH, create_custom_identify_push_handler(host_1, "Host 1")
)
# Create the second host
host_2 = new_host(key_pair=key_pair_2)
# Set up the identify and identify/push handlers
host_2.set_stream_handler(TProtocol("/ipfs/id/1.0.0"), identify_handler_for(host_2))
host_2.set_stream_handler(ID_PUSH, identify_push_handler_for(host_2))
# Set up custom identify and identify/push handlers
host_2.set_stream_handler(
TProtocol("/ipfs/id/1.0.0"), create_custom_identify_handler(host_2, "Host 2")
)
host_2.set_stream_handler(
ID_PUSH, create_custom_identify_push_handler(host_2, "Host 2")
)
# Start listening on random ports using the run context manager
import multiaddr
listen_addr_1 = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")
listen_addr_2 = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")
async with host_1.run([listen_addr_1]), host_2.run([listen_addr_2]):
async with (
host_1.run([listen_addr_1]),
host_2.run([listen_addr_2]),
trio.open_nursery() as nursery,
):
# Start the peer-store cleanup task
nursery.start_soon(host_1.get_peerstore().start_cleanup_task, 60)
nursery.start_soon(host_2.get_peerstore().start_cleanup_task, 60)
# Get the addresses of both hosts
addr_1 = host_1.get_addrs()[0]
logger.info(f"Host 1 listening on {addr_1}")
print(f"Host 1 listening on {addr_1}")
print(f"Peer ID: {host_1.get_id().pretty()}")
addr_2 = host_2.get_addrs()[0]
logger.info(f"Host 2 listening on {addr_2}")
print(f"Host 2 listening on {addr_2}")
print(f"Peer ID: {host_2.get_id().pretty()}")
print("\nConnecting Host 2 to Host 1...")
print("🏠 Host Configuration:")
print(f" Host 1: {addr_1}")
print(f" Host 1 Peer ID: {host_1.get_id().pretty()}")
print(f" Host 2: {addr_2}")
print(f" Host 2 Peer ID: {host_2.get_id().pretty()}")
print("\n🔗 Connecting Host 2 to Host 1...")
# Connect host_2 to host_1
peer_info = info_from_p2p_addr(addr_1)
await host_2.connect(peer_info)
logger.info("Host 2 connected to Host 1")
print("Host 2 successfully connected to Host 1")
print("Host 2 successfully connected to Host 1")
# Run the identify protocol from host_2 to host_1
# (so Host 1 learns Host 2's address)
print("\n🔄 Running identify protocol (Host 2 → Host 1)...")
from libp2p.identity.identify.identify import ID as IDENTIFY_PROTOCOL_ID
stream = await host_2.new_stream(host_1.get_id(), (IDENTIFY_PROTOCOL_ID,))
@ -94,64 +246,58 @@ async def main() -> None:
await stream.close()
# Run the identify protocol from host_1 to host_2
# (so Host 2 learns Host 1's address)
print("\n🔄 Running identify protocol (Host 1 → Host 2)...")
stream = await host_1.new_stream(host_2.get_id(), (IDENTIFY_PROTOCOL_ID,))
response = await stream.read()
await stream.close()
# --- NEW CODE: Update Host 1's peerstore with Host 2's addresses ---
from libp2p.identity.identify.pb.identify_pb2 import (
Identify,
)
# Update Host 1's peerstore with Host 2's addresses
identify_msg = Identify()
identify_msg.ParseFromString(response)
peerstore_1 = host_1.get_peerstore()
peer_id_2 = host_2.get_id()
for addr_bytes in identify_msg.listen_addrs:
maddr = multiaddr.Multiaddr(addr_bytes)
# TTL can be any positive int
peerstore_1.add_addr(
peer_id_2,
maddr,
ttl=3600,
)
# --- END NEW CODE ---
peerstore_1.add_addr(peer_id_2, maddr, ttl=3600)
# Now Host 1's peerstore should have Host 2's address
peerstore_1 = host_1.get_peerstore()
peer_id_2 = host_2.get_id()
addrs_1_for_2 = peerstore_1.addrs(peer_id_2)
logger.info(
f"[DEBUG] Host 1 peerstore addresses for Host 2 before push: "
f"{addrs_1_for_2}"
)
print(
f"[DEBUG] Host 1 peerstore addresses for Host 2 before push: "
f"{addrs_1_for_2}"
# Display peerstore information before push
await display_peerstore_info(
host_1, "Host 1", peer_id_2, "Host 2 (before push)"
)
# Push identify information from host_1 to host_2
logger.info("Host 1 pushing identify information to Host 2")
print("\nHost 1 pushing identify information to Host 2...")
print("\n📤 Host 1 pushing identify information to Host 2...")
try:
# Call push_identify_to_peer which now returns a boolean
success = await push_identify_to_peer(host_1, host_2.get_id())
if success:
logger.info("Identify push completed successfully")
print("Identify push completed successfully!")
print("Identify push completed successfully!")
else:
logger.warning("Identify push didn't complete successfully")
print("\nWarning: Identify push didn't complete successfully")
print("⚠️ Identify push didn't complete successfully")
except Exception as e:
logger.error(f"Error during identify push: {str(e)}")
print(f"\nError during identify push: {str(e)}")
print(f"Error during identify push: {str(e)}")
# Add this at the end of your async with block:
await trio.sleep(0.5) # Give background tasks time to finish
# Give a moment for the identify/push processing to complete
await trio.sleep(0.5)
# Display peerstore information after push
await display_peerstore_info(host_1, "Host 1", peer_id_2, "Host 2 (after push)")
await display_peerstore_info(
host_2, "Host 2", host_1.get_id(), "Host 1 (after push)"
)
# Give more time for background tasks to finish and connections to stabilize
print("\n⏳ Waiting for background tasks to complete...")
await trio.sleep(1.0)
# Gracefully close connections to prevent connection errors
print("🔌 Closing connections...")
await host_2.disconnect(host_1.get_id())
await trio.sleep(0.2)
print("\n🎉 Example completed successfully!")
if __name__ == "__main__":

View File

@ -41,6 +41,9 @@ from libp2p.identity.identify import (
ID as ID_IDENTIFY,
identify_handler_for,
)
from libp2p.identity.identify.identify import (
_remote_address_to_multiaddr,
)
from libp2p.identity.identify.pb.identify_pb2 import (
Identify,
)
@ -72,40 +75,30 @@ def custom_identify_push_handler_for(host, use_varint_format: bool = True):
async def handle_identify_push(stream: INetStream) -> None:
peer_id = stream.muxed_conn.peer_id
# Get remote address information
try:
if use_varint_format:
# Read length-prefixed identify message from the stream
from libp2p.utils.varint import decode_varint_from_bytes
remote_address = stream.get_remote_address()
if remote_address:
observed_multiaddr = _remote_address_to_multiaddr(remote_address)
logger.info(
"Connection from remote peer %s, address: %s, multiaddr: %s",
peer_id,
remote_address,
observed_multiaddr,
)
print(f"\n🔗 Received identify/push request from peer: {peer_id}")
# Add the peer ID to create a complete multiaddr
complete_multiaddr = f"{observed_multiaddr}/p2p/{peer_id}"
print(f" Remote address: {complete_multiaddr}")
except Exception as e:
logger.error("Error getting remote address: %s", e)
print(f"\n🔗 Received identify/push request from peer: {peer_id}")
# First read the varint length prefix
length_bytes = b""
while True:
b = await stream.read(1)
if not b:
break
length_bytes += b
if b[0] & 0x80 == 0:
break
try:
# Use the utility function to read the protobuf message
from libp2p.utils.varint import read_length_prefixed_protobuf
if not length_bytes:
logger.warning("No length prefix received from peer %s", peer_id)
return
msg_length = decode_varint_from_bytes(length_bytes)
# Read the protobuf message
data = await stream.read(msg_length)
if len(data) != msg_length:
logger.warning("Incomplete message received from peer %s", peer_id)
return
else:
# Read raw protobuf message from the stream
data = b""
while True:
chunk = await stream.read(4096)
if not chunk:
break
data += chunk
data = await read_length_prefixed_protobuf(stream, use_varint_format)
identify_msg = Identify()
identify_msg.ParseFromString(data)
@ -155,11 +148,41 @@ def custom_identify_push_handler_for(host, use_varint_format: bool = True):
await _update_peerstore_from_identify(peerstore, peer_id, identify_msg)
logger.info("Successfully processed identify/push from peer %s", peer_id)
print(f"\nSuccessfully processed identify/push from peer {peer_id}")
print(f"Successfully processed identify/push from peer {peer_id}")
except Exception as e:
logger.error("Error processing identify/push from %s: %s", peer_id, e)
print(f"\nError processing identify/push from {peer_id}: {e}")
error_msg = str(e)
logger.error(
"Error processing identify/push from %s: %s", peer_id, error_msg
)
print(f"\nError processing identify/push from {peer_id}: {error_msg}")
# Check for specific format mismatch errors
if (
"Error parsing message" in error_msg
or "DecodeError" in error_msg
or "ParseFromString" in error_msg
):
print("\n" + "=" * 60)
print("FORMAT MISMATCH DETECTED!")
print("=" * 60)
if use_varint_format:
print(
"You are using length-prefixed format (default) but the "
"dialer is using raw protobuf format."
)
print("\nTo fix this, run the dialer with the --raw-format flag:")
print(
"identify-push-listener-dialer-demo --raw-format -d <ADDRESS>"
)
else:
print("You are using raw protobuf format but the dialer")
print("is using length-prefixed format (default).")
print(
"\nTo fix this, run the dialer without the --raw-format flag:"
)
print("identify-push-listener-dialer-demo -d <ADDRESS>")
print("=" * 60)
finally:
# Close the stream after processing
await stream.close()
@ -167,7 +190,9 @@ def custom_identify_push_handler_for(host, use_varint_format: bool = True):
return handle_identify_push
async def run_listener(port: int, use_varint_format: bool = True) -> None:
async def run_listener(
port: int, use_varint_format: bool = True, raw_format_flag: bool = False
) -> None:
"""Run a host in listener mode."""
format_name = "length-prefixed" if use_varint_format else "raw protobuf"
print(
@ -187,29 +212,41 @@ async def run_listener(port: int, use_varint_format: bool = True) -> None:
)
host.set_stream_handler(
ID_IDENTIFY_PUSH,
identify_push_handler_for(host, use_varint_format=use_varint_format),
custom_identify_push_handler_for(host, use_varint_format=use_varint_format),
)
# Start listening
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
async with host.run([listen_addr]):
addr = host.get_addrs()[0]
logger.info("Listener host ready!")
print("Listener host ready!")
try:
async with host.run([listen_addr]):
addr = host.get_addrs()[0]
logger.info("Listener host ready!")
print("Listener host ready!")
logger.info(f"Listening on: {addr}")
print(f"Listening on: {addr}")
logger.info(f"Listening on: {addr}")
print(f"Listening on: {addr}")
logger.info(f"Peer ID: {host.get_id().pretty()}")
print(f"Peer ID: {host.get_id().pretty()}")
logger.info(f"Peer ID: {host.get_id().pretty()}")
print(f"Peer ID: {host.get_id().pretty()}")
print("\nRun dialer with command:")
print(f"identify-push-listener-dialer-demo -d {addr}")
print("\nWaiting for incoming connections... (Ctrl+C to exit)")
print("\nRun dialer with command:")
if raw_format_flag:
print(f"identify-push-listener-dialer-demo -d {addr} --raw-format")
else:
print(f"identify-push-listener-dialer-demo -d {addr}")
print("\nWaiting for incoming identify/push requests... (Ctrl+C to exit)")
# Keep running until interrupted
await trio.sleep_forever()
# Keep running until interrupted
try:
await trio.sleep_forever()
except KeyboardInterrupt:
print("\n🛑 Shutting down listener...")
logger.info("Listener interrupted by user")
return
except Exception as e:
logger.error(f"Listener error: {e}")
raise
async def run_dialer(
@ -256,7 +293,9 @@ async def run_dialer(
try:
await host.connect(peer_info)
logger.info("Successfully connected to listener!")
print("Successfully connected to listener!")
print("Successfully connected to listener!")
print(f" Connected to: {peer_info.peer_id}")
print(f" Full address: {destination}")
# Push identify information to the listener
logger.info("Pushing identify information to listener...")
@ -270,7 +309,7 @@ async def run_dialer(
if success:
logger.info("Identify push completed successfully!")
print("Identify push completed successfully!")
print("Identify push completed successfully!")
logger.info("Example completed successfully!")
print("\nExample completed successfully!")
@ -281,17 +320,57 @@ async def run_dialer(
logger.warning("Example completed with warnings.")
print("Example completed with warnings.")
except Exception as e:
logger.error(f"Error during identify push: {str(e)}")
print(f"\nError during identify push: {str(e)}")
error_msg = str(e)
logger.error(f"Error during identify push: {error_msg}")
print(f"\nError during identify push: {error_msg}")
# Check for specific format mismatch errors
if (
"Error parsing message" in error_msg
or "DecodeError" in error_msg
or "ParseFromString" in error_msg
):
print("\n" + "=" * 60)
print("FORMAT MISMATCH DETECTED!")
print("=" * 60)
if use_varint_format:
print(
"You are using length-prefixed format (default) but the "
"listener is using raw protobuf format."
)
print(
"\nTo fix this, run the dialer with the --raw-format flag:"
)
print(
f"identify-push-listener-dialer-demo --raw-format -d "
f"{destination}"
)
else:
print("You are using raw protobuf format but the listener")
print("is using length-prefixed format (default).")
print(
"\nTo fix this, run the dialer without the --raw-format "
"flag:"
)
print(f"identify-push-listener-dialer-demo -d {destination}")
print("=" * 60)
logger.error("Example completed with errors.")
print("Example completed with errors.")
# Continue execution despite the push error
except Exception as e:
logger.error(f"Error during dialer operation: {str(e)}")
print(f"\nError during dialer operation: {str(e)}")
raise
error_msg = str(e)
if "unable to connect" in error_msg or "SwarmException" in error_msg:
print(f"\n❌ Cannot connect to peer: {peer_info.peer_id}")
print(f" Address: {destination}")
print(f" Error: {error_msg}")
print("\n💡 Make sure the peer is running and the address is correct.")
return
else:
logger.error(f"Error during dialer operation: {error_msg}")
print(f"\nError during dialer operation: {error_msg}")
raise
def main() -> None:
@ -301,12 +380,21 @@ def main() -> None:
Without arguments, it runs as a listener on random port.
With -d parameter, it runs as a dialer on random port.
Port 0 (default) means the OS will automatically assign an available port.
This prevents port conflicts when running multiple instances.
Use --raw-format to send raw protobuf messages (old format) instead of
length-prefixed protobuf messages (new format, default).
"""
parser = argparse.ArgumentParser(description=description)
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
parser.add_argument(
"-p",
"--port",
default=0,
type=int,
help="source port number (0 = random available port)",
)
parser.add_argument(
"-d",
"--destination",
@ -321,6 +409,7 @@ def main() -> None:
"length-prefixed (new format)"
),
)
args = parser.parse_args()
# Determine format: raw format if --raw-format is specified, otherwise
@ -333,12 +422,12 @@ def main() -> None:
trio.run(run_dialer, args.port, args.destination, use_varint_format)
else:
# Run in listener mode with random available port if not specified
trio.run(run_listener, args.port, use_varint_format)
trio.run(run_listener, args.port, use_varint_format, args.raw_format)
except KeyboardInterrupt:
print("\nInterrupted by user")
logger.info("Interrupted by user")
print("\n👋 Goodbye!")
logger.info("Application interrupted by user")
except Exception as e:
print(f"\nError: {str(e)}")
print(f"\nError: {str(e)}")
logger.error("Error: %s", str(e))
sys.exit(1)

View File

@ -151,7 +151,10 @@ async def run_node(
host = new_host(key_pair=key_pair)
listen_addr = Multiaddr(f"/ip4/127.0.0.1/tcp/{port}")
async with host.run(listen_addrs=[listen_addr]):
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
# Start the peer-store cleanup task
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
peer_id = host.get_id().pretty()
addr_str = f"/ip4/127.0.0.1/tcp/{port}/p2p/{peer_id}"
await connect_to_bootstrap_nodes(host, bootstrap_nodes)

View File

@ -46,7 +46,10 @@ async def run(port: int) -> None:
logger.info("Starting peer Discovery")
host = new_host(key_pair=key_pair, enable_mDNS=True)
async with host.run(listen_addrs=[listen_addr]):
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
# Start the peer-store cleanup task
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
await trio.sleep_forever()

View File

@ -59,6 +59,9 @@ async def run(port: int, destination: str) -> None:
host = new_host(listen_addrs=[listen_addr])
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
# Start the peer-store cleanup task
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
if not destination:
host.set_stream_handler(PING_PROTOCOL_ID, handle_ping)

View File

@ -144,6 +144,9 @@ async def run(topic: str, destination: str | None, port: int | None) -> None:
pubsub = Pubsub(host, gossipsub)
termination_event = trio.Event() # Event to signal termination
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
# Start the peer-store cleanup task
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
logger.info(f"Node started with peer ID: {host.get_id()}")
logger.info(f"Listening on: {listen_addr}")
logger.info("Initializing PubSub and GossipSub...")

View File

@ -251,6 +251,7 @@ def new_host(
muxer_preference: Literal["YAMUX", "MPLEX"] | None = None,
listen_addrs: Sequence[multiaddr.Multiaddr] | None = None,
enable_mDNS: bool = False,
bootstrap: list[str] | None = None,
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
) -> IHost:
"""
@ -264,6 +265,7 @@ def new_host(
:param muxer_preference: optional explicit muxer preference
:param listen_addrs: optional list of multiaddrs to listen on
:param enable_mDNS: whether to enable mDNS discovery
:param bootstrap: optional list of bootstrap peer addresses as strings
:return: return a host instance
"""
swarm = new_swarm(
@ -276,7 +278,7 @@ def new_host(
)
if disc_opt is not None:
return RoutedHost(swarm, disc_opt, enable_mDNS)
return BasicHost(network=swarm,enable_mDNS=enable_mDNS , negotitate_timeout=negotiate_timeout)
return RoutedHost(swarm, disc_opt, enable_mDNS, bootstrap)
return BasicHost(network=swarm,enable_mDNS=enable_mDNS , bootstrap=bootstrap, negotitate_timeout=negotiate_timeout)
__version__ = __version("libp2p")

View File

@ -16,6 +16,7 @@ from typing import (
TYPE_CHECKING,
Any,
AsyncContextManager,
Optional,
)
from multiaddr import (
@ -41,20 +42,19 @@ from libp2p.io.abc import (
from libp2p.peer.id import (
ID,
)
import libp2p.peer.pb.peer_record_pb2 as pb
from libp2p.peer.peerinfo import (
PeerInfo,
)
if TYPE_CHECKING:
from libp2p.peer.envelope import Envelope
from libp2p.peer.peer_record import PeerRecord
from libp2p.protocol_muxer.multiselect import Multiselect
from libp2p.pubsub.pubsub import (
Pubsub,
)
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from libp2p.protocol_muxer.multiselect import Multiselect
from libp2p.pubsub.pb import (
rpc_pb2,
)
@ -357,6 +357,14 @@ class INetConn(Closer):
:return: A tuple containing instances of INetStream.
"""
@abstractmethod
def get_transport_addresses(self) -> list[Multiaddr]:
"""
Retrieve the transport addresses used by this connection.
:return: A list of multiaddresses used by the transport.
"""
# -------------------------- peermetadata interface.py --------------------------
@ -493,6 +501,71 @@ class IAddrBook(ABC):
"""
# ------------------ certified-addr-book interface.py ---------------------
class ICertifiedAddrBook(ABC):
"""
Interface for a certified address book.
Provides methods for managing signed peer records
"""
@abstractmethod
def consume_peer_record(self, envelope: "Envelope", ttl: int) -> bool:
"""
Accept and store a signed PeerRecord, unless it's older than
the one already stored.
This function:
- Extracts the peer ID and sequence number from the envelope
- Rejects the record if it's older (lower seq)
- Updates the stored peer record and replaces associated
addresses if accepted
Parameters
----------
envelope:
Signed envelope containing a PeerRecord.
ttl:
Time-to-live for the included multiaddrs (in seconds).
"""
@abstractmethod
def get_peer_record(self, peer_id: ID) -> Optional["Envelope"]:
"""
Retrieve the most recent signed PeerRecord `Envelope` for a peer, if it exists
and is still relevant.
First, it runs cleanup via `maybe_delete_peer_record` to purge stale data.
Then it checks whether the peer has valid, unexpired addresses before
returning the associated envelope.
Parameters
----------
peer_id : ID
The peer to look up.
"""
@abstractmethod
def maybe_delete_peer_record(self, peer_id: ID) -> None:
"""
Delete the signed peer record for a peer if it has no know
(non-expired) addresses.
This is a garbage collection mechanism: if all addresses for a peer have expired
or been cleared, there's no point holding onto its signed `Envelope`
Parameters
----------
peer_id : ID
The peer whose record we may delete.
"""
# -------------------------- keybook interface.py --------------------------
@ -758,7 +831,9 @@ class IProtoBook(ABC):
# -------------------------- peerstore interface.py --------------------------
class IPeerStore(IPeerMetadata, IAddrBook, IKeyBook, IMetrics, IProtoBook):
class IPeerStore(
IPeerMetadata, IAddrBook, ICertifiedAddrBook, IKeyBook, IMetrics, IProtoBook
):
"""
Interface for a peer store.
@ -893,7 +968,65 @@ class IPeerStore(IPeerMetadata, IAddrBook, IKeyBook, IMetrics, IProtoBook):
"""
# --------CERTIFIED-ADDR-BOOK----------
@abstractmethod
def consume_peer_record(self, envelope: "Envelope", ttl: int) -> bool:
"""
Accept and store a signed PeerRecord, unless it's older
than the one already stored.
This function:
- Extracts the peer ID and sequence number from the envelope
- Rejects the record if it's older (lower seq)
- Updates the stored peer record and replaces associated addresses if accepted
Parameters
----------
envelope:
Signed envelope containing a PeerRecord.
ttl:
Time-to-live for the included multiaddrs (in seconds).
"""
@abstractmethod
def get_peer_record(self, peer_id: ID) -> Optional["Envelope"]:
"""
Retrieve the most recent signed PeerRecord `Envelope` for a peer, if it exists
and is still relevant.
First, it runs cleanup via `maybe_delete_peer_record` to purge stale data.
Then it checks whether the peer has valid, unexpired addresses before
returning the associated envelope.
Parameters
----------
peer_id : ID
The peer to look up.
"""
@abstractmethod
def maybe_delete_peer_record(self, peer_id: ID) -> None:
"""
Delete the signed peer record for a peer if it has no
know (non-expired) addresses.
This is a garbage collection mechanism: if all addresses for a peer have expired
or been cleared, there's no point holding onto its signed `Envelope`
Parameters
----------
peer_id : ID
The peer whose record we may delete.
"""
# --------KEY-BOOK----------
@abstractmethod
def pubkey(self, peer_id: ID) -> PublicKey:
"""
@ -1202,6 +1335,10 @@ class IPeerStore(IPeerMetadata, IAddrBook, IKeyBook, IMetrics, IProtoBook):
def clear_peerdata(self, peer_id: ID) -> None:
"""clear_peerdata"""
@abstractmethod
async def start_cleanup_task(self, cleanup_interval: int = 3600) -> None:
"""Start periodic cleanup of expired peer records and addresses."""
# -------------------------- listener interface.py --------------------------
@ -1689,6 +1826,121 @@ class IHost(ABC):
"""
# -------------------------- peer-record interface.py --------------------------
class IPeerRecord(ABC):
"""
Interface for a libp2p PeerRecord object.
A PeerRecord contains metadata about a peer such as its ID, public addresses,
and a strictly increasing sequence number for versioning.
PeerRecords are used in signed routing Envelopes for secure peer data propagation.
"""
@abstractmethod
def domain(self) -> str:
"""
Return the domain string for this record type.
Used in envelope validation to distinguish different record types.
"""
@abstractmethod
def codec(self) -> bytes:
"""
Return a binary codec prefix that identifies the PeerRecord type.
This is prepended in signed envelopes to allow type-safe decoding.
"""
@abstractmethod
def to_protobuf(self) -> pb.PeerRecord:
"""
Convert this PeerRecord into its Protobuf representation.
:raises ValueError: if serialization fails (e.g., invalid peer ID).
:return: A populated protobuf `PeerRecord` message.
"""
@abstractmethod
def marshal_record(self) -> bytes:
"""
Serialize this PeerRecord into a byte string.
Used when signing or sealing the record in an envelope.
:raises ValueError: if protobuf serialization fails.
:return: Byte-encoded PeerRecord.
"""
@abstractmethod
def equal(self, other: object) -> bool:
"""
Compare this PeerRecord with another for equality.
Two PeerRecords are considered equal if:
- They have the same `peer_id`
- Their `seq` numbers match
- Their address lists are identical and ordered
:param other: Object to compare with.
:return: True if equal, False otherwise.
"""
# -------------------------- envelope interface.py --------------------------
class IEnvelope(ABC):
@abstractmethod
def marshal_envelope(self) -> bytes:
"""
Serialize this Envelope into its protobuf wire format.
Converts all envelope fields into a `pb.Envelope` protobuf message
and returns the serialized bytes.
:return: Serialized envelope as bytes.
"""
@abstractmethod
def validate(self, domain: str) -> None:
"""
Verify the envelope's signature within the given domain scope.
This ensures that the envelope has not been tampered with
and was signed under the correct usage context.
:param domain: Domain string that contextualizes the signature.
:raises ValueError: If the signature is invalid.
"""
@abstractmethod
def record(self) -> "PeerRecord":
"""
Lazily decode and return the embedded PeerRecord.
This method unmarshals the payload bytes into a `PeerRecord` instance,
using the registered codec to identify the type. The decoded result
is cached for future use.
:return: Decoded PeerRecord object.
:raises Exception: If decoding fails or payload type is unsupported.
"""
@abstractmethod
def equal(self, other: Any) -> bool:
"""
Compare this Envelope with another for structural equality.
Two envelopes are considered equal if:
- They have the same public key
- The payload type and payload bytes match
- Their signatures are identical
:param other: Another object to compare.
:return: True if equal, False otherwise.
"""
# -------------------------- peerdata interface.py --------------------------

View File

@ -13,7 +13,7 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1dlibp2p/crypto/pb/crypto.proto\x12\tcrypto.pb\"?\n\tPublicKey\x12$\n\x08key_type\x18\x01 \x02(\x0e\x32\x12.crypto.pb.KeyType\x12\x0c\n\x04\x64\x61ta\x18\x02 \x02(\x0c\"@\n\nPrivateKey\x12$\n\x08key_type\x18\x01 \x02(\x0e\x32\x12.crypto.pb.KeyType\x12\x0c\n\x04\x64\x61ta\x18\x02 \x02(\x0c*G\n\x07KeyType\x12\x07\n\x03RSA\x10\x00\x12\x0b\n\x07\x45\x64\x32\x35\x35\x31\x39\x10\x01\x12\r\n\tSecp256k1\x10\x02\x12\t\n\x05\x45\x43\x44SA\x10\x03\x12\x0c\n\x08\x45\x43\x43_P256\x10\x04')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1dlibp2p/crypto/pb/crypto.proto\x12\tcrypto.pb\"?\n\tPublicKey\x12$\n\x08key_type\x18\x01 \x02(\x0e\x32\x12.crypto.pb.KeyType\x12\x0c\n\x04\x64\x61ta\x18\x02 \x02(\x0c\"@\n\nPrivateKey\x12$\n\x08key_type\x18\x01 \x02(\x0e\x32\x12.crypto.pb.KeyType\x12\x0c\n\x04\x64\x61ta\x18\x02 \x02(\x0c*S\n\x07KeyType\x12\x07\n\x03RSA\x10\x00\x12\x0b\n\x07\x45\x64\x32\x35\x35\x31\x39\x10\x01\x12\r\n\tSecp256k1\x10\x02\x12\t\n\x05\x45\x43\x44SA\x10\x03\x12\x0c\n\x08\x45\x43\x43_P256\x10\x04\x12\n\n\x06X25519\x10\x05')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.crypto.pb.crypto_pb2', globals())
@ -21,7 +21,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_KEYTYPE._serialized_start=175
_KEYTYPE._serialized_end=246
_KEYTYPE._serialized_end=258
_PUBLICKEY._serialized_start=44
_PUBLICKEY._serialized_end=107
_PRIVATEKEY._serialized_start=109

View File

@ -28,6 +28,7 @@ class _KeyTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTy
Secp256k1: _KeyType.ValueType # 2
ECDSA: _KeyType.ValueType # 3
ECC_P256: _KeyType.ValueType # 4
X25519: _KeyType.ValueType # 5
class KeyType(_KeyType, metaclass=_KeyTypeEnumTypeWrapper): ...
@ -36,6 +37,7 @@ Ed25519: KeyType.ValueType # 1
Secp256k1: KeyType.ValueType # 2
ECDSA: KeyType.ValueType # 3
ECC_P256: KeyType.ValueType # 4
X25519: KeyType.ValueType # 5
global___KeyType = KeyType
@typing.final

View File

@ -0,0 +1,5 @@
"""Bootstrap peer discovery module for py-libp2p."""
from .bootstrap import BootstrapDiscovery
__all__ = ["BootstrapDiscovery"]

View File

@ -0,0 +1,94 @@
import logging
from multiaddr import Multiaddr
from multiaddr.resolvers import DNSResolver
from libp2p.abc import ID, INetworkService, PeerInfo
from libp2p.discovery.bootstrap.utils import validate_bootstrap_addresses
from libp2p.discovery.events.peerDiscovery import peerDiscovery
from libp2p.peer.peerinfo import info_from_p2p_addr
logger = logging.getLogger("libp2p.discovery.bootstrap")
resolver = DNSResolver()
class BootstrapDiscovery:
"""
Bootstrap-based peer discovery for py-libp2p.
Connects to predefined bootstrap peers and adds them to peerstore.
"""
def __init__(self, swarm: INetworkService, bootstrap_addrs: list[str]):
self.swarm = swarm
self.peerstore = swarm.peerstore
self.bootstrap_addrs = bootstrap_addrs or []
self.discovered_peers: set[str] = set()
async def start(self) -> None:
"""Process bootstrap addresses and emit peer discovery events."""
logger.debug(
f"Starting bootstrap discovery with "
f"{len(self.bootstrap_addrs)} bootstrap addresses"
)
# Validate and filter bootstrap addresses
self.bootstrap_addrs = validate_bootstrap_addresses(self.bootstrap_addrs)
for addr_str in self.bootstrap_addrs:
try:
await self._process_bootstrap_addr(addr_str)
except Exception as e:
logger.debug(f"Failed to process bootstrap address {addr_str}: {e}")
def stop(self) -> None:
"""Clean up bootstrap discovery resources."""
logger.debug("Stopping bootstrap discovery")
self.discovered_peers.clear()
async def _process_bootstrap_addr(self, addr_str: str) -> None:
"""Convert string address to PeerInfo and add to peerstore."""
try:
multiaddr = Multiaddr(addr_str)
except Exception as e:
logger.debug(f"Invalid multiaddr format '{addr_str}': {e}")
return
if self.is_dns_addr(multiaddr):
resolved_addrs = await resolver.resolve(multiaddr)
peer_id_str = multiaddr.get_peer_id()
if peer_id_str is None:
logger.warning(f"Missing peer ID in DNS address: {addr_str}")
return
peer_id = ID.from_base58(peer_id_str)
addrs = [addr for addr in resolved_addrs]
if not addrs:
logger.warning(f"No addresses resolved for DNS address: {addr_str}")
return
peer_info = PeerInfo(peer_id, addrs)
self.add_addr(peer_info)
else:
self.add_addr(info_from_p2p_addr(multiaddr))
def is_dns_addr(self, addr: Multiaddr) -> bool:
"""Check if the address is a DNS address."""
return any(protocol.name == "dnsaddr" for protocol in addr.protocols())
def add_addr(self, peer_info: PeerInfo) -> None:
"""Add a peer to the peerstore and emit discovery event."""
# Skip if it's our own peer
if peer_info.peer_id == self.swarm.get_peer_id():
logger.debug(f"Skipping own peer ID: {peer_info.peer_id}")
return
# Always add addresses to peerstore (allows multiple addresses for same peer)
self.peerstore.add_addrs(peer_info.peer_id, peer_info.addrs, 10)
# Only emit discovery event if this is the first time we see this peer
peer_id_str = str(peer_info.peer_id)
if peer_id_str not in self.discovered_peers:
# Track discovered peer
self.discovered_peers.add(peer_id_str)
# Emit peer discovery event
peerDiscovery.emit_peer_discovered(peer_info)
logger.debug(f"Peer discovered: {peer_info.peer_id}")
else:
logger.debug(f"Additional addresses added for peer: {peer_info.peer_id}")

View File

@ -0,0 +1,51 @@
"""Utility functions for bootstrap discovery."""
import logging
from multiaddr import Multiaddr
from libp2p.peer.peerinfo import InvalidAddrError, PeerInfo, info_from_p2p_addr
logger = logging.getLogger("libp2p.discovery.bootstrap.utils")
def validate_bootstrap_addresses(addrs: list[str]) -> list[str]:
"""
Validate and filter bootstrap addresses.
:param addrs: List of bootstrap address strings
:return: List of valid bootstrap addresses
"""
valid_addrs = []
for addr_str in addrs:
try:
# Try to parse as multiaddr
multiaddr = Multiaddr(addr_str)
# Try to extract peer info (this validates the p2p component)
info_from_p2p_addr(multiaddr)
valid_addrs.append(addr_str)
logger.debug(f"Valid bootstrap address: {addr_str}")
except (InvalidAddrError, ValueError, Exception) as e:
logger.warning(f"Invalid bootstrap address '{addr_str}': {e}")
continue
return valid_addrs
def parse_bootstrap_peer_info(addr_str: str) -> PeerInfo | None:
"""
Parse bootstrap address string into PeerInfo.
:param addr_str: Bootstrap address string
:return: PeerInfo object or None if parsing fails
"""
try:
multiaddr = Multiaddr(addr_str)
return info_from_p2p_addr(multiaddr)
except Exception as e:
logger.error(f"Failed to parse bootstrap address '{addr_str}': {e}")
return None

View File

@ -29,6 +29,7 @@ from libp2p.custom_types import (
StreamHandlerFn,
TProtocol,
)
from libp2p.discovery.bootstrap.bootstrap import BootstrapDiscovery
from libp2p.discovery.mdns.mdns import MDNSDiscovery
from libp2p.host.defaults import (
get_default_protocols,
@ -92,6 +93,7 @@ class BasicHost(IHost):
self,
network: INetworkService,
enable_mDNS: bool = False,
bootstrap: list[str] | None = None,
default_protocols: Optional["OrderedDict[TProtocol, StreamHandlerFn]"] = None,
negotitate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
) -> None:
@ -105,6 +107,8 @@ class BasicHost(IHost):
self.multiselect_client = MultiselectClient()
if enable_mDNS:
self.mDNS = MDNSDiscovery(network)
if bootstrap:
self.bootstrap = BootstrapDiscovery(network, bootstrap)
def get_id(self) -> ID:
"""
@ -172,11 +176,16 @@ class BasicHost(IHost):
if hasattr(self, "mDNS") and self.mDNS is not None:
logger.debug("Starting mDNS Discovery")
self.mDNS.start()
if hasattr(self, "bootstrap") and self.bootstrap is not None:
logger.debug("Starting Bootstrap Discovery")
await self.bootstrap.start()
try:
yield
finally:
if hasattr(self, "mDNS") and self.mDNS is not None:
self.mDNS.stop()
if hasattr(self, "bootstrap") and self.bootstrap is not None:
self.bootstrap.stop()
return _run()

View File

@ -19,9 +19,13 @@ class RoutedHost(BasicHost):
_router: IPeerRouting
def __init__(
self, network: INetworkService, router: IPeerRouting, enable_mDNS: bool = False
self,
network: INetworkService,
router: IPeerRouting,
enable_mDNS: bool = False,
bootstrap: list[str] | None = None,
):
super().__init__(network, enable_mDNS)
super().__init__(network, enable_mDNS, bootstrap)
self._router = router
async def connect(self, peer_info: PeerInfo) -> None:

View File

@ -15,6 +15,8 @@ from libp2p.custom_types import (
from libp2p.network.stream.exceptions import (
StreamClosed,
)
from libp2p.peer.envelope import seal_record
from libp2p.peer.peer_record import PeerRecord
from libp2p.utils import (
decode_varint_with_size,
get_agent_version,
@ -63,6 +65,11 @@ def _mk_identify_protobuf(
laddrs = host.get_addrs()
protocols = tuple(str(p) for p in host.get_mux().get_protocols() if p is not None)
# Create a signed peer-record for the remote peer
record = PeerRecord(host.get_id(), host.get_addrs())
envelope = seal_record(record, host.get_private_key())
protobuf = envelope.marshal_envelope()
observed_addr = observed_multiaddr.to_bytes() if observed_multiaddr else b""
return Identify(
protocol_version=PROTOCOL_VERSION,
@ -71,6 +78,7 @@ def _mk_identify_protobuf(
listen_addrs=map(_multiaddr_to_bytes, laddrs),
observed_addr=observed_addr,
protocols=protocols,
signedPeerRecord=protobuf,
)
@ -113,7 +121,7 @@ def parse_identify_response(response: bytes) -> Identify:
def identify_handler_for(
host: IHost, use_varint_format: bool = False
host: IHost, use_varint_format: bool = True
) -> StreamHandlerFn:
async def handle_identify(stream: INetStream) -> None:
# get observed address from ``stream``

View File

@ -9,4 +9,5 @@ message Identify {
repeated bytes listen_addrs = 2;
optional bytes observed_addr = 4;
repeated string protocols = 3;
optional bytes signedPeerRecord = 8;
}

View File

@ -13,7 +13,7 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n*libp2p/identity/identify/pb/identify.proto\x12\x0bidentify.pb\"\x8f\x01\n\x08Identify\x12\x18\n\x10protocol_version\x18\x05 \x01(\t\x12\x15\n\ragent_version\x18\x06 \x01(\t\x12\x12\n\npublic_key\x18\x01 \x01(\x0c\x12\x14\n\x0clisten_addrs\x18\x02 \x03(\x0c\x12\x15\n\robserved_addr\x18\x04 \x01(\x0c\x12\x11\n\tprotocols\x18\x03 \x03(\t')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n*libp2p/identity/identify/pb/identify.proto\x12\x0bidentify.pb\"\xa9\x01\n\x08Identify\x12\x18\n\x10protocol_version\x18\x05 \x01(\t\x12\x15\n\ragent_version\x18\x06 \x01(\t\x12\x12\n\npublic_key\x18\x01 \x01(\x0c\x12\x14\n\x0clisten_addrs\x18\x02 \x03(\x0c\x12\x15\n\robserved_addr\x18\x04 \x01(\x0c\x12\x11\n\tprotocols\x18\x03 \x03(\t\x12\x18\n\x10signedPeerRecord\x18\x08 \x01(\x0c')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.identity.identify.pb.identify_pb2', globals())
@ -21,5 +21,5 @@ if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_IDENTIFY._serialized_start=60
_IDENTIFY._serialized_end=203
_IDENTIFY._serialized_end=229
# @@protoc_insertion_point(module_scope)

View File

@ -22,10 +22,12 @@ class Identify(google.protobuf.message.Message):
LISTEN_ADDRS_FIELD_NUMBER: builtins.int
OBSERVED_ADDR_FIELD_NUMBER: builtins.int
PROTOCOLS_FIELD_NUMBER: builtins.int
SIGNEDPEERRECORD_FIELD_NUMBER: builtins.int
protocol_version: builtins.str
agent_version: builtins.str
public_key: builtins.bytes
observed_addr: builtins.bytes
signedPeerRecord: builtins.bytes
@property
def listen_addrs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ...
@property
@ -39,8 +41,9 @@ class Identify(google.protobuf.message.Message):
listen_addrs: collections.abc.Iterable[builtins.bytes] | None = ...,
observed_addr: builtins.bytes | None = ...,
protocols: collections.abc.Iterable[builtins.str] | None = ...,
signedPeerRecord: builtins.bytes | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["agent_version", b"agent_version", "observed_addr", b"observed_addr", "protocol_version", b"protocol_version", "public_key", b"public_key"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["agent_version", b"agent_version", "listen_addrs", b"listen_addrs", "observed_addr", b"observed_addr", "protocol_version", b"protocol_version", "protocols", b"protocols", "public_key", b"public_key"]) -> None: ...
def HasField(self, field_name: typing.Literal["agent_version", b"agent_version", "observed_addr", b"observed_addr", "protocol_version", b"protocol_version", "public_key", b"public_key", "signedPeerRecord", b"signedPeerRecord"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["agent_version", b"agent_version", "listen_addrs", b"listen_addrs", "observed_addr", b"observed_addr", "protocol_version", b"protocol_version", "protocols", b"protocols", "public_key", b"public_key", "signedPeerRecord", b"signedPeerRecord"]) -> None: ...
global___Identify = Identify

View File

@ -20,6 +20,7 @@ from libp2p.custom_types import (
from libp2p.network.stream.exceptions import (
StreamClosed,
)
from libp2p.peer.envelope import consume_envelope
from libp2p.peer.id import (
ID,
)
@ -28,7 +29,7 @@ from libp2p.utils import (
varint,
)
from libp2p.utils.varint import (
decode_varint_from_bytes,
read_length_prefixed_protobuf,
)
from ..identify.identify import (
@ -66,49 +67,8 @@ def identify_push_handler_for(
peer_id = stream.muxed_conn.peer_id
try:
if use_varint_format:
# Read length-prefixed identify message from the stream
# First read the varint length prefix
length_bytes = b""
while True:
b = await stream.read(1)
if not b:
break
length_bytes += b
if b[0] & 0x80 == 0:
break
if not length_bytes:
logger.warning("No length prefix received from peer %s", peer_id)
return
msg_length = decode_varint_from_bytes(length_bytes)
# Read the protobuf message
data = await stream.read(msg_length)
if len(data) != msg_length:
logger.warning("Incomplete message received from peer %s", peer_id)
return
else:
# Read raw protobuf message from the stream
# For raw format, we need to read all data before the stream is closed
data = b""
try:
# Read all available data in a single operation
data = await stream.read()
except StreamClosed:
# Try to read any remaining data
try:
data = await stream.read()
except Exception:
pass
# If we got no data, log a warning and return
if not data:
logger.warning(
"No data received in raw format from peer %s", peer_id
)
return
# Use the utility function to read the protobuf message
data = await read_length_prefixed_protobuf(stream, use_varint_format)
identify_msg = Identify()
identify_msg.ParseFromString(data)
@ -119,6 +79,11 @@ def identify_push_handler_for(
)
logger.debug("Successfully processed identify/push from peer %s", peer_id)
# Send acknowledgment to indicate successful processing
# This ensures the sender knows the message was received before closing
await stream.write(b"OK")
except StreamClosed:
logger.debug(
"Stream closed while processing identify/push from %s", peer_id
@ -127,7 +92,10 @@ def identify_push_handler_for(
logger.error("Error processing identify/push from %s: %s", peer_id, e)
finally:
# Close the stream after processing
await stream.close()
try:
await stream.close()
except Exception:
pass # Ignore errors when closing
return handle_identify_push
@ -173,6 +141,19 @@ async def _update_peerstore_from_identify(
except Exception as e:
logger.error("Error updating protocols for peer %s: %s", peer_id, e)
if identify_msg.HasField("signedPeerRecord"):
try:
# Convert the signed-peer-record(Envelope) from prtobuf bytes
envelope, _ = consume_envelope(
identify_msg.signedPeerRecord, "libp2p-peer-record"
)
# Use a default TTL of 2 hours (7200 seconds)
if not peerstore.consume_peer_record(envelope, 7200):
logger.error("Updating Certified-Addr-Book was unsuccessful")
except Exception as e:
logger.error(
"Error updating the certified addr book for peer %s: %s", peer_id, e
)
# Update observed address if present
if identify_msg.HasField("observed_addr") and identify_msg.observed_addr:
try:
@ -226,7 +207,20 @@ async def push_identify_to_peer(
# Send raw protobuf message
await stream.write(response)
# Close the stream
# Wait for acknowledgment from the receiver with timeout
# This ensures the message was processed before closing
try:
with trio.move_on_after(1.0): # 1 second timeout
ack = await stream.read(2) # Read "OK" acknowledgment
if ack != b"OK":
logger.warning(
"Unexpected acknowledgment from peer %s: %s", peer_id, ack
)
except Exception as e:
logger.debug("No acknowledgment received from peer %s: %s", peer_id, e)
# Continue anyway, as the message might have been processed
# Close the stream after acknowledgment (or timeout)
await stream.close()
logger.debug("Successfully pushed identify to peer %s", peer_id)

View File

@ -2,10 +2,10 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: libp2p/kad_dht/pb/kademlia.proto
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
@ -15,19 +15,19 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/kad_dht/pb/kademlia.proto\":\n\x06Record\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x14\n\x0ctimeReceived\x18\x05 \x01(\t\"\xca\x03\n\x07Message\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.Message.MessageType\x12\x17\n\x0f\x63lusterLevelRaw\x18\n \x01(\x05\x12\x0b\n\x03key\x18\x02 \x01(\x0c\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record\x12\"\n\x0b\x63loserPeers\x18\x08 \x03(\x0b\x32\r.Message.Peer\x12$\n\rproviderPeers\x18\t \x03(\x0b\x32\r.Message.Peer\x1aN\n\x04Peer\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12+\n\nconnection\x18\x03 \x01(\x0e\x32\x17.Message.ConnectionType\"i\n\x0bMessageType\x12\r\n\tPUT_VALUE\x10\x00\x12\r\n\tGET_VALUE\x10\x01\x12\x10\n\x0c\x41\x44\x44_PROVIDER\x10\x02\x12\x11\n\rGET_PROVIDERS\x10\x03\x12\r\n\tFIND_NODE\x10\x04\x12\x08\n\x04PING\x10\x05\"W\n\x0e\x43onnectionType\x12\x11\n\rNOT_CONNECTED\x10\x00\x12\r\n\tCONNECTED\x10\x01\x12\x0f\n\x0b\x43\x41N_CONNECT\x10\x02\x12\x12\n\x0e\x43\x41NNOT_CONNECT\x10\x03\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.kad_dht.pb.kademlia_pb2', _globals)
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.kad_dht.pb.kademlia_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_globals['_RECORD']._serialized_start=36
_globals['_RECORD']._serialized_end=94
_globals['_MESSAGE']._serialized_start=97
_globals['_MESSAGE']._serialized_end=555
_globals['_MESSAGE_PEER']._serialized_start=281
_globals['_MESSAGE_PEER']._serialized_end=359
_globals['_MESSAGE_MESSAGETYPE']._serialized_start=361
_globals['_MESSAGE_MESSAGETYPE']._serialized_end=466
_globals['_MESSAGE_CONNECTIONTYPE']._serialized_start=468
_globals['_MESSAGE_CONNECTIONTYPE']._serialized_end=555
_RECORD._serialized_start=36
_RECORD._serialized_end=94
_MESSAGE._serialized_start=97
_MESSAGE._serialized_end=555
_MESSAGE_PEER._serialized_start=281
_MESSAGE_PEER._serialized_end=359
_MESSAGE_MESSAGETYPE._serialized_start=361
_MESSAGE_MESSAGETYPE._serialized_end=466
_MESSAGE_CONNECTIONTYPE._serialized_start=468
_MESSAGE_CONNECTIONTYPE._serialized_end=555
# @@protoc_insertion_point(module_scope)

View File

@ -3,6 +3,7 @@ from typing import (
TYPE_CHECKING,
)
from multiaddr import Multiaddr
import trio
from libp2p.abc import (
@ -147,6 +148,24 @@ class SwarmConn(INetConn):
def get_streams(self) -> tuple[NetStream, ...]:
return tuple(self.streams)
def get_transport_addresses(self) -> list[Multiaddr]:
"""
Retrieve the transport addresses used by this connection.
Returns
-------
list[Multiaddr]
A list of multiaddresses used by the transport.
"""
# Return the addresses from the peerstore for this peer
try:
peer_id = self.muxed_conn.peer_id
return self.swarm.peerstore.addrs(peer_id)
except Exception as e:
logging.warning(f"Error getting transport addresses: {e}")
return []
def remove_stream(self, stream: NetStream) -> None:
if stream not in self.streams:
return

271
libp2p/peer/envelope.py Normal file
View File

@ -0,0 +1,271 @@
from typing import Any, cast
from libp2p.crypto.ed25519 import Ed25519PublicKey
from libp2p.crypto.keys import PrivateKey, PublicKey
from libp2p.crypto.rsa import RSAPublicKey
from libp2p.crypto.secp256k1 import Secp256k1PublicKey
import libp2p.peer.pb.crypto_pb2 as cryto_pb
import libp2p.peer.pb.envelope_pb2 as pb
import libp2p.peer.pb.peer_record_pb2 as record_pb
from libp2p.peer.peer_record import (
PeerRecord,
peer_record_from_protobuf,
unmarshal_record,
)
from libp2p.utils.varint import encode_uvarint
ENVELOPE_DOMAIN = "libp2p-peer-record"
PEER_RECORD_CODEC = b"\x03\x01"
class Envelope:
"""
A signed wrapper around a serialized libp2p record.
Envelopes are cryptographically signed by the author's private key
and are scoped to a specific 'domain' to prevent cross-protocol replay.
Attributes:
public_key: The public key that can verify the envelope's signature.
payload_type: A multicodec code identifying the type of payload inside.
raw_payload: The raw serialized record data.
signature: Signature over the domain-scoped payload content.
"""
public_key: PublicKey
payload_type: bytes
raw_payload: bytes
signature: bytes
_cached_record: PeerRecord | None = None
_unmarshal_error: Exception | None = None
def __init__(
self,
public_key: PublicKey,
payload_type: bytes,
raw_payload: bytes,
signature: bytes,
):
self.public_key = public_key
self.payload_type = payload_type
self.raw_payload = raw_payload
self.signature = signature
def marshal_envelope(self) -> bytes:
"""
Serialize this Envelope into its protobuf wire format.
Converts all envelope fields into a `pb.Envelope` protobuf message
and returns the serialized bytes.
:return: Serialized envelope as bytes.
"""
pb_env = pb.Envelope(
public_key=pub_key_to_protobuf(self.public_key),
payload_type=self.payload_type,
payload=self.raw_payload,
signature=self.signature,
)
return pb_env.SerializeToString()
def validate(self, domain: str) -> None:
"""
Verify the envelope's signature within the given domain scope.
This ensures that the envelope has not been tampered with
and was signed under the correct usage context.
:param domain: Domain string that contextualizes the signature.
:raises ValueError: If the signature is invalid.
"""
unsigned = make_unsigned(domain, self.payload_type, self.raw_payload)
if not self.public_key.verify(unsigned, self.signature):
raise ValueError("Invalid envelope signature")
def record(self) -> PeerRecord:
"""
Lazily decode and return the embedded PeerRecord.
This method unmarshals the payload bytes into a `PeerRecord` instance,
using the registered codec to identify the type. The decoded result
is cached for future use.
:return: Decoded PeerRecord object.
:raises Exception: If decoding fails or payload type is unsupported.
"""
if self._cached_record is not None:
return self._cached_record
try:
if self.payload_type != PEER_RECORD_CODEC:
raise ValueError("Unsuported payload type in envelope")
msg = record_pb.PeerRecord()
msg.ParseFromString(self.raw_payload)
self._cached_record = peer_record_from_protobuf(msg)
return self._cached_record
except Exception as e:
self._unmarshal_error = e
raise
def equal(self, other: Any) -> bool:
"""
Compare this Envelope with another for structural equality.
Two envelopes are considered equal if:
- They have the same public key
- The payload type and payload bytes match
- Their signatures are identical
:param other: Another object to compare.
:return: True if equal, False otherwise.
"""
if isinstance(other, Envelope):
return (
self.public_key.__eq__(other.public_key)
and self.payload_type == other.payload_type
and self.signature == other.signature
and self.raw_payload == other.raw_payload
)
return False
def pub_key_to_protobuf(pub_key: PublicKey) -> cryto_pb.PublicKey:
"""
Convert a Python PublicKey object to its protobuf equivalent.
:param pub_key: A libp2p-compatible PublicKey instance.
:return: Serialized protobuf PublicKey message.
"""
internal_key_type = pub_key.get_type()
key_type = cast(cryto_pb.KeyType, internal_key_type.value)
data = pub_key.to_bytes()
protobuf_key = cryto_pb.PublicKey(Type=key_type, Data=data)
return protobuf_key
def pub_key_from_protobuf(pb_key: cryto_pb.PublicKey) -> PublicKey:
"""
Parse a protobuf PublicKey message into a native libp2p PublicKey.
Supports Ed25519, RSA, and Secp256k1 key types.
:param pb_key: Protobuf representation of a public key.
:return: Parsed PublicKey object.
:raises ValueError: If the key type is unrecognized.
"""
if pb_key.Type == cryto_pb.KeyType.Ed25519:
return Ed25519PublicKey.from_bytes(pb_key.Data)
elif pb_key.Type == cryto_pb.KeyType.RSA:
return RSAPublicKey.from_bytes(pb_key.Data)
elif pb_key.Type == cryto_pb.KeyType.Secp256k1:
return Secp256k1PublicKey.from_bytes(pb_key.Data)
# libp2p.crypto.ecdsa not implemented
else:
raise ValueError(f"Unknown key type: {pb_key.Type}")
def seal_record(record: PeerRecord, private_key: PrivateKey) -> Envelope:
"""
Create and sign a new Envelope from a PeerRecord.
The record is serialized and signed in the scope of its domain and codec.
The result is a self-contained, verifiable Envelope.
:param record: A PeerRecord to encapsulate.
:param private_key: The signer's private key.
:return: A signed Envelope instance.
"""
payload = record.marshal_record()
unsigned = make_unsigned(record.domain(), record.codec(), payload)
signature = private_key.sign(unsigned)
return Envelope(
public_key=private_key.get_public_key(),
payload_type=record.codec(),
raw_payload=payload,
signature=signature,
)
def consume_envelope(data: bytes, domain: str) -> tuple[Envelope, PeerRecord]:
"""
Parse, validate, and decode an Envelope from bytes.
Validates the envelope's signature using the given domain and decodes
the inner payload into a PeerRecord.
:param data: Serialized envelope bytes.
:param domain: Domain string to verify signature against.
:return: Tuple of (Envelope, PeerRecord).
:raises ValueError: If signature validation or decoding fails.
"""
env = unmarshal_envelope(data)
env.validate(domain)
record = env.record()
return env, record
def unmarshal_envelope(data: bytes) -> Envelope:
"""
Deserialize an Envelope from its wire format.
This parses the protobuf fields without verifying the signature.
:param data: Serialized envelope bytes.
:return: Parsed Envelope object.
:raises DecodeError: If protobuf parsing fails.
"""
pb_env = pb.Envelope()
pb_env.ParseFromString(data)
pk = pub_key_from_protobuf(pb_env.public_key)
return Envelope(
public_key=pk,
payload_type=pb_env.payload_type,
raw_payload=pb_env.payload,
signature=pb_env.signature,
)
def make_unsigned(domain: str, payload_type: bytes, payload: bytes) -> bytes:
"""
Build a byte buffer to be signed for an Envelope.
The unsigned byte structure is:
varint(len(domain)) || domain ||
varint(len(payload_type)) || payload_type ||
varint(len(payload)) || payload
This is the exact input used during signing and verification.
:param domain: Domain string for signature scoping.
:param payload_type: Identifier for the type of payload.
:param payload: Raw serialized payload bytes.
:return: Byte buffer to be signed or verified.
"""
fields = [domain.encode(), payload_type, payload]
buf = bytearray()
for field in fields:
buf.extend(encode_uvarint(len(field)))
buf.extend(field)
return bytes(buf)
def debug_dump_envelope(env: Envelope) -> None:
print("\n=== Envelope ===")
print(f"Payload Type: {env.payload_type!r}")
print(f"Signature: {env.signature.hex()} ({len(env.signature)} bytes)")
print(f"Raw Payload: {env.raw_payload.hex()} ({len(env.raw_payload)} bytes)")
try:
peer_record = unmarshal_record(env.raw_payload)
print("\n=== Parsed PeerRecord ===")
print(peer_record)
except Exception as e:
print("Failed to parse PeerRecord:", e)

View File

@ -1,3 +1,4 @@
import functools
import hashlib
import base58
@ -36,25 +37,23 @@ if ENABLE_INLINING:
class ID:
_bytes: bytes
_xor_id: int | None = None
_b58_str: str | None = None
def __init__(self, peer_id_bytes: bytes) -> None:
self._bytes = peer_id_bytes
@property
@functools.cached_property
def xor_id(self) -> int:
if not self._xor_id:
self._xor_id = int(sha256_digest(self._bytes).hex(), 16)
return self._xor_id
return int(sha256_digest(self._bytes).hex(), 16)
@functools.cached_property
def base58(self) -> str:
return base58.b58encode(self._bytes).decode()
def to_bytes(self) -> bytes:
return self._bytes
def to_base58(self) -> str:
if not self._b58_str:
self._b58_str = base58.b58encode(self._bytes).decode()
return self._b58_str
return self.base58
def __repr__(self) -> str:
return f"<libp2p.peer.id.ID ({self!s})>"

View File

@ -0,0 +1,22 @@
syntax = "proto3";
package libp2p.peer.pb.crypto;
option go_package = "github.com/libp2p/go-libp2p/core/crypto/pb";
enum KeyType {
RSA = 0;
Ed25519 = 1;
Secp256k1 = 2;
ECDSA = 3;
}
message PublicKey {
KeyType Type = 1;
bytes Data = 2;
}
message PrivateKey {
KeyType Type = 1;
bytes Data = 2;
}

View File

@ -0,0 +1,31 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: libp2p/peer/pb/crypto.proto
# Protobuf Python Version: 4.25.3
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1blibp2p/peer/pb/crypto.proto\x12\x15libp2p.peer.pb.crypto\"G\n\tPublicKey\x12,\n\x04Type\x18\x01 \x01(\x0e\x32\x1e.libp2p.peer.pb.crypto.KeyType\x12\x0c\n\x04\x44\x61ta\x18\x02 \x01(\x0c\"H\n\nPrivateKey\x12,\n\x04Type\x18\x01 \x01(\x0e\x32\x1e.libp2p.peer.pb.crypto.KeyType\x12\x0c\n\x04\x44\x61ta\x18\x02 \x01(\x0c*9\n\x07KeyType\x12\x07\n\x03RSA\x10\x00\x12\x0b\n\x07\x45\x64\x32\x35\x35\x31\x39\x10\x01\x12\r\n\tSecp256k1\x10\x02\x12\t\n\x05\x45\x43\x44SA\x10\x03\x42,Z*github.com/libp2p/go-libp2p/core/crypto/pbb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.peer.pb.crypto_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
_globals['DESCRIPTOR']._options = None
_globals['DESCRIPTOR']._serialized_options = b'Z*github.com/libp2p/go-libp2p/core/crypto/pb'
_globals['_KEYTYPE']._serialized_start=201
_globals['_KEYTYPE']._serialized_end=258
_globals['_PUBLICKEY']._serialized_start=54
_globals['_PUBLICKEY']._serialized_end=125
_globals['_PRIVATEKEY']._serialized_start=127
_globals['_PRIVATEKEY']._serialized_end=199
# @@protoc_insertion_point(module_scope)

View File

@ -0,0 +1,33 @@
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
DESCRIPTOR: _descriptor.FileDescriptor
class KeyType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
RSA: _ClassVar[KeyType]
Ed25519: _ClassVar[KeyType]
Secp256k1: _ClassVar[KeyType]
ECDSA: _ClassVar[KeyType]
RSA: KeyType
Ed25519: KeyType
Secp256k1: KeyType
ECDSA: KeyType
class PublicKey(_message.Message):
__slots__ = ("Type", "Data")
TYPE_FIELD_NUMBER: _ClassVar[int]
DATA_FIELD_NUMBER: _ClassVar[int]
Type: KeyType
Data: bytes
def __init__(self, Type: _Optional[_Union[KeyType, str]] = ..., Data: _Optional[bytes] = ...) -> None: ...
class PrivateKey(_message.Message):
__slots__ = ("Type", "Data")
TYPE_FIELD_NUMBER: _ClassVar[int]
DATA_FIELD_NUMBER: _ClassVar[int]
Type: KeyType
Data: bytes
def __init__(self, Type: _Optional[_Union[KeyType, str]] = ..., Data: _Optional[bytes] = ...) -> None: ...

View File

@ -0,0 +1,14 @@
syntax = "proto3";
package libp2p.peer.pb.record;
import "libp2p/peer/pb/crypto.proto";
option go_package = "github.com/libp2p/go-libp2p/core/record/pb";
message Envelope {
libp2p.peer.pb.crypto.PublicKey public_key = 1;
bytes payload_type = 2;
bytes payload = 3;
bytes signature = 5;
}

View File

@ -0,0 +1,28 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: libp2p/peer/pb/envelope.proto
# Protobuf Python Version: 4.25.3
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from libp2p.peer.pb import crypto_pb2 as libp2p_dot_peer_dot_pb_dot_crypto__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1dlibp2p/peer/pb/envelope.proto\x12\x15libp2p.peer.pb.record\x1a\x1blibp2p/peer/pb/crypto.proto\"z\n\x08\x45nvelope\x12\x34\n\npublic_key\x18\x01 \x01(\x0b\x32 .libp2p.peer.pb.crypto.PublicKey\x12\x14\n\x0cpayload_type\x18\x02 \x01(\x0c\x12\x0f\n\x07payload\x18\x03 \x01(\x0c\x12\x11\n\tsignature\x18\x05 \x01(\x0c\x42,Z*github.com/libp2p/go-libp2p/core/record/pbb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.peer.pb.envelope_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
_globals['DESCRIPTOR']._options = None
_globals['DESCRIPTOR']._serialized_options = b'Z*github.com/libp2p/go-libp2p/core/record/pb'
_globals['_ENVELOPE']._serialized_start=85
_globals['_ENVELOPE']._serialized_end=207
# @@protoc_insertion_point(module_scope)

View File

@ -0,0 +1,18 @@
from libp2p.peer.pb import crypto_pb2 as _crypto_pb2
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Mapping as _Mapping, Optional as _Optional, Union as _Union
DESCRIPTOR: _descriptor.FileDescriptor
class Envelope(_message.Message):
__slots__ = ("public_key", "payload_type", "payload", "signature")
PUBLIC_KEY_FIELD_NUMBER: _ClassVar[int]
PAYLOAD_TYPE_FIELD_NUMBER: _ClassVar[int]
PAYLOAD_FIELD_NUMBER: _ClassVar[int]
SIGNATURE_FIELD_NUMBER: _ClassVar[int]
public_key: _crypto_pb2.PublicKey
payload_type: bytes
payload: bytes
signature: bytes
def __init__(self, public_key: _Optional[_Union[_crypto_pb2.PublicKey, _Mapping]] = ..., payload_type: _Optional[bytes] = ..., payload: _Optional[bytes] = ..., signature: _Optional[bytes] = ...) -> None: ... # type: ignore[type-arg]

View File

@ -0,0 +1,31 @@
syntax = "proto3";
package peer.pb;
option go_package = "github.com/libp2p/go-libp2p/core/peer/pb";
// PeerRecord messages contain information that is useful to share with other peers.
// Currently, a PeerRecord contains the public listen addresses for a peer, but this
// is expected to expand to include other information in the future.
//
// PeerRecords are designed to be serialized to bytes and placed inside of
// SignedEnvelopes before sharing with other peers.
// See https://github.com/libp2p/go-libp2p/blob/master/core/record/pb/envelope.proto for
// the SignedEnvelope definition.
message PeerRecord {
// AddressInfo is a wrapper around a binary multiaddr. It is defined as a
// separate message to allow us to add per-address metadata in the future.
message AddressInfo {
bytes multiaddr = 1;
}
// peer_id contains a libp2p peer id in its binary representation.
bytes peer_id = 1;
// seq contains a monotonically-increasing sequence counter to order PeerRecords in time.
uint64 seq = 2;
// addresses is a list of public listen addresses for the peer.
repeated AddressInfo addresses = 3;
}

View File

@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: libp2p/peer/pb/peer_record.proto
# Protobuf Python Version: 4.25.3
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/peer/pb/peer_record.proto\x12\x07peer.pb\"\x80\x01\n\nPeerRecord\x12\x0f\n\x07peer_id\x18\x01 \x01(\x0c\x12\x0b\n\x03seq\x18\x02 \x01(\x04\x12\x32\n\taddresses\x18\x03 \x03(\x0b\x32\x1f.peer.pb.PeerRecord.AddressInfo\x1a \n\x0b\x41\x64\x64ressInfo\x12\x11\n\tmultiaddr\x18\x01 \x01(\x0c\x42*Z(github.com/libp2p/go-libp2p/core/peer/pbb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.peer.pb.peer_record_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
_globals['DESCRIPTOR']._options = None
_globals['DESCRIPTOR']._serialized_options = b'Z(github.com/libp2p/go-libp2p/core/peer/pb'
_globals['_PEERRECORD']._serialized_start=46
_globals['_PEERRECORD']._serialized_end=174
_globals['_PEERRECORD_ADDRESSINFO']._serialized_start=142
_globals['_PEERRECORD_ADDRESSINFO']._serialized_end=174
# @@protoc_insertion_point(module_scope)

View File

@ -0,0 +1,21 @@
from google.protobuf.internal import containers as _containers
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union
DESCRIPTOR: _descriptor.FileDescriptor
class PeerRecord(_message.Message):
__slots__ = ("peer_id", "seq", "addresses")
class AddressInfo(_message.Message):
__slots__ = ("multiaddr",)
MULTIADDR_FIELD_NUMBER: _ClassVar[int]
multiaddr: bytes
def __init__(self, multiaddr: _Optional[bytes] = ...) -> None: ...
PEER_ID_FIELD_NUMBER: _ClassVar[int]
SEQ_FIELD_NUMBER: _ClassVar[int]
ADDRESSES_FIELD_NUMBER: _ClassVar[int]
peer_id: bytes
seq: int
addresses: _containers.RepeatedCompositeFieldContainer[PeerRecord.AddressInfo]
def __init__(self, peer_id: _Optional[bytes] = ..., seq: _Optional[int] = ..., addresses: _Optional[_Iterable[_Union[PeerRecord.AddressInfo, _Mapping]]] = ...) -> None: ... # type: ignore[type-arg]

251
libp2p/peer/peer_record.py Normal file
View File

@ -0,0 +1,251 @@
from collections.abc import Sequence
import threading
import time
from typing import Any
from multiaddr import Multiaddr
from libp2p.abc import IPeerRecord
from libp2p.peer.id import ID
import libp2p.peer.pb.peer_record_pb2 as pb
from libp2p.peer.peerinfo import PeerInfo
PEER_RECORD_ENVELOPE_DOMAIN = "libp2p-peer-record"
PEER_RECORD_ENVELOPE_PAYLOAD_TYPE = b"\x03\x01"
_last_timestamp_lock = threading.Lock()
_last_timestamp: int = 0
class PeerRecord(IPeerRecord):
"""
A record that contains metatdata about a peer in the libp2p network.
This includes:
- `peer_id`: The peer's globally unique indentifier.
- `addrs`: A list of the peer's publicly reachable multiaddrs.
- `seq`: A strictly monotonically increasing timestamp used
to order records over time.
PeerRecords are designed to be signed and transmitted in libp2p routing Envelopes.
"""
peer_id: ID
addrs: list[Multiaddr]
seq: int
def __init__(
self,
peer_id: ID | None = None,
addrs: list[Multiaddr] | None = None,
seq: int | None = None,
) -> None:
"""
Initialize a new PeerRecord.
If `seq` is not provided, a timestamp-based strictly increasing sequence
number will be generated.
:param peer_id: ID of the peer this record refers to.
:param addrs: Public multiaddrs of the peer.
:param seq: Monotonic sequence number.
"""
if peer_id is not None:
self.peer_id = peer_id
self.addrs = addrs or []
if seq is not None:
self.seq = seq
else:
self.seq = timestamp_seq()
def __repr__(self) -> str:
return (
f"PeerRecord(\n"
f" peer_id={self.peer_id},\n"
f" multiaddrs={[str(m) for m in self.addrs]},\n"
f" seq={self.seq}\n"
f")"
)
def domain(self) -> str:
"""
Return the domain string associated with this PeerRecord.
Used during record signing and envelope validation to identify the record type.
"""
return PEER_RECORD_ENVELOPE_DOMAIN
def codec(self) -> bytes:
"""
Return the codec identifier for PeerRecords.
This binary perfix helps distinguish PeerRecords in serialized envelopes.
"""
return PEER_RECORD_ENVELOPE_PAYLOAD_TYPE
def to_protobuf(self) -> pb.PeerRecord:
"""
Convert the current PeerRecord into a ProtoBuf PeerRecord message.
:raises ValueError: if peer_id serialization fails.
:return: A ProtoBuf-encoded PeerRecord message object.
"""
try:
id_bytes = self.peer_id.to_bytes()
except Exception as e:
raise ValueError(f"failed to marshal peer_id: {e}")
msg = pb.PeerRecord()
msg.peer_id = id_bytes
msg.seq = self.seq
msg.addresses.extend(addrs_to_protobuf(self.addrs))
return msg
def marshal_record(self) -> bytes:
"""
Serialize a PeerRecord into raw bytes suitable for embedding in an Envelope.
This is typically called during the process of signing or sealing the record.
:raises ValueError: if serialization to protobuf fails.
:return: Serialized PeerRecord bytes.
"""
try:
msg = self.to_protobuf()
return msg.SerializeToString()
except Exception as e:
raise ValueError(f"failed to marshal PeerRecord: {e}")
def equal(self, other: Any) -> bool:
"""
Check if this PeerRecord is identical to another.
Two PeerRecords are considered equal if:
- Their peer IDs match.
- Their sequence numbers are identical.
- Their address lists are identical and in the same order.
:param other: Another PeerRecord instance.
:return: True if all fields mathch, False otherwise.
"""
if isinstance(other, PeerRecord):
if self.peer_id == other.peer_id:
if self.seq == other.seq:
if len(self.addrs) == len(other.addrs):
for a1, a2 in zip(self.addrs, other.addrs):
if a1 == a2:
continue
else:
return False
return True
return False
def unmarshal_record(data: bytes) -> PeerRecord:
"""
Deserialize a PeerRecord from its serialized byte representation.
Typically used when receiveing a PeerRecord inside a signed routing Envelope.
:param data: Serialized protobuf-encoded bytes.
:raises ValueError: if parsing or conversion fails.
:reurn: A valid PeerRecord instance.
"""
if data is None:
raise ValueError("cannot unmarshal PeerRecord from None")
msg = pb.PeerRecord()
try:
msg.ParseFromString(data)
except Exception as e:
raise ValueError(f"Failed to parse PeerRecord protobuf: {e}")
try:
record = peer_record_from_protobuf(msg)
except Exception as e:
raise ValueError(f"Failed to convert protobuf to PeerRecord: {e}")
return record
def timestamp_seq() -> int:
"""
Generate a strictly increasing timestamp-based sequence number.
Ensures that even if multiple PeerRecords are generated in the same nanosecond,
their `seq` values will still be strictly increasing by using a lock to track the
last value.
:return: A strictly increasing integer timestamp.
"""
global _last_timestamp
now = int(time.time_ns())
with _last_timestamp_lock:
if now <= _last_timestamp:
now = _last_timestamp + 1
_last_timestamp = now
return now
def peer_record_from_peer_info(info: PeerInfo) -> PeerRecord:
"""
Create a PeerRecord from a PeerInfo object.
This automatically assigns a timestamp-based sequence number to the record.
:param info: A PeerInfo instance (contains peer_id and addrs).
:return: A PeerRecord instance.
"""
record = PeerRecord()
record.peer_id = info.peer_id
record.addrs = info.addrs
return record
def peer_record_from_protobuf(msg: pb.PeerRecord) -> PeerRecord:
"""
Convert a protobuf PeerRecord message into a PeerRecord object.
:param msg: Protobuf PeerRecord message.
:raises ValueError: if the peer_id cannot be parsed.
:return: A deserialized PeerRecord instance.
"""
try:
peer_id = ID(msg.peer_id)
except Exception as e:
raise ValueError(f"Failed to unmarshal peer_id: {e}")
addrs = addrs_from_protobuf(msg.addresses)
seq = msg.seq
return PeerRecord(peer_id, addrs, seq)
def addrs_from_protobuf(addrs: Sequence[pb.PeerRecord.AddressInfo]) -> list[Multiaddr]:
"""
Convert a list of protobuf address records to Multiaddr objects.
:param addrs: A list of protobuf PeerRecord.AddressInfo messages.
:return: A list of decoded Multiaddr instances (invalid ones are skipped).
"""
out = []
for addr_info in addrs:
try:
addr = Multiaddr(addr_info.multiaddr)
out.append(addr)
except Exception:
continue
return out
def addrs_to_protobuf(addrs: list[Multiaddr]) -> list[pb.PeerRecord.AddressInfo]:
"""
Convert a list of Multiaddr objects into their protobuf representation.
:param addrs: A list of Multiaddr instances.
:return: A list of PeerRecord.AddressInfo protobuf messages.
"""
out = []
for addr in addrs:
addr_info = pb.PeerRecord.AddressInfo()
addr_info.multiaddr = addr.to_bytes()
out.append(addr_info)
return out

View File

@ -23,6 +23,7 @@ from libp2p.crypto.keys import (
PrivateKey,
PublicKey,
)
from libp2p.peer.envelope import Envelope
from .id import (
ID,
@ -38,12 +39,23 @@ from .peerinfo import (
PERMANENT_ADDR_TTL = 0
class PeerRecordState:
envelope: Envelope
seq: int
def __init__(self, envelope: Envelope, seq: int):
self.envelope = envelope
self.seq = seq
class PeerStore(IPeerStore):
peer_data_map: dict[ID, PeerData]
def __init__(self) -> None:
def __init__(self, max_records: int = 10000) -> None:
self.peer_data_map = defaultdict(PeerData)
self.addr_update_channels: dict[ID, MemorySendChannel[Multiaddr]] = {}
self.peer_record_map: dict[ID, PeerRecordState] = {}
self.max_records = max_records
def peer_info(self, peer_id: ID) -> PeerInfo:
"""
@ -70,6 +82,10 @@ class PeerStore(IPeerStore):
else:
raise PeerStoreError("peer ID not found")
# Clear the peer records
if peer_id in self.peer_record_map:
self.peer_record_map.pop(peer_id, None)
def valid_peer_ids(self) -> list[ID]:
"""
:return: all of the valid peer IDs stored in peer store
@ -82,6 +98,38 @@ class PeerStore(IPeerStore):
peer_data.clear_addrs()
return valid_peer_ids
def _enforce_record_limit(self) -> None:
"""Enforce maximum number of stored records."""
if len(self.peer_record_map) > self.max_records:
# Record oldest records based on seequence number
sorted_records = sorted(
self.peer_record_map.items(), key=lambda x: x[1].seq
)
records_to_remove = len(self.peer_record_map) - self.max_records
for peer_id, _ in sorted_records[:records_to_remove]:
self.maybe_delete_peer_record(peer_id)
del self.peer_record_map[peer_id]
async def start_cleanup_task(self, cleanup_interval: int = 3600) -> None:
"""Start periodic cleanup of expired peer records and addresses."""
while True:
await trio.sleep(cleanup_interval)
self._cleanup_expired_records()
def _cleanup_expired_records(self) -> None:
"""Remove expired peer records and addresses"""
expired_peers = []
for peer_id, peer_data in self.peer_data_map.items():
if peer_data.is_expired():
expired_peers.append(peer_id)
for peer_id in expired_peers:
self.maybe_delete_peer_record(peer_id)
del self.peer_data_map[peer_id]
self._enforce_record_limit()
# --------PROTO-BOOK--------
def get_protocols(self, peer_id: ID) -> list[str]:
@ -165,6 +213,84 @@ class PeerStore(IPeerStore):
peer_data = self.peer_data_map[peer_id]
peer_data.clear_metadata()
# -----CERT-ADDR-BOOK-----
def maybe_delete_peer_record(self, peer_id: ID) -> None:
"""
Delete the signed peer record for a peer if it has no know
(non-expired) addresses.
This is a garbage collection mechanism: if all addresses for a peer have expired
or been cleared, there's no point holding onto its signed `Envelope`
:param peer_id: The peer whose record we may delete/
"""
if peer_id in self.peer_record_map:
if not self.addrs(peer_id):
self.peer_record_map.pop(peer_id, None)
def consume_peer_record(self, envelope: Envelope, ttl: int) -> bool:
"""
Accept and store a signed PeerRecord, unless it's older than
the one already stored.
This function:
- Extracts the peer ID and sequence number from the envelope
- Rejects the record if it's older (lower seq)
- Updates the stored peer record and replaces associated addresses if accepted
:param envelope: Signed envelope containing a PeerRecord.
:param ttl: Time-to-live for the included multiaddrs (in seconds).
:return: True if the record was accepted and stored; False if it was rejected.
"""
record = envelope.record()
peer_id = record.peer_id
existing = self.peer_record_map.get(peer_id)
if existing and existing.seq > record.seq:
return False # reject older record
new_addrs = set(record.addrs)
self.peer_record_map[peer_id] = PeerRecordState(envelope, record.seq)
self.peer_data_map[peer_id].clear_addrs()
self.add_addrs(peer_id, list(new_addrs), ttl)
return True
def consume_peer_records(self, envelopes: list[Envelope], ttl: int) -> list[bool]:
"""Consume multiple peer records in a single operation."""
results = []
for envelope in envelopes:
results.append(self.consume_peer_record(envelope, ttl))
return results
def get_peer_record(self, peer_id: ID) -> Envelope | None:
"""
Retrieve the most recent signed PeerRecord `Envelope` for a peer, if it exists
and is still relevant.
First, it runs cleanup via `maybe_delete_peer_record` to purge stale data.
Then it checks whether the peer has valid, unexpired addresses before
returning the associated envelope.
:param peer_id: The peer to look up.
:return: The signed Envelope if the peer is known and has valid
addresses; None otherwise.
"""
self.maybe_delete_peer_record(peer_id)
# Check if the peer has any valid addresses
if (
peer_id in self.peer_data_map
and not self.peer_data_map[peer_id].is_expired()
):
state = self.peer_record_map.get(peer_id)
if state is not None:
return state.envelope
return None
# -------ADDR-BOOK--------
def add_addr(self, peer_id: ID, addr: Multiaddr, ttl: int = 0) -> None:
@ -193,6 +319,8 @@ class PeerStore(IPeerStore):
except trio.WouldBlock:
pass # Or consider logging / dropping / replacing stream
self.maybe_delete_peer_record(peer_id)
def addrs(self, peer_id: ID) -> list[Multiaddr]:
"""
:param peer_id: peer ID to get addrs for
@ -216,6 +344,8 @@ class PeerStore(IPeerStore):
if peer_id in self.peer_data_map:
self.peer_data_map[peer_id].clear_addrs()
self.maybe_delete_peer_record(peer_id)
def peers_with_addrs(self) -> list[ID]:
"""
:return: all of the peer IDs which has addrsfloat stored in peer store

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: rpc.proto
# source: libp2p/pubsub/pb/rpc.proto
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
@ -13,39 +13,39 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\trpc.proto\x12\tpubsub.pb\"\xb4\x01\n\x03RPC\x12-\n\rsubscriptions\x18\x01 \x03(\x0b\x32\x16.pubsub.pb.RPC.SubOpts\x12#\n\x07publish\x18\x02 \x03(\x0b\x32\x12.pubsub.pb.Message\x12*\n\x07\x63ontrol\x18\x03 \x01(\x0b\x32\x19.pubsub.pb.ControlMessage\x1a-\n\x07SubOpts\x12\x11\n\tsubscribe\x18\x01 \x01(\x08\x12\x0f\n\x07topicid\x18\x02 \x01(\t\"i\n\x07Message\x12\x0f\n\x07\x66rom_id\x18\x01 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\x12\r\n\x05seqno\x18\x03 \x01(\x0c\x12\x10\n\x08topicIDs\x18\x04 \x03(\t\x12\x11\n\tsignature\x18\x05 \x01(\x0c\x12\x0b\n\x03key\x18\x06 \x01(\x0c\"\xb0\x01\n\x0e\x43ontrolMessage\x12&\n\x05ihave\x18\x01 \x03(\x0b\x32\x17.pubsub.pb.ControlIHave\x12&\n\x05iwant\x18\x02 \x03(\x0b\x32\x17.pubsub.pb.ControlIWant\x12&\n\x05graft\x18\x03 \x03(\x0b\x32\x17.pubsub.pb.ControlGraft\x12&\n\x05prune\x18\x04 \x03(\x0b\x32\x17.pubsub.pb.ControlPrune\"3\n\x0c\x43ontrolIHave\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\x12\n\nmessageIDs\x18\x02 \x03(\t\"\"\n\x0c\x43ontrolIWant\x12\x12\n\nmessageIDs\x18\x01 \x03(\t\"\x1f\n\x0c\x43ontrolGraft\x12\x0f\n\x07topicID\x18\x01 \x01(\t\"T\n\x0c\x43ontrolPrune\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\"\n\x05peers\x18\x02 \x03(\x0b\x32\x13.pubsub.pb.PeerInfo\x12\x0f\n\x07\x62\x61\x63koff\x18\x03 \x01(\x04\"4\n\x08PeerInfo\x12\x0e\n\x06peerID\x18\x01 \x01(\x0c\x12\x18\n\x10signedPeerRecord\x18\x02 \x01(\x0c\"\x87\x03\n\x0fTopicDescriptor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x31\n\x04\x61uth\x18\x02 \x01(\x0b\x32#.pubsub.pb.TopicDescriptor.AuthOpts\x12/\n\x03\x65nc\x18\x03 \x01(\x0b\x32\".pubsub.pb.TopicDescriptor.EncOpts\x1a|\n\x08\x41uthOpts\x12:\n\x04mode\x18\x01 \x01(\x0e\x32,.pubsub.pb.TopicDescriptor.AuthOpts.AuthMode\x12\x0c\n\x04keys\x18\x02 \x03(\x0c\"&\n\x08\x41uthMode\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03KEY\x10\x01\x12\x07\n\x03WOT\x10\x02\x1a\x83\x01\n\x07\x45ncOpts\x12\x38\n\x04mode\x18\x01 \x01(\x0e\x32*.pubsub.pb.TopicDescriptor.EncOpts.EncMode\x12\x11\n\tkeyHashes\x18\x02 \x03(\x0c\"+\n\x07\x45ncMode\x12\x08\n\x04NONE\x10\x00\x12\r\n\tSHAREDKEY\x10\x01\x12\x07\n\x03WOT\x10\x02')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1alibp2p/pubsub/pb/rpc.proto\x12\tpubsub.pb\"\xb4\x01\n\x03RPC\x12-\n\rsubscriptions\x18\x01 \x03(\x0b\x32\x16.pubsub.pb.RPC.SubOpts\x12#\n\x07publish\x18\x02 \x03(\x0b\x32\x12.pubsub.pb.Message\x12*\n\x07\x63ontrol\x18\x03 \x01(\x0b\x32\x19.pubsub.pb.ControlMessage\x1a-\n\x07SubOpts\x12\x11\n\tsubscribe\x18\x01 \x01(\x08\x12\x0f\n\x07topicid\x18\x02 \x01(\t\"i\n\x07Message\x12\x0f\n\x07\x66rom_id\x18\x01 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\x12\r\n\x05seqno\x18\x03 \x01(\x0c\x12\x10\n\x08topicIDs\x18\x04 \x03(\t\x12\x11\n\tsignature\x18\x05 \x01(\x0c\x12\x0b\n\x03key\x18\x06 \x01(\x0c\"\xb0\x01\n\x0e\x43ontrolMessage\x12&\n\x05ihave\x18\x01 \x03(\x0b\x32\x17.pubsub.pb.ControlIHave\x12&\n\x05iwant\x18\x02 \x03(\x0b\x32\x17.pubsub.pb.ControlIWant\x12&\n\x05graft\x18\x03 \x03(\x0b\x32\x17.pubsub.pb.ControlGraft\x12&\n\x05prune\x18\x04 \x03(\x0b\x32\x17.pubsub.pb.ControlPrune\"3\n\x0c\x43ontrolIHave\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\x12\n\nmessageIDs\x18\x02 \x03(\t\"\"\n\x0c\x43ontrolIWant\x12\x12\n\nmessageIDs\x18\x01 \x03(\t\"\x1f\n\x0c\x43ontrolGraft\x12\x0f\n\x07topicID\x18\x01 \x01(\t\"T\n\x0c\x43ontrolPrune\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\"\n\x05peers\x18\x02 \x03(\x0b\x32\x13.pubsub.pb.PeerInfo\x12\x0f\n\x07\x62\x61\x63koff\x18\x03 \x01(\x04\"4\n\x08PeerInfo\x12\x0e\n\x06peerID\x18\x01 \x01(\x0c\x12\x18\n\x10signedPeerRecord\x18\x02 \x01(\x0c\"\x87\x03\n\x0fTopicDescriptor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x31\n\x04\x61uth\x18\x02 \x01(\x0b\x32#.pubsub.pb.TopicDescriptor.AuthOpts\x12/\n\x03\x65nc\x18\x03 \x01(\x0b\x32\".pubsub.pb.TopicDescriptor.EncOpts\x1a|\n\x08\x41uthOpts\x12:\n\x04mode\x18\x01 \x01(\x0e\x32,.pubsub.pb.TopicDescriptor.AuthOpts.AuthMode\x12\x0c\n\x04keys\x18\x02 \x03(\x0c\"&\n\x08\x41uthMode\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03KEY\x10\x01\x12\x07\n\x03WOT\x10\x02\x1a\x83\x01\n\x07\x45ncOpts\x12\x38\n\x04mode\x18\x01 \x01(\x0e\x32*.pubsub.pb.TopicDescriptor.EncOpts.EncMode\x12\x11\n\tkeyHashes\x18\x02 \x03(\x0c\"+\n\x07\x45ncMode\x12\x08\n\x04NONE\x10\x00\x12\r\n\tSHAREDKEY\x10\x01\x12\x07\n\x03WOT\x10\x02')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'rpc_pb2', globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.pubsub.pb.rpc_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_RPC._serialized_start=25
_RPC._serialized_end=205
_RPC_SUBOPTS._serialized_start=160
_RPC_SUBOPTS._serialized_end=205
_MESSAGE._serialized_start=207
_MESSAGE._serialized_end=312
_CONTROLMESSAGE._serialized_start=315
_CONTROLMESSAGE._serialized_end=491
_CONTROLIHAVE._serialized_start=493
_CONTROLIHAVE._serialized_end=544
_CONTROLIWANT._serialized_start=546
_CONTROLIWANT._serialized_end=580
_CONTROLGRAFT._serialized_start=582
_CONTROLGRAFT._serialized_end=613
_CONTROLPRUNE._serialized_start=615
_CONTROLPRUNE._serialized_end=699
_PEERINFO._serialized_start=701
_PEERINFO._serialized_end=753
_TOPICDESCRIPTOR._serialized_start=756
_TOPICDESCRIPTOR._serialized_end=1147
_TOPICDESCRIPTOR_AUTHOPTS._serialized_start=889
_TOPICDESCRIPTOR_AUTHOPTS._serialized_end=1013
_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE._serialized_start=975
_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE._serialized_end=1013
_TOPICDESCRIPTOR_ENCOPTS._serialized_start=1016
_TOPICDESCRIPTOR_ENCOPTS._serialized_end=1147
_TOPICDESCRIPTOR_ENCOPTS_ENCMODE._serialized_start=1104
_TOPICDESCRIPTOR_ENCOPTS_ENCMODE._serialized_end=1147
_RPC._serialized_start=42
_RPC._serialized_end=222
_RPC_SUBOPTS._serialized_start=177
_RPC_SUBOPTS._serialized_end=222
_MESSAGE._serialized_start=224
_MESSAGE._serialized_end=329
_CONTROLMESSAGE._serialized_start=332
_CONTROLMESSAGE._serialized_end=508
_CONTROLIHAVE._serialized_start=510
_CONTROLIHAVE._serialized_end=561
_CONTROLIWANT._serialized_start=563
_CONTROLIWANT._serialized_end=597
_CONTROLGRAFT._serialized_start=599
_CONTROLGRAFT._serialized_end=630
_CONTROLPRUNE._serialized_start=632
_CONTROLPRUNE._serialized_end=716
_PEERINFO._serialized_start=718
_PEERINFO._serialized_end=770
_TOPICDESCRIPTOR._serialized_start=773
_TOPICDESCRIPTOR._serialized_end=1164
_TOPICDESCRIPTOR_AUTHOPTS._serialized_start=906
_TOPICDESCRIPTOR_AUTHOPTS._serialized_end=1030
_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE._serialized_start=992
_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE._serialized_end=1030
_TOPICDESCRIPTOR_ENCOPTS._serialized_start=1033
_TOPICDESCRIPTOR_ENCOPTS._serialized_end=1164
_TOPICDESCRIPTOR_ENCOPTS_ENCMODE._serialized_start=1121
_TOPICDESCRIPTOR_ENCOPTS_ENCMODE._serialized_end=1164
# @@protoc_insertion_point(module_scope)

View File

@ -11,6 +11,10 @@ import functools
import hashlib
import logging
import time
from typing import (
NamedTuple,
cast,
)
import base58
import trio
@ -26,6 +30,8 @@ from libp2p.crypto.keys import (
PrivateKey,
)
from libp2p.custom_types import (
AsyncValidatorFn,
SyncValidatorFn,
TProtocol,
ValidatorFn,
)
@ -71,11 +77,6 @@ from .pubsub_notifee import (
from .subscription import (
TrioSubscriptionAPI,
)
from .validation_throttler import (
TopicValidator,
ValidationResult,
ValidationThrottler,
)
from .validators import (
PUBSUB_SIGNING_PREFIX,
signature_validator,
@ -96,6 +97,14 @@ def get_content_addressed_msg_id(msg: rpc_pb2.Message) -> bytes:
return base64.b64encode(hashlib.sha256(msg.data).digest())
class TopicValidator(NamedTuple):
validator: ValidatorFn
is_async: bool
MAX_CONCURRENT_VALIDATORS = 10
class Pubsub(Service, IPubsub):
host: IHost
@ -103,6 +112,7 @@ class Pubsub(Service, IPubsub):
peer_receive_channel: trio.MemoryReceiveChannel[ID]
dead_peer_receive_channel: trio.MemoryReceiveChannel[ID]
_validator_semaphore: trio.Semaphore
seen_messages: LastSeenCache
@ -137,11 +147,7 @@ class Pubsub(Service, IPubsub):
msg_id_constructor: Callable[
[rpc_pb2.Message], bytes
] = get_peer_and_seqno_msg_id,
# TODO: these values have been copied from Go, but try to tune these dynamically
validation_queue_size: int = 32,
global_throttle_limit: int = 8192,
default_topic_throttle_limit: int = 1024,
validation_worker_count: int | None = None,
max_concurrent_validator_count: int = MAX_CONCURRENT_VALIDATORS,
) -> None:
"""
Construct a new Pubsub object, which is responsible for handling all
@ -167,6 +173,7 @@ class Pubsub(Service, IPubsub):
# Therefore, we can only close from the receive side.
self.peer_receive_channel = peer_receive
self.dead_peer_receive_channel = dead_peer_receive
self._validator_semaphore = trio.Semaphore(max_concurrent_validator_count)
# Register a notifee
self.host.get_network().register_notifee(
PubsubNotifee(peer_send, dead_peer_send)
@ -202,15 +209,7 @@ class Pubsub(Service, IPubsub):
# Create peers map, which maps peer_id (as string) to stream (to a given peer)
self.peers = {}
# Validation Throttler
self.validation_throttler = ValidationThrottler(
queue_size=validation_queue_size,
global_throttle_limit=global_throttle_limit,
default_topic_throttle_limit=default_topic_throttle_limit,
worker_count=validation_worker_count or 4,
)
# Keep a mapping of topic -> TopicValidator for easier lookup
# Map of topic to topic validator
self.topic_validators = {}
self.counter = int(time.time())
@ -222,19 +221,10 @@ class Pubsub(Service, IPubsub):
self.event_handle_dead_peer_queue_started = trio.Event()
async def run(self) -> None:
self.manager.run_daemon_task(self._start_validation_throttler)
self.manager.run_daemon_task(self.handle_peer_queue)
self.manager.run_daemon_task(self.handle_dead_peer_queue)
await self.manager.wait_finished()
async def _start_validation_throttler(self) -> None:
"""Start validation throttler in current nursery context"""
async with trio.open_nursery() as nursery:
await self.validation_throttler.start(nursery)
# Keep nursery alive until service stops
while self.manager.is_running:
await self.manager.wait_finished()
@property
def my_id(self) -> ID:
return self.host.get_id()
@ -314,12 +304,7 @@ class Pubsub(Service, IPubsub):
)
def set_topic_validator(
self,
topic: str,
validator: ValidatorFn,
is_async_validator: bool,
timeout: float | None = None,
throttle_limit: int | None = None,
self, topic: str, validator: ValidatorFn, is_async_validator: bool
) -> None:
"""
Register a validator under the given topic. One topic can only have one
@ -328,18 +313,8 @@ class Pubsub(Service, IPubsub):
:param topic: the topic to register validator under
:param validator: the validator used to validate messages published to the topic
:param is_async_validator: indicate if the validator is an asynchronous validator
:param timeout: optional timeout for the validator
:param throttle_limit: optional throttle limit for the validator
""" # noqa: E501
# Create throttled topic validator
topic_validator = self.validation_throttler.create_topic_validator(
topic=topic,
validator=validator,
is_async=is_async_validator,
timeout=timeout,
throttle_limit=throttle_limit,
)
self.topic_validators[topic] = topic_validator
self.topic_validators[topic] = TopicValidator(validator, is_async_validator)
def remove_topic_validator(self, topic: str) -> None:
"""
@ -349,18 +324,17 @@ class Pubsub(Service, IPubsub):
"""
self.topic_validators.pop(topic, None)
def get_msg_validators(self, msg: rpc_pb2.Message) -> list[TopicValidator]:
def get_msg_validators(self, msg: rpc_pb2.Message) -> tuple[TopicValidator, ...]:
"""
Get all validators corresponding to the topics in the message.
:param msg: the message published to the topic
:return: list of topic validators for the message's topics
"""
return [
return tuple(
self.topic_validators[topic]
for topic in msg.topicIDs
if topic in self.topic_validators
]
)
def add_to_blacklist(self, peer_id: ID) -> None:
"""
@ -689,63 +663,60 @@ class Pubsub(Service, IPubsub):
logger.debug("successfully published message %s", msg)
async def validate_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
async def validate_msg(
self,
msg_forwarder: ID,
msg: rpc_pb2.Message,
) -> None:
"""
Validate the received message.
:param msg_forwarder: the peer who forward us the message.
:param msg: the message.
"""
# Get applicable validators for this message
validators = self.get_msg_validators(msg)
sync_topic_validators: list[SyncValidatorFn] = []
async_topic_validators: list[AsyncValidatorFn] = []
for topic_validator in self.get_msg_validators(msg):
if topic_validator.is_async:
async_topic_validators.append(
cast(AsyncValidatorFn, topic_validator.validator)
)
else:
sync_topic_validators.append(
cast(SyncValidatorFn, topic_validator.validator)
)
if not validators:
# No validators, accept immediately
return
for validator in sync_topic_validators:
if not validator(msg_forwarder, msg):
raise ValidationError(f"Validation failed for msg={msg}")
# Use trio.Event for async coordination
validation_event = trio.Event()
result_container: dict[str, ValidationResult | None | Exception] = {
"result": None,
"error": None,
}
if len(async_topic_validators) > 0:
# Appends to lists are thread safe in CPython
results: list[bool] = []
def handle_validation_result(
result: ValidationResult, error: Exception | None
) -> None:
result_container["result"] = result
result_container["error"] = error
validation_event.set()
async with trio.open_nursery() as nursery:
for async_validator in async_topic_validators:
nursery.start_soon(
self._run_async_validator,
async_validator,
msg_forwarder,
msg,
results,
)
# Submit for throttled validation
success = await self.validation_throttler.submit_validation(
validators=validators,
msg_forwarder=msg_forwarder,
msg=msg,
result_callback=handle_validation_result,
)
if not all(results):
raise ValidationError(f"Validation failed for msg={msg}")
if not success:
# Validation was throttled at queue level
raise ValidationError("Validation throttled at queue level")
# Wait for validation result
await validation_event.wait()
result = result_container["result"]
error = result_container["error"]
if error:
raise ValidationError(f"Validation error: {error}")
if result == ValidationResult.REJECT:
raise ValidationError("Message validation rejected")
elif result == ValidationResult.THROTTLED:
raise ValidationError("Message validation throttled")
elif result == ValidationResult.IGNORE:
# Treat IGNORE as rejection for now, or you could silently drop
raise ValidationError("Message validation ignored")
# ACCEPT case - just return normally
async def _run_async_validator(
self,
func: AsyncValidatorFn,
msg_forwarder: ID,
msg: rpc_pb2.Message,
results: list[bool],
) -> None:
async with self._validator_semaphore:
result = await func(msg_forwarder, msg)
results.append(result)
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
"""

View File

@ -1,314 +0,0 @@
from collections.abc import (
Callable,
)
from dataclasses import dataclass
from enum import Enum
import logging
from typing import (
NamedTuple,
cast,
)
import trio
from libp2p.custom_types import AsyncValidatorFn, ValidatorFn
from libp2p.peer.id import (
ID,
)
from .pb import (
rpc_pb2,
)
logger = logging.getLogger("libp2p.pubsub.validation")
class ValidationResult(Enum):
ACCEPT = "accept"
REJECT = "reject"
IGNORE = "ignore"
THROTTLED = "throttled"
@dataclass
class ValidationRequest:
"""Request for message validation"""
validators: list["TopicValidator"]
msg_forwarder: ID # peer ID
msg: rpc_pb2.Message # message object
result_callback: Callable[[ValidationResult, Exception | None], None]
class TopicValidator(NamedTuple):
topic: str
validator: ValidatorFn
is_async: bool
timeout: float | None = None
# Per-topic throttle semaphore
throttle_semaphore: trio.Semaphore | None = None
class ValidationThrottler:
"""Manages all validation throttling mechanisms"""
def __init__(
self,
queue_size: int = 32,
global_throttle_limit: int = 8192,
default_topic_throttle_limit: int = 1024,
worker_count: int | None = None,
):
# 1. Queue-level throttling - bounded memory channel
self._validation_send, self._validation_receive = trio.open_memory_channel[
ValidationRequest
](queue_size)
# 2. Global validation throttling - limits total concurrent async validations
self._global_throttle = trio.Semaphore(global_throttle_limit)
# 3. Per-topic throttling - each validator gets its own semaphore
self._default_topic_throttle_limit = default_topic_throttle_limit
# Worker management
# TODO: Find a better way to manage worker count
self._worker_count = worker_count or 4
self._running = False
async def start(self, nursery: trio.Nursery) -> None:
"""Start the validation workers"""
self._running = True
# Start validation worker tasks
for i in range(self._worker_count):
nursery.start_soon(self._validation_worker, f"worker-{i}")
async def stop(self) -> None:
"""Stop the validation system"""
self._running = False
await self._validation_send.aclose()
def create_topic_validator(
self,
topic: str,
validator: ValidatorFn,
is_async: bool,
timeout: float | None = None,
throttle_limit: int | None = None,
) -> TopicValidator:
"""Create a new topic validator with its own throttle"""
limit = throttle_limit or self._default_topic_throttle_limit
throttle_sem = trio.Semaphore(limit)
return TopicValidator(
topic=topic,
validator=validator,
is_async=is_async,
timeout=timeout,
throttle_semaphore=throttle_sem,
)
async def submit_validation(
self,
validators: list[TopicValidator],
msg_forwarder: ID,
msg: rpc_pb2.Message,
result_callback: Callable[[ValidationResult, Exception | None], None],
) -> bool:
"""
Submit a message for validation.
Returns True if queued successfully, False if queue is full (throttled).
"""
if not self._running:
result_callback(
ValidationResult.REJECT, Exception("Validation system not running")
)
return False
request = ValidationRequest(
validators=validators,
msg_forwarder=msg_forwarder,
msg=msg,
result_callback=result_callback,
)
try:
# This will raise trio.WouldBlock if queue is full
self._validation_send.send_nowait(request)
return True
except trio.WouldBlock:
# Queue-level throttling: drop the message
logger.debug(
"Validation queue full, dropping message from %s", msg_forwarder
)
result_callback(
ValidationResult.THROTTLED, Exception("Validation queue full")
)
return False
async def _validation_worker(self, worker_id: str) -> None:
"""Worker that processes validation requests"""
logger.debug("Validation worker %s started", worker_id)
async with self._validation_receive:
async for request in self._validation_receive:
if not self._running:
break
try:
# Process the validation request
result = await self._validate_message(request)
request.result_callback(result, None)
except Exception as e:
logger.exception("Error in validation worker %s", worker_id)
request.result_callback(ValidationResult.REJECT, e)
logger.debug("Validation worker %s stopped", worker_id)
async def _validate_message(self, request: ValidationRequest) -> ValidationResult:
"""Core validation logic with throttling"""
validators = request.validators
msg_forwarder = request.msg_forwarder
msg = request.msg
if not validators:
return ValidationResult.ACCEPT
# Separate sync and async validators
sync_validators = [v for v in validators if not v.is_async]
async_validators = [v for v in validators if v.is_async]
# Run synchronous validators first
for validator in sync_validators:
try:
# Apply per-topic throttling even for sync validators
if validator.throttle_semaphore:
validator.throttle_semaphore.acquire_nowait()
try:
result = validator.validator(msg_forwarder, msg)
if not result:
return ValidationResult.REJECT
finally:
validator.throttle_semaphore.release()
else:
result = validator.validator(msg_forwarder, msg)
if not result:
return ValidationResult.REJECT
except trio.WouldBlock:
# Per-topic throttling for sync validator
logger.debug("Sync validation throttled for topic %s", validator.topic)
return ValidationResult.THROTTLED
except Exception as e:
logger.exception(
"Sync validator failed for topic %s: %s", validator.topic, e
)
return ValidationResult.REJECT
# Handle async validators with global + per-topic throttling
if async_validators:
return await self._validate_async_validators(
async_validators, msg_forwarder, msg
)
return ValidationResult.ACCEPT
async def _validate_async_validators(
self, validators: list[TopicValidator], msg_forwarder: ID, msg: rpc_pb2.Message
) -> ValidationResult:
"""Handle async validators with proper throttling"""
if len(validators) == 1:
# Fast path for single validator
return await self._validate_single_async_validator(
validators[0], msg_forwarder, msg
)
# Multiple async validators - run them concurrently
try:
# Try to acquire global throttle slot
self._global_throttle.acquire_nowait()
except trio.WouldBlock:
logger.debug(
"Global validation throttle exceeded, dropping message from %s",
msg_forwarder,
)
return ValidationResult.THROTTLED
try:
async with trio.open_nursery() as nursery:
results = {}
async def run_validator(validator: TopicValidator, index: int) -> None:
"""Run a single async validator and store the result"""
nonlocal results
result = await self._validate_single_async_validator(
validator, msg_forwarder, msg
)
results[index] = result
# Start all validators concurrently
for i, validator in enumerate(validators):
nursery.start_soon(run_validator, validator, i)
# Process results - any reject or throttle causes overall failure
final_result = ValidationResult.ACCEPT
for result in results.values():
if result == ValidationResult.REJECT:
return ValidationResult.REJECT
elif result == ValidationResult.THROTTLED:
final_result = ValidationResult.THROTTLED
elif (
result == ValidationResult.IGNORE
and final_result == ValidationResult.ACCEPT
):
final_result = ValidationResult.IGNORE
return final_result
finally:
self._global_throttle.release()
return ValidationResult.IGNORE
async def _validate_single_async_validator(
self, validator: TopicValidator, msg_forwarder: ID, msg: rpc_pb2.Message
) -> ValidationResult:
"""Validate with a single async validator"""
# Apply per-topic throttling
if validator.throttle_semaphore:
try:
validator.throttle_semaphore.acquire_nowait()
except trio.WouldBlock:
logger.debug(
"Per-topic validation throttled for topic %s", validator.topic
)
return ValidationResult.THROTTLED
else:
# Fallback if no throttle semaphore configured
pass
try:
# Apply timeout if configured
result: bool
if validator.timeout:
with trio.fail_after(validator.timeout):
func = cast(AsyncValidatorFn, validator.validator)
result = await func(msg_forwarder, msg)
else:
func = cast(AsyncValidatorFn, validator.validator)
result = await func(msg_forwarder, msg)
return ValidationResult.ACCEPT if result else ValidationResult.REJECT
except trio.TooSlowError:
logger.debug("Validation timeout for topic %s", validator.topic)
return ValidationResult.IGNORE
except Exception as e:
logger.exception(
"Async validator failed for topic %s: %s", validator.topic, e
)
return ValidationResult.REJECT
finally:
if validator.throttle_semaphore:
validator.throttle_semaphore.release()
return ValidationResult.IGNORE

View File

@ -15,6 +15,10 @@ from libp2p.relay.circuit_v2 import (
RelayLimits,
RelayResourceManager,
Reservation,
DCUTR_PROTOCOL_ID,
DCUtRProtocol,
ReachabilityChecker,
is_private_ip,
)
__all__ = [
@ -25,4 +29,9 @@ __all__ = [
"RelayLimits",
"RelayResourceManager",
"Reservation",
"DCUtRProtocol",
"DCUTR_PROTOCOL_ID",
"ReachabilityChecker",
"is_private_ip"
]

View File

@ -5,6 +5,16 @@ This package implements the Circuit Relay v2 protocol as specified in:
https://github.com/libp2p/specs/blob/master/relay/circuit-v2.md
"""
from .dcutr import (
DCUtRProtocol,
)
from .dcutr import PROTOCOL_ID as DCUTR_PROTOCOL_ID
from .nat import (
ReachabilityChecker,
is_private_ip,
)
from .discovery import (
RelayDiscovery,
)
@ -29,4 +39,8 @@ __all__ = [
"RelayResourceManager",
"CircuitV2Transport",
"RelayDiscovery",
"DCUtRProtocol",
"DCUTR_PROTOCOL_ID",
"ReachabilityChecker",
"is_private_ip",
]

View File

@ -0,0 +1,580 @@
"""
Direct Connection Upgrade through Relay (DCUtR) protocol implementation.
This module implements the DCUtR protocol as specified in:
https://github.com/libp2p/specs/blob/master/relay/DCUtR.md
DCUtR enables peers behind NAT to establish direct connections
using hole punching techniques.
"""
import logging
import time
from typing import Any
from multiaddr import Multiaddr
import trio
from libp2p.abc import (
IHost,
INetConn,
INetStream,
)
from libp2p.custom_types import (
TProtocol,
)
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
from libp2p.relay.circuit_v2.nat import (
ReachabilityChecker,
)
from libp2p.relay.circuit_v2.pb.dcutr_pb2 import (
HolePunch,
)
from libp2p.tools.async_service import (
Service,
)
logger = logging.getLogger(__name__)
# Protocol ID for DCUtR
PROTOCOL_ID = TProtocol("/libp2p/dcutr")
# Maximum message size for DCUtR (4KiB as per spec)
MAX_MESSAGE_SIZE = 4 * 1024
# Timeouts
STREAM_READ_TIMEOUT = 30 # seconds
STREAM_WRITE_TIMEOUT = 30 # seconds
DIAL_TIMEOUT = 10 # seconds
# Maximum number of hole punch attempts per peer
MAX_HOLE_PUNCH_ATTEMPTS = 5
# Delay between retry attempts
HOLE_PUNCH_RETRY_DELAY = 30 # seconds
# Maximum observed addresses to exchange
MAX_OBSERVED_ADDRS = 20
class DCUtRProtocol(Service):
"""
DCUtRProtocol implements the Direct Connection Upgrade through Relay protocol.
This protocol allows two NATed peers to establish direct connections through
hole punching, after they have established an initial connection through a relay.
"""
def __init__(self, host: IHost):
"""
Initialize the DCUtR protocol.
Parameters
----------
host : IHost
The libp2p host this protocol is running on
"""
super().__init__()
self.host = host
self.event_started = trio.Event()
self._hole_punch_attempts: dict[ID, int] = {}
self._direct_connections: set[ID] = set()
self._in_progress: set[ID] = set()
self._reachability_checker = ReachabilityChecker(host)
self._nursery: trio.Nursery | None = None
async def run(self, *, task_status: Any = trio.TASK_STATUS_IGNORED) -> None:
"""Run the protocol service."""
try:
# Register the DCUtR protocol handler
logger.debug("Registering DCUtR protocol handler")
self.host.set_stream_handler(PROTOCOL_ID, self._handle_dcutr_stream)
# Signal that we're ready
self.event_started.set()
# Start the service
async with trio.open_nursery() as nursery:
self._nursery = nursery
task_status.started()
logger.debug("DCUtR protocol service started")
# Wait for service to be stopped
await self.manager.wait_finished()
finally:
# Clean up
try:
# Use empty async lambda instead of None for stream handler
async def empty_handler(_: INetStream) -> None:
pass
self.host.set_stream_handler(PROTOCOL_ID, empty_handler)
logger.debug("DCUtR protocol handler unregistered")
except Exception as e:
logger.error("Error unregistering DCUtR protocol handler: %s", str(e))
# Clear state
self._hole_punch_attempts.clear()
self._direct_connections.clear()
self._in_progress.clear()
self._nursery = None
async def _handle_dcutr_stream(self, stream: INetStream) -> None:
"""
Handle incoming DCUtR streams.
Parameters
----------
stream : INetStream
The incoming stream
"""
try:
# Get the remote peer ID
remote_peer_id = stream.muxed_conn.peer_id
logger.debug("Received DCUtR stream from peer %s", remote_peer_id)
# Check if we already have a direct connection
if await self._have_direct_connection(remote_peer_id):
logger.debug(
"Already have direct connection to %s, closing stream",
remote_peer_id,
)
await stream.close()
return
# Check if there's already an active hole punch attempt
if remote_peer_id in self._in_progress:
logger.debug("Hole punch already in progress with %s", remote_peer_id)
# Let the existing attempt continue
await stream.close()
return
# Mark as in progress
self._in_progress.add(remote_peer_id)
try:
# Read the CONNECT message
with trio.fail_after(STREAM_READ_TIMEOUT):
msg_bytes = await stream.read(MAX_MESSAGE_SIZE)
# Parse the message
connect_msg = HolePunch()
connect_msg.ParseFromString(msg_bytes)
# Verify it's a CONNECT message
if connect_msg.type != HolePunch.CONNECT:
logger.warning("Expected CONNECT message, got %s", connect_msg.type)
await stream.close()
return
logger.debug(
"Received CONNECT message from %s with %d addresses",
remote_peer_id,
len(connect_msg.ObsAddrs),
)
# Process observed addresses from the peer
peer_addrs = self._decode_observed_addrs(list(connect_msg.ObsAddrs))
logger.debug("Decoded %d valid addresses from peer", len(peer_addrs))
# Store the addresses in the peerstore
if peer_addrs:
self.host.get_peerstore().add_addrs(
remote_peer_id, peer_addrs, 10 * 60
) # 10 minute TTL
# Send our CONNECT message with our observed addresses
our_addrs = await self._get_observed_addrs()
response = HolePunch()
response.type = HolePunch.CONNECT
response.ObsAddrs.extend(our_addrs)
with trio.fail_after(STREAM_WRITE_TIMEOUT):
await stream.write(response.SerializeToString())
logger.debug(
"Sent CONNECT response to %s with %d addresses",
remote_peer_id,
len(our_addrs),
)
# Wait for SYNC message
with trio.fail_after(STREAM_READ_TIMEOUT):
sync_bytes = await stream.read(MAX_MESSAGE_SIZE)
# Parse the SYNC message
sync_msg = HolePunch()
sync_msg.ParseFromString(sync_bytes)
# Verify it's a SYNC message
if sync_msg.type != HolePunch.SYNC:
logger.warning("Expected SYNC message, got %s", sync_msg.type)
await stream.close()
return
logger.debug("Received SYNC message from %s", remote_peer_id)
# Perform hole punch
success = await self._perform_hole_punch(remote_peer_id, peer_addrs)
if success:
logger.info(
"Successfully established direct connection with %s",
remote_peer_id,
)
else:
logger.warning(
"Failed to establish direct connection with %s", remote_peer_id
)
except trio.TooSlowError:
logger.warning("Timeout in DCUtR protocol with peer %s", remote_peer_id)
except Exception as e:
logger.error(
"Error in DCUtR protocol with peer %s: %s", remote_peer_id, str(e)
)
finally:
# Clean up
self._in_progress.discard(remote_peer_id)
await stream.close()
except Exception as e:
logger.error("Error handling DCUtR stream: %s", str(e))
await stream.close()
async def initiate_hole_punch(self, peer_id: ID) -> bool:
"""
Initiate a hole punch with a peer.
Parameters
----------
peer_id : ID
The peer to hole punch with
Returns
-------
bool
True if hole punch was successful, False otherwise
"""
# Check if we already have a direct connection
if await self._have_direct_connection(peer_id):
logger.debug("Already have direct connection to %s", peer_id)
return True
# Check if there's already an active hole punch attempt
if peer_id in self._in_progress:
logger.debug("Hole punch already in progress with %s", peer_id)
return False
# Check if we've exceeded the maximum number of attempts
attempts = self._hole_punch_attempts.get(peer_id, 0)
if attempts >= MAX_HOLE_PUNCH_ATTEMPTS:
logger.warning("Maximum hole punch attempts reached for peer %s", peer_id)
return False
# Mark as in progress and increment attempt counter
self._in_progress.add(peer_id)
self._hole_punch_attempts[peer_id] = attempts + 1
try:
# Open a DCUtR stream to the peer
logger.debug("Opening DCUtR stream to peer %s", peer_id)
stream = await self.host.new_stream(peer_id, [PROTOCOL_ID])
if not stream:
logger.warning("Failed to open DCUtR stream to peer %s", peer_id)
return False
try:
# Send our CONNECT message with our observed addresses
our_addrs = await self._get_observed_addrs()
connect_msg = HolePunch()
connect_msg.type = HolePunch.CONNECT
connect_msg.ObsAddrs.extend(our_addrs)
start_time = time.time()
with trio.fail_after(STREAM_WRITE_TIMEOUT):
await stream.write(connect_msg.SerializeToString())
logger.debug(
"Sent CONNECT message to %s with %d addresses",
peer_id,
len(our_addrs),
)
# Receive the peer's CONNECT message
with trio.fail_after(STREAM_READ_TIMEOUT):
resp_bytes = await stream.read(MAX_MESSAGE_SIZE)
# Calculate RTT
rtt = time.time() - start_time
# Parse the response
resp = HolePunch()
resp.ParseFromString(resp_bytes)
# Verify it's a CONNECT message
if resp.type != HolePunch.CONNECT:
logger.warning("Expected CONNECT message, got %s", resp.type)
return False
logger.debug(
"Received CONNECT response from %s with %d addresses",
peer_id,
len(resp.ObsAddrs),
)
# Process observed addresses from the peer
peer_addrs = self._decode_observed_addrs(list(resp.ObsAddrs))
logger.debug("Decoded %d valid addresses from peer", len(peer_addrs))
# Store the addresses in the peerstore
if peer_addrs:
self.host.get_peerstore().add_addrs(
peer_id, peer_addrs, 10 * 60
) # 10 minute TTL
# Send SYNC message with timing information
# We'll use a future time that's 2*RTT from now to ensure both sides
# are ready
punch_time = time.time() + (2 * rtt) + 1 # Add 1 second buffer
sync_msg = HolePunch()
sync_msg.type = HolePunch.SYNC
with trio.fail_after(STREAM_WRITE_TIMEOUT):
await stream.write(sync_msg.SerializeToString())
logger.debug("Sent SYNC message to %s", peer_id)
# Perform the synchronized hole punch
success = await self._perform_hole_punch(
peer_id, peer_addrs, punch_time
)
if success:
logger.info(
"Successfully established direct connection with %s", peer_id
)
return True
else:
logger.warning(
"Failed to establish direct connection with %s", peer_id
)
return False
except trio.TooSlowError:
logger.warning("Timeout in DCUtR protocol with peer %s", peer_id)
return False
except Exception as e:
logger.error(
"Error in DCUtR protocol with peer %s: %s", peer_id, str(e)
)
return False
finally:
await stream.close()
except Exception as e:
logger.error(
"Error initiating hole punch with peer %s: %s", peer_id, str(e)
)
return False
finally:
self._in_progress.discard(peer_id)
# This should never be reached, but add explicit return for type checking
return False
async def _perform_hole_punch(
self, peer_id: ID, addrs: list[Multiaddr], punch_time: float | None = None
) -> bool:
"""
Perform a hole punch attempt with a peer.
Parameters
----------
peer_id : ID
The peer to hole punch with
addrs : list[Multiaddr]
List of addresses to try
punch_time : Optional[float]
Time to perform the punch (if None, do it immediately)
Returns
-------
bool
True if hole punch was successful
"""
if not addrs:
logger.warning("No addresses to try for hole punch with %s", peer_id)
return False
# If punch_time is specified, wait until that time
if punch_time is not None:
now = time.time()
if punch_time > now:
wait_time = punch_time - now
logger.debug("Waiting %.2f seconds before hole punch", wait_time)
await trio.sleep(wait_time)
# Try to dial each address
logger.debug(
"Starting hole punch with peer %s using %d addresses", peer_id, len(addrs)
)
# Filter to only include non-relay addresses
direct_addrs = [
addr for addr in addrs if not str(addr).startswith("/p2p-circuit")
]
if not direct_addrs:
logger.warning("No direct addresses found for peer %s", peer_id)
return False
# Start dialing attempts in parallel
async with trio.open_nursery() as nursery:
for addr in direct_addrs[
:5
]: # Limit to 5 addresses to avoid too many connections
nursery.start_soon(self._dial_peer, peer_id, addr)
# Check if we established a direct connection
return await self._have_direct_connection(peer_id)
async def _dial_peer(self, peer_id: ID, addr: Multiaddr) -> None:
"""
Attempt to dial a peer at a specific address.
Parameters
----------
peer_id : ID
The peer to dial
addr : Multiaddr
The address to dial
"""
try:
logger.debug("Attempting to dial %s at %s", peer_id, addr)
# Create peer info
peer_info = PeerInfo(peer_id, [addr])
# Try to connect with timeout
with trio.fail_after(DIAL_TIMEOUT):
await self.host.connect(peer_info)
logger.info("Successfully connected to %s at %s", peer_id, addr)
# Add to direct connections set
self._direct_connections.add(peer_id)
except trio.TooSlowError:
logger.debug("Timeout dialing %s at %s", peer_id, addr)
except Exception as e:
logger.debug("Error dialing %s at %s: %s", peer_id, addr, str(e))
async def _have_direct_connection(self, peer_id: ID) -> bool:
"""
Check if we already have a direct connection to a peer.
Parameters
----------
peer_id : ID
The peer to check
Returns
-------
bool
True if we have a direct connection, False otherwise
"""
# Check our direct connections cache first
if peer_id in self._direct_connections:
return True
# Check if the peer is connected
network = self.host.get_network()
conn_or_conns = network.connections.get(peer_id)
if not conn_or_conns:
return False
# Handle both single connection and list of connections
connections: list[INetConn] = (
[conn_or_conns] if not isinstance(conn_or_conns, list) else conn_or_conns
)
# Check if any connection is direct (not relayed)
for conn in connections:
# Get the transport addresses
addrs = conn.get_transport_addresses()
# If any address doesn't start with /p2p-circuit, it's a direct connection
if any(not str(addr).startswith("/p2p-circuit") for addr in addrs):
# Cache this result
self._direct_connections.add(peer_id)
return True
return False
async def _get_observed_addrs(self) -> list[bytes]:
"""
Get our observed addresses to share with the peer.
Returns
-------
List[bytes]
List of observed addresses as bytes
"""
# Get all listen addresses
addrs = self.host.get_addrs()
# Filter out relay addresses
direct_addrs = [
addr for addr in addrs if not str(addr).startswith("/p2p-circuit")
]
# Limit the number of addresses
if len(direct_addrs) > MAX_OBSERVED_ADDRS:
direct_addrs = direct_addrs[:MAX_OBSERVED_ADDRS]
# Convert to bytes
addr_bytes = [addr.to_bytes() for addr in direct_addrs]
return addr_bytes
def _decode_observed_addrs(self, addr_bytes: list[bytes]) -> list[Multiaddr]:
"""
Decode observed addresses received from a peer.
Parameters
----------
addr_bytes : List[bytes]
The encoded addresses
Returns
-------
List[Multiaddr]
The decoded multiaddresses
"""
result = []
for addr_byte in addr_bytes:
try:
addr = Multiaddr(addr_byte)
# Validate the address (basic check)
if str(addr).startswith("/ip"):
result.append(addr)
except Exception as e:
logger.debug("Error decoding multiaddr: %s", str(e))
return result

View File

@ -0,0 +1,300 @@
"""
NAT traversal utilities for libp2p.
This module provides utilities for NAT traversal and reachability detection.
"""
import ipaddress
import logging
from multiaddr import (
Multiaddr,
)
from libp2p.abc import (
IHost,
INetConn,
)
from libp2p.peer.id import (
ID,
)
logger = logging.getLogger("libp2p.relay.circuit_v2.nat")
# Timeout for reachability checks
REACHABILITY_TIMEOUT = 10 # seconds
# Define private IP ranges
PRIVATE_IP_RANGES = [
("10.0.0.0", "10.255.255.255"), # Class A private network: 10.0.0.0/8
("172.16.0.0", "172.31.255.255"), # Class B private network: 172.16.0.0/12
("192.168.0.0", "192.168.255.255"), # Class C private network: 192.168.0.0/16
]
# Link-local address range: 169.254.0.0/16
LINK_LOCAL_RANGE = ("169.254.0.0", "169.254.255.255")
# Loopback address range: 127.0.0.0/8
LOOPBACK_RANGE = ("127.0.0.0", "127.255.255.255")
def ip_to_int(ip: str) -> int:
"""
Convert an IP address to an integer.
Parameters
----------
ip : str
IP address to convert
Returns
-------
int
Integer representation of the IP
"""
try:
return int(ipaddress.IPv4Address(ip))
except ipaddress.AddressValueError:
# Handle IPv6 addresses
return int(ipaddress.IPv6Address(ip))
def is_ip_in_range(ip: str, start_range: str, end_range: str) -> bool:
"""
Check if an IP address is within a range.
Parameters
----------
ip : str
IP address to check
start_range : str
Start of the range
end_range : str
End of the range
Returns
-------
bool
True if the IP is in the range
"""
try:
ip_int = ip_to_int(ip)
start_int = ip_to_int(start_range)
end_int = ip_to_int(end_range)
return start_int <= ip_int <= end_int
except Exception:
return False
def is_private_ip(ip: str) -> bool:
"""
Check if an IP address is private.
Parameters
----------
ip : str
IP address to check
Returns
-------
bool
True if IP is private
"""
for start_range, end_range in PRIVATE_IP_RANGES:
if is_ip_in_range(ip, start_range, end_range):
return True
# Check for link-local addresses
if is_ip_in_range(ip, *LINK_LOCAL_RANGE):
return True
# Check for loopback addresses
if is_ip_in_range(ip, *LOOPBACK_RANGE):
return True
return False
def extract_ip_from_multiaddr(addr: Multiaddr) -> str | None:
"""
Extract the IP address from a multiaddr.
Parameters
----------
addr : Multiaddr
Multiaddr to extract from
Returns
-------
Optional[str]
IP address or None if not found
"""
# Convert to string representation
addr_str = str(addr)
# Look for IPv4 address
ipv4_start = addr_str.find("/ip4/")
if ipv4_start != -1:
# Extract the IPv4 address
ipv4_end = addr_str.find("/", ipv4_start + 5)
if ipv4_end != -1:
return addr_str[ipv4_start + 5 : ipv4_end]
# Look for IPv6 address
ipv6_start = addr_str.find("/ip6/")
if ipv6_start != -1:
# Extract the IPv6 address
ipv6_end = addr_str.find("/", ipv6_start + 5)
if ipv6_end != -1:
return addr_str[ipv6_start + 5 : ipv6_end]
return None
class ReachabilityChecker:
"""
Utility class for checking peer reachability.
This class assesses whether a peer's addresses are likely
to be directly reachable or behind NAT.
"""
def __init__(self, host: IHost):
"""
Initialize the reachability checker.
Parameters
----------
host : IHost
The libp2p host
"""
self.host = host
self._peer_reachability: dict[ID, bool] = {}
self._known_public_peers: set[ID] = set()
def is_addr_public(self, addr: Multiaddr) -> bool:
"""
Check if an address is likely to be publicly reachable.
Parameters
----------
addr : Multiaddr
The multiaddr to check
Returns
-------
bool
True if address is likely public
"""
# Extract the IP address
ip = extract_ip_from_multiaddr(addr)
if not ip:
return False
# Check if it's a private IP
return not is_private_ip(ip)
def get_public_addrs(self, addrs: list[Multiaddr]) -> list[Multiaddr]:
"""
Filter a list of addresses to only include likely public ones.
Parameters
----------
addrs : List[Multiaddr]
List of addresses to filter
Returns
-------
List[Multiaddr]
List of likely public addresses
"""
return [addr for addr in addrs if self.is_addr_public(addr)]
async def check_peer_reachability(self, peer_id: ID) -> bool:
"""
Check if a peer is directly reachable.
Parameters
----------
peer_id : ID
The peer ID to check
Returns
-------
bool
True if peer is likely directly reachable
"""
# Check if we already know
if peer_id in self._peer_reachability:
return self._peer_reachability[peer_id]
# Check if the peer is connected
network = self.host.get_network()
connections: INetConn | list[INetConn] | None = network.connections.get(peer_id)
if not connections:
# Not connected, can't determine reachability
return False
# Check if any connection is direct (not relayed)
if isinstance(connections, list):
for conn in connections:
# Get the transport addresses
addrs = conn.get_transport_addresses()
# If any address doesn't start with /p2p-circuit,
# it's a direct connection
if any(not str(addr).startswith("/p2p-circuit") for addr in addrs):
self._peer_reachability[peer_id] = True
return True
else:
# Handle single connection case
addrs = connections.get_transport_addresses()
if any(not str(addr).startswith("/p2p-circuit") for addr in addrs):
self._peer_reachability[peer_id] = True
return True
# Get the peer's addresses from peerstore
try:
addrs = self.host.get_peerstore().addrs(peer_id)
# Check if peer has any public addresses
public_addrs = self.get_public_addrs(addrs)
if public_addrs:
self._peer_reachability[peer_id] = True
return True
except Exception as e:
logger.debug("Error getting peer addresses: %s", str(e))
# Default to not directly reachable
self._peer_reachability[peer_id] = False
return False
async def check_self_reachability(self) -> tuple[bool, list[Multiaddr]]:
"""
Check if this host is likely directly reachable.
Returns
-------
Tuple[bool, List[Multiaddr]]
Tuple of (is_reachable, public_addresses)
"""
# Get all host addresses
addrs = self.host.get_addrs()
# Filter for public addresses
public_addrs = self.get_public_addrs(addrs)
# If we have public addresses, assume we're reachable
# This is a simplified assumption - real reachability would need
# external checking
is_reachable = len(public_addrs) > 0
return is_reachable, public_addrs

View File

@ -5,6 +5,11 @@ Contains generated protobuf code for circuit_v2 relay protocol.
"""
# Import the classes to be accessible directly from the package
from .dcutr_pb2 import (
HolePunch,
)
from .circuit_pb2 import (
HopMessage,
Limit,
@ -13,4 +18,4 @@ from .circuit_pb2 import (
StopMessage,
)
__all__ = ["HopMessage", "Limit", "Reservation", "Status", "StopMessage"]
__all__ = ["HopMessage", "Limit", "Reservation", "Status", "StopMessage", "HolePunch"]

View File

@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: libp2p/relay/circuit_v2/pb/circuit.proto
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
@ -12,11 +11,14 @@ from google.protobuf import symbol_database as _symbol_database
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n(libp2p/relay/circuit_v2/pb/circuit.proto\x12\rcircuit.pb.v2\"\xf3\x01\n\nHopMessage\x12,\n\x04type\x18\x01 \x01(\x0e\x32\x1e.circuit.pb.v2.HopMessage.Type\x12\x0c\n\x04peer\x18\x02 \x01(\x0c\x12/\n\x0breservation\x18\x03 \x01(\x0b\x32\x1a.circuit.pb.v2.Reservation\x12#\n\x05limit\x18\x04 \x01(\x0b\x32\x14.circuit.pb.v2.Limit\x12%\n\x06status\x18\x05 \x01(\x0b\x32\x15.circuit.pb.v2.Status\",\n\x04Type\x12\x0b\n\x07RESERVE\x10\x00\x12\x0b\n\x07\x43ONNECT\x10\x01\x12\n\n\x06STATUS\x10\x02\"\x92\x01\n\x0bStopMessage\x12-\n\x04type\x18\x01 \x01(\x0e\x32\x1f.circuit.pb.v2.StopMessage.Type\x12\x0c\n\x04peer\x18\x02 \x01(\x0c\x12%\n\x06status\x18\x03 \x01(\x0b\x32\x15.circuit.pb.v2.Status\"\x1f\n\x04Type\x12\x0b\n\x07\x43ONNECT\x10\x00\x12\n\n\x06STATUS\x10\x01\"A\n\x0bReservation\x12\x0f\n\x07voucher\x18\x01 \x01(\x0c\x12\x11\n\tsignature\x18\x02 \x01(\x0c\x12\x0e\n\x06\x65xpire\x18\x03 \x01(\x03\"\'\n\x05Limit\x12\x10\n\x08\x64uration\x18\x01 \x01(\x03\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x03\"\xf6\x01\n\x06Status\x12(\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x1a.circuit.pb.v2.Status.Code\x12\x0f\n\x07message\x18\x02 \x01(\t\"\xb0\x01\n\x04\x43ode\x12\x06\n\x02OK\x10\x00\x12\x17\n\x13RESERVATION_REFUSED\x10\x64\x12\x1b\n\x17RESOURCE_LIMIT_EXCEEDED\x10\x65\x12\x15\n\x11PERMISSION_DENIED\x10\x66\x12\x16\n\x11\x43ONNECTION_FAILED\x10\xc8\x01\x12\x11\n\x0c\x44IAL_REFUSED\x10\xc9\x01\x12\x10\n\x0bSTOP_FAILED\x10\xac\x02\x12\x16\n\x11MALFORMED_MESSAGE\x10\x90\x03\x62\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.relay.circuit_v2.pb.circuit_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_HOPMESSAGE._serialized_start=60
_HOPMESSAGE._serialized_end=303

View File

@ -0,0 +1,14 @@
syntax = "proto2";
package holepunch.pb;
message HolePunch {
enum Type {
CONNECT = 100;
SYNC = 300;
}
required Type type = 1;
repeated bytes ObsAddrs = 2;
}

View File

@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: libp2p/relay/circuit_v2/pb/dcutr.proto
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n&libp2p/relay/circuit_v2/pb/dcutr.proto\x12\x0cholepunch.pb\"i\n\tHolePunch\x12*\n\x04type\x18\x01 \x02(\x0e\x32\x1c.holepunch.pb.HolePunch.Type\x12\x10\n\x08ObsAddrs\x18\x02 \x03(\x0c\"\x1e\n\x04Type\x12\x0b\n\x07\x43ONNECT\x10\x64\x12\t\n\x04SYNC\x10\xac\x02')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.relay.circuit_v2.pb.dcutr_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_HOLEPUNCH._serialized_start=56
_HOLEPUNCH._serialized_end=161
_HOLEPUNCH_TYPE._serialized_start=131
_HOLEPUNCH_TYPE._serialized_end=161
# @@protoc_insertion_point(module_scope)

View File

@ -0,0 +1,53 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
"""
import builtins
import collections.abc
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.internal.enum_type_wrapper
import google.protobuf.message
import sys
import typing
if sys.version_info >= (3, 10):
import typing as typing_extensions
else:
import typing_extensions
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@typing.final
class HolePunch(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
class _Type:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _TypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[HolePunch._Type.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
CONNECT: HolePunch._Type.ValueType # 100
SYNC: HolePunch._Type.ValueType # 300
class Type(_Type, metaclass=_TypeEnumTypeWrapper): ...
CONNECT: HolePunch.Type.ValueType # 100
SYNC: HolePunch.Type.ValueType # 300
TYPE_FIELD_NUMBER: builtins.int
OBSADDRS_FIELD_NUMBER: builtins.int
type: global___HolePunch.Type.ValueType
@property
def ObsAddrs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ...
def __init__(
self,
*,
type: global___HolePunch.Type.ValueType | None = ...,
ObsAddrs: collections.abc.Iterable[builtins.bytes] | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["type", b"type"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["ObsAddrs", b"ObsAddrs", "type", b"type"]) -> None: ...
global___HolePunch = HolePunch

View File

@ -41,7 +41,8 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter):
read_writer: NoisePacketReadWriter
noise_state: NoiseState
# FIXME: This prefix is added in msg#3 in Go. Check whether it's a desired behavior.
# NOTE: This prefix is added in msg#3 in Go.
# Support in py-libp2p is available but not used
prefix: bytes = b"\x00" * 32
def __init__(self, conn: IRawConnection, noise_state: NoiseState) -> None:

View File

@ -29,11 +29,6 @@ class Transport(ISecureTransport):
early_data: bytes | None
with_noise_pipes: bool
# NOTE: Implementations that support Noise Pipes must decide whether to use
# an XX or IK handshake based on whether they possess a cached static
# Noise key for the remote peer.
# TODO: A storage of seen noise static keys for pattern IK?
def __init__(
self,
libp2p_keypair: KeyPair,

View File

@ -1,3 +1,5 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from types import (
TracebackType,
)
@ -32,6 +34,72 @@ if TYPE_CHECKING:
)
class ReadWriteLock:
"""
A read-write lock that allows multiple concurrent readers
or one exclusive writer, implemented using Trio primitives.
"""
def __init__(self) -> None:
self._readers = 0
self._readers_lock = trio.Lock() # Protects access to _readers count
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time
async def acquire_read(self) -> None:
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
try:
async with self._readers_lock:
if self._readers == 0:
await self._writer_lock.acquire()
self._readers += 1
except trio.Cancelled:
raise
async def release_read(self) -> None:
"""Release a read lock."""
async with self._readers_lock:
if self._readers == 1:
self._writer_lock.release()
self._readers -= 1
async def acquire_write(self) -> None:
"""Acquire an exclusive write lock."""
try:
await self._writer_lock.acquire()
except trio.Cancelled:
raise
def release_write(self) -> None:
"""Release the exclusive write lock."""
self._writer_lock.release()
@asynccontextmanager
async def read_lock(self) -> AsyncGenerator[None, None]:
"""Context manager for acquiring and releasing a read lock safely."""
acquire = False
try:
await self.acquire_read()
acquire = True
yield
finally:
if acquire:
with trio.CancelScope() as scope:
scope.shield = True
await self.release_read()
@asynccontextmanager
async def write_lock(self) -> AsyncGenerator[None, None]:
"""Context manager for acquiring and releasing a write lock safely."""
acquire = False
try:
await self.acquire_write()
acquire = True
yield
finally:
if acquire:
self.release_write()
class MplexStream(IMuxedStream):
"""
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go
@ -46,7 +114,7 @@ class MplexStream(IMuxedStream):
read_deadline: int | None
write_deadline: int | None
# TODO: Add lock for read/write to avoid interleaving receiving messages?
rw_lock: ReadWriteLock
close_lock: trio.Lock
# NOTE: `dataIn` is size of 8 in Go implementation.
@ -80,6 +148,7 @@ class MplexStream(IMuxedStream):
self.event_remote_closed = trio.Event()
self.event_reset = trio.Event()
self.close_lock = trio.Lock()
self.rw_lock = ReadWriteLock()
self.incoming_data_channel = incoming_data_channel
self._buf = bytearray()
@ -113,48 +182,49 @@ class MplexStream(IMuxedStream):
:param n: number of bytes to read
:return: bytes actually read
"""
if n is not None and n < 0:
raise ValueError(
"the number of bytes to read `n` must be non-negative or "
f"`None` to indicate read until EOF, got n={n}"
)
if self.event_reset.is_set():
raise MplexStreamReset
if n is None:
return await self._read_until_eof()
if len(self._buf) == 0:
data: bytes
# Peek whether there is data available. If yes, we just read until there is
# no data, then return.
try:
data = self.incoming_data_channel.receive_nowait()
self._buf.extend(data)
except trio.EndOfChannel:
raise MplexStreamEOF
except trio.WouldBlock:
# We know `receive` will be blocked here. Wait for data here with
# `receive` and catch all kinds of errors here.
async with self.rw_lock.read_lock():
if n is not None and n < 0:
raise ValueError(
"the number of bytes to read `n` must be non-negative or "
f"`None` to indicate read until EOF, got n={n}"
)
if self.event_reset.is_set():
raise MplexStreamReset
if n is None:
return await self._read_until_eof()
if len(self._buf) == 0:
data: bytes
# Peek whether there is data available. If yes, we just read until
# there is no data, then return.
try:
data = await self.incoming_data_channel.receive()
data = self.incoming_data_channel.receive_nowait()
self._buf.extend(data)
except trio.EndOfChannel:
if self.event_reset.is_set():
raise MplexStreamReset
if self.event_remote_closed.is_set():
raise MplexStreamEOF
except trio.ClosedResourceError as error:
# Probably `incoming_data_channel` is closed in `reset` when we are
# waiting for `receive`.
if self.event_reset.is_set():
raise MplexStreamReset
raise Exception(
"`incoming_data_channel` is closed but stream is not reset. "
"This should never happen."
) from error
self._buf.extend(self._read_return_when_blocked())
payload = self._buf[:n]
self._buf = self._buf[len(payload) :]
return bytes(payload)
raise MplexStreamEOF
except trio.WouldBlock:
# We know `receive` will be blocked here. Wait for data here with
# `receive` and catch all kinds of errors here.
try:
data = await self.incoming_data_channel.receive()
self._buf.extend(data)
except trio.EndOfChannel:
if self.event_reset.is_set():
raise MplexStreamReset
if self.event_remote_closed.is_set():
raise MplexStreamEOF
except trio.ClosedResourceError as error:
# Probably `incoming_data_channel` is closed in `reset` when
# we are waiting for `receive`.
if self.event_reset.is_set():
raise MplexStreamReset
raise Exception(
"`incoming_data_channel` is closed but stream is not reset."
"This should never happen."
) from error
self._buf.extend(self._read_return_when_blocked())
payload = self._buf[:n]
self._buf = self._buf[len(payload) :]
return bytes(payload)
async def write(self, data: bytes) -> None:
"""
@ -162,22 +232,21 @@ class MplexStream(IMuxedStream):
:return: number of bytes written
"""
if self.event_local_closed.is_set():
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")
flag = (
HeaderTags.MessageInitiator
if self.is_initiator
else HeaderTags.MessageReceiver
)
await self.muxed_conn.send_message(flag, data, self.stream_id)
async with self.rw_lock.write_lock():
if self.event_local_closed.is_set():
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")
flag = (
HeaderTags.MessageInitiator
if self.is_initiator
else HeaderTags.MessageReceiver
)
await self.muxed_conn.send_message(flag, data, self.stream_id)
async def close(self) -> None:
"""
Closing a stream closes it for writing and closes the remote end for
reading but allows writing in the other direction.
"""
# TODO error handling with timeout
async with self.close_lock:
if self.event_local_closed.is_set():
return
@ -185,8 +254,17 @@ class MplexStream(IMuxedStream):
flag = (
HeaderTags.CloseInitiator if self.is_initiator else HeaderTags.CloseReceiver
)
# TODO: Raise when `muxed_conn.send_message` fails and `Mplex` isn't shutdown.
await self.muxed_conn.send_message(flag, None, self.stream_id)
try:
with trio.fail_after(5): # timeout in seconds
await self.muxed_conn.send_message(flag, None, self.stream_id)
except trio.TooSlowError:
raise TimeoutError("Timeout while trying to close the stream")
except MuxedConnUnavailable:
if not self.muxed_conn.event_shutting_down.is_set():
raise RuntimeError(
"Failed to send close message and Mplex isn't shutting down"
)
_is_remote_closed: bool
async with self.close_lock:

View File

@ -45,6 +45,9 @@ from libp2p.stream_muxer.exceptions import (
MuxedStreamReset,
)
# Configure logger for this module
logger = logging.getLogger("libp2p.stream_muxer.yamux")
PROTOCOL_ID = "/yamux/1.0.0"
TYPE_DATA = 0x0
TYPE_WINDOW_UPDATE = 0x1
@ -98,13 +101,13 @@ class YamuxStream(IMuxedStream):
# Flow control: Check if we have enough send window
total_len = len(data)
sent = 0
logging.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
while sent < total_len:
# Wait for available window with timeout
timeout = False
async with self.window_lock:
if self.send_window == 0:
logging.debug(
logger.debug(
f"Stream {self.stream_id}: Window is zero, waiting for update"
)
# Release lock and wait with timeout
@ -152,12 +155,12 @@ class YamuxStream(IMuxedStream):
"""
if increment <= 0:
# If increment is zero or negative, skip sending update
logging.debug(
logger.debug(
f"Stream {self.stream_id}: Skipping window update"
f"(increment={increment})"
)
return
logging.debug(
logger.debug(
f"Stream {self.stream_id}: Sending window update with increment={increment}"
)
@ -185,7 +188,7 @@ class YamuxStream(IMuxedStream):
# If the stream is closed for receiving and the buffer is empty, raise EOF
if self.recv_closed and not self.conn.stream_buffers.get(self.stream_id):
logging.debug(
logger.debug(
f"Stream {self.stream_id}: Stream closed for receiving and buffer empty"
)
raise MuxedStreamEOF("Stream is closed for receiving")
@ -198,7 +201,7 @@ class YamuxStream(IMuxedStream):
# If buffer is not available, check if stream is closed
if buffer is None:
logging.debug(f"Stream {self.stream_id}: No buffer available")
logger.debug(f"Stream {self.stream_id}: No buffer available")
raise MuxedStreamEOF("Stream buffer closed")
# If we have data in buffer, process it
@ -210,34 +213,34 @@ class YamuxStream(IMuxedStream):
# Send window update for the chunk we just read
async with self.window_lock:
self.recv_window += len(chunk)
logging.debug(f"Stream {self.stream_id}: Update {len(chunk)}")
logger.debug(f"Stream {self.stream_id}: Update {len(chunk)}")
await self.send_window_update(len(chunk), skip_lock=True)
# If stream is closed (FIN received) and buffer is empty, break
if self.recv_closed and len(buffer) == 0:
logging.debug(f"Stream {self.stream_id}: Closed with empty buffer")
logger.debug(f"Stream {self.stream_id}: Closed with empty buffer")
break
# If stream was reset, raise reset error
if self.reset_received:
logging.debug(f"Stream {self.stream_id}: Stream was reset")
logger.debug(f"Stream {self.stream_id}: Stream was reset")
raise MuxedStreamReset("Stream was reset")
# Wait for more data or stream closure
logging.debug(f"Stream {self.stream_id}: Waiting for data or FIN")
logger.debug(f"Stream {self.stream_id}: Waiting for data or FIN")
await self.conn.stream_events[self.stream_id].wait()
self.conn.stream_events[self.stream_id] = trio.Event()
# After loop exit, first check if we have data to return
if data:
logging.debug(
logger.debug(
f"Stream {self.stream_id}: Returning {len(data)} bytes after loop"
)
return data
# No data accumulated, now check why we exited the loop
if self.conn.event_shutting_down.is_set():
logging.debug(f"Stream {self.stream_id}: Connection shutting down")
logger.debug(f"Stream {self.stream_id}: Connection shutting down")
raise MuxedStreamEOF("Connection shut down")
# Return empty data
@ -246,7 +249,7 @@ class YamuxStream(IMuxedStream):
data = await self.conn.read_stream(self.stream_id, n)
async with self.window_lock:
self.recv_window += len(data)
logging.debug(
logger.debug(
f"Stream {self.stream_id}: Sending window update after read, "
f"increment={len(data)}"
)
@ -255,7 +258,7 @@ class YamuxStream(IMuxedStream):
async def close(self) -> None:
if not self.send_closed:
logging.debug(f"Half-closing stream {self.stream_id} (local end)")
logger.debug(f"Half-closing stream {self.stream_id} (local end)")
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0
)
@ -271,7 +274,7 @@ class YamuxStream(IMuxedStream):
async def reset(self) -> None:
if not self.closed:
logging.debug(f"Resetting stream {self.stream_id}")
logger.debug(f"Resetting stream {self.stream_id}")
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0
)
@ -349,7 +352,7 @@ class Yamux(IMuxedConn):
self._nursery: Nursery | None = None
async def start(self) -> None:
logging.debug(f"Starting Yamux for {self.peer_id}")
logger.debug(f"Starting Yamux for {self.peer_id}")
if self.event_started.is_set():
return
async with trio.open_nursery() as nursery:
@ -362,7 +365,7 @@ class Yamux(IMuxedConn):
return self.is_initiator_value
async def close(self, error_code: int = GO_AWAY_NORMAL) -> None:
logging.debug(f"Closing Yamux connection with code {error_code}")
logger.debug(f"Closing Yamux connection with code {error_code}")
async with self.streams_lock:
if not self.event_shutting_down.is_set():
try:
@ -371,7 +374,7 @@ class Yamux(IMuxedConn):
)
await self.secured_conn.write(header)
except Exception as e:
logging.debug(f"Failed to send GO_AWAY: {e}")
logger.debug(f"Failed to send GO_AWAY: {e}")
self.event_shutting_down.set()
for stream in self.streams.values():
stream.closed = True
@ -382,12 +385,12 @@ class Yamux(IMuxedConn):
self.stream_events.clear()
try:
await self.secured_conn.close()
logging.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
logger.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
except Exception as e:
logging.debug(f"Error closing secured_conn for peer {self.peer_id}: {e}")
logger.debug(f"Error closing secured_conn for peer {self.peer_id}: {e}")
self.event_closed.set()
if self.on_close:
logging.debug(f"Calling on_close in Yamux.close for peer {self.peer_id}")
logger.debug(f"Calling on_close in Yamux.close for peer {self.peer_id}")
if inspect.iscoroutinefunction(self.on_close):
if self.on_close is not None:
await self.on_close()
@ -416,7 +419,7 @@ class Yamux(IMuxedConn):
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_SYN, stream_id, 0
)
logging.debug(f"Sending SYN header for stream {stream_id}")
logger.debug(f"Sending SYN header for stream {stream_id}")
await self.secured_conn.write(header)
return stream
except Exception as e:
@ -424,32 +427,32 @@ class Yamux(IMuxedConn):
raise e
async def accept_stream(self) -> IMuxedStream:
logging.debug("Waiting for new stream")
logger.debug("Waiting for new stream")
try:
stream = await self.new_stream_receive_channel.receive()
logging.debug(f"Received stream {stream.stream_id}")
logger.debug(f"Received stream {stream.stream_id}")
return stream
except trio.EndOfChannel:
raise MuxedStreamError("No new streams available")
async def read_stream(self, stream_id: int, n: int = -1) -> bytes:
logging.debug(f"Reading from stream {self.peer_id}:{stream_id}, n={n}")
logger.debug(f"Reading from stream {self.peer_id}:{stream_id}, n={n}")
if n is None:
n = -1
while True:
async with self.streams_lock:
if stream_id not in self.streams:
logging.debug(f"Stream {self.peer_id}:{stream_id} unknown")
logger.debug(f"Stream {self.peer_id}:{stream_id} unknown")
raise MuxedStreamEOF("Stream closed")
if self.event_shutting_down.is_set():
logging.debug(
logger.debug(
f"Stream {self.peer_id}:{stream_id}: connection shutting down"
)
raise MuxedStreamEOF("Connection shut down")
stream = self.streams[stream_id]
buffer = self.stream_buffers.get(stream_id)
logging.debug(
logger.debug(
f"Stream {self.peer_id}:{stream_id}: "
f"closed={stream.closed}, "
f"recv_closed={stream.recv_closed}, "
@ -457,7 +460,7 @@ class Yamux(IMuxedConn):
f"buffer_len={len(buffer) if buffer else 0}"
)
if buffer is None:
logging.debug(
logger.debug(
f"Stream {self.peer_id}:{stream_id}:"
f"Buffer gone, assuming closed"
)
@ -470,7 +473,7 @@ class Yamux(IMuxedConn):
else:
data = bytes(buffer[:n])
del buffer[:n]
logging.debug(
logger.debug(
f"Returning {len(data)} bytes"
f"from stream {self.peer_id}:{stream_id}, "
f"buffer_len={len(buffer)}"
@ -478,7 +481,7 @@ class Yamux(IMuxedConn):
return data
# If reset received and buffer is empty, raise reset
if stream.reset_received:
logging.debug(
logger.debug(
f"Stream {self.peer_id}:{stream_id}:"
f"reset_received=True, raising MuxedStreamReset"
)
@ -491,7 +494,7 @@ class Yamux(IMuxedConn):
else:
data = bytes(buffer[:n])
del buffer[:n]
logging.debug(
logger.debug(
f"Returning {len(data)} bytes"
f"from stream {self.peer_id}:{stream_id}, "
f"buffer_len={len(buffer)}"
@ -499,21 +502,21 @@ class Yamux(IMuxedConn):
return data
# Check if stream is closed
if stream.closed:
logging.debug(
logger.debug(
f"Stream {self.peer_id}:{stream_id}:"
f"closed=True, raising MuxedStreamReset"
)
raise MuxedStreamReset("Stream is reset or closed")
# Check if recv_closed and buffer empty
if stream.recv_closed:
logging.debug(
logger.debug(
f"Stream {self.peer_id}:{stream_id}:"
f"recv_closed=True, buffer empty, raising EOF"
)
raise MuxedStreamEOF("Stream is closed for receiving")
# Wait for data if stream is still open
logging.debug(f"Waiting for data on stream {self.peer_id}:{stream_id}")
logger.debug(f"Waiting for data on stream {self.peer_id}:{stream_id}")
try:
await self.stream_events[stream_id].wait()
self.stream_events[stream_id] = trio.Event()
@ -528,7 +531,7 @@ class Yamux(IMuxedConn):
try:
header = await self.secured_conn.read(HEADER_SIZE)
if not header or len(header) < HEADER_SIZE:
logging.debug(
logger.debug(
f"Connection closed orincomplete header for peer {self.peer_id}"
)
self.event_shutting_down.set()
@ -537,7 +540,7 @@ class Yamux(IMuxedConn):
version, typ, flags, stream_id, length = struct.unpack(
YAMUX_HEADER_FORMAT, header
)
logging.debug(
logger.debug(
f"Received header for peer {self.peer_id}:"
f"type={typ}, flags={flags}, stream_id={stream_id},"
f"length={length}"
@ -558,7 +561,7 @@ class Yamux(IMuxedConn):
0,
)
await self.secured_conn.write(ack_header)
logging.debug(
logger.debug(
f"Sending stream {stream_id}"
f"to channel for peer {self.peer_id}"
)
@ -576,7 +579,7 @@ class Yamux(IMuxedConn):
elif typ == TYPE_DATA and flags & FLAG_RST:
async with self.streams_lock:
if stream_id in self.streams:
logging.debug(
logger.debug(
f"Resetting stream {stream_id} for peer {self.peer_id}"
)
self.streams[stream_id].closed = True
@ -585,27 +588,27 @@ class Yamux(IMuxedConn):
elif typ == TYPE_DATA and flags & FLAG_ACK:
async with self.streams_lock:
if stream_id in self.streams:
logging.debug(
logger.debug(
f"Received ACK for stream"
f"{stream_id} for peer {self.peer_id}"
)
elif typ == TYPE_GO_AWAY:
error_code = length
if error_code == GO_AWAY_NORMAL:
logging.debug(
logger.debug(
f"Received GO_AWAY for peer"
f"{self.peer_id}: Normal termination"
)
elif error_code == GO_AWAY_PROTOCOL_ERROR:
logging.error(
logger.error(
f"Received GO_AWAY for peer{self.peer_id}: Protocol error"
)
elif error_code == GO_AWAY_INTERNAL_ERROR:
logging.error(
logger.error(
f"Received GO_AWAY for peer {self.peer_id}: Internal error"
)
else:
logging.error(
logger.error(
f"Received GO_AWAY for peer {self.peer_id}"
f"with unknown error code: {error_code}"
)
@ -614,7 +617,7 @@ class Yamux(IMuxedConn):
break
elif typ == TYPE_PING:
if flags & FLAG_SYN:
logging.debug(
logger.debug(
f"Received ping request with value"
f"{length} for peer {self.peer_id}"
)
@ -623,7 +626,7 @@ class Yamux(IMuxedConn):
)
await self.secured_conn.write(ping_header)
elif flags & FLAG_ACK:
logging.debug(
logger.debug(
f"Received ping response with value"
f"{length} for peer {self.peer_id}"
)
@ -637,7 +640,7 @@ class Yamux(IMuxedConn):
self.stream_buffers[stream_id].extend(data)
self.stream_events[stream_id].set()
if flags & FLAG_FIN:
logging.debug(
logger.debug(
f"Received FIN for stream {self.peer_id}:"
f"{stream_id}, marking recv_closed"
)
@ -645,7 +648,7 @@ class Yamux(IMuxedConn):
if self.streams[stream_id].send_closed:
self.streams[stream_id].closed = True
except Exception as e:
logging.error(f"Error reading data for stream {stream_id}: {e}")
logger.error(f"Error reading data for stream {stream_id}: {e}")
# Mark stream as closed on read error
async with self.streams_lock:
if stream_id in self.streams:
@ -659,7 +662,7 @@ class Yamux(IMuxedConn):
if stream_id in self.streams:
stream = self.streams[stream_id]
async with stream.window_lock:
logging.debug(
logger.debug(
f"Received window update for stream"
f"{self.peer_id}:{stream_id},"
f" increment: {increment}"
@ -674,7 +677,7 @@ class Yamux(IMuxedConn):
and details.get("requested_count") == 2
and details.get("received_count") == 0
):
logging.info(
logger.info(
f"Stream closed cleanly for peer {self.peer_id}"
+ f" (IncompleteReadError: {details})"
)
@ -682,15 +685,32 @@ class Yamux(IMuxedConn):
await self._cleanup_on_error()
break
else:
logging.error(
logger.error(
f"Error in handle_incoming for peer {self.peer_id}: "
+ f"{type(e).__name__}: {str(e)}"
)
else:
logging.error(
f"Error in handle_incoming for peer {self.peer_id}: "
+ f"{type(e).__name__}: {str(e)}"
)
# Handle RawConnError with more nuance
if isinstance(e, RawConnError):
error_msg = str(e)
# If RawConnError is empty, it's likely normal cleanup
if not error_msg.strip():
logger.info(
f"RawConnError (empty) during cleanup for peer "
f"{self.peer_id} (normal connection shutdown)"
)
else:
# Log non-empty RawConnError as warning
logger.warning(
f"RawConnError during connection handling for peer "
f"{self.peer_id}: {error_msg}"
)
else:
# Log all other errors normally
logger.error(
f"Error in handle_incoming for peer {self.peer_id}: "
+ f"{type(e).__name__}: {str(e)}"
)
# Don't crash the whole connection for temporary errors
if self.event_shutting_down.is_set() or isinstance(
e, (RawConnError, OSError)
@ -720,9 +740,9 @@ class Yamux(IMuxedConn):
# Close the secured connection
try:
await self.secured_conn.close()
logging.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
logger.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
except Exception as close_error:
logging.error(
logger.error(
f"Error closing secured_conn for peer {self.peer_id}: {close_error}"
)
@ -731,14 +751,14 @@ class Yamux(IMuxedConn):
# Call on_close callback if provided
if self.on_close:
logging.debug(f"Calling on_close for peer {self.peer_id}")
logger.debug(f"Calling on_close for peer {self.peer_id}")
try:
if inspect.iscoroutinefunction(self.on_close):
await self.on_close()
else:
self.on_close()
except Exception as callback_error:
logging.error(f"Error in on_close callback: {callback_error}")
logger.error(f"Error in on_close callback: {callback_error}")
# Cancel nursery tasks
if self._nursery:

View File

@ -9,6 +9,7 @@ from libp2p.utils.varint import (
read_varint_prefixed_bytes,
decode_varint_from_bytes,
decode_varint_with_size,
read_length_prefixed_protobuf,
)
from libp2p.utils.version import (
get_agent_version,
@ -24,4 +25,5 @@ __all__ = [
"read_varint_prefixed_bytes",
"decode_varint_from_bytes",
"decode_varint_with_size",
"read_length_prefixed_protobuf",
]

View File

@ -1,7 +1,9 @@
import itertools
import logging
import math
from typing import BinaryIO
from libp2p.abc import INetStream
from libp2p.exceptions import (
ParseError,
)
@ -25,42 +27,41 @@ HIGH_MASK = 2**7
SHIFT_64_BIT_MAX = int(math.ceil(64 / 7)) * 7
def encode_uvarint(number: int) -> bytes:
"""Pack `number` into varint bytes."""
buf = b""
while True:
towrite = number & 0x7F
number >>= 7
if number:
buf += bytes((towrite | 0x80,))
else:
buf += bytes((towrite,))
def encode_uvarint(value: int) -> bytes:
"""Encode an unsigned integer as a varint."""
if value < 0:
raise ValueError("Cannot encode negative value as uvarint")
result = bytearray()
while value >= 0x80:
result.append((value & 0x7F) | 0x80)
value >>= 7
result.append(value & 0x7F)
return bytes(result)
def decode_uvarint(data: bytes) -> int:
"""Decode a varint from bytes."""
if not data:
raise ParseError("Unexpected end of data")
result = 0
shift = 0
for byte in data:
result |= (byte & 0x7F) << shift
if (byte & 0x80) == 0:
break
return buf
shift += 7
if shift >= 64:
raise ValueError("Varint too long")
return result
def decode_varint_from_bytes(data: bytes) -> int:
"""
Decode a varint from bytes and return the value.
This is a synchronous version of decode_uvarint_from_stream for already-read bytes.
"""
res = 0
for shift in itertools.count(0, 7):
if shift > SHIFT_64_BIT_MAX:
raise ParseError("Integer is too large...")
if not data:
raise ParseError("Unexpected end of data")
value = data[0]
data = data[1:]
res += (value & LOW_MASK) << shift
if not value & HIGH_MASK:
break
return res
"""Decode a varint from bytes (alias for decode_uvarint for backward comp)."""
return decode_uvarint(data)
async def decode_uvarint_from_stream(reader: Reader) -> int:
@ -84,34 +85,33 @@ async def decode_uvarint_from_stream(reader: Reader) -> int:
def decode_varint_with_size(data: bytes) -> tuple[int, int]:
"""
Decode a varint from bytes and return (value, bytes_consumed).
Returns (0, 0) if the data doesn't start with a valid varint.
Decode a varint from bytes and return both the value and the number of bytes
consumed.
Returns:
Tuple[int, int]: (value, bytes_consumed)
"""
try:
# Calculate how many bytes the varint consumes
varint_size = 0
for i, byte in enumerate(data):
varint_size += 1
if (byte & 0x80) == 0:
break
result = 0
shift = 0
bytes_consumed = 0
if varint_size == 0:
return 0, 0
for byte in data:
result |= (byte & 0x7F) << shift
bytes_consumed += 1
if (byte & 0x80) == 0:
break
shift += 7
if shift >= 64:
raise ValueError("Varint too long")
# Extract just the varint bytes
varint_bytes = data[:varint_size]
# Decode the varint
value = decode_varint_from_bytes(varint_bytes)
return value, varint_size
except Exception:
return 0, 0
return result, bytes_consumed
def encode_varint_prefixed(msg_bytes: bytes) -> bytes:
varint_len = encode_uvarint(len(msg_bytes))
return varint_len + msg_bytes
def encode_varint_prefixed(data: bytes) -> bytes:
"""Encode data with a varint length prefix."""
length_bytes = encode_uvarint(len(data))
return length_bytes + data
async def read_varint_prefixed_bytes(reader: Reader) -> bytes:
@ -138,3 +138,95 @@ async def read_delim(reader: Reader) -> bytes:
f'`msg_bytes` is not delimited by b"\\n": `msg_bytes`={msg_bytes!r}'
)
return msg_bytes[:-1]
def read_varint_prefixed_bytes_sync(
stream: BinaryIO, max_length: int = 1024 * 1024
) -> bytes:
"""
Read varint-prefixed bytes from a stream.
Args:
stream: A stream-like object with a read() method
max_length: Maximum allowed data length to prevent memory exhaustion
Returns:
bytes: The data without the length prefix
Raises:
ValueError: If the length prefix is invalid or too large
EOFError: If the stream ends unexpectedly
"""
# Read the varint length prefix
length_bytes = b""
while True:
byte_data = stream.read(1)
if not byte_data:
raise EOFError("Stream ended while reading varint length prefix")
length_bytes += byte_data
if byte_data[0] & 0x80 == 0:
break
# Decode the length
length = decode_uvarint(length_bytes)
if length > max_length:
raise ValueError(f"Data length {length} exceeds maximum allowed {max_length}")
# Read the data
data = stream.read(length)
if len(data) != length:
raise EOFError(f"Expected {length} bytes, got {len(data)}")
return data
async def read_length_prefixed_protobuf(
stream: INetStream, use_varint_format: bool = True, max_length: int = 1024 * 1024
) -> bytes:
"""Read a protobuf message from a stream, handling both formats."""
if use_varint_format:
# Read length-prefixed protobuf message from the stream
# First read the varint length prefix
length_bytes = b""
while True:
b = await stream.read(1)
if not b:
raise Exception("No length prefix received")
length_bytes += b
if b[0] & 0x80 == 0:
break
msg_length = decode_varint_from_bytes(length_bytes)
if msg_length > max_length:
raise Exception(
f"Message length {msg_length} exceeds maximum allowed {max_length}"
)
# Read the protobuf message
data = await stream.read(msg_length)
if len(data) != msg_length:
raise Exception(
f"Incomplete message: expected {msg_length}, got {len(data)}"
)
return data
else:
# Read raw protobuf message from the stream
# For raw format, read all available data in one go
data = await stream.read()
# If we got no data, raise an exception
if not data:
raise Exception("No data received in raw format")
if len(data) > max_length:
raise Exception(
f"Message length {len(data)} exceeds maximum allowed {max_length}"
)
return data

View File

@ -0,0 +1 @@
remove FIXME comment since it's obsolete and 32-byte prefix support is there but not enabled by default

View File

@ -0,0 +1 @@
Added `Bootstrap` peer discovery module that allows nodes to connect to predefined bootstrap peers for network discovery.

View File

@ -0,0 +1 @@
Add lock for read/write to avoid interleaving receiving messages in mplex_stream.py

View File

@ -0,0 +1 @@
[mplex] Add timeout and error handling during stream close

View File

@ -0,0 +1,2 @@
Added the `Certified Addr-Book` interface supported by `Envelope` and `PeerRecord` class.
Integrated the signed-peer-record transfer in the identify/push protocols.

View File

@ -0,0 +1,2 @@
Added throttling for async topic validators in validate_msg, enforcing a
concurrency limit to prevent resource exhaustion under heavy load.

View File

@ -0,0 +1 @@
Pin py-multiaddr dependency to specific git commit db8124e2321f316d3b7d2733c7df11d6ad9c03e6

View File

@ -0,0 +1 @@
Replace the libp2p.peer.ID cache attributes with functools.cached_property functional decorator.

View File

@ -0,0 +1 @@
Clarified the requirement for a trailing newline in newsfragments to pass lint checks.

View File

@ -0,0 +1 @@
Fixed incorrect handling of raw protobuf format in identify protocol. The identify example now properly handles both raw and length-prefixed (varint) message formats, provides better error messages, and displays connection status with peer IDs. Replaced mock-based tests with comprehensive real network integration tests for both formats.

View File

@ -0,0 +1 @@
Fixed incorrect handling of raw protobuf format in identify push protocol. The identify push example now properly handles both raw and length-prefixed (varint) message formats, provides better error messages, and displays connection status with peer IDs. Replaced mock-based tests with comprehensive real network integration tests for both formats.

View File

@ -0,0 +1 @@
Yamux RawConnError Logging Refactor - Improved error handling and debug logging

View File

@ -0,0 +1 @@
The TODO IK patterns in Noise has been deprecated in specs: https://github.com/libp2p/specs/tree/master/noise#handshake-pattern

View File

@ -0,0 +1 @@
Recompiled protobufs that were out of date and added a `make` rule so that protobufs are always up to date.

View File

@ -0,0 +1,3 @@
Remove the already completed TODO tasks in Peerstore:
TODO: Set up an async task for periodic peer-store cleanup for expired addresses and records.
TODO: Make proper use of this function

View File

@ -18,12 +18,19 @@ Each file should be named like `<ISSUE>.<TYPE>.rst`, where
- `performance`
- `removal`
So for example: `123.feature.rst`, `456.bugfix.rst`
So for example: `1024.feature.rst`
**Important**: Ensure the file ends with a newline character (`\n`) to pass GitHub tox linting checks.
```
Added support for Ed25519 key generation in libp2p peer identity creation.
```
If the PR fixes an issue, use that number here. If there is no issue,
then open up the PR first and use the PR number for the newsfragment.
Note that the `towncrier` tool will automatically
**Note** that the `towncrier` tool will automatically
reflow your text, so don't try to do any fancy formatting. Run
`towncrier build --draft` to get a preview of what the release notes entry
will look like in the final release notes.

View File

@ -19,10 +19,11 @@ dependencies = [
"exceptiongroup>=1.2.0; python_version < '3.11'",
"grpcio>=1.41.0",
"lru-dict>=1.1.6",
"multiaddr>=0.0.9",
# "multiaddr>=0.0.9",
"multiaddr @ git+https://github.com/multiformats/py-multiaddr.git@db8124e2321f316d3b7d2733c7df11d6ad9c03e6",
"mypy-protobuf>=3.0.0",
"noiseprotocol>=0.3.0",
"protobuf>=4.21.0,<5.0.0",
"protobuf>=4.25.0,<5.0.0",
"pycryptodome>=3.9.2",
"pymultihash>=0.8.2",
"pynacl>=1.3.0",

View File

@ -13,6 +13,8 @@ from libp2p.identity.identify.identify import (
_multiaddr_to_bytes,
parse_identify_response,
)
from libp2p.peer.envelope import Envelope, consume_envelope, unmarshal_envelope
from libp2p.peer.peer_record import unmarshal_record
from tests.utils.factories import (
host_pair_factory,
)
@ -40,6 +42,19 @@ async def test_identify_protocol(security_protocol):
# Parse the response (handles both old and new formats)
identify_response = parse_identify_response(response)
# Validate the recieved envelope and then store it in the certified-addr-book
envelope, record = consume_envelope(
identify_response.signedPeerRecord, "libp2p-peer-record"
)
assert host_b.peerstore.consume_peer_record(envelope, ttl=7200)
# Check if the peer_id in the record is same as of host_a
assert record.peer_id == host_a.get_id()
# Check if the peer-record is correctly consumed
assert host_a.get_addrs() == host_b.peerstore.addrs(host_a.get_id())
assert isinstance(host_b.peerstore.get_peer_record(host_a.get_id()), Envelope)
logger.debug("host_a: %s", host_a.get_addrs())
logger.debug("host_b: %s", host_b.get_addrs())
@ -71,5 +86,14 @@ async def test_identify_protocol(security_protocol):
# Check protocols
assert set(identify_response.protocols) == set(host_a.get_mux().get_protocols())
# sanity check
assert identify_response == _mk_identify_protobuf(host_a, cleaned_addr)
# sanity check if the peer_id of the identify msg are same
assert (
unmarshal_record(
unmarshal_envelope(identify_response.signedPeerRecord).raw_payload
).peer_id
== unmarshal_record(
unmarshal_envelope(
_mk_identify_protobuf(host_a, cleaned_addr).signedPeerRecord
).raw_payload
).peer_id
)

View File

@ -0,0 +1,241 @@
import logging
import pytest
from libp2p.custom_types import TProtocol
from libp2p.identity.identify.identify import (
AGENT_VERSION,
ID,
PROTOCOL_VERSION,
_multiaddr_to_bytes,
identify_handler_for,
parse_identify_response,
)
from tests.utils.factories import host_pair_factory
logger = logging.getLogger("libp2p.identity.identify-integration-test")
@pytest.mark.trio
async def test_identify_protocol_varint_format_integration(security_protocol):
"""Test identify protocol with varint format in real network scenario."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=True)
)
# Make identify request
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
assert result.listen_addrs == [
_multiaddr_to_bytes(addr) for addr in host_a.get_addrs()
]
@pytest.mark.trio
async def test_identify_protocol_raw_format_integration(security_protocol):
"""Test identify protocol with raw format in real network scenario."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=False)
)
# Make identify request
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
assert result.listen_addrs == [
_multiaddr_to_bytes(addr) for addr in host_a.get_addrs()
]
@pytest.mark.trio
async def test_identify_default_format_behavior(security_protocol):
"""Test identify protocol uses correct default format."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Use default identify handler (should use varint format)
host_a.set_stream_handler(ID, identify_handler_for(host_a))
# Make identify request
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
@pytest.mark.trio
async def test_identify_cross_format_compatibility_varint_to_raw(security_protocol):
"""Test varint dialer with raw listener compatibility."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Host A uses raw format
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=False)
)
# Host B makes request (will automatically detect format)
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response (should work with automatic format detection)
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
@pytest.mark.trio
async def test_identify_cross_format_compatibility_raw_to_varint(security_protocol):
"""Test raw dialer with varint listener compatibility."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Host A uses varint format
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=True)
)
# Host B makes request (will automatically detect format)
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response (should work with automatic format detection)
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
@pytest.mark.trio
async def test_identify_format_detection_robustness(security_protocol):
"""Test identify protocol format detection is robust with various message sizes."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Test both formats with different message sizes
for use_varint in [True, False]:
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=use_varint)
)
# Make identify request
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
@pytest.mark.trio
async def test_identify_large_message_handling(security_protocol):
"""Test identify protocol handles large messages with many protocols."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Add many protocols to make the message larger
async def dummy_handler(stream):
pass
for i in range(10):
host_a.set_stream_handler(TProtocol(f"/test/protocol/{i}"), dummy_handler)
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=True)
)
# Make identify request
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
@pytest.mark.trio
async def test_identify_message_equivalence_real_network(security_protocol):
"""Test that both formats produce equivalent messages in real network."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Test varint format
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=True)
)
stream_varint = await host_b.new_stream(host_a.get_id(), (ID,))
response_varint = await stream_varint.read(8192)
await stream_varint.close()
# Test raw format
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=False)
)
stream_raw = await host_b.new_stream(host_a.get_id(), (ID,))
response_raw = await stream_raw.read(8192)
await stream_raw.close()
# Parse both responses
result_varint = parse_identify_response(response_varint)
result_raw = parse_identify_response(response_raw)
# Both should produce identical parsed results
assert result_varint.agent_version == result_raw.agent_version
assert result_varint.protocol_version == result_raw.protocol_version
assert result_varint.public_key == result_raw.public_key
assert result_varint.listen_addrs == result_raw.listen_addrs

View File

@ -1,410 +0,0 @@
import pytest
from libp2p.identity.identify.identify import (
_mk_identify_protobuf,
)
from libp2p.identity.identify.pb.identify_pb2 import (
Identify,
)
from libp2p.io.abc import Closer, Reader, Writer
from libp2p.utils.varint import (
decode_varint_from_bytes,
encode_varint_prefixed,
)
from tests.utils.factories import (
host_pair_factory,
)
class MockStream(Reader, Writer, Closer):
"""Mock stream for testing identify protocol compatibility."""
def __init__(self, data: bytes):
self.data = data
self.position = 0
self.closed = False
async def read(self, n: int | None = None) -> bytes:
if self.closed or self.position >= len(self.data):
return b""
if n is None:
n = len(self.data) - self.position
result = self.data[self.position : self.position + n]
self.position += len(result)
return result
async def write(self, data: bytes) -> None:
# Mock write - just store the data
pass
async def close(self) -> None:
self.closed = True
def create_identify_message(host, observed_multiaddr=None):
"""Create an identify protobuf message."""
return _mk_identify_protobuf(host, observed_multiaddr)
def create_new_format_message(identify_msg):
"""Create a new format (length-prefixed) identify message."""
msg_bytes = identify_msg.SerializeToString()
return encode_varint_prefixed(msg_bytes)
def create_old_format_message(identify_msg):
"""Create an old format (raw protobuf) identify message."""
return identify_msg.SerializeToString()
async def read_new_format_message(stream) -> bytes:
"""Read a new format (length-prefixed) identify message."""
# Read varint length prefix
length_bytes = b""
while True:
b = await stream.read(1)
if not b:
break
length_bytes += b
if b[0] & 0x80 == 0:
break
if not length_bytes:
raise ValueError("No length prefix received")
msg_length = decode_varint_from_bytes(length_bytes)
# Read the protobuf message
response = await stream.read(msg_length)
if len(response) != msg_length:
raise ValueError("Incomplete message received")
return response
async def read_old_format_message(stream) -> bytes:
"""Read an old format (raw protobuf) identify message."""
# Read all available data
response = b""
while True:
chunk = await stream.read(4096)
if not chunk:
break
response += chunk
return response
async def read_compatible_message(stream) -> bytes:
"""Read an identify message in either old or new format."""
# Try to read a few bytes to detect the format
first_bytes = await stream.read(10)
if not first_bytes:
raise ValueError("No data received")
# Try to decode as varint length prefix (new format)
try:
msg_length = decode_varint_from_bytes(first_bytes)
# Validate that the length is reasonable (not too large)
if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB
# Calculate how many bytes the varint consumed
varint_len = 0
for i, byte in enumerate(first_bytes):
varint_len += 1
if (byte & 0x80) == 0:
break
# Read the remaining protobuf message
remaining_bytes = await stream.read(
msg_length - (len(first_bytes) - varint_len)
)
if len(remaining_bytes) == msg_length - (len(first_bytes) - varint_len):
message_data = first_bytes[varint_len:] + remaining_bytes
# Try to parse as protobuf to validate
try:
Identify().ParseFromString(message_data)
return message_data
except Exception:
# If protobuf parsing fails, fall back to old format
pass
except Exception:
pass
# Fall back to old format (raw protobuf)
response = first_bytes
# Read more data if available
while True:
chunk = await stream.read(4096)
if not chunk:
break
response += chunk
return response
async def read_compatible_message_simple(stream) -> bytes:
"""Read a message in either old or new format (simplified version for testing)."""
# Try to read a few bytes to detect the format
first_bytes = await stream.read(10)
if not first_bytes:
raise ValueError("No data received")
# Try to decode as varint length prefix (new format)
try:
msg_length = decode_varint_from_bytes(first_bytes)
# Validate that the length is reasonable (not too large)
if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB
# Calculate how many bytes the varint consumed
varint_len = 0
for i, byte in enumerate(first_bytes):
varint_len += 1
if (byte & 0x80) == 0:
break
# Read the remaining message
remaining_bytes = await stream.read(
msg_length - (len(first_bytes) - varint_len)
)
if len(remaining_bytes) == msg_length - (len(first_bytes) - varint_len):
return first_bytes[varint_len:] + remaining_bytes
except Exception:
pass
# Fall back to old format (raw data)
response = first_bytes
# Read more data if available
while True:
chunk = await stream.read(4096)
if not chunk:
break
response += chunk
return response
def detect_format(data):
"""Detect if data is in new or old format (varint-prefixed or raw protobuf)."""
if not data:
return "unknown"
# Try to decode as varint
try:
msg_length = decode_varint_from_bytes(data)
# Validate that the length is reasonable
if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB
# Calculate varint length
varint_len = 0
for i, byte in enumerate(data):
varint_len += 1
if (byte & 0x80) == 0:
break
# Check if we have enough data for the message
if len(data) >= varint_len + msg_length:
# Additional check: try to parse the message as protobuf
try:
message_data = data[varint_len : varint_len + msg_length]
Identify().ParseFromString(message_data)
return "new"
except Exception:
# If protobuf parsing fails, it's probably not a valid new format
pass
except Exception:
pass
# If varint decoding fails or length is unreasonable, assume old format
return "old"
@pytest.mark.trio
async def test_identify_new_format_compatibility(security_protocol):
"""Test that identify protocol works with new format (length-prefixed) messages."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Create new format message
new_format_data = create_new_format_message(identify_msg)
# Create mock stream with new format data
stream = MockStream(new_format_data)
# Read using new format reader
response = await read_new_format_message(stream)
# Parse the response
parsed_msg = Identify()
parsed_msg.ParseFromString(response)
# Verify the message content
assert parsed_msg.protocol_version == identify_msg.protocol_version
assert parsed_msg.agent_version == identify_msg.agent_version
assert parsed_msg.public_key == identify_msg.public_key
@pytest.mark.trio
async def test_identify_old_format_compatibility(security_protocol):
"""Test that identify protocol works with old format (raw protobuf) messages."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Create old format message
old_format_data = create_old_format_message(identify_msg)
# Create mock stream with old format data
stream = MockStream(old_format_data)
# Read using old format reader
response = await read_old_format_message(stream)
# Parse the response
parsed_msg = Identify()
parsed_msg.ParseFromString(response)
# Verify the message content
assert parsed_msg.protocol_version == identify_msg.protocol_version
assert parsed_msg.agent_version == identify_msg.agent_version
assert parsed_msg.public_key == identify_msg.public_key
@pytest.mark.trio
async def test_identify_backward_compatibility_old_format(security_protocol):
"""Test backward compatibility reader with old format messages."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Create old format message
old_format_data = create_old_format_message(identify_msg)
# Create mock stream with old format data
stream = MockStream(old_format_data)
# Read using old format reader (which should work reliably)
response = await read_old_format_message(stream)
# Parse the response
parsed_msg = Identify()
parsed_msg.ParseFromString(response)
# Verify the message content
assert parsed_msg.protocol_version == identify_msg.protocol_version
assert parsed_msg.agent_version == identify_msg.agent_version
assert parsed_msg.public_key == identify_msg.public_key
@pytest.mark.trio
async def test_identify_backward_compatibility_new_format(security_protocol):
"""Test backward compatibility reader with new format messages."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Create new format message
new_format_data = create_new_format_message(identify_msg)
# Create mock stream with new format data
stream = MockStream(new_format_data)
# Read using new format reader (which should work reliably)
response = await read_new_format_message(stream)
# Parse the response
parsed_msg = Identify()
parsed_msg.ParseFromString(response)
# Verify the message content
assert parsed_msg.protocol_version == identify_msg.protocol_version
assert parsed_msg.agent_version == identify_msg.agent_version
assert parsed_msg.public_key == identify_msg.public_key
@pytest.mark.trio
async def test_identify_format_detection(security_protocol):
"""Test that the format detection works correctly."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Test new format detection
new_format_data = create_new_format_message(identify_msg)
format_type = detect_format(new_format_data)
assert format_type == "new", "New format should be detected correctly"
# Test old format detection
old_format_data = create_old_format_message(identify_msg)
format_type = detect_format(old_format_data)
assert format_type == "old", "Old format should be detected correctly"
@pytest.mark.trio
async def test_identify_error_handling(security_protocol):
"""Test error handling for malformed messages."""
from libp2p.exceptions import ParseError
# Test with empty data
stream = MockStream(b"")
with pytest.raises(ValueError, match="No data received"):
await read_compatible_message(stream)
# Test with incomplete varint
stream = MockStream(b"\x80") # Incomplete varint
with pytest.raises(ParseError, match="Unexpected end of data"):
await read_new_format_message(stream)
# Test with invalid protobuf data
stream = MockStream(b"\x05invalid") # Length prefix but invalid protobuf
with pytest.raises(Exception): # Should fail when parsing protobuf
response = await read_new_format_message(stream)
Identify().ParseFromString(response)
@pytest.mark.trio
async def test_identify_message_equivalence(security_protocol):
"""Test that old and new format messages are equivalent."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Create both formats
new_format_data = create_new_format_message(identify_msg)
old_format_data = create_old_format_message(identify_msg)
# Extract the protobuf message from new format
varint_len = 0
for i, byte in enumerate(new_format_data):
varint_len += 1
if (byte & 0x80) == 0:
break
new_format_protobuf = new_format_data[varint_len:]
# The protobuf messages should be identical
assert new_format_protobuf == old_format_data, (
"Protobuf messages should be identical in both formats"
)

View File

@ -0,0 +1,552 @@
import logging
import pytest
import trio
from libp2p.custom_types import TProtocol
from libp2p.identity.identify_push.identify_push import (
ID_PUSH,
identify_push_handler_for,
push_identify_to_peer,
push_identify_to_peers,
)
from tests.utils.factories import host_pair_factory
logger = logging.getLogger("libp2p.identity.identify-push-integration-test")
@pytest.mark.trio
async def test_identify_push_protocol_varint_format_integration(security_protocol):
"""Test identify/push protocol with varint format in real network scenario."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Add some protocols to host_b so it has something to push
async def dummy_handler(stream):
pass
host_b.set_stream_handler(TProtocol("/test/protocol/1"), dummy_handler)
host_b.set_stream_handler(TProtocol("/test/protocol/2"), dummy_handler)
# Set up identify/push handler on host_a
host_a.set_stream_handler(
ID_PUSH, identify_push_handler_for(host_a, use_varint_format=True)
)
# Push identify information from host_b to host_a
await push_identify_to_peer(host_b, host_a.get_id(), use_varint_format=True)
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Verify that host_a's peerstore was updated
peerstore_a = host_a.get_peerstore()
peer_id_b = host_b.get_id()
# Check that addresses were added
addrs = peerstore_a.addrs(peer_id_b)
assert len(addrs) > 0
# Check that protocols were added
protocols = peerstore_a.get_protocols(peer_id_b)
assert protocols is not None
# The protocols should include the dummy protocols we added
assert len(protocols) >= 2 # Should include the dummy protocols
@pytest.mark.trio
async def test_identify_push_protocol_raw_format_integration(security_protocol):
"""Test identify/push protocol with raw format in real network scenario."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Add some protocols to both hosts
async def dummy_handler(stream):
pass
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
# Set up identify/push handler on host_a
host_a.set_stream_handler(
ID_PUSH, identify_push_handler_for(host_a, use_varint_format=False)
)
# Push identify information from host_b to host_a
await push_identify_to_peer(host_b, host_a.get_id(), use_varint_format=False)
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Verify that host_a's peerstore was updated
peerstore_a = host_a.get_peerstore()
peer_id_b = host_b.get_id()
# Check that addresses were added
addrs = peerstore_a.addrs(peer_id_b)
assert len(addrs) > 0
# Check that protocols were added
protocols = peerstore_a.get_protocols(peer_id_b)
assert protocols is not None
assert len(protocols) >= 1 # Should include the dummy protocol
@pytest.mark.trio
async def test_identify_push_default_format_behavior(security_protocol):
"""Test identify/push protocol uses correct default format."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Add some protocols to both hosts
async def dummy_handler(stream):
pass
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
# Use default identify/push handler (should use varint format)
host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a))
# Push identify information from host_b to host_a
await push_identify_to_peer(host_b, host_a.get_id())
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Verify that host_a's peerstore was updated
peerstore_a = host_a.get_peerstore()
peer_id_b = host_b.get_id()
# Check that protocols were added
protocols = peerstore_a.get_protocols(peer_id_b)
assert protocols is not None
assert len(protocols) >= 1 # Should include the dummy protocol
@pytest.mark.trio
async def test_identify_push_cross_format_compatibility_varint_to_raw(
security_protocol,
):
"""Test varint pusher with raw listener compatibility."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Use an event to signal when handler is ready
handler_ready = trio.Event()
# Create a wrapper handler that signals when ready
original_handler = identify_push_handler_for(host_a, use_varint_format=False)
async def wrapped_handler(stream):
handler_ready.set() # Signal that handler is ready
await original_handler(stream)
# Host A uses raw format with wrapped handler
host_a.set_stream_handler(ID_PUSH, wrapped_handler)
# Host B pushes with varint format (should fail gracefully)
success = await push_identify_to_peer(
host_b, host_a.get_id(), use_varint_format=True
)
# This should fail due to format mismatch
# Note: The format detection might be more robust than expected
# so we just check that the operation completes
assert isinstance(success, bool)
@pytest.mark.trio
async def test_identify_push_cross_format_compatibility_raw_to_varint(
security_protocol,
):
"""Test raw pusher with varint listener compatibility."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Use an event to signal when handler is ready
handler_ready = trio.Event()
# Create a wrapper handler that signals when ready
original_handler = identify_push_handler_for(host_a, use_varint_format=True)
async def wrapped_handler(stream):
handler_ready.set() # Signal that handler is ready
await original_handler(stream)
# Host A uses varint format with wrapped handler
host_a.set_stream_handler(ID_PUSH, wrapped_handler)
# Host B pushes with raw format (should fail gracefully)
success = await push_identify_to_peer(
host_b, host_a.get_id(), use_varint_format=False
)
# This should fail due to format mismatch
# Note: The format detection might be more robust than expected
# so we just check that the operation completes
assert isinstance(success, bool)
@pytest.mark.trio
async def test_identify_push_multiple_peers_integration(security_protocol):
"""Test identify/push protocol with multiple peers."""
# Create two hosts using the factory
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create a third host following the pattern from test_identify_push.py
import multiaddr
from libp2p import new_host
from libp2p.crypto.secp256k1 import create_new_key_pair
from libp2p.peer.peerinfo import info_from_p2p_addr
# Create a new key pair for host_c
key_pair_c = create_new_key_pair()
host_c = new_host(key_pair=key_pair_c)
# Set up identify/push handlers on all hosts
host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a))
host_b.set_stream_handler(ID_PUSH, identify_push_handler_for(host_b))
host_c.set_stream_handler(ID_PUSH, identify_push_handler_for(host_c))
# Start listening on a random port using the run context manager
listen_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")
async with host_c.run([listen_addr]):
# Connect host_c to host_a and host_b using the correct pattern
await host_c.connect(info_from_p2p_addr(host_a.get_addrs()[0]))
await host_c.connect(info_from_p2p_addr(host_b.get_addrs()[0]))
# Push identify information from host_a to all connected peers
await push_identify_to_peers(host_a)
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Check that host_b's peerstore has been updated with host_a's information
peerstore_b = host_b.get_peerstore()
peer_id_a = host_a.get_id()
# Check that the peer is in the peerstore
assert peer_id_a in peerstore_b.peer_ids()
# Check that host_c's peerstore has been updated with host_a's information
peerstore_c = host_c.get_peerstore()
# Check that the peer is in the peerstore
assert peer_id_a in peerstore_c.peer_ids()
# Test for push_identify to only connected peers and not all peers
# Disconnect a from c.
await host_c.disconnect(host_a.get_id())
await push_identify_to_peers(host_c)
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Check that host_a's peerstore has not been updated with host_c's info
assert host_c.get_id() not in host_a.get_peerstore().peer_ids()
# Check that host_b's peerstore has been updated with host_c's info
assert host_c.get_id() in host_b.get_peerstore().peer_ids()
@pytest.mark.trio
async def test_identify_push_large_message_handling(security_protocol):
"""Test identify/push protocol handles large messages with many protocols."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Add many protocols to make the message larger
async def dummy_handler(stream):
pass
for i in range(10):
host_b.set_stream_handler(TProtocol(f"/test/protocol/{i}"), dummy_handler)
# Also add some protocols to host_a to ensure it has protocols to push
for i in range(5):
host_a.set_stream_handler(TProtocol(f"/test/protocol/a{i}"), dummy_handler)
# Set up identify/push handler on host_a
host_a.set_stream_handler(
ID_PUSH, identify_push_handler_for(host_a, use_varint_format=True)
)
# Push identify information from host_b to host_a
success = await push_identify_to_peer(
host_b, host_a.get_id(), use_varint_format=True
)
assert success
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Verify that host_a's peerstore was updated with all protocols
peerstore_a = host_a.get_peerstore()
peer_id_b = host_b.get_id()
protocols = peerstore_a.get_protocols(peer_id_b)
assert protocols is not None
assert len(protocols) >= 10 # Should include the dummy protocols
@pytest.mark.trio
async def test_identify_push_peerstore_update_completeness(security_protocol):
"""Test that identify/push updates all relevant peerstore information."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Add some protocols to both hosts
async def dummy_handler(stream):
pass
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
# Set up identify/push handler on host_a
host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a))
# Push identify information from host_b to host_a
await push_identify_to_peer(host_b, host_a.get_id())
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Verify that host_a's peerstore was updated
peerstore_a = host_a.get_peerstore()
peer_id_b = host_b.get_id()
# Check that protocols were added
protocols = peerstore_a.get_protocols(peer_id_b)
assert protocols is not None
assert len(protocols) > 0
# Check that addresses were added
addrs = peerstore_a.addrs(peer_id_b)
assert len(addrs) > 0
# Check that public key was added
pubkey = peerstore_a.pubkey(peer_id_b)
assert pubkey is not None
assert pubkey.serialize() == host_b.get_public_key().serialize()
@pytest.mark.trio
async def test_identify_push_concurrent_requests(security_protocol):
"""Test identify/push protocol handles concurrent requests properly."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Add some protocols to both hosts
async def dummy_handler(stream):
pass
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
# Set up identify/push handler on host_a
host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a))
# Make multiple concurrent push requests
results = []
async def push_identify():
result = await push_identify_to_peer(host_b, host_a.get_id())
results.append(result)
# Run multiple concurrent pushes using nursery
async with trio.open_nursery() as nursery:
for _ in range(3):
nursery.start_soon(push_identify)
# All should succeed
assert len(results) == 3
assert all(results)
# Wait a bit for the pushes to complete
await trio.sleep(0.1)
# Verify that host_a's peerstore was updated
peerstore_a = host_a.get_peerstore()
peer_id_b = host_b.get_id()
protocols = peerstore_a.get_protocols(peer_id_b)
assert protocols is not None
assert len(protocols) > 0
@pytest.mark.trio
async def test_identify_push_stream_handling(security_protocol):
"""Test identify/push protocol properly handles stream lifecycle."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Add some protocols to both hosts
async def dummy_handler(stream):
pass
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
# Set up identify/push handler on host_a
host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a))
# Push identify information from host_b to host_a
success = await push_identify_to_peer(host_b, host_a.get_id())
assert success
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Verify that host_a's peerstore was updated
peerstore_a = host_a.get_peerstore()
peer_id_b = host_b.get_id()
protocols = peerstore_a.get_protocols(peer_id_b)
assert protocols is not None
assert len(protocols) > 0
@pytest.mark.trio
async def test_identify_push_error_handling(security_protocol):
"""Test identify/push protocol handles errors gracefully."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create a handler that raises an exception but catches it to prevent test
# failure
async def error_handler(stream):
try:
await stream.close()
raise Exception("Test error")
except Exception:
# Catch the exception to prevent it from propagating up
pass
host_a.set_stream_handler(ID_PUSH, error_handler)
# Push should complete (message sent) but handler should fail gracefully
success = await push_identify_to_peer(host_b, host_a.get_id())
assert success # The push operation itself succeeds (message sent)
# Wait a bit for the handler to process
await trio.sleep(0.1)
# Verify that the error was handled gracefully (no test failure)
# The handler caught the exception and didn't propagate it
@pytest.mark.trio
async def test_identify_push_message_equivalence_real_network(security_protocol):
"""Test that both formats produce equivalent peerstore updates in real network."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Add some protocols to both hosts
async def dummy_handler(stream):
pass
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
# Test varint format
host_a.set_stream_handler(
ID_PUSH, identify_push_handler_for(host_a, use_varint_format=True)
)
await push_identify_to_peer(host_b, host_a.get_id(), use_varint_format=True)
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Get peerstore state after varint push
peerstore_a = host_a.get_peerstore()
peer_id_b = host_b.get_id()
protocols_varint = peerstore_a.get_protocols(peer_id_b)
addrs_varint = peerstore_a.addrs(peer_id_b)
# Clear peerstore for next test
peerstore_a.clear_addrs(peer_id_b)
peerstore_a.clear_protocol_data(peer_id_b)
# Test raw format
host_a.set_stream_handler(
ID_PUSH, identify_push_handler_for(host_a, use_varint_format=False)
)
await push_identify_to_peer(host_b, host_a.get_id(), use_varint_format=False)
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Get peerstore state after raw push
protocols_raw = peerstore_a.get_protocols(peer_id_b)
addrs_raw = peerstore_a.addrs(peer_id_b)
# Both should produce equivalent peerstore updates
# Check that both formats successfully updated protocols
assert protocols_varint is not None
assert protocols_raw is not None
assert len(protocols_varint) > 0
assert len(protocols_raw) > 0
# Check that both formats successfully updated addresses
assert addrs_varint is not None
assert addrs_raw is not None
assert len(addrs_varint) > 0
assert len(addrs_raw) > 0
# Both should contain the same essential information
# (exact address lists might differ due to format-specific handling)
assert set(protocols_varint) == set(protocols_raw)
@pytest.mark.trio
async def test_identify_push_with_observed_address(security_protocol):
"""Test identify/push protocol includes observed address information."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Add some protocols to both hosts
async def dummy_handler(stream):
pass
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
# Set up identify/push handler on host_a
host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a))
# Get host_b's address as observed by host_a
from multiaddr import Multiaddr
host_b_addr = host_b.get_addrs()[0]
observed_multiaddr = Multiaddr(str(host_b_addr))
# Push identify information with observed address
await push_identify_to_peer(
host_b, host_a.get_id(), observed_multiaddr=observed_multiaddr
)
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Verify that host_a's peerstore was updated
peerstore_a = host_a.get_peerstore()
peer_id_b = host_b.get_id()
# Check that addresses were added
addrs = peerstore_a.addrs(peer_id_b)
assert len(addrs) > 0
# The observed address should be among the stored addresses
addr_strings = [str(addr) for addr in addrs]
assert str(observed_multiaddr) in addr_strings

View File

@ -3,7 +3,10 @@ from multiaddr import (
Multiaddr,
)
from libp2p.crypto.rsa import create_new_key_pair
from libp2p.peer.envelope import Envelope, seal_record
from libp2p.peer.id import ID
from libp2p.peer.peer_record import PeerRecord
from libp2p.peer.peerstore import (
PeerStore,
PeerStoreError,
@ -84,3 +87,53 @@ def test_peers_with_addrs():
store.clear_addrs(ID(b"peer2"))
assert set(store.peers_with_addrs()) == {ID(b"peer3")}
def test_ceritified_addr_book():
store = PeerStore()
key_pair = create_new_key_pair()
peer_id = ID.from_pubkey(key_pair.public_key)
addrs = [
Multiaddr("/ip4/127.0.0.1/tcp/9000"),
Multiaddr("/ip4/127.0.0.1/tcp/9001"),
]
ttl = 60
# Construct signed PereRecord
record = PeerRecord(peer_id, addrs, 21)
envelope = seal_record(record, key_pair.private_key)
result = store.consume_peer_record(envelope, ttl)
assert result is True
# Retrieve the record
retrieved = store.get_peer_record(peer_id)
assert retrieved is not None
assert isinstance(retrieved, Envelope)
addr_list = store.addrs(peer_id)
assert set(addr_list) == set(addrs)
# Now try to push an older record (should be rejected)
old_record = PeerRecord(peer_id, [Multiaddr("/ip4/10.0.0.1/tcp/4001")], 20)
old_envelope = seal_record(old_record, key_pair.private_key)
result = store.consume_peer_record(old_envelope, ttl)
assert result is False
# Push a new record (should override)
new_addrs = [Multiaddr("/ip4/192.168.0.1/tcp/5001")]
new_record = PeerRecord(peer_id, new_addrs, 23)
new_envelope = seal_record(new_record, key_pair.private_key)
result = store.consume_peer_record(new_envelope, ttl)
assert result is True
# Confirm the record is updated
latest = store.get_peer_record(peer_id)
assert isinstance(latest, Envelope)
assert latest.record().seq == 23
# Merged addresses = old addres + new_addrs
expected_addrs = set(new_addrs)
actual_addrs = set(store.addrs(peer_id))
assert actual_addrs == expected_addrs

View File

@ -0,0 +1,129 @@
from multiaddr import Multiaddr
from libp2p.crypto.rsa import (
create_new_key_pair,
)
from libp2p.peer.envelope import (
Envelope,
consume_envelope,
make_unsigned,
seal_record,
unmarshal_envelope,
)
from libp2p.peer.id import ID
import libp2p.peer.pb.crypto_pb2 as crypto_pb
import libp2p.peer.pb.envelope_pb2 as env_pb
from libp2p.peer.peer_record import PeerRecord
DOMAIN = "libp2p-peer-record"
def test_basic_protobuf_serialization_deserialization():
pubkey = crypto_pb.PublicKey()
pubkey.Type = crypto_pb.KeyType.Ed25519
pubkey.Data = b"\x01\x02\x03"
env = env_pb.Envelope()
env.public_key.CopyFrom(pubkey)
env.payload_type = b"\x03\x01"
env.payload = b"test-payload"
env.signature = b"signature-bytes"
serialized = env.SerializeToString()
new_env = env_pb.Envelope()
new_env.ParseFromString(serialized)
assert new_env.public_key.Type == crypto_pb.KeyType.Ed25519
assert new_env.public_key.Data == b"\x01\x02\x03"
assert new_env.payload_type == b"\x03\x01"
assert new_env.payload == b"test-payload"
assert new_env.signature == b"signature-bytes"
def test_enevelope_marshal_unmarshal_roundtrip():
keypair = create_new_key_pair()
pubkey = keypair.public_key
private_key = keypair.private_key
payload_type = b"\x03\x01"
payload = b"test-record"
sig = private_key.sign(make_unsigned(DOMAIN, payload_type, payload))
env = Envelope(pubkey, payload_type, payload, sig)
serialized = env.marshal_envelope()
new_env = unmarshal_envelope(serialized)
assert new_env.public_key == pubkey
assert new_env.payload_type == payload_type
assert new_env.raw_payload == payload
assert new_env.signature == sig
def test_seal_and_consume_envelope_roundtrip():
keypair = create_new_key_pair()
priv_key = keypair.private_key
pub_key = keypair.public_key
peer_id = ID.from_pubkey(pub_key)
addrs = [Multiaddr("/ip4/127.0.0.1/tcp/4001"), Multiaddr("/ip4/127.0.0.1/tcp/4002")]
seq = 12345
record = PeerRecord(peer_id=peer_id, addrs=addrs, seq=seq)
# Seal
envelope = seal_record(record, priv_key)
serialized = envelope.marshal_envelope()
# Consume
env, rec = consume_envelope(serialized, record.domain())
# Assertions
assert env.public_key == pub_key
assert rec.peer_id == peer_id
assert rec.seq == seq
assert rec.addrs == addrs
def test_envelope_equal():
# Create a new keypair
keypair = create_new_key_pair()
private_key = keypair.private_key
# Create a mock PeerRecord
record = PeerRecord(
peer_id=ID.from_base58("QmNM23MiU1Kd7yfiKVdUnaDo8RYca8By4zDmr7uSaVV8Px"),
seq=1,
addrs=[Multiaddr("/ip4/127.0.0.1/tcp/4001")],
)
# Seal it into an Envelope
env1 = seal_record(record, private_key)
# Create a second identical envelope
env2 = Envelope(
public_key=env1.public_key,
payload_type=env1.payload_type,
raw_payload=env1.raw_payload,
signature=env1.signature,
)
# They should be equal
assert env1.equal(env2)
# Now change something — payload type
env2.payload_type = b"\x99\x99"
assert not env1.equal(env2)
# Restore payload_type but change signature
env2.payload_type = env1.payload_type
env2.signature = b"wrong-signature"
assert not env1.equal(env2)
# Restore signature but change payload
env2.signature = env1.signature
env2.raw_payload = b"tampered"
assert not env1.equal(env2)
# Finally, test with a non-envelope object
assert not env1.equal("not-an-envelope")

View File

@ -0,0 +1,112 @@
import time
from multiaddr import Multiaddr
from libp2p.peer.id import ID
import libp2p.peer.pb.peer_record_pb2 as pb
from libp2p.peer.peer_record import (
PeerRecord,
addrs_from_protobuf,
peer_record_from_protobuf,
unmarshal_record,
)
# Testing methods from PeerRecord base class and PeerRecord protobuf:
def test_basic_protobuf_serializatrion_deserialization():
record = pb.PeerRecord()
record.seq = 1
serialized = record.SerializeToString()
new_record = pb.PeerRecord()
new_record.ParseFromString(serialized)
assert new_record.seq == 1
def test_timestamp_seq_monotonicity():
rec1 = PeerRecord()
time.sleep(1)
rec2 = PeerRecord()
assert isinstance(rec1.seq, int)
assert isinstance(rec2.seq, int)
assert rec2.seq > rec1.seq, f"Expected seq2 ({rec2.seq}) > seq1 ({rec1.seq})"
def test_addrs_from_protobuf_multiple_addresses():
ma1 = Multiaddr("/ip4/127.0.0.1/tcp/4001")
ma2 = Multiaddr("/ip4/127.0.0.1/tcp/4002")
addr_info1 = pb.PeerRecord.AddressInfo()
addr_info1.multiaddr = ma1.to_bytes()
addr_info2 = pb.PeerRecord.AddressInfo()
addr_info2.multiaddr = ma2.to_bytes()
result = addrs_from_protobuf([addr_info1, addr_info2])
assert result == [ma1, ma2]
def test_peer_record_from_protobuf():
peer_id = ID.from_base58("QmNM23MiU1Kd7yfiKVdUnaDo8RYca8By4zDmr7uSaVV8Px")
record = pb.PeerRecord()
record.peer_id = peer_id.to_bytes()
record.seq = 42
for addr_str in ["/ip4/127.0.0.1/tcp/4001", "/ip4/127.0.0.1/tcp/4002"]:
ma = Multiaddr(addr_str)
addr_info = pb.PeerRecord.AddressInfo()
addr_info.multiaddr = ma.to_bytes()
record.addresses.append(addr_info)
result = peer_record_from_protobuf(record)
assert result.peer_id == peer_id
assert result.seq == 42
assert len(result.addrs) == 2
assert str(result.addrs[0]) == "/ip4/127.0.0.1/tcp/4001"
assert str(result.addrs[1]) == "/ip4/127.0.0.1/tcp/4002"
def test_to_protobuf_generates_correct_message():
peer_id = ID.from_base58("QmNM23MiU1Kd7yfiKVdUnaDo8RYca8By4zDmr7uSaVV8Px")
addrs = [Multiaddr("/ip4/127.0.0.1/tcp/4001")]
seq = 12345
record = PeerRecord(peer_id, addrs, seq)
proto = record.to_protobuf()
assert isinstance(proto, pb.PeerRecord)
assert proto.peer_id == peer_id.to_bytes()
assert proto.seq == seq
assert len(proto.addresses) == 1
assert proto.addresses[0].multiaddr == addrs[0].to_bytes()
def test_unmarshal_record_roundtrip():
record = PeerRecord(
peer_id=ID.from_base58("QmNM23MiU1Kd7yfiKVdUnaDo8RYca8By4zDmr7uSaVV8Px"),
addrs=[Multiaddr("/ip4/127.0.0.1/tcp/4001")],
seq=999,
)
serialized = record.to_protobuf().SerializeToString()
deserialized = unmarshal_record(serialized)
assert deserialized.peer_id == record.peer_id
assert deserialized.seq == record.seq
assert len(deserialized.addrs) == 1
assert deserialized.addrs[0] == record.addrs[0]
def test_marshal_record_and_equal():
peer_id = ID.from_base58("QmNM23MiU1Kd7yfiKVdUnaDo8RYca8By4zDmr7uSaVV8Px")
addrs = [Multiaddr("/ip4/127.0.0.1/tcp/4001")]
original = PeerRecord(peer_id, addrs)
serialized = original.marshal_record()
deserailzed = unmarshal_record(serialized)
assert original.equal(deserailzed)

View File

@ -120,3 +120,30 @@ async def test_addr_stream_yields_new_addrs():
nursery.cancel_scope.cancel()
assert collected == [addr1, addr2]
@pytest.mark.trio
async def test_cleanup_task_remove_expired_data():
store = PeerStore()
peer_id = ID(b"peer123")
addr = Multiaddr("/ip4/127.0.0.1/tcp/4040")
# Insert addrs with short TTL (0.01s)
store.add_addr(peer_id, addr, 1)
assert store.addrs(peer_id) == [addr]
assert peer_id in store.peer_data_map
# Start cleanup task in a nursery
async with trio.open_nursery() as nursery:
# Run the cleanup task with a short interval so it runs soon
nursery.start_soon(store.start_cleanup_task, 1)
# Sleep long enough for TTL to expire and cleanup to run
await trio.sleep(3)
# Cancel the nursery to stop background tasks
nursery.cancel_scope.cancel()
# Confirm the peer data is gone from the peer_data_map
assert peer_id not in store.peer_data_map

View File

@ -5,10 +5,12 @@ import inspect
from typing import (
NamedTuple,
)
from unittest.mock import patch
import pytest
import trio
from libp2p.custom_types import AsyncValidatorFn
from libp2p.exceptions import (
ValidationError,
)
@ -243,7 +245,37 @@ async def test_get_msg_validators():
((False, True), (True, False), (True, True)),
)
@pytest.mark.trio
async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed):
async def test_validate_msg_with_throttle_condition(
is_topic_1_val_passed, is_topic_2_val_passed
):
CONCURRENCY_LIMIT = 10
state = {
"concurrency_counter": 0,
"max_observed": 0,
}
lock = trio.Lock()
async def mock_run_async_validator(
self,
func: AsyncValidatorFn,
msg_forwarder: ID,
msg: rpc_pb2.Message,
results: list[bool],
) -> None:
async with self._validator_semaphore:
async with lock:
state["concurrency_counter"] += 1
if state["concurrency_counter"] > state["max_observed"]:
state["max_observed"] = state["concurrency_counter"]
try:
result = await func(msg_forwarder, msg)
results.append(result)
finally:
async with lock:
state["concurrency_counter"] -= 1
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
def passed_sync_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool:
@ -280,11 +312,19 @@ async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed):
seqno=b"\x00" * 8,
)
if is_topic_1_val_passed and is_topic_2_val_passed:
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
else:
with pytest.raises(ValidationError):
with patch(
"libp2p.pubsub.pubsub.Pubsub._run_async_validator",
new=mock_run_async_validator,
):
if is_topic_1_val_passed and is_topic_2_val_passed:
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
else:
with pytest.raises(ValidationError):
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
assert state["max_observed"] <= CONCURRENCY_LIMIT, (
f"Max concurrency observed: {state['max_observed']}"
)
@pytest.mark.trio

View File

@ -0,0 +1,563 @@
"""Integration tests for DCUtR protocol with real libp2p hosts using circuit relay."""
import logging
from unittest.mock import AsyncMock, MagicMock
import pytest
from multiaddr import Multiaddr
import trio
from libp2p.relay.circuit_v2.dcutr import (
MAX_HOLE_PUNCH_ATTEMPTS,
PROTOCOL_ID,
DCUtRProtocol,
)
from libp2p.relay.circuit_v2.pb.dcutr_pb2 import (
HolePunch,
)
from libp2p.relay.circuit_v2.protocol import (
DEFAULT_RELAY_LIMITS,
CircuitV2Protocol,
)
from libp2p.tools.async_service import (
background_trio_service,
)
from tests.utils.factories import (
HostFactory,
)
logger = logging.getLogger(__name__)
# Test timeouts
SLEEP_TIME = 0.5 # seconds
@pytest.mark.trio
async def test_dcutr_through_relay_connection():
"""
Test DCUtR protocol where peers are connected via relay,
then upgrade to direct.
"""
# Create three hosts: two peers and one relay
async with HostFactory.create_batch_and_listen(3) as hosts:
peer1, peer2, relay = hosts
# Create circuit relay protocol for the relay
relay_protocol = CircuitV2Protocol(relay, DEFAULT_RELAY_LIMITS, allow_hop=True)
# Create DCUtR protocols for both peers
dcutr1 = DCUtRProtocol(peer1)
dcutr2 = DCUtRProtocol(peer2)
# Track if DCUtR stream handlers were called
handler1_called = False
handler2_called = False
# Override stream handlers to track calls
original_handler1 = dcutr1._handle_dcutr_stream
original_handler2 = dcutr2._handle_dcutr_stream
async def tracked_handler1(stream):
nonlocal handler1_called
handler1_called = True
await original_handler1(stream)
async def tracked_handler2(stream):
nonlocal handler2_called
handler2_called = True
await original_handler2(stream)
dcutr1._handle_dcutr_stream = tracked_handler1
dcutr2._handle_dcutr_stream = tracked_handler2
# Start all protocols
async with background_trio_service(relay_protocol):
async with background_trio_service(dcutr1):
async with background_trio_service(dcutr2):
await relay_protocol.event_started.wait()
await dcutr1.event_started.wait()
await dcutr2.event_started.wait()
# Connect both peers to the relay
relay_addrs = relay.get_addrs()
# Add relay addresses to both peers' peerstores
for addr in relay_addrs:
peer1.get_peerstore().add_addrs(relay.get_id(), [addr], 3600)
peer2.get_peerstore().add_addrs(relay.get_id(), [addr], 3600)
# Connect peers to relay
await peer1.connect(relay.get_peerstore().peer_info(relay.get_id()))
await peer2.connect(relay.get_peerstore().peer_info(relay.get_id()))
await trio.sleep(0.1)
# Verify peers are connected to relay
assert relay.get_id() in [
peer_id for peer_id in peer1.get_network().connections.keys()
]
assert relay.get_id() in [
peer_id for peer_id in peer2.get_network().connections.keys()
]
# Verify peers are NOT directly connected to each other
assert peer2.get_id() not in [
peer_id for peer_id in peer1.get_network().connections.keys()
]
assert peer1.get_id() not in [
peer_id for peer_id in peer2.get_network().connections.keys()
]
# Now test DCUtR: peer1 opens a DCUtR stream to peer2 through the
# relay
# This should trigger the DCUtR protocol for hole punching
try:
# Create a circuit relay multiaddr for peer2 through the relay
relay_addr = relay_addrs[0]
circuit_addr = Multiaddr(
f"{relay_addr}/p2p-circuit/p2p/{peer2.get_id()}"
)
# Add the circuit address to peer1's peerstore
peer1.get_peerstore().add_addrs(
peer2.get_id(), [circuit_addr], 3600
)
# Open a DCUtR stream from peer1 to peer2 through the relay
stream = await peer1.new_stream(peer2.get_id(), [PROTOCOL_ID])
# Send a CONNECT message with observed addresses
peer1_addrs = peer1.get_addrs()
connect_msg = HolePunch(
type=HolePunch.CONNECT,
ObsAddrs=[addr.to_bytes() for addr in peer1_addrs[:2]],
)
await stream.write(connect_msg.SerializeToString())
# Wait for the message to be processed
await trio.sleep(0.2)
# Verify that the DCUtR stream handler was called on peer2
assert handler2_called, (
"DCUtR stream handler should have been called on peer2"
)
# Close the stream
await stream.close()
except Exception as e:
logger.info(
"Expected error when trying to open DCUtR stream through "
"relay: %s",
e,
)
# This might fail because we need more setup, but the important
# thing is testing the right scenario
# Wait a bit more
await trio.sleep(0.1)
@pytest.mark.trio
async def test_dcutr_relay_to_direct_upgrade():
"""Test the complete flow: relay connection -> DCUtR -> direct connection."""
# Create three hosts: two peers and one relay
async with HostFactory.create_batch_and_listen(3) as hosts:
peer1, peer2, relay = hosts
# Create circuit relay protocol for the relay
relay_protocol = CircuitV2Protocol(relay, DEFAULT_RELAY_LIMITS, allow_hop=True)
# Create DCUtR protocols for both peers
dcutr1 = DCUtRProtocol(peer1)
dcutr2 = DCUtRProtocol(peer2)
# Track messages received
messages_received = []
# Override stream handler to capture messages
original_handler = dcutr2._handle_dcutr_stream
async def message_capturing_handler(stream):
try:
# Read the message
msg_data = await stream.read()
hole_punch = HolePunch()
hole_punch.ParseFromString(msg_data)
messages_received.append(hole_punch)
# Send a SYNC response
sync_msg = HolePunch(type=HolePunch.SYNC)
await stream.write(sync_msg.SerializeToString())
await original_handler(stream)
except Exception as e:
logger.error(f"Error in message capturing handler: {e}")
await stream.close()
dcutr2._handle_dcutr_stream = message_capturing_handler
# Start all protocols
async with background_trio_service(relay_protocol):
async with background_trio_service(dcutr1):
async with background_trio_service(dcutr2):
await relay_protocol.event_started.wait()
await dcutr1.event_started.wait()
await dcutr2.event_started.wait()
# Re-register the handler with the host
dcutr2.host.set_stream_handler(
PROTOCOL_ID, message_capturing_handler
)
# Connect both peers to the relay
relay_addrs = relay.get_addrs()
# Add relay addresses to both peers' peerstores
for addr in relay_addrs:
peer1.get_peerstore().add_addrs(relay.get_id(), [addr], 3600)
peer2.get_peerstore().add_addrs(relay.get_id(), [addr], 3600)
# Connect peers to relay
await peer1.connect(relay.get_peerstore().peer_info(relay.get_id()))
await peer2.connect(relay.get_peerstore().peer_info(relay.get_id()))
await trio.sleep(0.1)
# Verify peers are connected to relay but not to each other
assert relay.get_id() in [
peer_id for peer_id in peer1.get_network().connections.keys()
]
assert relay.get_id() in [
peer_id for peer_id in peer2.get_network().connections.keys()
]
assert peer2.get_id() not in [
peer_id for peer_id in peer1.get_network().connections.keys()
]
# Try to open a DCUtR stream through the relay
try:
# Create a circuit relay multiaddr for peer2 through the relay
relay_addr = relay_addrs[0]
circuit_addr = Multiaddr(
f"{relay_addr}/p2p-circuit/p2p/{peer2.get_id()}"
)
# Add the circuit address to peer1's peerstore
peer1.get_peerstore().add_addrs(
peer2.get_id(), [circuit_addr], 3600
)
# Open a DCUtR stream from peer1 to peer2 through the relay
stream = await peer1.new_stream(peer2.get_id(), [PROTOCOL_ID])
# Send a CONNECT message with observed addresses
peer1_addrs = peer1.get_addrs()
connect_msg = HolePunch(
type=HolePunch.CONNECT,
ObsAddrs=[addr.to_bytes() for addr in peer1_addrs[:2]],
)
await stream.write(connect_msg.SerializeToString())
# Wait for the message to be processed
await trio.sleep(0.2)
# Verify that the CONNECT message was received
assert len(messages_received) == 1, (
"Should have received one message"
)
assert messages_received[0].type == HolePunch.CONNECT, (
"Should have received CONNECT message"
)
assert len(messages_received[0].ObsAddrs) == 2, (
"Should have received 2 observed addresses"
)
# Close the stream
await stream.close()
except Exception as e:
logger.info(
"Expected error when trying to open DCUtR stream through "
"relay: %s",
e,
)
# Wait a bit more
await trio.sleep(0.1)
@pytest.mark.trio
async def test_dcutr_hole_punch_through_relay():
"""Test hole punching when peers are connected through relay."""
# Create three hosts: two peers and one relay
async with HostFactory.create_batch_and_listen(3) as hosts:
peer1, peer2, relay = hosts
# Create circuit relay protocol for the relay
relay_protocol = CircuitV2Protocol(relay, DEFAULT_RELAY_LIMITS, allow_hop=True)
# Create DCUtR protocols for both peers
dcutr1 = DCUtRProtocol(peer1)
dcutr2 = DCUtRProtocol(peer2)
# Start all protocols
async with background_trio_service(relay_protocol):
async with background_trio_service(dcutr1):
async with background_trio_service(dcutr2):
await relay_protocol.event_started.wait()
await dcutr1.event_started.wait()
await dcutr2.event_started.wait()
# Connect both peers to the relay
relay_addrs = relay.get_addrs()
# Add relay addresses to both peers' peerstores
for addr in relay_addrs:
peer1.get_peerstore().add_addrs(relay.get_id(), [addr], 3600)
peer2.get_peerstore().add_addrs(relay.get_id(), [addr], 3600)
# Connect peers to relay
await peer1.connect(relay.get_peerstore().peer_info(relay.get_id()))
await peer2.connect(relay.get_peerstore().peer_info(relay.get_id()))
await trio.sleep(0.1)
# Verify peers are connected to relay but not to each other
assert relay.get_id() in [
peer_id for peer_id in peer1.get_network().connections.keys()
]
assert relay.get_id() in [
peer_id for peer_id in peer2.get_network().connections.keys()
]
assert peer2.get_id() not in [
peer_id for peer_id in peer1.get_network().connections.keys()
]
# Check if there's already a direct connection (should be False)
has_direct = await dcutr1._have_direct_connection(peer2.get_id())
assert not has_direct, "Peers should not have a direct connection"
# Try to initiate a hole punch (this should work through the relay
# connection)
# In a real scenario, this would be called after establishing a
# relay connection
result = await dcutr1.initiate_hole_punch(peer2.get_id())
# This should attempt hole punching but likely fail due to no public
# addresses
# The important thing is that the DCUtR protocol logic is executed
logger.info(
"Hole punch result: %s",
result,
)
assert result is not None, "Hole punch result should not be None"
assert isinstance(result, bool), (
"Hole punch result should be a boolean"
)
# Wait a bit more
await trio.sleep(0.1)
@pytest.mark.trio
async def test_dcutr_relay_connection_verification():
"""Test that DCUtR works correctly when peers are connected via relay."""
# Create three hosts: two peers and one relay
async with HostFactory.create_batch_and_listen(3) as hosts:
peer1, peer2, relay = hosts
# Create circuit relay protocol for the relay
relay_protocol = CircuitV2Protocol(relay, DEFAULT_RELAY_LIMITS, allow_hop=True)
# Create DCUtR protocols for both peers
dcutr1 = DCUtRProtocol(peer1)
dcutr2 = DCUtRProtocol(peer2)
# Start all protocols
async with background_trio_service(relay_protocol):
async with background_trio_service(dcutr1):
async with background_trio_service(dcutr2):
await relay_protocol.event_started.wait()
await dcutr1.event_started.wait()
await dcutr2.event_started.wait()
# Connect both peers to the relay
relay_addrs = relay.get_addrs()
# Add relay addresses to both peers' peerstores
for addr in relay_addrs:
peer1.get_peerstore().add_addrs(relay.get_id(), [addr], 3600)
peer2.get_peerstore().add_addrs(relay.get_id(), [addr], 3600)
# Connect peers to relay
await peer1.connect(relay.get_peerstore().peer_info(relay.get_id()))
await peer2.connect(relay.get_peerstore().peer_info(relay.get_id()))
await trio.sleep(0.1)
# Verify peers are connected to relay
assert relay.get_id() in [
peer_id for peer_id in peer1.get_network().connections.keys()
]
assert relay.get_id() in [
peer_id for peer_id in peer2.get_network().connections.keys()
]
# Verify peers are NOT directly connected to each other
assert peer2.get_id() not in [
peer_id for peer_id in peer1.get_network().connections.keys()
]
assert peer1.get_id() not in [
peer_id for peer_id in peer2.get_network().connections.keys()
]
# Test getting observed addresses (real implementation)
observed_addrs1 = await dcutr1._get_observed_addrs()
observed_addrs2 = await dcutr2._get_observed_addrs()
assert isinstance(observed_addrs1, list)
assert isinstance(observed_addrs2, list)
# Should contain the hosts' actual addresses
assert len(observed_addrs1) > 0, (
"Peer1 should have observed addresses"
)
assert len(observed_addrs2) > 0, (
"Peer2 should have observed addresses"
)
# Test decoding observed addresses
test_addrs = [
Multiaddr("/ip4/127.0.0.1/tcp/1234").to_bytes(),
Multiaddr("/ip4/192.168.1.1/tcp/5678").to_bytes(),
b"invalid-addr", # This should be filtered out
]
decoded = dcutr1._decode_observed_addrs(test_addrs)
assert len(decoded) == 2, "Should decode 2 valid addresses"
assert all(str(addr).startswith("/ip4/") for addr in decoded)
# Wait a bit more
await trio.sleep(0.1)
@pytest.mark.trio
async def test_dcutr_relay_error_handling():
"""Test DCUtR error handling when working through relay connections."""
# Create three hosts: two peers and one relay
async with HostFactory.create_batch_and_listen(3) as hosts:
peer1, peer2, relay = hosts
# Create circuit relay protocol for the relay
relay_protocol = CircuitV2Protocol(relay, DEFAULT_RELAY_LIMITS, allow_hop=True)
# Create DCUtR protocols for both peers
dcutr1 = DCUtRProtocol(peer1)
dcutr2 = DCUtRProtocol(peer2)
# Start all protocols
async with background_trio_service(relay_protocol):
async with background_trio_service(dcutr1):
async with background_trio_service(dcutr2):
await relay_protocol.event_started.wait()
await dcutr1.event_started.wait()
await dcutr2.event_started.wait()
# Connect both peers to the relay
relay_addrs = relay.get_addrs()
# Add relay addresses to both peers' peerstores
for addr in relay_addrs:
peer1.get_peerstore().add_addrs(relay.get_id(), [addr], 3600)
peer2.get_peerstore().add_addrs(relay.get_id(), [addr], 3600)
# Connect peers to relay
await peer1.connect(relay.get_peerstore().peer_info(relay.get_id()))
await peer2.connect(relay.get_peerstore().peer_info(relay.get_id()))
await trio.sleep(0.1)
# Test with a stream that times out
timeout_stream = MagicMock()
timeout_stream.muxed_conn.peer_id = peer2.get_id()
timeout_stream.read = AsyncMock(side_effect=trio.TooSlowError())
timeout_stream.write = AsyncMock()
timeout_stream.close = AsyncMock()
# This should not raise an exception, just log and close
await dcutr1._handle_dcutr_stream(timeout_stream)
# Verify stream was closed
assert timeout_stream.close.called
# Test with malformed message
malformed_stream = MagicMock()
malformed_stream.muxed_conn.peer_id = peer2.get_id()
malformed_stream.read = AsyncMock(return_value=b"not-a-protobuf")
malformed_stream.write = AsyncMock()
malformed_stream.close = AsyncMock()
# This should not raise an exception, just log and close
await dcutr1._handle_dcutr_stream(malformed_stream)
# Verify stream was closed
assert malformed_stream.close.called
# Wait a bit more
await trio.sleep(0.1)
@pytest.mark.trio
async def test_dcutr_relay_attempt_limiting():
"""Test DCUtR attempt limiting when working through relay connections."""
# Create three hosts: two peers and one relay
async with HostFactory.create_batch_and_listen(3) as hosts:
peer1, peer2, relay = hosts
# Create circuit relay protocol for the relay
relay_protocol = CircuitV2Protocol(relay, DEFAULT_RELAY_LIMITS, allow_hop=True)
# Create DCUtR protocols for both peers
dcutr1 = DCUtRProtocol(peer1)
dcutr2 = DCUtRProtocol(peer2)
# Start all protocols
async with background_trio_service(relay_protocol):
async with background_trio_service(dcutr1):
async with background_trio_service(dcutr2):
await relay_protocol.event_started.wait()
await dcutr1.event_started.wait()
await dcutr2.event_started.wait()
# Connect both peers to the relay
relay_addrs = relay.get_addrs()
# Add relay addresses to both peers' peerstores
for addr in relay_addrs:
peer1.get_peerstore().add_addrs(relay.get_id(), [addr], 3600)
peer2.get_peerstore().add_addrs(relay.get_id(), [addr], 3600)
# Connect peers to relay
await peer1.connect(relay.get_peerstore().peer_info(relay.get_id()))
await peer2.connect(relay.get_peerstore().peer_info(relay.get_id()))
await trio.sleep(0.1)
# Set max attempts reached
dcutr1._hole_punch_attempts[peer2.get_id()] = (
MAX_HOLE_PUNCH_ATTEMPTS
)
# Try to initiate hole punch - should fail due to max attempts
result = await dcutr1.initiate_hole_punch(peer2.get_id())
assert result is False, "Hole punch should fail due to max attempts"
# Reset attempts
dcutr1._hole_punch_attempts.clear()
# Add to direct connections
dcutr1._direct_connections.add(peer2.get_id())
# Try to initiate hole punch - should succeed immediately
result = await dcutr1.initiate_hole_punch(peer2.get_id())
assert result is True, (
"Hole punch should succeed for already connected peers"
)
# Wait a bit more
await trio.sleep(0.1)

View File

@ -0,0 +1,208 @@
"""Unit tests for DCUtR protocol."""
import logging
from unittest.mock import AsyncMock, MagicMock
import pytest
import trio
from libp2p.abc import INetStream
from libp2p.peer.id import ID
from libp2p.relay.circuit_v2.dcutr import (
MAX_HOLE_PUNCH_ATTEMPTS,
DCUtRProtocol,
)
from libp2p.relay.circuit_v2.pb.dcutr_pb2 import HolePunch
from libp2p.tools.async_service import background_trio_service
logger = logging.getLogger(__name__)
@pytest.mark.trio
async def test_dcutr_protocol_initialization():
"""Test DCUtR protocol initialization."""
mock_host = MagicMock()
dcutr = DCUtRProtocol(mock_host)
# Test that the protocol is initialized correctly
assert dcutr.host == mock_host
assert not dcutr.event_started.is_set()
assert dcutr._hole_punch_attempts == {}
assert dcutr._direct_connections == set()
assert dcutr._in_progress == set()
# Test that the protocol can be started
async with background_trio_service(dcutr):
# Wait for the protocol to start
await dcutr.event_started.wait()
# Verify that the stream handler was registered
mock_host.set_stream_handler.assert_called_once()
# Verify that the event is set
assert dcutr.event_started.is_set()
@pytest.mark.trio
async def test_dcutr_message_exchange():
"""Test DCUtR message exchange."""
mock_host = MagicMock()
dcutr = DCUtRProtocol(mock_host)
# Test that the protocol can be started
async with background_trio_service(dcutr):
# Wait for the protocol to start
await dcutr.event_started.wait()
# Test CONNECT message
connect_msg = HolePunch(
type=HolePunch.CONNECT,
ObsAddrs=[b"/ip4/127.0.0.1/tcp/1234", b"/ip4/192.168.1.1/tcp/5678"],
)
# Test SYNC message
sync_msg = HolePunch(type=HolePunch.SYNC)
# Verify message types
assert connect_msg.type == HolePunch.CONNECT
assert sync_msg.type == HolePunch.SYNC
assert len(connect_msg.ObsAddrs) == 2
@pytest.mark.trio
async def test_dcutr_error_handling(monkeypatch):
"""Test DCUtR error handling."""
mock_host = MagicMock()
dcutr = DCUtRProtocol(mock_host)
async with background_trio_service(dcutr):
await dcutr.event_started.wait()
# Simulate a stream that times out
class TimeoutStream(INetStream):
def __init__(self):
self._protocol = None
self.muxed_conn = MagicMock(peer_id=ID(b"peer"))
async def read(self, n: int | None = None) -> bytes:
await trio.sleep(0.2)
raise trio.TooSlowError()
async def write(self, data: bytes) -> None:
return None
async def close(self, *args, **kwargs):
return None
async def reset(self):
return None
def get_protocol(self):
return self._protocol
def set_protocol(self, protocol_id):
self._protocol = protocol_id
def get_remote_address(self):
return ("127.0.0.1", 1234)
# Should not raise, just log and close
await dcutr._handle_dcutr_stream(TimeoutStream())
# Simulate a stream with malformed message
class MalformedStream(INetStream):
def __init__(self):
self._protocol = None
self.muxed_conn = MagicMock(peer_id=ID(b"peer"))
async def read(self, n: int | None = None) -> bytes:
return b"not-a-protobuf"
async def write(self, data: bytes) -> None:
return None
async def close(self, *args, **kwargs):
return None
async def reset(self):
return None
def get_protocol(self):
return self._protocol
def set_protocol(self, protocol_id):
self._protocol = protocol_id
def get_remote_address(self):
return ("127.0.0.1", 1234)
await dcutr._handle_dcutr_stream(MalformedStream())
@pytest.mark.trio
async def test_dcutr_max_attempts_and_already_connected():
"""Test max hole punch attempts and already-connected peer."""
mock_host = MagicMock()
dcutr = DCUtRProtocol(mock_host)
peer_id = ID(b"peer")
# Simulate already having a direct connection
dcutr._direct_connections.add(peer_id)
result = await dcutr.initiate_hole_punch(peer_id)
assert result is True
# Remove direct connection, simulate max attempts
dcutr._direct_connections.clear()
dcutr._hole_punch_attempts[peer_id] = MAX_HOLE_PUNCH_ATTEMPTS
result = await dcutr.initiate_hole_punch(peer_id)
assert result is False
@pytest.mark.trio
async def test_dcutr_observed_addr_encoding_decoding():
"""Test observed address encoding/decoding."""
from multiaddr import Multiaddr
mock_host = MagicMock()
dcutr = DCUtRProtocol(mock_host)
# Simulate valid and invalid multiaddrs as bytes
valid = [
Multiaddr("/ip4/127.0.0.1/tcp/1234").to_bytes(),
Multiaddr("/ip4/192.168.1.1/tcp/5678").to_bytes(),
]
invalid = [b"not-a-multiaddr", b""]
decoded = dcutr._decode_observed_addrs(valid + invalid)
assert len(decoded) == 2
@pytest.mark.trio
async def test_dcutr_real_perform_hole_punch(monkeypatch):
"""Test initiate_hole_punch with real _perform_hole_punch logic (mock network)."""
mock_host = MagicMock()
dcutr = DCUtRProtocol(mock_host)
peer_id = ID(b"peer")
# Patch methods to simulate a successful punch
monkeypatch.setattr(dcutr, "_have_direct_connection", AsyncMock(return_value=False))
monkeypatch.setattr(
dcutr,
"_get_observed_addrs",
AsyncMock(return_value=[b"/ip4/127.0.0.1/tcp/1234"]),
)
mock_stream = MagicMock()
mock_stream.read = AsyncMock(
side_effect=[
HolePunch(
type=HolePunch.CONNECT, ObsAddrs=[b"/ip4/192.168.1.1/tcp/4321"]
).SerializeToString(),
HolePunch(type=HolePunch.SYNC).SerializeToString(),
]
)
mock_stream.write = AsyncMock()
mock_stream.close = AsyncMock()
mock_stream.muxed_conn = MagicMock(peer_id=peer_id)
mock_host.new_stream = AsyncMock(return_value=mock_stream)
monkeypatch.setattr(dcutr, "_perform_hole_punch", AsyncMock(return_value=True))
result = await dcutr.initiate_hole_punch(peer_id)
assert result is True

View File

@ -0,0 +1,297 @@
"""Tests for NAT traversal utilities."""
from unittest.mock import MagicMock
import pytest
from multiaddr import Multiaddr
from libp2p.peer.id import ID
from libp2p.relay.circuit_v2.nat import (
ReachabilityChecker,
extract_ip_from_multiaddr,
ip_to_int,
is_ip_in_range,
is_private_ip,
)
def test_ip_to_int_ipv4():
"""Test converting IPv4 addresses to integers."""
assert ip_to_int("192.168.1.1") == 3232235777
assert ip_to_int("10.0.0.1") == 167772161
assert ip_to_int("127.0.0.1") == 2130706433
def test_ip_to_int_ipv6():
"""Test converting IPv6 addresses to integers."""
# Test with a simple IPv6 address
ipv6_int = ip_to_int("::1")
assert isinstance(ipv6_int, int)
assert ipv6_int > 0
def test_ip_to_int_invalid():
"""Test handling of invalid IP addresses."""
with pytest.raises(ValueError):
ip_to_int("invalid-ip")
def test_is_ip_in_range():
"""Test IP range checking."""
# Test within range
assert is_ip_in_range("192.168.1.5", "192.168.1.1", "192.168.1.10") is True
assert is_ip_in_range("10.0.0.5", "10.0.0.0", "10.0.0.255") is True
# Test outside range
assert is_ip_in_range("192.168.2.5", "192.168.1.1", "192.168.1.10") is False
assert is_ip_in_range("8.8.8.8", "10.0.0.0", "10.0.0.255") is False
def test_is_ip_in_range_invalid():
"""Test IP range checking with invalid inputs."""
assert is_ip_in_range("invalid", "192.168.1.1", "192.168.1.10") is False
assert is_ip_in_range("192.168.1.5", "invalid", "192.168.1.10") is False
def test_is_private_ip():
"""Test private IP detection."""
# Private IPs
assert is_private_ip("192.168.1.1") is True
assert is_private_ip("10.0.0.1") is True
assert is_private_ip("172.16.0.1") is True
assert is_private_ip("127.0.0.1") is True # Loopback
assert is_private_ip("169.254.1.1") is True # Link-local
# Public IPs
assert is_private_ip("8.8.8.8") is False
assert is_private_ip("1.1.1.1") is False
assert is_private_ip("208.67.222.222") is False
def test_extract_ip_from_multiaddr():
"""Test IP extraction from multiaddrs."""
# IPv4 addresses
addr1 = Multiaddr("/ip4/192.168.1.1/tcp/1234")
assert extract_ip_from_multiaddr(addr1) == "192.168.1.1"
addr2 = Multiaddr("/ip4/10.0.0.1/udp/5678")
assert extract_ip_from_multiaddr(addr2) == "10.0.0.1"
# IPv6 addresses
addr3 = Multiaddr("/ip6/::1/tcp/1234")
assert extract_ip_from_multiaddr(addr3) == "::1"
addr4 = Multiaddr("/ip6/2001:db8::1/udp/5678")
assert extract_ip_from_multiaddr(addr4) == "2001:db8::1"
# No IP address
addr5 = Multiaddr("/dns4/example.com/tcp/1234")
assert extract_ip_from_multiaddr(addr5) is None
# Complex multiaddr (without p2p to avoid base58 issues)
addr6 = Multiaddr("/ip4/192.168.1.1/tcp/1234/udp/5678")
assert extract_ip_from_multiaddr(addr6) == "192.168.1.1"
def test_reachability_checker_init():
"""Test ReachabilityChecker initialization."""
mock_host = MagicMock()
checker = ReachabilityChecker(mock_host)
assert checker.host == mock_host
assert checker._peer_reachability == {}
assert checker._known_public_peers == set()
def test_reachability_checker_is_addr_public():
"""Test public address detection."""
mock_host = MagicMock()
checker = ReachabilityChecker(mock_host)
# Public addresses
public_addr1 = Multiaddr("/ip4/8.8.8.8/tcp/1234")
assert checker.is_addr_public(public_addr1) is True
public_addr2 = Multiaddr("/ip4/1.1.1.1/udp/5678")
assert checker.is_addr_public(public_addr2) is True
# Private addresses
private_addr1 = Multiaddr("/ip4/192.168.1.1/tcp/1234")
assert checker.is_addr_public(private_addr1) is False
private_addr2 = Multiaddr("/ip4/10.0.0.1/udp/5678")
assert checker.is_addr_public(private_addr2) is False
private_addr3 = Multiaddr("/ip4/127.0.0.1/tcp/1234")
assert checker.is_addr_public(private_addr3) is False
# No IP address
dns_addr = Multiaddr("/dns4/example.com/tcp/1234")
assert checker.is_addr_public(dns_addr) is False
def test_reachability_checker_get_public_addrs():
"""Test filtering for public addresses."""
mock_host = MagicMock()
checker = ReachabilityChecker(mock_host)
addrs = [
Multiaddr("/ip4/8.8.8.8/tcp/1234"), # Public
Multiaddr("/ip4/192.168.1.1/tcp/1234"), # Private
Multiaddr("/ip4/1.1.1.1/udp/5678"), # Public
Multiaddr("/ip4/10.0.0.1/tcp/1234"), # Private
Multiaddr("/dns4/example.com/tcp/1234"), # DNS
]
public_addrs = checker.get_public_addrs(addrs)
assert len(public_addrs) == 2
assert Multiaddr("/ip4/8.8.8.8/tcp/1234") in public_addrs
assert Multiaddr("/ip4/1.1.1.1/udp/5678") in public_addrs
@pytest.mark.trio
async def test_check_peer_reachability_connected_direct():
"""Test peer reachability when directly connected."""
mock_host = MagicMock()
mock_network = MagicMock()
mock_host.get_network.return_value = mock_network
peer_id = ID(b"test-peer-id")
mock_conn = MagicMock()
mock_conn.get_transport_addresses.return_value = [
Multiaddr("/ip4/192.168.1.1/tcp/1234") # Direct connection
]
mock_network.connections = {peer_id: mock_conn}
checker = ReachabilityChecker(mock_host)
result = await checker.check_peer_reachability(peer_id)
assert result is True
assert checker._peer_reachability[peer_id] is True
@pytest.mark.trio
async def test_check_peer_reachability_connected_relay():
"""Test peer reachability when connected through relay."""
mock_host = MagicMock()
mock_network = MagicMock()
mock_host.get_network.return_value = mock_network
peer_id = ID(b"test-peer-id")
mock_conn = MagicMock()
mock_conn.get_transport_addresses.return_value = [
Multiaddr("/p2p-circuit/ip4/192.168.1.1/tcp/1234") # Relay connection
]
mock_network.connections = {peer_id: mock_conn}
# Mock peerstore with public addresses
mock_peerstore = MagicMock()
mock_peerstore.addrs.return_value = [
Multiaddr("/ip4/8.8.8.8/tcp/1234") # Public address
]
mock_host.get_peerstore.return_value = mock_peerstore
checker = ReachabilityChecker(mock_host)
result = await checker.check_peer_reachability(peer_id)
assert result is True
assert checker._peer_reachability[peer_id] is True
@pytest.mark.trio
async def test_check_peer_reachability_not_connected():
"""Test peer reachability when not connected."""
mock_host = MagicMock()
mock_network = MagicMock()
mock_host.get_network.return_value = mock_network
peer_id = ID(b"test-peer-id")
mock_network.connections = {} # No connections
checker = ReachabilityChecker(mock_host)
result = await checker.check_peer_reachability(peer_id)
assert result is False
# When not connected, the method doesn't add to cache
assert peer_id not in checker._peer_reachability
@pytest.mark.trio
async def test_check_peer_reachability_cached():
"""Test that peer reachability results are cached."""
mock_host = MagicMock()
checker = ReachabilityChecker(mock_host)
peer_id = ID(b"test-peer-id")
checker._peer_reachability[peer_id] = True
result = await checker.check_peer_reachability(peer_id)
assert result is True
# Should not call host methods when cached
mock_host.get_network.assert_not_called()
@pytest.mark.trio
async def test_check_self_reachability_with_public_addrs():
"""Test self reachability when host has public addresses."""
mock_host = MagicMock()
mock_host.get_addrs.return_value = [
Multiaddr("/ip4/8.8.8.8/tcp/1234"), # Public
Multiaddr("/ip4/192.168.1.1/tcp/1234"), # Private
Multiaddr("/ip4/1.1.1.1/udp/5678"), # Public
]
checker = ReachabilityChecker(mock_host)
is_reachable, public_addrs = await checker.check_self_reachability()
assert is_reachable is True
assert len(public_addrs) == 2
assert Multiaddr("/ip4/8.8.8.8/tcp/1234") in public_addrs
assert Multiaddr("/ip4/1.1.1.1/udp/5678") in public_addrs
@pytest.mark.trio
async def test_check_self_reachability_no_public_addrs():
"""Test self reachability when host has no public addresses."""
mock_host = MagicMock()
mock_host.get_addrs.return_value = [
Multiaddr("/ip4/192.168.1.1/tcp/1234"), # Private
Multiaddr("/ip4/10.0.0.1/udp/5678"), # Private
Multiaddr("/ip4/127.0.0.1/tcp/1234"), # Loopback
]
checker = ReachabilityChecker(mock_host)
is_reachable, public_addrs = await checker.check_self_reachability()
assert is_reachable is False
assert len(public_addrs) == 0
@pytest.mark.trio
async def test_check_peer_reachability_multiple_connections():
"""Test peer reachability with multiple connections."""
mock_host = MagicMock()
mock_network = MagicMock()
mock_host.get_network.return_value = mock_network
peer_id = ID(b"test-peer-id")
mock_conn1 = MagicMock()
mock_conn1.get_transport_addresses.return_value = [
Multiaddr("/p2p-circuit/ip4/192.168.1.1/tcp/1234") # Relay
]
mock_conn2 = MagicMock()
mock_conn2.get_transport_addresses.return_value = [
Multiaddr("/ip4/192.168.1.1/tcp/1234") # Direct
]
mock_network.connections = {peer_id: [mock_conn1, mock_conn2]}
checker = ReachabilityChecker(mock_host)
result = await checker.check_peer_reachability(peer_id)
assert result is True
assert checker._peer_reachability[peer_id] is True

View File

@ -8,6 +8,7 @@ from libp2p.stream_muxer.mplex.exceptions import (
MplexStreamClosed,
MplexStreamEOF,
MplexStreamReset,
MuxedConnUnavailable,
)
from libp2p.stream_muxer.mplex.mplex import (
MPLEX_MESSAGE_CHANNEL_SIZE,
@ -213,3 +214,39 @@ async def test_mplex_stream_reset(mplex_stream_pair):
# `reset` should do nothing as well.
await stream_0.reset()
await stream_1.reset()
@pytest.mark.trio
async def test_mplex_stream_close_timeout(monkeypatch, mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
# (simulate hanging)
async def fake_send_message(*args, **kwargs):
await trio.sleep_forever()
monkeypatch.setattr(stream_0.muxed_conn, "send_message", fake_send_message)
with pytest.raises(TimeoutError):
await stream_0.close()
@pytest.mark.trio
async def test_mplex_stream_close_mux_unavailable(monkeypatch, mplex_stream_pair):
stream_0, _ = mplex_stream_pair
# Patch send_message to raise MuxedConnUnavailable
def raise_unavailable(*args, **kwargs):
raise MuxedConnUnavailable("Simulated conn drop")
monkeypatch.setattr(stream_0.muxed_conn, "send_message", raise_unavailable)
# Case 1: Mplex is shutting down — should not raise
stream_0.muxed_conn.event_shutting_down.set()
await stream_0.close() # Should NOT raise
# Case 2: Mplex is NOT shutting down — should raise RuntimeError
stream_0.event_local_closed = trio.Event() # Reset since it was set in first call
stream_0.muxed_conn.event_shutting_down = trio.Event() # Unset the shutdown flag
with pytest.raises(RuntimeError, match="Failed to send close message"):
await stream_0.close()

View File

@ -0,0 +1,590 @@
from typing import Any, cast
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import trio
from trio.testing import wait_all_tasks_blocked
from libp2p.stream_muxer.exceptions import (
MuxedConnUnavailable,
)
from libp2p.stream_muxer.mplex.constants import HeaderTags
from libp2p.stream_muxer.mplex.datastructures import StreamID
from libp2p.stream_muxer.mplex.exceptions import (
MplexStreamClosed,
MplexStreamEOF,
MplexStreamReset,
)
from libp2p.stream_muxer.mplex.mplex_stream import MplexStream
class MockMuxedConn:
"""A mock Mplex connection for testing purposes."""
def __init__(self):
self.sent_messages = []
self.streams: dict[StreamID, MplexStream] = {}
self.streams_lock = trio.Lock()
self.is_unavailable = False
async def send_message(
self, flag: HeaderTags, data: bytes | None, stream_id: StreamID
) -> None:
"""Mocks sending a message over the connection."""
if self.is_unavailable:
raise MuxedConnUnavailable("Connection is unavailable")
self.sent_messages.append((flag, data, stream_id))
# Yield to allow other tasks to run
await trio.lowlevel.checkpoint()
def get_remote_address(self) -> tuple[str, int]:
"""Mocks getting the remote address."""
return "127.0.0.1", 4001
@pytest.fixture
async def mplex_stream():
"""Provides a fully initialized MplexStream and its communication channels."""
# Use a buffered channel to prevent deadlocks in simple tests
send_chan, recv_chan = trio.open_memory_channel(10)
stream_id = StreamID(1, is_initiator=True)
muxed_conn = MockMuxedConn()
stream = MplexStream("test-stream", stream_id, cast(Any, muxed_conn), recv_chan)
muxed_conn.streams[stream_id] = stream
yield stream, send_chan, muxed_conn
# Cleanup: Close channels and reset stream state
await send_chan.aclose()
await recv_chan.aclose()
# Reset stream state to prevent cross-test contamination
stream.event_local_closed = trio.Event()
stream.event_remote_closed = trio.Event()
stream.event_reset = trio.Event()
# ===============================================
# 1. Tests for Stream-Level Lock Integration
# ===============================================
@pytest.mark.trio
async def test_stream_write_is_protected_by_rwlock(mplex_stream):
"""Verify that stream.write() acquires and releases the write lock."""
stream, _, muxed_conn = mplex_stream
# Mock lock methods
original_acquire = stream.rw_lock.acquire_write
original_release = stream.rw_lock.release_write
stream.rw_lock.acquire_write = AsyncMock(wraps=original_acquire)
stream.rw_lock.release_write = MagicMock(wraps=original_release)
await stream.write(b"test data")
stream.rw_lock.acquire_write.assert_awaited_once()
stream.rw_lock.release_write.assert_called_once()
# Verify the message was actually sent
assert len(muxed_conn.sent_messages) == 1
flag, data, stream_id = muxed_conn.sent_messages[0]
assert flag == HeaderTags.MessageInitiator
assert data == b"test data"
assert stream_id == stream.stream_id
@pytest.mark.trio
async def test_stream_read_is_protected_by_rwlock(mplex_stream):
"""Verify that stream.read() acquires and releases the read lock."""
stream, send_chan, _ = mplex_stream
# Mock lock methods
original_acquire = stream.rw_lock.acquire_read
original_release = stream.rw_lock.release_read
stream.rw_lock.acquire_read = AsyncMock(wraps=original_acquire)
stream.rw_lock.release_read = AsyncMock(wraps=original_release)
await send_chan.send(b"hello")
result = await stream.read(5)
stream.rw_lock.acquire_read.assert_awaited_once()
stream.rw_lock.release_read.assert_awaited_once()
assert result == b"hello"
@pytest.mark.trio
async def test_multiple_readers_can_coexist(mplex_stream):
"""Verify multiple readers can operate concurrently."""
stream, send_chan, _ = mplex_stream
# Send enough data for both reads
await send_chan.send(b"data1")
await send_chan.send(b"data2")
# Track lock acquisition order
acquisition_order = []
release_order = []
# Patch lock methods to track concurrency
original_acquire = stream.rw_lock.acquire_read
original_release = stream.rw_lock.release_read
async def tracked_acquire():
nonlocal acquisition_order
acquisition_order.append("start")
await original_acquire()
acquisition_order.append("acquired")
async def tracked_release():
nonlocal release_order
release_order.append("start")
await original_release()
release_order.append("released")
with (
patch.object(
stream.rw_lock, "acquire_read", side_effect=tracked_acquire, autospec=True
),
patch.object(
stream.rw_lock, "release_read", side_effect=tracked_release, autospec=True
),
):
# Execute concurrent reads
async with trio.open_nursery() as nursery:
nursery.start_soon(stream.read, 5)
nursery.start_soon(stream.read, 5)
# Verify both reads happened
assert acquisition_order.count("start") == 2
assert acquisition_order.count("acquired") == 2
assert release_order.count("start") == 2
assert release_order.count("released") == 2
@pytest.mark.trio
async def test_writer_blocks_readers(mplex_stream):
"""Verify that a writer blocks all readers and new readers queue behind."""
stream, send_chan, _ = mplex_stream
writer_acquired = trio.Event()
readers_ready = trio.Event()
writer_finished = trio.Event()
all_readers_started = trio.Event()
all_readers_done = trio.Event()
counters = {"reader_start_count": 0, "reader_done_count": 0}
reader_target = 3
reader_start_lock = trio.Lock()
# Patch write lock to control test flow
original_acquire_write = stream.rw_lock.acquire_write
original_release_write = stream.rw_lock.release_write
async def tracked_acquire_write():
await original_acquire_write()
writer_acquired.set()
# Wait for readers to queue up
await readers_ready.wait()
# Must be synchronous since real release_write is sync
def tracked_release_write():
original_release_write()
writer_finished.set()
with (
patch.object(
stream.rw_lock, "acquire_write", side_effect=tracked_acquire_write
),
patch.object(
stream.rw_lock, "release_write", side_effect=tracked_release_write
),
):
async with trio.open_nursery() as nursery:
# Start writer
nursery.start_soon(stream.write, b"test")
await writer_acquired.wait()
# Start readers
async def reader_task():
async with reader_start_lock:
counters["reader_start_count"] += 1
if counters["reader_start_count"] == reader_target:
all_readers_started.set()
try:
# This will block until data is available
await stream.read(5)
except (MplexStreamReset, MplexStreamEOF):
pass
finally:
async with reader_start_lock:
counters["reader_done_count"] += 1
if counters["reader_done_count"] == reader_target:
all_readers_done.set()
for _ in range(reader_target):
nursery.start_soon(reader_task)
# Wait until all readers are started
await all_readers_started.wait()
# Let the writer finish and release the lock
readers_ready.set()
await writer_finished.wait()
# Send data to unblock the readers
for i in range(reader_target):
await send_chan.send(b"data" + str(i).encode())
# Wait for all readers to finish
await all_readers_done.wait()
@pytest.mark.trio
async def test_writer_waits_for_readers(mplex_stream):
"""Verify a writer waits for existing readers to complete."""
stream, send_chan, _ = mplex_stream
readers_started = trio.Event()
writer_entered = trio.Event()
writer_acquiring = trio.Event()
readers_finished = trio.Event()
# Send data for readers
await send_chan.send(b"data1")
await send_chan.send(b"data2")
# Patch read lock to control test flow
original_acquire_read = stream.rw_lock.acquire_read
async def tracked_acquire_read():
await original_acquire_read()
readers_started.set()
# Wait until readers are allowed to finish
await readers_finished.wait()
# Patch write lock to detect when writer is blocked
original_acquire_write = stream.rw_lock.acquire_write
async def tracked_acquire_write():
writer_acquiring.set()
await original_acquire_write()
writer_entered.set()
with (
patch.object(stream.rw_lock, "acquire_read", side_effect=tracked_acquire_read),
patch.object(
stream.rw_lock, "acquire_write", side_effect=tracked_acquire_write
),
):
async with trio.open_nursery() as nursery:
# Start readers
nursery.start_soon(stream.read, 5)
nursery.start_soon(stream.read, 5)
# Wait for at least one reader to acquire the lock
await readers_started.wait()
# Start writer (should block)
nursery.start_soon(stream.write, b"test")
# Wait for writer to start acquiring lock
await writer_acquiring.wait()
# Verify writer hasn't entered critical section
assert not writer_entered.is_set()
# Allow readers to finish
readers_finished.set()
# Verify writer can proceed
await writer_entered.wait()
@pytest.mark.trio
async def test_lock_behavior_during_cancellation(mplex_stream):
"""Verify that a lock is released when a task holding it is cancelled."""
stream, _, _ = mplex_stream
reader_acquired_lock = trio.Event()
async def cancellable_reader(task_status):
async with stream.rw_lock.read_lock():
reader_acquired_lock.set()
task_status.started()
# Wait indefinitely until cancelled.
await trio.sleep_forever()
async with trio.open_nursery() as nursery:
# Start the reader and wait for it to acquire the lock.
await nursery.start(cancellable_reader)
await reader_acquired_lock.wait()
# Now that the reader has the lock, cancel the nursery.
# This will cancel the reader task, and its lock should be released.
nursery.cancel_scope.cancel()
# After the nursery is cancelled, the reader should have released the lock.
# To verify, we try to acquire a write lock. If the read lock was not
# released, this will time out.
with trio.move_on_after(1) as cancel_scope:
async with stream.rw_lock.write_lock():
pass
if cancel_scope.cancelled_caught:
pytest.fail(
"Write lock could not be acquired after a cancelled reader, "
"indicating the read lock was not released."
)
@pytest.mark.trio
async def test_concurrent_read_write_sequence(mplex_stream):
"""Verify complex sequence of interleaved reads and writes."""
stream, send_chan, _ = mplex_stream
results = []
# Use a mock to intercept writes and feed them back to the read channel
original_write = stream.write
reader1_finished = trio.Event()
writer1_finished = trio.Event()
reader2_finished = trio.Event()
async def mocked_write(data: bytes) -> None:
await original_write(data)
# Simulate the other side receiving the data and sending a response
# by putting data into the read channel.
await send_chan.send(data)
with patch.object(stream, "write", wraps=mocked_write) as patched_write:
async with trio.open_nursery() as nursery:
# Test scenario:
# 1. Reader 1 starts, waits for data.
# 2. Writer 1 writes, which gets fed back to the stream.
# 3. Reader 2 starts, reads what Writer 1 wrote.
# 4. Writer 2 writes.
async def reader1():
nonlocal results
results.append("R1 start")
data = await stream.read(5)
results.append(data)
results.append("R1 done")
reader1_finished.set()
async def writer1():
nonlocal results
await reader1_finished.wait()
results.append("W1 start")
await stream.write(b"write1")
results.append("W1 done")
writer1_finished.set()
async def reader2():
nonlocal results
await writer1_finished.wait()
# This will read the data from writer1
results.append("R2 start")
data = await stream.read(6)
results.append(data)
results.append("R2 done")
reader2_finished.set()
async def writer2():
nonlocal results
await reader2_finished.wait()
results.append("W2 start")
await stream.write(b"write2")
results.append("W2 done")
# Execute sequence
nursery.start_soon(reader1)
nursery.start_soon(writer1)
nursery.start_soon(reader2)
nursery.start_soon(writer2)
await send_chan.send(b"data1")
# Verify sequence and that write was called
assert patched_write.call_count == 2
assert results == [
"R1 start",
b"data1",
"R1 done",
"W1 start",
"W1 done",
"R2 start",
b"write1",
"R2 done",
"W2 start",
"W2 done",
]
# ===============================================
# 2. Tests for Reset, EOF, and Close Interactions
# ===============================================
@pytest.mark.trio
async def test_read_after_remote_close_triggers_eof(mplex_stream):
"""Verify reading from a remotely closed stream returns EOF correctly."""
stream, send_chan, _ = mplex_stream
# Send some data that can be read first
await send_chan.send(b"data")
# Close the channel to signify no more data will ever arrive
await send_chan.aclose()
# Mark the stream as remotely closed
stream.event_remote_closed.set()
# The first read should succeed, consuming the buffered data
data = await stream.read(4)
assert data == b"data"
# Now that the buffer is empty and the channel is closed, this should raise EOF
with pytest.raises(MplexStreamEOF):
await stream.read(1)
@pytest.mark.trio
async def test_read_on_closed_stream_raises_eof(mplex_stream):
"""Test that reading from a closed stream with no data raises EOF."""
stream, send_chan, _ = mplex_stream
stream.event_remote_closed.set()
await send_chan.aclose() # Ensure the channel is closed
# Reading from a stream that is closed and has no data should raise EOF
with pytest.raises(MplexStreamEOF):
await stream.read(100)
@pytest.mark.trio
async def test_write_to_locally_closed_stream_raises(mplex_stream):
"""Verify writing to a locally closed stream raises MplexStreamClosed."""
stream, _, _ = mplex_stream
stream.event_local_closed.set()
with pytest.raises(MplexStreamClosed):
await stream.write(b"this should fail")
@pytest.mark.trio
async def test_read_from_reset_stream_raises(mplex_stream):
"""Verify reading from a reset stream raises MplexStreamReset."""
stream, _, _ = mplex_stream
stream.event_reset.set()
with pytest.raises(MplexStreamReset):
await stream.read(10)
@pytest.mark.trio
async def test_write_to_reset_stream_raises(mplex_stream):
"""Verify writing to a reset stream raises MplexStreamClosed."""
stream, _, _ = mplex_stream
# A stream reset implies it's also locally closed.
await stream.reset()
# The `write` method checks `event_local_closed`, which `reset` sets.
with pytest.raises(MplexStreamClosed):
await stream.write(b"this should also fail")
@pytest.mark.trio
async def test_stream_reset_cleans_up_resources(mplex_stream):
"""Verify reset() cleans up stream state and resources."""
stream, _, muxed_conn = mplex_stream
stream_id = stream.stream_id
assert stream_id in muxed_conn.streams
await stream.reset()
assert stream.event_reset.is_set()
assert stream.event_local_closed.is_set()
assert stream.event_remote_closed.is_set()
assert (HeaderTags.ResetInitiator, None, stream_id) in muxed_conn.sent_messages
assert stream_id not in muxed_conn.streams
# Verify the underlying data channel is closed
with pytest.raises(trio.ClosedResourceError):
await stream.incoming_data_channel.receive()
# ===============================================
# 3. Rigorous Concurrency Tests with Events
# ===============================================
@pytest.mark.trio
async def test_writer_is_blocked_by_reader_using_events(mplex_stream):
"""Verify a writer must wait for a reader using trio.Event for synchronization."""
stream, _, _ = mplex_stream
reader_has_lock = trio.Event()
writer_finished = trio.Event()
async def reader():
async with stream.rw_lock.read_lock():
reader_has_lock.set()
# Hold the lock until the writer has finished its attempt
await writer_finished.wait()
async def writer():
await reader_has_lock.wait()
# This call will now block until the reader releases the lock
await stream.write(b"data")
writer_finished.set()
async with trio.open_nursery() as nursery:
nursery.start_soon(reader)
nursery.start_soon(writer)
# Verify writer is blocked
await wait_all_tasks_blocked()
assert not writer_finished.is_set()
# Signal the reader to finish
writer_finished.set()
@pytest.mark.trio
async def test_multiple_readers_can_read_concurrently_using_events(mplex_stream):
"""Verify that multiple readers can acquire a read lock simultaneously."""
stream, _, _ = mplex_stream
counters = {"readers_in_critical_section": 0}
lock = trio.Lock() # To safely mutate the counter
reader1_acquired = trio.Event()
reader2_acquired = trio.Event()
all_readers_finished = trio.Event()
async def concurrent_reader(event_to_set: trio.Event):
async with stream.rw_lock.read_lock():
async with lock:
counters["readers_in_critical_section"] += 1
event_to_set.set()
# Wait until all readers have finished before exiting the lock context
await all_readers_finished.wait()
async with lock:
counters["readers_in_critical_section"] -= 1
async with trio.open_nursery() as nursery:
nursery.start_soon(concurrent_reader, reader1_acquired)
nursery.start_soon(concurrent_reader, reader2_acquired)
# Wait for both readers to acquire their locks
await reader1_acquired.wait()
await reader2_acquired.wait()
# Check that both were in the critical section at the same time
async with lock:
assert counters["readers_in_critical_section"] == 2
# Signal for all readers to finish
all_readers_finished.set()
# Verify they exit cleanly
await wait_all_tasks_blocked()
async with lock:
assert counters["readers_in_critical_section"] == 0

View File

@ -0,0 +1 @@
"""Discovery tests for py-libp2p."""

View File

@ -0,0 +1 @@
"""Bootstrap discovery tests for py-libp2p."""

View File

@ -0,0 +1,52 @@
#!/usr/bin/env python3
"""
Test the full bootstrap discovery integration
"""
import secrets
import pytest
from libp2p import new_host
from libp2p.crypto.secp256k1 import create_new_key_pair
from libp2p.host.basic_host import BasicHost
@pytest.mark.trio
async def test_bootstrap_integration():
"""Test bootstrap integration with new_host"""
# Test bootstrap addresses
bootstrap_addrs = [
"/ip4/104.131.131.82/tcp/4001/p2p/QmaCpDMGvV2BGHeYERUEnRQAwe3N8SznbYGzPwp8qDrq",
"/ip4/104.236.179.241/tcp/4001/p2p/QmSoLPppuBtQSGwKDZT2M73ULpjvfd3aZ6ha4oFGL1KrGM",
]
# Generate key pair
secret = secrets.token_bytes(32)
key_pair = create_new_key_pair(secret)
# Create host with bootstrap
host = new_host(key_pair=key_pair, bootstrap=bootstrap_addrs)
# Verify bootstrap discovery is set up (cast to BasicHost for type checking)
assert isinstance(host, BasicHost), "Host should be a BasicHost instance"
assert hasattr(host, "bootstrap"), "Host should have bootstrap attribute"
assert host.bootstrap is not None, "Bootstrap discovery should be initialized"
assert len(host.bootstrap.bootstrap_addrs) == len(bootstrap_addrs), (
"Bootstrap addresses should match"
)
def test_bootstrap_no_addresses():
"""Test that bootstrap is not initialized when no addresses provided"""
secret = secrets.token_bytes(32)
key_pair = create_new_key_pair(secret)
# Create host without bootstrap
host = new_host(key_pair=key_pair)
# Verify bootstrap is not initialized
assert isinstance(host, BasicHost)
assert not hasattr(host, "bootstrap") or host.bootstrap is None, (
"Bootstrap should not be initialized when no addresses provided"
)

View File

@ -0,0 +1,39 @@
#!/usr/bin/env python3
"""
Test bootstrap address validation
"""
from libp2p.discovery.bootstrap.utils import (
parse_bootstrap_peer_info,
validate_bootstrap_addresses,
)
def test_validate_addresses():
"""Test validation with a mix of valid and invalid addresses in one list."""
addresses = [
# Valid - using proper peer IDs
"/ip4/104.131.131.82/tcp/4001/p2p/QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ",
"/ip4/104.236.179.241/tcp/4001/p2p/QmSoLPppuBtQSGwKDZT2M73ULpjvfd3aZ6ha4oFGL1KrGM",
# Invalid
"invalid-address",
"/ip4/192.168.1.1/tcp/4001", # Missing p2p part
"", # Empty
"/ip4/127.0.0.1/tcp/4001/p2p/InvalidPeerID", # Bad peer ID
]
valid_expected = [
"/ip4/104.131.131.82/tcp/4001/p2p/QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ",
"/ip4/104.236.179.241/tcp/4001/p2p/QmSoLPppuBtQSGwKDZT2M73ULpjvfd3aZ6ha4oFGL1KrGM",
]
validated = validate_bootstrap_addresses(addresses)
assert validated == valid_expected, (
f"Expected only valid addresses, got: {validated}"
)
for addr in addresses:
peer_info = parse_bootstrap_peer_info(addr)
if addr in valid_expected:
assert peer_info is not None and peer_info.peer_id is not None, (
f"Should parse valid address: {addr}"
)
else:
assert peer_info is None, f"Should not parse invalid address: {addr}"