mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Compare commits
244 Commits
async-vali
...
333d56dc00
| Author | SHA1 | Date | |
|---|---|---|---|
| 333d56dc00 | |||
| 2535305123 | |||
| 9df542f97f | |||
| 93fe070cfb | |||
| 7a4c955c98 | |||
| 934f49af83 | |||
| 970b535b25 | |||
| fc6b290c56 | |||
| ef6557518c | |||
| 1783a6b0b9 | |||
| 1077516196 | |||
| aad87f983f | |||
| 7d6eb28d7c | |||
| 9a06ee429f | |||
| 526b65e1d5 | |||
| 59e1d9ae39 | |||
| d620270eaf | |||
| 31040931ea | |||
| 96e2149f4d | |||
| cb5bfeda39 | |||
| b26e8333bd | |||
| d99b67eafa | |||
| cdfb083c06 | |||
| d4c387f923 | |||
| 56526b4870 | |||
| 3c52b859ba | |||
| 426aae7efb | |||
| 999315a74a | |||
| df39e240e7 | |||
| 5c11ac20e7 | |||
| 9fa3afbb04 | |||
| c577fd2f71 | |||
| 9f80dbae12 | |||
| c08007feda | |||
| c2c4228591 | |||
| 943bcc4d36 | |||
| 2006b2c92c | |||
| fe3f7adc1b | |||
| 7b2d637382 | |||
| 91bee9df89 | |||
| 5bf9c7b537 | |||
| 8958c0fac3 | |||
| 091ac082b9 | |||
| 15f4a399ec | |||
| 3917d7b596 | |||
| 3aacb3a391 | |||
| ba39e91a2e | |||
| 57d1c9d807 | |||
| efc899e872 | |||
| cea1985c5c | |||
| 702ad4876e | |||
| a21d9e878b | |||
| 5ab68026d6 | |||
| d1792588f9 | |||
| 53db128f69 | |||
| cacb3c8aca | |||
| 6214697349 | |||
| cda50e0ead | |||
| 05fde3ad40 | |||
| 292bd1a942 | |||
| c9795e3138 | |||
| b80817b5ae | |||
| 6c6adf7459 | |||
| 79f3a173f4 | |||
| 7fb3c2da9f | |||
| 6b7f50be3d | |||
| 6a0a7c21e8 | |||
| fde8c8f127 | |||
| bc1b1ed6ae | |||
| 63a8458d45 | |||
| ed91ee0c31 | |||
| 75ffb791ac | |||
| cf48d2e9a4 | |||
| 88a1f0a390 | |||
| b38d504fc1 | |||
| 3bd6d1f579 | |||
| b6cbd78943 | |||
| ed2716c1bf | |||
| 9efc5a1bd1 | |||
| 5b9bec8e28 | |||
| c2c91b8c58 | |||
| 8a2d1f7045 | |||
| 94d695c6bc | |||
| 905f3a5708 | |||
| dabb3a0962 | |||
| 69d5274891 | |||
| 3ff5728209 | |||
| a1b16248d3 | |||
| 55dd8835a7 | |||
| e20a9a3814 | |||
| 7f6469d5d4 | |||
| ee66958e7f | |||
| c306400bd9 | |||
| 05b372b1eb | |||
| e4ab3cb2c5 | |||
| fe71c479dc | |||
| 95e1f62870 | |||
| 9378490dcb | |||
| a2fcf33bc1 | |||
| b363d1d6d0 | |||
| 9a0f224a1c | |||
| 13379e38d8 | |||
| 09d2110d65 | |||
| cff0bfc17d | |||
| a2ad10b1e4 | |||
| 7c2014087f | |||
| 37df8d679d | |||
| 5c78a41552 | |||
| 90f143cd88 | |||
| 1ecff5437c | |||
| aa7276c863 | |||
| b838a0e3b6 | |||
| b01596ad92 | |||
| 1565d409e8 | |||
| 400ee9b896 | |||
| 2730db4285 | |||
| bb896dac2c | |||
| a14c42ef73 | |||
| af61523c87 | |||
| d2fdf70692 | |||
| 09cd8b37ed | |||
| 1ea50a3cf3 | |||
| f4247faa51 | |||
| 92e79bbb3f | |||
| eb3121b818 | |||
| 787648177f | |||
| fc9b28910a | |||
| 26d0ed2d81 | |||
| 618aff9368 | |||
| 32e545d9c7 | |||
| e712e6c0c4 | |||
| 59a898c8ce | |||
| fa174230ba | |||
| b840eaa7e1 | |||
| b143c96abd | |||
| 678b920992 | |||
| cb11f076c8 | |||
| 9ed44f5fa3 | |||
| 8786f06862 | |||
| 8c96c5a941 | |||
| 16445714f7 | |||
| 64bc388b33 | |||
| 09e151aafc | |||
| 2d335d4394 | |||
| 8b8b051885 | |||
| 07c8d4cd1f | |||
| 09e6feea8e | |||
| 601a8a3ef0 | |||
| 9d597012cc | |||
| 8625226be8 | |||
| c2b1738cd9 | |||
| 83acc38281 | |||
| 1899dac84c | |||
| aab2a0b603 | |||
| bab08c0900 | |||
| 6431fb8788 | |||
| c8053417d5 | |||
| 6eba9d8ca0 | |||
| 0e1b738cbb | |||
| 2ff5ae9c90 | |||
| ecc443dcfe | |||
| aa6039bcd3 | |||
| 8352d19113 | |||
| ceb9f7d3f7 | |||
| 9b667bd472 | |||
| eca548851b | |||
| e91f458446 | |||
| 0416572457 | |||
| 39375fb338 | |||
| 8bf261ca77 | |||
| 3a927c8419 | |||
| ec92af20e7 | |||
| 01db5d5fa0 | |||
| 21ee417793 | |||
| 37e4fee9f8 | |||
| c277cce2ed | |||
| 048e6deb96 | |||
| 2dc2dd4670 | |||
| e6a355d395 | |||
| 7b181f3ce5 | |||
| 0606788ab6 | |||
| 7d62a2f558 | |||
| 26fd169ccc | |||
| 99db5b309f | |||
| 7cfe5b9dc7 | |||
| 092b9c0c57 | |||
| fcf0546831 | |||
| 85bad2d0ae | |||
| 11560f5cc9 | |||
| 3507531344 | |||
| c9162beb2b | |||
| f587e50cab | |||
| d1a0f4f767 | |||
| 3ca27c6e93 | |||
| b4482e1a5e | |||
| ae82895d86 | |||
| 9f40d97a05 | |||
| 6fe28dcdd3 | |||
| 41b1ecb67c | |||
| e3c9b4bd54 | |||
| e132b154e3 | |||
| 430527625b | |||
| 93fc063e70 | |||
| 5315816521 | |||
| 42f07ae1ab | |||
| 773962c070 | |||
| ab94e77310 | |||
| 23622ea1a0 | |||
| 6aeb217349 | |||
| 003e7bf278 | |||
| 6f33cde9a9 | |||
| 9f38d48e26 | |||
| 2c1e50428a | |||
| 9e76940e75 | |||
| 9cd3805542 | |||
| d03bdd75d6 | |||
| 8ff7bb1f20 | |||
| 21db1c3b72 | |||
| 3592ad308f | |||
| 9669a92976 | |||
| 2dfee68f20 | |||
| 198208aef3 | |||
| cda163fc48 | |||
| 26ed99dafd | |||
| a26fd95854 | |||
| 2965b4e364 | |||
| 242998ae9d | |||
| 5f497c7f5d | |||
| e65e38a3f1 | |||
| 8fb664bfdf | |||
| 3dcd99a2d1 | |||
| 75abc8b863 | |||
| 91dca97d83 | |||
| 80c686ddce | |||
| dcb199a6b7 | |||
| 16be6fab85 | |||
| cbb1e26a4f | |||
| 69a2cb00ba | |||
| ec20ca81dd | |||
| b5ec1bd7ee | |||
| ddbd190993 | |||
| 36be4c354b | |||
| befb2d31db | |||
| 12ad2dcdf4 |
3
.gitignore
vendored
3
.gitignore
vendored
@ -178,3 +178,6 @@ env.bak/
|
||||
#lockfiles
|
||||
uv.lock
|
||||
poetry.lock
|
||||
|
||||
# Sphinx documentation build
|
||||
_build/
|
||||
|
||||
8
Makefile
8
Makefile
@ -60,6 +60,7 @@ PB = libp2p/crypto/pb/crypto.proto \
|
||||
libp2p/identity/identify/pb/identify.proto \
|
||||
libp2p/host/autonat/pb/autonat.proto \
|
||||
libp2p/relay/circuit_v2/pb/circuit.proto \
|
||||
libp2p/relay/circuit_v2/pb/dcutr.proto \
|
||||
libp2p/kad_dht/pb/kademlia.proto
|
||||
|
||||
PY = $(PB:.proto=_pb2.py)
|
||||
@ -68,6 +69,8 @@ PYI = $(PB:.proto=_pb2.pyi)
|
||||
## Set default to `protobufs`, otherwise `format` is called when typing only `make`
|
||||
all: protobufs
|
||||
|
||||
.PHONY: protobufs clean-proto
|
||||
|
||||
protobufs: $(PY)
|
||||
|
||||
%_pb2.py: %.proto
|
||||
@ -76,6 +79,11 @@ protobufs: $(PY)
|
||||
clean-proto:
|
||||
rm -f $(PY) $(PYI)
|
||||
|
||||
# Force protobuf regeneration by making them always out of date
|
||||
$(PY): FORCE
|
||||
|
||||
FORCE:
|
||||
|
||||
# docs commands
|
||||
|
||||
docs: check-docs
|
||||
|
||||
52
README.md
52
README.md
@ -12,13 +12,13 @@
|
||||
[](https://github.com/libp2p/py-libp2p/actions/workflows/tox.yml)
|
||||
[](http://py-libp2p.readthedocs.io/en/latest/?badge=latest)
|
||||
|
||||
> ⚠️ **Warning:** py-libp2p is an experimental and work-in-progress repo under development. We do not yet recommend using py-libp2p in production environments.
|
||||
> py-libp2p has moved beyond its experimental roots and is steadily progressing toward production readiness. The core features are stable, and we’re focused on refining performance, expanding protocol support, and ensuring smooth interop with other libp2p implementations. We welcome contributions and real-world usage feedback to help us reach full production maturity.
|
||||
|
||||
Read more in the [documentation on ReadTheDocs](https://py-libp2p.readthedocs.io/). [View the release notes](https://py-libp2p.readthedocs.io/en/latest/release_notes.html).
|
||||
|
||||
## Maintainers
|
||||
|
||||
Currently maintained by [@pacrob](https://github.com/pacrob), [@seetadev](https://github.com/seetadev) and [@dhuseby](https://github.com/dhuseby), looking for assistance!
|
||||
Currently maintained by [@pacrob](https://github.com/pacrob), [@seetadev](https://github.com/seetadev) and [@dhuseby](https://github.com/dhuseby). Please reach out to us for collaboration or active feedback. If you have questions, feel free to open a new [discussion](https://github.com/libp2p/py-libp2p/discussions). We are also available on the libp2p Discord — join us at #py-libp2p [sub-channel](https://discord.gg/d92MEugb).
|
||||
|
||||
## Feature Breakdown
|
||||
|
||||
@ -34,19 +34,19 @@ ______________________________________________________________________
|
||||
| -------------------------------------- | :--------: | :---------------------------------------------------------------------------------: |
|
||||
| **`libp2p-tcp`** | ✅ | [source](https://github.com/libp2p/py-libp2p/blob/main/libp2p/transport/tcp/tcp.py) |
|
||||
| **`libp2p-quic`** | 🌱 | |
|
||||
| **`libp2p-websocket`** | ❌ | |
|
||||
| **`libp2p-webrtc-browser-to-server`** | ❌ | |
|
||||
| **`libp2p-webrtc-private-to-private`** | ❌ | |
|
||||
| **`libp2p-websocket`** | 🌱 | |
|
||||
| **`libp2p-webrtc-browser-to-server`** | 🌱 | |
|
||||
| **`libp2p-webrtc-private-to-private`** | 🌱 | |
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
### NAT Traversal
|
||||
|
||||
| **NAT Traversal** | **Status** |
|
||||
| ----------------------------- | :--------: |
|
||||
| **`libp2p-circuit-relay-v2`** | ❌ |
|
||||
| **`libp2p-autonat`** | ❌ |
|
||||
| **`libp2p-hole-punching`** | ❌ |
|
||||
| **NAT Traversal** | **Status** | **Source** |
|
||||
| ----------------------------- | :--------: | :-----------------------------------------------------------------------------: |
|
||||
| **`libp2p-circuit-relay-v2`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/relay/circuit_v2) |
|
||||
| **`libp2p-autonat`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/host/autonat) |
|
||||
| **`libp2p-hole-punching`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/relay/circuit_v2) |
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
@ -54,27 +54,27 @@ ______________________________________________________________________
|
||||
|
||||
| **Secure Communication** | **Status** | **Source** |
|
||||
| ------------------------ | :--------: | :---------------------------------------------------------------------------: |
|
||||
| **`libp2p-noise`** | 🌱 | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/security/noise) |
|
||||
| **`libp2p-tls`** | ❌ | |
|
||||
| **`libp2p-noise`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/security/noise) |
|
||||
| **`libp2p-tls`** | 🌱 | |
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
### Discovery
|
||||
|
||||
| **Discovery** | **Status** |
|
||||
| -------------------- | :--------: |
|
||||
| **`bootstrap`** | ❌ |
|
||||
| **`random-walk`** | ❌ |
|
||||
| **`mdns-discovery`** | ❌ |
|
||||
| **`rendezvous`** | ❌ |
|
||||
| **Discovery** | **Status** | **Source** |
|
||||
| -------------------- | :--------: | :--------------------------------------------------------------------------------: |
|
||||
| **`bootstrap`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/discovery/bootstrap) |
|
||||
| **`random-walk`** | 🌱 | |
|
||||
| **`mdns-discovery`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/discovery/mdns) |
|
||||
| **`rendezvous`** | 🌱 | |
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
### Peer Routing
|
||||
|
||||
| **Peer Routing** | **Status** |
|
||||
| -------------------- | :--------: |
|
||||
| **`libp2p-kad-dht`** | ❌ |
|
||||
| **Peer Routing** | **Status** | **Source** |
|
||||
| -------------------- | :--------: | :--------------------------------------------------------------------: |
|
||||
| **`libp2p-kad-dht`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/kad_dht) |
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
@ -89,10 +89,10 @@ ______________________________________________________________________
|
||||
|
||||
### Stream Muxers
|
||||
|
||||
| **Stream Muxers** | **Status** | **Status** |
|
||||
| ------------------ | :--------: | :----------------------------------------------------------------------------------------: |
|
||||
| **`libp2p-yamux`** | 🌱 | |
|
||||
| **`libp2p-mplex`** | 🛠️ | [source](https://github.com/libp2p/py-libp2p/blob/main/libp2p/stream_muxer/mplex/mplex.py) |
|
||||
| **Stream Muxers** | **Status** | **Source** |
|
||||
| ------------------ | :--------: | :-------------------------------------------------------------------------------: |
|
||||
| **`libp2p-yamux`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/stream_muxer/yamux) |
|
||||
| **`libp2p-mplex`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/stream_muxer/mplex) |
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
@ -100,7 +100,7 @@ ______________________________________________________________________
|
||||
|
||||
| **Storage** | **Status** |
|
||||
| ------------------- | :--------: |
|
||||
| **`libp2p-record`** | ❌ |
|
||||
| **`libp2p-record`** | 🌱 |
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
|
||||
194
docs/examples.multiple_connections.rst
Normal file
194
docs/examples.multiple_connections.rst
Normal file
@ -0,0 +1,194 @@
|
||||
Multiple Connections Per Peer
|
||||
=============================
|
||||
|
||||
This example demonstrates how to use the multiple connections per peer feature in py-libp2p.
|
||||
|
||||
Overview
|
||||
--------
|
||||
|
||||
The multiple connections per peer feature allows a libp2p node to maintain multiple network connections to the same peer. This provides several benefits:
|
||||
|
||||
- **Improved reliability**: If one connection fails, others remain available
|
||||
- **Better performance**: Load can be distributed across multiple connections
|
||||
- **Enhanced throughput**: Multiple streams can be created in parallel
|
||||
- **Fault tolerance**: Redundant connections provide backup paths
|
||||
|
||||
Configuration
|
||||
-------------
|
||||
|
||||
The feature is configured through the `ConnectionConfig` class:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from libp2p.network.swarm import ConnectionConfig
|
||||
|
||||
# Default configuration
|
||||
config = ConnectionConfig()
|
||||
print(f"Max connections per peer: {config.max_connections_per_peer}")
|
||||
print(f"Load balancing strategy: {config.load_balancing_strategy}")
|
||||
|
||||
# Custom configuration
|
||||
custom_config = ConnectionConfig(
|
||||
max_connections_per_peer=5,
|
||||
connection_timeout=60.0,
|
||||
load_balancing_strategy="least_loaded"
|
||||
)
|
||||
|
||||
Load Balancing Strategies
|
||||
-------------------------
|
||||
|
||||
Two load balancing strategies are available:
|
||||
|
||||
**Round Robin** (default)
|
||||
Cycles through connections in order, distributing load evenly.
|
||||
|
||||
**Least Loaded**
|
||||
Selects the connection with the fewest active streams.
|
||||
|
||||
API Usage
|
||||
---------
|
||||
|
||||
The new API provides direct access to multiple connections:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from libp2p import new_swarm
|
||||
|
||||
# Create swarm with multiple connections support
|
||||
swarm = new_swarm()
|
||||
|
||||
# Dial a peer - returns list of connections
|
||||
connections = await swarm.dial_peer(peer_id)
|
||||
print(f"Established {len(connections)} connections")
|
||||
|
||||
# Get all connections to a peer
|
||||
peer_connections = swarm.get_connections(peer_id)
|
||||
|
||||
# Get all connections (across all peers)
|
||||
all_connections = swarm.get_connections()
|
||||
|
||||
# Get the complete connections map
|
||||
connections_map = swarm.get_connections_map()
|
||||
|
||||
# Backward compatibility - get single connection
|
||||
single_conn = swarm.get_connection(peer_id)
|
||||
|
||||
Backward Compatibility
|
||||
----------------------
|
||||
|
||||
Existing code continues to work through backward compatibility features:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Legacy 1:1 mapping (returns first connection for each peer)
|
||||
legacy_connections = swarm.connections_legacy
|
||||
|
||||
# Single connection access (returns first available connection)
|
||||
conn = swarm.get_connection(peer_id)
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
A complete working example is available in the `examples/doc-examples/multiple_connections_example.py` file.
|
||||
|
||||
Production Configuration
|
||||
-------------------------
|
||||
|
||||
For production use, consider these settings:
|
||||
|
||||
**RetryConfig Parameters**
|
||||
|
||||
The `RetryConfig` class controls connection retry behavior with exponential backoff:
|
||||
|
||||
- **max_retries**: Maximum number of retry attempts before giving up (default: 3)
|
||||
- **initial_delay**: Initial delay in seconds before the first retry (default: 0.1s)
|
||||
- **max_delay**: Maximum delay cap to prevent excessive wait times (default: 30.0s)
|
||||
- **backoff_multiplier**: Exponential backoff multiplier - each retry multiplies delay by this factor (default: 2.0)
|
||||
- **jitter_factor**: Random jitter (0.0-1.0) to prevent synchronized retries (default: 0.1)
|
||||
|
||||
**ConnectionConfig Parameters**
|
||||
|
||||
The `ConnectionConfig` class manages multi-connection behavior:
|
||||
|
||||
- **max_connections_per_peer**: Maximum connections allowed to a single peer (default: 3)
|
||||
- **connection_timeout**: Timeout for establishing new connections in seconds (default: 30.0s)
|
||||
- **load_balancing_strategy**: Strategy for distributing streams ("round_robin" or "least_loaded")
|
||||
|
||||
**Load Balancing Strategies Explained**
|
||||
|
||||
- **round_robin**: Cycles through connections in order, distributing load evenly. Simple and predictable.
|
||||
- **least_loaded**: Selects the connection with the fewest active streams. Better for performance but more complex.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from libp2p.network.swarm import ConnectionConfig, RetryConfig
|
||||
|
||||
# Production-ready configuration
|
||||
retry_config = RetryConfig(
|
||||
max_retries=3, # Maximum retry attempts before giving up
|
||||
initial_delay=0.1, # Start with 100ms delay
|
||||
max_delay=30.0, # Cap exponential backoff at 30 seconds
|
||||
backoff_multiplier=2.0, # Double delay each retry (100ms -> 200ms -> 400ms)
|
||||
jitter_factor=0.1 # Add 10% random jitter to prevent thundering herd
|
||||
)
|
||||
|
||||
connection_config = ConnectionConfig(
|
||||
max_connections_per_peer=3, # Allow up to 3 connections per peer
|
||||
connection_timeout=30.0, # 30 second timeout for new connections
|
||||
load_balancing_strategy="round_robin" # Simple, predictable load distribution
|
||||
)
|
||||
|
||||
swarm = new_swarm(
|
||||
retry_config=retry_config,
|
||||
connection_config=connection_config
|
||||
)
|
||||
|
||||
**How RetryConfig Works in Practice**
|
||||
|
||||
With the configuration above, connection retries follow this pattern:
|
||||
|
||||
1. **Attempt 1**: Immediate connection attempt
|
||||
2. **Attempt 2**: Wait 100ms ± 10ms jitter, then retry
|
||||
3. **Attempt 3**: Wait 200ms ± 20ms jitter, then retry
|
||||
4. **Attempt 4**: Wait 400ms ± 40ms jitter, then retry
|
||||
5. **Attempt 5**: Wait 800ms ± 80ms jitter, then retry
|
||||
6. **Attempt 6**: Wait 1.6s ± 160ms jitter, then retry
|
||||
7. **Attempt 7**: Wait 3.2s ± 320ms jitter, then retry
|
||||
8. **Attempt 8**: Wait 6.4s ± 640ms jitter, then retry
|
||||
9. **Attempt 9**: Wait 12.8s ± 1.28s jitter, then retry
|
||||
10. **Attempt 10**: Wait 25.6s ± 2.56s jitter, then retry
|
||||
11. **Attempt 11**: Wait 30.0s (capped) ± 3.0s jitter, then retry
|
||||
12. **Attempt 12**: Wait 30.0s (capped) ± 3.0s jitter, then retry
|
||||
13. **Give up**: After 12 retries (3 initial + 9 retries), connection fails
|
||||
|
||||
The jitter prevents multiple clients from retrying simultaneously, reducing server load.
|
||||
|
||||
**Parameter Tuning Guidelines**
|
||||
|
||||
**For Development/Testing:**
|
||||
- Use lower `max_retries` (1-2) and shorter delays for faster feedback
|
||||
- Example: `RetryConfig(max_retries=2, initial_delay=0.01, max_delay=0.1)`
|
||||
|
||||
**For Production:**
|
||||
- Use moderate `max_retries` (3-5) with reasonable delays for reliability
|
||||
- Example: `RetryConfig(max_retries=5, initial_delay=0.1, max_delay=60.0)`
|
||||
|
||||
**For High-Latency Networks:**
|
||||
- Use higher `max_retries` (5-10) with longer delays
|
||||
- Example: `RetryConfig(max_retries=8, initial_delay=0.5, max_delay=120.0)`
|
||||
|
||||
**For Load Balancing:**
|
||||
- Use `round_robin` for simple, predictable behavior
|
||||
- Use `least_loaded` when you need optimal performance and can handle complexity
|
||||
|
||||
Architecture
|
||||
------------
|
||||
|
||||
The implementation follows the same architectural patterns as the Go and JavaScript reference implementations:
|
||||
|
||||
- **Core data structure**: `dict[ID, list[INetConn]]` for 1:many mapping
|
||||
- **API consistency**: Methods like `get_connections()` match reference implementations
|
||||
- **Load balancing**: Integrated at the API level for optimal performance
|
||||
- **Backward compatibility**: Maintains existing interfaces for gradual migration
|
||||
|
||||
This design ensures consistency across libp2p implementations while providing the benefits of multiple connections per peer.
|
||||
131
docs/examples.random_walk.rst
Normal file
131
docs/examples.random_walk.rst
Normal file
@ -0,0 +1,131 @@
|
||||
Random Walk Example
|
||||
===================
|
||||
|
||||
This example demonstrates the Random Walk module's peer discovery capabilities using real libp2p hosts and Kademlia DHT.
|
||||
It shows how the Random Walk module automatically discovers new peers and maintains routing table health.
|
||||
|
||||
The Random Walk implementation performs the following key operations:
|
||||
|
||||
* **Automatic Peer Discovery**: Generates random peer IDs and queries the DHT network to discover new peers
|
||||
* **Routing Table Maintenance**: Periodically refreshes the routing table to maintain network connectivity
|
||||
* **Connection Management**: Maintains optimal connections to healthy peers in the network
|
||||
* **Real-time Statistics**: Displays routing table size, connected peers, and peerstore statistics
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python -m pip install libp2p
|
||||
Collecting libp2p
|
||||
...
|
||||
Successfully installed libp2p-x.x.x
|
||||
$ cd examples/random_walk
|
||||
$ python random_walk.py --mode server
|
||||
2025-08-12 19:51:25,424 - random-walk-example - INFO - === Random Walk Example for py-libp2p ===
|
||||
2025-08-12 19:51:25,424 - random-walk-example - INFO - Mode: server, Port: 0 Demo interval: 30s
|
||||
2025-08-12 19:51:25,426 - random-walk-example - INFO - Starting server node on port 45123
|
||||
2025-08-12 19:51:25,426 - random-walk-example - INFO - Node peer ID: 16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef
|
||||
2025-08-12 19:51:25,426 - random-walk-example - INFO - Node address: /ip4/0.0.0.0/tcp/45123/p2p/16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef
|
||||
2025-08-12 19:51:25,427 - random-walk-example - INFO - Initial routing table size: 0
|
||||
2025-08-12 19:51:25,427 - random-walk-example - INFO - DHT service started in SERVER mode
|
||||
2025-08-12 19:51:25,430 - libp2p.discovery.random_walk.rt_refresh_manager - INFO - RT Refresh Manager started
|
||||
2025-08-12 19:51:55,432 - random-walk-example - INFO - --- Iteration 1 ---
|
||||
2025-08-12 19:51:55,432 - random-walk-example - INFO - Routing table size: 15
|
||||
2025-08-12 19:51:55,432 - random-walk-example - INFO - Connected peers: 8
|
||||
2025-08-12 19:51:55,432 - random-walk-example - INFO - Peerstore size: 42
|
||||
|
||||
You can also run the example in client mode:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python random_walk.py --mode client
|
||||
2025-08-12 19:52:15,424 - random-walk-example - INFO - === Random Walk Example for py-libp2p ===
|
||||
2025-08-12 19:52:15,424 - random-walk-example - INFO - Mode: client, Port: 0 Demo interval: 30s
|
||||
2025-08-12 19:52:15,426 - random-walk-example - INFO - Starting client node on port 51234
|
||||
2025-08-12 19:52:15,426 - random-walk-example - INFO - Node peer ID: 16Uiu2HAmAbc123xyz...
|
||||
2025-08-12 19:52:15,427 - random-walk-example - INFO - DHT service started in CLIENT mode
|
||||
2025-08-12 19:52:45,432 - random-walk-example - INFO - --- Iteration 1 ---
|
||||
2025-08-12 19:52:45,432 - random-walk-example - INFO - Routing table size: 8
|
||||
2025-08-12 19:52:45,432 - random-walk-example - INFO - Connected peers: 5
|
||||
2025-08-12 19:52:45,432 - random-walk-example - INFO - Peerstore size: 25
|
||||
|
||||
Command Line Options
|
||||
--------------------
|
||||
|
||||
The example supports several command-line options:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python random_walk.py --help
|
||||
usage: random_walk.py [-h] [--mode {server,client}] [--port PORT]
|
||||
[--demo-interval DEMO_INTERVAL] [--verbose]
|
||||
|
||||
Random Walk Example for py-libp2p Kademlia DHT
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--mode {server,client}
|
||||
Node mode: server (DHT server), or client (DHT client)
|
||||
--port PORT Port to listen on (0 for random)
|
||||
--demo-interval DEMO_INTERVAL
|
||||
Interval between random walk demonstrations in seconds
|
||||
--verbose Enable verbose logging
|
||||
|
||||
Key Features Demonstrated
|
||||
-------------------------
|
||||
|
||||
**Automatic Random Walk Discovery**
|
||||
The example shows how the Random Walk module automatically:
|
||||
|
||||
* Generates random 256-bit peer IDs for discovery queries
|
||||
* Performs concurrent random walks to maximize peer discovery
|
||||
* Validates discovered peers and adds them to the routing table
|
||||
* Maintains routing table health through periodic refreshes
|
||||
|
||||
**Real-time Network Statistics**
|
||||
The example displays live statistics every 30 seconds (configurable):
|
||||
|
||||
* **Routing Table Size**: Number of peers in the Kademlia routing table
|
||||
* **Connected Peers**: Number of actively connected peers
|
||||
* **Peerstore Size**: Total number of known peers with addresses
|
||||
|
||||
**Connection Management**
|
||||
The example includes sophisticated connection management:
|
||||
|
||||
* Automatically maintains connections to healthy peers
|
||||
* Filters for compatible peers (TCP + IPv4 addresses)
|
||||
* Reconnects to maintain optimal network connectivity
|
||||
* Handles connection failures gracefully
|
||||
|
||||
**DHT Integration**
|
||||
Shows seamless integration between Random Walk and Kademlia DHT:
|
||||
|
||||
* RT Refresh Manager coordinates with the DHT routing table
|
||||
* Peer discovery feeds directly into DHT operations
|
||||
* Both SERVER and CLIENT modes supported
|
||||
* Bootstrap connectivity to public IPFS nodes
|
||||
|
||||
Understanding the Output
|
||||
------------------------
|
||||
|
||||
When you run the example, you'll see periodic statistics that show how the Random Walk module is working:
|
||||
|
||||
* **Initial Phase**: Routing table starts empty and quickly discovers peers
|
||||
* **Growth Phase**: Routing table size increases as more peers are discovered
|
||||
* **Maintenance Phase**: Routing table size stabilizes as the system maintains optimal peer connections
|
||||
|
||||
The Random Walk module runs automatically in the background, performing peer discovery queries every few minutes to ensure the routing table remains populated with fresh, reachable peers.
|
||||
|
||||
Configuration
|
||||
-------------
|
||||
|
||||
The Random Walk module can be configured through the following parameters in ``libp2p.discovery.random_walk.config``:
|
||||
|
||||
* ``RANDOM_WALK_ENABLED``: Enable/disable automatic random walks (default: True)
|
||||
* ``REFRESH_INTERVAL``: Time between automatic refreshes in seconds (default: 300)
|
||||
* ``RANDOM_WALK_CONCURRENCY``: Number of concurrent random walks (default: 3)
|
||||
* ``MIN_RT_REFRESH_THRESHOLD``: Minimum routing table size before triggering refresh (default: 4)
|
||||
|
||||
See Also
|
||||
--------
|
||||
|
||||
* :doc:`examples.kademlia` - Kademlia DHT value storage and content routing
|
||||
* :doc:`libp2p.discovery.random_walk` - Random Walk module API documentation
|
||||
@ -14,3 +14,5 @@ Examples
|
||||
examples.circuit_relay
|
||||
examples.kademlia
|
||||
examples.mDNS
|
||||
examples.random_walk
|
||||
examples.multiple_connections
|
||||
|
||||
13
docs/libp2p.discovery.bootstrap.rst
Normal file
13
docs/libp2p.discovery.bootstrap.rst
Normal file
@ -0,0 +1,13 @@
|
||||
libp2p.discovery.bootstrap package
|
||||
==================================
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
Module contents
|
||||
---------------
|
||||
|
||||
.. automodule:: libp2p.discovery.bootstrap
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
48
docs/libp2p.discovery.random_walk.rst
Normal file
48
docs/libp2p.discovery.random_walk.rst
Normal file
@ -0,0 +1,48 @@
|
||||
libp2p.discovery.random_walk package
|
||||
====================================
|
||||
|
||||
The Random Walk module implements a peer discovery mechanism.
|
||||
It performs random walks through the DHT network to discover new peers and maintain routing table health through periodic refreshes.
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
libp2p.discovery.random_walk.config module
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automodule:: libp2p.discovery.random_walk.config
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.discovery.random_walk.exceptions module
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automodule:: libp2p.discovery.random_walk.exceptions
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.discovery.random_walk.random_walk module
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automodule:: libp2p.discovery.random_walk.random_walk
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.discovery.random_walk.rt_refresh_manager module
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automodule:: libp2p.discovery.random_walk.rt_refresh_manager
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Module contents
|
||||
---------------
|
||||
|
||||
.. automodule:: libp2p.discovery.random_walk
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
@ -7,8 +7,10 @@ Subpackages
|
||||
.. toctree::
|
||||
:maxdepth: 4
|
||||
|
||||
libp2p.discovery.bootstrap
|
||||
libp2p.discovery.events
|
||||
libp2p.discovery.mdns
|
||||
libp2p.discovery.random_walk
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
63
examples/advanced/network_discover.py
Normal file
63
examples/advanced/network_discover.py
Normal file
@ -0,0 +1,63 @@
|
||||
"""
|
||||
Advanced demonstration of Thin Waist address handling.
|
||||
|
||||
Run:
|
||||
python -m examples.advanced.network_discovery
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
try:
|
||||
from libp2p.utils.address_validation import (
|
||||
expand_wildcard_address,
|
||||
get_available_interfaces,
|
||||
get_optimal_binding_address,
|
||||
)
|
||||
except ImportError:
|
||||
# Fallbacks if utilities are missing
|
||||
def get_available_interfaces(port: int, protocol: str = "tcp"):
|
||||
return [Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")]
|
||||
|
||||
def expand_wildcard_address(addr: Multiaddr, port: int | None = None):
|
||||
if port is None:
|
||||
return [addr]
|
||||
addr_str = str(addr).rsplit("/", 1)[0]
|
||||
return [Multiaddr(addr_str + f"/{port}")]
|
||||
|
||||
def get_optimal_binding_address(port: int, protocol: str = "tcp"):
|
||||
return Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
port = 8080
|
||||
interfaces = get_available_interfaces(port)
|
||||
print(f"Discovered interfaces for port {port}:")
|
||||
for a in interfaces:
|
||||
print(f" - {a}")
|
||||
|
||||
wildcard_v4 = Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
||||
expanded_v4 = expand_wildcard_address(wildcard_v4)
|
||||
print("\nExpanded IPv4 wildcard:")
|
||||
for a in expanded_v4:
|
||||
print(f" - {a}")
|
||||
|
||||
wildcard_v6 = Multiaddr(f"/ip6/::/tcp/{port}")
|
||||
expanded_v6 = expand_wildcard_address(wildcard_v6)
|
||||
print("\nExpanded IPv6 wildcard:")
|
||||
for a in expanded_v6:
|
||||
print(f" - {a}")
|
||||
|
||||
print("\nOptimal binding address heuristic result:")
|
||||
print(f" -> {get_optimal_binding_address(port)}")
|
||||
|
||||
override_port = 9000
|
||||
overridden = expand_wildcard_address(wildcard_v4, port=override_port)
|
||||
print(f"\nPort override expansion to {override_port}:")
|
||||
for a in overridden:
|
||||
print(f" - {a}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
136
examples/bootstrap/bootstrap.py
Normal file
136
examples/bootstrap/bootstrap.py
Normal file
@ -0,0 +1,136 @@
|
||||
import argparse
|
||||
import logging
|
||||
import secrets
|
||||
|
||||
import multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p import new_host
|
||||
from libp2p.abc import PeerInfo
|
||||
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||
from libp2p.discovery.events.peerDiscovery import peerDiscovery
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger("libp2p.discovery.bootstrap")
|
||||
logger.setLevel(logging.INFO)
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(
|
||||
logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
)
|
||||
logger.addHandler(handler)
|
||||
|
||||
# Configure root logger to only show warnings and above to reduce noise
|
||||
# This prevents verbose DEBUG messages from multiaddr, DNS, etc.
|
||||
logging.getLogger().setLevel(logging.WARNING)
|
||||
|
||||
# Specifically silence noisy libraries
|
||||
logging.getLogger("multiaddr").setLevel(logging.WARNING)
|
||||
logging.getLogger("root").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def on_peer_discovery(peer_info: PeerInfo) -> None:
|
||||
"""Handler for peer discovery events."""
|
||||
logger.info(f"🔍 Discovered peer: {peer_info.peer_id}")
|
||||
logger.debug(f" Addresses: {[str(addr) for addr in peer_info.addrs]}")
|
||||
|
||||
|
||||
# Example bootstrap peers
|
||||
BOOTSTRAP_PEERS = [
|
||||
"/dnsaddr/github.com/p2p/QmNnooDu7bfjPFoTZYxMNLWUQJyrVwtbZg5gBMjTezGAJN",
|
||||
"/dnsaddr/cloudflare.com/p2p/QmNnooDu7bfjPFoTZYxMNLWUQJyrVwtbZg5gBMjTezGAJN",
|
||||
"/dnsaddr/google.com/p2p/QmNnooDu7bfjPFoTZYxMNLWUQJyrVwtbZg5gBMjTezGAJN",
|
||||
"/dnsaddr/bootstrap.libp2p.io/p2p/QmNnooDu7bfjPFoTZYxMNLWUQJyrVwtbZg5gBMjTezGAJN",
|
||||
"/dnsaddr/bootstrap.libp2p.io/p2p/QmbLHAnMoJPWSCR5Zhtx6BHJX9KiKNN6tpvbUcqanj75Nb",
|
||||
"/ip4/104.131.131.82/tcp/4001/p2p/QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ",
|
||||
"/ip6/2604:a880:1:20::203:d001/tcp/4001/p2p/QmSoLPppuBtQSGwKDZT2M73ULpjvfd3aZ6ha4oFGL1KrGM",
|
||||
"/ip4/128.199.219.111/tcp/4001/p2p/QmSoLV4Bbm51jM9C4gDYZQ9Cy3U6aXMJDAbzgu2fzaDs64",
|
||||
"/ip4/104.236.76.40/tcp/4001/p2p/QmSoLV4Bbm51jM9C4gDYZQ9Cy3U6aXMJDAbzgu2fzaDs64",
|
||||
"/ip4/178.62.158.247/tcp/4001/p2p/QmSoLer265NRgSp2LA3dPaeykiS1J6DifTC88f5uVQKNAd",
|
||||
"/ip6/2604:a880:1:20::203:d001/tcp/4001/p2p/QmSoLPppuBtQSGwKDZT2M73ULpjvfd3aZ6ha4oFGL1KrGM",
|
||||
"/ip6/2400:6180:0:d0::151:6001/tcp/4001/p2p/QmSoLSafTMBsPKadTEgaXctDQVcqN88CNLHXMkTNwMKPnu",
|
||||
"/ip6/2a03:b0c0:0:1010::23:1001/tcp/4001/p2p/QmSoLueR4xBeUbY9WZ9xGUUxunbKWcrNFTDAadQJmocnWm",
|
||||
]
|
||||
|
||||
|
||||
async def run(port: int, bootstrap_addrs: list[str]) -> None:
|
||||
"""Run the bootstrap discovery example."""
|
||||
# Generate key pair
|
||||
secret = secrets.token_bytes(32)
|
||||
key_pair = create_new_key_pair(secret)
|
||||
|
||||
# Create listen address
|
||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
||||
|
||||
# Register peer discovery handler
|
||||
peerDiscovery.register_peer_discovered_handler(on_peer_discovery)
|
||||
|
||||
logger.info("🚀 Starting Bootstrap Discovery Example")
|
||||
logger.info(f"📍 Listening on: {listen_addr}")
|
||||
logger.info(f"🌐 Bootstrap peers: {len(bootstrap_addrs)}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Bootstrap Discovery Example")
|
||||
print("=" * 60)
|
||||
print("This example demonstrates connecting to bootstrap peers.")
|
||||
print("Watch the logs for peer discovery events!")
|
||||
print("Press Ctrl+C to exit.")
|
||||
print("=" * 60)
|
||||
|
||||
# Create and run host with bootstrap discovery
|
||||
host = new_host(key_pair=key_pair, bootstrap=bootstrap_addrs)
|
||||
|
||||
try:
|
||||
async with host.run(listen_addrs=[listen_addr]):
|
||||
# Keep running and log peer discovery events
|
||||
await trio.sleep_forever()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("👋 Shutting down...")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Main entry point."""
|
||||
description = """
|
||||
Bootstrap Discovery Example for py-libp2p
|
||||
|
||||
This example demonstrates how to use bootstrap peers for peer discovery.
|
||||
Bootstrap peers are predefined peers that help new nodes join the network.
|
||||
|
||||
Usage:
|
||||
python bootstrap.py -p 8000
|
||||
python bootstrap.py -p 8001 --custom-bootstrap \\
|
||||
"/ip4/127.0.0.1/tcp/8000/p2p/QmYourPeerID"
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=description, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p", "--port", default=0, type=int, help="Port to listen on (default: random)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--custom-bootstrap",
|
||||
nargs="*",
|
||||
help="Custom bootstrap addresses (space-separated)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v", "--verbose", action="store_true", help="Enable verbose output"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.verbose:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
# Use custom bootstrap addresses if provided, otherwise use defaults
|
||||
bootstrap_addrs = (
|
||||
args.custom_bootstrap if args.custom_bootstrap else BOOTSTRAP_PEERS
|
||||
)
|
||||
|
||||
try:
|
||||
trio.run(run, args.port, bootstrap_addrs)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Exiting...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -43,6 +43,9 @@ async def run(port: int, destination: str) -> None:
|
||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
||||
host = new_host()
|
||||
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
|
||||
# Start the peer-store cleanup task
|
||||
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
||||
|
||||
if not destination: # its the server
|
||||
|
||||
async def stream_handler(stream: INetStream) -> None:
|
||||
|
||||
@ -24,13 +24,8 @@ async def main():
|
||||
noise_transport = NoiseTransport(
|
||||
# local_key_pair: The key pair used for libp2p identity and authentication
|
||||
libp2p_keypair=key_pair,
|
||||
# noise_privkey: The private key used for Noise protocol encryption
|
||||
noise_privkey=key_pair.private_key,
|
||||
# early_data: Optional data to send during the handshake
|
||||
# (None means no early data)
|
||||
early_data=None,
|
||||
# with_noise_pipes: Whether to use Noise pipes for additional security features
|
||||
with_noise_pipes=False,
|
||||
# TODO: add early data
|
||||
)
|
||||
|
||||
# Create a security options dictionary mapping protocol ID to transport
|
||||
|
||||
@ -28,9 +28,7 @@ async def main():
|
||||
noise_privkey=key_pair.private_key,
|
||||
# early_data: Optional data to send during the handshake
|
||||
# (None means no early data)
|
||||
early_data=None,
|
||||
# with_noise_pipes: Whether to use Noise pipes for additional security features
|
||||
with_noise_pipes=False,
|
||||
# TODO: add early data
|
||||
)
|
||||
|
||||
# Create a security options dictionary mapping protocol ID to transport
|
||||
|
||||
@ -31,9 +31,7 @@ async def main():
|
||||
noise_privkey=key_pair.private_key,
|
||||
# early_data: Optional data to send during the handshake
|
||||
# (None means no early data)
|
||||
early_data=None,
|
||||
# with_noise_pipes: Whether to use Noise pipes for additional security features
|
||||
with_noise_pipes=False,
|
||||
# TODO: add early data
|
||||
)
|
||||
|
||||
# Create a security options dictionary mapping protocol ID to transport
|
||||
|
||||
@ -28,9 +28,7 @@ async def main():
|
||||
noise_privkey=key_pair.private_key,
|
||||
# early_data: Optional data to send during the handshake
|
||||
# (None means no early data)
|
||||
early_data=None,
|
||||
# with_noise_pipes: Whether to use Noise pipes for additional security features
|
||||
with_noise_pipes=False,
|
||||
# TODO: add early data
|
||||
)
|
||||
|
||||
# Create a security options dictionary mapping protocol ID to transport
|
||||
|
||||
170
examples/doc-examples/multiple_connections_example.py
Normal file
170
examples/doc-examples/multiple_connections_example.py
Normal file
@ -0,0 +1,170 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Example demonstrating multiple connections per peer support in libp2p.
|
||||
|
||||
This example shows how to:
|
||||
1. Configure multiple connections per peer
|
||||
2. Use different load balancing strategies
|
||||
3. Access multiple connections through the new API
|
||||
4. Maintain backward compatibility
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p import new_swarm
|
||||
from libp2p.network.swarm import ConnectionConfig, RetryConfig
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def example_basic_multiple_connections() -> None:
|
||||
"""Example of basic multiple connections per peer usage."""
|
||||
logger.info("Creating swarm with multiple connections support...")
|
||||
|
||||
# Create swarm with default configuration
|
||||
swarm = new_swarm()
|
||||
default_connection = ConnectionConfig()
|
||||
|
||||
logger.info(f"Swarm created with peer ID: {swarm.get_peer_id()}")
|
||||
logger.info(
|
||||
f"Connection config: max_connections_per_peer="
|
||||
f"{default_connection.max_connections_per_peer}"
|
||||
)
|
||||
|
||||
await swarm.close()
|
||||
logger.info("Basic multiple connections example completed")
|
||||
|
||||
|
||||
async def example_custom_connection_config() -> None:
|
||||
"""Example of custom connection configuration."""
|
||||
logger.info("Creating swarm with custom connection configuration...")
|
||||
|
||||
# Custom connection configuration for high-performance scenarios
|
||||
connection_config = ConnectionConfig(
|
||||
max_connections_per_peer=5, # More connections per peer
|
||||
connection_timeout=60.0, # Longer timeout
|
||||
load_balancing_strategy="least_loaded", # Use least loaded strategy
|
||||
)
|
||||
|
||||
# Create swarm with custom connection config
|
||||
swarm = new_swarm(connection_config=connection_config)
|
||||
|
||||
logger.info("Custom connection config applied:")
|
||||
logger.info(
|
||||
f" Max connections per peer: {connection_config.max_connections_per_peer}"
|
||||
)
|
||||
logger.info(f" Connection timeout: {connection_config.connection_timeout}s")
|
||||
logger.info(
|
||||
f" Load balancing strategy: {connection_config.load_balancing_strategy}"
|
||||
)
|
||||
|
||||
await swarm.close()
|
||||
logger.info("Custom connection config example completed")
|
||||
|
||||
|
||||
async def example_multiple_connections_api() -> None:
|
||||
"""Example of using the new multiple connections API."""
|
||||
logger.info("Demonstrating multiple connections API...")
|
||||
|
||||
connection_config = ConnectionConfig(
|
||||
max_connections_per_peer=3, load_balancing_strategy="round_robin"
|
||||
)
|
||||
|
||||
swarm = new_swarm(connection_config=connection_config)
|
||||
|
||||
logger.info("Multiple connections API features:")
|
||||
logger.info(" - dial_peer() returns list[INetConn]")
|
||||
logger.info(" - get_connections(peer_id) returns list[INetConn]")
|
||||
logger.info(" - get_connections_map() returns dict[ID, list[INetConn]]")
|
||||
logger.info(
|
||||
" - get_connection(peer_id) returns INetConn | None (backward compatibility)"
|
||||
)
|
||||
|
||||
await swarm.close()
|
||||
logger.info("Multiple connections API example completed")
|
||||
|
||||
|
||||
async def example_backward_compatibility() -> None:
|
||||
"""Example of backward compatibility features."""
|
||||
logger.info("Demonstrating backward compatibility...")
|
||||
|
||||
swarm = new_swarm()
|
||||
|
||||
logger.info("Backward compatibility features:")
|
||||
logger.info(" - connections_legacy property provides 1:1 mapping")
|
||||
logger.info(" - get_connection() method for single connection access")
|
||||
logger.info(" - Existing code continues to work")
|
||||
|
||||
await swarm.close()
|
||||
logger.info("Backward compatibility example completed")
|
||||
|
||||
|
||||
async def example_production_ready_config() -> None:
|
||||
"""Example of production-ready configuration."""
|
||||
logger.info("Creating swarm with production-ready configuration...")
|
||||
|
||||
# Production-ready retry configuration
|
||||
retry_config = RetryConfig(
|
||||
max_retries=3, # Reasonable retry limit
|
||||
initial_delay=0.1, # Quick initial retry
|
||||
max_delay=30.0, # Cap exponential backoff
|
||||
backoff_multiplier=2.0, # Standard exponential backoff
|
||||
jitter_factor=0.1, # Small jitter to prevent thundering herd
|
||||
)
|
||||
|
||||
# Production-ready connection configuration
|
||||
connection_config = ConnectionConfig(
|
||||
max_connections_per_peer=3, # Balance between performance and resource usage
|
||||
connection_timeout=30.0, # Reasonable timeout
|
||||
load_balancing_strategy="round_robin", # Simple, predictable strategy
|
||||
)
|
||||
|
||||
# Create swarm with production config
|
||||
swarm = new_swarm(retry_config=retry_config, connection_config=connection_config)
|
||||
|
||||
logger.info("Production-ready configuration applied:")
|
||||
logger.info(
|
||||
f" Retry: {retry_config.max_retries} retries, "
|
||||
f"{retry_config.max_delay}s max delay"
|
||||
)
|
||||
logger.info(f" Connections: {connection_config.max_connections_per_peer} per peer")
|
||||
logger.info(f" Load balancing: {connection_config.load_balancing_strategy}")
|
||||
|
||||
await swarm.close()
|
||||
logger.info("Production-ready configuration example completed")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Run all examples."""
|
||||
logger.info("Multiple Connections Per Peer Examples")
|
||||
logger.info("=" * 50)
|
||||
|
||||
try:
|
||||
await example_basic_multiple_connections()
|
||||
logger.info("-" * 30)
|
||||
|
||||
await example_custom_connection_config()
|
||||
logger.info("-" * 30)
|
||||
|
||||
await example_multiple_connections_api()
|
||||
logger.info("-" * 30)
|
||||
|
||||
await example_backward_compatibility()
|
||||
logger.info("-" * 30)
|
||||
|
||||
await example_production_ready_config()
|
||||
logger.info("-" * 30)
|
||||
|
||||
logger.info("All examples completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Example failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
trio.run(main)
|
||||
@ -1,4 +1,6 @@
|
||||
import argparse
|
||||
import random
|
||||
import secrets
|
||||
|
||||
import multiaddr
|
||||
import trio
|
||||
@ -12,49 +14,71 @@ from libp2p.crypto.secp256k1 import (
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.network.stream.exceptions import (
|
||||
StreamEOF,
|
||||
)
|
||||
from libp2p.network.stream.net_stream import (
|
||||
INetStream,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
info_from_p2p_addr,
|
||||
)
|
||||
from libp2p.utils.address_validation import (
|
||||
find_free_port,
|
||||
get_available_interfaces,
|
||||
)
|
||||
|
||||
PROTOCOL_ID = TProtocol("/echo/1.0.0")
|
||||
MAX_READ_LEN = 2**32 - 1
|
||||
|
||||
|
||||
async def _echo_stream_handler(stream: INetStream) -> None:
|
||||
# Wait until EOF
|
||||
msg = await stream.read(MAX_READ_LEN)
|
||||
await stream.write(msg)
|
||||
await stream.close()
|
||||
try:
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
print(f"Received connection from {peer_id}")
|
||||
# Wait until EOF
|
||||
msg = await stream.read(MAX_READ_LEN)
|
||||
print(f"Echoing message: {msg.decode('utf-8')}")
|
||||
await stream.write(msg)
|
||||
except StreamEOF:
|
||||
print("Stream closed by remote peer.")
|
||||
except Exception as e:
|
||||
print(f"Error in echo handler: {e}")
|
||||
finally:
|
||||
await stream.close()
|
||||
|
||||
|
||||
async def run(port: int, destination: str, seed: int | None = None) -> None:
|
||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
||||
if port <= 0:
|
||||
port = find_free_port()
|
||||
listen_addr = get_available_interfaces(port)
|
||||
|
||||
if seed:
|
||||
import random
|
||||
|
||||
random.seed(seed)
|
||||
secret_number = random.getrandbits(32 * 8)
|
||||
secret = secret_number.to_bytes(length=32, byteorder="big")
|
||||
else:
|
||||
import secrets
|
||||
|
||||
secret = secrets.token_bytes(32)
|
||||
|
||||
host = new_host(key_pair=create_new_key_pair(secret))
|
||||
async with host.run(listen_addrs=[listen_addr]):
|
||||
async with host.run(listen_addrs=listen_addr), trio.open_nursery() as nursery:
|
||||
# Start the peer-store cleanup task
|
||||
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
||||
|
||||
print(f"I am {host.get_id().to_string()}")
|
||||
|
||||
if not destination: # its the server
|
||||
host.set_stream_handler(PROTOCOL_ID, _echo_stream_handler)
|
||||
|
||||
# Print all listen addresses with peer ID (JS parity)
|
||||
print("Listener ready, listening on:\n")
|
||||
peer_id = host.get_id().to_string()
|
||||
for addr in listen_addr:
|
||||
print(f"{addr}/p2p/{peer_id}")
|
||||
|
||||
print(
|
||||
"Run this from the same folder in another console:\n\n"
|
||||
f"echo-demo "
|
||||
f"-d {host.get_addrs()[0]}\n"
|
||||
"\nRun this from the same folder in another console:\n\n"
|
||||
f"echo-demo -d {host.get_addrs()[0]}\n"
|
||||
)
|
||||
print("Waiting for incoming connections...")
|
||||
await trio.sleep_forever()
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
import base64
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import multiaddr
|
||||
import trio
|
||||
@ -13,6 +14,8 @@ from libp2p.identity.identify.identify import (
|
||||
identify_handler_for,
|
||||
parse_identify_response,
|
||||
)
|
||||
from libp2p.identity.identify.pb.identify_pb2 import Identify
|
||||
from libp2p.peer.envelope import debug_dump_envelope, unmarshal_envelope
|
||||
from libp2p.peer.peerinfo import (
|
||||
info_from_p2p_addr,
|
||||
)
|
||||
@ -31,10 +34,11 @@ def decode_multiaddrs(raw_addrs):
|
||||
return decoded_addrs
|
||||
|
||||
|
||||
def print_identify_response(identify_response):
|
||||
def print_identify_response(identify_response: Identify):
|
||||
"""Pretty-print Identify response."""
|
||||
public_key_b64 = base64.b64encode(identify_response.public_key).decode("utf-8")
|
||||
listen_addrs = decode_multiaddrs(identify_response.listen_addrs)
|
||||
signed_peer_record = unmarshal_envelope(identify_response.signedPeerRecord)
|
||||
try:
|
||||
observed_addr_decoded = decode_multiaddrs([identify_response.observed_addr])
|
||||
except Exception:
|
||||
@ -50,6 +54,8 @@ def print_identify_response(identify_response):
|
||||
f" Agent Version: {identify_response.agent_version}"
|
||||
)
|
||||
|
||||
debug_dump_envelope(signed_peer_record)
|
||||
|
||||
|
||||
async def run(port: int, destination: str, use_varint_format: bool = True) -> None:
|
||||
localhost_ip = "0.0.0.0"
|
||||
@ -60,58 +66,158 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No
|
||||
host_a = new_host()
|
||||
|
||||
# Set up identify handler with specified format
|
||||
# Set use_varint_format = False, if want to checkout the Signed-PeerRecord
|
||||
identify_handler = identify_handler_for(
|
||||
host_a, use_varint_format=use_varint_format
|
||||
)
|
||||
host_a.set_stream_handler(IDENTIFY_PROTOCOL_ID, identify_handler)
|
||||
|
||||
async with host_a.run(listen_addrs=[listen_addr]):
|
||||
async with (
|
||||
host_a.run(listen_addrs=[listen_addr]),
|
||||
trio.open_nursery() as nursery,
|
||||
):
|
||||
# Start the peer-store cleanup task
|
||||
nursery.start_soon(host_a.get_peerstore().start_cleanup_task, 60)
|
||||
|
||||
# Get the actual address and replace 0.0.0.0 with 127.0.0.1 for client
|
||||
# connections
|
||||
server_addr = str(host_a.get_addrs()[0])
|
||||
client_addr = server_addr.replace("/ip4/0.0.0.0/", "/ip4/127.0.0.1/")
|
||||
|
||||
format_name = "length-prefixed" if use_varint_format else "raw protobuf"
|
||||
format_flag = "--raw-format" if not use_varint_format else ""
|
||||
print(
|
||||
f"First host listening (using {format_name} format). "
|
||||
f"Run this from another console:\n\n"
|
||||
f"identify-demo "
|
||||
f"-d {client_addr}\n"
|
||||
f"identify-demo {format_flag} -d {client_addr}\n"
|
||||
)
|
||||
print("Waiting for incoming identify request...")
|
||||
await trio.sleep_forever()
|
||||
|
||||
# Add a custom handler to show connection events
|
||||
async def custom_identify_handler(stream):
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
print(f"\n🔗 Received identify request from peer: {peer_id}")
|
||||
|
||||
# Show remote address in multiaddr format
|
||||
try:
|
||||
from libp2p.identity.identify.identify import (
|
||||
_remote_address_to_multiaddr,
|
||||
)
|
||||
|
||||
remote_address = stream.get_remote_address()
|
||||
if remote_address:
|
||||
observed_multiaddr = _remote_address_to_multiaddr(
|
||||
remote_address
|
||||
)
|
||||
# Add the peer ID to create a complete multiaddr
|
||||
complete_multiaddr = f"{observed_multiaddr}/p2p/{peer_id}"
|
||||
print(f" Remote address: {complete_multiaddr}")
|
||||
else:
|
||||
print(f" Remote address: {remote_address}")
|
||||
except Exception:
|
||||
print(f" Remote address: {stream.get_remote_address()}")
|
||||
|
||||
# Call the original handler
|
||||
await identify_handler(stream)
|
||||
|
||||
print(f"✅ Successfully processed identify request from {peer_id}")
|
||||
|
||||
# Replace the handler with our custom one
|
||||
host_a.set_stream_handler(IDENTIFY_PROTOCOL_ID, custom_identify_handler)
|
||||
|
||||
try:
|
||||
await trio.sleep_forever()
|
||||
except KeyboardInterrupt:
|
||||
print("\n🛑 Shutting down listener...")
|
||||
logger.info("Listener interrupted by user")
|
||||
return
|
||||
|
||||
else:
|
||||
# Create second host (dialer)
|
||||
listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}")
|
||||
host_b = new_host()
|
||||
|
||||
async with host_b.run(listen_addrs=[listen_addr]):
|
||||
async with (
|
||||
host_b.run(listen_addrs=[listen_addr]),
|
||||
trio.open_nursery() as nursery,
|
||||
):
|
||||
# Start the peer-store cleanup task
|
||||
nursery.start_soon(host_b.get_peerstore().start_cleanup_task, 60)
|
||||
|
||||
# Connect to the first host
|
||||
print(f"dialer (host_b) listening on {host_b.get_addrs()[0]}")
|
||||
maddr = multiaddr.Multiaddr(destination)
|
||||
info = info_from_p2p_addr(maddr)
|
||||
print(f"Second host connecting to peer: {info.peer_id}")
|
||||
|
||||
await host_b.connect(info)
|
||||
try:
|
||||
await host_b.connect(info)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
if "unable to connect" in error_msg or "SwarmException" in error_msg:
|
||||
print(f"\n❌ Cannot connect to peer: {info.peer_id}")
|
||||
print(f" Address: {destination}")
|
||||
print(f" Error: {error_msg}")
|
||||
print(
|
||||
"\n💡 Make sure the peer is running and the address is correct."
|
||||
)
|
||||
return
|
||||
else:
|
||||
# Re-raise other exceptions
|
||||
raise
|
||||
|
||||
stream = await host_b.new_stream(info.peer_id, (IDENTIFY_PROTOCOL_ID,))
|
||||
|
||||
try:
|
||||
print("Starting identify protocol...")
|
||||
|
||||
# Read the complete response (could be either format)
|
||||
# Read a larger chunk to get all the data before stream closes
|
||||
response = await stream.read(8192) # Read enough data in one go
|
||||
# Read the response using the utility function
|
||||
from libp2p.utils.varint import read_length_prefixed_protobuf
|
||||
|
||||
response = await read_length_prefixed_protobuf(
|
||||
stream, use_varint_format
|
||||
)
|
||||
full_response = response
|
||||
|
||||
await stream.close()
|
||||
|
||||
# Parse the response using the robust protocol-level function
|
||||
# This handles both old and new formats automatically
|
||||
identify_msg = parse_identify_response(response)
|
||||
identify_msg = parse_identify_response(full_response)
|
||||
print_identify_response(identify_msg)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Identify protocol error: {e}")
|
||||
error_msg = str(e)
|
||||
print(f"Identify protocol error: {error_msg}")
|
||||
|
||||
# Check for specific format mismatch errors
|
||||
if "Error parsing message" in error_msg or "DecodeError" in error_msg:
|
||||
print("\n" + "=" * 60)
|
||||
print("FORMAT MISMATCH DETECTED!")
|
||||
print("=" * 60)
|
||||
if use_varint_format:
|
||||
print(
|
||||
"You are using length-prefixed format (default) but the "
|
||||
"listener"
|
||||
)
|
||||
print("is using raw protobuf format.")
|
||||
print(
|
||||
"\nTo fix this, run the dialer with the --raw-format flag:"
|
||||
)
|
||||
print(f"identify-demo --raw-format -d {destination}")
|
||||
else:
|
||||
print("You are using raw protobuf format but the listener")
|
||||
print("is using length-prefixed format (default).")
|
||||
print(
|
||||
"\nTo fix this, run the dialer without the --raw-format "
|
||||
"flag:"
|
||||
)
|
||||
print(f"identify-demo -d {destination}")
|
||||
print("=" * 60)
|
||||
else:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
return
|
||||
|
||||
@ -147,16 +253,27 @@ def main() -> None:
|
||||
"length-prefixed (new format)"
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Determine format: raw format if --raw-format is specified, otherwise
|
||||
# length-prefixed
|
||||
use_varint_format = not args.raw_format
|
||||
# Determine format: use varint (length-prefixed) if --raw-format is specified,
|
||||
# otherwise use raw protobuf format (old format)
|
||||
use_varint_format = args.raw_format
|
||||
|
||||
try:
|
||||
trio.run(run, *(args.port, args.destination, use_varint_format))
|
||||
if args.destination:
|
||||
# Run in dialer mode
|
||||
trio.run(run, *(args.port, args.destination, use_varint_format))
|
||||
else:
|
||||
# Run in listener mode
|
||||
trio.run(run, *(args.port, args.destination, use_varint_format))
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
print("\n👋 Goodbye!")
|
||||
logger.info("Application interrupted by user")
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error: {str(e)}")
|
||||
logger.error("Error: %s", str(e))
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -11,23 +11,26 @@ This example shows how to:
|
||||
|
||||
import logging
|
||||
|
||||
import multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p import (
|
||||
new_host,
|
||||
)
|
||||
from libp2p.abc import (
|
||||
INetStream,
|
||||
)
|
||||
from libp2p.crypto.secp256k1 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.identity.identify import (
|
||||
identify_handler_for,
|
||||
from libp2p.identity.identify.pb.identify_pb2 import (
|
||||
Identify,
|
||||
)
|
||||
from libp2p.identity.identify_push import (
|
||||
ID_PUSH,
|
||||
identify_push_handler_for,
|
||||
push_identify_to_peer,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
@ -38,8 +41,145 @@ from libp2p.peer.peerinfo import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_custom_identify_handler(host, host_name: str):
|
||||
"""Create a custom identify handler that displays received information."""
|
||||
|
||||
async def handle_identify(stream: INetStream) -> None:
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
print(f"\n🔍 {host_name} received identify request from peer: {peer_id}")
|
||||
|
||||
# Get the standard identify response using the existing function
|
||||
from libp2p.identity.identify.identify import (
|
||||
_mk_identify_protobuf,
|
||||
_remote_address_to_multiaddr,
|
||||
)
|
||||
|
||||
# Get observed address
|
||||
observed_multiaddr = None
|
||||
try:
|
||||
remote_address = stream.get_remote_address()
|
||||
if remote_address:
|
||||
observed_multiaddr = _remote_address_to_multiaddr(remote_address)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Build the identify protobuf
|
||||
identify_msg = _mk_identify_protobuf(host, observed_multiaddr)
|
||||
response_data = identify_msg.SerializeToString()
|
||||
|
||||
print(f" 📋 {host_name} identify information:")
|
||||
if identify_msg.HasField("protocol_version"):
|
||||
print(f" Protocol Version: {identify_msg.protocol_version}")
|
||||
if identify_msg.HasField("agent_version"):
|
||||
print(f" Agent Version: {identify_msg.agent_version}")
|
||||
if identify_msg.HasField("public_key"):
|
||||
print(f" Public Key: {identify_msg.public_key.hex()[:16]}...")
|
||||
if identify_msg.listen_addrs:
|
||||
print(" Listen Addresses:")
|
||||
for addr_bytes in identify_msg.listen_addrs:
|
||||
addr = multiaddr.Multiaddr(addr_bytes)
|
||||
print(f" - {addr}")
|
||||
if identify_msg.protocols:
|
||||
print(" Supported Protocols:")
|
||||
for protocol in identify_msg.protocols:
|
||||
print(f" - {protocol}")
|
||||
|
||||
# Send the response
|
||||
await stream.write(response_data)
|
||||
await stream.close()
|
||||
|
||||
return handle_identify
|
||||
|
||||
|
||||
def create_custom_identify_push_handler(host, host_name: str):
|
||||
"""Create a custom identify/push handler that displays received information."""
|
||||
|
||||
async def handle_identify_push(stream: INetStream) -> None:
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
print(f"\n📤 {host_name} received identify/push from peer: {peer_id}")
|
||||
|
||||
try:
|
||||
# Read the identify message using the utility function
|
||||
from libp2p.utils.varint import read_length_prefixed_protobuf
|
||||
|
||||
data = await read_length_prefixed_protobuf(stream, use_varint_format=True)
|
||||
|
||||
# Parse the identify message
|
||||
identify_msg = Identify()
|
||||
identify_msg.ParseFromString(data)
|
||||
|
||||
print(" 📋 Received identify information:")
|
||||
if identify_msg.HasField("protocol_version"):
|
||||
print(f" Protocol Version: {identify_msg.protocol_version}")
|
||||
if identify_msg.HasField("agent_version"):
|
||||
print(f" Agent Version: {identify_msg.agent_version}")
|
||||
if identify_msg.HasField("public_key"):
|
||||
print(f" Public Key: {identify_msg.public_key.hex()[:16]}...")
|
||||
if identify_msg.HasField("observed_addr") and identify_msg.observed_addr:
|
||||
observed_addr = multiaddr.Multiaddr(identify_msg.observed_addr)
|
||||
print(f" Observed Address: {observed_addr}")
|
||||
if identify_msg.listen_addrs:
|
||||
print(" Listen Addresses:")
|
||||
for addr_bytes in identify_msg.listen_addrs:
|
||||
addr = multiaddr.Multiaddr(addr_bytes)
|
||||
print(f" - {addr}")
|
||||
if identify_msg.protocols:
|
||||
print(" Supported Protocols:")
|
||||
for protocol in identify_msg.protocols:
|
||||
print(f" - {protocol}")
|
||||
|
||||
# Update the peerstore with the new information
|
||||
from libp2p.identity.identify_push.identify_push import (
|
||||
_update_peerstore_from_identify,
|
||||
)
|
||||
|
||||
await _update_peerstore_from_identify(
|
||||
host.get_peerstore(), peer_id, identify_msg
|
||||
)
|
||||
|
||||
print(f" ✅ {host_name} updated peerstore with new information")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ Error processing identify/push: {e}")
|
||||
finally:
|
||||
await stream.close()
|
||||
|
||||
return handle_identify_push
|
||||
|
||||
|
||||
async def display_peerstore_info(host, host_name: str, peer_id, description: str):
|
||||
"""Display peerstore information for a specific peer."""
|
||||
peerstore = host.get_peerstore()
|
||||
|
||||
try:
|
||||
addrs = peerstore.addrs(peer_id)
|
||||
except Exception:
|
||||
addrs = []
|
||||
|
||||
try:
|
||||
protocols = peerstore.get_protocols(peer_id)
|
||||
except Exception:
|
||||
protocols = []
|
||||
|
||||
print(f"\n📚 {host_name} peerstore for {description}:")
|
||||
print(f" Peer ID: {peer_id}")
|
||||
if addrs:
|
||||
print(" Addresses:")
|
||||
for addr in addrs:
|
||||
print(f" - {addr}")
|
||||
else:
|
||||
print(" Addresses: None")
|
||||
|
||||
if protocols:
|
||||
print(" Protocols:")
|
||||
for protocol in protocols:
|
||||
print(f" - {protocol}")
|
||||
else:
|
||||
print(" Protocols: None")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
print("\n==== Starting Identify-Push Example ====\n")
|
||||
print("\n==== Starting Enhanced Identify-Push Example ====\n")
|
||||
|
||||
# Create key pairs for the two hosts
|
||||
key_pair_1 = create_new_key_pair()
|
||||
@ -48,45 +188,57 @@ async def main() -> None:
|
||||
# Create the first host
|
||||
host_1 = new_host(key_pair=key_pair_1)
|
||||
|
||||
# Set up the identify and identify/push handlers
|
||||
host_1.set_stream_handler(TProtocol("/ipfs/id/1.0.0"), identify_handler_for(host_1))
|
||||
host_1.set_stream_handler(ID_PUSH, identify_push_handler_for(host_1))
|
||||
# Set up custom identify and identify/push handlers
|
||||
host_1.set_stream_handler(
|
||||
TProtocol("/ipfs/id/1.0.0"), create_custom_identify_handler(host_1, "Host 1")
|
||||
)
|
||||
host_1.set_stream_handler(
|
||||
ID_PUSH, create_custom_identify_push_handler(host_1, "Host 1")
|
||||
)
|
||||
|
||||
# Create the second host
|
||||
host_2 = new_host(key_pair=key_pair_2)
|
||||
|
||||
# Set up the identify and identify/push handlers
|
||||
host_2.set_stream_handler(TProtocol("/ipfs/id/1.0.0"), identify_handler_for(host_2))
|
||||
host_2.set_stream_handler(ID_PUSH, identify_push_handler_for(host_2))
|
||||
# Set up custom identify and identify/push handlers
|
||||
host_2.set_stream_handler(
|
||||
TProtocol("/ipfs/id/1.0.0"), create_custom_identify_handler(host_2, "Host 2")
|
||||
)
|
||||
host_2.set_stream_handler(
|
||||
ID_PUSH, create_custom_identify_push_handler(host_2, "Host 2")
|
||||
)
|
||||
|
||||
# Start listening on random ports using the run context manager
|
||||
import multiaddr
|
||||
|
||||
listen_addr_1 = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")
|
||||
listen_addr_2 = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")
|
||||
|
||||
async with host_1.run([listen_addr_1]), host_2.run([listen_addr_2]):
|
||||
async with (
|
||||
host_1.run([listen_addr_1]),
|
||||
host_2.run([listen_addr_2]),
|
||||
trio.open_nursery() as nursery,
|
||||
):
|
||||
# Start the peer-store cleanup task
|
||||
nursery.start_soon(host_1.get_peerstore().start_cleanup_task, 60)
|
||||
nursery.start_soon(host_2.get_peerstore().start_cleanup_task, 60)
|
||||
|
||||
# Get the addresses of both hosts
|
||||
addr_1 = host_1.get_addrs()[0]
|
||||
logger.info(f"Host 1 listening on {addr_1}")
|
||||
print(f"Host 1 listening on {addr_1}")
|
||||
print(f"Peer ID: {host_1.get_id().pretty()}")
|
||||
|
||||
addr_2 = host_2.get_addrs()[0]
|
||||
logger.info(f"Host 2 listening on {addr_2}")
|
||||
print(f"Host 2 listening on {addr_2}")
|
||||
print(f"Peer ID: {host_2.get_id().pretty()}")
|
||||
|
||||
print("\nConnecting Host 2 to Host 1...")
|
||||
print("🏠 Host Configuration:")
|
||||
print(f" Host 1: {addr_1}")
|
||||
print(f" Host 1 Peer ID: {host_1.get_id().pretty()}")
|
||||
print(f" Host 2: {addr_2}")
|
||||
print(f" Host 2 Peer ID: {host_2.get_id().pretty()}")
|
||||
|
||||
print("\n🔗 Connecting Host 2 to Host 1...")
|
||||
|
||||
# Connect host_2 to host_1
|
||||
peer_info = info_from_p2p_addr(addr_1)
|
||||
await host_2.connect(peer_info)
|
||||
logger.info("Host 2 connected to Host 1")
|
||||
print("Host 2 successfully connected to Host 1")
|
||||
print("✅ Host 2 successfully connected to Host 1")
|
||||
|
||||
# Run the identify protocol from host_2 to host_1
|
||||
# (so Host 1 learns Host 2's address)
|
||||
print("\n🔄 Running identify protocol (Host 2 → Host 1)...")
|
||||
from libp2p.identity.identify.identify import ID as IDENTIFY_PROTOCOL_ID
|
||||
|
||||
stream = await host_2.new_stream(host_1.get_id(), (IDENTIFY_PROTOCOL_ID,))
|
||||
@ -94,64 +246,58 @@ async def main() -> None:
|
||||
await stream.close()
|
||||
|
||||
# Run the identify protocol from host_1 to host_2
|
||||
# (so Host 2 learns Host 1's address)
|
||||
print("\n🔄 Running identify protocol (Host 1 → Host 2)...")
|
||||
stream = await host_1.new_stream(host_2.get_id(), (IDENTIFY_PROTOCOL_ID,))
|
||||
response = await stream.read()
|
||||
await stream.close()
|
||||
|
||||
# --- NEW CODE: Update Host 1's peerstore with Host 2's addresses ---
|
||||
from libp2p.identity.identify.pb.identify_pb2 import (
|
||||
Identify,
|
||||
)
|
||||
|
||||
# Update Host 1's peerstore with Host 2's addresses
|
||||
identify_msg = Identify()
|
||||
identify_msg.ParseFromString(response)
|
||||
peerstore_1 = host_1.get_peerstore()
|
||||
peer_id_2 = host_2.get_id()
|
||||
for addr_bytes in identify_msg.listen_addrs:
|
||||
maddr = multiaddr.Multiaddr(addr_bytes)
|
||||
# TTL can be any positive int
|
||||
peerstore_1.add_addr(
|
||||
peer_id_2,
|
||||
maddr,
|
||||
ttl=3600,
|
||||
)
|
||||
# --- END NEW CODE ---
|
||||
peerstore_1.add_addr(peer_id_2, maddr, ttl=3600)
|
||||
|
||||
# Now Host 1's peerstore should have Host 2's address
|
||||
peerstore_1 = host_1.get_peerstore()
|
||||
peer_id_2 = host_2.get_id()
|
||||
addrs_1_for_2 = peerstore_1.addrs(peer_id_2)
|
||||
logger.info(
|
||||
f"[DEBUG] Host 1 peerstore addresses for Host 2 before push: "
|
||||
f"{addrs_1_for_2}"
|
||||
)
|
||||
print(
|
||||
f"[DEBUG] Host 1 peerstore addresses for Host 2 before push: "
|
||||
f"{addrs_1_for_2}"
|
||||
# Display peerstore information before push
|
||||
await display_peerstore_info(
|
||||
host_1, "Host 1", peer_id_2, "Host 2 (before push)"
|
||||
)
|
||||
|
||||
# Push identify information from host_1 to host_2
|
||||
logger.info("Host 1 pushing identify information to Host 2")
|
||||
print("\nHost 1 pushing identify information to Host 2...")
|
||||
print("\n📤 Host 1 pushing identify information to Host 2...")
|
||||
|
||||
try:
|
||||
# Call push_identify_to_peer which now returns a boolean
|
||||
success = await push_identify_to_peer(host_1, host_2.get_id())
|
||||
|
||||
if success:
|
||||
logger.info("Identify push completed successfully")
|
||||
print("Identify push completed successfully!")
|
||||
print("✅ Identify push completed successfully!")
|
||||
else:
|
||||
logger.warning("Identify push didn't complete successfully")
|
||||
print("\nWarning: Identify push didn't complete successfully")
|
||||
print("⚠️ Identify push didn't complete successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during identify push: {str(e)}")
|
||||
print(f"\nError during identify push: {str(e)}")
|
||||
print(f"❌ Error during identify push: {str(e)}")
|
||||
|
||||
# Add this at the end of your async with block:
|
||||
await trio.sleep(0.5) # Give background tasks time to finish
|
||||
# Give a moment for the identify/push processing to complete
|
||||
await trio.sleep(0.5)
|
||||
|
||||
# Display peerstore information after push
|
||||
await display_peerstore_info(host_1, "Host 1", peer_id_2, "Host 2 (after push)")
|
||||
await display_peerstore_info(
|
||||
host_2, "Host 2", host_1.get_id(), "Host 1 (after push)"
|
||||
)
|
||||
|
||||
# Give more time for background tasks to finish and connections to stabilize
|
||||
print("\n⏳ Waiting for background tasks to complete...")
|
||||
await trio.sleep(1.0)
|
||||
|
||||
# Gracefully close connections to prevent connection errors
|
||||
print("🔌 Closing connections...")
|
||||
await host_2.disconnect(host_1.get_id())
|
||||
await trio.sleep(0.2)
|
||||
|
||||
print("\n🎉 Example completed successfully!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -41,6 +41,9 @@ from libp2p.identity.identify import (
|
||||
ID as ID_IDENTIFY,
|
||||
identify_handler_for,
|
||||
)
|
||||
from libp2p.identity.identify.identify import (
|
||||
_remote_address_to_multiaddr,
|
||||
)
|
||||
from libp2p.identity.identify.pb.identify_pb2 import (
|
||||
Identify,
|
||||
)
|
||||
@ -72,40 +75,30 @@ def custom_identify_push_handler_for(host, use_varint_format: bool = True):
|
||||
async def handle_identify_push(stream: INetStream) -> None:
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
|
||||
# Get remote address information
|
||||
try:
|
||||
if use_varint_format:
|
||||
# Read length-prefixed identify message from the stream
|
||||
from libp2p.utils.varint import decode_varint_from_bytes
|
||||
remote_address = stream.get_remote_address()
|
||||
if remote_address:
|
||||
observed_multiaddr = _remote_address_to_multiaddr(remote_address)
|
||||
logger.info(
|
||||
"Connection from remote peer %s, address: %s, multiaddr: %s",
|
||||
peer_id,
|
||||
remote_address,
|
||||
observed_multiaddr,
|
||||
)
|
||||
print(f"\n🔗 Received identify/push request from peer: {peer_id}")
|
||||
# Add the peer ID to create a complete multiaddr
|
||||
complete_multiaddr = f"{observed_multiaddr}/p2p/{peer_id}"
|
||||
print(f" Remote address: {complete_multiaddr}")
|
||||
except Exception as e:
|
||||
logger.error("Error getting remote address: %s", e)
|
||||
print(f"\n🔗 Received identify/push request from peer: {peer_id}")
|
||||
|
||||
# First read the varint length prefix
|
||||
length_bytes = b""
|
||||
while True:
|
||||
b = await stream.read(1)
|
||||
if not b:
|
||||
break
|
||||
length_bytes += b
|
||||
if b[0] & 0x80 == 0:
|
||||
break
|
||||
try:
|
||||
# Use the utility function to read the protobuf message
|
||||
from libp2p.utils.varint import read_length_prefixed_protobuf
|
||||
|
||||
if not length_bytes:
|
||||
logger.warning("No length prefix received from peer %s", peer_id)
|
||||
return
|
||||
|
||||
msg_length = decode_varint_from_bytes(length_bytes)
|
||||
|
||||
# Read the protobuf message
|
||||
data = await stream.read(msg_length)
|
||||
if len(data) != msg_length:
|
||||
logger.warning("Incomplete message received from peer %s", peer_id)
|
||||
return
|
||||
else:
|
||||
# Read raw protobuf message from the stream
|
||||
data = b""
|
||||
while True:
|
||||
chunk = await stream.read(4096)
|
||||
if not chunk:
|
||||
break
|
||||
data += chunk
|
||||
data = await read_length_prefixed_protobuf(stream, use_varint_format)
|
||||
|
||||
identify_msg = Identify()
|
||||
identify_msg.ParseFromString(data)
|
||||
@ -155,11 +148,41 @@ def custom_identify_push_handler_for(host, use_varint_format: bool = True):
|
||||
await _update_peerstore_from_identify(peerstore, peer_id, identify_msg)
|
||||
|
||||
logger.info("Successfully processed identify/push from peer %s", peer_id)
|
||||
print(f"\nSuccessfully processed identify/push from peer {peer_id}")
|
||||
print(f"✅ Successfully processed identify/push from peer {peer_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error processing identify/push from %s: %s", peer_id, e)
|
||||
print(f"\nError processing identify/push from {peer_id}: {e}")
|
||||
error_msg = str(e)
|
||||
logger.error(
|
||||
"Error processing identify/push from %s: %s", peer_id, error_msg
|
||||
)
|
||||
print(f"\nError processing identify/push from {peer_id}: {error_msg}")
|
||||
|
||||
# Check for specific format mismatch errors
|
||||
if (
|
||||
"Error parsing message" in error_msg
|
||||
or "DecodeError" in error_msg
|
||||
or "ParseFromString" in error_msg
|
||||
):
|
||||
print("\n" + "=" * 60)
|
||||
print("FORMAT MISMATCH DETECTED!")
|
||||
print("=" * 60)
|
||||
if use_varint_format:
|
||||
print(
|
||||
"You are using length-prefixed format (default) but the "
|
||||
"dialer is using raw protobuf format."
|
||||
)
|
||||
print("\nTo fix this, run the dialer with the --raw-format flag:")
|
||||
print(
|
||||
"identify-push-listener-dialer-demo --raw-format -d <ADDRESS>"
|
||||
)
|
||||
else:
|
||||
print("You are using raw protobuf format but the dialer")
|
||||
print("is using length-prefixed format (default).")
|
||||
print(
|
||||
"\nTo fix this, run the dialer without the --raw-format flag:"
|
||||
)
|
||||
print("identify-push-listener-dialer-demo -d <ADDRESS>")
|
||||
print("=" * 60)
|
||||
finally:
|
||||
# Close the stream after processing
|
||||
await stream.close()
|
||||
@ -167,7 +190,9 @@ def custom_identify_push_handler_for(host, use_varint_format: bool = True):
|
||||
return handle_identify_push
|
||||
|
||||
|
||||
async def run_listener(port: int, use_varint_format: bool = True) -> None:
|
||||
async def run_listener(
|
||||
port: int, use_varint_format: bool = True, raw_format_flag: bool = False
|
||||
) -> None:
|
||||
"""Run a host in listener mode."""
|
||||
format_name = "length-prefixed" if use_varint_format else "raw protobuf"
|
||||
print(
|
||||
@ -187,29 +212,41 @@ async def run_listener(port: int, use_varint_format: bool = True) -> None:
|
||||
)
|
||||
host.set_stream_handler(
|
||||
ID_IDENTIFY_PUSH,
|
||||
identify_push_handler_for(host, use_varint_format=use_varint_format),
|
||||
custom_identify_push_handler_for(host, use_varint_format=use_varint_format),
|
||||
)
|
||||
|
||||
# Start listening
|
||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
||||
|
||||
async with host.run([listen_addr]):
|
||||
addr = host.get_addrs()[0]
|
||||
logger.info("Listener host ready!")
|
||||
print("Listener host ready!")
|
||||
try:
|
||||
async with host.run([listen_addr]):
|
||||
addr = host.get_addrs()[0]
|
||||
logger.info("Listener host ready!")
|
||||
print("Listener host ready!")
|
||||
|
||||
logger.info(f"Listening on: {addr}")
|
||||
print(f"Listening on: {addr}")
|
||||
logger.info(f"Listening on: {addr}")
|
||||
print(f"Listening on: {addr}")
|
||||
|
||||
logger.info(f"Peer ID: {host.get_id().pretty()}")
|
||||
print(f"Peer ID: {host.get_id().pretty()}")
|
||||
logger.info(f"Peer ID: {host.get_id().pretty()}")
|
||||
print(f"Peer ID: {host.get_id().pretty()}")
|
||||
|
||||
print("\nRun dialer with command:")
|
||||
print(f"identify-push-listener-dialer-demo -d {addr}")
|
||||
print("\nWaiting for incoming connections... (Ctrl+C to exit)")
|
||||
print("\nRun dialer with command:")
|
||||
if raw_format_flag:
|
||||
print(f"identify-push-listener-dialer-demo -d {addr} --raw-format")
|
||||
else:
|
||||
print(f"identify-push-listener-dialer-demo -d {addr}")
|
||||
print("\nWaiting for incoming identify/push requests... (Ctrl+C to exit)")
|
||||
|
||||
# Keep running until interrupted
|
||||
await trio.sleep_forever()
|
||||
# Keep running until interrupted
|
||||
try:
|
||||
await trio.sleep_forever()
|
||||
except KeyboardInterrupt:
|
||||
print("\n🛑 Shutting down listener...")
|
||||
logger.info("Listener interrupted by user")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"Listener error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def run_dialer(
|
||||
@ -256,7 +293,9 @@ async def run_dialer(
|
||||
try:
|
||||
await host.connect(peer_info)
|
||||
logger.info("Successfully connected to listener!")
|
||||
print("Successfully connected to listener!")
|
||||
print("✅ Successfully connected to listener!")
|
||||
print(f" Connected to: {peer_info.peer_id}")
|
||||
print(f" Full address: {destination}")
|
||||
|
||||
# Push identify information to the listener
|
||||
logger.info("Pushing identify information to listener...")
|
||||
@ -270,7 +309,7 @@ async def run_dialer(
|
||||
|
||||
if success:
|
||||
logger.info("Identify push completed successfully!")
|
||||
print("Identify push completed successfully!")
|
||||
print("✅ Identify push completed successfully!")
|
||||
|
||||
logger.info("Example completed successfully!")
|
||||
print("\nExample completed successfully!")
|
||||
@ -281,17 +320,57 @@ async def run_dialer(
|
||||
logger.warning("Example completed with warnings.")
|
||||
print("Example completed with warnings.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during identify push: {str(e)}")
|
||||
print(f"\nError during identify push: {str(e)}")
|
||||
error_msg = str(e)
|
||||
logger.error(f"Error during identify push: {error_msg}")
|
||||
print(f"\nError during identify push: {error_msg}")
|
||||
|
||||
# Check for specific format mismatch errors
|
||||
if (
|
||||
"Error parsing message" in error_msg
|
||||
or "DecodeError" in error_msg
|
||||
or "ParseFromString" in error_msg
|
||||
):
|
||||
print("\n" + "=" * 60)
|
||||
print("FORMAT MISMATCH DETECTED!")
|
||||
print("=" * 60)
|
||||
if use_varint_format:
|
||||
print(
|
||||
"You are using length-prefixed format (default) but the "
|
||||
"listener is using raw protobuf format."
|
||||
)
|
||||
print(
|
||||
"\nTo fix this, run the dialer with the --raw-format flag:"
|
||||
)
|
||||
print(
|
||||
f"identify-push-listener-dialer-demo --raw-format -d "
|
||||
f"{destination}"
|
||||
)
|
||||
else:
|
||||
print("You are using raw protobuf format but the listener")
|
||||
print("is using length-prefixed format (default).")
|
||||
print(
|
||||
"\nTo fix this, run the dialer without the --raw-format "
|
||||
"flag:"
|
||||
)
|
||||
print(f"identify-push-listener-dialer-demo -d {destination}")
|
||||
print("=" * 60)
|
||||
|
||||
logger.error("Example completed with errors.")
|
||||
print("Example completed with errors.")
|
||||
# Continue execution despite the push error
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during dialer operation: {str(e)}")
|
||||
print(f"\nError during dialer operation: {str(e)}")
|
||||
raise
|
||||
error_msg = str(e)
|
||||
if "unable to connect" in error_msg or "SwarmException" in error_msg:
|
||||
print(f"\n❌ Cannot connect to peer: {peer_info.peer_id}")
|
||||
print(f" Address: {destination}")
|
||||
print(f" Error: {error_msg}")
|
||||
print("\n💡 Make sure the peer is running and the address is correct.")
|
||||
return
|
||||
else:
|
||||
logger.error(f"Error during dialer operation: {error_msg}")
|
||||
print(f"\nError during dialer operation: {error_msg}")
|
||||
raise
|
||||
|
||||
|
||||
def main() -> None:
|
||||
@ -301,12 +380,21 @@ def main() -> None:
|
||||
Without arguments, it runs as a listener on random port.
|
||||
With -d parameter, it runs as a dialer on random port.
|
||||
|
||||
Port 0 (default) means the OS will automatically assign an available port.
|
||||
This prevents port conflicts when running multiple instances.
|
||||
|
||||
Use --raw-format to send raw protobuf messages (old format) instead of
|
||||
length-prefixed protobuf messages (new format, default).
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--port",
|
||||
default=0,
|
||||
type=int,
|
||||
help="source port number (0 = random available port)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--destination",
|
||||
@ -321,6 +409,7 @@ def main() -> None:
|
||||
"length-prefixed (new format)"
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Determine format: raw format if --raw-format is specified, otherwise
|
||||
@ -333,12 +422,12 @@ def main() -> None:
|
||||
trio.run(run_dialer, args.port, args.destination, use_varint_format)
|
||||
else:
|
||||
# Run in listener mode with random available port if not specified
|
||||
trio.run(run_listener, args.port, use_varint_format)
|
||||
trio.run(run_listener, args.port, use_varint_format, args.raw_format)
|
||||
except KeyboardInterrupt:
|
||||
print("\nInterrupted by user")
|
||||
logger.info("Interrupted by user")
|
||||
print("\n👋 Goodbye!")
|
||||
logger.info("Application interrupted by user")
|
||||
except Exception as e:
|
||||
print(f"\nError: {str(e)}")
|
||||
print(f"\n❌ Error: {str(e)}")
|
||||
logger.error("Error: %s", str(e))
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@ -151,7 +151,10 @@ async def run_node(
|
||||
host = new_host(key_pair=key_pair)
|
||||
listen_addr = Multiaddr(f"/ip4/127.0.0.1/tcp/{port}")
|
||||
|
||||
async with host.run(listen_addrs=[listen_addr]):
|
||||
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
|
||||
# Start the peer-store cleanup task
|
||||
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
||||
|
||||
peer_id = host.get_id().pretty()
|
||||
addr_str = f"/ip4/127.0.0.1/tcp/{port}/p2p/{peer_id}"
|
||||
await connect_to_bootstrap_nodes(host, bootstrap_nodes)
|
||||
@ -224,7 +227,7 @@ async def run_node(
|
||||
|
||||
# Keep the node running
|
||||
while True:
|
||||
logger.debug(
|
||||
logger.info(
|
||||
"Status - Connected peers: %d,"
|
||||
"Peers in store: %d, Values in store: %d",
|
||||
len(dht.host.get_connected_peers()),
|
||||
|
||||
@ -46,7 +46,10 @@ async def run(port: int) -> None:
|
||||
|
||||
logger.info("Starting peer Discovery")
|
||||
host = new_host(key_pair=key_pair, enable_mDNS=True)
|
||||
async with host.run(listen_addrs=[listen_addr]):
|
||||
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
|
||||
# Start the peer-store cleanup task
|
||||
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
||||
|
||||
await trio.sleep_forever()
|
||||
|
||||
|
||||
|
||||
@ -59,6 +59,9 @@ async def run(port: int, destination: str) -> None:
|
||||
host = new_host(listen_addrs=[listen_addr])
|
||||
|
||||
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
|
||||
# Start the peer-store cleanup task
|
||||
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
||||
|
||||
if not destination:
|
||||
host.set_stream_handler(PING_PROTOCOL_ID, handle_ping)
|
||||
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import argparse
|
||||
import logging
|
||||
import socket
|
||||
|
||||
import base58
|
||||
import multiaddr
|
||||
@ -31,6 +30,9 @@ from libp2p.stream_muxer.mplex.mplex import (
|
||||
from libp2p.tools.async_service.trio_service import (
|
||||
background_trio_service,
|
||||
)
|
||||
from libp2p.utils.address_validation import (
|
||||
find_free_port,
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@ -77,13 +79,6 @@ async def publish_loop(pubsub, topic, termination_event):
|
||||
await trio.sleep(1) # Avoid tight loop on error
|
||||
|
||||
|
||||
def find_free_port():
|
||||
"""Find a free port on localhost."""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0)) # Bind to a free port provided by the OS
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
async def monitor_peer_topics(pubsub, nursery, termination_event):
|
||||
"""
|
||||
Monitor for new topics that peers are subscribed to and
|
||||
@ -144,6 +139,9 @@ async def run(topic: str, destination: str | None, port: int | None) -> None:
|
||||
pubsub = Pubsub(host, gossipsub)
|
||||
termination_event = trio.Event() # Event to signal termination
|
||||
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
|
||||
# Start the peer-store cleanup task
|
||||
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
||||
|
||||
logger.info(f"Node started with peer ID: {host.get_id()}")
|
||||
logger.info(f"Listening on: {listen_addr}")
|
||||
logger.info("Initializing PubSub and GossipSub...")
|
||||
|
||||
221
examples/random_walk/random_walk.py
Normal file
221
examples/random_walk/random_walk.py
Normal file
@ -0,0 +1,221 @@
|
||||
"""
|
||||
Random Walk Example for py-libp2p Kademlia DHT
|
||||
|
||||
This example demonstrates the Random Walk module's peer discovery capabilities
|
||||
using real libp2p hosts and Kademlia DHT. It shows how the Random Walk module
|
||||
automatically discovers new peers and maintains routing table health.
|
||||
|
||||
Usage:
|
||||
# Start server nodes (they will discover peers via random walk)
|
||||
python3 random_walk.py --mode server
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import random
|
||||
import secrets
|
||||
import sys
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p import new_host
|
||||
from libp2p.abc import IHost
|
||||
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||
from libp2p.kad_dht.kad_dht import DHTMode, KadDHT
|
||||
from libp2p.tools.async_service import background_trio_service
|
||||
|
||||
|
||||
# Simple logging configuration
|
||||
def setup_logging(verbose: bool = False):
|
||||
"""Setup unified logging configuration."""
|
||||
level = logging.DEBUG if verbose else logging.INFO
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[logging.StreamHandler()],
|
||||
)
|
||||
|
||||
# Configure key module loggers
|
||||
for module in ["libp2p.discovery.random_walk", "libp2p.kad_dht"]:
|
||||
logging.getLogger(module).setLevel(level)
|
||||
|
||||
# Suppress noisy logs
|
||||
logging.getLogger("multiaddr").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
logger = logging.getLogger("random-walk-example")
|
||||
|
||||
# Default bootstrap nodes
|
||||
DEFAULT_BOOTSTRAP_NODES = [
|
||||
"/ip4/104.131.131.82/tcp/4001/p2p/QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ"
|
||||
]
|
||||
|
||||
|
||||
def filter_compatible_peer_info(peer_info) -> bool:
|
||||
"""Filter peer info to check if it has compatible addresses (TCP + IPv4)."""
|
||||
if not hasattr(peer_info, "addrs") or not peer_info.addrs:
|
||||
return False
|
||||
|
||||
for addr in peer_info.addrs:
|
||||
addr_str = str(addr)
|
||||
if "/tcp/" in addr_str and "/ip4/" in addr_str and "/quic" not in addr_str:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def maintain_connections(host: IHost) -> None:
|
||||
"""Maintain connections to ensure the host remains connected to healthy peers."""
|
||||
while True:
|
||||
try:
|
||||
connected_peers = host.get_connected_peers()
|
||||
list_peers = host.get_peerstore().peers_with_addrs()
|
||||
|
||||
if len(connected_peers) < 20:
|
||||
logger.debug("Reconnecting to maintain peer connections...")
|
||||
|
||||
# Find compatible peers
|
||||
compatible_peers = []
|
||||
for peer_id in list_peers:
|
||||
try:
|
||||
peer_info = host.get_peerstore().peer_info(peer_id)
|
||||
if filter_compatible_peer_info(peer_info):
|
||||
compatible_peers.append(peer_id)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Connect to random subset of compatible peers
|
||||
if compatible_peers:
|
||||
random_peers = random.sample(
|
||||
compatible_peers, min(50, len(compatible_peers))
|
||||
)
|
||||
for peer_id in random_peers:
|
||||
if peer_id not in connected_peers:
|
||||
try:
|
||||
with trio.move_on_after(5):
|
||||
peer_info = host.get_peerstore().peer_info(peer_id)
|
||||
await host.connect(peer_info)
|
||||
logger.debug(f"Connected to peer: {peer_id}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to connect to {peer_id}: {e}")
|
||||
|
||||
await trio.sleep(15)
|
||||
except Exception as e:
|
||||
logger.error(f"Error maintaining connections: {e}")
|
||||
|
||||
|
||||
async def demonstrate_random_walk_discovery(dht: KadDHT, interval: int = 30) -> None:
|
||||
"""Demonstrate Random Walk peer discovery with periodic statistics."""
|
||||
iteration = 0
|
||||
while True:
|
||||
iteration += 1
|
||||
logger.info(f"--- Iteration {iteration} ---")
|
||||
logger.info(f"Routing table size: {dht.get_routing_table_size()}")
|
||||
logger.info(f"Connected peers: {len(dht.host.get_connected_peers())}")
|
||||
logger.info(f"Peerstore size: {len(dht.host.get_peerstore().peer_ids())}")
|
||||
await trio.sleep(interval)
|
||||
|
||||
|
||||
async def run_node(port: int, mode: str, demo_interval: int = 30) -> None:
|
||||
"""Run a node that demonstrates Random Walk peer discovery."""
|
||||
try:
|
||||
if port <= 0:
|
||||
port = random.randint(10000, 60000)
|
||||
|
||||
logger.info(f"Starting {mode} node on port {port}")
|
||||
|
||||
# Determine DHT mode
|
||||
dht_mode = DHTMode.SERVER if mode == "server" else DHTMode.CLIENT
|
||||
|
||||
# Create host and DHT
|
||||
key_pair = create_new_key_pair(secrets.token_bytes(32))
|
||||
host = new_host(key_pair=key_pair, bootstrap=DEFAULT_BOOTSTRAP_NODES)
|
||||
listen_addr = Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
||||
|
||||
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
|
||||
# Start maintenance tasks
|
||||
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
||||
nursery.start_soon(maintain_connections, host)
|
||||
|
||||
peer_id = host.get_id().pretty()
|
||||
logger.info(f"Node peer ID: {peer_id}")
|
||||
logger.info(f"Node address: /ip4/0.0.0.0/tcp/{port}/p2p/{peer_id}")
|
||||
|
||||
# Create and start DHT with Random Walk enabled
|
||||
dht = KadDHT(host, dht_mode, enable_random_walk=True)
|
||||
logger.info(f"Initial routing table size: {dht.get_routing_table_size()}")
|
||||
|
||||
async with background_trio_service(dht):
|
||||
logger.info(f"DHT service started in {dht_mode.value} mode")
|
||||
logger.info(f"Random Walk enabled: {dht.is_random_walk_enabled()}")
|
||||
|
||||
async with trio.open_nursery() as task_nursery:
|
||||
# Start demonstration and status reporting
|
||||
task_nursery.start_soon(
|
||||
demonstrate_random_walk_discovery, dht, demo_interval
|
||||
)
|
||||
|
||||
# Periodic status updates
|
||||
async def status_reporter():
|
||||
while True:
|
||||
await trio.sleep(30)
|
||||
logger.debug(
|
||||
f"Connected: {len(dht.host.get_connected_peers())}, "
|
||||
f"Routing table: {dht.get_routing_table_size()}, "
|
||||
f"Peerstore: {len(dht.host.get_peerstore().peer_ids())}"
|
||||
)
|
||||
|
||||
task_nursery.start_soon(status_reporter)
|
||||
await trio.sleep_forever()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Node error: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Random Walk Example for py-libp2p Kademlia DHT",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
choices=["server", "client"],
|
||||
default="server",
|
||||
help="Node mode: server (DHT server), or client (DHT client)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=0, help="Port to listen on (0 for random)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--demo-interval",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Interval between random walk demonstrations in seconds",
|
||||
)
|
||||
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the random walk example."""
|
||||
try:
|
||||
args = parse_args()
|
||||
setup_logging(args.verbose)
|
||||
|
||||
logger.info("=== Random Walk Example for py-libp2p ===")
|
||||
logger.info(
|
||||
f"Mode: {args.mode}, Port: {args.port} Demo interval: {args.demo_interval}s"
|
||||
)
|
||||
|
||||
trio.run(run_node, args.port, args.mode, args.demo_interval)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received interrupt signal, shutting down...")
|
||||
except Exception as e:
|
||||
logger.critical(f"Example failed: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,3 +1,5 @@
|
||||
"""Libp2p Python implementation."""
|
||||
|
||||
from collections.abc import (
|
||||
Mapping,
|
||||
Sequence,
|
||||
@ -6,15 +8,12 @@ from importlib.metadata import version as __version
|
||||
from typing import (
|
||||
Literal,
|
||||
Optional,
|
||||
Type,
|
||||
cast,
|
||||
)
|
||||
|
||||
import multiaddr
|
||||
|
||||
from libp2p.abc import (
|
||||
IHost,
|
||||
IMuxedConn,
|
||||
INetworkService,
|
||||
IPeerRouting,
|
||||
IPeerStore,
|
||||
@ -32,9 +31,6 @@ from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
TSecurityOptions,
|
||||
)
|
||||
from libp2p.discovery.mdns.mdns import (
|
||||
MDNSDiscovery,
|
||||
)
|
||||
from libp2p.host.basic_host import (
|
||||
BasicHost,
|
||||
)
|
||||
@ -42,6 +38,8 @@ from libp2p.host.routed_host import (
|
||||
RoutedHost,
|
||||
)
|
||||
from libp2p.network.swarm import (
|
||||
ConnectionConfig,
|
||||
RetryConfig,
|
||||
Swarm,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
@ -49,22 +47,25 @@ from libp2p.peer.id import (
|
||||
)
|
||||
from libp2p.peer.peerstore import (
|
||||
PeerStore,
|
||||
create_signed_peer_record,
|
||||
)
|
||||
from libp2p.security.insecure.transport import (
|
||||
PLAINTEXT_PROTOCOL_ID,
|
||||
InsecureTransport,
|
||||
)
|
||||
from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID
|
||||
from libp2p.security.noise.transport import Transport as NoiseTransport
|
||||
from libp2p.security.noise.transport import (
|
||||
PROTOCOL_ID as NOISE_PROTOCOL_ID,
|
||||
Transport as NoiseTransport,
|
||||
)
|
||||
import libp2p.security.secio.transport as secio
|
||||
from libp2p.stream_muxer.mplex.mplex import (
|
||||
MPLEX_PROTOCOL_ID,
|
||||
Mplex,
|
||||
)
|
||||
from libp2p.stream_muxer.yamux.yamux import (
|
||||
PROTOCOL_ID as YAMUX_PROTOCOL_ID,
|
||||
Yamux,
|
||||
)
|
||||
from libp2p.stream_muxer.yamux.yamux import PROTOCOL_ID as YAMUX_PROTOCOL_ID
|
||||
from libp2p.transport.tcp.tcp import (
|
||||
TCP,
|
||||
)
|
||||
@ -87,7 +88,6 @@ MUXER_MPLEX = "MPLEX"
|
||||
DEFAULT_NEGOTIATE_TIMEOUT = 5
|
||||
|
||||
|
||||
|
||||
def set_default_muxer(muxer_name: Literal["YAMUX", "MPLEX"]) -> None:
|
||||
"""
|
||||
Set the default multiplexer protocol to use.
|
||||
@ -155,7 +155,6 @@ def get_default_muxer_options() -> TMuxerOptions:
|
||||
else: # YAMUX is default
|
||||
return create_yamux_muxer_option()
|
||||
|
||||
|
||||
def new_swarm(
|
||||
key_pair: KeyPair | None = None,
|
||||
muxer_opt: TMuxerOptions | None = None,
|
||||
@ -163,6 +162,8 @@ def new_swarm(
|
||||
peerstore_opt: IPeerStore | None = None,
|
||||
muxer_preference: Literal["YAMUX", "MPLEX"] | None = None,
|
||||
listen_addrs: Sequence[multiaddr.Multiaddr] | None = None,
|
||||
retry_config: Optional["RetryConfig"] = None,
|
||||
connection_config: Optional["ConnectionConfig"] = None,
|
||||
) -> INetworkService:
|
||||
"""
|
||||
Create a swarm instance based on the parameters.
|
||||
@ -239,7 +240,14 @@ def new_swarm(
|
||||
# Store our key pair in peerstore
|
||||
peerstore.add_key_pair(id_opt, key_pair)
|
||||
|
||||
return Swarm(id_opt, peerstore, upgrader, transport)
|
||||
return Swarm(
|
||||
id_opt,
|
||||
peerstore,
|
||||
upgrader,
|
||||
transport,
|
||||
retry_config=retry_config,
|
||||
connection_config=connection_config
|
||||
)
|
||||
|
||||
|
||||
def new_host(
|
||||
@ -251,6 +259,7 @@ def new_host(
|
||||
muxer_preference: Literal["YAMUX", "MPLEX"] | None = None,
|
||||
listen_addrs: Sequence[multiaddr.Multiaddr] | None = None,
|
||||
enable_mDNS: bool = False,
|
||||
bootstrap: list[str] | None = None,
|
||||
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
) -> IHost:
|
||||
"""
|
||||
@ -264,6 +273,7 @@ def new_host(
|
||||
:param muxer_preference: optional explicit muxer preference
|
||||
:param listen_addrs: optional list of multiaddrs to listen on
|
||||
:param enable_mDNS: whether to enable mDNS discovery
|
||||
:param bootstrap: optional list of bootstrap peer addresses as strings
|
||||
:return: return a host instance
|
||||
"""
|
||||
swarm = new_swarm(
|
||||
@ -276,7 +286,13 @@ def new_host(
|
||||
)
|
||||
|
||||
if disc_opt is not None:
|
||||
return RoutedHost(swarm, disc_opt, enable_mDNS)
|
||||
return BasicHost(network=swarm,enable_mDNS=enable_mDNS , negotitate_timeout=negotiate_timeout)
|
||||
return RoutedHost(swarm, disc_opt, enable_mDNS, bootstrap)
|
||||
return BasicHost(
|
||||
network=swarm,
|
||||
enable_mDNS=enable_mDNS,
|
||||
bootstrap=bootstrap,
|
||||
negotitate_timeout=negotiate_timeout
|
||||
)
|
||||
|
||||
|
||||
__version__ = __version("libp2p")
|
||||
|
||||
334
libp2p/abc.py
334
libp2p/abc.py
@ -16,6 +16,7 @@ from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncContextManager,
|
||||
Optional,
|
||||
)
|
||||
|
||||
from multiaddr import (
|
||||
@ -41,20 +42,19 @@ from libp2p.io.abc import (
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
import libp2p.peer.pb.peer_record_pb2 as pb
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from libp2p.peer.envelope import Envelope
|
||||
from libp2p.peer.peer_record import PeerRecord
|
||||
from libp2p.protocol_muxer.multiselect import Multiselect
|
||||
from libp2p.pubsub.pubsub import (
|
||||
Pubsub,
|
||||
)
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from libp2p.protocol_muxer.multiselect import Multiselect
|
||||
|
||||
from libp2p.pubsub.pb import (
|
||||
rpc_pb2,
|
||||
)
|
||||
@ -357,6 +357,14 @@ class INetConn(Closer):
|
||||
:return: A tuple containing instances of INetStream.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_transport_addresses(self) -> list[Multiaddr]:
|
||||
"""
|
||||
Retrieve the transport addresses used by this connection.
|
||||
|
||||
:return: A list of multiaddresses used by the transport.
|
||||
"""
|
||||
|
||||
|
||||
# -------------------------- peermetadata interface.py --------------------------
|
||||
|
||||
@ -493,6 +501,71 @@ class IAddrBook(ABC):
|
||||
"""
|
||||
|
||||
|
||||
# ------------------ certified-addr-book interface.py ---------------------
|
||||
class ICertifiedAddrBook(ABC):
|
||||
"""
|
||||
Interface for a certified address book.
|
||||
|
||||
Provides methods for managing signed peer records
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def consume_peer_record(self, envelope: "Envelope", ttl: int) -> bool:
|
||||
"""
|
||||
Accept and store a signed PeerRecord, unless it's older than
|
||||
the one already stored.
|
||||
|
||||
This function:
|
||||
- Extracts the peer ID and sequence number from the envelope
|
||||
- Rejects the record if it's older (lower seq)
|
||||
- Updates the stored peer record and replaces associated
|
||||
addresses if accepted
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
envelope:
|
||||
Signed envelope containing a PeerRecord.
|
||||
ttl:
|
||||
Time-to-live for the included multiaddrs (in seconds).
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_peer_record(self, peer_id: ID) -> Optional["Envelope"]:
|
||||
"""
|
||||
Retrieve the most recent signed PeerRecord `Envelope` for a peer, if it exists
|
||||
and is still relevant.
|
||||
|
||||
First, it runs cleanup via `maybe_delete_peer_record` to purge stale data.
|
||||
Then it checks whether the peer has valid, unexpired addresses before
|
||||
returning the associated envelope.
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer to look up.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def maybe_delete_peer_record(self, peer_id: ID) -> None:
|
||||
"""
|
||||
Delete the signed peer record for a peer if it has no know
|
||||
(non-expired) addresses.
|
||||
|
||||
This is a garbage collection mechanism: if all addresses for a peer have expired
|
||||
or been cleared, there's no point holding onto its signed `Envelope`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer whose record we may delete.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# -------------------------- keybook interface.py --------------------------
|
||||
|
||||
|
||||
@ -758,7 +831,9 @@ class IProtoBook(ABC):
|
||||
# -------------------------- peerstore interface.py --------------------------
|
||||
|
||||
|
||||
class IPeerStore(IPeerMetadata, IAddrBook, IKeyBook, IMetrics, IProtoBook):
|
||||
class IPeerStore(
|
||||
IPeerMetadata, IAddrBook, ICertifiedAddrBook, IKeyBook, IMetrics, IProtoBook
|
||||
):
|
||||
"""
|
||||
Interface for a peer store.
|
||||
|
||||
@ -893,7 +968,73 @@ class IPeerStore(IPeerMetadata, IAddrBook, IKeyBook, IMetrics, IProtoBook):
|
||||
|
||||
"""
|
||||
|
||||
# --------CERTIFIED-ADDR-BOOK----------
|
||||
|
||||
@abstractmethod
|
||||
def get_local_record(self) -> Optional["Envelope"]:
|
||||
"""Get the local-peer-record wrapped in Envelope"""
|
||||
|
||||
@abstractmethod
|
||||
def set_local_record(self, envelope: "Envelope") -> None:
|
||||
"""Set the local-peer-record wrapped in Envelope"""
|
||||
|
||||
@abstractmethod
|
||||
def consume_peer_record(self, envelope: "Envelope", ttl: int) -> bool:
|
||||
"""
|
||||
Accept and store a signed PeerRecord, unless it's older
|
||||
than the one already stored.
|
||||
|
||||
This function:
|
||||
- Extracts the peer ID and sequence number from the envelope
|
||||
- Rejects the record if it's older (lower seq)
|
||||
- Updates the stored peer record and replaces associated addresses if accepted
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
envelope:
|
||||
Signed envelope containing a PeerRecord.
|
||||
ttl:
|
||||
Time-to-live for the included multiaddrs (in seconds).
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_peer_record(self, peer_id: ID) -> Optional["Envelope"]:
|
||||
"""
|
||||
Retrieve the most recent signed PeerRecord `Envelope` for a peer, if it exists
|
||||
and is still relevant.
|
||||
|
||||
First, it runs cleanup via `maybe_delete_peer_record` to purge stale data.
|
||||
Then it checks whether the peer has valid, unexpired addresses before
|
||||
returning the associated envelope.
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer to look up.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def maybe_delete_peer_record(self, peer_id: ID) -> None:
|
||||
"""
|
||||
Delete the signed peer record for a peer if it has no
|
||||
know (non-expired) addresses.
|
||||
|
||||
This is a garbage collection mechanism: if all addresses for a peer have expired
|
||||
or been cleared, there's no point holding onto its signed `Envelope`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer whose record we may delete.
|
||||
|
||||
"""
|
||||
|
||||
# --------KEY-BOOK----------
|
||||
|
||||
@abstractmethod
|
||||
def pubkey(self, peer_id: ID) -> PublicKey:
|
||||
"""
|
||||
@ -1202,6 +1343,10 @@ class IPeerStore(IPeerMetadata, IAddrBook, IKeyBook, IMetrics, IProtoBook):
|
||||
def clear_peerdata(self, peer_id: ID) -> None:
|
||||
"""clear_peerdata"""
|
||||
|
||||
@abstractmethod
|
||||
async def start_cleanup_task(self, cleanup_interval: int = 3600) -> None:
|
||||
"""Start periodic cleanup of expired peer records and addresses."""
|
||||
|
||||
|
||||
# -------------------------- listener interface.py --------------------------
|
||||
|
||||
@ -1267,15 +1412,16 @@ class INetwork(ABC):
|
||||
----------
|
||||
peerstore : IPeerStore
|
||||
The peer store for managing peer information.
|
||||
connections : dict[ID, INetConn]
|
||||
A mapping of peer IDs to network connections.
|
||||
connections : dict[ID, list[INetConn]]
|
||||
A mapping of peer IDs to lists of network connections
|
||||
(multiple connections per peer).
|
||||
listeners : dict[str, IListener]
|
||||
A mapping of listener identifiers to listener instances.
|
||||
|
||||
"""
|
||||
|
||||
peerstore: IPeerStore
|
||||
connections: dict[ID, INetConn]
|
||||
connections: dict[ID, list[INetConn]]
|
||||
listeners: dict[str, IListener]
|
||||
|
||||
@abstractmethod
|
||||
@ -1291,9 +1437,56 @@ class INetwork(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def dial_peer(self, peer_id: ID) -> INetConn:
|
||||
def get_connections(self, peer_id: ID | None = None) -> list[INetConn]:
|
||||
"""
|
||||
Create a connection to the specified peer.
|
||||
Get connections for peer (like JS getConnections, Go ConnsToPeer).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID | None
|
||||
The peer ID to get connections for. If None, returns all connections.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[INetConn]
|
||||
List of connections to the specified peer, or all connections
|
||||
if peer_id is None.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_connections_map(self) -> dict[ID, list[INetConn]]:
|
||||
"""
|
||||
Get all connections map (like JS getConnectionsMap).
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[ID, list[INetConn]]
|
||||
The complete mapping of peer IDs to their connection lists.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_connection(self, peer_id: ID) -> INetConn | None:
|
||||
"""
|
||||
Get single connection for backward compatibility.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer ID to get a connection for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
INetConn | None
|
||||
The first available connection, or None if no connections exist.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def dial_peer(self, peer_id: ID) -> list[INetConn]:
|
||||
"""
|
||||
Create connections to the specified peer with load balancing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -1302,8 +1495,8 @@ class INetwork(ABC):
|
||||
|
||||
Returns
|
||||
-------
|
||||
INetConn
|
||||
The network connection instance to the specified peer.
|
||||
list[INetConn]
|
||||
List of established connections to the peer.
|
||||
|
||||
Raises
|
||||
------
|
||||
@ -1689,6 +1882,121 @@ class IHost(ABC):
|
||||
"""
|
||||
|
||||
|
||||
# -------------------------- peer-record interface.py --------------------------
|
||||
class IPeerRecord(ABC):
|
||||
"""
|
||||
Interface for a libp2p PeerRecord object.
|
||||
|
||||
A PeerRecord contains metadata about a peer such as its ID, public addresses,
|
||||
and a strictly increasing sequence number for versioning.
|
||||
|
||||
PeerRecords are used in signed routing Envelopes for secure peer data propagation.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def domain(self) -> str:
|
||||
"""
|
||||
Return the domain string for this record type.
|
||||
|
||||
Used in envelope validation to distinguish different record types.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def codec(self) -> bytes:
|
||||
"""
|
||||
Return a binary codec prefix that identifies the PeerRecord type.
|
||||
|
||||
This is prepended in signed envelopes to allow type-safe decoding.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def to_protobuf(self) -> pb.PeerRecord:
|
||||
"""
|
||||
Convert this PeerRecord into its Protobuf representation.
|
||||
|
||||
:raises ValueError: if serialization fails (e.g., invalid peer ID).
|
||||
:return: A populated protobuf `PeerRecord` message.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def marshal_record(self) -> bytes:
|
||||
"""
|
||||
Serialize this PeerRecord into a byte string.
|
||||
|
||||
Used when signing or sealing the record in an envelope.
|
||||
|
||||
:raises ValueError: if protobuf serialization fails.
|
||||
:return: Byte-encoded PeerRecord.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def equal(self, other: object) -> bool:
|
||||
"""
|
||||
Compare this PeerRecord with another for equality.
|
||||
|
||||
Two PeerRecords are considered equal if:
|
||||
- They have the same `peer_id`
|
||||
- Their `seq` numbers match
|
||||
- Their address lists are identical and ordered
|
||||
|
||||
:param other: Object to compare with.
|
||||
:return: True if equal, False otherwise.
|
||||
"""
|
||||
|
||||
|
||||
# -------------------------- envelope interface.py --------------------------
|
||||
class IEnvelope(ABC):
|
||||
@abstractmethod
|
||||
def marshal_envelope(self) -> bytes:
|
||||
"""
|
||||
Serialize this Envelope into its protobuf wire format.
|
||||
|
||||
Converts all envelope fields into a `pb.Envelope` protobuf message
|
||||
and returns the serialized bytes.
|
||||
|
||||
:return: Serialized envelope as bytes.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def validate(self, domain: str) -> None:
|
||||
"""
|
||||
Verify the envelope's signature within the given domain scope.
|
||||
|
||||
This ensures that the envelope has not been tampered with
|
||||
and was signed under the correct usage context.
|
||||
|
||||
:param domain: Domain string that contextualizes the signature.
|
||||
:raises ValueError: If the signature is invalid.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def record(self) -> "PeerRecord":
|
||||
"""
|
||||
Lazily decode and return the embedded PeerRecord.
|
||||
|
||||
This method unmarshals the payload bytes into a `PeerRecord` instance,
|
||||
using the registered codec to identify the type. The decoded result
|
||||
is cached for future use.
|
||||
|
||||
:return: Decoded PeerRecord object.
|
||||
:raises Exception: If decoding fails or payload type is unsupported.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def equal(self, other: Any) -> bool:
|
||||
"""
|
||||
Compare this Envelope with another for structural equality.
|
||||
|
||||
Two envelopes are considered equal if:
|
||||
- They have the same public key
|
||||
- The payload type and payload bytes match
|
||||
- Their signatures are identical
|
||||
|
||||
:param other: Another object to compare.
|
||||
:return: True if equal, False otherwise.
|
||||
"""
|
||||
|
||||
|
||||
# -------------------------- peerdata interface.py --------------------------
|
||||
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ _sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1dlibp2p/crypto/pb/crypto.proto\x12\tcrypto.pb\"?\n\tPublicKey\x12$\n\x08key_type\x18\x01 \x02(\x0e\x32\x12.crypto.pb.KeyType\x12\x0c\n\x04\x64\x61ta\x18\x02 \x02(\x0c\"@\n\nPrivateKey\x12$\n\x08key_type\x18\x01 \x02(\x0e\x32\x12.crypto.pb.KeyType\x12\x0c\n\x04\x64\x61ta\x18\x02 \x02(\x0c*G\n\x07KeyType\x12\x07\n\x03RSA\x10\x00\x12\x0b\n\x07\x45\x64\x32\x35\x35\x31\x39\x10\x01\x12\r\n\tSecp256k1\x10\x02\x12\t\n\x05\x45\x43\x44SA\x10\x03\x12\x0c\n\x08\x45\x43\x43_P256\x10\x04')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1dlibp2p/crypto/pb/crypto.proto\x12\tcrypto.pb\"?\n\tPublicKey\x12$\n\x08key_type\x18\x01 \x02(\x0e\x32\x12.crypto.pb.KeyType\x12\x0c\n\x04\x64\x61ta\x18\x02 \x02(\x0c\"@\n\nPrivateKey\x12$\n\x08key_type\x18\x01 \x02(\x0e\x32\x12.crypto.pb.KeyType\x12\x0c\n\x04\x64\x61ta\x18\x02 \x02(\x0c*S\n\x07KeyType\x12\x07\n\x03RSA\x10\x00\x12\x0b\n\x07\x45\x64\x32\x35\x35\x31\x39\x10\x01\x12\r\n\tSecp256k1\x10\x02\x12\t\n\x05\x45\x43\x44SA\x10\x03\x12\x0c\n\x08\x45\x43\x43_P256\x10\x04\x12\n\n\x06X25519\x10\x05')
|
||||
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.crypto.pb.crypto_pb2', globals())
|
||||
@ -21,7 +21,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
|
||||
DESCRIPTOR._options = None
|
||||
_KEYTYPE._serialized_start=175
|
||||
_KEYTYPE._serialized_end=246
|
||||
_KEYTYPE._serialized_end=258
|
||||
_PUBLICKEY._serialized_start=44
|
||||
_PUBLICKEY._serialized_end=107
|
||||
_PRIVATEKEY._serialized_start=109
|
||||
|
||||
@ -28,6 +28,7 @@ class _KeyTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTy
|
||||
Secp256k1: _KeyType.ValueType # 2
|
||||
ECDSA: _KeyType.ValueType # 3
|
||||
ECC_P256: _KeyType.ValueType # 4
|
||||
X25519: _KeyType.ValueType # 5
|
||||
|
||||
class KeyType(_KeyType, metaclass=_KeyTypeEnumTypeWrapper): ...
|
||||
|
||||
@ -36,6 +37,7 @@ Ed25519: KeyType.ValueType # 1
|
||||
Secp256k1: KeyType.ValueType # 2
|
||||
ECDSA: KeyType.ValueType # 3
|
||||
ECC_P256: KeyType.ValueType # 4
|
||||
X25519: KeyType.ValueType # 5
|
||||
global___KeyType = KeyType
|
||||
|
||||
@typing.final
|
||||
|
||||
5
libp2p/discovery/bootstrap/__init__.py
Normal file
5
libp2p/discovery/bootstrap/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
"""Bootstrap peer discovery module for py-libp2p."""
|
||||
|
||||
from .bootstrap import BootstrapDiscovery
|
||||
|
||||
__all__ = ["BootstrapDiscovery"]
|
||||
94
libp2p/discovery/bootstrap/bootstrap.py
Normal file
94
libp2p/discovery/bootstrap/bootstrap.py
Normal file
@ -0,0 +1,94 @@
|
||||
import logging
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
from multiaddr.resolvers import DNSResolver
|
||||
|
||||
from libp2p.abc import ID, INetworkService, PeerInfo
|
||||
from libp2p.discovery.bootstrap.utils import validate_bootstrap_addresses
|
||||
from libp2p.discovery.events.peerDiscovery import peerDiscovery
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
|
||||
logger = logging.getLogger("libp2p.discovery.bootstrap")
|
||||
resolver = DNSResolver()
|
||||
|
||||
|
||||
class BootstrapDiscovery:
|
||||
"""
|
||||
Bootstrap-based peer discovery for py-libp2p.
|
||||
Connects to predefined bootstrap peers and adds them to peerstore.
|
||||
"""
|
||||
|
||||
def __init__(self, swarm: INetworkService, bootstrap_addrs: list[str]):
|
||||
self.swarm = swarm
|
||||
self.peerstore = swarm.peerstore
|
||||
self.bootstrap_addrs = bootstrap_addrs or []
|
||||
self.discovered_peers: set[str] = set()
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Process bootstrap addresses and emit peer discovery events."""
|
||||
logger.debug(
|
||||
f"Starting bootstrap discovery with "
|
||||
f"{len(self.bootstrap_addrs)} bootstrap addresses"
|
||||
)
|
||||
|
||||
# Validate and filter bootstrap addresses
|
||||
self.bootstrap_addrs = validate_bootstrap_addresses(self.bootstrap_addrs)
|
||||
|
||||
for addr_str in self.bootstrap_addrs:
|
||||
try:
|
||||
await self._process_bootstrap_addr(addr_str)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to process bootstrap address {addr_str}: {e}")
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Clean up bootstrap discovery resources."""
|
||||
logger.debug("Stopping bootstrap discovery")
|
||||
self.discovered_peers.clear()
|
||||
|
||||
async def _process_bootstrap_addr(self, addr_str: str) -> None:
|
||||
"""Convert string address to PeerInfo and add to peerstore."""
|
||||
try:
|
||||
multiaddr = Multiaddr(addr_str)
|
||||
except Exception as e:
|
||||
logger.debug(f"Invalid multiaddr format '{addr_str}': {e}")
|
||||
return
|
||||
if self.is_dns_addr(multiaddr):
|
||||
resolved_addrs = await resolver.resolve(multiaddr)
|
||||
peer_id_str = multiaddr.get_peer_id()
|
||||
if peer_id_str is None:
|
||||
logger.warning(f"Missing peer ID in DNS address: {addr_str}")
|
||||
return
|
||||
peer_id = ID.from_base58(peer_id_str)
|
||||
addrs = [addr for addr in resolved_addrs]
|
||||
if not addrs:
|
||||
logger.warning(f"No addresses resolved for DNS address: {addr_str}")
|
||||
return
|
||||
peer_info = PeerInfo(peer_id, addrs)
|
||||
self.add_addr(peer_info)
|
||||
else:
|
||||
self.add_addr(info_from_p2p_addr(multiaddr))
|
||||
|
||||
def is_dns_addr(self, addr: Multiaddr) -> bool:
|
||||
"""Check if the address is a DNS address."""
|
||||
return any(protocol.name == "dnsaddr" for protocol in addr.protocols())
|
||||
|
||||
def add_addr(self, peer_info: PeerInfo) -> None:
|
||||
"""Add a peer to the peerstore and emit discovery event."""
|
||||
# Skip if it's our own peer
|
||||
if peer_info.peer_id == self.swarm.get_peer_id():
|
||||
logger.debug(f"Skipping own peer ID: {peer_info.peer_id}")
|
||||
return
|
||||
|
||||
# Always add addresses to peerstore (allows multiple addresses for same peer)
|
||||
self.peerstore.add_addrs(peer_info.peer_id, peer_info.addrs, 10)
|
||||
|
||||
# Only emit discovery event if this is the first time we see this peer
|
||||
peer_id_str = str(peer_info.peer_id)
|
||||
if peer_id_str not in self.discovered_peers:
|
||||
# Track discovered peer
|
||||
self.discovered_peers.add(peer_id_str)
|
||||
# Emit peer discovery event
|
||||
peerDiscovery.emit_peer_discovered(peer_info)
|
||||
logger.debug(f"Peer discovered: {peer_info.peer_id}")
|
||||
else:
|
||||
logger.debug(f"Additional addresses added for peer: {peer_info.peer_id}")
|
||||
51
libp2p/discovery/bootstrap/utils.py
Normal file
51
libp2p/discovery/bootstrap/utils.py
Normal file
@ -0,0 +1,51 @@
|
||||
"""Utility functions for bootstrap discovery."""
|
||||
|
||||
import logging
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
from libp2p.peer.peerinfo import InvalidAddrError, PeerInfo, info_from_p2p_addr
|
||||
|
||||
logger = logging.getLogger("libp2p.discovery.bootstrap.utils")
|
||||
|
||||
|
||||
def validate_bootstrap_addresses(addrs: list[str]) -> list[str]:
|
||||
"""
|
||||
Validate and filter bootstrap addresses.
|
||||
|
||||
:param addrs: List of bootstrap address strings
|
||||
:return: List of valid bootstrap addresses
|
||||
"""
|
||||
valid_addrs = []
|
||||
|
||||
for addr_str in addrs:
|
||||
try:
|
||||
# Try to parse as multiaddr
|
||||
multiaddr = Multiaddr(addr_str)
|
||||
|
||||
# Try to extract peer info (this validates the p2p component)
|
||||
info_from_p2p_addr(multiaddr)
|
||||
|
||||
valid_addrs.append(addr_str)
|
||||
logger.debug(f"Valid bootstrap address: {addr_str}")
|
||||
|
||||
except (InvalidAddrError, ValueError, Exception) as e:
|
||||
logger.warning(f"Invalid bootstrap address '{addr_str}': {e}")
|
||||
continue
|
||||
|
||||
return valid_addrs
|
||||
|
||||
|
||||
def parse_bootstrap_peer_info(addr_str: str) -> PeerInfo | None:
|
||||
"""
|
||||
Parse bootstrap address string into PeerInfo.
|
||||
|
||||
:param addr_str: Bootstrap address string
|
||||
:return: PeerInfo object or None if parsing fails
|
||||
"""
|
||||
try:
|
||||
multiaddr = Multiaddr(addr_str)
|
||||
return info_from_p2p_addr(multiaddr)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse bootstrap address '{addr_str}': {e}")
|
||||
return None
|
||||
17
libp2p/discovery/random_walk/__init__.py
Normal file
17
libp2p/discovery/random_walk/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
"""Random walk discovery modules for py-libp2p."""
|
||||
|
||||
from .rt_refresh_manager import RTRefreshManager
|
||||
from .random_walk import RandomWalk
|
||||
from .exceptions import (
|
||||
RoutingTableRefreshError,
|
||||
RandomWalkError,
|
||||
PeerValidationError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"RTRefreshManager",
|
||||
"RandomWalk",
|
||||
"RoutingTableRefreshError",
|
||||
"RandomWalkError",
|
||||
"PeerValidationError",
|
||||
]
|
||||
16
libp2p/discovery/random_walk/config.py
Normal file
16
libp2p/discovery/random_walk/config.py
Normal file
@ -0,0 +1,16 @@
|
||||
from typing import Final
|
||||
|
||||
# Timing constants (matching go-libp2p)
|
||||
PEER_PING_TIMEOUT: Final[float] = 10.0 # seconds
|
||||
REFRESH_QUERY_TIMEOUT: Final[float] = 60.0 # seconds
|
||||
REFRESH_INTERVAL: Final[float] = 300.0 # 5 minutes
|
||||
SUCCESSFUL_OUTBOUND_QUERY_GRACE_PERIOD: Final[float] = 60.0 # 1 minute
|
||||
|
||||
# Routing table thresholds
|
||||
MIN_RT_REFRESH_THRESHOLD: Final[int] = 4 # Minimum peers before triggering refresh
|
||||
MAX_N_BOOTSTRAPPERS: Final[int] = 2 # Maximum bootstrap peers to try
|
||||
|
||||
# Random walk specific
|
||||
RANDOM_WALK_CONCURRENCY: Final[int] = 3 # Number of concurrent random walks
|
||||
RANDOM_WALK_ENABLED: Final[bool] = True # Enable automatic random walks
|
||||
RANDOM_WALK_RT_THRESHOLD: Final[int] = 20 # RT size threshold for peerstore fallback
|
||||
19
libp2p/discovery/random_walk/exceptions.py
Normal file
19
libp2p/discovery/random_walk/exceptions.py
Normal file
@ -0,0 +1,19 @@
|
||||
from libp2p.exceptions import BaseLibp2pError
|
||||
|
||||
|
||||
class RoutingTableRefreshError(BaseLibp2pError):
|
||||
"""Base exception for routing table refresh operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RandomWalkError(RoutingTableRefreshError):
|
||||
"""Exception raised during random walk operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class PeerValidationError(RoutingTableRefreshError):
|
||||
"""Exception raised when peer validation fails."""
|
||||
|
||||
pass
|
||||
218
libp2p/discovery/random_walk/random_walk.py
Normal file
218
libp2p/discovery/random_walk/random_walk.py
Normal file
@ -0,0 +1,218 @@
|
||||
from collections.abc import Awaitable, Callable
|
||||
import logging
|
||||
import secrets
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.abc import IHost
|
||||
from libp2p.discovery.random_walk.config import (
|
||||
RANDOM_WALK_CONCURRENCY,
|
||||
RANDOM_WALK_RT_THRESHOLD,
|
||||
REFRESH_QUERY_TIMEOUT,
|
||||
)
|
||||
from libp2p.discovery.random_walk.exceptions import RandomWalkError
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import PeerInfo
|
||||
|
||||
logger = logging.getLogger("libp2p.discovery.random_walk")
|
||||
|
||||
|
||||
class RandomWalk:
|
||||
"""
|
||||
Random Walk implementation for peer discovery in Kademlia DHT.
|
||||
|
||||
Generates random peer IDs and performs FIND_NODE queries to discover
|
||||
new peers and populate the routing table.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: IHost,
|
||||
local_peer_id: ID,
|
||||
query_function: Callable[[bytes], Awaitable[list[ID]]],
|
||||
):
|
||||
"""
|
||||
Initialize Random Walk module.
|
||||
|
||||
Args:
|
||||
host: The libp2p host instance
|
||||
local_peer_id: Local peer ID
|
||||
query_function: Function to query for closest peers given target key bytes
|
||||
|
||||
"""
|
||||
self.host = host
|
||||
self.local_peer_id = local_peer_id
|
||||
self.query_function = query_function
|
||||
|
||||
def generate_random_peer_id(self) -> str:
|
||||
"""
|
||||
Generate a completely random peer ID
|
||||
for random walk queries.
|
||||
|
||||
Returns:
|
||||
Random peer ID as string
|
||||
|
||||
"""
|
||||
# Generate 32 random bytes (256 bits) - same as go-libp2p
|
||||
random_bytes = secrets.token_bytes(32)
|
||||
# Convert to hex string for query
|
||||
return random_bytes.hex()
|
||||
|
||||
async def perform_random_walk(self) -> list[PeerInfo]:
|
||||
"""
|
||||
Perform a single random walk operation.
|
||||
|
||||
Returns:
|
||||
List of validated peers discovered during the walk
|
||||
|
||||
"""
|
||||
try:
|
||||
# Generate random peer ID
|
||||
random_peer_id = self.generate_random_peer_id()
|
||||
logger.info(f"Starting random walk for peer ID: {random_peer_id}")
|
||||
|
||||
# Perform FIND_NODE query
|
||||
discovered_peer_ids: list[ID] = []
|
||||
|
||||
with trio.move_on_after(REFRESH_QUERY_TIMEOUT):
|
||||
# Call the query function with target key bytes
|
||||
target_key = bytes.fromhex(random_peer_id)
|
||||
discovered_peer_ids = await self.query_function(target_key) or []
|
||||
|
||||
if not discovered_peer_ids:
|
||||
logger.debug(f"No peers discovered in random walk for {random_peer_id}")
|
||||
return []
|
||||
|
||||
logger.info(
|
||||
f"Discovered {len(discovered_peer_ids)} peers in random walk "
|
||||
f"for {random_peer_id[:8]}..." # Show only first 8 chars for brevity
|
||||
)
|
||||
|
||||
# Convert peer IDs to PeerInfo objects and validate
|
||||
validated_peers: list[PeerInfo] = []
|
||||
|
||||
for peer_id in discovered_peer_ids:
|
||||
try:
|
||||
# Get addresses from peerstore
|
||||
addrs = self.host.get_peerstore().addrs(peer_id)
|
||||
if addrs:
|
||||
peer_info = PeerInfo(peer_id, addrs)
|
||||
validated_peers.append(peer_info)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to create PeerInfo for {peer_id}: {e}")
|
||||
continue
|
||||
|
||||
return validated_peers
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Random walk failed: {e}")
|
||||
raise RandomWalkError(f"Random walk operation failed: {e}") from e
|
||||
|
||||
async def run_concurrent_random_walks(
|
||||
self, count: int = RANDOM_WALK_CONCURRENCY, current_routing_table_size: int = 0
|
||||
) -> list[PeerInfo]:
|
||||
"""
|
||||
Run multiple random walks concurrently.
|
||||
|
||||
Args:
|
||||
count: Number of concurrent random walks to perform
|
||||
current_routing_table_size: Current size of routing table (for optimization)
|
||||
|
||||
Returns:
|
||||
Combined list of all validated peers discovered
|
||||
|
||||
"""
|
||||
all_validated_peers: list[PeerInfo] = []
|
||||
logger.info(f"Starting {count} concurrent random walks")
|
||||
|
||||
# First, try to add peers from peerstore if routing table is small
|
||||
if current_routing_table_size < RANDOM_WALK_RT_THRESHOLD:
|
||||
try:
|
||||
peerstore_peers = self._get_peerstore_peers()
|
||||
if peerstore_peers:
|
||||
logger.debug(
|
||||
f"RT size ({current_routing_table_size}) below threshold, "
|
||||
f"adding {len(peerstore_peers)} peerstore peers"
|
||||
)
|
||||
all_validated_peers.extend(peerstore_peers)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing peerstore peers: {e}")
|
||||
|
||||
async def single_walk() -> None:
|
||||
try:
|
||||
peers = await self.perform_random_walk()
|
||||
all_validated_peers.extend(peers)
|
||||
except Exception as e:
|
||||
logger.warning(f"Concurrent random walk failed: {e}")
|
||||
return
|
||||
|
||||
# Run concurrent random walks
|
||||
async with trio.open_nursery() as nursery:
|
||||
for _ in range(count):
|
||||
nursery.start_soon(single_walk)
|
||||
|
||||
# Remove duplicates based on peer ID
|
||||
unique_peers = {}
|
||||
for peer in all_validated_peers:
|
||||
unique_peers[peer.peer_id] = peer
|
||||
|
||||
result = list(unique_peers.values())
|
||||
logger.info(
|
||||
f"Concurrent random walks completed: {len(result)} unique peers discovered"
|
||||
)
|
||||
return result
|
||||
|
||||
def _get_peerstore_peers(self) -> list[PeerInfo]:
|
||||
"""
|
||||
Get peer info objects from the host's peerstore.
|
||||
|
||||
Returns:
|
||||
List of PeerInfo objects from peerstore
|
||||
|
||||
"""
|
||||
try:
|
||||
peerstore = self.host.get_peerstore()
|
||||
peer_ids = peerstore.peers_with_addrs()
|
||||
|
||||
peer_infos = []
|
||||
for peer_id in peer_ids:
|
||||
try:
|
||||
# Skip local peer
|
||||
if peer_id == self.local_peer_id:
|
||||
continue
|
||||
|
||||
peer_info = peerstore.peer_info(peer_id)
|
||||
if peer_info and peer_info.addrs:
|
||||
# Filter for compatible addresses (TCP + IPv4)
|
||||
if self._has_compatible_addresses(peer_info):
|
||||
peer_infos.append(peer_info)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting peer info for {peer_id}: {e}")
|
||||
|
||||
return peer_infos
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error accessing peerstore: {e}")
|
||||
return []
|
||||
|
||||
def _has_compatible_addresses(self, peer_info: PeerInfo) -> bool:
|
||||
"""
|
||||
Check if a peer has TCP+IPv4 compatible addresses.
|
||||
|
||||
Args:
|
||||
peer_info: PeerInfo to check
|
||||
|
||||
Returns:
|
||||
True if peer has compatible addresses
|
||||
|
||||
"""
|
||||
if not peer_info.addrs:
|
||||
return False
|
||||
|
||||
for addr in peer_info.addrs:
|
||||
addr_str = str(addr)
|
||||
# Check for TCP and IPv4 compatibility, avoid QUIC
|
||||
if "/tcp/" in addr_str and "/ip4/" in addr_str and "/quic" not in addr_str:
|
||||
return True
|
||||
|
||||
return False
|
||||
208
libp2p/discovery/random_walk/rt_refresh_manager.py
Normal file
208
libp2p/discovery/random_walk/rt_refresh_manager.py
Normal file
@ -0,0 +1,208 @@
|
||||
from collections.abc import Awaitable, Callable
|
||||
import logging
|
||||
import time
|
||||
from typing import Protocol
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.abc import IHost
|
||||
from libp2p.discovery.random_walk.config import (
|
||||
MIN_RT_REFRESH_THRESHOLD,
|
||||
RANDOM_WALK_CONCURRENCY,
|
||||
RANDOM_WALK_ENABLED,
|
||||
REFRESH_INTERVAL,
|
||||
)
|
||||
from libp2p.discovery.random_walk.exceptions import RoutingTableRefreshError
|
||||
from libp2p.discovery.random_walk.random_walk import RandomWalk
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import PeerInfo
|
||||
|
||||
|
||||
class RoutingTableProtocol(Protocol):
|
||||
"""Protocol for routing table operations needed by RT refresh manager."""
|
||||
|
||||
def size(self) -> int:
|
||||
"""Return the current size of the routing table."""
|
||||
...
|
||||
|
||||
async def add_peer(self, peer_obj: PeerInfo) -> bool:
|
||||
"""Add a peer to the routing table."""
|
||||
...
|
||||
|
||||
|
||||
logger = logging.getLogger("libp2p.discovery.random_walk.rt_refresh_manager")
|
||||
|
||||
|
||||
class RTRefreshManager:
|
||||
"""
|
||||
Routing Table Refresh Manager for py-libp2p.
|
||||
|
||||
Manages periodic routing table refreshes and random walk operations
|
||||
to maintain routing table health and discover new peers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: IHost,
|
||||
routing_table: RoutingTableProtocol,
|
||||
local_peer_id: ID,
|
||||
query_function: Callable[[bytes], Awaitable[list[ID]]],
|
||||
enable_auto_refresh: bool = RANDOM_WALK_ENABLED,
|
||||
refresh_interval: float = REFRESH_INTERVAL,
|
||||
min_refresh_threshold: int = MIN_RT_REFRESH_THRESHOLD,
|
||||
):
|
||||
"""
|
||||
Initialize RT Refresh Manager.
|
||||
|
||||
Args:
|
||||
host: The libp2p host instance
|
||||
routing_table: Routing table of host
|
||||
local_peer_id: Local peer ID
|
||||
query_function: Function to query for closest peers given target key bytes
|
||||
enable_auto_refresh: Whether to enable automatic refresh
|
||||
refresh_interval: Interval between refreshes in seconds
|
||||
min_refresh_threshold: Minimum RT size before triggering refresh
|
||||
|
||||
"""
|
||||
self.host = host
|
||||
self.routing_table = routing_table
|
||||
self.local_peer_id = local_peer_id
|
||||
self.query_function = query_function
|
||||
|
||||
self.enable_auto_refresh = enable_auto_refresh
|
||||
self.refresh_interval = refresh_interval
|
||||
self.min_refresh_threshold = min_refresh_threshold
|
||||
|
||||
# Initialize random walk module
|
||||
self.random_walk = RandomWalk(
|
||||
host=host,
|
||||
local_peer_id=self.local_peer_id,
|
||||
query_function=query_function,
|
||||
)
|
||||
|
||||
# Control variables
|
||||
self._running = False
|
||||
self._nursery: trio.Nursery | None = None
|
||||
|
||||
# Tracking
|
||||
self._last_refresh_time = 0.0
|
||||
self._refresh_done_callbacks: list[Callable[[], None]] = []
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the RT Refresh Manager."""
|
||||
if self._running:
|
||||
logger.warning("RT Refresh Manager is already running")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
logger.info("Starting RT Refresh Manager")
|
||||
|
||||
# Start the main loop
|
||||
async with trio.open_nursery() as nursery:
|
||||
self._nursery = nursery
|
||||
nursery.start_soon(self._main_loop)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the RT Refresh Manager."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
logger.info("Stopping RT Refresh Manager")
|
||||
self._running = False
|
||||
|
||||
async def _main_loop(self) -> None:
|
||||
"""Main loop for the RT Refresh Manager."""
|
||||
logger.info("RT Refresh Manager main loop started")
|
||||
|
||||
# Initial refresh if auto-refresh is enabled
|
||||
if self.enable_auto_refresh:
|
||||
await self._do_refresh(force=True)
|
||||
|
||||
try:
|
||||
while self._running:
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Schedule periodic refresh if enabled
|
||||
if self.enable_auto_refresh:
|
||||
nursery.start_soon(self._periodic_refresh_task)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"RT Refresh Manager main loop error: {e}")
|
||||
finally:
|
||||
logger.info("RT Refresh Manager main loop stopped")
|
||||
|
||||
async def _periodic_refresh_task(self) -> None:
|
||||
"""Task for periodic refreshes."""
|
||||
while self._running:
|
||||
await trio.sleep(self.refresh_interval)
|
||||
if self._running:
|
||||
await self._do_refresh()
|
||||
|
||||
async def _do_refresh(self, force: bool = False) -> None:
|
||||
"""
|
||||
Perform routing table refresh operation.
|
||||
|
||||
Args:
|
||||
force: Whether to force refresh regardless of timing
|
||||
|
||||
"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
|
||||
# Check if refresh is needed
|
||||
if not force:
|
||||
if current_time - self._last_refresh_time < self.refresh_interval:
|
||||
logger.debug("Skipping refresh: interval not elapsed")
|
||||
return
|
||||
|
||||
if self.routing_table.size() >= self.min_refresh_threshold:
|
||||
logger.debug("Skipping refresh: routing table size above threshold")
|
||||
return
|
||||
|
||||
logger.info(f"Starting routing table refresh (force={force})")
|
||||
start_time = current_time
|
||||
|
||||
# Perform random walks to discover new peers
|
||||
logger.info("Running concurrent random walks to discover new peers")
|
||||
current_rt_size = self.routing_table.size()
|
||||
discovered_peers = await self.random_walk.run_concurrent_random_walks(
|
||||
count=RANDOM_WALK_CONCURRENCY,
|
||||
current_routing_table_size=current_rt_size,
|
||||
)
|
||||
|
||||
# Add discovered peers to routing table
|
||||
added_count = 0
|
||||
for peer_info in discovered_peers:
|
||||
result = await self.routing_table.add_peer(peer_info)
|
||||
if result:
|
||||
added_count += 1
|
||||
|
||||
self._last_refresh_time = current_time
|
||||
|
||||
duration = time.time() - start_time
|
||||
logger.info(
|
||||
f"Routing table refresh completed: "
|
||||
f"{added_count}/{len(discovered_peers)} peers added, "
|
||||
f"RT size: {self.routing_table.size()}, "
|
||||
f"duration: {duration:.2f}s"
|
||||
)
|
||||
|
||||
# Notify refresh completion
|
||||
for callback in self._refresh_done_callbacks:
|
||||
try:
|
||||
callback()
|
||||
except Exception as e:
|
||||
logger.warning(f"Refresh callback error: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Routing table refresh failed: {e}")
|
||||
raise RoutingTableRefreshError(f"Refresh operation failed: {e}") from e
|
||||
|
||||
def add_refresh_done_callback(self, callback: Callable[[], None]) -> None:
|
||||
"""Add a callback to be called when refresh completes."""
|
||||
self._refresh_done_callbacks.append(callback)
|
||||
|
||||
def remove_refresh_done_callback(self, callback: Callable[[], None]) -> None:
|
||||
"""Remove a refresh completion callback."""
|
||||
if callback in self._refresh_done_callbacks:
|
||||
self._refresh_done_callbacks.remove(callback)
|
||||
@ -29,6 +29,7 @@ from libp2p.custom_types import (
|
||||
StreamHandlerFn,
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.discovery.bootstrap.bootstrap import BootstrapDiscovery
|
||||
from libp2p.discovery.mdns.mdns import MDNSDiscovery
|
||||
from libp2p.host.defaults import (
|
||||
get_default_protocols,
|
||||
@ -42,6 +43,7 @@ from libp2p.peer.id import (
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
from libp2p.peer.peerstore import create_signed_peer_record
|
||||
from libp2p.protocol_muxer.exceptions import (
|
||||
MultiselectClientError,
|
||||
MultiselectError,
|
||||
@ -92,6 +94,7 @@ class BasicHost(IHost):
|
||||
self,
|
||||
network: INetworkService,
|
||||
enable_mDNS: bool = False,
|
||||
bootstrap: list[str] | None = None,
|
||||
default_protocols: Optional["OrderedDict[TProtocol, StreamHandlerFn]"] = None,
|
||||
negotitate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
) -> None:
|
||||
@ -105,6 +108,16 @@ class BasicHost(IHost):
|
||||
self.multiselect_client = MultiselectClient()
|
||||
if enable_mDNS:
|
||||
self.mDNS = MDNSDiscovery(network)
|
||||
if bootstrap:
|
||||
self.bootstrap = BootstrapDiscovery(network, bootstrap)
|
||||
|
||||
# Cache a signed-record if the local-node in the PeerStore
|
||||
envelope = create_signed_peer_record(
|
||||
self.get_id(),
|
||||
self.get_addrs(),
|
||||
self.get_private_key(),
|
||||
)
|
||||
self.get_peerstore().set_local_record(envelope)
|
||||
|
||||
def get_id(self) -> ID:
|
||||
"""
|
||||
@ -172,11 +185,16 @@ class BasicHost(IHost):
|
||||
if hasattr(self, "mDNS") and self.mDNS is not None:
|
||||
logger.debug("Starting mDNS Discovery")
|
||||
self.mDNS.start()
|
||||
if hasattr(self, "bootstrap") and self.bootstrap is not None:
|
||||
logger.debug("Starting Bootstrap Discovery")
|
||||
await self.bootstrap.start()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if hasattr(self, "mDNS") and self.mDNS is not None:
|
||||
self.mDNS.stop()
|
||||
if hasattr(self, "bootstrap") and self.bootstrap is not None:
|
||||
self.bootstrap.stop()
|
||||
|
||||
return _run()
|
||||
|
||||
@ -279,6 +297,11 @@ class BasicHost(IHost):
|
||||
protocol, handler = await self.multiselect.negotiate(
|
||||
MultiselectCommunicator(net_stream), self.negotiate_timeout
|
||||
)
|
||||
if protocol is None:
|
||||
await net_stream.reset()
|
||||
raise StreamFailure(
|
||||
"Failed to negotiate protocol: no protocol selected"
|
||||
)
|
||||
except MultiselectError as error:
|
||||
peer_id = net_stream.muxed_conn.peer_id
|
||||
logger.debug(
|
||||
@ -286,6 +309,13 @@ class BasicHost(IHost):
|
||||
)
|
||||
await net_stream.reset()
|
||||
return
|
||||
if protocol is None:
|
||||
logger.debug(
|
||||
"no protocol negotiated, closing stream from peer %s",
|
||||
net_stream.muxed_conn.peer_id,
|
||||
)
|
||||
await net_stream.reset()
|
||||
return
|
||||
net_stream.set_protocol(protocol)
|
||||
if handler is None:
|
||||
logger.debug(
|
||||
@ -313,7 +343,7 @@ class BasicHost(IHost):
|
||||
:param peer_id: ID of the peer to check
|
||||
:return: True if peer has an active connection, False otherwise
|
||||
"""
|
||||
return peer_id in self._network.connections
|
||||
return len(self._network.get_connections(peer_id)) > 0
|
||||
|
||||
def get_peer_connection_info(self, peer_id: ID) -> INetConn | None:
|
||||
"""
|
||||
@ -322,4 +352,4 @@ class BasicHost(IHost):
|
||||
:param peer_id: ID of the peer to get info for
|
||||
:return: Connection object if peer is connected, None otherwise
|
||||
"""
|
||||
return self._network.connections.get(peer_id)
|
||||
return self._network.get_connection(peer_id)
|
||||
|
||||
@ -19,9 +19,13 @@ class RoutedHost(BasicHost):
|
||||
_router: IPeerRouting
|
||||
|
||||
def __init__(
|
||||
self, network: INetworkService, router: IPeerRouting, enable_mDNS: bool = False
|
||||
self,
|
||||
network: INetworkService,
|
||||
router: IPeerRouting,
|
||||
enable_mDNS: bool = False,
|
||||
bootstrap: list[str] | None = None,
|
||||
):
|
||||
super().__init__(network, enable_mDNS)
|
||||
super().__init__(network, enable_mDNS, bootstrap)
|
||||
self._router = router
|
||||
|
||||
async def connect(self, peer_info: PeerInfo) -> None:
|
||||
|
||||
@ -15,6 +15,7 @@ from libp2p.custom_types import (
|
||||
from libp2p.network.stream.exceptions import (
|
||||
StreamClosed,
|
||||
)
|
||||
from libp2p.peer.peerstore import env_to_send_in_RPC
|
||||
from libp2p.utils import (
|
||||
decode_varint_with_size,
|
||||
get_agent_version,
|
||||
@ -63,6 +64,9 @@ def _mk_identify_protobuf(
|
||||
laddrs = host.get_addrs()
|
||||
protocols = tuple(str(p) for p in host.get_mux().get_protocols() if p is not None)
|
||||
|
||||
# Create a signed peer-record for the remote peer
|
||||
envelope_bytes, _ = env_to_send_in_RPC(host)
|
||||
|
||||
observed_addr = observed_multiaddr.to_bytes() if observed_multiaddr else b""
|
||||
return Identify(
|
||||
protocol_version=PROTOCOL_VERSION,
|
||||
@ -71,6 +75,7 @@ def _mk_identify_protobuf(
|
||||
listen_addrs=map(_multiaddr_to_bytes, laddrs),
|
||||
observed_addr=observed_addr,
|
||||
protocols=protocols,
|
||||
signedPeerRecord=envelope_bytes,
|
||||
)
|
||||
|
||||
|
||||
@ -113,7 +118,7 @@ def parse_identify_response(response: bytes) -> Identify:
|
||||
|
||||
|
||||
def identify_handler_for(
|
||||
host: IHost, use_varint_format: bool = False
|
||||
host: IHost, use_varint_format: bool = True
|
||||
) -> StreamHandlerFn:
|
||||
async def handle_identify(stream: INetStream) -> None:
|
||||
# get observed address from ``stream``
|
||||
|
||||
@ -9,4 +9,5 @@ message Identify {
|
||||
repeated bytes listen_addrs = 2;
|
||||
optional bytes observed_addr = 4;
|
||||
repeated string protocols = 3;
|
||||
optional bytes signedPeerRecord = 8;
|
||||
}
|
||||
|
||||
@ -13,7 +13,7 @@ _sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n*libp2p/identity/identify/pb/identify.proto\x12\x0bidentify.pb\"\x8f\x01\n\x08Identify\x12\x18\n\x10protocol_version\x18\x05 \x01(\t\x12\x15\n\ragent_version\x18\x06 \x01(\t\x12\x12\n\npublic_key\x18\x01 \x01(\x0c\x12\x14\n\x0clisten_addrs\x18\x02 \x03(\x0c\x12\x15\n\robserved_addr\x18\x04 \x01(\x0c\x12\x11\n\tprotocols\x18\x03 \x03(\t')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n*libp2p/identity/identify/pb/identify.proto\x12\x0bidentify.pb\"\xa9\x01\n\x08Identify\x12\x18\n\x10protocol_version\x18\x05 \x01(\t\x12\x15\n\ragent_version\x18\x06 \x01(\t\x12\x12\n\npublic_key\x18\x01 \x01(\x0c\x12\x14\n\x0clisten_addrs\x18\x02 \x03(\x0c\x12\x15\n\robserved_addr\x18\x04 \x01(\x0c\x12\x11\n\tprotocols\x18\x03 \x03(\t\x12\x18\n\x10signedPeerRecord\x18\x08 \x01(\x0c')
|
||||
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.identity.identify.pb.identify_pb2', globals())
|
||||
@ -21,5 +21,5 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
|
||||
DESCRIPTOR._options = None
|
||||
_IDENTIFY._serialized_start=60
|
||||
_IDENTIFY._serialized_end=203
|
||||
_IDENTIFY._serialized_end=229
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
||||
@ -22,10 +22,12 @@ class Identify(google.protobuf.message.Message):
|
||||
LISTEN_ADDRS_FIELD_NUMBER: builtins.int
|
||||
OBSERVED_ADDR_FIELD_NUMBER: builtins.int
|
||||
PROTOCOLS_FIELD_NUMBER: builtins.int
|
||||
SIGNEDPEERRECORD_FIELD_NUMBER: builtins.int
|
||||
protocol_version: builtins.str
|
||||
agent_version: builtins.str
|
||||
public_key: builtins.bytes
|
||||
observed_addr: builtins.bytes
|
||||
signedPeerRecord: builtins.bytes
|
||||
@property
|
||||
def listen_addrs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ...
|
||||
@property
|
||||
@ -39,8 +41,9 @@ class Identify(google.protobuf.message.Message):
|
||||
listen_addrs: collections.abc.Iterable[builtins.bytes] | None = ...,
|
||||
observed_addr: builtins.bytes | None = ...,
|
||||
protocols: collections.abc.Iterable[builtins.str] | None = ...,
|
||||
signedPeerRecord: builtins.bytes | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["agent_version", b"agent_version", "observed_addr", b"observed_addr", "protocol_version", b"protocol_version", "public_key", b"public_key"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["agent_version", b"agent_version", "listen_addrs", b"listen_addrs", "observed_addr", b"observed_addr", "protocol_version", b"protocol_version", "protocols", b"protocols", "public_key", b"public_key"]) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["agent_version", b"agent_version", "observed_addr", b"observed_addr", "protocol_version", b"protocol_version", "public_key", b"public_key", "signedPeerRecord", b"signedPeerRecord"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["agent_version", b"agent_version", "listen_addrs", b"listen_addrs", "observed_addr", b"observed_addr", "protocol_version", b"protocol_version", "protocols", b"protocols", "public_key", b"public_key", "signedPeerRecord", b"signedPeerRecord"]) -> None: ...
|
||||
|
||||
global___Identify = Identify
|
||||
|
||||
@ -20,6 +20,7 @@ from libp2p.custom_types import (
|
||||
from libp2p.network.stream.exceptions import (
|
||||
StreamClosed,
|
||||
)
|
||||
from libp2p.peer.envelope import consume_envelope
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
@ -28,7 +29,7 @@ from libp2p.utils import (
|
||||
varint,
|
||||
)
|
||||
from libp2p.utils.varint import (
|
||||
decode_varint_from_bytes,
|
||||
read_length_prefixed_protobuf,
|
||||
)
|
||||
|
||||
from ..identify.identify import (
|
||||
@ -66,49 +67,8 @@ def identify_push_handler_for(
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
|
||||
try:
|
||||
if use_varint_format:
|
||||
# Read length-prefixed identify message from the stream
|
||||
# First read the varint length prefix
|
||||
length_bytes = b""
|
||||
while True:
|
||||
b = await stream.read(1)
|
||||
if not b:
|
||||
break
|
||||
length_bytes += b
|
||||
if b[0] & 0x80 == 0:
|
||||
break
|
||||
|
||||
if not length_bytes:
|
||||
logger.warning("No length prefix received from peer %s", peer_id)
|
||||
return
|
||||
|
||||
msg_length = decode_varint_from_bytes(length_bytes)
|
||||
|
||||
# Read the protobuf message
|
||||
data = await stream.read(msg_length)
|
||||
if len(data) != msg_length:
|
||||
logger.warning("Incomplete message received from peer %s", peer_id)
|
||||
return
|
||||
else:
|
||||
# Read raw protobuf message from the stream
|
||||
# For raw format, we need to read all data before the stream is closed
|
||||
data = b""
|
||||
try:
|
||||
# Read all available data in a single operation
|
||||
data = await stream.read()
|
||||
except StreamClosed:
|
||||
# Try to read any remaining data
|
||||
try:
|
||||
data = await stream.read()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# If we got no data, log a warning and return
|
||||
if not data:
|
||||
logger.warning(
|
||||
"No data received in raw format from peer %s", peer_id
|
||||
)
|
||||
return
|
||||
# Use the utility function to read the protobuf message
|
||||
data = await read_length_prefixed_protobuf(stream, use_varint_format)
|
||||
|
||||
identify_msg = Identify()
|
||||
identify_msg.ParseFromString(data)
|
||||
@ -119,6 +79,11 @@ def identify_push_handler_for(
|
||||
)
|
||||
|
||||
logger.debug("Successfully processed identify/push from peer %s", peer_id)
|
||||
|
||||
# Send acknowledgment to indicate successful processing
|
||||
# This ensures the sender knows the message was received before closing
|
||||
await stream.write(b"OK")
|
||||
|
||||
except StreamClosed:
|
||||
logger.debug(
|
||||
"Stream closed while processing identify/push from %s", peer_id
|
||||
@ -127,7 +92,10 @@ def identify_push_handler_for(
|
||||
logger.error("Error processing identify/push from %s: %s", peer_id, e)
|
||||
finally:
|
||||
# Close the stream after processing
|
||||
await stream.close()
|
||||
try:
|
||||
await stream.close()
|
||||
except Exception:
|
||||
pass # Ignore errors when closing
|
||||
|
||||
return handle_identify_push
|
||||
|
||||
@ -173,6 +141,19 @@ async def _update_peerstore_from_identify(
|
||||
except Exception as e:
|
||||
logger.error("Error updating protocols for peer %s: %s", peer_id, e)
|
||||
|
||||
if identify_msg.HasField("signedPeerRecord"):
|
||||
try:
|
||||
# Convert the signed-peer-record(Envelope) from prtobuf bytes
|
||||
envelope, _ = consume_envelope(
|
||||
identify_msg.signedPeerRecord, "libp2p-peer-record"
|
||||
)
|
||||
# Use a default TTL of 2 hours (7200 seconds)
|
||||
if not peerstore.consume_peer_record(envelope, 7200):
|
||||
logger.error("Updating Certified-Addr-Book was unsuccessful")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error updating the certified addr book for peer %s: %s", peer_id, e
|
||||
)
|
||||
# Update observed address if present
|
||||
if identify_msg.HasField("observed_addr") and identify_msg.observed_addr:
|
||||
try:
|
||||
@ -226,7 +207,20 @@ async def push_identify_to_peer(
|
||||
# Send raw protobuf message
|
||||
await stream.write(response)
|
||||
|
||||
# Close the stream
|
||||
# Wait for acknowledgment from the receiver with timeout
|
||||
# This ensures the message was processed before closing
|
||||
try:
|
||||
with trio.move_on_after(1.0): # 1 second timeout
|
||||
ack = await stream.read(2) # Read "OK" acknowledgment
|
||||
if ack != b"OK":
|
||||
logger.warning(
|
||||
"Unexpected acknowledgment from peer %s: %s", peer_id, ack
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("No acknowledgment received from peer %s: %s", peer_id, e)
|
||||
# Continue anyway, as the message might have been processed
|
||||
|
||||
# Close the stream after acknowledgment (or timeout)
|
||||
await stream.close()
|
||||
|
||||
logger.debug("Successfully pushed identify to peer %s", peer_id)
|
||||
|
||||
@ -5,6 +5,7 @@ This module provides a complete Distributed Hash Table (DHT)
|
||||
implementation based on the Kademlia algorithm and protocol.
|
||||
"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from enum import (
|
||||
Enum,
|
||||
)
|
||||
@ -20,15 +21,19 @@ import varint
|
||||
from libp2p.abc import (
|
||||
IHost,
|
||||
)
|
||||
from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager
|
||||
from libp2p.kad_dht.utils import maybe_consume_signed_record
|
||||
from libp2p.network.stream.net_stream import (
|
||||
INetStream,
|
||||
)
|
||||
from libp2p.peer.envelope import Envelope
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
from libp2p.peer.peerstore import env_to_send_in_RPC
|
||||
from libp2p.tools.async_service import (
|
||||
Service,
|
||||
)
|
||||
@ -73,14 +78,27 @@ class KadDHT(Service):
|
||||
|
||||
This class provides a DHT implementation that combines routing table management,
|
||||
peer discovery, content routing, and value storage.
|
||||
|
||||
Optional Random Walk feature enhances peer discovery by automatically
|
||||
performing periodic random queries to discover new peers and maintain
|
||||
routing table health.
|
||||
|
||||
Example:
|
||||
# Basic DHT without random walk (default)
|
||||
dht = KadDHT(host, DHTMode.SERVER)
|
||||
|
||||
# DHT with random walk enabled for enhanced peer discovery
|
||||
dht = KadDHT(host, DHTMode.SERVER, enable_random_walk=True)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, host: IHost, mode: DHTMode):
|
||||
def __init__(self, host: IHost, mode: DHTMode, enable_random_walk: bool = False):
|
||||
"""
|
||||
Initialize a new Kademlia DHT node.
|
||||
|
||||
:param host: The libp2p host.
|
||||
:param mode: The mode of host (Client or Server) - must be DHTMode enum
|
||||
:param enable_random_walk: Whether to enable automatic random walk
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@ -92,6 +110,7 @@ class KadDHT(Service):
|
||||
raise TypeError(f"mode must be DHTMode enum, got {type(mode)}")
|
||||
|
||||
self.mode = mode
|
||||
self.enable_random_walk = enable_random_walk
|
||||
|
||||
# Initialize the routing table
|
||||
self.routing_table = RoutingTable(self.local_peer_id, self.host)
|
||||
@ -108,13 +127,56 @@ class KadDHT(Service):
|
||||
# Last time we republished provider records
|
||||
self._last_provider_republish = time.time()
|
||||
|
||||
# Initialize RT Refresh Manager (only if random walk is enabled)
|
||||
self.rt_refresh_manager: RTRefreshManager | None = None
|
||||
if self.enable_random_walk:
|
||||
self.rt_refresh_manager = RTRefreshManager(
|
||||
host=self.host,
|
||||
routing_table=self.routing_table,
|
||||
local_peer_id=self.local_peer_id,
|
||||
query_function=self._create_query_function(),
|
||||
enable_auto_refresh=True,
|
||||
)
|
||||
|
||||
# Set protocol handlers
|
||||
host.set_stream_handler(PROTOCOL_ID, self.handle_stream)
|
||||
|
||||
def _create_query_function(self) -> Callable[[bytes], Awaitable[list[ID]]]:
|
||||
"""
|
||||
Create a query function that wraps peer_routing.find_closest_peers_network.
|
||||
|
||||
This function is used by the RandomWalk module to query for peers without
|
||||
directly importing PeerRouting, avoiding circular import issues.
|
||||
|
||||
Returns:
|
||||
Callable that takes target_key bytes and returns list of peer IDs
|
||||
|
||||
"""
|
||||
|
||||
async def query_function(target_key: bytes) -> list[ID]:
|
||||
"""Query for closest peers to target key."""
|
||||
return await self.peer_routing.find_closest_peers_network(target_key)
|
||||
|
||||
return query_function
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run the DHT service."""
|
||||
logger.info(f"Starting Kademlia DHT with peer ID {self.local_peer_id}")
|
||||
|
||||
# Start the RT Refresh Manager in parallel with the main DHT service
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Start the RT Refresh Manager only if random walk is enabled
|
||||
if self.rt_refresh_manager is not None:
|
||||
nursery.start_soon(self.rt_refresh_manager.start)
|
||||
logger.info("RT Refresh Manager started - Random Walk is now active")
|
||||
else:
|
||||
logger.info("Random Walk is disabled - RT Refresh Manager not started")
|
||||
|
||||
# Start the main DHT service loop
|
||||
nursery.start_soon(self._run_main_loop)
|
||||
|
||||
async def _run_main_loop(self) -> None:
|
||||
"""Run the main DHT service loop."""
|
||||
# Main service loop
|
||||
while self.manager.is_running:
|
||||
# Periodically refresh the routing table
|
||||
@ -135,6 +197,17 @@ class KadDHT(Service):
|
||||
# Wait before next maintenance cycle
|
||||
await trio.sleep(ROUTING_TABLE_REFRESH_INTERVAL)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the DHT service and cleanup resources."""
|
||||
logger.info("Stopping Kademlia DHT")
|
||||
|
||||
# Stop the RT Refresh Manager only if it was started
|
||||
if self.rt_refresh_manager is not None:
|
||||
await self.rt_refresh_manager.stop()
|
||||
logger.info("RT Refresh Manager stopped")
|
||||
else:
|
||||
logger.info("RT Refresh Manager was not running (Random Walk disabled)")
|
||||
|
||||
async def switch_mode(self, new_mode: DHTMode) -> DHTMode:
|
||||
"""
|
||||
Switch the DHT mode.
|
||||
@ -164,6 +237,9 @@ class KadDHT(Service):
|
||||
await self.add_peer(peer_id)
|
||||
logger.debug(f"Added peer {peer_id} to routing table")
|
||||
|
||||
closer_peer_envelope: Envelope | None = None
|
||||
provider_peer_envelope: Envelope | None = None
|
||||
|
||||
try:
|
||||
# Read varint-prefixed length for the message
|
||||
length_prefix = b""
|
||||
@ -204,6 +280,14 @@ class KadDHT(Service):
|
||||
)
|
||||
logger.debug(f"Found {len(closest_peers)} peers close to target")
|
||||
|
||||
# Consume the source signed_peer_record if sent
|
||||
if not maybe_consume_signed_record(message, self.host, peer_id):
|
||||
logger.error(
|
||||
"Received an invalid-signed-record, dropping the stream"
|
||||
)
|
||||
await stream.close()
|
||||
return
|
||||
|
||||
# Build response message with protobuf
|
||||
response = Message()
|
||||
response.type = Message.MessageType.FIND_NODE
|
||||
@ -228,6 +312,21 @@ class KadDHT(Service):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Add the signed-peer-record for each peer in the peer-proto
|
||||
# if cached in the peerstore
|
||||
closer_peer_envelope = (
|
||||
self.host.get_peerstore().get_peer_record(peer)
|
||||
)
|
||||
|
||||
if closer_peer_envelope is not None:
|
||||
peer_proto.signedRecord = (
|
||||
closer_peer_envelope.marshal_envelope()
|
||||
)
|
||||
|
||||
# Create sender_signed_peer_record
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
||||
response.senderRecord = envelope_bytes
|
||||
|
||||
# Serialize and send response
|
||||
response_bytes = response.SerializeToString()
|
||||
await stream.write(varint.encode(len(response_bytes)))
|
||||
@ -242,6 +341,14 @@ class KadDHT(Service):
|
||||
key = message.key
|
||||
logger.debug(f"Received ADD_PROVIDER for key {key.hex()}")
|
||||
|
||||
# Consume the source signed-peer-record if sent
|
||||
if not maybe_consume_signed_record(message, self.host, peer_id):
|
||||
logger.error(
|
||||
"Received an invalid-signed-record, dropping the stream"
|
||||
)
|
||||
await stream.close()
|
||||
return
|
||||
|
||||
# Extract provider information
|
||||
for provider_proto in message.providerPeers:
|
||||
try:
|
||||
@ -268,6 +375,17 @@ class KadDHT(Service):
|
||||
logger.debug(
|
||||
f"Added provider {provider_id} for key {key.hex()}"
|
||||
)
|
||||
|
||||
# Process the signed-records of provider if sent
|
||||
if not maybe_consume_signed_record(
|
||||
provider_proto, self.host
|
||||
):
|
||||
logger.error(
|
||||
"Received an invalid-signed-record,"
|
||||
"dropping the stream"
|
||||
)
|
||||
await stream.close()
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process provider info: {e}")
|
||||
|
||||
@ -276,6 +394,10 @@ class KadDHT(Service):
|
||||
response.type = Message.MessageType.ADD_PROVIDER
|
||||
response.key = key
|
||||
|
||||
# Add sender's signed-peer-record
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
||||
response.senderRecord = envelope_bytes
|
||||
|
||||
response_bytes = response.SerializeToString()
|
||||
await stream.write(varint.encode(len(response_bytes)))
|
||||
await stream.write(response_bytes)
|
||||
@ -287,6 +409,14 @@ class KadDHT(Service):
|
||||
key = message.key
|
||||
logger.debug(f"Received GET_PROVIDERS request for key {key.hex()}")
|
||||
|
||||
# Consume the source signed_peer_record if sent
|
||||
if not maybe_consume_signed_record(message, self.host, peer_id):
|
||||
logger.error(
|
||||
"Received an invalid-signed-record, dropping the stream"
|
||||
)
|
||||
await stream.close()
|
||||
return
|
||||
|
||||
# Find providers for the key
|
||||
providers = self.provider_store.get_providers(key)
|
||||
logger.debug(
|
||||
@ -298,12 +428,28 @@ class KadDHT(Service):
|
||||
response.type = Message.MessageType.GET_PROVIDERS
|
||||
response.key = key
|
||||
|
||||
# Create sender_signed_peer_record for the response
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
||||
response.senderRecord = envelope_bytes
|
||||
|
||||
# Add provider information to response
|
||||
for provider_info in providers:
|
||||
provider_proto = response.providerPeers.add()
|
||||
provider_proto.id = provider_info.peer_id.to_bytes()
|
||||
provider_proto.connection = Message.ConnectionType.CAN_CONNECT
|
||||
|
||||
# Add provider signed-records if cached
|
||||
provider_peer_envelope = (
|
||||
self.host.get_peerstore().get_peer_record(
|
||||
provider_info.peer_id
|
||||
)
|
||||
)
|
||||
|
||||
if provider_peer_envelope is not None:
|
||||
provider_proto.signedRecord = (
|
||||
provider_peer_envelope.marshal_envelope()
|
||||
)
|
||||
|
||||
# Add addresses if available
|
||||
for addr in provider_info.addrs:
|
||||
provider_proto.addrs.append(addr.to_bytes())
|
||||
@ -327,6 +473,16 @@ class KadDHT(Service):
|
||||
peer_proto.id = peer.to_bytes()
|
||||
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
|
||||
|
||||
# Add the signed-records of closest_peers if cached
|
||||
closer_peer_envelope = (
|
||||
self.host.get_peerstore().get_peer_record(peer)
|
||||
)
|
||||
|
||||
if closer_peer_envelope is not None:
|
||||
peer_proto.signedRecord = (
|
||||
closer_peer_envelope.marshal_envelope()
|
||||
)
|
||||
|
||||
# Add addresses if available
|
||||
try:
|
||||
addrs = self.host.get_peerstore().addrs(peer)
|
||||
@ -347,6 +503,14 @@ class KadDHT(Service):
|
||||
key = message.key
|
||||
logger.debug(f"Received GET_VALUE request for key {key.hex()}")
|
||||
|
||||
# Consume the sender_signed_peer_record
|
||||
if not maybe_consume_signed_record(message, self.host, peer_id):
|
||||
logger.error(
|
||||
"Received an invalid-signed-record, dropping the stream"
|
||||
)
|
||||
await stream.close()
|
||||
return
|
||||
|
||||
value = self.value_store.get(key)
|
||||
if value:
|
||||
logger.debug(f"Found value for key {key.hex()}")
|
||||
@ -361,6 +525,10 @@ class KadDHT(Service):
|
||||
response.record.value = value
|
||||
response.record.timeReceived = str(time.time())
|
||||
|
||||
# Create sender_signed_peer_record
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
||||
response.senderRecord = envelope_bytes
|
||||
|
||||
# Serialize and send response
|
||||
response_bytes = response.SerializeToString()
|
||||
await stream.write(varint.encode(len(response_bytes)))
|
||||
@ -374,6 +542,10 @@ class KadDHT(Service):
|
||||
response.type = Message.MessageType.GET_VALUE
|
||||
response.key = key
|
||||
|
||||
# Create sender_signed_peer_record for the response
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
||||
response.senderRecord = envelope_bytes
|
||||
|
||||
# Add closest peers to key
|
||||
closest_peers = self.routing_table.find_local_closest_peers(
|
||||
key, 20
|
||||
@ -392,6 +564,16 @@ class KadDHT(Service):
|
||||
peer_proto.id = peer.to_bytes()
|
||||
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
|
||||
|
||||
# Add signed-records of closer-peers if cached
|
||||
closer_peer_envelope = (
|
||||
self.host.get_peerstore().get_peer_record(peer)
|
||||
)
|
||||
|
||||
if closer_peer_envelope is not None:
|
||||
peer_proto.signedRecord = (
|
||||
closer_peer_envelope.marshal_envelope()
|
||||
)
|
||||
|
||||
# Add addresses if available
|
||||
try:
|
||||
addrs = self.host.get_peerstore().addrs(peer)
|
||||
@ -414,6 +596,15 @@ class KadDHT(Service):
|
||||
key = message.record.key
|
||||
value = message.record.value
|
||||
success = False
|
||||
|
||||
# Consume the source signed_peer_record if sent
|
||||
if not maybe_consume_signed_record(message, self.host, peer_id):
|
||||
logger.error(
|
||||
"Received an invalid-signed-record, dropping the stream"
|
||||
)
|
||||
await stream.close()
|
||||
return
|
||||
|
||||
try:
|
||||
if not (key and value):
|
||||
raise ValueError(
|
||||
@ -434,6 +625,12 @@ class KadDHT(Service):
|
||||
response.type = Message.MessageType.PUT_VALUE
|
||||
if success:
|
||||
response.key = key
|
||||
|
||||
# Create sender_signed_peer_record for the response
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
||||
response.senderRecord = envelope_bytes
|
||||
|
||||
# Serialize and send response
|
||||
response_bytes = response.SerializeToString()
|
||||
await stream.write(varint.encode(len(response_bytes)))
|
||||
await stream.write(response_bytes)
|
||||
@ -614,3 +811,15 @@ class KadDHT(Service):
|
||||
|
||||
"""
|
||||
return self.value_store.size()
|
||||
|
||||
def is_random_walk_enabled(self) -> bool:
|
||||
"""
|
||||
Check if random walk peer discovery is enabled.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if random walk is enabled, False otherwise.
|
||||
|
||||
"""
|
||||
return self.enable_random_walk
|
||||
|
||||
@ -27,6 +27,7 @@ message Message {
|
||||
bytes id = 1;
|
||||
repeated bytes addrs = 2;
|
||||
ConnectionType connection = 3;
|
||||
optional bytes signedRecord = 4; // Envelope(PeerRecord) encoded
|
||||
}
|
||||
|
||||
MessageType type = 1;
|
||||
@ -35,4 +36,6 @@ message Message {
|
||||
Record record = 3;
|
||||
repeated Peer closerPeers = 8;
|
||||
repeated Peer providerPeers = 9;
|
||||
|
||||
optional bytes senderRecord = 11; // Envelope(PeerRecord) encoded
|
||||
}
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: libp2p/kad_dht/pb/kademlia.proto
|
||||
# Protobuf Python Version: 4.25.3
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
@ -13,7 +14,7 @@ _sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/kad_dht/pb/kademlia.proto\":\n\x06Record\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x14\n\x0ctimeReceived\x18\x05 \x01(\t\"\xca\x03\n\x07Message\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.Message.MessageType\x12\x17\n\x0f\x63lusterLevelRaw\x18\n \x01(\x05\x12\x0b\n\x03key\x18\x02 \x01(\x0c\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record\x12\"\n\x0b\x63loserPeers\x18\x08 \x03(\x0b\x32\r.Message.Peer\x12$\n\rproviderPeers\x18\t \x03(\x0b\x32\r.Message.Peer\x1aN\n\x04Peer\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12+\n\nconnection\x18\x03 \x01(\x0e\x32\x17.Message.ConnectionType\"i\n\x0bMessageType\x12\r\n\tPUT_VALUE\x10\x00\x12\r\n\tGET_VALUE\x10\x01\x12\x10\n\x0c\x41\x44\x44_PROVIDER\x10\x02\x12\x11\n\rGET_PROVIDERS\x10\x03\x12\r\n\tFIND_NODE\x10\x04\x12\x08\n\x04PING\x10\x05\"W\n\x0e\x43onnectionType\x12\x11\n\rNOT_CONNECTED\x10\x00\x12\r\n\tCONNECTED\x10\x01\x12\x0f\n\x0b\x43\x41N_CONNECT\x10\x02\x12\x12\n\x0e\x43\x41NNOT_CONNECT\x10\x03\x62\x06proto3')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/kad_dht/pb/kademlia.proto\":\n\x06Record\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x14\n\x0ctimeReceived\x18\x05 \x01(\t\"\xa2\x04\n\x07Message\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.Message.MessageType\x12\x17\n\x0f\x63lusterLevelRaw\x18\n \x01(\x05\x12\x0b\n\x03key\x18\x02 \x01(\x0c\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record\x12\"\n\x0b\x63loserPeers\x18\x08 \x03(\x0b\x32\r.Message.Peer\x12$\n\rproviderPeers\x18\t \x03(\x0b\x32\r.Message.Peer\x12\x19\n\x0csenderRecord\x18\x0b \x01(\x0cH\x00\x88\x01\x01\x1az\n\x04Peer\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12+\n\nconnection\x18\x03 \x01(\x0e\x32\x17.Message.ConnectionType\x12\x19\n\x0csignedRecord\x18\x04 \x01(\x0cH\x00\x88\x01\x01\x42\x0f\n\r_signedRecord\"i\n\x0bMessageType\x12\r\n\tPUT_VALUE\x10\x00\x12\r\n\tGET_VALUE\x10\x01\x12\x10\n\x0c\x41\x44\x44_PROVIDER\x10\x02\x12\x11\n\rGET_PROVIDERS\x10\x03\x12\r\n\tFIND_NODE\x10\x04\x12\x08\n\x04PING\x10\x05\"W\n\x0e\x43onnectionType\x12\x11\n\rNOT_CONNECTED\x10\x00\x12\r\n\tCONNECTED\x10\x01\x12\x0f\n\x0b\x43\x41N_CONNECT\x10\x02\x12\x12\n\x0e\x43\x41NNOT_CONNECT\x10\x03\x42\x0f\n\r_senderRecordb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
@ -23,11 +24,11 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
_globals['_RECORD']._serialized_start=36
|
||||
_globals['_RECORD']._serialized_end=94
|
||||
_globals['_MESSAGE']._serialized_start=97
|
||||
_globals['_MESSAGE']._serialized_end=555
|
||||
_globals['_MESSAGE_PEER']._serialized_start=281
|
||||
_globals['_MESSAGE_PEER']._serialized_end=359
|
||||
_globals['_MESSAGE_MESSAGETYPE']._serialized_start=361
|
||||
_globals['_MESSAGE_MESSAGETYPE']._serialized_end=466
|
||||
_globals['_MESSAGE_CONNECTIONTYPE']._serialized_start=468
|
||||
_globals['_MESSAGE_CONNECTIONTYPE']._serialized_end=555
|
||||
_globals['_MESSAGE']._serialized_end=643
|
||||
_globals['_MESSAGE_PEER']._serialized_start=308
|
||||
_globals['_MESSAGE_PEER']._serialized_end=430
|
||||
_globals['_MESSAGE_MESSAGETYPE']._serialized_start=432
|
||||
_globals['_MESSAGE_MESSAGETYPE']._serialized_end=537
|
||||
_globals['_MESSAGE_CONNECTIONTYPE']._serialized_start=539
|
||||
_globals['_MESSAGE_CONNECTIONTYPE']._serialized_end=626
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
||||
@ -1,133 +1,70 @@
|
||||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
"""
|
||||
from google.protobuf.internal import containers as _containers
|
||||
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import message as _message
|
||||
from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union
|
||||
|
||||
import builtins
|
||||
import collections.abc
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.internal.containers
|
||||
import google.protobuf.internal.enum_type_wrapper
|
||||
import google.protobuf.message
|
||||
import sys
|
||||
import typing
|
||||
DESCRIPTOR: _descriptor.FileDescriptor
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
import typing as typing_extensions
|
||||
else:
|
||||
import typing_extensions
|
||||
class Record(_message.Message):
|
||||
__slots__ = ("key", "value", "timeReceived")
|
||||
KEY_FIELD_NUMBER: _ClassVar[int]
|
||||
VALUE_FIELD_NUMBER: _ClassVar[int]
|
||||
TIMERECEIVED_FIELD_NUMBER: _ClassVar[int]
|
||||
key: bytes
|
||||
value: bytes
|
||||
timeReceived: str
|
||||
def __init__(self, key: _Optional[bytes] = ..., value: _Optional[bytes] = ..., timeReceived: _Optional[str] = ...) -> None: ...
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
|
||||
@typing.final
|
||||
class Record(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
KEY_FIELD_NUMBER: builtins.int
|
||||
VALUE_FIELD_NUMBER: builtins.int
|
||||
TIMERECEIVED_FIELD_NUMBER: builtins.int
|
||||
key: builtins.bytes
|
||||
value: builtins.bytes
|
||||
timeReceived: builtins.str
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
key: builtins.bytes = ...,
|
||||
value: builtins.bytes = ...,
|
||||
timeReceived: builtins.str = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["key", b"key", "timeReceived", b"timeReceived", "value", b"value"]) -> None: ...
|
||||
|
||||
global___Record = Record
|
||||
|
||||
@typing.final
|
||||
class Message(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
class _MessageType:
|
||||
ValueType = typing.NewType("ValueType", builtins.int)
|
||||
V: typing_extensions.TypeAlias = ValueType
|
||||
|
||||
class _MessageTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._MessageType.ValueType], builtins.type):
|
||||
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
||||
PUT_VALUE: Message._MessageType.ValueType # 0
|
||||
GET_VALUE: Message._MessageType.ValueType # 1
|
||||
ADD_PROVIDER: Message._MessageType.ValueType # 2
|
||||
GET_PROVIDERS: Message._MessageType.ValueType # 3
|
||||
FIND_NODE: Message._MessageType.ValueType # 4
|
||||
PING: Message._MessageType.ValueType # 5
|
||||
|
||||
class MessageType(_MessageType, metaclass=_MessageTypeEnumTypeWrapper): ...
|
||||
PUT_VALUE: Message.MessageType.ValueType # 0
|
||||
GET_VALUE: Message.MessageType.ValueType # 1
|
||||
ADD_PROVIDER: Message.MessageType.ValueType # 2
|
||||
GET_PROVIDERS: Message.MessageType.ValueType # 3
|
||||
FIND_NODE: Message.MessageType.ValueType # 4
|
||||
PING: Message.MessageType.ValueType # 5
|
||||
|
||||
class _ConnectionType:
|
||||
ValueType = typing.NewType("ValueType", builtins.int)
|
||||
V: typing_extensions.TypeAlias = ValueType
|
||||
|
||||
class _ConnectionTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._ConnectionType.ValueType], builtins.type):
|
||||
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
||||
NOT_CONNECTED: Message._ConnectionType.ValueType # 0
|
||||
CONNECTED: Message._ConnectionType.ValueType # 1
|
||||
CAN_CONNECT: Message._ConnectionType.ValueType # 2
|
||||
CANNOT_CONNECT: Message._ConnectionType.ValueType # 3
|
||||
|
||||
class ConnectionType(_ConnectionType, metaclass=_ConnectionTypeEnumTypeWrapper): ...
|
||||
NOT_CONNECTED: Message.ConnectionType.ValueType # 0
|
||||
CONNECTED: Message.ConnectionType.ValueType # 1
|
||||
CAN_CONNECT: Message.ConnectionType.ValueType # 2
|
||||
CANNOT_CONNECT: Message.ConnectionType.ValueType # 3
|
||||
|
||||
@typing.final
|
||||
class Peer(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
ID_FIELD_NUMBER: builtins.int
|
||||
ADDRS_FIELD_NUMBER: builtins.int
|
||||
CONNECTION_FIELD_NUMBER: builtins.int
|
||||
id: builtins.bytes
|
||||
connection: global___Message.ConnectionType.ValueType
|
||||
@property
|
||||
def addrs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
id: builtins.bytes = ...,
|
||||
addrs: collections.abc.Iterable[builtins.bytes] | None = ...,
|
||||
connection: global___Message.ConnectionType.ValueType = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["addrs", b"addrs", "connection", b"connection", "id", b"id"]) -> None: ...
|
||||
|
||||
TYPE_FIELD_NUMBER: builtins.int
|
||||
CLUSTERLEVELRAW_FIELD_NUMBER: builtins.int
|
||||
KEY_FIELD_NUMBER: builtins.int
|
||||
RECORD_FIELD_NUMBER: builtins.int
|
||||
CLOSERPEERS_FIELD_NUMBER: builtins.int
|
||||
PROVIDERPEERS_FIELD_NUMBER: builtins.int
|
||||
type: global___Message.MessageType.ValueType
|
||||
clusterLevelRaw: builtins.int
|
||||
key: builtins.bytes
|
||||
@property
|
||||
def record(self) -> global___Record: ...
|
||||
@property
|
||||
def closerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ...
|
||||
@property
|
||||
def providerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
type: global___Message.MessageType.ValueType = ...,
|
||||
clusterLevelRaw: builtins.int = ...,
|
||||
key: builtins.bytes = ...,
|
||||
record: global___Record | None = ...,
|
||||
closerPeers: collections.abc.Iterable[global___Message.Peer] | None = ...,
|
||||
providerPeers: collections.abc.Iterable[global___Message.Peer] | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["record", b"record"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["closerPeers", b"closerPeers", "clusterLevelRaw", b"clusterLevelRaw", "key", b"key", "providerPeers", b"providerPeers", "record", b"record", "type", b"type"]) -> None: ...
|
||||
|
||||
global___Message = Message
|
||||
class Message(_message.Message):
|
||||
__slots__ = ("type", "clusterLevelRaw", "key", "record", "closerPeers", "providerPeers", "senderRecord")
|
||||
class MessageType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
||||
__slots__ = ()
|
||||
PUT_VALUE: _ClassVar[Message.MessageType]
|
||||
GET_VALUE: _ClassVar[Message.MessageType]
|
||||
ADD_PROVIDER: _ClassVar[Message.MessageType]
|
||||
GET_PROVIDERS: _ClassVar[Message.MessageType]
|
||||
FIND_NODE: _ClassVar[Message.MessageType]
|
||||
PING: _ClassVar[Message.MessageType]
|
||||
PUT_VALUE: Message.MessageType
|
||||
GET_VALUE: Message.MessageType
|
||||
ADD_PROVIDER: Message.MessageType
|
||||
GET_PROVIDERS: Message.MessageType
|
||||
FIND_NODE: Message.MessageType
|
||||
PING: Message.MessageType
|
||||
class ConnectionType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
||||
__slots__ = ()
|
||||
NOT_CONNECTED: _ClassVar[Message.ConnectionType]
|
||||
CONNECTED: _ClassVar[Message.ConnectionType]
|
||||
CAN_CONNECT: _ClassVar[Message.ConnectionType]
|
||||
CANNOT_CONNECT: _ClassVar[Message.ConnectionType]
|
||||
NOT_CONNECTED: Message.ConnectionType
|
||||
CONNECTED: Message.ConnectionType
|
||||
CAN_CONNECT: Message.ConnectionType
|
||||
CANNOT_CONNECT: Message.ConnectionType
|
||||
class Peer(_message.Message):
|
||||
__slots__ = ("id", "addrs", "connection", "signedRecord")
|
||||
ID_FIELD_NUMBER: _ClassVar[int]
|
||||
ADDRS_FIELD_NUMBER: _ClassVar[int]
|
||||
CONNECTION_FIELD_NUMBER: _ClassVar[int]
|
||||
SIGNEDRECORD_FIELD_NUMBER: _ClassVar[int]
|
||||
id: bytes
|
||||
addrs: _containers.RepeatedScalarFieldContainer[bytes]
|
||||
connection: Message.ConnectionType
|
||||
signedRecord: bytes
|
||||
def __init__(self, id: _Optional[bytes] = ..., addrs: _Optional[_Iterable[bytes]] = ..., connection: _Optional[_Union[Message.ConnectionType, str]] = ..., signedRecord: _Optional[bytes] = ...) -> None: ...
|
||||
TYPE_FIELD_NUMBER: _ClassVar[int]
|
||||
CLUSTERLEVELRAW_FIELD_NUMBER: _ClassVar[int]
|
||||
KEY_FIELD_NUMBER: _ClassVar[int]
|
||||
RECORD_FIELD_NUMBER: _ClassVar[int]
|
||||
CLOSERPEERS_FIELD_NUMBER: _ClassVar[int]
|
||||
PROVIDERPEERS_FIELD_NUMBER: _ClassVar[int]
|
||||
SENDERRECORD_FIELD_NUMBER: _ClassVar[int]
|
||||
type: Message.MessageType
|
||||
clusterLevelRaw: int
|
||||
key: bytes
|
||||
record: Record
|
||||
closerPeers: _containers.RepeatedCompositeFieldContainer[Message.Peer]
|
||||
providerPeers: _containers.RepeatedCompositeFieldContainer[Message.Peer]
|
||||
senderRecord: bytes
|
||||
def __init__(self, type: _Optional[_Union[Message.MessageType, str]] = ..., clusterLevelRaw: _Optional[int] = ..., key: _Optional[bytes] = ..., record: _Optional[_Union[Record, _Mapping]] = ..., closerPeers: _Optional[_Iterable[_Union[Message.Peer, _Mapping]]] = ..., providerPeers: _Optional[_Iterable[_Union[Message.Peer, _Mapping]]] = ..., senderRecord: _Optional[bytes] = ...) -> None: ... # type: ignore
|
||||
|
||||
@ -15,12 +15,14 @@ from libp2p.abc import (
|
||||
INetStream,
|
||||
IPeerRouting,
|
||||
)
|
||||
from libp2p.peer.envelope import Envelope
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
from libp2p.peer.peerstore import env_to_send_in_RPC
|
||||
|
||||
from .common import (
|
||||
ALPHA,
|
||||
@ -33,6 +35,7 @@ from .routing_table import (
|
||||
RoutingTable,
|
||||
)
|
||||
from .utils import (
|
||||
maybe_consume_signed_record,
|
||||
sort_peer_ids_by_distance,
|
||||
)
|
||||
|
||||
@ -170,7 +173,7 @@ class PeerRouting(IPeerRouting):
|
||||
|
||||
# Return early if we have no peers to start with
|
||||
if not closest_peers:
|
||||
logger.warning("No local peers available for network lookup")
|
||||
logger.debug("No local peers available for network lookup")
|
||||
return []
|
||||
|
||||
# Iterative lookup until convergence
|
||||
@ -255,6 +258,10 @@ class PeerRouting(IPeerRouting):
|
||||
find_node_msg.type = Message.MessageType.FIND_NODE
|
||||
find_node_msg.key = target_key # Set target key directly as bytes
|
||||
|
||||
# Create sender_signed_peer_record
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
||||
find_node_msg.senderRecord = envelope_bytes
|
||||
|
||||
# Serialize and send the protobuf message with varint length prefix
|
||||
proto_bytes = find_node_msg.SerializeToString()
|
||||
logger.debug(
|
||||
@ -299,7 +306,22 @@ class PeerRouting(IPeerRouting):
|
||||
|
||||
# Process closest peers from response
|
||||
if response_msg.type == Message.MessageType.FIND_NODE:
|
||||
# Consume the sender_signed_peer_record
|
||||
if not maybe_consume_signed_record(response_msg, self.host, peer):
|
||||
logger.error(
|
||||
"Received an invalid-signed-record,ignoring the response"
|
||||
)
|
||||
return []
|
||||
|
||||
for peer_data in response_msg.closerPeers:
|
||||
# Consume the received closer_peers signed-records, peer-id is
|
||||
# sent with the peer-data
|
||||
if not maybe_consume_signed_record(peer_data, self.host):
|
||||
logger.error(
|
||||
"Received an invalid-signed-record,ignoring the response"
|
||||
)
|
||||
return []
|
||||
|
||||
new_peer_id = ID(peer_data.id)
|
||||
if new_peer_id not in results:
|
||||
results.append(new_peer_id)
|
||||
@ -332,6 +354,7 @@ class PeerRouting(IPeerRouting):
|
||||
"""
|
||||
try:
|
||||
# Read message length
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
length_bytes = await stream.read(4)
|
||||
if not length_bytes:
|
||||
return
|
||||
@ -345,10 +368,18 @@ class PeerRouting(IPeerRouting):
|
||||
|
||||
# Parse protobuf message
|
||||
kad_message = Message()
|
||||
closer_peer_envelope: Envelope | None = None
|
||||
try:
|
||||
kad_message.ParseFromString(message_bytes)
|
||||
|
||||
if kad_message.type == Message.MessageType.FIND_NODE:
|
||||
# Consume the sender's signed-peer-record if sent
|
||||
if not maybe_consume_signed_record(kad_message, self.host, peer_id):
|
||||
logger.error(
|
||||
"Received an invalid-signed-record, dropping the stream"
|
||||
)
|
||||
return
|
||||
|
||||
# Get target key directly from protobuf message
|
||||
target_key = kad_message.key
|
||||
|
||||
@ -361,12 +392,26 @@ class PeerRouting(IPeerRouting):
|
||||
response = Message()
|
||||
response.type = Message.MessageType.FIND_NODE
|
||||
|
||||
# Create sender_signed_peer_record for the response
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
||||
response.senderRecord = envelope_bytes
|
||||
|
||||
# Add peer information to response
|
||||
for peer_id in closest_peers:
|
||||
peer_proto = response.closerPeers.add()
|
||||
peer_proto.id = peer_id.to_bytes()
|
||||
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
|
||||
|
||||
# Add the signed-records of closest_peers if cached
|
||||
closer_peer_envelope = (
|
||||
self.host.get_peerstore().get_peer_record(peer_id)
|
||||
)
|
||||
|
||||
if isinstance(closer_peer_envelope, Envelope):
|
||||
peer_proto.signedRecord = (
|
||||
closer_peer_envelope.marshal_envelope()
|
||||
)
|
||||
|
||||
# Add addresses if available
|
||||
try:
|
||||
addrs = self.host.get_peerstore().addrs(peer_id)
|
||||
|
||||
@ -22,12 +22,14 @@ from libp2p.abc import (
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.kad_dht.utils import maybe_consume_signed_record
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
from libp2p.peer.peerstore import env_to_send_in_RPC
|
||||
|
||||
from .common import (
|
||||
ALPHA,
|
||||
@ -240,11 +242,18 @@ class ProviderStore:
|
||||
message.type = Message.MessageType.ADD_PROVIDER
|
||||
message.key = key
|
||||
|
||||
# Create sender's signed-peer-record
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
||||
message.senderRecord = envelope_bytes
|
||||
|
||||
# Add our provider info
|
||||
provider = message.providerPeers.add()
|
||||
provider.id = self.local_peer_id.to_bytes()
|
||||
provider.addrs.extend(addrs)
|
||||
|
||||
# Add the provider's signed-peer-record
|
||||
provider.signedRecord = envelope_bytes
|
||||
|
||||
# Serialize and send the message
|
||||
proto_bytes = message.SerializeToString()
|
||||
await stream.write(varint.encode(len(proto_bytes)))
|
||||
@ -276,10 +285,15 @@ class ProviderStore:
|
||||
response = Message()
|
||||
response.ParseFromString(response_bytes)
|
||||
|
||||
# Check response type
|
||||
response.type == Message.MessageType.ADD_PROVIDER
|
||||
if response.type:
|
||||
result = True
|
||||
if response.type == Message.MessageType.ADD_PROVIDER:
|
||||
# Consume the sender's signed-peer-record if sent
|
||||
if not maybe_consume_signed_record(response, self.host, peer_id):
|
||||
logger.error(
|
||||
"Received an invalid-signed-record, ignoring the response"
|
||||
)
|
||||
result = False
|
||||
else:
|
||||
result = True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error sending ADD_PROVIDER to {peer_id}: {e}")
|
||||
@ -380,6 +394,10 @@ class ProviderStore:
|
||||
message.type = Message.MessageType.GET_PROVIDERS
|
||||
message.key = key
|
||||
|
||||
# Create sender's signed-peer-record
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
||||
message.senderRecord = envelope_bytes
|
||||
|
||||
# Serialize and send the message
|
||||
proto_bytes = message.SerializeToString()
|
||||
await stream.write(varint.encode(len(proto_bytes)))
|
||||
@ -414,10 +432,26 @@ class ProviderStore:
|
||||
if response.type != Message.MessageType.GET_PROVIDERS:
|
||||
return []
|
||||
|
||||
# Consume the sender's signed-peer-record if sent
|
||||
if not maybe_consume_signed_record(response, self.host, peer_id):
|
||||
logger.error(
|
||||
"Received an invalid-signed-record, ignoring the response"
|
||||
)
|
||||
return []
|
||||
|
||||
# Extract provider information
|
||||
providers = []
|
||||
for provider_proto in response.providerPeers:
|
||||
try:
|
||||
# Consume the provider's signed-peer-record if sent, peer-id
|
||||
# already sent with the provider-proto
|
||||
if not maybe_consume_signed_record(provider_proto, self.host):
|
||||
logger.error(
|
||||
"Received an invalid-signed-record, "
|
||||
"ignoring the response"
|
||||
)
|
||||
return []
|
||||
|
||||
# Create peer ID from bytes
|
||||
provider_id = ID(provider_proto.id)
|
||||
|
||||
@ -431,6 +465,7 @@ class ProviderStore:
|
||||
|
||||
# Create PeerInfo and add to result
|
||||
providers.append(PeerInfo(provider_id, addrs))
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse provider info: {e}")
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ from collections import (
|
||||
import logging
|
||||
import time
|
||||
|
||||
import multihash
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
@ -40,6 +41,22 @@ PEER_REFRESH_INTERVAL = 60 # Interval to refresh peers in seconds
|
||||
STALE_PEER_THRESHOLD = 3600 # Time in seconds after which a peer is considered stale
|
||||
|
||||
|
||||
def peer_id_to_key(peer_id: ID) -> bytes:
|
||||
"""
|
||||
Convert a peer ID to a 256-bit key for routing table operations.
|
||||
This normalizes all peer IDs to exactly 256 bits by hashing them with SHA-256.
|
||||
|
||||
:param peer_id: The peer ID to convert
|
||||
:return: 32-byte (256-bit) key for routing table operations
|
||||
"""
|
||||
return multihash.digest(peer_id.to_bytes(), "sha2-256").digest
|
||||
|
||||
|
||||
def key_to_int(key: bytes) -> int:
|
||||
"""Convert a 256-bit key to an integer for range calculations."""
|
||||
return int.from_bytes(key, byteorder="big")
|
||||
|
||||
|
||||
class KBucket:
|
||||
"""
|
||||
A k-bucket implementation for the Kademlia DHT.
|
||||
@ -357,9 +374,24 @@ class KBucket:
|
||||
True if the key is in range, False otherwise
|
||||
|
||||
"""
|
||||
key_int = int.from_bytes(key, byteorder="big")
|
||||
key_int = key_to_int(key)
|
||||
return self.min_range <= key_int < self.max_range
|
||||
|
||||
def peer_id_in_range(self, peer_id: ID) -> bool:
|
||||
"""
|
||||
Check if a peer ID is in the range of this bucket.
|
||||
|
||||
params: peer_id: The peer ID to check
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the peer ID is in range, False otherwise
|
||||
|
||||
"""
|
||||
key = peer_id_to_key(peer_id)
|
||||
return self.key_in_range(key)
|
||||
|
||||
def split(self) -> tuple["KBucket", "KBucket"]:
|
||||
"""
|
||||
Split the bucket into two buckets.
|
||||
@ -376,8 +408,9 @@ class KBucket:
|
||||
|
||||
# Redistribute peers
|
||||
for peer_id, (peer_info, timestamp) in self.peers.items():
|
||||
peer_key = int.from_bytes(peer_id.to_bytes(), byteorder="big")
|
||||
if peer_key < midpoint:
|
||||
peer_key = peer_id_to_key(peer_id)
|
||||
peer_key_int = key_to_int(peer_key)
|
||||
if peer_key_int < midpoint:
|
||||
lower_bucket.peers[peer_id] = (peer_info, timestamp)
|
||||
else:
|
||||
upper_bucket.peers[peer_id] = (peer_info, timestamp)
|
||||
@ -458,7 +491,38 @@ class RoutingTable:
|
||||
success = await bucket.add_peer(peer_info)
|
||||
if success:
|
||||
logger.debug(f"Successfully added peer {peer_id} to routing table")
|
||||
return success
|
||||
return True
|
||||
|
||||
# If bucket is full and couldn't add peer, try splitting the bucket
|
||||
# Only split if the bucket contains our Peer ID
|
||||
if self._should_split_bucket(bucket):
|
||||
logger.debug(
|
||||
f"Bucket is full, attempting to split bucket for peer {peer_id}"
|
||||
)
|
||||
split_success = self._split_bucket(bucket)
|
||||
if split_success:
|
||||
# After splitting,
|
||||
# find the appropriate bucket for the peer and try to add it
|
||||
target_bucket = self.find_bucket(peer_info.peer_id)
|
||||
success = await target_bucket.add_peer(peer_info)
|
||||
if success:
|
||||
logger.debug(
|
||||
f"Successfully added peer {peer_id} after bucket split"
|
||||
)
|
||||
return True
|
||||
else:
|
||||
logger.debug(
|
||||
f"Failed to add peer {peer_id} even after bucket split"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
logger.debug(f"Failed to split bucket for peer {peer_id}")
|
||||
return False
|
||||
else:
|
||||
logger.debug(
|
||||
f"Bucket is full and cannot be split, peer {peer_id} not added"
|
||||
)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error adding peer {peer_obj} to routing table: {e}")
|
||||
@ -480,9 +544,9 @@ class RoutingTable:
|
||||
|
||||
def find_bucket(self, peer_id: ID) -> KBucket:
|
||||
"""
|
||||
Find the bucket that would contain the given peer ID or PeerInfo.
|
||||
Find the bucket that would contain the given peer ID.
|
||||
|
||||
:param peer_obj: Either a peer ID or a PeerInfo object
|
||||
:param peer_id: The peer ID to find a bucket for
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -490,7 +554,7 @@ class RoutingTable:
|
||||
|
||||
"""
|
||||
for bucket in self.buckets:
|
||||
if bucket.key_in_range(peer_id.to_bytes()):
|
||||
if bucket.peer_id_in_range(peer_id):
|
||||
return bucket
|
||||
|
||||
return self.buckets[0]
|
||||
@ -513,7 +577,11 @@ class RoutingTable:
|
||||
all_peers.extend(bucket.peer_ids())
|
||||
|
||||
# Sort by XOR distance to the key
|
||||
all_peers.sort(key=lambda p: xor_distance(p.to_bytes(), key))
|
||||
def distance_to_key(peer_id: ID) -> int:
|
||||
peer_key = peer_id_to_key(peer_id)
|
||||
return xor_distance(peer_key, key)
|
||||
|
||||
all_peers.sort(key=distance_to_key)
|
||||
|
||||
return all_peers[:count]
|
||||
|
||||
@ -591,6 +659,20 @@ class RoutingTable:
|
||||
stale_peers.extend(bucket.get_stale_peers(stale_threshold_seconds))
|
||||
return stale_peers
|
||||
|
||||
def get_peer_infos(self) -> list[PeerInfo]:
|
||||
"""
|
||||
Get all PeerInfo objects in the routing table.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[PeerInfo]: List of all PeerInfo objects
|
||||
|
||||
"""
|
||||
peer_infos = []
|
||||
for bucket in self.buckets:
|
||||
peer_infos.extend(bucket.peer_infos())
|
||||
return peer_infos
|
||||
|
||||
def cleanup_routing_table(self) -> None:
|
||||
"""
|
||||
Cleanup the routing table by removing all data.
|
||||
@ -598,3 +680,66 @@ class RoutingTable:
|
||||
"""
|
||||
self.buckets = [KBucket(self.host, BUCKET_SIZE)]
|
||||
logger.info("Routing table cleaned up, all data removed.")
|
||||
|
||||
def _should_split_bucket(self, bucket: KBucket) -> bool:
|
||||
"""
|
||||
Check if a bucket should be split according to Kademlia rules.
|
||||
|
||||
:param bucket: The bucket to check
|
||||
:return: True if the bucket should be split
|
||||
"""
|
||||
# Check if we've exceeded maximum buckets
|
||||
if len(self.buckets) >= MAXIMUM_BUCKETS:
|
||||
logger.debug("Maximum number of buckets reached, cannot split")
|
||||
return False
|
||||
|
||||
# Check if the bucket contains our local ID
|
||||
local_key = peer_id_to_key(self.local_id)
|
||||
local_key_int = key_to_int(local_key)
|
||||
contains_local_id = bucket.min_range <= local_key_int < bucket.max_range
|
||||
|
||||
logger.debug(
|
||||
f"Bucket range: {bucket.min_range} - {bucket.max_range}, "
|
||||
f"local_key_int: {local_key_int}, contains_local: {contains_local_id}"
|
||||
)
|
||||
|
||||
return contains_local_id
|
||||
|
||||
def _split_bucket(self, bucket: KBucket) -> bool:
|
||||
"""
|
||||
Split a bucket into two buckets.
|
||||
|
||||
:param bucket: The bucket to split
|
||||
:return: True if the bucket was successfully split
|
||||
"""
|
||||
try:
|
||||
# Find the bucket index
|
||||
bucket_index = self.buckets.index(bucket)
|
||||
logger.debug(f"Splitting bucket at index {bucket_index}")
|
||||
|
||||
# Split the bucket
|
||||
lower_bucket, upper_bucket = bucket.split()
|
||||
|
||||
# Replace the original bucket with the two new buckets
|
||||
self.buckets[bucket_index] = lower_bucket
|
||||
self.buckets.insert(bucket_index + 1, upper_bucket)
|
||||
|
||||
logger.debug(
|
||||
f"Bucket split successful. New bucket count: {len(self.buckets)}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Lower bucket range: "
|
||||
f"{lower_bucket.min_range} - {lower_bucket.max_range}, "
|
||||
f"peers: {lower_bucket.size()}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Upper bucket range: "
|
||||
f"{upper_bucket.min_range} - {upper_bucket.max_range}, "
|
||||
f"peers: {upper_bucket.size()}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error splitting bucket: {e}")
|
||||
return False
|
||||
|
||||
@ -2,13 +2,93 @@
|
||||
Utility functions for Kademlia DHT implementation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import base58
|
||||
import multihash
|
||||
|
||||
from libp2p.abc import IHost
|
||||
from libp2p.peer.envelope import consume_envelope
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
|
||||
from .pb.kademlia_pb2 import (
|
||||
Message,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("kademlia-example.utils")
|
||||
|
||||
|
||||
def maybe_consume_signed_record(
|
||||
msg: Message | Message.Peer, host: IHost, peer_id: ID | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Attempt to parse and store a signed-peer-record (Envelope) received during
|
||||
DHT communication. If the record is invalid, the peer-id does not match, or
|
||||
updating the peerstore fails, the function logs an error and returns False.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
msg : Message | Message.Peer
|
||||
The protobuf message received during DHT communication. Can either be a
|
||||
top-level `Message` containing `senderRecord` or a `Message.Peer`
|
||||
containing `signedRecord`.
|
||||
host : IHost
|
||||
The local host instance, providing access to the peerstore for storing
|
||||
verified peer records.
|
||||
peer_id : ID | None, optional
|
||||
The expected peer ID for record validation. If provided, the peer ID
|
||||
inside the record must match this value.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if a valid signed peer record was successfully consumed and stored,
|
||||
False otherwise.
|
||||
|
||||
"""
|
||||
if isinstance(msg, Message):
|
||||
if msg.HasField("senderRecord"):
|
||||
try:
|
||||
# Convert the signed-peer-record(Envelope) from
|
||||
# protobuf bytes
|
||||
envelope, record = consume_envelope(
|
||||
msg.senderRecord,
|
||||
"libp2p-peer-record",
|
||||
)
|
||||
if not (isinstance(peer_id, ID) and record.peer_id == peer_id):
|
||||
return False
|
||||
# Use the default TTL of 2 hours (7200 seconds)
|
||||
if not host.get_peerstore().consume_peer_record(envelope, 7200):
|
||||
logger.error("Failed to update the Certified-Addr-Book")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error("Failed to update the Certified-Addr-Book: %s", e)
|
||||
return False
|
||||
else:
|
||||
if msg.HasField("signedRecord"):
|
||||
try:
|
||||
# Convert the signed-peer-record(Envelope) from
|
||||
# protobuf bytes
|
||||
envelope, record = consume_envelope(
|
||||
msg.signedRecord,
|
||||
"libp2p-peer-record",
|
||||
)
|
||||
if not record.peer_id.to_bytes() == msg.id:
|
||||
return False
|
||||
# Use the default TTL of 2 hours (7200 seconds)
|
||||
if not host.get_peerstore().consume_peer_record(envelope, 7200):
|
||||
logger.error("Failed to update the Certified-Addr-Book")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to update the Certified-Addr-Book: %s",
|
||||
e,
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def create_key_from_binary(binary_data: bytes) -> bytes:
|
||||
"""
|
||||
|
||||
@ -15,9 +15,11 @@ from libp2p.abc import (
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.kad_dht.utils import maybe_consume_signed_record
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peerstore import env_to_send_in_RPC
|
||||
|
||||
from .common import (
|
||||
DEFAULT_TTL,
|
||||
@ -110,6 +112,10 @@ class ValueStore:
|
||||
message = Message()
|
||||
message.type = Message.MessageType.PUT_VALUE
|
||||
|
||||
# Create sender's signed-peer-record
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
||||
message.senderRecord = envelope_bytes
|
||||
|
||||
# Set message fields
|
||||
message.key = key
|
||||
message.record.key = key
|
||||
@ -155,7 +161,13 @@ class ValueStore:
|
||||
|
||||
# Check if response is valid
|
||||
if response.type == Message.MessageType.PUT_VALUE:
|
||||
if response.key:
|
||||
# Consume the sender's signed-peer-record if sent
|
||||
if not maybe_consume_signed_record(response, self.host, peer_id):
|
||||
logger.error(
|
||||
"Received an invalid-signed-record, ignoring the response"
|
||||
)
|
||||
return False
|
||||
if response.key == key:
|
||||
result = True
|
||||
return result
|
||||
|
||||
@ -231,6 +243,10 @@ class ValueStore:
|
||||
message.type = Message.MessageType.GET_VALUE
|
||||
message.key = key
|
||||
|
||||
# Create sender's signed-peer-record
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
||||
message.senderRecord = envelope_bytes
|
||||
|
||||
# Serialize and send the protobuf message
|
||||
proto_bytes = message.SerializeToString()
|
||||
await stream.write(varint.encode(len(proto_bytes)))
|
||||
@ -275,6 +291,13 @@ class ValueStore:
|
||||
and response.HasField("record")
|
||||
and response.record.value
|
||||
):
|
||||
# Consume the sender's signed-peer-record
|
||||
if not maybe_consume_signed_record(response, self.host, peer_id):
|
||||
logger.error(
|
||||
"Received an invalid-signed-record, ignoring the response"
|
||||
)
|
||||
return None
|
||||
|
||||
logger.debug(
|
||||
f"Received value for key {key.hex()} from peer {peer_id}"
|
||||
)
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import (
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
@ -22,7 +23,8 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
"""
|
||||
Reference: https://github.com/libp2p/go-libp2p-swarm/blob/04c86bbdafd390651cb2ee14e334f7caeedad722/swarm_conn.go
|
||||
Reference: https://github.com/libp2p/go-libp2p-swarm/blob/
|
||||
04c86bbdafd390651cb2ee14e334f7caeedad722/swarm_conn.go
|
||||
"""
|
||||
|
||||
|
||||
@ -42,6 +44,21 @@ class SwarmConn(INetConn):
|
||||
self.streams = set()
|
||||
self.event_closed = trio.Event()
|
||||
self.event_started = trio.Event()
|
||||
# Provide back-references/hooks expected by NetStream
|
||||
try:
|
||||
setattr(self.muxed_conn, "swarm", self.swarm)
|
||||
|
||||
# NetStream expects an awaitable remove_stream hook
|
||||
async def _remove_stream_hook(stream: NetStream) -> None:
|
||||
self.remove_stream(stream)
|
||||
|
||||
setattr(self.muxed_conn, "remove_stream", _remove_stream_hook)
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
f"Failed to set optional conveniences on muxed_conn "
|
||||
f"for peer {muxed_conn.peer_id}: {e}"
|
||||
)
|
||||
# optional conveniences
|
||||
if hasattr(muxed_conn, "on_close"):
|
||||
logging.debug(f"Setting on_close for peer {muxed_conn.peer_id}")
|
||||
setattr(muxed_conn, "on_close", self._on_muxed_conn_closed)
|
||||
@ -147,6 +164,24 @@ class SwarmConn(INetConn):
|
||||
def get_streams(self) -> tuple[NetStream, ...]:
|
||||
return tuple(self.streams)
|
||||
|
||||
def get_transport_addresses(self) -> list[Multiaddr]:
|
||||
"""
|
||||
Retrieve the transport addresses used by this connection.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[Multiaddr]
|
||||
A list of multiaddresses used by the transport.
|
||||
|
||||
"""
|
||||
# Return the addresses from the peerstore for this peer
|
||||
try:
|
||||
peer_id = self.muxed_conn.peer_id
|
||||
return self.swarm.peerstore.addrs(peer_id)
|
||||
except Exception as e:
|
||||
logging.warning(f"Error getting transport addresses: {e}")
|
||||
return []
|
||||
|
||||
def remove_stream(self, stream: NetStream) -> None:
|
||||
if stream not in self.streams:
|
||||
return
|
||||
|
||||
@ -1,4 +1,10 @@
|
||||
from collections.abc import (
|
||||
Awaitable,
|
||||
Callable,
|
||||
)
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
import random
|
||||
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
@ -55,6 +61,59 @@ from .exceptions import (
|
||||
logger = logging.getLogger("libp2p.network.swarm")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetryConfig:
|
||||
"""
|
||||
Configuration for retry logic with exponential backoff.
|
||||
|
||||
This configuration controls how connection attempts are retried when they fail.
|
||||
The retry mechanism uses exponential backoff with jitter to prevent thundering
|
||||
herd problems in distributed systems.
|
||||
|
||||
Attributes:
|
||||
max_retries: Maximum number of retry attempts before giving up.
|
||||
Default: 3 attempts
|
||||
initial_delay: Initial delay in seconds before the first retry.
|
||||
Default: 0.1 seconds (100ms)
|
||||
max_delay: Maximum delay cap in seconds to prevent excessive wait times.
|
||||
Default: 30.0 seconds
|
||||
backoff_multiplier: Multiplier for exponential backoff (each retry multiplies
|
||||
the delay by this factor). Default: 2.0 (doubles each time)
|
||||
jitter_factor: Random jitter factor (0.0-1.0) to add randomness to delays
|
||||
and prevent synchronized retries. Default: 0.1 (10% jitter)
|
||||
|
||||
"""
|
||||
|
||||
max_retries: int = 3
|
||||
initial_delay: float = 0.1
|
||||
max_delay: float = 30.0
|
||||
backoff_multiplier: float = 2.0
|
||||
jitter_factor: float = 0.1
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionConfig:
|
||||
"""
|
||||
Configuration for multi-connection support.
|
||||
|
||||
This configuration controls how multiple connections per peer are managed,
|
||||
including connection limits, timeouts, and load balancing strategies.
|
||||
|
||||
Attributes:
|
||||
max_connections_per_peer: Maximum number of connections allowed to a single
|
||||
peer. Default: 3 connections
|
||||
connection_timeout: Timeout in seconds for establishing new connections.
|
||||
Default: 30.0 seconds
|
||||
load_balancing_strategy: Strategy for distributing streams across connections.
|
||||
Options: "round_robin" (default) or "least_loaded"
|
||||
|
||||
"""
|
||||
|
||||
max_connections_per_peer: int = 3
|
||||
connection_timeout: float = 30.0
|
||||
load_balancing_strategy: str = "round_robin" # or "least_loaded"
|
||||
|
||||
|
||||
def create_default_stream_handler(network: INetworkService) -> StreamHandlerFn:
|
||||
async def stream_handler(stream: INetStream) -> None:
|
||||
await network.get_manager().wait_finished()
|
||||
@ -67,9 +126,8 @@ class Swarm(Service, INetworkService):
|
||||
peerstore: IPeerStore
|
||||
upgrader: TransportUpgrader
|
||||
transport: ITransport
|
||||
# TODO: Connection and `peer_id` are 1-1 mapping in our implementation,
|
||||
# whereas in Go one `peer_id` may point to multiple connections.
|
||||
connections: dict[ID, INetConn]
|
||||
# Enhanced: Support for multiple connections per peer
|
||||
connections: dict[ID, list[INetConn]] # Multiple connections per peer
|
||||
listeners: dict[str, IListener]
|
||||
common_stream_handler: StreamHandlerFn
|
||||
listener_nursery: trio.Nursery | None
|
||||
@ -77,18 +135,31 @@ class Swarm(Service, INetworkService):
|
||||
|
||||
notifees: list[INotifee]
|
||||
|
||||
# Enhanced: New configuration
|
||||
retry_config: RetryConfig
|
||||
connection_config: ConnectionConfig
|
||||
_round_robin_index: dict[ID, int]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
peer_id: ID,
|
||||
peerstore: IPeerStore,
|
||||
upgrader: TransportUpgrader,
|
||||
transport: ITransport,
|
||||
retry_config: RetryConfig | None = None,
|
||||
connection_config: ConnectionConfig | None = None,
|
||||
):
|
||||
self.self_id = peer_id
|
||||
self.peerstore = peerstore
|
||||
self.upgrader = upgrader
|
||||
self.transport = transport
|
||||
self.connections = dict()
|
||||
|
||||
# Enhanced: Initialize retry and connection configuration
|
||||
self.retry_config = retry_config or RetryConfig()
|
||||
self.connection_config = connection_config or ConnectionConfig()
|
||||
|
||||
# Enhanced: Initialize connections as 1:many mapping
|
||||
self.connections = {}
|
||||
self.listeners = dict()
|
||||
|
||||
# Create Notifee array
|
||||
@ -99,6 +170,9 @@ class Swarm(Service, INetworkService):
|
||||
self.listener_nursery = None
|
||||
self.event_listener_nursery_created = trio.Event()
|
||||
|
||||
# Load balancing state
|
||||
self._round_robin_index = {}
|
||||
|
||||
async def run(self) -> None:
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Create a nursery for listener tasks.
|
||||
@ -118,18 +192,74 @@ class Swarm(Service, INetworkService):
|
||||
def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None:
|
||||
self.common_stream_handler = stream_handler
|
||||
|
||||
async def dial_peer(self, peer_id: ID) -> INetConn:
|
||||
def get_connections(self, peer_id: ID | None = None) -> list[INetConn]:
|
||||
"""
|
||||
Try to create a connection to peer_id.
|
||||
Get connections for peer (like JS getConnections, Go ConnsToPeer).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID | None
|
||||
The peer ID to get connections for. If None, returns all connections.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[INetConn]
|
||||
List of connections to the specified peer, or all connections
|
||||
if peer_id is None.
|
||||
|
||||
"""
|
||||
if peer_id is not None:
|
||||
return self.connections.get(peer_id, [])
|
||||
|
||||
# Return all connections from all peers
|
||||
all_conns = []
|
||||
for conns in self.connections.values():
|
||||
all_conns.extend(conns)
|
||||
return all_conns
|
||||
|
||||
def get_connections_map(self) -> dict[ID, list[INetConn]]:
|
||||
"""
|
||||
Get all connections map (like JS getConnectionsMap).
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[ID, list[INetConn]]
|
||||
The complete mapping of peer IDs to their connection lists.
|
||||
|
||||
"""
|
||||
return self.connections.copy()
|
||||
|
||||
def get_connection(self, peer_id: ID) -> INetConn | None:
|
||||
"""
|
||||
Get single connection for backward compatibility.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer ID to get a connection for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
INetConn | None
|
||||
The first available connection, or None if no connections exist.
|
||||
|
||||
"""
|
||||
conns = self.get_connections(peer_id)
|
||||
return conns[0] if conns else None
|
||||
|
||||
async def dial_peer(self, peer_id: ID) -> list[INetConn]:
|
||||
"""
|
||||
Try to create connections to peer_id with enhanced retry logic.
|
||||
|
||||
:param peer_id: peer if we want to dial
|
||||
:raises SwarmException: raised when an error occurs
|
||||
:return: muxed connection
|
||||
:return: list of muxed connections
|
||||
"""
|
||||
if peer_id in self.connections:
|
||||
# If muxed connection already exists for peer_id,
|
||||
# set muxed connection equal to existing muxed connection
|
||||
return self.connections[peer_id]
|
||||
# Check if we already have connections
|
||||
existing_connections = self.get_connections(peer_id)
|
||||
if existing_connections:
|
||||
logger.debug(f"Reusing existing connections to peer {peer_id}")
|
||||
return existing_connections
|
||||
|
||||
logger.debug("attempting to dial peer %s", peer_id)
|
||||
|
||||
@ -142,12 +272,19 @@ class Swarm(Service, INetworkService):
|
||||
if not addrs:
|
||||
raise SwarmException(f"No known addresses to peer {peer_id}")
|
||||
|
||||
connections = []
|
||||
exceptions: list[SwarmException] = []
|
||||
|
||||
# Try all known addresses
|
||||
# Enhanced: Try all known addresses with retry logic
|
||||
for multiaddr in addrs:
|
||||
try:
|
||||
return await self.dial_addr(multiaddr, peer_id)
|
||||
connection = await self._dial_with_retry(multiaddr, peer_id)
|
||||
connections.append(connection)
|
||||
|
||||
# Limit number of connections per peer
|
||||
if len(connections) >= self.connection_config.max_connections_per_peer:
|
||||
break
|
||||
|
||||
except SwarmException as e:
|
||||
exceptions.append(e)
|
||||
logger.debug(
|
||||
@ -157,15 +294,73 @@ class Swarm(Service, INetworkService):
|
||||
exc_info=e,
|
||||
)
|
||||
|
||||
# Tried all addresses, raising exception.
|
||||
raise SwarmException(
|
||||
f"unable to connect to {peer_id}, no addresses established a successful "
|
||||
"connection (with exceptions)"
|
||||
) from MultiError(exceptions)
|
||||
if not connections:
|
||||
# Tried all addresses, raising exception.
|
||||
raise SwarmException(
|
||||
f"unable to connect to {peer_id}, no addresses established a "
|
||||
"successful connection (with exceptions)"
|
||||
) from MultiError(exceptions)
|
||||
|
||||
async def dial_addr(self, addr: Multiaddr, peer_id: ID) -> INetConn:
|
||||
return connections
|
||||
|
||||
async def _dial_with_retry(self, addr: Multiaddr, peer_id: ID) -> INetConn:
|
||||
"""
|
||||
Try to create a connection to peer_id with addr.
|
||||
Enhanced: Dial with retry logic and exponential backoff.
|
||||
|
||||
:param addr: the address to dial
|
||||
:param peer_id: the peer we want to connect to
|
||||
:raises SwarmException: raised when all retry attempts fail
|
||||
:return: network connection
|
||||
"""
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(self.retry_config.max_retries + 1):
|
||||
try:
|
||||
return await self._dial_addr_single_attempt(addr, peer_id)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempt < self.retry_config.max_retries:
|
||||
delay = self._calculate_backoff_delay(attempt)
|
||||
logger.debug(
|
||||
f"Connection attempt {attempt + 1} failed, "
|
||||
f"retrying in {delay:.2f}s: {e}"
|
||||
)
|
||||
await trio.sleep(delay)
|
||||
else:
|
||||
logger.debug(f"All {self.retry_config.max_retries} attempts failed")
|
||||
|
||||
# Convert the last exception to SwarmException for consistency
|
||||
if last_exception is not None:
|
||||
if isinstance(last_exception, SwarmException):
|
||||
raise last_exception
|
||||
else:
|
||||
raise SwarmException(
|
||||
f"Failed to connect after {self.retry_config.max_retries} attempts"
|
||||
) from last_exception
|
||||
|
||||
# This should never be reached, but mypy requires it
|
||||
raise SwarmException("Unexpected error in retry logic")
|
||||
|
||||
def _calculate_backoff_delay(self, attempt: int) -> float:
|
||||
"""
|
||||
Enhanced: Calculate backoff delay with jitter to prevent thundering herd.
|
||||
|
||||
:param attempt: the current attempt number (0-based)
|
||||
:return: delay in seconds
|
||||
"""
|
||||
delay = min(
|
||||
self.retry_config.initial_delay
|
||||
* (self.retry_config.backoff_multiplier**attempt),
|
||||
self.retry_config.max_delay,
|
||||
)
|
||||
|
||||
# Add jitter to prevent synchronized retries
|
||||
jitter = delay * self.retry_config.jitter_factor
|
||||
return delay + random.uniform(-jitter, jitter)
|
||||
|
||||
async def _dial_addr_single_attempt(self, addr: Multiaddr, peer_id: ID) -> INetConn:
|
||||
"""
|
||||
Enhanced: Single attempt to dial an address (extracted from original dial_addr).
|
||||
|
||||
:param addr: the address we want to connect with
|
||||
:param peer_id: the peer we want to connect to
|
||||
@ -212,19 +407,97 @@ class Swarm(Service, INetworkService):
|
||||
|
||||
return swarm_conn
|
||||
|
||||
async def dial_addr(self, addr: Multiaddr, peer_id: ID) -> INetConn:
|
||||
"""
|
||||
Enhanced: Try to create a connection to peer_id with addr using retry logic.
|
||||
|
||||
:param addr: the address we want to connect with
|
||||
:param peer_id: the peer we want to connect to
|
||||
:raises SwarmException: raised when an error occurs
|
||||
:return: network connection
|
||||
"""
|
||||
return await self._dial_with_retry(addr, peer_id)
|
||||
|
||||
async def new_stream(self, peer_id: ID) -> INetStream:
|
||||
"""
|
||||
Enhanced: Create a new stream with load balancing across multiple connections.
|
||||
|
||||
:param peer_id: peer_id of destination
|
||||
:raises SwarmException: raised when an error occurs
|
||||
:return: net stream instance
|
||||
"""
|
||||
logger.debug("attempting to open a stream to peer %s", peer_id)
|
||||
|
||||
swarm_conn = await self.dial_peer(peer_id)
|
||||
# Get existing connections or dial new ones
|
||||
connections = self.get_connections(peer_id)
|
||||
if not connections:
|
||||
connections = await self.dial_peer(peer_id)
|
||||
|
||||
net_stream = await swarm_conn.new_stream()
|
||||
logger.debug("successfully opened a stream to peer %s", peer_id)
|
||||
return net_stream
|
||||
# Load balancing strategy at interface level
|
||||
connection = self._select_connection(connections, peer_id)
|
||||
|
||||
try:
|
||||
net_stream = await connection.new_stream()
|
||||
logger.debug("successfully opened a stream to peer %s", peer_id)
|
||||
return net_stream
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to create stream on connection: {e}")
|
||||
# Try other connections if available
|
||||
for other_conn in connections:
|
||||
if other_conn != connection:
|
||||
try:
|
||||
net_stream = await other_conn.new_stream()
|
||||
logger.debug(
|
||||
f"Successfully opened a stream to peer {peer_id} "
|
||||
"using alternative connection"
|
||||
)
|
||||
return net_stream
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# All connections failed, raise exception
|
||||
raise SwarmException(f"Failed to create stream to peer {peer_id}") from e
|
||||
|
||||
def _select_connection(self, connections: list[INetConn], peer_id: ID) -> INetConn:
|
||||
"""
|
||||
Select connection based on load balancing strategy.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
connections : list[INetConn]
|
||||
List of available connections.
|
||||
peer_id : ID
|
||||
The peer ID for round-robin tracking.
|
||||
strategy : str
|
||||
Load balancing strategy ("round_robin", "least_loaded", etc.).
|
||||
|
||||
Returns
|
||||
-------
|
||||
INetConn
|
||||
Selected connection.
|
||||
|
||||
"""
|
||||
if not connections:
|
||||
raise ValueError("No connections available")
|
||||
|
||||
strategy = self.connection_config.load_balancing_strategy
|
||||
|
||||
if strategy == "round_robin":
|
||||
# Simple round-robin selection
|
||||
if peer_id not in self._round_robin_index:
|
||||
self._round_robin_index[peer_id] = 0
|
||||
|
||||
index = self._round_robin_index[peer_id] % len(connections)
|
||||
self._round_robin_index[peer_id] += 1
|
||||
return connections[index]
|
||||
|
||||
elif strategy == "least_loaded":
|
||||
# Find connection with least streams
|
||||
return min(connections, key=lambda c: len(c.get_streams()))
|
||||
|
||||
else:
|
||||
# Default to first connection
|
||||
return connections[0]
|
||||
|
||||
async def listen(self, *multiaddrs: Multiaddr) -> bool:
|
||||
"""
|
||||
@ -245,9 +518,11 @@ class Swarm(Service, INetworkService):
|
||||
# We need to wait until `self.listener_nursery` is created.
|
||||
await self.event_listener_nursery_created.wait()
|
||||
|
||||
success_count = 0
|
||||
for maddr in multiaddrs:
|
||||
if str(maddr) in self.listeners:
|
||||
return True
|
||||
success_count += 1
|
||||
continue
|
||||
|
||||
async def conn_handler(
|
||||
read_write_closer: ReadWriteCloser, maddr: Multiaddr = maddr
|
||||
@ -298,13 +573,14 @@ class Swarm(Service, INetworkService):
|
||||
# Call notifiers since event occurred
|
||||
await self.notify_listen(maddr)
|
||||
|
||||
return True
|
||||
success_count += 1
|
||||
logger.debug("successfully started listening on: %s", maddr)
|
||||
except OSError:
|
||||
# Failed. Continue looping.
|
||||
logger.debug("fail to listen on: %s", maddr)
|
||||
|
||||
# No maddr succeeded
|
||||
return False
|
||||
# Return true if at least one address succeeded
|
||||
return success_count > 0
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
@ -317,17 +593,25 @@ class Swarm(Service, INetworkService):
|
||||
# Perform alternative cleanup if the manager isn't initialized
|
||||
# Close all connections manually
|
||||
if hasattr(self, "connections"):
|
||||
for conn_id in list(self.connections.keys()):
|
||||
conn = self.connections[conn_id]
|
||||
await conn.close()
|
||||
for peer_id, conns in list(self.connections.items()):
|
||||
for conn in conns:
|
||||
await conn.close()
|
||||
|
||||
# Clear connection tracking dictionary
|
||||
self.connections.clear()
|
||||
|
||||
# Close all listeners
|
||||
if hasattr(self, "listeners"):
|
||||
for listener in self.listeners.values():
|
||||
for maddr_str, listener in self.listeners.items():
|
||||
await listener.close()
|
||||
# Notify about listener closure
|
||||
try:
|
||||
multiaddr = Multiaddr(maddr_str)
|
||||
await self.notify_listen_close(multiaddr)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to notify listen_close for {maddr_str}: {e}"
|
||||
)
|
||||
self.listeners.clear()
|
||||
|
||||
# Close the transport if it exists and has a close method
|
||||
@ -341,12 +625,28 @@ class Swarm(Service, INetworkService):
|
||||
logger.debug("swarm successfully closed")
|
||||
|
||||
async def close_peer(self, peer_id: ID) -> None:
|
||||
if peer_id not in self.connections:
|
||||
"""
|
||||
Close all connections to the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer ID to close connections for.
|
||||
|
||||
"""
|
||||
connections = self.get_connections(peer_id)
|
||||
if not connections:
|
||||
return
|
||||
connection = self.connections[peer_id]
|
||||
# NOTE: `connection.close` will delete `peer_id` from `self.connections`
|
||||
# and `notify_disconnected` for us.
|
||||
await connection.close()
|
||||
|
||||
# Close all connections
|
||||
for connection in connections:
|
||||
try:
|
||||
await connection.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing connection to {peer_id}: {e}")
|
||||
|
||||
# Remove from connections dict
|
||||
self.connections.pop(peer_id, None)
|
||||
|
||||
logger.debug("successfully close the connection to peer %s", peer_id)
|
||||
|
||||
@ -365,21 +665,71 @@ class Swarm(Service, INetworkService):
|
||||
await muxed_conn.event_started.wait()
|
||||
self.manager.run_task(swarm_conn.start)
|
||||
await swarm_conn.event_started.wait()
|
||||
# Store muxed_conn with peer id
|
||||
self.connections[muxed_conn.peer_id] = swarm_conn
|
||||
|
||||
# Add to connections dict with deduplication
|
||||
peer_id = muxed_conn.peer_id
|
||||
if peer_id not in self.connections:
|
||||
self.connections[peer_id] = []
|
||||
|
||||
# Check for duplicate connections by comparing the underlying muxed connection
|
||||
for existing_conn in self.connections[peer_id]:
|
||||
if existing_conn.muxed_conn == muxed_conn:
|
||||
logger.debug(f"Connection already exists for peer {peer_id}")
|
||||
# existing_conn is a SwarmConn since it's stored in the connections list
|
||||
return existing_conn # type: ignore[return-value]
|
||||
|
||||
self.connections[peer_id].append(swarm_conn)
|
||||
|
||||
# Trim if we exceed max connections
|
||||
max_conns = self.connection_config.max_connections_per_peer
|
||||
if len(self.connections[peer_id]) > max_conns:
|
||||
self._trim_connections(peer_id)
|
||||
|
||||
# Call notifiers since event occurred
|
||||
await self.notify_connected(swarm_conn)
|
||||
return swarm_conn
|
||||
|
||||
def _trim_connections(self, peer_id: ID) -> None:
|
||||
"""
|
||||
Remove oldest connections when limit is exceeded.
|
||||
"""
|
||||
connections = self.connections[peer_id]
|
||||
if len(connections) <= self.connection_config.max_connections_per_peer:
|
||||
return
|
||||
|
||||
# Sort by creation time and remove oldest
|
||||
# For now, just keep the most recent connections
|
||||
max_conns = self.connection_config.max_connections_per_peer
|
||||
connections_to_remove = connections[:-max_conns]
|
||||
|
||||
for conn in connections_to_remove:
|
||||
logger.debug(f"Trimming old connection for peer {peer_id}")
|
||||
trio.lowlevel.spawn_system_task(self._close_connection_async, conn)
|
||||
|
||||
# Keep only the most recent connections
|
||||
max_conns = self.connection_config.max_connections_per_peer
|
||||
self.connections[peer_id] = connections[-max_conns:]
|
||||
|
||||
async def _close_connection_async(self, connection: INetConn) -> None:
|
||||
"""Close a connection asynchronously."""
|
||||
try:
|
||||
await connection.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing connection: {e}")
|
||||
|
||||
def remove_conn(self, swarm_conn: SwarmConn) -> None:
|
||||
"""
|
||||
Simply remove the connection from Swarm's records, without closing
|
||||
the connection.
|
||||
"""
|
||||
peer_id = swarm_conn.muxed_conn.peer_id
|
||||
if peer_id not in self.connections:
|
||||
return
|
||||
del self.connections[peer_id]
|
||||
|
||||
if peer_id in self.connections:
|
||||
self.connections[peer_id] = [
|
||||
conn for conn in self.connections[peer_id] if conn != swarm_conn
|
||||
]
|
||||
if not self.connections[peer_id]:
|
||||
del self.connections[peer_id]
|
||||
|
||||
# Notifee
|
||||
|
||||
@ -411,7 +761,35 @@ class Swarm(Service, INetworkService):
|
||||
nursery.start_soon(notifee.listen, self, multiaddr)
|
||||
|
||||
async def notify_closed_stream(self, stream: INetStream) -> None:
|
||||
raise NotImplementedError
|
||||
async with trio.open_nursery() as nursery:
|
||||
for notifee in self.notifees:
|
||||
nursery.start_soon(notifee.closed_stream, self, stream)
|
||||
|
||||
async def notify_listen_close(self, multiaddr: Multiaddr) -> None:
|
||||
raise NotImplementedError
|
||||
async with trio.open_nursery() as nursery:
|
||||
for notifee in self.notifees:
|
||||
nursery.start_soon(notifee.listen_close, self, multiaddr)
|
||||
|
||||
# Generic notifier used by NetStream._notify_closed
|
||||
async def notify_all(self, notifier: Callable[[INotifee], Awaitable[None]]) -> None:
|
||||
async with trio.open_nursery() as nursery:
|
||||
for notifee in self.notifees:
|
||||
nursery.start_soon(notifier, notifee)
|
||||
|
||||
# Backward compatibility properties
|
||||
@property
|
||||
def connections_legacy(self) -> dict[ID, INetConn]:
|
||||
"""
|
||||
Legacy 1:1 mapping for backward compatibility.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[ID, INetConn]
|
||||
Legacy mapping with only the first connection per peer.
|
||||
|
||||
"""
|
||||
legacy_conns = {}
|
||||
for peer_id, conns in self.connections.items():
|
||||
if conns:
|
||||
legacy_conns[peer_id] = conns[0]
|
||||
return legacy_conns
|
||||
|
||||
276
libp2p/peer/envelope.py
Normal file
276
libp2p/peer/envelope.py
Normal file
@ -0,0 +1,276 @@
|
||||
from typing import Any, cast
|
||||
|
||||
import multiaddr
|
||||
|
||||
from libp2p.crypto.ed25519 import Ed25519PublicKey
|
||||
from libp2p.crypto.keys import PrivateKey, PublicKey
|
||||
from libp2p.crypto.rsa import RSAPublicKey
|
||||
from libp2p.crypto.secp256k1 import Secp256k1PublicKey
|
||||
import libp2p.peer.pb.crypto_pb2 as cryto_pb
|
||||
import libp2p.peer.pb.envelope_pb2 as pb
|
||||
import libp2p.peer.pb.peer_record_pb2 as record_pb
|
||||
from libp2p.peer.peer_record import (
|
||||
PeerRecord,
|
||||
peer_record_from_protobuf,
|
||||
unmarshal_record,
|
||||
)
|
||||
from libp2p.utils.varint import encode_uvarint
|
||||
|
||||
ENVELOPE_DOMAIN = "libp2p-peer-record"
|
||||
PEER_RECORD_CODEC = b"\x03\x01"
|
||||
|
||||
|
||||
class Envelope:
|
||||
"""
|
||||
A signed wrapper around a serialized libp2p record.
|
||||
|
||||
Envelopes are cryptographically signed by the author's private key
|
||||
and are scoped to a specific 'domain' to prevent cross-protocol replay.
|
||||
|
||||
Attributes:
|
||||
public_key: The public key that can verify the envelope's signature.
|
||||
payload_type: A multicodec code identifying the type of payload inside.
|
||||
raw_payload: The raw serialized record data.
|
||||
signature: Signature over the domain-scoped payload content.
|
||||
|
||||
"""
|
||||
|
||||
public_key: PublicKey
|
||||
payload_type: bytes
|
||||
raw_payload: bytes
|
||||
signature: bytes
|
||||
|
||||
_cached_record: PeerRecord | None = None
|
||||
_unmarshal_error: Exception | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
public_key: PublicKey,
|
||||
payload_type: bytes,
|
||||
raw_payload: bytes,
|
||||
signature: bytes,
|
||||
):
|
||||
self.public_key = public_key
|
||||
self.payload_type = payload_type
|
||||
self.raw_payload = raw_payload
|
||||
self.signature = signature
|
||||
|
||||
def marshal_envelope(self) -> bytes:
|
||||
"""
|
||||
Serialize this Envelope into its protobuf wire format.
|
||||
|
||||
Converts all envelope fields into a `pb.Envelope` protobuf message
|
||||
and returns the serialized bytes.
|
||||
|
||||
:return: Serialized envelope as bytes.
|
||||
"""
|
||||
pb_env = pb.Envelope(
|
||||
public_key=pub_key_to_protobuf(self.public_key),
|
||||
payload_type=self.payload_type,
|
||||
payload=self.raw_payload,
|
||||
signature=self.signature,
|
||||
)
|
||||
return pb_env.SerializeToString()
|
||||
|
||||
def validate(self, domain: str) -> None:
|
||||
"""
|
||||
Verify the envelope's signature within the given domain scope.
|
||||
|
||||
This ensures that the envelope has not been tampered with
|
||||
and was signed under the correct usage context.
|
||||
|
||||
:param domain: Domain string that contextualizes the signature.
|
||||
:raises ValueError: If the signature is invalid.
|
||||
"""
|
||||
unsigned = make_unsigned(domain, self.payload_type, self.raw_payload)
|
||||
if not self.public_key.verify(unsigned, self.signature):
|
||||
raise ValueError("Invalid envelope signature")
|
||||
|
||||
def record(self) -> PeerRecord:
|
||||
"""
|
||||
Lazily decode and return the embedded PeerRecord.
|
||||
|
||||
This method unmarshals the payload bytes into a `PeerRecord` instance,
|
||||
using the registered codec to identify the type. The decoded result
|
||||
is cached for future use.
|
||||
|
||||
:return: Decoded PeerRecord object.
|
||||
:raises Exception: If decoding fails or payload type is unsupported.
|
||||
"""
|
||||
if self._cached_record is not None:
|
||||
return self._cached_record
|
||||
|
||||
try:
|
||||
if self.payload_type != PEER_RECORD_CODEC:
|
||||
raise ValueError("Unsuported payload type in envelope")
|
||||
msg = record_pb.PeerRecord()
|
||||
msg.ParseFromString(self.raw_payload)
|
||||
|
||||
self._cached_record = peer_record_from_protobuf(msg)
|
||||
return self._cached_record
|
||||
except Exception as e:
|
||||
self._unmarshal_error = e
|
||||
raise
|
||||
|
||||
def equal(self, other: Any) -> bool:
|
||||
"""
|
||||
Compare this Envelope with another for structural equality.
|
||||
|
||||
Two envelopes are considered equal if:
|
||||
- They have the same public key
|
||||
- The payload type and payload bytes match
|
||||
- Their signatures are identical
|
||||
|
||||
:param other: Another object to compare.
|
||||
:return: True if equal, False otherwise.
|
||||
"""
|
||||
if isinstance(other, Envelope):
|
||||
return (
|
||||
self.public_key.__eq__(other.public_key)
|
||||
and self.payload_type == other.payload_type
|
||||
and self.signature == other.signature
|
||||
and self.raw_payload == other.raw_payload
|
||||
)
|
||||
return False
|
||||
|
||||
def _env_addrs_set(self) -> set[multiaddr.Multiaddr]:
|
||||
return {b for b in self.record().addrs}
|
||||
|
||||
|
||||
def pub_key_to_protobuf(pub_key: PublicKey) -> cryto_pb.PublicKey:
|
||||
"""
|
||||
Convert a Python PublicKey object to its protobuf equivalent.
|
||||
|
||||
:param pub_key: A libp2p-compatible PublicKey instance.
|
||||
:return: Serialized protobuf PublicKey message.
|
||||
"""
|
||||
internal_key_type = pub_key.get_type()
|
||||
key_type = cast(cryto_pb.KeyType, internal_key_type.value)
|
||||
data = pub_key.to_bytes()
|
||||
protobuf_key = cryto_pb.PublicKey(Type=key_type, Data=data)
|
||||
return protobuf_key
|
||||
|
||||
|
||||
def pub_key_from_protobuf(pb_key: cryto_pb.PublicKey) -> PublicKey:
|
||||
"""
|
||||
Parse a protobuf PublicKey message into a native libp2p PublicKey.
|
||||
|
||||
Supports Ed25519, RSA, and Secp256k1 key types.
|
||||
|
||||
:param pb_key: Protobuf representation of a public key.
|
||||
:return: Parsed PublicKey object.
|
||||
:raises ValueError: If the key type is unrecognized.
|
||||
"""
|
||||
if pb_key.Type == cryto_pb.KeyType.Ed25519:
|
||||
return Ed25519PublicKey.from_bytes(pb_key.Data)
|
||||
elif pb_key.Type == cryto_pb.KeyType.RSA:
|
||||
return RSAPublicKey.from_bytes(pb_key.Data)
|
||||
elif pb_key.Type == cryto_pb.KeyType.Secp256k1:
|
||||
return Secp256k1PublicKey.from_bytes(pb_key.Data)
|
||||
# libp2p.crypto.ecdsa not implemented
|
||||
else:
|
||||
raise ValueError(f"Unknown key type: {pb_key.Type}")
|
||||
|
||||
|
||||
def seal_record(record: PeerRecord, private_key: PrivateKey) -> Envelope:
|
||||
"""
|
||||
Create and sign a new Envelope from a PeerRecord.
|
||||
|
||||
The record is serialized and signed in the scope of its domain and codec.
|
||||
The result is a self-contained, verifiable Envelope.
|
||||
|
||||
:param record: A PeerRecord to encapsulate.
|
||||
:param private_key: The signer's private key.
|
||||
:return: A signed Envelope instance.
|
||||
"""
|
||||
payload = record.marshal_record()
|
||||
|
||||
unsigned = make_unsigned(record.domain(), record.codec(), payload)
|
||||
signature = private_key.sign(unsigned)
|
||||
|
||||
return Envelope(
|
||||
public_key=private_key.get_public_key(),
|
||||
payload_type=record.codec(),
|
||||
raw_payload=payload,
|
||||
signature=signature,
|
||||
)
|
||||
|
||||
|
||||
def consume_envelope(data: bytes, domain: str) -> tuple[Envelope, PeerRecord]:
|
||||
"""
|
||||
Parse, validate, and decode an Envelope from bytes.
|
||||
|
||||
Validates the envelope's signature using the given domain and decodes
|
||||
the inner payload into a PeerRecord.
|
||||
|
||||
:param data: Serialized envelope bytes.
|
||||
:param domain: Domain string to verify signature against.
|
||||
:return: Tuple of (Envelope, PeerRecord).
|
||||
:raises ValueError: If signature validation or decoding fails.
|
||||
"""
|
||||
env = unmarshal_envelope(data)
|
||||
env.validate(domain)
|
||||
record = env.record()
|
||||
return env, record
|
||||
|
||||
|
||||
def unmarshal_envelope(data: bytes) -> Envelope:
|
||||
"""
|
||||
Deserialize an Envelope from its wire format.
|
||||
|
||||
This parses the protobuf fields without verifying the signature.
|
||||
|
||||
:param data: Serialized envelope bytes.
|
||||
:return: Parsed Envelope object.
|
||||
:raises DecodeError: If protobuf parsing fails.
|
||||
"""
|
||||
pb_env = pb.Envelope()
|
||||
pb_env.ParseFromString(data)
|
||||
pk = pub_key_from_protobuf(pb_env.public_key)
|
||||
|
||||
return Envelope(
|
||||
public_key=pk,
|
||||
payload_type=pb_env.payload_type,
|
||||
raw_payload=pb_env.payload,
|
||||
signature=pb_env.signature,
|
||||
)
|
||||
|
||||
|
||||
def make_unsigned(domain: str, payload_type: bytes, payload: bytes) -> bytes:
|
||||
"""
|
||||
Build a byte buffer to be signed for an Envelope.
|
||||
|
||||
The unsigned byte structure is:
|
||||
varint(len(domain)) || domain ||
|
||||
varint(len(payload_type)) || payload_type ||
|
||||
varint(len(payload)) || payload
|
||||
|
||||
This is the exact input used during signing and verification.
|
||||
|
||||
:param domain: Domain string for signature scoping.
|
||||
:param payload_type: Identifier for the type of payload.
|
||||
:param payload: Raw serialized payload bytes.
|
||||
:return: Byte buffer to be signed or verified.
|
||||
"""
|
||||
fields = [domain.encode(), payload_type, payload]
|
||||
buf = bytearray()
|
||||
|
||||
for field in fields:
|
||||
buf.extend(encode_uvarint(len(field)))
|
||||
buf.extend(field)
|
||||
|
||||
return bytes(buf)
|
||||
|
||||
|
||||
def debug_dump_envelope(env: Envelope) -> None:
|
||||
print("\n=== Envelope ===")
|
||||
print(f"Payload Type: {env.payload_type!r}")
|
||||
print(f"Signature: {env.signature.hex()} ({len(env.signature)} bytes)")
|
||||
print(f"Raw Payload: {env.raw_payload.hex()} ({len(env.raw_payload)} bytes)")
|
||||
|
||||
try:
|
||||
peer_record = unmarshal_record(env.raw_payload)
|
||||
print("\n=== Parsed PeerRecord ===")
|
||||
print(peer_record)
|
||||
except Exception as e:
|
||||
print("Failed to parse PeerRecord:", e)
|
||||
@ -1,3 +1,4 @@
|
||||
import functools
|
||||
import hashlib
|
||||
|
||||
import base58
|
||||
@ -36,25 +37,23 @@ if ENABLE_INLINING:
|
||||
|
||||
class ID:
|
||||
_bytes: bytes
|
||||
_xor_id: int | None = None
|
||||
_b58_str: str | None = None
|
||||
|
||||
def __init__(self, peer_id_bytes: bytes) -> None:
|
||||
self._bytes = peer_id_bytes
|
||||
|
||||
@property
|
||||
@functools.cached_property
|
||||
def xor_id(self) -> int:
|
||||
if not self._xor_id:
|
||||
self._xor_id = int(sha256_digest(self._bytes).hex(), 16)
|
||||
return self._xor_id
|
||||
return int(sha256_digest(self._bytes).hex(), 16)
|
||||
|
||||
@functools.cached_property
|
||||
def base58(self) -> str:
|
||||
return base58.b58encode(self._bytes).decode()
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
return self._bytes
|
||||
|
||||
def to_base58(self) -> str:
|
||||
if not self._b58_str:
|
||||
self._b58_str = base58.b58encode(self._bytes).decode()
|
||||
return self._b58_str
|
||||
return self.base58
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<libp2p.peer.id.ID ({self!s})>"
|
||||
|
||||
22
libp2p/peer/pb/crypto.proto
Normal file
22
libp2p/peer/pb/crypto.proto
Normal file
@ -0,0 +1,22 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package libp2p.peer.pb.crypto;
|
||||
|
||||
option go_package = "github.com/libp2p/go-libp2p/core/crypto/pb";
|
||||
|
||||
enum KeyType {
|
||||
RSA = 0;
|
||||
Ed25519 = 1;
|
||||
Secp256k1 = 2;
|
||||
ECDSA = 3;
|
||||
}
|
||||
|
||||
message PublicKey {
|
||||
KeyType Type = 1;
|
||||
bytes Data = 2;
|
||||
}
|
||||
|
||||
message PrivateKey {
|
||||
KeyType Type = 1;
|
||||
bytes Data = 2;
|
||||
}
|
||||
31
libp2p/peer/pb/crypto_pb2.py
Normal file
31
libp2p/peer/pb/crypto_pb2.py
Normal file
@ -0,0 +1,31 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: libp2p/peer/pb/crypto.proto
|
||||
# Protobuf Python Version: 4.25.3
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1blibp2p/peer/pb/crypto.proto\x12\x15libp2p.peer.pb.crypto\"G\n\tPublicKey\x12,\n\x04Type\x18\x01 \x01(\x0e\x32\x1e.libp2p.peer.pb.crypto.KeyType\x12\x0c\n\x04\x44\x61ta\x18\x02 \x01(\x0c\"H\n\nPrivateKey\x12,\n\x04Type\x18\x01 \x01(\x0e\x32\x1e.libp2p.peer.pb.crypto.KeyType\x12\x0c\n\x04\x44\x61ta\x18\x02 \x01(\x0c*9\n\x07KeyType\x12\x07\n\x03RSA\x10\x00\x12\x0b\n\x07\x45\x64\x32\x35\x35\x31\x39\x10\x01\x12\r\n\tSecp256k1\x10\x02\x12\t\n\x05\x45\x43\x44SA\x10\x03\x42,Z*github.com/libp2p/go-libp2p/core/crypto/pbb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.peer.pb.crypto_pb2', _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
_globals['DESCRIPTOR']._options = None
|
||||
_globals['DESCRIPTOR']._serialized_options = b'Z*github.com/libp2p/go-libp2p/core/crypto/pb'
|
||||
_globals['_KEYTYPE']._serialized_start=201
|
||||
_globals['_KEYTYPE']._serialized_end=258
|
||||
_globals['_PUBLICKEY']._serialized_start=54
|
||||
_globals['_PUBLICKEY']._serialized_end=125
|
||||
_globals['_PRIVATEKEY']._serialized_start=127
|
||||
_globals['_PRIVATEKEY']._serialized_end=199
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
33
libp2p/peer/pb/crypto_pb2.pyi
Normal file
33
libp2p/peer/pb/crypto_pb2.pyi
Normal file
@ -0,0 +1,33 @@
|
||||
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import message as _message
|
||||
from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
|
||||
|
||||
DESCRIPTOR: _descriptor.FileDescriptor
|
||||
|
||||
class KeyType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
||||
__slots__ = ()
|
||||
RSA: _ClassVar[KeyType]
|
||||
Ed25519: _ClassVar[KeyType]
|
||||
Secp256k1: _ClassVar[KeyType]
|
||||
ECDSA: _ClassVar[KeyType]
|
||||
RSA: KeyType
|
||||
Ed25519: KeyType
|
||||
Secp256k1: KeyType
|
||||
ECDSA: KeyType
|
||||
|
||||
class PublicKey(_message.Message):
|
||||
__slots__ = ("Type", "Data")
|
||||
TYPE_FIELD_NUMBER: _ClassVar[int]
|
||||
DATA_FIELD_NUMBER: _ClassVar[int]
|
||||
Type: KeyType
|
||||
Data: bytes
|
||||
def __init__(self, Type: _Optional[_Union[KeyType, str]] = ..., Data: _Optional[bytes] = ...) -> None: ...
|
||||
|
||||
class PrivateKey(_message.Message):
|
||||
__slots__ = ("Type", "Data")
|
||||
TYPE_FIELD_NUMBER: _ClassVar[int]
|
||||
DATA_FIELD_NUMBER: _ClassVar[int]
|
||||
Type: KeyType
|
||||
Data: bytes
|
||||
def __init__(self, Type: _Optional[_Union[KeyType, str]] = ..., Data: _Optional[bytes] = ...) -> None: ...
|
||||
14
libp2p/peer/pb/envelope.proto
Normal file
14
libp2p/peer/pb/envelope.proto
Normal file
@ -0,0 +1,14 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package libp2p.peer.pb.record;
|
||||
|
||||
import "libp2p/peer/pb/crypto.proto";
|
||||
|
||||
option go_package = "github.com/libp2p/go-libp2p/core/record/pb";
|
||||
|
||||
message Envelope {
|
||||
libp2p.peer.pb.crypto.PublicKey public_key = 1;
|
||||
bytes payload_type = 2;
|
||||
bytes payload = 3;
|
||||
bytes signature = 5;
|
||||
}
|
||||
28
libp2p/peer/pb/envelope_pb2.py
Normal file
28
libp2p/peer/pb/envelope_pb2.py
Normal file
@ -0,0 +1,28 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: libp2p/peer/pb/envelope.proto
|
||||
# Protobuf Python Version: 4.25.3
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
from libp2p.peer.pb import crypto_pb2 as libp2p_dot_peer_dot_pb_dot_crypto__pb2
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1dlibp2p/peer/pb/envelope.proto\x12\x15libp2p.peer.pb.record\x1a\x1blibp2p/peer/pb/crypto.proto\"z\n\x08\x45nvelope\x12\x34\n\npublic_key\x18\x01 \x01(\x0b\x32 .libp2p.peer.pb.crypto.PublicKey\x12\x14\n\x0cpayload_type\x18\x02 \x01(\x0c\x12\x0f\n\x07payload\x18\x03 \x01(\x0c\x12\x11\n\tsignature\x18\x05 \x01(\x0c\x42,Z*github.com/libp2p/go-libp2p/core/record/pbb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.peer.pb.envelope_pb2', _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
_globals['DESCRIPTOR']._options = None
|
||||
_globals['DESCRIPTOR']._serialized_options = b'Z*github.com/libp2p/go-libp2p/core/record/pb'
|
||||
_globals['_ENVELOPE']._serialized_start=85
|
||||
_globals['_ENVELOPE']._serialized_end=207
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
18
libp2p/peer/pb/envelope_pb2.pyi
Normal file
18
libp2p/peer/pb/envelope_pb2.pyi
Normal file
@ -0,0 +1,18 @@
|
||||
from libp2p.peer.pb import crypto_pb2 as _crypto_pb2
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import message as _message
|
||||
from typing import ClassVar as _ClassVar, Mapping as _Mapping, Optional as _Optional, Union as _Union
|
||||
|
||||
DESCRIPTOR: _descriptor.FileDescriptor
|
||||
|
||||
class Envelope(_message.Message):
|
||||
__slots__ = ("public_key", "payload_type", "payload", "signature")
|
||||
PUBLIC_KEY_FIELD_NUMBER: _ClassVar[int]
|
||||
PAYLOAD_TYPE_FIELD_NUMBER: _ClassVar[int]
|
||||
PAYLOAD_FIELD_NUMBER: _ClassVar[int]
|
||||
SIGNATURE_FIELD_NUMBER: _ClassVar[int]
|
||||
public_key: _crypto_pb2.PublicKey
|
||||
payload_type: bytes
|
||||
payload: bytes
|
||||
signature: bytes
|
||||
def __init__(self, public_key: _Optional[_Union[_crypto_pb2.PublicKey, _Mapping]] = ..., payload_type: _Optional[bytes] = ..., payload: _Optional[bytes] = ..., signature: _Optional[bytes] = ...) -> None: ... # type: ignore[type-arg]
|
||||
31
libp2p/peer/pb/peer_record.proto
Normal file
31
libp2p/peer/pb/peer_record.proto
Normal file
@ -0,0 +1,31 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package peer.pb;
|
||||
|
||||
option go_package = "github.com/libp2p/go-libp2p/core/peer/pb";
|
||||
|
||||
// PeerRecord messages contain information that is useful to share with other peers.
|
||||
// Currently, a PeerRecord contains the public listen addresses for a peer, but this
|
||||
// is expected to expand to include other information in the future.
|
||||
//
|
||||
// PeerRecords are designed to be serialized to bytes and placed inside of
|
||||
// SignedEnvelopes before sharing with other peers.
|
||||
// See https://github.com/libp2p/go-libp2p/blob/master/core/record/pb/envelope.proto for
|
||||
// the SignedEnvelope definition.
|
||||
message PeerRecord {
|
||||
|
||||
// AddressInfo is a wrapper around a binary multiaddr. It is defined as a
|
||||
// separate message to allow us to add per-address metadata in the future.
|
||||
message AddressInfo {
|
||||
bytes multiaddr = 1;
|
||||
}
|
||||
|
||||
// peer_id contains a libp2p peer id in its binary representation.
|
||||
bytes peer_id = 1;
|
||||
|
||||
// seq contains a monotonically-increasing sequence counter to order PeerRecords in time.
|
||||
uint64 seq = 2;
|
||||
|
||||
// addresses is a list of public listen addresses for the peer.
|
||||
repeated AddressInfo addresses = 3;
|
||||
}
|
||||
29
libp2p/peer/pb/peer_record_pb2.py
Normal file
29
libp2p/peer/pb/peer_record_pb2.py
Normal file
@ -0,0 +1,29 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: libp2p/peer/pb/peer_record.proto
|
||||
# Protobuf Python Version: 4.25.3
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/peer/pb/peer_record.proto\x12\x07peer.pb\"\x80\x01\n\nPeerRecord\x12\x0f\n\x07peer_id\x18\x01 \x01(\x0c\x12\x0b\n\x03seq\x18\x02 \x01(\x04\x12\x32\n\taddresses\x18\x03 \x03(\x0b\x32\x1f.peer.pb.PeerRecord.AddressInfo\x1a \n\x0b\x41\x64\x64ressInfo\x12\x11\n\tmultiaddr\x18\x01 \x01(\x0c\x42*Z(github.com/libp2p/go-libp2p/core/peer/pbb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.peer.pb.peer_record_pb2', _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
_globals['DESCRIPTOR']._options = None
|
||||
_globals['DESCRIPTOR']._serialized_options = b'Z(github.com/libp2p/go-libp2p/core/peer/pb'
|
||||
_globals['_PEERRECORD']._serialized_start=46
|
||||
_globals['_PEERRECORD']._serialized_end=174
|
||||
_globals['_PEERRECORD_ADDRESSINFO']._serialized_start=142
|
||||
_globals['_PEERRECORD_ADDRESSINFO']._serialized_end=174
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
21
libp2p/peer/pb/peer_record_pb2.pyi
Normal file
21
libp2p/peer/pb/peer_record_pb2.pyi
Normal file
@ -0,0 +1,21 @@
|
||||
from google.protobuf.internal import containers as _containers
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import message as _message
|
||||
from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union
|
||||
|
||||
DESCRIPTOR: _descriptor.FileDescriptor
|
||||
|
||||
class PeerRecord(_message.Message):
|
||||
__slots__ = ("peer_id", "seq", "addresses")
|
||||
class AddressInfo(_message.Message):
|
||||
__slots__ = ("multiaddr",)
|
||||
MULTIADDR_FIELD_NUMBER: _ClassVar[int]
|
||||
multiaddr: bytes
|
||||
def __init__(self, multiaddr: _Optional[bytes] = ...) -> None: ...
|
||||
PEER_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
SEQ_FIELD_NUMBER: _ClassVar[int]
|
||||
ADDRESSES_FIELD_NUMBER: _ClassVar[int]
|
||||
peer_id: bytes
|
||||
seq: int
|
||||
addresses: _containers.RepeatedCompositeFieldContainer[PeerRecord.AddressInfo]
|
||||
def __init__(self, peer_id: _Optional[bytes] = ..., seq: _Optional[int] = ..., addresses: _Optional[_Iterable[_Union[PeerRecord.AddressInfo, _Mapping]]] = ...) -> None: ... # type: ignore[type-arg]
|
||||
251
libp2p/peer/peer_record.py
Normal file
251
libp2p/peer/peer_record.py
Normal file
@ -0,0 +1,251 @@
|
||||
from collections.abc import Sequence
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
from libp2p.abc import IPeerRecord
|
||||
from libp2p.peer.id import ID
|
||||
import libp2p.peer.pb.peer_record_pb2 as pb
|
||||
from libp2p.peer.peerinfo import PeerInfo
|
||||
|
||||
PEER_RECORD_ENVELOPE_DOMAIN = "libp2p-peer-record"
|
||||
PEER_RECORD_ENVELOPE_PAYLOAD_TYPE = b"\x03\x01"
|
||||
|
||||
_last_timestamp_lock = threading.Lock()
|
||||
_last_timestamp: int = 0
|
||||
|
||||
|
||||
class PeerRecord(IPeerRecord):
|
||||
"""
|
||||
A record that contains metatdata about a peer in the libp2p network.
|
||||
|
||||
This includes:
|
||||
- `peer_id`: The peer's globally unique indentifier.
|
||||
- `addrs`: A list of the peer's publicly reachable multiaddrs.
|
||||
- `seq`: A strictly monotonically increasing timestamp used
|
||||
to order records over time.
|
||||
|
||||
PeerRecords are designed to be signed and transmitted in libp2p routing Envelopes.
|
||||
"""
|
||||
|
||||
peer_id: ID
|
||||
addrs: list[Multiaddr]
|
||||
seq: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
peer_id: ID | None = None,
|
||||
addrs: list[Multiaddr] | None = None,
|
||||
seq: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a new PeerRecord.
|
||||
If `seq` is not provided, a timestamp-based strictly increasing sequence
|
||||
number will be generated.
|
||||
|
||||
:param peer_id: ID of the peer this record refers to.
|
||||
:param addrs: Public multiaddrs of the peer.
|
||||
:param seq: Monotonic sequence number.
|
||||
|
||||
"""
|
||||
if peer_id is not None:
|
||||
self.peer_id = peer_id
|
||||
self.addrs = addrs or []
|
||||
if seq is not None:
|
||||
self.seq = seq
|
||||
else:
|
||||
self.seq = timestamp_seq()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"PeerRecord(\n"
|
||||
f" peer_id={self.peer_id},\n"
|
||||
f" multiaddrs={[str(m) for m in self.addrs]},\n"
|
||||
f" seq={self.seq}\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
def domain(self) -> str:
|
||||
"""
|
||||
Return the domain string associated with this PeerRecord.
|
||||
|
||||
Used during record signing and envelope validation to identify the record type.
|
||||
"""
|
||||
return PEER_RECORD_ENVELOPE_DOMAIN
|
||||
|
||||
def codec(self) -> bytes:
|
||||
"""
|
||||
Return the codec identifier for PeerRecords.
|
||||
|
||||
This binary perfix helps distinguish PeerRecords in serialized envelopes.
|
||||
"""
|
||||
return PEER_RECORD_ENVELOPE_PAYLOAD_TYPE
|
||||
|
||||
def to_protobuf(self) -> pb.PeerRecord:
|
||||
"""
|
||||
Convert the current PeerRecord into a ProtoBuf PeerRecord message.
|
||||
|
||||
:raises ValueError: if peer_id serialization fails.
|
||||
:return: A ProtoBuf-encoded PeerRecord message object.
|
||||
"""
|
||||
try:
|
||||
id_bytes = self.peer_id.to_bytes()
|
||||
except Exception as e:
|
||||
raise ValueError(f"failed to marshal peer_id: {e}")
|
||||
|
||||
msg = pb.PeerRecord()
|
||||
msg.peer_id = id_bytes
|
||||
msg.seq = self.seq
|
||||
msg.addresses.extend(addrs_to_protobuf(self.addrs))
|
||||
return msg
|
||||
|
||||
def marshal_record(self) -> bytes:
|
||||
"""
|
||||
Serialize a PeerRecord into raw bytes suitable for embedding in an Envelope.
|
||||
|
||||
This is typically called during the process of signing or sealing the record.
|
||||
:raises ValueError: if serialization to protobuf fails.
|
||||
:return: Serialized PeerRecord bytes.
|
||||
"""
|
||||
try:
|
||||
msg = self.to_protobuf()
|
||||
return msg.SerializeToString()
|
||||
except Exception as e:
|
||||
raise ValueError(f"failed to marshal PeerRecord: {e}")
|
||||
|
||||
def equal(self, other: Any) -> bool:
|
||||
"""
|
||||
Check if this PeerRecord is identical to another.
|
||||
|
||||
Two PeerRecords are considered equal if:
|
||||
- Their peer IDs match.
|
||||
- Their sequence numbers are identical.
|
||||
- Their address lists are identical and in the same order.
|
||||
|
||||
:param other: Another PeerRecord instance.
|
||||
:return: True if all fields mathch, False otherwise.
|
||||
"""
|
||||
if isinstance(other, PeerRecord):
|
||||
if self.peer_id == other.peer_id:
|
||||
if self.seq == other.seq:
|
||||
if len(self.addrs) == len(other.addrs):
|
||||
for a1, a2 in zip(self.addrs, other.addrs):
|
||||
if a1 == a2:
|
||||
continue
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def unmarshal_record(data: bytes) -> PeerRecord:
|
||||
"""
|
||||
Deserialize a PeerRecord from its serialized byte representation.
|
||||
|
||||
Typically used when receiveing a PeerRecord inside a signed routing Envelope.
|
||||
|
||||
:param data: Serialized protobuf-encoded bytes.
|
||||
:raises ValueError: if parsing or conversion fails.
|
||||
:reurn: A valid PeerRecord instance.
|
||||
"""
|
||||
if data is None:
|
||||
raise ValueError("cannot unmarshal PeerRecord from None")
|
||||
|
||||
msg = pb.PeerRecord()
|
||||
try:
|
||||
msg.ParseFromString(data)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse PeerRecord protobuf: {e}")
|
||||
|
||||
try:
|
||||
record = peer_record_from_protobuf(msg)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to convert protobuf to PeerRecord: {e}")
|
||||
|
||||
return record
|
||||
|
||||
|
||||
def timestamp_seq() -> int:
|
||||
"""
|
||||
Generate a strictly increasing timestamp-based sequence number.
|
||||
|
||||
Ensures that even if multiple PeerRecords are generated in the same nanosecond,
|
||||
their `seq` values will still be strictly increasing by using a lock to track the
|
||||
last value.
|
||||
|
||||
:return: A strictly increasing integer timestamp.
|
||||
"""
|
||||
global _last_timestamp
|
||||
now = int(time.time_ns())
|
||||
with _last_timestamp_lock:
|
||||
if now <= _last_timestamp:
|
||||
now = _last_timestamp + 1
|
||||
_last_timestamp = now
|
||||
return now
|
||||
|
||||
|
||||
def peer_record_from_peer_info(info: PeerInfo) -> PeerRecord:
|
||||
"""
|
||||
Create a PeerRecord from a PeerInfo object.
|
||||
|
||||
This automatically assigns a timestamp-based sequence number to the record.
|
||||
:param info: A PeerInfo instance (contains peer_id and addrs).
|
||||
:return: A PeerRecord instance.
|
||||
"""
|
||||
record = PeerRecord()
|
||||
record.peer_id = info.peer_id
|
||||
record.addrs = info.addrs
|
||||
return record
|
||||
|
||||
|
||||
def peer_record_from_protobuf(msg: pb.PeerRecord) -> PeerRecord:
|
||||
"""
|
||||
Convert a protobuf PeerRecord message into a PeerRecord object.
|
||||
|
||||
:param msg: Protobuf PeerRecord message.
|
||||
:raises ValueError: if the peer_id cannot be parsed.
|
||||
:return: A deserialized PeerRecord instance.
|
||||
"""
|
||||
try:
|
||||
peer_id = ID(msg.peer_id)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to unmarshal peer_id: {e}")
|
||||
|
||||
addrs = addrs_from_protobuf(msg.addresses)
|
||||
seq = msg.seq
|
||||
|
||||
return PeerRecord(peer_id, addrs, seq)
|
||||
|
||||
|
||||
def addrs_from_protobuf(addrs: Sequence[pb.PeerRecord.AddressInfo]) -> list[Multiaddr]:
|
||||
"""
|
||||
Convert a list of protobuf address records to Multiaddr objects.
|
||||
|
||||
:param addrs: A list of protobuf PeerRecord.AddressInfo messages.
|
||||
:return: A list of decoded Multiaddr instances (invalid ones are skipped).
|
||||
"""
|
||||
out = []
|
||||
for addr_info in addrs:
|
||||
try:
|
||||
addr = Multiaddr(addr_info.multiaddr)
|
||||
out.append(addr)
|
||||
except Exception:
|
||||
continue
|
||||
return out
|
||||
|
||||
|
||||
def addrs_to_protobuf(addrs: list[Multiaddr]) -> list[pb.PeerRecord.AddressInfo]:
|
||||
"""
|
||||
Convert a list of Multiaddr objects into their protobuf representation.
|
||||
|
||||
:param addrs: A list of Multiaddr instances.
|
||||
:return: A list of PeerRecord.AddressInfo protobuf messages.
|
||||
"""
|
||||
out = []
|
||||
for addr in addrs:
|
||||
addr_info = pb.PeerRecord.AddressInfo()
|
||||
addr_info.multiaddr = addr.to_bytes()
|
||||
out.append(addr_info)
|
||||
return out
|
||||
@ -16,6 +16,7 @@ import trio
|
||||
from trio import MemoryReceiveChannel, MemorySendChannel
|
||||
|
||||
from libp2p.abc import (
|
||||
IHost,
|
||||
IPeerStore,
|
||||
)
|
||||
from libp2p.crypto.keys import (
|
||||
@ -23,6 +24,8 @@ from libp2p.crypto.keys import (
|
||||
PrivateKey,
|
||||
PublicKey,
|
||||
)
|
||||
from libp2p.peer.envelope import Envelope, seal_record
|
||||
from libp2p.peer.peer_record import PeerRecord
|
||||
|
||||
from .id import (
|
||||
ID,
|
||||
@ -38,12 +41,112 @@ from .peerinfo import (
|
||||
PERMANENT_ADDR_TTL = 0
|
||||
|
||||
|
||||
def create_signed_peer_record(
|
||||
peer_id: ID, addrs: list[Multiaddr], pvt_key: PrivateKey
|
||||
) -> Envelope:
|
||||
"""Creates a signed_peer_record wrapped in an Envelope"""
|
||||
record = PeerRecord(peer_id, addrs)
|
||||
envelope = seal_record(record, pvt_key)
|
||||
return envelope
|
||||
|
||||
|
||||
def env_to_send_in_RPC(host: IHost) -> tuple[bytes, bool]:
|
||||
"""
|
||||
Return the signed peer record (Envelope) to be sent in an RPC.
|
||||
|
||||
This function checks whether the host already has a cached signed peer record
|
||||
(SPR). If one exists and its addresses match the host's current listen
|
||||
addresses, the cached envelope is reused. Otherwise, a new signed peer record
|
||||
is created, cached, and returned.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
host : IHost
|
||||
The local host instance, providing access to peer ID, listen addresses,
|
||||
private key, and the peerstore.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[bytes, bool]
|
||||
A 2-tuple where the first element is the serialized envelope (bytes)
|
||||
for the signed peer record, and the second element is a boolean flag
|
||||
indicating whether a new record was created (True) or an existing cached
|
||||
one was reused (False).
|
||||
|
||||
"""
|
||||
listen_addrs_set = {addr for addr in host.get_addrs()}
|
||||
local_env = host.get_peerstore().get_local_record()
|
||||
|
||||
if local_env is None:
|
||||
# No cached SPR yet -> create one
|
||||
return issue_and_cache_local_record(host), True
|
||||
else:
|
||||
record_addrs_set = local_env._env_addrs_set()
|
||||
if record_addrs_set == listen_addrs_set:
|
||||
# Perfect match -> reuse cached envelope
|
||||
return local_env.marshal_envelope(), False
|
||||
else:
|
||||
# Addresses changed -> issue a new SPR and cache it
|
||||
return issue_and_cache_local_record(host), True
|
||||
|
||||
|
||||
def issue_and_cache_local_record(host: IHost) -> bytes:
|
||||
"""
|
||||
Create and cache a new signed peer record (Envelope) for the host.
|
||||
|
||||
This function generates a new signed peer record from the host’s peer ID,
|
||||
listen addresses, and private key. The resulting envelope is stored in
|
||||
the peerstore as the local record for future reuse.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
host : IHost
|
||||
The local host instance, providing access to peer ID, listen addresses,
|
||||
private key, and the peerstore.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bytes
|
||||
The serialized envelope (bytes) representing the newly created signed
|
||||
peer record.
|
||||
|
||||
"""
|
||||
env = create_signed_peer_record(
|
||||
host.get_id(),
|
||||
host.get_addrs(),
|
||||
host.get_private_key(),
|
||||
)
|
||||
# Cache it for next time use
|
||||
host.get_peerstore().set_local_record(env)
|
||||
return env.marshal_envelope()
|
||||
|
||||
|
||||
class PeerRecordState:
|
||||
envelope: Envelope
|
||||
seq: int
|
||||
|
||||
def __init__(self, envelope: Envelope, seq: int):
|
||||
self.envelope = envelope
|
||||
self.seq = seq
|
||||
|
||||
|
||||
class PeerStore(IPeerStore):
|
||||
peer_data_map: dict[ID, PeerData]
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, max_records: int = 10000) -> None:
|
||||
self.peer_data_map = defaultdict(PeerData)
|
||||
self.addr_update_channels: dict[ID, MemorySendChannel[Multiaddr]] = {}
|
||||
self.peer_record_map: dict[ID, PeerRecordState] = {}
|
||||
self.local_peer_record: Envelope | None = None
|
||||
self.max_records = max_records
|
||||
|
||||
def get_local_record(self) -> Envelope | None:
|
||||
"""Get the local-signed-record wrapped in Envelope"""
|
||||
return self.local_peer_record
|
||||
|
||||
def set_local_record(self, envelope: Envelope) -> None:
|
||||
"""Set the local-signed-record wrapped in Envelope"""
|
||||
self.local_peer_record = envelope
|
||||
|
||||
def peer_info(self, peer_id: ID) -> PeerInfo:
|
||||
"""
|
||||
@ -70,6 +173,10 @@ class PeerStore(IPeerStore):
|
||||
else:
|
||||
raise PeerStoreError("peer ID not found")
|
||||
|
||||
# Clear the peer records
|
||||
if peer_id in self.peer_record_map:
|
||||
self.peer_record_map.pop(peer_id, None)
|
||||
|
||||
def valid_peer_ids(self) -> list[ID]:
|
||||
"""
|
||||
:return: all of the valid peer IDs stored in peer store
|
||||
@ -82,6 +189,38 @@ class PeerStore(IPeerStore):
|
||||
peer_data.clear_addrs()
|
||||
return valid_peer_ids
|
||||
|
||||
def _enforce_record_limit(self) -> None:
|
||||
"""Enforce maximum number of stored records."""
|
||||
if len(self.peer_record_map) > self.max_records:
|
||||
# Record oldest records based on seequence number
|
||||
sorted_records = sorted(
|
||||
self.peer_record_map.items(), key=lambda x: x[1].seq
|
||||
)
|
||||
records_to_remove = len(self.peer_record_map) - self.max_records
|
||||
for peer_id, _ in sorted_records[:records_to_remove]:
|
||||
self.maybe_delete_peer_record(peer_id)
|
||||
del self.peer_record_map[peer_id]
|
||||
|
||||
async def start_cleanup_task(self, cleanup_interval: int = 3600) -> None:
|
||||
"""Start periodic cleanup of expired peer records and addresses."""
|
||||
while True:
|
||||
await trio.sleep(cleanup_interval)
|
||||
self._cleanup_expired_records()
|
||||
|
||||
def _cleanup_expired_records(self) -> None:
|
||||
"""Remove expired peer records and addresses"""
|
||||
expired_peers = []
|
||||
|
||||
for peer_id, peer_data in self.peer_data_map.items():
|
||||
if peer_data.is_expired():
|
||||
expired_peers.append(peer_id)
|
||||
|
||||
for peer_id in expired_peers:
|
||||
self.maybe_delete_peer_record(peer_id)
|
||||
del self.peer_data_map[peer_id]
|
||||
|
||||
self._enforce_record_limit()
|
||||
|
||||
# --------PROTO-BOOK--------
|
||||
|
||||
def get_protocols(self, peer_id: ID) -> list[str]:
|
||||
@ -165,6 +304,84 @@ class PeerStore(IPeerStore):
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
peer_data.clear_metadata()
|
||||
|
||||
# -----CERT-ADDR-BOOK-----
|
||||
|
||||
def maybe_delete_peer_record(self, peer_id: ID) -> None:
|
||||
"""
|
||||
Delete the signed peer record for a peer if it has no know
|
||||
(non-expired) addresses.
|
||||
|
||||
This is a garbage collection mechanism: if all addresses for a peer have expired
|
||||
or been cleared, there's no point holding onto its signed `Envelope`
|
||||
|
||||
:param peer_id: The peer whose record we may delete/
|
||||
"""
|
||||
if peer_id in self.peer_record_map:
|
||||
if not self.addrs(peer_id):
|
||||
self.peer_record_map.pop(peer_id, None)
|
||||
|
||||
def consume_peer_record(self, envelope: Envelope, ttl: int) -> bool:
|
||||
"""
|
||||
Accept and store a signed PeerRecord, unless it's older than
|
||||
the one already stored.
|
||||
|
||||
This function:
|
||||
- Extracts the peer ID and sequence number from the envelope
|
||||
- Rejects the record if it's older (lower seq)
|
||||
- Updates the stored peer record and replaces associated addresses if accepted
|
||||
|
||||
:param envelope: Signed envelope containing a PeerRecord.
|
||||
:param ttl: Time-to-live for the included multiaddrs (in seconds).
|
||||
:return: True if the record was accepted and stored; False if it was rejected.
|
||||
"""
|
||||
record = envelope.record()
|
||||
peer_id = record.peer_id
|
||||
|
||||
existing = self.peer_record_map.get(peer_id)
|
||||
if existing and existing.seq > record.seq:
|
||||
return False # reject older record
|
||||
|
||||
new_addrs = set(record.addrs)
|
||||
|
||||
self.peer_record_map[peer_id] = PeerRecordState(envelope, record.seq)
|
||||
self.peer_data_map[peer_id].clear_addrs()
|
||||
self.add_addrs(peer_id, list(new_addrs), ttl)
|
||||
|
||||
return True
|
||||
|
||||
def consume_peer_records(self, envelopes: list[Envelope], ttl: int) -> list[bool]:
|
||||
"""Consume multiple peer records in a single operation."""
|
||||
results = []
|
||||
for envelope in envelopes:
|
||||
results.append(self.consume_peer_record(envelope, ttl))
|
||||
return results
|
||||
|
||||
def get_peer_record(self, peer_id: ID) -> Envelope | None:
|
||||
"""
|
||||
Retrieve the most recent signed PeerRecord `Envelope` for a peer, if it exists
|
||||
and is still relevant.
|
||||
|
||||
First, it runs cleanup via `maybe_delete_peer_record` to purge stale data.
|
||||
Then it checks whether the peer has valid, unexpired addresses before
|
||||
returning the associated envelope.
|
||||
|
||||
:param peer_id: The peer to look up.
|
||||
:return: The signed Envelope if the peer is known and has valid
|
||||
addresses; None otherwise.
|
||||
|
||||
"""
|
||||
self.maybe_delete_peer_record(peer_id)
|
||||
|
||||
# Check if the peer has any valid addresses
|
||||
if (
|
||||
peer_id in self.peer_data_map
|
||||
and not self.peer_data_map[peer_id].is_expired()
|
||||
):
|
||||
state = self.peer_record_map.get(peer_id)
|
||||
if state is not None:
|
||||
return state.envelope
|
||||
return None
|
||||
|
||||
# -------ADDR-BOOK--------
|
||||
|
||||
def add_addr(self, peer_id: ID, addr: Multiaddr, ttl: int = 0) -> None:
|
||||
@ -193,6 +410,8 @@ class PeerStore(IPeerStore):
|
||||
except trio.WouldBlock:
|
||||
pass # Or consider logging / dropping / replacing stream
|
||||
|
||||
self.maybe_delete_peer_record(peer_id)
|
||||
|
||||
def addrs(self, peer_id: ID) -> list[Multiaddr]:
|
||||
"""
|
||||
:param peer_id: peer ID to get addrs for
|
||||
@ -216,6 +435,8 @@ class PeerStore(IPeerStore):
|
||||
if peer_id in self.peer_data_map:
|
||||
self.peer_data_map[peer_id].clear_addrs()
|
||||
|
||||
self.maybe_delete_peer_record(peer_id)
|
||||
|
||||
def peers_with_addrs(self) -> list[ID]:
|
||||
"""
|
||||
:return: all of the peer IDs which has addrsfloat stored in peer store
|
||||
|
||||
@ -48,12 +48,11 @@ class Multiselect(IMultiselectMuxer):
|
||||
"""
|
||||
self.handlers[protocol] = handler
|
||||
|
||||
# FIXME: Make TProtocol Optional[TProtocol] to keep types consistent
|
||||
async def negotiate(
|
||||
self,
|
||||
communicator: IMultiselectCommunicator,
|
||||
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
) -> tuple[TProtocol, StreamHandlerFn | None]:
|
||||
) -> tuple[TProtocol | None, StreamHandlerFn | None]:
|
||||
"""
|
||||
Negotiate performs protocol selection.
|
||||
|
||||
@ -84,14 +83,14 @@ class Multiselect(IMultiselectMuxer):
|
||||
raise MultiselectError() from error
|
||||
|
||||
else:
|
||||
protocol = TProtocol(command)
|
||||
if protocol in self.handlers:
|
||||
protocol_to_check = None if not command else TProtocol(command)
|
||||
if protocol_to_check in self.handlers:
|
||||
try:
|
||||
await communicator.write(protocol)
|
||||
await communicator.write(command)
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectError() from error
|
||||
|
||||
return protocol, self.handlers[protocol]
|
||||
return protocol_to_check, self.handlers[protocol_to_check]
|
||||
try:
|
||||
await communicator.write(PROTOCOL_NOT_FOUND_MSG)
|
||||
except MultiselectCommunicatorError as error:
|
||||
|
||||
@ -134,8 +134,10 @@ class MultiselectClient(IMultiselectClient):
|
||||
:raise MultiselectClientError: raised when protocol negotiation failed
|
||||
:return: selected protocol
|
||||
"""
|
||||
# Represent `None` protocol as an empty string.
|
||||
protocol_str = protocol if protocol is not None else ""
|
||||
try:
|
||||
await communicator.write(protocol)
|
||||
await communicator.write(protocol_str)
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectClientError() from error
|
||||
|
||||
@ -145,7 +147,7 @@ class MultiselectClient(IMultiselectClient):
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectClientError() from error
|
||||
|
||||
if response == protocol:
|
||||
if response == protocol_str:
|
||||
return protocol
|
||||
if response == PROTOCOL_NOT_FOUND_MSG:
|
||||
raise MultiselectClientError("protocol not supported")
|
||||
|
||||
@ -30,7 +30,10 @@ class MultiselectCommunicator(IMultiselectCommunicator):
|
||||
"""
|
||||
:raise MultiselectCommunicatorError: raised when failed to write to underlying reader
|
||||
""" # noqa: E501
|
||||
msg_bytes = encode_delim(msg_str.encode())
|
||||
if msg_str is None:
|
||||
msg_bytes = encode_delim(b"")
|
||||
else:
|
||||
msg_bytes = encode_delim(msg_str.encode())
|
||||
try:
|
||||
await self.read_writer.write(msg_bytes)
|
||||
except IOException as error:
|
||||
|
||||
@ -15,6 +15,7 @@ from libp2p.custom_types import (
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peerstore import env_to_send_in_RPC
|
||||
|
||||
from .exceptions import (
|
||||
PubsubRouterError,
|
||||
@ -103,6 +104,11 @@ class FloodSub(IPubsubRouter):
|
||||
)
|
||||
rpc_msg = rpc_pb2.RPC(publish=[pubsub_msg])
|
||||
|
||||
# Add the senderRecord of the peer in the RPC msg
|
||||
if isinstance(self.pubsub, Pubsub):
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.pubsub.host)
|
||||
rpc_msg.senderRecord = envelope_bytes
|
||||
|
||||
logger.debug("publishing message %s", pubsub_msg)
|
||||
|
||||
if self.pubsub is None:
|
||||
|
||||
@ -34,10 +34,12 @@ from libp2p.peer.peerinfo import (
|
||||
)
|
||||
from libp2p.peer.peerstore import (
|
||||
PERMANENT_ADDR_TTL,
|
||||
env_to_send_in_RPC,
|
||||
)
|
||||
from libp2p.pubsub import (
|
||||
floodsub,
|
||||
)
|
||||
from libp2p.pubsub.utils import maybe_consume_signed_record
|
||||
from libp2p.tools.async_service import (
|
||||
Service,
|
||||
)
|
||||
@ -226,6 +228,12 @@ class GossipSub(IPubsubRouter, Service):
|
||||
:param rpc: RPC message
|
||||
:param sender_peer_id: id of the peer who sent the message
|
||||
"""
|
||||
# Process the senderRecord if sent
|
||||
if isinstance(self.pubsub, Pubsub):
|
||||
if not maybe_consume_signed_record(rpc, self.pubsub.host, sender_peer_id):
|
||||
logger.error("Received an invalid-signed-record, ignoring the message")
|
||||
return
|
||||
|
||||
control_message = rpc.control
|
||||
|
||||
# Relay each rpc control message to the appropriate handler
|
||||
@ -253,6 +261,11 @@ class GossipSub(IPubsubRouter, Service):
|
||||
)
|
||||
rpc_msg = rpc_pb2.RPC(publish=[pubsub_msg])
|
||||
|
||||
# Add the senderRecord of the peer in the RPC msg
|
||||
if isinstance(self.pubsub, Pubsub):
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.pubsub.host)
|
||||
rpc_msg.senderRecord = envelope_bytes
|
||||
|
||||
logger.debug("publishing message %s", pubsub_msg)
|
||||
|
||||
for peer_id in peers_gen:
|
||||
@ -775,16 +788,16 @@ class GossipSub(IPubsubRouter, Service):
|
||||
# Get list of all seen (seqnos, from) from the (seqno, from) tuples in
|
||||
# seen_messages cache
|
||||
seen_seqnos_and_peers = [
|
||||
seqno_and_from for seqno_and_from in self.pubsub.seen_messages.cache.keys()
|
||||
str(seqno_and_from)
|
||||
for seqno_and_from in self.pubsub.seen_messages.cache.keys()
|
||||
]
|
||||
|
||||
# Add all unknown message ids (ids that appear in ihave_msg but not in
|
||||
# seen_seqnos) to list of messages we want to request
|
||||
# FIXME: Update type of message ID
|
||||
msg_ids_wanted: list[Any] = [
|
||||
msg_ids_wanted: list[str] = [
|
||||
msg_id
|
||||
for msg_id in ihave_msg.messageIDs
|
||||
if literal_eval(msg_id) not in seen_seqnos_and_peers
|
||||
if msg_id not in seen_seqnos_and_peers
|
||||
]
|
||||
|
||||
# Request messages with IWANT message
|
||||
@ -818,6 +831,13 @@ class GossipSub(IPubsubRouter, Service):
|
||||
# 1) Package these messages into a single packet
|
||||
packet: rpc_pb2.RPC = rpc_pb2.RPC()
|
||||
|
||||
# Here the an RPC message is being created and published in response
|
||||
# to the iwant control msg, so we will send a freshly created senderRecord
|
||||
# with the RPC msg
|
||||
if isinstance(self.pubsub, Pubsub):
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.pubsub.host)
|
||||
packet.senderRecord = envelope_bytes
|
||||
|
||||
packet.publish.extend(msgs_to_forward)
|
||||
|
||||
if self.pubsub is None:
|
||||
@ -973,6 +993,12 @@ class GossipSub(IPubsubRouter, Service):
|
||||
raise NoPubsubAttached
|
||||
# Add control message to packet
|
||||
packet: rpc_pb2.RPC = rpc_pb2.RPC()
|
||||
|
||||
# Add the sender's peer-record in the RPC msg
|
||||
if isinstance(self.pubsub, Pubsub):
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.pubsub.host)
|
||||
packet.senderRecord = envelope_bytes
|
||||
|
||||
packet.control.CopyFrom(control_msg)
|
||||
|
||||
# Get stream for peer from pubsub
|
||||
|
||||
@ -14,6 +14,7 @@ message RPC {
|
||||
}
|
||||
|
||||
optional ControlMessage control = 3;
|
||||
optional bytes senderRecord = 4;
|
||||
}
|
||||
|
||||
message Message {
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: rpc.proto
|
||||
# source: libp2p/pubsub/pb/rpc.proto
|
||||
# Protobuf Python Version: 4.25.3
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf.internal import builder as _builder
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
@ -13,39 +14,39 @@ _sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\trpc.proto\x12\tpubsub.pb\"\xb4\x01\n\x03RPC\x12-\n\rsubscriptions\x18\x01 \x03(\x0b\x32\x16.pubsub.pb.RPC.SubOpts\x12#\n\x07publish\x18\x02 \x03(\x0b\x32\x12.pubsub.pb.Message\x12*\n\x07\x63ontrol\x18\x03 \x01(\x0b\x32\x19.pubsub.pb.ControlMessage\x1a-\n\x07SubOpts\x12\x11\n\tsubscribe\x18\x01 \x01(\x08\x12\x0f\n\x07topicid\x18\x02 \x01(\t\"i\n\x07Message\x12\x0f\n\x07\x66rom_id\x18\x01 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\x12\r\n\x05seqno\x18\x03 \x01(\x0c\x12\x10\n\x08topicIDs\x18\x04 \x03(\t\x12\x11\n\tsignature\x18\x05 \x01(\x0c\x12\x0b\n\x03key\x18\x06 \x01(\x0c\"\xb0\x01\n\x0e\x43ontrolMessage\x12&\n\x05ihave\x18\x01 \x03(\x0b\x32\x17.pubsub.pb.ControlIHave\x12&\n\x05iwant\x18\x02 \x03(\x0b\x32\x17.pubsub.pb.ControlIWant\x12&\n\x05graft\x18\x03 \x03(\x0b\x32\x17.pubsub.pb.ControlGraft\x12&\n\x05prune\x18\x04 \x03(\x0b\x32\x17.pubsub.pb.ControlPrune\"3\n\x0c\x43ontrolIHave\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\x12\n\nmessageIDs\x18\x02 \x03(\t\"\"\n\x0c\x43ontrolIWant\x12\x12\n\nmessageIDs\x18\x01 \x03(\t\"\x1f\n\x0c\x43ontrolGraft\x12\x0f\n\x07topicID\x18\x01 \x01(\t\"T\n\x0c\x43ontrolPrune\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\"\n\x05peers\x18\x02 \x03(\x0b\x32\x13.pubsub.pb.PeerInfo\x12\x0f\n\x07\x62\x61\x63koff\x18\x03 \x01(\x04\"4\n\x08PeerInfo\x12\x0e\n\x06peerID\x18\x01 \x01(\x0c\x12\x18\n\x10signedPeerRecord\x18\x02 \x01(\x0c\"\x87\x03\n\x0fTopicDescriptor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x31\n\x04\x61uth\x18\x02 \x01(\x0b\x32#.pubsub.pb.TopicDescriptor.AuthOpts\x12/\n\x03\x65nc\x18\x03 \x01(\x0b\x32\".pubsub.pb.TopicDescriptor.EncOpts\x1a|\n\x08\x41uthOpts\x12:\n\x04mode\x18\x01 \x01(\x0e\x32,.pubsub.pb.TopicDescriptor.AuthOpts.AuthMode\x12\x0c\n\x04keys\x18\x02 \x03(\x0c\"&\n\x08\x41uthMode\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03KEY\x10\x01\x12\x07\n\x03WOT\x10\x02\x1a\x83\x01\n\x07\x45ncOpts\x12\x38\n\x04mode\x18\x01 \x01(\x0e\x32*.pubsub.pb.TopicDescriptor.EncOpts.EncMode\x12\x11\n\tkeyHashes\x18\x02 \x03(\x0c\"+\n\x07\x45ncMode\x12\x08\n\x04NONE\x10\x00\x12\r\n\tSHAREDKEY\x10\x01\x12\x07\n\x03WOT\x10\x02')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1alibp2p/pubsub/pb/rpc.proto\x12\tpubsub.pb\"\xca\x01\n\x03RPC\x12-\n\rsubscriptions\x18\x01 \x03(\x0b\x32\x16.pubsub.pb.RPC.SubOpts\x12#\n\x07publish\x18\x02 \x03(\x0b\x32\x12.pubsub.pb.Message\x12*\n\x07\x63ontrol\x18\x03 \x01(\x0b\x32\x19.pubsub.pb.ControlMessage\x12\x14\n\x0csenderRecord\x18\x04 \x01(\x0c\x1a-\n\x07SubOpts\x12\x11\n\tsubscribe\x18\x01 \x01(\x08\x12\x0f\n\x07topicid\x18\x02 \x01(\t\"i\n\x07Message\x12\x0f\n\x07\x66rom_id\x18\x01 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\x12\r\n\x05seqno\x18\x03 \x01(\x0c\x12\x10\n\x08topicIDs\x18\x04 \x03(\t\x12\x11\n\tsignature\x18\x05 \x01(\x0c\x12\x0b\n\x03key\x18\x06 \x01(\x0c\"\xb0\x01\n\x0e\x43ontrolMessage\x12&\n\x05ihave\x18\x01 \x03(\x0b\x32\x17.pubsub.pb.ControlIHave\x12&\n\x05iwant\x18\x02 \x03(\x0b\x32\x17.pubsub.pb.ControlIWant\x12&\n\x05graft\x18\x03 \x03(\x0b\x32\x17.pubsub.pb.ControlGraft\x12&\n\x05prune\x18\x04 \x03(\x0b\x32\x17.pubsub.pb.ControlPrune\"3\n\x0c\x43ontrolIHave\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\x12\n\nmessageIDs\x18\x02 \x03(\t\"\"\n\x0c\x43ontrolIWant\x12\x12\n\nmessageIDs\x18\x01 \x03(\t\"\x1f\n\x0c\x43ontrolGraft\x12\x0f\n\x07topicID\x18\x01 \x01(\t\"T\n\x0c\x43ontrolPrune\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\"\n\x05peers\x18\x02 \x03(\x0b\x32\x13.pubsub.pb.PeerInfo\x12\x0f\n\x07\x62\x61\x63koff\x18\x03 \x01(\x04\"4\n\x08PeerInfo\x12\x0e\n\x06peerID\x18\x01 \x01(\x0c\x12\x18\n\x10signedPeerRecord\x18\x02 \x01(\x0c\"\x87\x03\n\x0fTopicDescriptor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x31\n\x04\x61uth\x18\x02 \x01(\x0b\x32#.pubsub.pb.TopicDescriptor.AuthOpts\x12/\n\x03\x65nc\x18\x03 \x01(\x0b\x32\".pubsub.pb.TopicDescriptor.EncOpts\x1a|\n\x08\x41uthOpts\x12:\n\x04mode\x18\x01 \x01(\x0e\x32,.pubsub.pb.TopicDescriptor.AuthOpts.AuthMode\x12\x0c\n\x04keys\x18\x02 \x03(\x0c\"&\n\x08\x41uthMode\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03KEY\x10\x01\x12\x07\n\x03WOT\x10\x02\x1a\x83\x01\n\x07\x45ncOpts\x12\x38\n\x04mode\x18\x01 \x01(\x0e\x32*.pubsub.pb.TopicDescriptor.EncOpts.EncMode\x12\x11\n\tkeyHashes\x18\x02 \x03(\x0c\"+\n\x07\x45ncMode\x12\x08\n\x04NONE\x10\x00\x12\r\n\tSHAREDKEY\x10\x01\x12\x07\n\x03WOT\x10\x02')
|
||||
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'rpc_pb2', globals())
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.pubsub.pb.rpc_pb2', _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
|
||||
DESCRIPTOR._options = None
|
||||
_RPC._serialized_start=25
|
||||
_RPC._serialized_end=205
|
||||
_RPC_SUBOPTS._serialized_start=160
|
||||
_RPC_SUBOPTS._serialized_end=205
|
||||
_MESSAGE._serialized_start=207
|
||||
_MESSAGE._serialized_end=312
|
||||
_CONTROLMESSAGE._serialized_start=315
|
||||
_CONTROLMESSAGE._serialized_end=491
|
||||
_CONTROLIHAVE._serialized_start=493
|
||||
_CONTROLIHAVE._serialized_end=544
|
||||
_CONTROLIWANT._serialized_start=546
|
||||
_CONTROLIWANT._serialized_end=580
|
||||
_CONTROLGRAFT._serialized_start=582
|
||||
_CONTROLGRAFT._serialized_end=613
|
||||
_CONTROLPRUNE._serialized_start=615
|
||||
_CONTROLPRUNE._serialized_end=699
|
||||
_PEERINFO._serialized_start=701
|
||||
_PEERINFO._serialized_end=753
|
||||
_TOPICDESCRIPTOR._serialized_start=756
|
||||
_TOPICDESCRIPTOR._serialized_end=1147
|
||||
_TOPICDESCRIPTOR_AUTHOPTS._serialized_start=889
|
||||
_TOPICDESCRIPTOR_AUTHOPTS._serialized_end=1013
|
||||
_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE._serialized_start=975
|
||||
_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE._serialized_end=1013
|
||||
_TOPICDESCRIPTOR_ENCOPTS._serialized_start=1016
|
||||
_TOPICDESCRIPTOR_ENCOPTS._serialized_end=1147
|
||||
_TOPICDESCRIPTOR_ENCOPTS_ENCMODE._serialized_start=1104
|
||||
_TOPICDESCRIPTOR_ENCOPTS_ENCMODE._serialized_end=1147
|
||||
_globals['_RPC']._serialized_start=42
|
||||
_globals['_RPC']._serialized_end=244
|
||||
_globals['_RPC_SUBOPTS']._serialized_start=199
|
||||
_globals['_RPC_SUBOPTS']._serialized_end=244
|
||||
_globals['_MESSAGE']._serialized_start=246
|
||||
_globals['_MESSAGE']._serialized_end=351
|
||||
_globals['_CONTROLMESSAGE']._serialized_start=354
|
||||
_globals['_CONTROLMESSAGE']._serialized_end=530
|
||||
_globals['_CONTROLIHAVE']._serialized_start=532
|
||||
_globals['_CONTROLIHAVE']._serialized_end=583
|
||||
_globals['_CONTROLIWANT']._serialized_start=585
|
||||
_globals['_CONTROLIWANT']._serialized_end=619
|
||||
_globals['_CONTROLGRAFT']._serialized_start=621
|
||||
_globals['_CONTROLGRAFT']._serialized_end=652
|
||||
_globals['_CONTROLPRUNE']._serialized_start=654
|
||||
_globals['_CONTROLPRUNE']._serialized_end=738
|
||||
_globals['_PEERINFO']._serialized_start=740
|
||||
_globals['_PEERINFO']._serialized_end=792
|
||||
_globals['_TOPICDESCRIPTOR']._serialized_start=795
|
||||
_globals['_TOPICDESCRIPTOR']._serialized_end=1186
|
||||
_globals['_TOPICDESCRIPTOR_AUTHOPTS']._serialized_start=928
|
||||
_globals['_TOPICDESCRIPTOR_AUTHOPTS']._serialized_end=1052
|
||||
_globals['_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE']._serialized_start=1014
|
||||
_globals['_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE']._serialized_end=1052
|
||||
_globals['_TOPICDESCRIPTOR_ENCOPTS']._serialized_start=1055
|
||||
_globals['_TOPICDESCRIPTOR_ENCOPTS']._serialized_end=1186
|
||||
_globals['_TOPICDESCRIPTOR_ENCOPTS_ENCMODE']._serialized_start=1143
|
||||
_globals['_TOPICDESCRIPTOR_ENCOPTS_ENCMODE']._serialized_end=1186
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
||||
@ -1,323 +1,132 @@
|
||||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
Modified from https://github.com/libp2p/go-libp2p-pubsub/blob/master/pb/rpc.proto"""
|
||||
from google.protobuf.internal import containers as _containers
|
||||
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import message as _message
|
||||
from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union
|
||||
|
||||
import builtins
|
||||
import collections.abc
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.internal.containers
|
||||
import google.protobuf.internal.enum_type_wrapper
|
||||
import google.protobuf.message
|
||||
import sys
|
||||
import typing
|
||||
DESCRIPTOR: _descriptor.FileDescriptor
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
import typing as typing_extensions
|
||||
else:
|
||||
import typing_extensions
|
||||
class RPC(_message.Message):
|
||||
__slots__ = ("subscriptions", "publish", "control", "senderRecord")
|
||||
class SubOpts(_message.Message):
|
||||
__slots__ = ("subscribe", "topicid")
|
||||
SUBSCRIBE_FIELD_NUMBER: _ClassVar[int]
|
||||
TOPICID_FIELD_NUMBER: _ClassVar[int]
|
||||
subscribe: bool
|
||||
topicid: str
|
||||
def __init__(self, subscribe: bool = ..., topicid: _Optional[str] = ...) -> None: ...
|
||||
SUBSCRIPTIONS_FIELD_NUMBER: _ClassVar[int]
|
||||
PUBLISH_FIELD_NUMBER: _ClassVar[int]
|
||||
CONTROL_FIELD_NUMBER: _ClassVar[int]
|
||||
SENDERRECORD_FIELD_NUMBER: _ClassVar[int]
|
||||
subscriptions: _containers.RepeatedCompositeFieldContainer[RPC.SubOpts]
|
||||
publish: _containers.RepeatedCompositeFieldContainer[Message]
|
||||
control: ControlMessage
|
||||
senderRecord: bytes
|
||||
def __init__(self, subscriptions: _Optional[_Iterable[_Union[RPC.SubOpts, _Mapping]]] = ..., publish: _Optional[_Iterable[_Union[Message, _Mapping]]] = ..., control: _Optional[_Union[ControlMessage, _Mapping]] = ..., senderRecord: _Optional[bytes] = ...) -> None: ... # type: ignore
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
class Message(_message.Message):
|
||||
__slots__ = ("from_id", "data", "seqno", "topicIDs", "signature", "key")
|
||||
FROM_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
DATA_FIELD_NUMBER: _ClassVar[int]
|
||||
SEQNO_FIELD_NUMBER: _ClassVar[int]
|
||||
TOPICIDS_FIELD_NUMBER: _ClassVar[int]
|
||||
SIGNATURE_FIELD_NUMBER: _ClassVar[int]
|
||||
KEY_FIELD_NUMBER: _ClassVar[int]
|
||||
from_id: bytes
|
||||
data: bytes
|
||||
seqno: bytes
|
||||
topicIDs: _containers.RepeatedScalarFieldContainer[str]
|
||||
signature: bytes
|
||||
key: bytes
|
||||
def __init__(self, from_id: _Optional[bytes] = ..., data: _Optional[bytes] = ..., seqno: _Optional[bytes] = ..., topicIDs: _Optional[_Iterable[str]] = ..., signature: _Optional[bytes] = ..., key: _Optional[bytes] = ...) -> None: ...
|
||||
|
||||
@typing.final
|
||||
class RPC(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
class ControlMessage(_message.Message):
|
||||
__slots__ = ("ihave", "iwant", "graft", "prune")
|
||||
IHAVE_FIELD_NUMBER: _ClassVar[int]
|
||||
IWANT_FIELD_NUMBER: _ClassVar[int]
|
||||
GRAFT_FIELD_NUMBER: _ClassVar[int]
|
||||
PRUNE_FIELD_NUMBER: _ClassVar[int]
|
||||
ihave: _containers.RepeatedCompositeFieldContainer[ControlIHave]
|
||||
iwant: _containers.RepeatedCompositeFieldContainer[ControlIWant]
|
||||
graft: _containers.RepeatedCompositeFieldContainer[ControlGraft]
|
||||
prune: _containers.RepeatedCompositeFieldContainer[ControlPrune]
|
||||
def __init__(self, ihave: _Optional[_Iterable[_Union[ControlIHave, _Mapping]]] = ..., iwant: _Optional[_Iterable[_Union[ControlIWant, _Mapping]]] = ..., graft: _Optional[_Iterable[_Union[ControlGraft, _Mapping]]] = ..., prune: _Optional[_Iterable[_Union[ControlPrune, _Mapping]]] = ...) -> None: ... # type: ignore
|
||||
|
||||
@typing.final
|
||||
class SubOpts(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
class ControlIHave(_message.Message):
|
||||
__slots__ = ("topicID", "messageIDs")
|
||||
TOPICID_FIELD_NUMBER: _ClassVar[int]
|
||||
MESSAGEIDS_FIELD_NUMBER: _ClassVar[int]
|
||||
topicID: str
|
||||
messageIDs: _containers.RepeatedScalarFieldContainer[str]
|
||||
def __init__(self, topicID: _Optional[str] = ..., messageIDs: _Optional[_Iterable[str]] = ...) -> None: ...
|
||||
|
||||
SUBSCRIBE_FIELD_NUMBER: builtins.int
|
||||
TOPICID_FIELD_NUMBER: builtins.int
|
||||
subscribe: builtins.bool
|
||||
"""subscribe or unsubscribe"""
|
||||
topicid: builtins.str
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
subscribe: builtins.bool | None = ...,
|
||||
topicid: builtins.str | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["subscribe", b"subscribe", "topicid", b"topicid"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["subscribe", b"subscribe", "topicid", b"topicid"]) -> None: ...
|
||||
class ControlIWant(_message.Message):
|
||||
__slots__ = ("messageIDs",)
|
||||
MESSAGEIDS_FIELD_NUMBER: _ClassVar[int]
|
||||
messageIDs: _containers.RepeatedScalarFieldContainer[str]
|
||||
def __init__(self, messageIDs: _Optional[_Iterable[str]] = ...) -> None: ...
|
||||
|
||||
SUBSCRIPTIONS_FIELD_NUMBER: builtins.int
|
||||
PUBLISH_FIELD_NUMBER: builtins.int
|
||||
CONTROL_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def subscriptions(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___RPC.SubOpts]: ...
|
||||
@property
|
||||
def publish(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message]: ...
|
||||
@property
|
||||
def control(self) -> global___ControlMessage: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
subscriptions: collections.abc.Iterable[global___RPC.SubOpts] | None = ...,
|
||||
publish: collections.abc.Iterable[global___Message] | None = ...,
|
||||
control: global___ControlMessage | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["control", b"control"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["control", b"control", "publish", b"publish", "subscriptions", b"subscriptions"]) -> None: ...
|
||||
class ControlGraft(_message.Message):
|
||||
__slots__ = ("topicID",)
|
||||
TOPICID_FIELD_NUMBER: _ClassVar[int]
|
||||
topicID: str
|
||||
def __init__(self, topicID: _Optional[str] = ...) -> None: ...
|
||||
|
||||
global___RPC = RPC
|
||||
class ControlPrune(_message.Message):
|
||||
__slots__ = ("topicID", "peers", "backoff")
|
||||
TOPICID_FIELD_NUMBER: _ClassVar[int]
|
||||
PEERS_FIELD_NUMBER: _ClassVar[int]
|
||||
BACKOFF_FIELD_NUMBER: _ClassVar[int]
|
||||
topicID: str
|
||||
peers: _containers.RepeatedCompositeFieldContainer[PeerInfo]
|
||||
backoff: int
|
||||
def __init__(self, topicID: _Optional[str] = ..., peers: _Optional[_Iterable[_Union[PeerInfo, _Mapping]]] = ..., backoff: _Optional[int] = ...) -> None: ... # type: ignore
|
||||
|
||||
@typing.final
|
||||
class Message(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
class PeerInfo(_message.Message):
|
||||
__slots__ = ("peerID", "signedPeerRecord")
|
||||
PEERID_FIELD_NUMBER: _ClassVar[int]
|
||||
SIGNEDPEERRECORD_FIELD_NUMBER: _ClassVar[int]
|
||||
peerID: bytes
|
||||
signedPeerRecord: bytes
|
||||
def __init__(self, peerID: _Optional[bytes] = ..., signedPeerRecord: _Optional[bytes] = ...) -> None: ...
|
||||
|
||||
FROM_ID_FIELD_NUMBER: builtins.int
|
||||
DATA_FIELD_NUMBER: builtins.int
|
||||
SEQNO_FIELD_NUMBER: builtins.int
|
||||
TOPICIDS_FIELD_NUMBER: builtins.int
|
||||
SIGNATURE_FIELD_NUMBER: builtins.int
|
||||
KEY_FIELD_NUMBER: builtins.int
|
||||
from_id: builtins.bytes
|
||||
data: builtins.bytes
|
||||
seqno: builtins.bytes
|
||||
signature: builtins.bytes
|
||||
key: builtins.bytes
|
||||
@property
|
||||
def topicIDs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
from_id: builtins.bytes | None = ...,
|
||||
data: builtins.bytes | None = ...,
|
||||
seqno: builtins.bytes | None = ...,
|
||||
topicIDs: collections.abc.Iterable[builtins.str] | None = ...,
|
||||
signature: builtins.bytes | None = ...,
|
||||
key: builtins.bytes | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["data", b"data", "from_id", b"from_id", "key", b"key", "seqno", b"seqno", "signature", b"signature"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["data", b"data", "from_id", b"from_id", "key", b"key", "seqno", b"seqno", "signature", b"signature", "topicIDs", b"topicIDs"]) -> None: ...
|
||||
|
||||
global___Message = Message
|
||||
|
||||
@typing.final
|
||||
class ControlMessage(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
IHAVE_FIELD_NUMBER: builtins.int
|
||||
IWANT_FIELD_NUMBER: builtins.int
|
||||
GRAFT_FIELD_NUMBER: builtins.int
|
||||
PRUNE_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def ihave(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ControlIHave]: ...
|
||||
@property
|
||||
def iwant(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ControlIWant]: ...
|
||||
@property
|
||||
def graft(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ControlGraft]: ...
|
||||
@property
|
||||
def prune(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ControlPrune]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ihave: collections.abc.Iterable[global___ControlIHave] | None = ...,
|
||||
iwant: collections.abc.Iterable[global___ControlIWant] | None = ...,
|
||||
graft: collections.abc.Iterable[global___ControlGraft] | None = ...,
|
||||
prune: collections.abc.Iterable[global___ControlPrune] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["graft", b"graft", "ihave", b"ihave", "iwant", b"iwant", "prune", b"prune"]) -> None: ...
|
||||
|
||||
global___ControlMessage = ControlMessage
|
||||
|
||||
@typing.final
|
||||
class ControlIHave(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
TOPICID_FIELD_NUMBER: builtins.int
|
||||
MESSAGEIDS_FIELD_NUMBER: builtins.int
|
||||
topicID: builtins.str
|
||||
@property
|
||||
def messageIDs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
topicID: builtins.str | None = ...,
|
||||
messageIDs: collections.abc.Iterable[builtins.str] | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["topicID", b"topicID"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["messageIDs", b"messageIDs", "topicID", b"topicID"]) -> None: ...
|
||||
|
||||
global___ControlIHave = ControlIHave
|
||||
|
||||
@typing.final
|
||||
class ControlIWant(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
MESSAGEIDS_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def messageIDs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
messageIDs: collections.abc.Iterable[builtins.str] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["messageIDs", b"messageIDs"]) -> None: ...
|
||||
|
||||
global___ControlIWant = ControlIWant
|
||||
|
||||
@typing.final
|
||||
class ControlGraft(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
TOPICID_FIELD_NUMBER: builtins.int
|
||||
topicID: builtins.str
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
topicID: builtins.str | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["topicID", b"topicID"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["topicID", b"topicID"]) -> None: ...
|
||||
|
||||
global___ControlGraft = ControlGraft
|
||||
|
||||
@typing.final
|
||||
class ControlPrune(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
TOPICID_FIELD_NUMBER: builtins.int
|
||||
PEERS_FIELD_NUMBER: builtins.int
|
||||
BACKOFF_FIELD_NUMBER: builtins.int
|
||||
topicID: builtins.str
|
||||
backoff: builtins.int
|
||||
@property
|
||||
def peers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___PeerInfo]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
topicID: builtins.str | None = ...,
|
||||
peers: collections.abc.Iterable[global___PeerInfo] | None = ...,
|
||||
backoff: builtins.int | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["backoff", b"backoff", "topicID", b"topicID"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["backoff", b"backoff", "peers", b"peers", "topicID", b"topicID"]) -> None: ...
|
||||
|
||||
global___ControlPrune = ControlPrune
|
||||
|
||||
@typing.final
|
||||
class PeerInfo(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
PEERID_FIELD_NUMBER: builtins.int
|
||||
SIGNEDPEERRECORD_FIELD_NUMBER: builtins.int
|
||||
peerID: builtins.bytes
|
||||
signedPeerRecord: builtins.bytes
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
peerID: builtins.bytes | None = ...,
|
||||
signedPeerRecord: builtins.bytes | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["peerID", b"peerID", "signedPeerRecord", b"signedPeerRecord"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["peerID", b"peerID", "signedPeerRecord", b"signedPeerRecord"]) -> None: ...
|
||||
|
||||
global___PeerInfo = PeerInfo
|
||||
|
||||
@typing.final
|
||||
class TopicDescriptor(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
@typing.final
|
||||
class AuthOpts(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
class _AuthMode:
|
||||
ValueType = typing.NewType("ValueType", builtins.int)
|
||||
V: typing_extensions.TypeAlias = ValueType
|
||||
|
||||
class _AuthModeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[TopicDescriptor.AuthOpts._AuthMode.ValueType], builtins.type):
|
||||
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
||||
NONE: TopicDescriptor.AuthOpts._AuthMode.ValueType # 0
|
||||
"""no authentication, anyone can publish"""
|
||||
KEY: TopicDescriptor.AuthOpts._AuthMode.ValueType # 1
|
||||
"""only messages signed by keys in the topic descriptor are accepted"""
|
||||
WOT: TopicDescriptor.AuthOpts._AuthMode.ValueType # 2
|
||||
"""web of trust, certificates can allow publisher set to grow"""
|
||||
|
||||
class AuthMode(_AuthMode, metaclass=_AuthModeEnumTypeWrapper): ...
|
||||
NONE: TopicDescriptor.AuthOpts.AuthMode.ValueType # 0
|
||||
"""no authentication, anyone can publish"""
|
||||
KEY: TopicDescriptor.AuthOpts.AuthMode.ValueType # 1
|
||||
"""only messages signed by keys in the topic descriptor are accepted"""
|
||||
WOT: TopicDescriptor.AuthOpts.AuthMode.ValueType # 2
|
||||
"""web of trust, certificates can allow publisher set to grow"""
|
||||
|
||||
MODE_FIELD_NUMBER: builtins.int
|
||||
KEYS_FIELD_NUMBER: builtins.int
|
||||
mode: global___TopicDescriptor.AuthOpts.AuthMode.ValueType
|
||||
@property
|
||||
def keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]:
|
||||
"""root keys to trust"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
mode: global___TopicDescriptor.AuthOpts.AuthMode.ValueType | None = ...,
|
||||
keys: collections.abc.Iterable[builtins.bytes] | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["mode", b"mode"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["keys", b"keys", "mode", b"mode"]) -> None: ...
|
||||
|
||||
@typing.final
|
||||
class EncOpts(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
class _EncMode:
|
||||
ValueType = typing.NewType("ValueType", builtins.int)
|
||||
V: typing_extensions.TypeAlias = ValueType
|
||||
|
||||
class _EncModeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[TopicDescriptor.EncOpts._EncMode.ValueType], builtins.type):
|
||||
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
||||
NONE: TopicDescriptor.EncOpts._EncMode.ValueType # 0
|
||||
"""no encryption, anyone can read"""
|
||||
SHAREDKEY: TopicDescriptor.EncOpts._EncMode.ValueType # 1
|
||||
"""messages are encrypted with shared key"""
|
||||
WOT: TopicDescriptor.EncOpts._EncMode.ValueType # 2
|
||||
"""web of trust, certificates can allow publisher set to grow"""
|
||||
|
||||
class EncMode(_EncMode, metaclass=_EncModeEnumTypeWrapper): ...
|
||||
NONE: TopicDescriptor.EncOpts.EncMode.ValueType # 0
|
||||
"""no encryption, anyone can read"""
|
||||
SHAREDKEY: TopicDescriptor.EncOpts.EncMode.ValueType # 1
|
||||
"""messages are encrypted with shared key"""
|
||||
WOT: TopicDescriptor.EncOpts.EncMode.ValueType # 2
|
||||
"""web of trust, certificates can allow publisher set to grow"""
|
||||
|
||||
MODE_FIELD_NUMBER: builtins.int
|
||||
KEYHASHES_FIELD_NUMBER: builtins.int
|
||||
mode: global___TopicDescriptor.EncOpts.EncMode.ValueType
|
||||
@property
|
||||
def keyHashes(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]:
|
||||
"""the hashes of the shared keys used (salted)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
mode: global___TopicDescriptor.EncOpts.EncMode.ValueType | None = ...,
|
||||
keyHashes: collections.abc.Iterable[builtins.bytes] | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["mode", b"mode"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["keyHashes", b"keyHashes", "mode", b"mode"]) -> None: ...
|
||||
|
||||
NAME_FIELD_NUMBER: builtins.int
|
||||
AUTH_FIELD_NUMBER: builtins.int
|
||||
ENC_FIELD_NUMBER: builtins.int
|
||||
name: builtins.str
|
||||
@property
|
||||
def auth(self) -> global___TopicDescriptor.AuthOpts: ...
|
||||
@property
|
||||
def enc(self) -> global___TopicDescriptor.EncOpts: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: builtins.str | None = ...,
|
||||
auth: global___TopicDescriptor.AuthOpts | None = ...,
|
||||
enc: global___TopicDescriptor.EncOpts | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["auth", b"auth", "enc", b"enc", "name", b"name"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["auth", b"auth", "enc", b"enc", "name", b"name"]) -> None: ...
|
||||
|
||||
global___TopicDescriptor = TopicDescriptor
|
||||
class TopicDescriptor(_message.Message):
|
||||
__slots__ = ("name", "auth", "enc")
|
||||
class AuthOpts(_message.Message):
|
||||
__slots__ = ("mode", "keys")
|
||||
class AuthMode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
||||
__slots__ = ()
|
||||
NONE: _ClassVar[TopicDescriptor.AuthOpts.AuthMode]
|
||||
KEY: _ClassVar[TopicDescriptor.AuthOpts.AuthMode]
|
||||
WOT: _ClassVar[TopicDescriptor.AuthOpts.AuthMode]
|
||||
NONE: TopicDescriptor.AuthOpts.AuthMode
|
||||
KEY: TopicDescriptor.AuthOpts.AuthMode
|
||||
WOT: TopicDescriptor.AuthOpts.AuthMode
|
||||
MODE_FIELD_NUMBER: _ClassVar[int]
|
||||
KEYS_FIELD_NUMBER: _ClassVar[int]
|
||||
mode: TopicDescriptor.AuthOpts.AuthMode
|
||||
keys: _containers.RepeatedScalarFieldContainer[bytes]
|
||||
def __init__(self, mode: _Optional[_Union[TopicDescriptor.AuthOpts.AuthMode, str]] = ..., keys: _Optional[_Iterable[bytes]] = ...) -> None: ...
|
||||
class EncOpts(_message.Message):
|
||||
__slots__ = ("mode", "keyHashes")
|
||||
class EncMode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
||||
__slots__ = ()
|
||||
NONE: _ClassVar[TopicDescriptor.EncOpts.EncMode]
|
||||
SHAREDKEY: _ClassVar[TopicDescriptor.EncOpts.EncMode]
|
||||
WOT: _ClassVar[TopicDescriptor.EncOpts.EncMode]
|
||||
NONE: TopicDescriptor.EncOpts.EncMode
|
||||
SHAREDKEY: TopicDescriptor.EncOpts.EncMode
|
||||
WOT: TopicDescriptor.EncOpts.EncMode
|
||||
MODE_FIELD_NUMBER: _ClassVar[int]
|
||||
KEYHASHES_FIELD_NUMBER: _ClassVar[int]
|
||||
mode: TopicDescriptor.EncOpts.EncMode
|
||||
keyHashes: _containers.RepeatedScalarFieldContainer[bytes]
|
||||
def __init__(self, mode: _Optional[_Union[TopicDescriptor.EncOpts.EncMode, str]] = ..., keyHashes: _Optional[_Iterable[bytes]] = ...) -> None: ...
|
||||
NAME_FIELD_NUMBER: _ClassVar[int]
|
||||
AUTH_FIELD_NUMBER: _ClassVar[int]
|
||||
ENC_FIELD_NUMBER: _ClassVar[int]
|
||||
name: str
|
||||
auth: TopicDescriptor.AuthOpts
|
||||
enc: TopicDescriptor.EncOpts
|
||||
def __init__(self, name: _Optional[str] = ..., auth: _Optional[_Union[TopicDescriptor.AuthOpts, _Mapping]] = ..., enc: _Optional[_Union[TopicDescriptor.EncOpts, _Mapping]] = ...) -> None: ... # type: ignore
|
||||
|
||||
@ -11,6 +11,10 @@ import functools
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from typing import (
|
||||
NamedTuple,
|
||||
cast,
|
||||
)
|
||||
|
||||
import base58
|
||||
import trio
|
||||
@ -26,6 +30,8 @@ from libp2p.crypto.keys import (
|
||||
PrivateKey,
|
||||
)
|
||||
from libp2p.custom_types import (
|
||||
AsyncValidatorFn,
|
||||
SyncValidatorFn,
|
||||
TProtocol,
|
||||
ValidatorFn,
|
||||
)
|
||||
@ -50,6 +56,8 @@ from libp2p.peer.id import (
|
||||
from libp2p.peer.peerdata import (
|
||||
PeerDataError,
|
||||
)
|
||||
from libp2p.peer.peerstore import env_to_send_in_RPC
|
||||
from libp2p.pubsub.utils import maybe_consume_signed_record
|
||||
from libp2p.tools.async_service import (
|
||||
Service,
|
||||
)
|
||||
@ -71,11 +79,6 @@ from .pubsub_notifee import (
|
||||
from .subscription import (
|
||||
TrioSubscriptionAPI,
|
||||
)
|
||||
from .validation_throttler import (
|
||||
TopicValidator,
|
||||
ValidationResult,
|
||||
ValidationThrottler,
|
||||
)
|
||||
from .validators import (
|
||||
PUBSUB_SIGNING_PREFIX,
|
||||
signature_validator,
|
||||
@ -96,6 +99,14 @@ def get_content_addressed_msg_id(msg: rpc_pb2.Message) -> bytes:
|
||||
return base64.b64encode(hashlib.sha256(msg.data).digest())
|
||||
|
||||
|
||||
class TopicValidator(NamedTuple):
|
||||
validator: ValidatorFn
|
||||
is_async: bool
|
||||
|
||||
|
||||
MAX_CONCURRENT_VALIDATORS = 10
|
||||
|
||||
|
||||
class Pubsub(Service, IPubsub):
|
||||
host: IHost
|
||||
|
||||
@ -103,6 +114,7 @@ class Pubsub(Service, IPubsub):
|
||||
|
||||
peer_receive_channel: trio.MemoryReceiveChannel[ID]
|
||||
dead_peer_receive_channel: trio.MemoryReceiveChannel[ID]
|
||||
_validator_semaphore: trio.Semaphore
|
||||
|
||||
seen_messages: LastSeenCache
|
||||
|
||||
@ -137,11 +149,7 @@ class Pubsub(Service, IPubsub):
|
||||
msg_id_constructor: Callable[
|
||||
[rpc_pb2.Message], bytes
|
||||
] = get_peer_and_seqno_msg_id,
|
||||
# TODO: these values have been copied from Go, but try to tune these dynamically
|
||||
validation_queue_size: int = 32,
|
||||
global_throttle_limit: int = 8192,
|
||||
default_topic_throttle_limit: int = 1024,
|
||||
validation_worker_count: int | None = None,
|
||||
max_concurrent_validator_count: int = MAX_CONCURRENT_VALIDATORS,
|
||||
) -> None:
|
||||
"""
|
||||
Construct a new Pubsub object, which is responsible for handling all
|
||||
@ -167,6 +175,7 @@ class Pubsub(Service, IPubsub):
|
||||
# Therefore, we can only close from the receive side.
|
||||
self.peer_receive_channel = peer_receive
|
||||
self.dead_peer_receive_channel = dead_peer_receive
|
||||
self._validator_semaphore = trio.Semaphore(max_concurrent_validator_count)
|
||||
# Register a notifee
|
||||
self.host.get_network().register_notifee(
|
||||
PubsubNotifee(peer_send, dead_peer_send)
|
||||
@ -202,15 +211,7 @@ class Pubsub(Service, IPubsub):
|
||||
# Create peers map, which maps peer_id (as string) to stream (to a given peer)
|
||||
self.peers = {}
|
||||
|
||||
# Validation Throttler
|
||||
self.validation_throttler = ValidationThrottler(
|
||||
queue_size=validation_queue_size,
|
||||
global_throttle_limit=global_throttle_limit,
|
||||
default_topic_throttle_limit=default_topic_throttle_limit,
|
||||
worker_count=validation_worker_count or 4,
|
||||
)
|
||||
|
||||
# Keep a mapping of topic -> TopicValidator for easier lookup
|
||||
# Map of topic to topic validator
|
||||
self.topic_validators = {}
|
||||
|
||||
self.counter = int(time.time())
|
||||
@ -222,19 +223,10 @@ class Pubsub(Service, IPubsub):
|
||||
self.event_handle_dead_peer_queue_started = trio.Event()
|
||||
|
||||
async def run(self) -> None:
|
||||
self.manager.run_daemon_task(self._start_validation_throttler)
|
||||
self.manager.run_daemon_task(self.handle_peer_queue)
|
||||
self.manager.run_daemon_task(self.handle_dead_peer_queue)
|
||||
await self.manager.wait_finished()
|
||||
|
||||
async def _start_validation_throttler(self) -> None:
|
||||
"""Start validation throttler in current nursery context"""
|
||||
async with trio.open_nursery() as nursery:
|
||||
await self.validation_throttler.start(nursery)
|
||||
# Keep nursery alive until service stops
|
||||
while self.manager.is_running:
|
||||
await self.manager.wait_finished()
|
||||
|
||||
@property
|
||||
def my_id(self) -> ID:
|
||||
return self.host.get_id()
|
||||
@ -257,6 +249,10 @@ class Pubsub(Service, IPubsub):
|
||||
packet.subscriptions.extend(
|
||||
[rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id)]
|
||||
)
|
||||
# Add the sender's signedRecord in the RPC message
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
||||
packet.senderRecord = envelope_bytes
|
||||
|
||||
return packet
|
||||
|
||||
async def continuously_read_stream(self, stream: INetStream) -> None:
|
||||
@ -273,6 +269,14 @@ class Pubsub(Service, IPubsub):
|
||||
incoming: bytes = await read_varint_prefixed_bytes(stream)
|
||||
rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC()
|
||||
rpc_incoming.ParseFromString(incoming)
|
||||
|
||||
# Process the sender's signed-record if sent
|
||||
if not maybe_consume_signed_record(rpc_incoming, self.host, peer_id):
|
||||
logger.error(
|
||||
"Received an invalid-signed-record, ignoring the incoming msg"
|
||||
)
|
||||
continue
|
||||
|
||||
if rpc_incoming.publish:
|
||||
# deal with RPC.publish
|
||||
for msg in rpc_incoming.publish:
|
||||
@ -314,12 +318,7 @@ class Pubsub(Service, IPubsub):
|
||||
)
|
||||
|
||||
def set_topic_validator(
|
||||
self,
|
||||
topic: str,
|
||||
validator: ValidatorFn,
|
||||
is_async_validator: bool,
|
||||
timeout: float | None = None,
|
||||
throttle_limit: int | None = None,
|
||||
self, topic: str, validator: ValidatorFn, is_async_validator: bool
|
||||
) -> None:
|
||||
"""
|
||||
Register a validator under the given topic. One topic can only have one
|
||||
@ -328,18 +327,8 @@ class Pubsub(Service, IPubsub):
|
||||
:param topic: the topic to register validator under
|
||||
:param validator: the validator used to validate messages published to the topic
|
||||
:param is_async_validator: indicate if the validator is an asynchronous validator
|
||||
:param timeout: optional timeout for the validator
|
||||
:param throttle_limit: optional throttle limit for the validator
|
||||
""" # noqa: E501
|
||||
# Create throttled topic validator
|
||||
topic_validator = self.validation_throttler.create_topic_validator(
|
||||
topic=topic,
|
||||
validator=validator,
|
||||
is_async=is_async_validator,
|
||||
timeout=timeout,
|
||||
throttle_limit=throttle_limit,
|
||||
)
|
||||
self.topic_validators[topic] = topic_validator
|
||||
self.topic_validators[topic] = TopicValidator(validator, is_async_validator)
|
||||
|
||||
def remove_topic_validator(self, topic: str) -> None:
|
||||
"""
|
||||
@ -349,18 +338,17 @@ class Pubsub(Service, IPubsub):
|
||||
"""
|
||||
self.topic_validators.pop(topic, None)
|
||||
|
||||
def get_msg_validators(self, msg: rpc_pb2.Message) -> list[TopicValidator]:
|
||||
def get_msg_validators(self, msg: rpc_pb2.Message) -> tuple[TopicValidator, ...]:
|
||||
"""
|
||||
Get all validators corresponding to the topics in the message.
|
||||
|
||||
:param msg: the message published to the topic
|
||||
:return: list of topic validators for the message's topics
|
||||
"""
|
||||
return [
|
||||
return tuple(
|
||||
self.topic_validators[topic]
|
||||
for topic in msg.topicIDs
|
||||
if topic in self.topic_validators
|
||||
]
|
||||
)
|
||||
|
||||
def add_to_blacklist(self, peer_id: ID) -> None:
|
||||
"""
|
||||
@ -598,6 +586,9 @@ class Pubsub(Service, IPubsub):
|
||||
[rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id)]
|
||||
)
|
||||
|
||||
# Add the senderRecord of the peer in the RPC msg
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
||||
packet.senderRecord = envelope_bytes
|
||||
# Send out subscribe message to all peers
|
||||
await self.message_all_peers(packet.SerializeToString())
|
||||
|
||||
@ -630,6 +621,9 @@ class Pubsub(Service, IPubsub):
|
||||
packet.subscriptions.extend(
|
||||
[rpc_pb2.RPC.SubOpts(subscribe=False, topicid=topic_id)]
|
||||
)
|
||||
# Add the senderRecord of the peer in the RPC msg
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
||||
packet.senderRecord = envelope_bytes
|
||||
|
||||
# Send out unsubscribe message to all peers
|
||||
await self.message_all_peers(packet.SerializeToString())
|
||||
@ -689,63 +683,60 @@ class Pubsub(Service, IPubsub):
|
||||
|
||||
logger.debug("successfully published message %s", msg)
|
||||
|
||||
async def validate_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
|
||||
async def validate_msg(
|
||||
self,
|
||||
msg_forwarder: ID,
|
||||
msg: rpc_pb2.Message,
|
||||
) -> None:
|
||||
"""
|
||||
Validate the received message.
|
||||
|
||||
:param msg_forwarder: the peer who forward us the message.
|
||||
:param msg: the message.
|
||||
"""
|
||||
# Get applicable validators for this message
|
||||
validators = self.get_msg_validators(msg)
|
||||
sync_topic_validators: list[SyncValidatorFn] = []
|
||||
async_topic_validators: list[AsyncValidatorFn] = []
|
||||
for topic_validator in self.get_msg_validators(msg):
|
||||
if topic_validator.is_async:
|
||||
async_topic_validators.append(
|
||||
cast(AsyncValidatorFn, topic_validator.validator)
|
||||
)
|
||||
else:
|
||||
sync_topic_validators.append(
|
||||
cast(SyncValidatorFn, topic_validator.validator)
|
||||
)
|
||||
|
||||
if not validators:
|
||||
# No validators, accept immediately
|
||||
return
|
||||
for validator in sync_topic_validators:
|
||||
if not validator(msg_forwarder, msg):
|
||||
raise ValidationError(f"Validation failed for msg={msg}")
|
||||
|
||||
# Use trio.Event for async coordination
|
||||
validation_event = trio.Event()
|
||||
result_container: dict[str, ValidationResult | None | Exception] = {
|
||||
"result": None,
|
||||
"error": None,
|
||||
}
|
||||
if len(async_topic_validators) > 0:
|
||||
# Appends to lists are thread safe in CPython
|
||||
results: list[bool] = []
|
||||
|
||||
def handle_validation_result(
|
||||
result: ValidationResult, error: Exception | None
|
||||
) -> None:
|
||||
result_container["result"] = result
|
||||
result_container["error"] = error
|
||||
validation_event.set()
|
||||
async with trio.open_nursery() as nursery:
|
||||
for async_validator in async_topic_validators:
|
||||
nursery.start_soon(
|
||||
self._run_async_validator,
|
||||
async_validator,
|
||||
msg_forwarder,
|
||||
msg,
|
||||
results,
|
||||
)
|
||||
|
||||
# Submit for throttled validation
|
||||
success = await self.validation_throttler.submit_validation(
|
||||
validators=validators,
|
||||
msg_forwarder=msg_forwarder,
|
||||
msg=msg,
|
||||
result_callback=handle_validation_result,
|
||||
)
|
||||
if not all(results):
|
||||
raise ValidationError(f"Validation failed for msg={msg}")
|
||||
|
||||
if not success:
|
||||
# Validation was throttled at queue level
|
||||
raise ValidationError("Validation throttled at queue level")
|
||||
|
||||
# Wait for validation result
|
||||
await validation_event.wait()
|
||||
|
||||
result = result_container["result"]
|
||||
error = result_container["error"]
|
||||
|
||||
if error:
|
||||
raise ValidationError(f"Validation error: {error}")
|
||||
|
||||
if result == ValidationResult.REJECT:
|
||||
raise ValidationError("Message validation rejected")
|
||||
elif result == ValidationResult.THROTTLED:
|
||||
raise ValidationError("Message validation throttled")
|
||||
elif result == ValidationResult.IGNORE:
|
||||
# Treat IGNORE as rejection for now, or you could silently drop
|
||||
raise ValidationError("Message validation ignored")
|
||||
# ACCEPT case - just return normally
|
||||
async def _run_async_validator(
|
||||
self,
|
||||
func: AsyncValidatorFn,
|
||||
msg_forwarder: ID,
|
||||
msg: rpc_pb2.Message,
|
||||
results: list[bool],
|
||||
) -> None:
|
||||
async with self._validator_semaphore:
|
||||
result = await func(msg_forwarder, msg)
|
||||
results.append(result)
|
||||
|
||||
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
|
||||
"""
|
||||
|
||||
50
libp2p/pubsub/utils.py
Normal file
50
libp2p/pubsub/utils.py
Normal file
@ -0,0 +1,50 @@
|
||||
import logging
|
||||
|
||||
from libp2p.abc import IHost
|
||||
from libp2p.peer.envelope import consume_envelope
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.pubsub.pb.rpc_pb2 import RPC
|
||||
|
||||
logger = logging.getLogger("pubsub-example.utils")
|
||||
|
||||
|
||||
def maybe_consume_signed_record(msg: RPC, host: IHost, peer_id: ID) -> bool:
|
||||
"""
|
||||
Attempt to parse and store a signed-peer-record (Envelope) received during
|
||||
PubSub communication. If the record is invalid, the peer-id does not match, or
|
||||
updating the peerstore fails, the function logs an error and returns False.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
msg : RPC
|
||||
The protobuf message received during PubSub communication.
|
||||
host : IHost
|
||||
The local host instance, providing access to the peerstore for storing
|
||||
verified peer records.
|
||||
peer_id : ID | None, optional
|
||||
The expected peer ID for record validation. If provided, the peer ID
|
||||
inside the record must match this value.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if a valid signed peer record was successfully consumed and stored,
|
||||
False otherwise.
|
||||
|
||||
"""
|
||||
if msg.HasField("senderRecord"):
|
||||
try:
|
||||
# Convert the signed-peer-record(Envelope) from
|
||||
# protobuf bytes
|
||||
envelope, record = consume_envelope(msg.senderRecord, "libp2p-peer-record")
|
||||
if not record.peer_id == peer_id:
|
||||
return False
|
||||
|
||||
# Use the default TTL of 2 hours (7200 seconds)
|
||||
if not host.get_peerstore().consume_peer_record(envelope, 7200):
|
||||
logger.error("Failed to update the Certified-Addr-Book")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error("Failed to update the Certified-Addr-Book: %s", e)
|
||||
return False
|
||||
return True
|
||||
@ -1,314 +0,0 @@
|
||||
from collections.abc import (
|
||||
Callable,
|
||||
)
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import logging
|
||||
from typing import (
|
||||
NamedTuple,
|
||||
cast,
|
||||
)
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.custom_types import AsyncValidatorFn, ValidatorFn
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
|
||||
from .pb import (
|
||||
rpc_pb2,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("libp2p.pubsub.validation")
|
||||
|
||||
|
||||
class ValidationResult(Enum):
|
||||
ACCEPT = "accept"
|
||||
REJECT = "reject"
|
||||
IGNORE = "ignore"
|
||||
THROTTLED = "throttled"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationRequest:
|
||||
"""Request for message validation"""
|
||||
|
||||
validators: list["TopicValidator"]
|
||||
msg_forwarder: ID # peer ID
|
||||
msg: rpc_pb2.Message # message object
|
||||
result_callback: Callable[[ValidationResult, Exception | None], None]
|
||||
|
||||
|
||||
class TopicValidator(NamedTuple):
|
||||
topic: str
|
||||
validator: ValidatorFn
|
||||
is_async: bool
|
||||
timeout: float | None = None
|
||||
# Per-topic throttle semaphore
|
||||
throttle_semaphore: trio.Semaphore | None = None
|
||||
|
||||
|
||||
class ValidationThrottler:
|
||||
"""Manages all validation throttling mechanisms"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
queue_size: int = 32,
|
||||
global_throttle_limit: int = 8192,
|
||||
default_topic_throttle_limit: int = 1024,
|
||||
worker_count: int | None = None,
|
||||
):
|
||||
# 1. Queue-level throttling - bounded memory channel
|
||||
self._validation_send, self._validation_receive = trio.open_memory_channel[
|
||||
ValidationRequest
|
||||
](queue_size)
|
||||
|
||||
# 2. Global validation throttling - limits total concurrent async validations
|
||||
self._global_throttle = trio.Semaphore(global_throttle_limit)
|
||||
|
||||
# 3. Per-topic throttling - each validator gets its own semaphore
|
||||
self._default_topic_throttle_limit = default_topic_throttle_limit
|
||||
|
||||
# Worker management
|
||||
# TODO: Find a better way to manage worker count
|
||||
self._worker_count = worker_count or 4
|
||||
self._running = False
|
||||
|
||||
async def start(self, nursery: trio.Nursery) -> None:
|
||||
"""Start the validation workers"""
|
||||
self._running = True
|
||||
|
||||
# Start validation worker tasks
|
||||
for i in range(self._worker_count):
|
||||
nursery.start_soon(self._validation_worker, f"worker-{i}")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the validation system"""
|
||||
self._running = False
|
||||
await self._validation_send.aclose()
|
||||
|
||||
def create_topic_validator(
|
||||
self,
|
||||
topic: str,
|
||||
validator: ValidatorFn,
|
||||
is_async: bool,
|
||||
timeout: float | None = None,
|
||||
throttle_limit: int | None = None,
|
||||
) -> TopicValidator:
|
||||
"""Create a new topic validator with its own throttle"""
|
||||
limit = throttle_limit or self._default_topic_throttle_limit
|
||||
throttle_sem = trio.Semaphore(limit)
|
||||
|
||||
return TopicValidator(
|
||||
topic=topic,
|
||||
validator=validator,
|
||||
is_async=is_async,
|
||||
timeout=timeout,
|
||||
throttle_semaphore=throttle_sem,
|
||||
)
|
||||
|
||||
async def submit_validation(
|
||||
self,
|
||||
validators: list[TopicValidator],
|
||||
msg_forwarder: ID,
|
||||
msg: rpc_pb2.Message,
|
||||
result_callback: Callable[[ValidationResult, Exception | None], None],
|
||||
) -> bool:
|
||||
"""
|
||||
Submit a message for validation.
|
||||
Returns True if queued successfully, False if queue is full (throttled).
|
||||
"""
|
||||
if not self._running:
|
||||
result_callback(
|
||||
ValidationResult.REJECT, Exception("Validation system not running")
|
||||
)
|
||||
return False
|
||||
|
||||
request = ValidationRequest(
|
||||
validators=validators,
|
||||
msg_forwarder=msg_forwarder,
|
||||
msg=msg,
|
||||
result_callback=result_callback,
|
||||
)
|
||||
|
||||
try:
|
||||
# This will raise trio.WouldBlock if queue is full
|
||||
self._validation_send.send_nowait(request)
|
||||
return True
|
||||
except trio.WouldBlock:
|
||||
# Queue-level throttling: drop the message
|
||||
logger.debug(
|
||||
"Validation queue full, dropping message from %s", msg_forwarder
|
||||
)
|
||||
result_callback(
|
||||
ValidationResult.THROTTLED, Exception("Validation queue full")
|
||||
)
|
||||
return False
|
||||
|
||||
async def _validation_worker(self, worker_id: str) -> None:
|
||||
"""Worker that processes validation requests"""
|
||||
logger.debug("Validation worker %s started", worker_id)
|
||||
|
||||
async with self._validation_receive:
|
||||
async for request in self._validation_receive:
|
||||
if not self._running:
|
||||
break
|
||||
|
||||
try:
|
||||
# Process the validation request
|
||||
result = await self._validate_message(request)
|
||||
request.result_callback(result, None)
|
||||
except Exception as e:
|
||||
logger.exception("Error in validation worker %s", worker_id)
|
||||
request.result_callback(ValidationResult.REJECT, e)
|
||||
|
||||
logger.debug("Validation worker %s stopped", worker_id)
|
||||
|
||||
async def _validate_message(self, request: ValidationRequest) -> ValidationResult:
|
||||
"""Core validation logic with throttling"""
|
||||
validators = request.validators
|
||||
msg_forwarder = request.msg_forwarder
|
||||
msg = request.msg
|
||||
|
||||
if not validators:
|
||||
return ValidationResult.ACCEPT
|
||||
|
||||
# Separate sync and async validators
|
||||
sync_validators = [v for v in validators if not v.is_async]
|
||||
async_validators = [v for v in validators if v.is_async]
|
||||
|
||||
# Run synchronous validators first
|
||||
for validator in sync_validators:
|
||||
try:
|
||||
# Apply per-topic throttling even for sync validators
|
||||
if validator.throttle_semaphore:
|
||||
validator.throttle_semaphore.acquire_nowait()
|
||||
try:
|
||||
result = validator.validator(msg_forwarder, msg)
|
||||
if not result:
|
||||
return ValidationResult.REJECT
|
||||
finally:
|
||||
validator.throttle_semaphore.release()
|
||||
else:
|
||||
result = validator.validator(msg_forwarder, msg)
|
||||
if not result:
|
||||
return ValidationResult.REJECT
|
||||
except trio.WouldBlock:
|
||||
# Per-topic throttling for sync validator
|
||||
logger.debug("Sync validation throttled for topic %s", validator.topic)
|
||||
return ValidationResult.THROTTLED
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Sync validator failed for topic %s: %s", validator.topic, e
|
||||
)
|
||||
return ValidationResult.REJECT
|
||||
|
||||
# Handle async validators with global + per-topic throttling
|
||||
if async_validators:
|
||||
return await self._validate_async_validators(
|
||||
async_validators, msg_forwarder, msg
|
||||
)
|
||||
|
||||
return ValidationResult.ACCEPT
|
||||
|
||||
async def _validate_async_validators(
|
||||
self, validators: list[TopicValidator], msg_forwarder: ID, msg: rpc_pb2.Message
|
||||
) -> ValidationResult:
|
||||
"""Handle async validators with proper throttling"""
|
||||
if len(validators) == 1:
|
||||
# Fast path for single validator
|
||||
return await self._validate_single_async_validator(
|
||||
validators[0], msg_forwarder, msg
|
||||
)
|
||||
|
||||
# Multiple async validators - run them concurrently
|
||||
try:
|
||||
# Try to acquire global throttle slot
|
||||
self._global_throttle.acquire_nowait()
|
||||
except trio.WouldBlock:
|
||||
logger.debug(
|
||||
"Global validation throttle exceeded, dropping message from %s",
|
||||
msg_forwarder,
|
||||
)
|
||||
return ValidationResult.THROTTLED
|
||||
|
||||
try:
|
||||
async with trio.open_nursery() as nursery:
|
||||
results = {}
|
||||
|
||||
async def run_validator(validator: TopicValidator, index: int) -> None:
|
||||
"""Run a single async validator and store the result"""
|
||||
nonlocal results
|
||||
result = await self._validate_single_async_validator(
|
||||
validator, msg_forwarder, msg
|
||||
)
|
||||
results[index] = result
|
||||
|
||||
# Start all validators concurrently
|
||||
for i, validator in enumerate(validators):
|
||||
nursery.start_soon(run_validator, validator, i)
|
||||
|
||||
# Process results - any reject or throttle causes overall failure
|
||||
final_result = ValidationResult.ACCEPT
|
||||
for result in results.values():
|
||||
if result == ValidationResult.REJECT:
|
||||
return ValidationResult.REJECT
|
||||
elif result == ValidationResult.THROTTLED:
|
||||
final_result = ValidationResult.THROTTLED
|
||||
elif (
|
||||
result == ValidationResult.IGNORE
|
||||
and final_result == ValidationResult.ACCEPT
|
||||
):
|
||||
final_result = ValidationResult.IGNORE
|
||||
|
||||
return final_result
|
||||
|
||||
finally:
|
||||
self._global_throttle.release()
|
||||
|
||||
return ValidationResult.IGNORE
|
||||
|
||||
async def _validate_single_async_validator(
|
||||
self, validator: TopicValidator, msg_forwarder: ID, msg: rpc_pb2.Message
|
||||
) -> ValidationResult:
|
||||
"""Validate with a single async validator"""
|
||||
# Apply per-topic throttling
|
||||
if validator.throttle_semaphore:
|
||||
try:
|
||||
validator.throttle_semaphore.acquire_nowait()
|
||||
except trio.WouldBlock:
|
||||
logger.debug(
|
||||
"Per-topic validation throttled for topic %s", validator.topic
|
||||
)
|
||||
return ValidationResult.THROTTLED
|
||||
else:
|
||||
# Fallback if no throttle semaphore configured
|
||||
pass
|
||||
|
||||
try:
|
||||
# Apply timeout if configured
|
||||
result: bool
|
||||
if validator.timeout:
|
||||
with trio.fail_after(validator.timeout):
|
||||
func = cast(AsyncValidatorFn, validator.validator)
|
||||
result = await func(msg_forwarder, msg)
|
||||
else:
|
||||
func = cast(AsyncValidatorFn, validator.validator)
|
||||
result = await func(msg_forwarder, msg)
|
||||
|
||||
return ValidationResult.ACCEPT if result else ValidationResult.REJECT
|
||||
|
||||
except trio.TooSlowError:
|
||||
logger.debug("Validation timeout for topic %s", validator.topic)
|
||||
return ValidationResult.IGNORE
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Async validator failed for topic %s: %s", validator.topic, e
|
||||
)
|
||||
return ValidationResult.REJECT
|
||||
finally:
|
||||
if validator.throttle_semaphore:
|
||||
validator.throttle_semaphore.release()
|
||||
|
||||
return ValidationResult.IGNORE
|
||||
@ -15,6 +15,10 @@ from libp2p.relay.circuit_v2 import (
|
||||
RelayLimits,
|
||||
RelayResourceManager,
|
||||
Reservation,
|
||||
DCUTR_PROTOCOL_ID,
|
||||
DCUtRProtocol,
|
||||
ReachabilityChecker,
|
||||
is_private_ip,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -25,4 +29,9 @@ __all__ = [
|
||||
"RelayLimits",
|
||||
"RelayResourceManager",
|
||||
"Reservation",
|
||||
"DCUtRProtocol",
|
||||
"DCUTR_PROTOCOL_ID",
|
||||
"ReachabilityChecker",
|
||||
"is_private_ip"
|
||||
|
||||
]
|
||||
|
||||
@ -5,6 +5,16 @@ This package implements the Circuit Relay v2 protocol as specified in:
|
||||
https://github.com/libp2p/specs/blob/master/relay/circuit-v2.md
|
||||
"""
|
||||
|
||||
from .dcutr import (
|
||||
DCUtRProtocol,
|
||||
)
|
||||
from .dcutr import PROTOCOL_ID as DCUTR_PROTOCOL_ID
|
||||
|
||||
from .nat import (
|
||||
ReachabilityChecker,
|
||||
is_private_ip,
|
||||
)
|
||||
|
||||
from .discovery import (
|
||||
RelayDiscovery,
|
||||
)
|
||||
@ -29,4 +39,8 @@ __all__ = [
|
||||
"RelayResourceManager",
|
||||
"CircuitV2Transport",
|
||||
"RelayDiscovery",
|
||||
"DCUtRProtocol",
|
||||
"DCUTR_PROTOCOL_ID",
|
||||
"ReachabilityChecker",
|
||||
"is_private_ip",
|
||||
]
|
||||
|
||||
580
libp2p/relay/circuit_v2/dcutr.py
Normal file
580
libp2p/relay/circuit_v2/dcutr.py
Normal file
@ -0,0 +1,580 @@
|
||||
"""
|
||||
Direct Connection Upgrade through Relay (DCUtR) protocol implementation.
|
||||
|
||||
This module implements the DCUtR protocol as specified in:
|
||||
https://github.com/libp2p/specs/blob/master/relay/DCUtR.md
|
||||
|
||||
DCUtR enables peers behind NAT to establish direct connections
|
||||
using hole punching techniques.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
IHost,
|
||||
INetConn,
|
||||
INetStream,
|
||||
)
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
from libp2p.relay.circuit_v2.nat import (
|
||||
ReachabilityChecker,
|
||||
)
|
||||
from libp2p.relay.circuit_v2.pb.dcutr_pb2 import (
|
||||
HolePunch,
|
||||
)
|
||||
from libp2p.tools.async_service import (
|
||||
Service,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Protocol ID for DCUtR
|
||||
PROTOCOL_ID = TProtocol("/libp2p/dcutr")
|
||||
|
||||
# Maximum message size for DCUtR (4KiB as per spec)
|
||||
MAX_MESSAGE_SIZE = 4 * 1024
|
||||
|
||||
# Timeouts
|
||||
STREAM_READ_TIMEOUT = 30 # seconds
|
||||
STREAM_WRITE_TIMEOUT = 30 # seconds
|
||||
DIAL_TIMEOUT = 10 # seconds
|
||||
|
||||
# Maximum number of hole punch attempts per peer
|
||||
MAX_HOLE_PUNCH_ATTEMPTS = 5
|
||||
|
||||
# Delay between retry attempts
|
||||
HOLE_PUNCH_RETRY_DELAY = 30 # seconds
|
||||
|
||||
# Maximum observed addresses to exchange
|
||||
MAX_OBSERVED_ADDRS = 20
|
||||
|
||||
|
||||
class DCUtRProtocol(Service):
|
||||
"""
|
||||
DCUtRProtocol implements the Direct Connection Upgrade through Relay protocol.
|
||||
|
||||
This protocol allows two NATed peers to establish direct connections through
|
||||
hole punching, after they have established an initial connection through a relay.
|
||||
"""
|
||||
|
||||
def __init__(self, host: IHost):
|
||||
"""
|
||||
Initialize the DCUtR protocol.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
host : IHost
|
||||
The libp2p host this protocol is running on
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.host = host
|
||||
self.event_started = trio.Event()
|
||||
self._hole_punch_attempts: dict[ID, int] = {}
|
||||
self._direct_connections: set[ID] = set()
|
||||
self._in_progress: set[ID] = set()
|
||||
self._reachability_checker = ReachabilityChecker(host)
|
||||
self._nursery: trio.Nursery | None = None
|
||||
|
||||
async def run(self, *, task_status: Any = trio.TASK_STATUS_IGNORED) -> None:
|
||||
"""Run the protocol service."""
|
||||
try:
|
||||
# Register the DCUtR protocol handler
|
||||
logger.debug("Registering DCUtR protocol handler")
|
||||
self.host.set_stream_handler(PROTOCOL_ID, self._handle_dcutr_stream)
|
||||
|
||||
# Signal that we're ready
|
||||
self.event_started.set()
|
||||
|
||||
# Start the service
|
||||
async with trio.open_nursery() as nursery:
|
||||
self._nursery = nursery
|
||||
task_status.started()
|
||||
logger.debug("DCUtR protocol service started")
|
||||
|
||||
# Wait for service to be stopped
|
||||
await self.manager.wait_finished()
|
||||
finally:
|
||||
# Clean up
|
||||
try:
|
||||
# Use empty async lambda instead of None for stream handler
|
||||
async def empty_handler(_: INetStream) -> None:
|
||||
pass
|
||||
|
||||
self.host.set_stream_handler(PROTOCOL_ID, empty_handler)
|
||||
logger.debug("DCUtR protocol handler unregistered")
|
||||
except Exception as e:
|
||||
logger.error("Error unregistering DCUtR protocol handler: %s", str(e))
|
||||
|
||||
# Clear state
|
||||
self._hole_punch_attempts.clear()
|
||||
self._direct_connections.clear()
|
||||
self._in_progress.clear()
|
||||
self._nursery = None
|
||||
|
||||
async def _handle_dcutr_stream(self, stream: INetStream) -> None:
|
||||
"""
|
||||
Handle incoming DCUtR streams.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
stream : INetStream
|
||||
The incoming stream
|
||||
|
||||
"""
|
||||
try:
|
||||
# Get the remote peer ID
|
||||
remote_peer_id = stream.muxed_conn.peer_id
|
||||
logger.debug("Received DCUtR stream from peer %s", remote_peer_id)
|
||||
|
||||
# Check if we already have a direct connection
|
||||
if await self._have_direct_connection(remote_peer_id):
|
||||
logger.debug(
|
||||
"Already have direct connection to %s, closing stream",
|
||||
remote_peer_id,
|
||||
)
|
||||
await stream.close()
|
||||
return
|
||||
|
||||
# Check if there's already an active hole punch attempt
|
||||
if remote_peer_id in self._in_progress:
|
||||
logger.debug("Hole punch already in progress with %s", remote_peer_id)
|
||||
# Let the existing attempt continue
|
||||
await stream.close()
|
||||
return
|
||||
|
||||
# Mark as in progress
|
||||
self._in_progress.add(remote_peer_id)
|
||||
|
||||
try:
|
||||
# Read the CONNECT message
|
||||
with trio.fail_after(STREAM_READ_TIMEOUT):
|
||||
msg_bytes = await stream.read(MAX_MESSAGE_SIZE)
|
||||
|
||||
# Parse the message
|
||||
connect_msg = HolePunch()
|
||||
connect_msg.ParseFromString(msg_bytes)
|
||||
|
||||
# Verify it's a CONNECT message
|
||||
if connect_msg.type != HolePunch.CONNECT:
|
||||
logger.warning("Expected CONNECT message, got %s", connect_msg.type)
|
||||
await stream.close()
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
"Received CONNECT message from %s with %d addresses",
|
||||
remote_peer_id,
|
||||
len(connect_msg.ObsAddrs),
|
||||
)
|
||||
|
||||
# Process observed addresses from the peer
|
||||
peer_addrs = self._decode_observed_addrs(list(connect_msg.ObsAddrs))
|
||||
logger.debug("Decoded %d valid addresses from peer", len(peer_addrs))
|
||||
|
||||
# Store the addresses in the peerstore
|
||||
if peer_addrs:
|
||||
self.host.get_peerstore().add_addrs(
|
||||
remote_peer_id, peer_addrs, 10 * 60
|
||||
) # 10 minute TTL
|
||||
|
||||
# Send our CONNECT message with our observed addresses
|
||||
our_addrs = await self._get_observed_addrs()
|
||||
response = HolePunch()
|
||||
response.type = HolePunch.CONNECT
|
||||
response.ObsAddrs.extend(our_addrs)
|
||||
|
||||
with trio.fail_after(STREAM_WRITE_TIMEOUT):
|
||||
await stream.write(response.SerializeToString())
|
||||
|
||||
logger.debug(
|
||||
"Sent CONNECT response to %s with %d addresses",
|
||||
remote_peer_id,
|
||||
len(our_addrs),
|
||||
)
|
||||
|
||||
# Wait for SYNC message
|
||||
with trio.fail_after(STREAM_READ_TIMEOUT):
|
||||
sync_bytes = await stream.read(MAX_MESSAGE_SIZE)
|
||||
|
||||
# Parse the SYNC message
|
||||
sync_msg = HolePunch()
|
||||
sync_msg.ParseFromString(sync_bytes)
|
||||
|
||||
# Verify it's a SYNC message
|
||||
if sync_msg.type != HolePunch.SYNC:
|
||||
logger.warning("Expected SYNC message, got %s", sync_msg.type)
|
||||
await stream.close()
|
||||
return
|
||||
|
||||
logger.debug("Received SYNC message from %s", remote_peer_id)
|
||||
|
||||
# Perform hole punch
|
||||
success = await self._perform_hole_punch(remote_peer_id, peer_addrs)
|
||||
|
||||
if success:
|
||||
logger.info(
|
||||
"Successfully established direct connection with %s",
|
||||
remote_peer_id,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Failed to establish direct connection with %s", remote_peer_id
|
||||
)
|
||||
|
||||
except trio.TooSlowError:
|
||||
logger.warning("Timeout in DCUtR protocol with peer %s", remote_peer_id)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error in DCUtR protocol with peer %s: %s", remote_peer_id, str(e)
|
||||
)
|
||||
finally:
|
||||
# Clean up
|
||||
self._in_progress.discard(remote_peer_id)
|
||||
await stream.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error handling DCUtR stream: %s", str(e))
|
||||
await stream.close()
|
||||
|
||||
async def initiate_hole_punch(self, peer_id: ID) -> bool:
|
||||
"""
|
||||
Initiate a hole punch with a peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer to hole punch with
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if hole punch was successful, False otherwise
|
||||
|
||||
"""
|
||||
# Check if we already have a direct connection
|
||||
if await self._have_direct_connection(peer_id):
|
||||
logger.debug("Already have direct connection to %s", peer_id)
|
||||
return True
|
||||
|
||||
# Check if there's already an active hole punch attempt
|
||||
if peer_id in self._in_progress:
|
||||
logger.debug("Hole punch already in progress with %s", peer_id)
|
||||
return False
|
||||
|
||||
# Check if we've exceeded the maximum number of attempts
|
||||
attempts = self._hole_punch_attempts.get(peer_id, 0)
|
||||
if attempts >= MAX_HOLE_PUNCH_ATTEMPTS:
|
||||
logger.warning("Maximum hole punch attempts reached for peer %s", peer_id)
|
||||
return False
|
||||
|
||||
# Mark as in progress and increment attempt counter
|
||||
self._in_progress.add(peer_id)
|
||||
self._hole_punch_attempts[peer_id] = attempts + 1
|
||||
|
||||
try:
|
||||
# Open a DCUtR stream to the peer
|
||||
logger.debug("Opening DCUtR stream to peer %s", peer_id)
|
||||
stream = await self.host.new_stream(peer_id, [PROTOCOL_ID])
|
||||
if not stream:
|
||||
logger.warning("Failed to open DCUtR stream to peer %s", peer_id)
|
||||
return False
|
||||
|
||||
try:
|
||||
# Send our CONNECT message with our observed addresses
|
||||
our_addrs = await self._get_observed_addrs()
|
||||
connect_msg = HolePunch()
|
||||
connect_msg.type = HolePunch.CONNECT
|
||||
connect_msg.ObsAddrs.extend(our_addrs)
|
||||
|
||||
start_time = time.time()
|
||||
with trio.fail_after(STREAM_WRITE_TIMEOUT):
|
||||
await stream.write(connect_msg.SerializeToString())
|
||||
|
||||
logger.debug(
|
||||
"Sent CONNECT message to %s with %d addresses",
|
||||
peer_id,
|
||||
len(our_addrs),
|
||||
)
|
||||
|
||||
# Receive the peer's CONNECT message
|
||||
with trio.fail_after(STREAM_READ_TIMEOUT):
|
||||
resp_bytes = await stream.read(MAX_MESSAGE_SIZE)
|
||||
|
||||
# Calculate RTT
|
||||
rtt = time.time() - start_time
|
||||
|
||||
# Parse the response
|
||||
resp = HolePunch()
|
||||
resp.ParseFromString(resp_bytes)
|
||||
|
||||
# Verify it's a CONNECT message
|
||||
if resp.type != HolePunch.CONNECT:
|
||||
logger.warning("Expected CONNECT message, got %s", resp.type)
|
||||
return False
|
||||
|
||||
logger.debug(
|
||||
"Received CONNECT response from %s with %d addresses",
|
||||
peer_id,
|
||||
len(resp.ObsAddrs),
|
||||
)
|
||||
|
||||
# Process observed addresses from the peer
|
||||
peer_addrs = self._decode_observed_addrs(list(resp.ObsAddrs))
|
||||
logger.debug("Decoded %d valid addresses from peer", len(peer_addrs))
|
||||
|
||||
# Store the addresses in the peerstore
|
||||
if peer_addrs:
|
||||
self.host.get_peerstore().add_addrs(
|
||||
peer_id, peer_addrs, 10 * 60
|
||||
) # 10 minute TTL
|
||||
|
||||
# Send SYNC message with timing information
|
||||
# We'll use a future time that's 2*RTT from now to ensure both sides
|
||||
# are ready
|
||||
punch_time = time.time() + (2 * rtt) + 1 # Add 1 second buffer
|
||||
|
||||
sync_msg = HolePunch()
|
||||
sync_msg.type = HolePunch.SYNC
|
||||
|
||||
with trio.fail_after(STREAM_WRITE_TIMEOUT):
|
||||
await stream.write(sync_msg.SerializeToString())
|
||||
|
||||
logger.debug("Sent SYNC message to %s", peer_id)
|
||||
|
||||
# Perform the synchronized hole punch
|
||||
success = await self._perform_hole_punch(
|
||||
peer_id, peer_addrs, punch_time
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(
|
||||
"Successfully established direct connection with %s", peer_id
|
||||
)
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
"Failed to establish direct connection with %s", peer_id
|
||||
)
|
||||
return False
|
||||
|
||||
except trio.TooSlowError:
|
||||
logger.warning("Timeout in DCUtR protocol with peer %s", peer_id)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error in DCUtR protocol with peer %s: %s", peer_id, str(e)
|
||||
)
|
||||
return False
|
||||
finally:
|
||||
await stream.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error initiating hole punch with peer %s: %s", peer_id, str(e)
|
||||
)
|
||||
return False
|
||||
finally:
|
||||
self._in_progress.discard(peer_id)
|
||||
|
||||
# This should never be reached, but add explicit return for type checking
|
||||
return False
|
||||
|
||||
async def _perform_hole_punch(
|
||||
self, peer_id: ID, addrs: list[Multiaddr], punch_time: float | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Perform a hole punch attempt with a peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer to hole punch with
|
||||
addrs : list[Multiaddr]
|
||||
List of addresses to try
|
||||
punch_time : Optional[float]
|
||||
Time to perform the punch (if None, do it immediately)
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if hole punch was successful
|
||||
|
||||
"""
|
||||
if not addrs:
|
||||
logger.warning("No addresses to try for hole punch with %s", peer_id)
|
||||
return False
|
||||
|
||||
# If punch_time is specified, wait until that time
|
||||
if punch_time is not None:
|
||||
now = time.time()
|
||||
if punch_time > now:
|
||||
wait_time = punch_time - now
|
||||
logger.debug("Waiting %.2f seconds before hole punch", wait_time)
|
||||
await trio.sleep(wait_time)
|
||||
|
||||
# Try to dial each address
|
||||
logger.debug(
|
||||
"Starting hole punch with peer %s using %d addresses", peer_id, len(addrs)
|
||||
)
|
||||
|
||||
# Filter to only include non-relay addresses
|
||||
direct_addrs = [
|
||||
addr for addr in addrs if not str(addr).startswith("/p2p-circuit")
|
||||
]
|
||||
|
||||
if not direct_addrs:
|
||||
logger.warning("No direct addresses found for peer %s", peer_id)
|
||||
return False
|
||||
|
||||
# Start dialing attempts in parallel
|
||||
async with trio.open_nursery() as nursery:
|
||||
for addr in direct_addrs[
|
||||
:5
|
||||
]: # Limit to 5 addresses to avoid too many connections
|
||||
nursery.start_soon(self._dial_peer, peer_id, addr)
|
||||
|
||||
# Check if we established a direct connection
|
||||
return await self._have_direct_connection(peer_id)
|
||||
|
||||
async def _dial_peer(self, peer_id: ID, addr: Multiaddr) -> None:
|
||||
"""
|
||||
Attempt to dial a peer at a specific address.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer to dial
|
||||
addr : Multiaddr
|
||||
The address to dial
|
||||
|
||||
"""
|
||||
try:
|
||||
logger.debug("Attempting to dial %s at %s", peer_id, addr)
|
||||
|
||||
# Create peer info
|
||||
peer_info = PeerInfo(peer_id, [addr])
|
||||
|
||||
# Try to connect with timeout
|
||||
with trio.fail_after(DIAL_TIMEOUT):
|
||||
await self.host.connect(peer_info)
|
||||
|
||||
logger.info("Successfully connected to %s at %s", peer_id, addr)
|
||||
|
||||
# Add to direct connections set
|
||||
self._direct_connections.add(peer_id)
|
||||
|
||||
except trio.TooSlowError:
|
||||
logger.debug("Timeout dialing %s at %s", peer_id, addr)
|
||||
except Exception as e:
|
||||
logger.debug("Error dialing %s at %s: %s", peer_id, addr, str(e))
|
||||
|
||||
async def _have_direct_connection(self, peer_id: ID) -> bool:
|
||||
"""
|
||||
Check if we already have a direct connection to a peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer to check
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if we have a direct connection, False otherwise
|
||||
|
||||
"""
|
||||
# Check our direct connections cache first
|
||||
if peer_id in self._direct_connections:
|
||||
return True
|
||||
|
||||
# Check if the peer is connected
|
||||
network = self.host.get_network()
|
||||
conn_or_conns = network.connections.get(peer_id)
|
||||
if not conn_or_conns:
|
||||
return False
|
||||
|
||||
# Handle both single connection and list of connections
|
||||
connections: list[INetConn] = (
|
||||
[conn_or_conns] if not isinstance(conn_or_conns, list) else conn_or_conns
|
||||
)
|
||||
|
||||
# Check if any connection is direct (not relayed)
|
||||
for conn in connections:
|
||||
# Get the transport addresses
|
||||
addrs = conn.get_transport_addresses()
|
||||
|
||||
# If any address doesn't start with /p2p-circuit, it's a direct connection
|
||||
if any(not str(addr).startswith("/p2p-circuit") for addr in addrs):
|
||||
# Cache this result
|
||||
self._direct_connections.add(peer_id)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _get_observed_addrs(self) -> list[bytes]:
|
||||
"""
|
||||
Get our observed addresses to share with the peer.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[bytes]
|
||||
List of observed addresses as bytes
|
||||
|
||||
"""
|
||||
# Get all listen addresses
|
||||
addrs = self.host.get_addrs()
|
||||
|
||||
# Filter out relay addresses
|
||||
direct_addrs = [
|
||||
addr for addr in addrs if not str(addr).startswith("/p2p-circuit")
|
||||
]
|
||||
|
||||
# Limit the number of addresses
|
||||
if len(direct_addrs) > MAX_OBSERVED_ADDRS:
|
||||
direct_addrs = direct_addrs[:MAX_OBSERVED_ADDRS]
|
||||
|
||||
# Convert to bytes
|
||||
addr_bytes = [addr.to_bytes() for addr in direct_addrs]
|
||||
|
||||
return addr_bytes
|
||||
|
||||
def _decode_observed_addrs(self, addr_bytes: list[bytes]) -> list[Multiaddr]:
|
||||
"""
|
||||
Decode observed addresses received from a peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
addr_bytes : List[bytes]
|
||||
The encoded addresses
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[Multiaddr]
|
||||
The decoded multiaddresses
|
||||
|
||||
"""
|
||||
result = []
|
||||
|
||||
for addr_byte in addr_bytes:
|
||||
try:
|
||||
addr = Multiaddr(addr_byte)
|
||||
# Validate the address (basic check)
|
||||
if str(addr).startswith("/ip"):
|
||||
result.append(addr)
|
||||
except Exception as e:
|
||||
logger.debug("Error decoding multiaddr: %s", str(e))
|
||||
|
||||
return result
|
||||
300
libp2p/relay/circuit_v2/nat.py
Normal file
300
libp2p/relay/circuit_v2/nat.py
Normal file
@ -0,0 +1,300 @@
|
||||
"""
|
||||
NAT traversal utilities for libp2p.
|
||||
|
||||
This module provides utilities for NAT traversal and reachability detection.
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
)
|
||||
|
||||
from libp2p.abc import (
|
||||
IHost,
|
||||
INetConn,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("libp2p.relay.circuit_v2.nat")
|
||||
|
||||
# Timeout for reachability checks
|
||||
REACHABILITY_TIMEOUT = 10 # seconds
|
||||
|
||||
# Define private IP ranges
|
||||
PRIVATE_IP_RANGES = [
|
||||
("10.0.0.0", "10.255.255.255"), # Class A private network: 10.0.0.0/8
|
||||
("172.16.0.0", "172.31.255.255"), # Class B private network: 172.16.0.0/12
|
||||
("192.168.0.0", "192.168.255.255"), # Class C private network: 192.168.0.0/16
|
||||
]
|
||||
|
||||
# Link-local address range: 169.254.0.0/16
|
||||
LINK_LOCAL_RANGE = ("169.254.0.0", "169.254.255.255")
|
||||
|
||||
# Loopback address range: 127.0.0.0/8
|
||||
LOOPBACK_RANGE = ("127.0.0.0", "127.255.255.255")
|
||||
|
||||
|
||||
def ip_to_int(ip: str) -> int:
|
||||
"""
|
||||
Convert an IP address to an integer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ip : str
|
||||
IP address to convert
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Integer representation of the IP
|
||||
|
||||
"""
|
||||
try:
|
||||
return int(ipaddress.IPv4Address(ip))
|
||||
except ipaddress.AddressValueError:
|
||||
# Handle IPv6 addresses
|
||||
return int(ipaddress.IPv6Address(ip))
|
||||
|
||||
|
||||
def is_ip_in_range(ip: str, start_range: str, end_range: str) -> bool:
|
||||
"""
|
||||
Check if an IP address is within a range.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ip : str
|
||||
IP address to check
|
||||
start_range : str
|
||||
Start of the range
|
||||
end_range : str
|
||||
End of the range
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the IP is in the range
|
||||
|
||||
"""
|
||||
try:
|
||||
ip_int = ip_to_int(ip)
|
||||
start_int = ip_to_int(start_range)
|
||||
end_int = ip_to_int(end_range)
|
||||
return start_int <= ip_int <= end_int
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_private_ip(ip: str) -> bool:
|
||||
"""
|
||||
Check if an IP address is private.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ip : str
|
||||
IP address to check
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if IP is private
|
||||
|
||||
"""
|
||||
for start_range, end_range in PRIVATE_IP_RANGES:
|
||||
if is_ip_in_range(ip, start_range, end_range):
|
||||
return True
|
||||
|
||||
# Check for link-local addresses
|
||||
if is_ip_in_range(ip, *LINK_LOCAL_RANGE):
|
||||
return True
|
||||
|
||||
# Check for loopback addresses
|
||||
if is_ip_in_range(ip, *LOOPBACK_RANGE):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def extract_ip_from_multiaddr(addr: Multiaddr) -> str | None:
|
||||
"""
|
||||
Extract the IP address from a multiaddr.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
addr : Multiaddr
|
||||
Multiaddr to extract from
|
||||
|
||||
Returns
|
||||
-------
|
||||
Optional[str]
|
||||
IP address or None if not found
|
||||
|
||||
"""
|
||||
# Convert to string representation
|
||||
addr_str = str(addr)
|
||||
|
||||
# Look for IPv4 address
|
||||
ipv4_start = addr_str.find("/ip4/")
|
||||
if ipv4_start != -1:
|
||||
# Extract the IPv4 address
|
||||
ipv4_end = addr_str.find("/", ipv4_start + 5)
|
||||
if ipv4_end != -1:
|
||||
return addr_str[ipv4_start + 5 : ipv4_end]
|
||||
|
||||
# Look for IPv6 address
|
||||
ipv6_start = addr_str.find("/ip6/")
|
||||
if ipv6_start != -1:
|
||||
# Extract the IPv6 address
|
||||
ipv6_end = addr_str.find("/", ipv6_start + 5)
|
||||
if ipv6_end != -1:
|
||||
return addr_str[ipv6_start + 5 : ipv6_end]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class ReachabilityChecker:
|
||||
"""
|
||||
Utility class for checking peer reachability.
|
||||
|
||||
This class assesses whether a peer's addresses are likely
|
||||
to be directly reachable or behind NAT.
|
||||
"""
|
||||
|
||||
def __init__(self, host: IHost):
|
||||
"""
|
||||
Initialize the reachability checker.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
host : IHost
|
||||
The libp2p host
|
||||
|
||||
"""
|
||||
self.host = host
|
||||
self._peer_reachability: dict[ID, bool] = {}
|
||||
self._known_public_peers: set[ID] = set()
|
||||
|
||||
def is_addr_public(self, addr: Multiaddr) -> bool:
|
||||
"""
|
||||
Check if an address is likely to be publicly reachable.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
addr : Multiaddr
|
||||
The multiaddr to check
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if address is likely public
|
||||
|
||||
"""
|
||||
# Extract the IP address
|
||||
ip = extract_ip_from_multiaddr(addr)
|
||||
if not ip:
|
||||
return False
|
||||
|
||||
# Check if it's a private IP
|
||||
return not is_private_ip(ip)
|
||||
|
||||
def get_public_addrs(self, addrs: list[Multiaddr]) -> list[Multiaddr]:
|
||||
"""
|
||||
Filter a list of addresses to only include likely public ones.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
addrs : List[Multiaddr]
|
||||
List of addresses to filter
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[Multiaddr]
|
||||
List of likely public addresses
|
||||
|
||||
"""
|
||||
return [addr for addr in addrs if self.is_addr_public(addr)]
|
||||
|
||||
async def check_peer_reachability(self, peer_id: ID) -> bool:
|
||||
"""
|
||||
Check if a peer is directly reachable.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer ID to check
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if peer is likely directly reachable
|
||||
|
||||
"""
|
||||
# Check if we already know
|
||||
if peer_id in self._peer_reachability:
|
||||
return self._peer_reachability[peer_id]
|
||||
|
||||
# Check if the peer is connected
|
||||
network = self.host.get_network()
|
||||
connections: INetConn | list[INetConn] | None = network.connections.get(peer_id)
|
||||
if not connections:
|
||||
# Not connected, can't determine reachability
|
||||
return False
|
||||
|
||||
# Check if any connection is direct (not relayed)
|
||||
if isinstance(connections, list):
|
||||
for conn in connections:
|
||||
# Get the transport addresses
|
||||
addrs = conn.get_transport_addresses()
|
||||
|
||||
# If any address doesn't start with /p2p-circuit,
|
||||
# it's a direct connection
|
||||
if any(not str(addr).startswith("/p2p-circuit") for addr in addrs):
|
||||
self._peer_reachability[peer_id] = True
|
||||
return True
|
||||
else:
|
||||
# Handle single connection case
|
||||
addrs = connections.get_transport_addresses()
|
||||
if any(not str(addr).startswith("/p2p-circuit") for addr in addrs):
|
||||
self._peer_reachability[peer_id] = True
|
||||
return True
|
||||
|
||||
# Get the peer's addresses from peerstore
|
||||
try:
|
||||
addrs = self.host.get_peerstore().addrs(peer_id)
|
||||
# Check if peer has any public addresses
|
||||
public_addrs = self.get_public_addrs(addrs)
|
||||
if public_addrs:
|
||||
self._peer_reachability[peer_id] = True
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug("Error getting peer addresses: %s", str(e))
|
||||
|
||||
# Default to not directly reachable
|
||||
self._peer_reachability[peer_id] = False
|
||||
return False
|
||||
|
||||
async def check_self_reachability(self) -> tuple[bool, list[Multiaddr]]:
|
||||
"""
|
||||
Check if this host is likely directly reachable.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[bool, List[Multiaddr]]
|
||||
Tuple of (is_reachable, public_addresses)
|
||||
|
||||
"""
|
||||
# Get all host addresses
|
||||
addrs = self.host.get_addrs()
|
||||
|
||||
# Filter for public addresses
|
||||
public_addrs = self.get_public_addrs(addrs)
|
||||
|
||||
# If we have public addresses, assume we're reachable
|
||||
# This is a simplified assumption - real reachability would need
|
||||
# external checking
|
||||
is_reachable = len(public_addrs) > 0
|
||||
|
||||
return is_reachable, public_addrs
|
||||
@ -5,6 +5,11 @@ Contains generated protobuf code for circuit_v2 relay protocol.
|
||||
"""
|
||||
|
||||
# Import the classes to be accessible directly from the package
|
||||
|
||||
from .dcutr_pb2 import (
|
||||
HolePunch,
|
||||
)
|
||||
|
||||
from .circuit_pb2 import (
|
||||
HopMessage,
|
||||
Limit,
|
||||
@ -13,4 +18,4 @@ from .circuit_pb2 import (
|
||||
StopMessage,
|
||||
)
|
||||
|
||||
__all__ = ["HopMessage", "Limit", "Reservation", "Status", "StopMessage"]
|
||||
__all__ = ["HopMessage", "Limit", "Reservation", "Status", "StopMessage", "HolePunch"]
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# NO CHECKED-IN PROTOBUF GENCODE
|
||||
# source: libp2p/relay/circuit_v2/pb/circuit.proto
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf.internal import builder as _builder
|
||||
@ -12,11 +11,14 @@ from google.protobuf import symbol_database as _symbol_database
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n(libp2p/relay/circuit_v2/pb/circuit.proto\x12\rcircuit.pb.v2\"\xf3\x01\n\nHopMessage\x12,\n\x04type\x18\x01 \x01(\x0e\x32\x1e.circuit.pb.v2.HopMessage.Type\x12\x0c\n\x04peer\x18\x02 \x01(\x0c\x12/\n\x0breservation\x18\x03 \x01(\x0b\x32\x1a.circuit.pb.v2.Reservation\x12#\n\x05limit\x18\x04 \x01(\x0b\x32\x14.circuit.pb.v2.Limit\x12%\n\x06status\x18\x05 \x01(\x0b\x32\x15.circuit.pb.v2.Status\",\n\x04Type\x12\x0b\n\x07RESERVE\x10\x00\x12\x0b\n\x07\x43ONNECT\x10\x01\x12\n\n\x06STATUS\x10\x02\"\x92\x01\n\x0bStopMessage\x12-\n\x04type\x18\x01 \x01(\x0e\x32\x1f.circuit.pb.v2.StopMessage.Type\x12\x0c\n\x04peer\x18\x02 \x01(\x0c\x12%\n\x06status\x18\x03 \x01(\x0b\x32\x15.circuit.pb.v2.Status\"\x1f\n\x04Type\x12\x0b\n\x07\x43ONNECT\x10\x00\x12\n\n\x06STATUS\x10\x01\"A\n\x0bReservation\x12\x0f\n\x07voucher\x18\x01 \x01(\x0c\x12\x11\n\tsignature\x18\x02 \x01(\x0c\x12\x0e\n\x06\x65xpire\x18\x03 \x01(\x03\"\'\n\x05Limit\x12\x10\n\x08\x64uration\x18\x01 \x01(\x03\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x03\"\xf6\x01\n\x06Status\x12(\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x1a.circuit.pb.v2.Status.Code\x12\x0f\n\x07message\x18\x02 \x01(\t\"\xb0\x01\n\x04\x43ode\x12\x06\n\x02OK\x10\x00\x12\x17\n\x13RESERVATION_REFUSED\x10\x64\x12\x1b\n\x17RESOURCE_LIMIT_EXCEEDED\x10\x65\x12\x15\n\x11PERMISSION_DENIED\x10\x66\x12\x16\n\x11\x43ONNECTION_FAILED\x10\xc8\x01\x12\x11\n\x0c\x44IAL_REFUSED\x10\xc9\x01\x12\x10\n\x0bSTOP_FAILED\x10\xac\x02\x12\x16\n\x11MALFORMED_MESSAGE\x10\x90\x03\x62\x06proto3')
|
||||
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.relay.circuit_v2.pb.circuit_pb2', globals())
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
|
||||
DESCRIPTOR._options = None
|
||||
_HOPMESSAGE._serialized_start=60
|
||||
_HOPMESSAGE._serialized_end=303
|
||||
|
||||
14
libp2p/relay/circuit_v2/pb/dcutr.proto
Normal file
14
libp2p/relay/circuit_v2/pb/dcutr.proto
Normal file
@ -0,0 +1,14 @@
|
||||
syntax = "proto2";
|
||||
|
||||
package holepunch.pb;
|
||||
|
||||
message HolePunch {
|
||||
enum Type {
|
||||
CONNECT = 100;
|
||||
SYNC = 300;
|
||||
}
|
||||
|
||||
required Type type = 1;
|
||||
|
||||
repeated bytes ObsAddrs = 2;
|
||||
}
|
||||
27
libp2p/relay/circuit_v2/pb/dcutr_pb2.py
Normal file
27
libp2p/relay/circuit_v2/pb/dcutr_pb2.py
Normal file
@ -0,0 +1,27 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: libp2p/relay/circuit_v2/pb/dcutr.proto
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf.internal import builder as _builder
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n&libp2p/relay/circuit_v2/pb/dcutr.proto\x12\x0cholepunch.pb\"i\n\tHolePunch\x12*\n\x04type\x18\x01 \x02(\x0e\x32\x1c.holepunch.pb.HolePunch.Type\x12\x10\n\x08ObsAddrs\x18\x02 \x03(\x0c\"\x1e\n\x04Type\x12\x0b\n\x07\x43ONNECT\x10\x64\x12\t\n\x04SYNC\x10\xac\x02')
|
||||
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.relay.circuit_v2.pb.dcutr_pb2', globals())
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
|
||||
DESCRIPTOR._options = None
|
||||
_HOLEPUNCH._serialized_start=56
|
||||
_HOLEPUNCH._serialized_end=161
|
||||
_HOLEPUNCH_TYPE._serialized_start=131
|
||||
_HOLEPUNCH_TYPE._serialized_end=161
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
53
libp2p/relay/circuit_v2/pb/dcutr_pb2.pyi
Normal file
53
libp2p/relay/circuit_v2/pb/dcutr_pb2.pyi
Normal file
@ -0,0 +1,53 @@
|
||||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import collections.abc
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.internal.containers
|
||||
import google.protobuf.internal.enum_type_wrapper
|
||||
import google.protobuf.message
|
||||
import sys
|
||||
import typing
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
import typing as typing_extensions
|
||||
else:
|
||||
import typing_extensions
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
|
||||
@typing.final
|
||||
class HolePunch(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
class _Type:
|
||||
ValueType = typing.NewType("ValueType", builtins.int)
|
||||
V: typing_extensions.TypeAlias = ValueType
|
||||
|
||||
class _TypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[HolePunch._Type.ValueType], builtins.type):
|
||||
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
||||
CONNECT: HolePunch._Type.ValueType # 100
|
||||
SYNC: HolePunch._Type.ValueType # 300
|
||||
|
||||
class Type(_Type, metaclass=_TypeEnumTypeWrapper): ...
|
||||
CONNECT: HolePunch.Type.ValueType # 100
|
||||
SYNC: HolePunch.Type.ValueType # 300
|
||||
|
||||
TYPE_FIELD_NUMBER: builtins.int
|
||||
OBSADDRS_FIELD_NUMBER: builtins.int
|
||||
type: global___HolePunch.Type.ValueType
|
||||
@property
|
||||
def ObsAddrs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
type: global___HolePunch.Type.ValueType | None = ...,
|
||||
ObsAddrs: collections.abc.Iterable[builtins.bytes] | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["type", b"type"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["ObsAddrs", b"ObsAddrs", "type", b"type"]) -> None: ...
|
||||
|
||||
global___HolePunch = HolePunch
|
||||
68
libp2p/security/noise/early_data.py
Normal file
68
libp2p/security/noise/early_data.py
Normal file
@ -0,0 +1,68 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from libp2p.abc import IRawConnection
|
||||
from libp2p.custom_types import TProtocol
|
||||
from libp2p.peer.id import ID
|
||||
|
||||
from .pb import noise_pb2 as noise_pb
|
||||
|
||||
|
||||
class EarlyDataHandler(ABC):
|
||||
"""Interface for handling early data during Noise handshake"""
|
||||
|
||||
@abstractmethod
|
||||
async def send(
|
||||
self, conn: IRawConnection, peer_id: ID
|
||||
) -> noise_pb.NoiseExtensions | None:
|
||||
"""Called to generate early data to send during handshake"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def received(
|
||||
self, conn: IRawConnection, extensions: noise_pb.NoiseExtensions | None
|
||||
) -> None:
|
||||
"""Called when early data is received during handshake"""
|
||||
pass
|
||||
|
||||
|
||||
class TransportEarlyDataHandler(EarlyDataHandler):
|
||||
"""Default early data handler for muxer negotiation"""
|
||||
|
||||
def __init__(self, supported_muxers: list[TProtocol]):
|
||||
self.supported_muxers = supported_muxers
|
||||
self.received_muxers: list[TProtocol] = []
|
||||
|
||||
async def send(
|
||||
self, conn: IRawConnection, peer_id: ID
|
||||
) -> noise_pb.NoiseExtensions | None:
|
||||
"""Send our supported muxers list"""
|
||||
if not self.supported_muxers:
|
||||
return None
|
||||
|
||||
extensions = noise_pb.NoiseExtensions()
|
||||
# Convert TProtocol to string for serialization
|
||||
extensions.stream_muxers[:] = [str(muxer) for muxer in self.supported_muxers]
|
||||
return extensions
|
||||
|
||||
async def received(
|
||||
self, conn: IRawConnection, extensions: noise_pb.NoiseExtensions | None
|
||||
) -> None:
|
||||
"""Store received muxers list"""
|
||||
if extensions and extensions.stream_muxers:
|
||||
self.received_muxers = [
|
||||
TProtocol(muxer) for muxer in extensions.stream_muxers
|
||||
]
|
||||
|
||||
def match_muxers(self, is_initiator: bool) -> TProtocol | None:
|
||||
"""Find first common muxer between local and remote"""
|
||||
if is_initiator:
|
||||
# Initiator: find first local muxer that remote supports
|
||||
for local_muxer in self.supported_muxers:
|
||||
if local_muxer in self.received_muxers:
|
||||
return local_muxer
|
||||
else:
|
||||
# Responder: find first remote muxer that we support
|
||||
for remote_muxer in self.received_muxers:
|
||||
if remote_muxer in self.supported_muxers:
|
||||
return remote_muxer
|
||||
return None
|
||||
@ -41,7 +41,8 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter):
|
||||
read_writer: NoisePacketReadWriter
|
||||
noise_state: NoiseState
|
||||
|
||||
# FIXME: This prefix is added in msg#3 in Go. Check whether it's a desired behavior.
|
||||
# NOTE: This prefix is added in msg#3 in Go.
|
||||
# Support in py-libp2p is available but not used
|
||||
prefix: bytes = b"\x00" * 32
|
||||
|
||||
def __init__(self, conn: IRawConnection, noise_state: NoiseState) -> None:
|
||||
|
||||
@ -30,6 +30,9 @@ from libp2p.security.secure_session import (
|
||||
SecureSession,
|
||||
)
|
||||
|
||||
from .early_data import (
|
||||
EarlyDataHandler,
|
||||
)
|
||||
from .exceptions import (
|
||||
HandshakeHasNotFinished,
|
||||
InvalidSignature,
|
||||
@ -45,6 +48,7 @@ from .messages import (
|
||||
make_handshake_payload_sig,
|
||||
verify_handshake_payload_sig,
|
||||
)
|
||||
from .pb import noise_pb2 as noise_pb
|
||||
|
||||
|
||||
class IPattern(ABC):
|
||||
@ -62,7 +66,8 @@ class BasePattern(IPattern):
|
||||
noise_static_key: PrivateKey
|
||||
local_peer: ID
|
||||
libp2p_privkey: PrivateKey
|
||||
early_data: bytes | None
|
||||
initiator_early_data_handler: EarlyDataHandler | None
|
||||
responder_early_data_handler: EarlyDataHandler | None
|
||||
|
||||
def create_noise_state(self) -> NoiseState:
|
||||
noise_state = NoiseState.from_name(self.protocol_name)
|
||||
@ -73,11 +78,50 @@ class BasePattern(IPattern):
|
||||
raise NoiseStateError("noise_protocol is not initialized")
|
||||
return noise_state
|
||||
|
||||
def make_handshake_payload(self) -> NoiseHandshakePayload:
|
||||
async def make_handshake_payload(
|
||||
self, conn: IRawConnection, peer_id: ID, is_initiator: bool
|
||||
) -> NoiseHandshakePayload:
|
||||
signature = make_handshake_payload_sig(
|
||||
self.libp2p_privkey, self.noise_static_key.get_public_key()
|
||||
)
|
||||
return NoiseHandshakePayload(self.libp2p_privkey.get_public_key(), signature)
|
||||
|
||||
# NEW: Get early data from appropriate handler
|
||||
extensions = None
|
||||
if is_initiator and self.initiator_early_data_handler:
|
||||
extensions = await self.initiator_early_data_handler.send(conn, peer_id)
|
||||
elif not is_initiator and self.responder_early_data_handler:
|
||||
extensions = await self.responder_early_data_handler.send(conn, peer_id)
|
||||
|
||||
# NEW: Serialize extensions into early_data field
|
||||
early_data = None
|
||||
if extensions:
|
||||
early_data = extensions.SerializeToString()
|
||||
|
||||
return NoiseHandshakePayload(
|
||||
self.libp2p_privkey.get_public_key(),
|
||||
signature,
|
||||
early_data, # ← This is the key addition
|
||||
)
|
||||
|
||||
async def handle_received_payload(
|
||||
self, conn: IRawConnection, payload: NoiseHandshakePayload, is_initiator: bool
|
||||
) -> None:
|
||||
"""Process early data from received payload"""
|
||||
if not payload.early_data:
|
||||
return
|
||||
|
||||
# Deserialize the NoiseExtensions from early_data field
|
||||
try:
|
||||
extensions = noise_pb.NoiseExtensions.FromString(payload.early_data)
|
||||
except Exception:
|
||||
# Invalid extensions, ignore silently
|
||||
return
|
||||
|
||||
# Pass to appropriate handler
|
||||
if is_initiator and self.initiator_early_data_handler:
|
||||
await self.initiator_early_data_handler.received(conn, extensions)
|
||||
elif not is_initiator and self.responder_early_data_handler:
|
||||
await self.responder_early_data_handler.received(conn, extensions)
|
||||
|
||||
|
||||
class PatternXX(BasePattern):
|
||||
@ -86,13 +130,15 @@ class PatternXX(BasePattern):
|
||||
local_peer: ID,
|
||||
libp2p_privkey: PrivateKey,
|
||||
noise_static_key: PrivateKey,
|
||||
early_data: bytes | None = None,
|
||||
initiator_early_data_handler: EarlyDataHandler | None,
|
||||
responder_early_data_handler: EarlyDataHandler | None,
|
||||
) -> None:
|
||||
self.protocol_name = b"Noise_XX_25519_ChaChaPoly_SHA256"
|
||||
self.local_peer = local_peer
|
||||
self.libp2p_privkey = libp2p_privkey
|
||||
self.noise_static_key = noise_static_key
|
||||
self.early_data = early_data
|
||||
self.initiator_early_data_handler = initiator_early_data_handler
|
||||
self.responder_early_data_handler = responder_early_data_handler
|
||||
|
||||
async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn:
|
||||
noise_state = self.create_noise_state()
|
||||
@ -106,18 +152,23 @@ class PatternXX(BasePattern):
|
||||
|
||||
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
|
||||
|
||||
# Consume msg#1.
|
||||
# 1. Consume msg#1 (just empty bytes)
|
||||
await read_writer.read_msg()
|
||||
|
||||
# Send msg#2, which should include our handshake payload.
|
||||
our_payload = self.make_handshake_payload()
|
||||
# 2. Send msg#2 with our payload INCLUDING EARLY DATA
|
||||
our_payload = await self.make_handshake_payload(
|
||||
conn,
|
||||
self.local_peer, # We send our own peer ID in responder role
|
||||
is_initiator=False,
|
||||
)
|
||||
msg_2 = our_payload.serialize()
|
||||
await read_writer.write_msg(msg_2)
|
||||
|
||||
# Receive and consume msg#3.
|
||||
# 3. Receive msg#3
|
||||
msg_3 = await read_writer.read_msg()
|
||||
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_3)
|
||||
|
||||
# Extract remote pubkey from noise handshake state
|
||||
if handshake_state.rs is None:
|
||||
raise NoiseStateError(
|
||||
"something is wrong in the underlying noise `handshake_state`: "
|
||||
@ -126,14 +177,31 @@ class PatternXX(BasePattern):
|
||||
)
|
||||
remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs)
|
||||
|
||||
# 4. Verify signature (unchanged)
|
||||
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
|
||||
raise InvalidSignature
|
||||
|
||||
# NEW: Process early data from msg#3 AFTER signature verification
|
||||
await self.handle_received_payload(
|
||||
conn, peer_handshake_payload, is_initiator=False
|
||||
)
|
||||
|
||||
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
|
||||
|
||||
if not noise_state.handshake_finished:
|
||||
raise HandshakeHasNotFinished(
|
||||
"handshake is done but it is not marked as finished in `noise_state`"
|
||||
)
|
||||
|
||||
# NEW: Get negotiated muxer for connection state
|
||||
# negotiated_muxer = None
|
||||
if self.responder_early_data_handler and hasattr(
|
||||
self.responder_early_data_handler, "match_muxers"
|
||||
):
|
||||
# negotiated_muxer =
|
||||
# self.responder_early_data_handler.match_muxers(is_initiator=False)
|
||||
pass
|
||||
|
||||
transport_read_writer = NoiseTransportReadWriter(conn, noise_state)
|
||||
return SecureSession(
|
||||
local_peer=self.local_peer,
|
||||
@ -142,6 +210,8 @@ class PatternXX(BasePattern):
|
||||
remote_permanent_pubkey=remote_pubkey,
|
||||
is_initiator=False,
|
||||
conn=transport_read_writer,
|
||||
# NOTE: negotiated_muxer would need to be added to SecureSession constructor
|
||||
# For now, store it in connection metadata or similar
|
||||
)
|
||||
|
||||
async def handshake_outbound(
|
||||
@ -158,24 +228,27 @@ class PatternXX(BasePattern):
|
||||
if handshake_state is None:
|
||||
raise NoiseStateError("Handshake state is not initialized")
|
||||
|
||||
# Send msg#1, which is *not* encrypted.
|
||||
# 1. Send msg#1 (empty) - no early data possible in XX pattern
|
||||
msg_1 = b""
|
||||
await read_writer.write_msg(msg_1)
|
||||
|
||||
# Read msg#2 from the remote, which contains the public key of the peer.
|
||||
# 2. Read msg#2 from responder
|
||||
msg_2 = await read_writer.read_msg()
|
||||
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_2)
|
||||
|
||||
# Extract remote pubkey from noise handshake state
|
||||
if handshake_state.rs is None:
|
||||
raise NoiseStateError(
|
||||
"something is wrong in the underlying noise `handshake_state`: "
|
||||
"we received and consumed msg#3, which should have included the "
|
||||
"we received and consumed msg#2, which should have included the "
|
||||
"remote static public key, but it is not present in the handshake_state"
|
||||
)
|
||||
remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs)
|
||||
|
||||
# Verify signature BEFORE processing early data (security)
|
||||
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
|
||||
raise InvalidSignature
|
||||
|
||||
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
|
||||
if remote_peer_id_from_pubkey != remote_peer:
|
||||
raise PeerIDMismatchesPubkey(
|
||||
@ -184,8 +257,15 @@ class PatternXX(BasePattern):
|
||||
f"remote_peer_id_from_pubkey={remote_peer_id_from_pubkey}"
|
||||
)
|
||||
|
||||
# Send msg#3, which includes our encrypted payload and our noise static key.
|
||||
our_payload = self.make_handshake_payload()
|
||||
# NEW: Process early data from msg#2 AFTER verification
|
||||
await self.handle_received_payload(
|
||||
conn, peer_handshake_payload, is_initiator=True
|
||||
)
|
||||
|
||||
# 3. Send msg#3 with our payload INCLUDING EARLY DATA
|
||||
our_payload = await self.make_handshake_payload(
|
||||
conn, remote_peer, is_initiator=True
|
||||
)
|
||||
msg_3 = our_payload.serialize()
|
||||
await read_writer.write_msg(msg_3)
|
||||
|
||||
@ -193,6 +273,16 @@ class PatternXX(BasePattern):
|
||||
raise HandshakeHasNotFinished(
|
||||
"handshake is done but it is not marked as finished in `noise_state`"
|
||||
)
|
||||
|
||||
# NEW: Get negotiated muxer
|
||||
# negotiated_muxer = None
|
||||
if self.initiator_early_data_handler and hasattr(
|
||||
self.initiator_early_data_handler, "match_muxers"
|
||||
):
|
||||
pass
|
||||
# negotiated_muxer =
|
||||
# self.initiator_early_data_handler.match_muxers(is_initiator=True)
|
||||
|
||||
transport_read_writer = NoiseTransportReadWriter(conn, noise_state)
|
||||
return SecureSession(
|
||||
local_peer=self.local_peer,
|
||||
@ -201,6 +291,8 @@ class PatternXX(BasePattern):
|
||||
remote_permanent_pubkey=remote_pubkey,
|
||||
is_initiator=True,
|
||||
conn=transport_read_writer,
|
||||
# NOTE: negotiated_muxer would need to be added to SecureSession constructor
|
||||
# For now, store it in connection metadata or similar
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -1,8 +1,13 @@
|
||||
syntax = "proto3";
|
||||
syntax = "proto2";
|
||||
package pb;
|
||||
|
||||
message NoiseHandshakePayload {
|
||||
bytes identity_key = 1;
|
||||
bytes identity_sig = 2;
|
||||
bytes data = 3;
|
||||
message NoiseExtensions {
|
||||
repeated bytes webtransport_certhashes = 1;
|
||||
repeated string stream_muxers = 2;
|
||||
}
|
||||
|
||||
message NoiseHandshakePayload {
|
||||
optional bytes identity_key = 1;
|
||||
optional bytes identity_sig = 2;
|
||||
optional bytes data = 3;
|
||||
}
|
||||
|
||||
@ -13,13 +13,15 @@ _sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$libp2p/security/noise/pb/noise.proto\x12\x02pb\"Q\n\x15NoiseHandshakePayload\x12\x14\n\x0cidentity_key\x18\x01 \x01(\x0c\x12\x14\n\x0cidentity_sig\x18\x02 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\x62\x06proto3')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$libp2p/security/noise/pb/noise.proto\x12\x02pb\"I\n\x0fNoiseExtensions\x12\x1f\n\x17webtransport_certhashes\x18\x01 \x03(\x0c\x12\x15\n\rstream_muxers\x18\x02 \x03(\t\"Q\n\x15NoiseHandshakePayload\x12\x14\n\x0cidentity_key\x18\x01 \x01(\x0c\x12\x14\n\x0cidentity_sig\x18\x02 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c')
|
||||
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.security.noise.pb.noise_pb2', globals())
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
|
||||
DESCRIPTOR._options = None
|
||||
_NOISEHANDSHAKEPAYLOAD._serialized_start=44
|
||||
_NOISEHANDSHAKEPAYLOAD._serialized_end=125
|
||||
_NOISEEXTENSIONS._serialized_start=44
|
||||
_NOISEEXTENSIONS._serialized_end=117
|
||||
_NOISEHANDSHAKEPAYLOAD._serialized_start=119
|
||||
_NOISEHANDSHAKEPAYLOAD._serialized_end=200
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
||||
@ -4,12 +4,34 @@ isort:skip_file
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import collections.abc
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.internal.containers
|
||||
import google.protobuf.message
|
||||
import typing
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
|
||||
@typing.final
|
||||
class NoiseExtensions(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
WEBTRANSPORT_CERTHASHES_FIELD_NUMBER: builtins.int
|
||||
STREAM_MUXERS_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def webtransport_certhashes(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ...
|
||||
@property
|
||||
def stream_muxers(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
webtransport_certhashes: collections.abc.Iterable[builtins.bytes] | None = ...,
|
||||
stream_muxers: collections.abc.Iterable[builtins.str] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["stream_muxers", b"stream_muxers", "webtransport_certhashes", b"webtransport_certhashes"]) -> None: ...
|
||||
|
||||
global___NoiseExtensions = NoiseExtensions
|
||||
|
||||
@typing.final
|
||||
class NoiseHandshakePayload(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
@ -23,10 +45,11 @@ class NoiseHandshakePayload(google.protobuf.message.Message):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
identity_key: builtins.bytes = ...,
|
||||
identity_sig: builtins.bytes = ...,
|
||||
data: builtins.bytes = ...,
|
||||
identity_key: builtins.bytes | None = ...,
|
||||
identity_sig: builtins.bytes | None = ...,
|
||||
data: builtins.bytes | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["data", b"data", "identity_key", b"identity_key", "identity_sig", b"identity_sig"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["data", b"data", "identity_key", b"identity_key", "identity_sig", b"identity_sig"]) -> None: ...
|
||||
|
||||
global___NoiseHandshakePayload = NoiseHandshakePayload
|
||||
|
||||
@ -14,6 +14,7 @@ from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
|
||||
from .early_data import EarlyDataHandler, TransportEarlyDataHandler
|
||||
from .patterns import (
|
||||
IPattern,
|
||||
PatternXX,
|
||||
@ -26,40 +27,40 @@ class Transport(ISecureTransport):
|
||||
libp2p_privkey: PrivateKey
|
||||
noise_privkey: PrivateKey
|
||||
local_peer: ID
|
||||
early_data: bytes | None
|
||||
with_noise_pipes: bool
|
||||
|
||||
# NOTE: Implementations that support Noise Pipes must decide whether to use
|
||||
# an XX or IK handshake based on whether they possess a cached static
|
||||
# Noise key for the remote peer.
|
||||
# TODO: A storage of seen noise static keys for pattern IK?
|
||||
supported_muxers: list[TProtocol]
|
||||
initiator_early_data_handler: EarlyDataHandler | None
|
||||
responder_early_data_handler: EarlyDataHandler | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
libp2p_keypair: KeyPair,
|
||||
noise_privkey: PrivateKey,
|
||||
early_data: bytes | None = None,
|
||||
with_noise_pipes: bool = False,
|
||||
supported_muxers: list[TProtocol] | None = None,
|
||||
initiator_handler: EarlyDataHandler | None = None,
|
||||
responder_handler: EarlyDataHandler | None = None,
|
||||
) -> None:
|
||||
self.libp2p_privkey = libp2p_keypair.private_key
|
||||
self.noise_privkey = noise_privkey
|
||||
self.local_peer = ID.from_pubkey(libp2p_keypair.public_key)
|
||||
self.early_data = early_data
|
||||
self.with_noise_pipes = with_noise_pipes
|
||||
self.supported_muxers = supported_muxers or []
|
||||
|
||||
if self.with_noise_pipes:
|
||||
raise NotImplementedError
|
||||
# Create default handlers for muxer negotiation if none provided
|
||||
if initiator_handler is None and self.supported_muxers:
|
||||
initiator_handler = TransportEarlyDataHandler(self.supported_muxers)
|
||||
if responder_handler is None and self.supported_muxers:
|
||||
responder_handler = TransportEarlyDataHandler(self.supported_muxers)
|
||||
|
||||
self.initiator_early_data_handler = initiator_handler
|
||||
self.responder_early_data_handler = responder_handler
|
||||
|
||||
def get_pattern(self) -> IPattern:
|
||||
if self.with_noise_pipes:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
return PatternXX(
|
||||
self.local_peer,
|
||||
self.libp2p_privkey,
|
||||
self.noise_privkey,
|
||||
self.early_data,
|
||||
)
|
||||
return PatternXX(
|
||||
self.local_peer,
|
||||
self.libp2p_privkey,
|
||||
self.noise_privkey,
|
||||
self.initiator_early_data_handler,
|
||||
self.responder_early_data_handler,
|
||||
)
|
||||
|
||||
async def secure_inbound(self, conn: IRawConnection) -> ISecureConn:
|
||||
pattern = self.get_pattern()
|
||||
|
||||
@ -17,6 +17,9 @@ from libp2p.custom_types import (
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.protocol_muxer.exceptions import (
|
||||
MultiselectError,
|
||||
)
|
||||
from libp2p.protocol_muxer.multiselect import (
|
||||
Multiselect,
|
||||
)
|
||||
@ -104,7 +107,7 @@ class SecurityMultistream(ABC):
|
||||
:param is_initiator: true if we are the initiator, false otherwise
|
||||
:return: selected secure transport
|
||||
"""
|
||||
protocol: TProtocol
|
||||
protocol: TProtocol | None
|
||||
communicator = MultiselectCommunicator(conn)
|
||||
if is_initiator:
|
||||
# Select protocol if initiator
|
||||
@ -114,5 +117,9 @@ class SecurityMultistream(ABC):
|
||||
else:
|
||||
# Select protocol if non-initiator
|
||||
protocol, _ = await self.multiselect.negotiate(communicator)
|
||||
if protocol is None:
|
||||
raise MultiselectError(
|
||||
"Failed to negotiate a security protocol: no protocol selected"
|
||||
)
|
||||
# Return transport from protocol
|
||||
return self.transports[protocol]
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from types import (
|
||||
TracebackType,
|
||||
)
|
||||
@ -32,6 +34,72 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
|
||||
class ReadWriteLock:
|
||||
"""
|
||||
A read-write lock that allows multiple concurrent readers
|
||||
or one exclusive writer, implemented using Trio primitives.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._readers = 0
|
||||
self._readers_lock = trio.Lock() # Protects access to _readers count
|
||||
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time
|
||||
|
||||
async def acquire_read(self) -> None:
|
||||
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
|
||||
try:
|
||||
async with self._readers_lock:
|
||||
if self._readers == 0:
|
||||
await self._writer_lock.acquire()
|
||||
self._readers += 1
|
||||
except trio.Cancelled:
|
||||
raise
|
||||
|
||||
async def release_read(self) -> None:
|
||||
"""Release a read lock."""
|
||||
async with self._readers_lock:
|
||||
if self._readers == 1:
|
||||
self._writer_lock.release()
|
||||
self._readers -= 1
|
||||
|
||||
async def acquire_write(self) -> None:
|
||||
"""Acquire an exclusive write lock."""
|
||||
try:
|
||||
await self._writer_lock.acquire()
|
||||
except trio.Cancelled:
|
||||
raise
|
||||
|
||||
def release_write(self) -> None:
|
||||
"""Release the exclusive write lock."""
|
||||
self._writer_lock.release()
|
||||
|
||||
@asynccontextmanager
|
||||
async def read_lock(self) -> AsyncGenerator[None, None]:
|
||||
"""Context manager for acquiring and releasing a read lock safely."""
|
||||
acquire = False
|
||||
try:
|
||||
await self.acquire_read()
|
||||
acquire = True
|
||||
yield
|
||||
finally:
|
||||
if acquire:
|
||||
with trio.CancelScope() as scope:
|
||||
scope.shield = True
|
||||
await self.release_read()
|
||||
|
||||
@asynccontextmanager
|
||||
async def write_lock(self) -> AsyncGenerator[None, None]:
|
||||
"""Context manager for acquiring and releasing a write lock safely."""
|
||||
acquire = False
|
||||
try:
|
||||
await self.acquire_write()
|
||||
acquire = True
|
||||
yield
|
||||
finally:
|
||||
if acquire:
|
||||
self.release_write()
|
||||
|
||||
|
||||
class MplexStream(IMuxedStream):
|
||||
"""
|
||||
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go
|
||||
@ -46,7 +114,7 @@ class MplexStream(IMuxedStream):
|
||||
read_deadline: int | None
|
||||
write_deadline: int | None
|
||||
|
||||
# TODO: Add lock for read/write to avoid interleaving receiving messages?
|
||||
rw_lock: ReadWriteLock
|
||||
close_lock: trio.Lock
|
||||
|
||||
# NOTE: `dataIn` is size of 8 in Go implementation.
|
||||
@ -80,6 +148,7 @@ class MplexStream(IMuxedStream):
|
||||
self.event_remote_closed = trio.Event()
|
||||
self.event_reset = trio.Event()
|
||||
self.close_lock = trio.Lock()
|
||||
self.rw_lock = ReadWriteLock()
|
||||
self.incoming_data_channel = incoming_data_channel
|
||||
self._buf = bytearray()
|
||||
|
||||
@ -113,48 +182,49 @@ class MplexStream(IMuxedStream):
|
||||
:param n: number of bytes to read
|
||||
:return: bytes actually read
|
||||
"""
|
||||
if n is not None and n < 0:
|
||||
raise ValueError(
|
||||
"the number of bytes to read `n` must be non-negative or "
|
||||
f"`None` to indicate read until EOF, got n={n}"
|
||||
)
|
||||
if self.event_reset.is_set():
|
||||
raise MplexStreamReset
|
||||
if n is None:
|
||||
return await self._read_until_eof()
|
||||
if len(self._buf) == 0:
|
||||
data: bytes
|
||||
# Peek whether there is data available. If yes, we just read until there is
|
||||
# no data, then return.
|
||||
try:
|
||||
data = self.incoming_data_channel.receive_nowait()
|
||||
self._buf.extend(data)
|
||||
except trio.EndOfChannel:
|
||||
raise MplexStreamEOF
|
||||
except trio.WouldBlock:
|
||||
# We know `receive` will be blocked here. Wait for data here with
|
||||
# `receive` and catch all kinds of errors here.
|
||||
async with self.rw_lock.read_lock():
|
||||
if n is not None and n < 0:
|
||||
raise ValueError(
|
||||
"the number of bytes to read `n` must be non-negative or "
|
||||
f"`None` to indicate read until EOF, got n={n}"
|
||||
)
|
||||
if self.event_reset.is_set():
|
||||
raise MplexStreamReset
|
||||
if n is None:
|
||||
return await self._read_until_eof()
|
||||
if len(self._buf) == 0:
|
||||
data: bytes
|
||||
# Peek whether there is data available. If yes, we just read until
|
||||
# there is no data, then return.
|
||||
try:
|
||||
data = await self.incoming_data_channel.receive()
|
||||
data = self.incoming_data_channel.receive_nowait()
|
||||
self._buf.extend(data)
|
||||
except trio.EndOfChannel:
|
||||
if self.event_reset.is_set():
|
||||
raise MplexStreamReset
|
||||
if self.event_remote_closed.is_set():
|
||||
raise MplexStreamEOF
|
||||
except trio.ClosedResourceError as error:
|
||||
# Probably `incoming_data_channel` is closed in `reset` when we are
|
||||
# waiting for `receive`.
|
||||
if self.event_reset.is_set():
|
||||
raise MplexStreamReset
|
||||
raise Exception(
|
||||
"`incoming_data_channel` is closed but stream is not reset. "
|
||||
"This should never happen."
|
||||
) from error
|
||||
self._buf.extend(self._read_return_when_blocked())
|
||||
payload = self._buf[:n]
|
||||
self._buf = self._buf[len(payload) :]
|
||||
return bytes(payload)
|
||||
raise MplexStreamEOF
|
||||
except trio.WouldBlock:
|
||||
# We know `receive` will be blocked here. Wait for data here with
|
||||
# `receive` and catch all kinds of errors here.
|
||||
try:
|
||||
data = await self.incoming_data_channel.receive()
|
||||
self._buf.extend(data)
|
||||
except trio.EndOfChannel:
|
||||
if self.event_reset.is_set():
|
||||
raise MplexStreamReset
|
||||
if self.event_remote_closed.is_set():
|
||||
raise MplexStreamEOF
|
||||
except trio.ClosedResourceError as error:
|
||||
# Probably `incoming_data_channel` is closed in `reset` when
|
||||
# we are waiting for `receive`.
|
||||
if self.event_reset.is_set():
|
||||
raise MplexStreamReset
|
||||
raise Exception(
|
||||
"`incoming_data_channel` is closed but stream is not reset."
|
||||
"This should never happen."
|
||||
) from error
|
||||
self._buf.extend(self._read_return_when_blocked())
|
||||
payload = self._buf[:n]
|
||||
self._buf = self._buf[len(payload) :]
|
||||
return bytes(payload)
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
"""
|
||||
@ -162,22 +232,21 @@ class MplexStream(IMuxedStream):
|
||||
|
||||
:return: number of bytes written
|
||||
"""
|
||||
if self.event_local_closed.is_set():
|
||||
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")
|
||||
flag = (
|
||||
HeaderTags.MessageInitiator
|
||||
if self.is_initiator
|
||||
else HeaderTags.MessageReceiver
|
||||
)
|
||||
await self.muxed_conn.send_message(flag, data, self.stream_id)
|
||||
async with self.rw_lock.write_lock():
|
||||
if self.event_local_closed.is_set():
|
||||
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")
|
||||
flag = (
|
||||
HeaderTags.MessageInitiator
|
||||
if self.is_initiator
|
||||
else HeaderTags.MessageReceiver
|
||||
)
|
||||
await self.muxed_conn.send_message(flag, data, self.stream_id)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
Closing a stream closes it for writing and closes the remote end for
|
||||
reading but allows writing in the other direction.
|
||||
"""
|
||||
# TODO error handling with timeout
|
||||
|
||||
async with self.close_lock:
|
||||
if self.event_local_closed.is_set():
|
||||
return
|
||||
@ -185,8 +254,17 @@ class MplexStream(IMuxedStream):
|
||||
flag = (
|
||||
HeaderTags.CloseInitiator if self.is_initiator else HeaderTags.CloseReceiver
|
||||
)
|
||||
# TODO: Raise when `muxed_conn.send_message` fails and `Mplex` isn't shutdown.
|
||||
await self.muxed_conn.send_message(flag, None, self.stream_id)
|
||||
|
||||
try:
|
||||
with trio.fail_after(5): # timeout in seconds
|
||||
await self.muxed_conn.send_message(flag, None, self.stream_id)
|
||||
except trio.TooSlowError:
|
||||
raise TimeoutError("Timeout while trying to close the stream")
|
||||
except MuxedConnUnavailable:
|
||||
if not self.muxed_conn.event_shutting_down.is_set():
|
||||
raise RuntimeError(
|
||||
"Failed to send close message and Mplex isn't shutting down"
|
||||
)
|
||||
|
||||
_is_remote_closed: bool
|
||||
async with self.close_lock:
|
||||
|
||||
@ -17,6 +17,9 @@ from libp2p.custom_types import (
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.protocol_muxer.exceptions import (
|
||||
MultiselectError,
|
||||
)
|
||||
from libp2p.protocol_muxer.multiselect import (
|
||||
Multiselect,
|
||||
)
|
||||
@ -73,7 +76,7 @@ class MuxerMultistream:
|
||||
:param conn: conn to choose a transport over
|
||||
:return: selected muxer transport
|
||||
"""
|
||||
protocol: TProtocol
|
||||
protocol: TProtocol | None
|
||||
communicator = MultiselectCommunicator(conn)
|
||||
if conn.is_initiator:
|
||||
protocol = await self.multiselect_client.select_one_of(
|
||||
@ -81,6 +84,10 @@ class MuxerMultistream:
|
||||
)
|
||||
else:
|
||||
protocol, _ = await self.multiselect.negotiate(communicator)
|
||||
if protocol is None:
|
||||
raise MultiselectError(
|
||||
"Fail to negotiate a stream muxer protocol: no protocol selected"
|
||||
)
|
||||
return self.transports[protocol]
|
||||
|
||||
async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn:
|
||||
|
||||
@ -45,6 +45,9 @@ from libp2p.stream_muxer.exceptions import (
|
||||
MuxedStreamReset,
|
||||
)
|
||||
|
||||
# Configure logger for this module
|
||||
logger = logging.getLogger("libp2p.stream_muxer.yamux")
|
||||
|
||||
PROTOCOL_ID = "/yamux/1.0.0"
|
||||
TYPE_DATA = 0x0
|
||||
TYPE_WINDOW_UPDATE = 0x1
|
||||
@ -98,13 +101,13 @@ class YamuxStream(IMuxedStream):
|
||||
# Flow control: Check if we have enough send window
|
||||
total_len = len(data)
|
||||
sent = 0
|
||||
logging.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
|
||||
logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
|
||||
while sent < total_len:
|
||||
# Wait for available window with timeout
|
||||
timeout = False
|
||||
async with self.window_lock:
|
||||
if self.send_window == 0:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id}: Window is zero, waiting for update"
|
||||
)
|
||||
# Release lock and wait with timeout
|
||||
@ -152,12 +155,12 @@ class YamuxStream(IMuxedStream):
|
||||
"""
|
||||
if increment <= 0:
|
||||
# If increment is zero or negative, skip sending update
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id}: Skipping window update"
|
||||
f"(increment={increment})"
|
||||
)
|
||||
return
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id}: Sending window update with increment={increment}"
|
||||
)
|
||||
|
||||
@ -185,7 +188,7 @@ class YamuxStream(IMuxedStream):
|
||||
|
||||
# If the stream is closed for receiving and the buffer is empty, raise EOF
|
||||
if self.recv_closed and not self.conn.stream_buffers.get(self.stream_id):
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id}: Stream closed for receiving and buffer empty"
|
||||
)
|
||||
raise MuxedStreamEOF("Stream is closed for receiving")
|
||||
@ -198,7 +201,7 @@ class YamuxStream(IMuxedStream):
|
||||
|
||||
# If buffer is not available, check if stream is closed
|
||||
if buffer is None:
|
||||
logging.debug(f"Stream {self.stream_id}: No buffer available")
|
||||
logger.debug(f"Stream {self.stream_id}: No buffer available")
|
||||
raise MuxedStreamEOF("Stream buffer closed")
|
||||
|
||||
# If we have data in buffer, process it
|
||||
@ -210,34 +213,34 @@ class YamuxStream(IMuxedStream):
|
||||
# Send window update for the chunk we just read
|
||||
async with self.window_lock:
|
||||
self.recv_window += len(chunk)
|
||||
logging.debug(f"Stream {self.stream_id}: Update {len(chunk)}")
|
||||
logger.debug(f"Stream {self.stream_id}: Update {len(chunk)}")
|
||||
await self.send_window_update(len(chunk), skip_lock=True)
|
||||
|
||||
# If stream is closed (FIN received) and buffer is empty, break
|
||||
if self.recv_closed and len(buffer) == 0:
|
||||
logging.debug(f"Stream {self.stream_id}: Closed with empty buffer")
|
||||
logger.debug(f"Stream {self.stream_id}: Closed with empty buffer")
|
||||
break
|
||||
|
||||
# If stream was reset, raise reset error
|
||||
if self.reset_received:
|
||||
logging.debug(f"Stream {self.stream_id}: Stream was reset")
|
||||
logger.debug(f"Stream {self.stream_id}: Stream was reset")
|
||||
raise MuxedStreamReset("Stream was reset")
|
||||
|
||||
# Wait for more data or stream closure
|
||||
logging.debug(f"Stream {self.stream_id}: Waiting for data or FIN")
|
||||
logger.debug(f"Stream {self.stream_id}: Waiting for data or FIN")
|
||||
await self.conn.stream_events[self.stream_id].wait()
|
||||
self.conn.stream_events[self.stream_id] = trio.Event()
|
||||
|
||||
# After loop exit, first check if we have data to return
|
||||
if data:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id}: Returning {len(data)} bytes after loop"
|
||||
)
|
||||
return data
|
||||
|
||||
# No data accumulated, now check why we exited the loop
|
||||
if self.conn.event_shutting_down.is_set():
|
||||
logging.debug(f"Stream {self.stream_id}: Connection shutting down")
|
||||
logger.debug(f"Stream {self.stream_id}: Connection shutting down")
|
||||
raise MuxedStreamEOF("Connection shut down")
|
||||
|
||||
# Return empty data
|
||||
@ -246,7 +249,7 @@ class YamuxStream(IMuxedStream):
|
||||
data = await self.conn.read_stream(self.stream_id, n)
|
||||
async with self.window_lock:
|
||||
self.recv_window += len(data)
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id}: Sending window update after read, "
|
||||
f"increment={len(data)}"
|
||||
)
|
||||
@ -255,7 +258,7 @@ class YamuxStream(IMuxedStream):
|
||||
|
||||
async def close(self) -> None:
|
||||
if not self.send_closed:
|
||||
logging.debug(f"Half-closing stream {self.stream_id} (local end)")
|
||||
logger.debug(f"Half-closing stream {self.stream_id} (local end)")
|
||||
header = struct.pack(
|
||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0
|
||||
)
|
||||
@ -271,7 +274,7 @@ class YamuxStream(IMuxedStream):
|
||||
|
||||
async def reset(self) -> None:
|
||||
if not self.closed:
|
||||
logging.debug(f"Resetting stream {self.stream_id}")
|
||||
logger.debug(f"Resetting stream {self.stream_id}")
|
||||
header = struct.pack(
|
||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0
|
||||
)
|
||||
@ -349,7 +352,7 @@ class Yamux(IMuxedConn):
|
||||
self._nursery: Nursery | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
logging.debug(f"Starting Yamux for {self.peer_id}")
|
||||
logger.debug(f"Starting Yamux for {self.peer_id}")
|
||||
if self.event_started.is_set():
|
||||
return
|
||||
async with trio.open_nursery() as nursery:
|
||||
@ -362,7 +365,7 @@ class Yamux(IMuxedConn):
|
||||
return self.is_initiator_value
|
||||
|
||||
async def close(self, error_code: int = GO_AWAY_NORMAL) -> None:
|
||||
logging.debug(f"Closing Yamux connection with code {error_code}")
|
||||
logger.debug(f"Closing Yamux connection with code {error_code}")
|
||||
async with self.streams_lock:
|
||||
if not self.event_shutting_down.is_set():
|
||||
try:
|
||||
@ -371,7 +374,7 @@ class Yamux(IMuxedConn):
|
||||
)
|
||||
await self.secured_conn.write(header)
|
||||
except Exception as e:
|
||||
logging.debug(f"Failed to send GO_AWAY: {e}")
|
||||
logger.debug(f"Failed to send GO_AWAY: {e}")
|
||||
self.event_shutting_down.set()
|
||||
for stream in self.streams.values():
|
||||
stream.closed = True
|
||||
@ -382,12 +385,12 @@ class Yamux(IMuxedConn):
|
||||
self.stream_events.clear()
|
||||
try:
|
||||
await self.secured_conn.close()
|
||||
logging.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
|
||||
logger.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
|
||||
except Exception as e:
|
||||
logging.debug(f"Error closing secured_conn for peer {self.peer_id}: {e}")
|
||||
logger.debug(f"Error closing secured_conn for peer {self.peer_id}: {e}")
|
||||
self.event_closed.set()
|
||||
if self.on_close:
|
||||
logging.debug(f"Calling on_close in Yamux.close for peer {self.peer_id}")
|
||||
logger.debug(f"Calling on_close in Yamux.close for peer {self.peer_id}")
|
||||
if inspect.iscoroutinefunction(self.on_close):
|
||||
if self.on_close is not None:
|
||||
await self.on_close()
|
||||
@ -416,7 +419,7 @@ class Yamux(IMuxedConn):
|
||||
header = struct.pack(
|
||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_SYN, stream_id, 0
|
||||
)
|
||||
logging.debug(f"Sending SYN header for stream {stream_id}")
|
||||
logger.debug(f"Sending SYN header for stream {stream_id}")
|
||||
await self.secured_conn.write(header)
|
||||
return stream
|
||||
except Exception as e:
|
||||
@ -424,32 +427,32 @@ class Yamux(IMuxedConn):
|
||||
raise e
|
||||
|
||||
async def accept_stream(self) -> IMuxedStream:
|
||||
logging.debug("Waiting for new stream")
|
||||
logger.debug("Waiting for new stream")
|
||||
try:
|
||||
stream = await self.new_stream_receive_channel.receive()
|
||||
logging.debug(f"Received stream {stream.stream_id}")
|
||||
logger.debug(f"Received stream {stream.stream_id}")
|
||||
return stream
|
||||
except trio.EndOfChannel:
|
||||
raise MuxedStreamError("No new streams available")
|
||||
|
||||
async def read_stream(self, stream_id: int, n: int = -1) -> bytes:
|
||||
logging.debug(f"Reading from stream {self.peer_id}:{stream_id}, n={n}")
|
||||
logger.debug(f"Reading from stream {self.peer_id}:{stream_id}, n={n}")
|
||||
if n is None:
|
||||
n = -1
|
||||
|
||||
while True:
|
||||
async with self.streams_lock:
|
||||
if stream_id not in self.streams:
|
||||
logging.debug(f"Stream {self.peer_id}:{stream_id} unknown")
|
||||
logger.debug(f"Stream {self.peer_id}:{stream_id} unknown")
|
||||
raise MuxedStreamEOF("Stream closed")
|
||||
if self.event_shutting_down.is_set():
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.peer_id}:{stream_id}: connection shutting down"
|
||||
)
|
||||
raise MuxedStreamEOF("Connection shut down")
|
||||
stream = self.streams[stream_id]
|
||||
buffer = self.stream_buffers.get(stream_id)
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.peer_id}:{stream_id}: "
|
||||
f"closed={stream.closed}, "
|
||||
f"recv_closed={stream.recv_closed}, "
|
||||
@ -457,7 +460,7 @@ class Yamux(IMuxedConn):
|
||||
f"buffer_len={len(buffer) if buffer else 0}"
|
||||
)
|
||||
if buffer is None:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.peer_id}:{stream_id}:"
|
||||
f"Buffer gone, assuming closed"
|
||||
)
|
||||
@ -470,7 +473,7 @@ class Yamux(IMuxedConn):
|
||||
else:
|
||||
data = bytes(buffer[:n])
|
||||
del buffer[:n]
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Returning {len(data)} bytes"
|
||||
f"from stream {self.peer_id}:{stream_id}, "
|
||||
f"buffer_len={len(buffer)}"
|
||||
@ -478,7 +481,7 @@ class Yamux(IMuxedConn):
|
||||
return data
|
||||
# If reset received and buffer is empty, raise reset
|
||||
if stream.reset_received:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.peer_id}:{stream_id}:"
|
||||
f"reset_received=True, raising MuxedStreamReset"
|
||||
)
|
||||
@ -491,7 +494,7 @@ class Yamux(IMuxedConn):
|
||||
else:
|
||||
data = bytes(buffer[:n])
|
||||
del buffer[:n]
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Returning {len(data)} bytes"
|
||||
f"from stream {self.peer_id}:{stream_id}, "
|
||||
f"buffer_len={len(buffer)}"
|
||||
@ -499,21 +502,21 @@ class Yamux(IMuxedConn):
|
||||
return data
|
||||
# Check if stream is closed
|
||||
if stream.closed:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.peer_id}:{stream_id}:"
|
||||
f"closed=True, raising MuxedStreamReset"
|
||||
)
|
||||
raise MuxedStreamReset("Stream is reset or closed")
|
||||
# Check if recv_closed and buffer empty
|
||||
if stream.recv_closed:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.peer_id}:{stream_id}:"
|
||||
f"recv_closed=True, buffer empty, raising EOF"
|
||||
)
|
||||
raise MuxedStreamEOF("Stream is closed for receiving")
|
||||
|
||||
# Wait for data if stream is still open
|
||||
logging.debug(f"Waiting for data on stream {self.peer_id}:{stream_id}")
|
||||
logger.debug(f"Waiting for data on stream {self.peer_id}:{stream_id}")
|
||||
try:
|
||||
await self.stream_events[stream_id].wait()
|
||||
self.stream_events[stream_id] = trio.Event()
|
||||
@ -528,7 +531,7 @@ class Yamux(IMuxedConn):
|
||||
try:
|
||||
header = await self.secured_conn.read(HEADER_SIZE)
|
||||
if not header or len(header) < HEADER_SIZE:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Connection closed orincomplete header for peer {self.peer_id}"
|
||||
)
|
||||
self.event_shutting_down.set()
|
||||
@ -537,7 +540,7 @@ class Yamux(IMuxedConn):
|
||||
version, typ, flags, stream_id, length = struct.unpack(
|
||||
YAMUX_HEADER_FORMAT, header
|
||||
)
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Received header for peer {self.peer_id}:"
|
||||
f"type={typ}, flags={flags}, stream_id={stream_id},"
|
||||
f"length={length}"
|
||||
@ -558,7 +561,7 @@ class Yamux(IMuxedConn):
|
||||
0,
|
||||
)
|
||||
await self.secured_conn.write(ack_header)
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Sending stream {stream_id}"
|
||||
f"to channel for peer {self.peer_id}"
|
||||
)
|
||||
@ -576,7 +579,7 @@ class Yamux(IMuxedConn):
|
||||
elif typ == TYPE_DATA and flags & FLAG_RST:
|
||||
async with self.streams_lock:
|
||||
if stream_id in self.streams:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Resetting stream {stream_id} for peer {self.peer_id}"
|
||||
)
|
||||
self.streams[stream_id].closed = True
|
||||
@ -585,27 +588,27 @@ class Yamux(IMuxedConn):
|
||||
elif typ == TYPE_DATA and flags & FLAG_ACK:
|
||||
async with self.streams_lock:
|
||||
if stream_id in self.streams:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Received ACK for stream"
|
||||
f"{stream_id} for peer {self.peer_id}"
|
||||
)
|
||||
elif typ == TYPE_GO_AWAY:
|
||||
error_code = length
|
||||
if error_code == GO_AWAY_NORMAL:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Received GO_AWAY for peer"
|
||||
f"{self.peer_id}: Normal termination"
|
||||
)
|
||||
elif error_code == GO_AWAY_PROTOCOL_ERROR:
|
||||
logging.error(
|
||||
logger.error(
|
||||
f"Received GO_AWAY for peer{self.peer_id}: Protocol error"
|
||||
)
|
||||
elif error_code == GO_AWAY_INTERNAL_ERROR:
|
||||
logging.error(
|
||||
logger.error(
|
||||
f"Received GO_AWAY for peer {self.peer_id}: Internal error"
|
||||
)
|
||||
else:
|
||||
logging.error(
|
||||
logger.error(
|
||||
f"Received GO_AWAY for peer {self.peer_id}"
|
||||
f"with unknown error code: {error_code}"
|
||||
)
|
||||
@ -614,7 +617,7 @@ class Yamux(IMuxedConn):
|
||||
break
|
||||
elif typ == TYPE_PING:
|
||||
if flags & FLAG_SYN:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Received ping request with value"
|
||||
f"{length} for peer {self.peer_id}"
|
||||
)
|
||||
@ -623,7 +626,7 @@ class Yamux(IMuxedConn):
|
||||
)
|
||||
await self.secured_conn.write(ping_header)
|
||||
elif flags & FLAG_ACK:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Received ping response with value"
|
||||
f"{length} for peer {self.peer_id}"
|
||||
)
|
||||
@ -637,7 +640,7 @@ class Yamux(IMuxedConn):
|
||||
self.stream_buffers[stream_id].extend(data)
|
||||
self.stream_events[stream_id].set()
|
||||
if flags & FLAG_FIN:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Received FIN for stream {self.peer_id}:"
|
||||
f"{stream_id}, marking recv_closed"
|
||||
)
|
||||
@ -645,7 +648,7 @@ class Yamux(IMuxedConn):
|
||||
if self.streams[stream_id].send_closed:
|
||||
self.streams[stream_id].closed = True
|
||||
except Exception as e:
|
||||
logging.error(f"Error reading data for stream {stream_id}: {e}")
|
||||
logger.error(f"Error reading data for stream {stream_id}: {e}")
|
||||
# Mark stream as closed on read error
|
||||
async with self.streams_lock:
|
||||
if stream_id in self.streams:
|
||||
@ -659,7 +662,7 @@ class Yamux(IMuxedConn):
|
||||
if stream_id in self.streams:
|
||||
stream = self.streams[stream_id]
|
||||
async with stream.window_lock:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Received window update for stream"
|
||||
f"{self.peer_id}:{stream_id},"
|
||||
f" increment: {increment}"
|
||||
@ -674,7 +677,7 @@ class Yamux(IMuxedConn):
|
||||
and details.get("requested_count") == 2
|
||||
and details.get("received_count") == 0
|
||||
):
|
||||
logging.info(
|
||||
logger.info(
|
||||
f"Stream closed cleanly for peer {self.peer_id}"
|
||||
+ f" (IncompleteReadError: {details})"
|
||||
)
|
||||
@ -682,15 +685,32 @@ class Yamux(IMuxedConn):
|
||||
await self._cleanup_on_error()
|
||||
break
|
||||
else:
|
||||
logging.error(
|
||||
logger.error(
|
||||
f"Error in handle_incoming for peer {self.peer_id}: "
|
||||
+ f"{type(e).__name__}: {str(e)}"
|
||||
)
|
||||
else:
|
||||
logging.error(
|
||||
f"Error in handle_incoming for peer {self.peer_id}: "
|
||||
+ f"{type(e).__name__}: {str(e)}"
|
||||
)
|
||||
# Handle RawConnError with more nuance
|
||||
if isinstance(e, RawConnError):
|
||||
error_msg = str(e)
|
||||
# If RawConnError is empty, it's likely normal cleanup
|
||||
if not error_msg.strip():
|
||||
logger.info(
|
||||
f"RawConnError (empty) during cleanup for peer "
|
||||
f"{self.peer_id} (normal connection shutdown)"
|
||||
)
|
||||
else:
|
||||
# Log non-empty RawConnError as warning
|
||||
logger.warning(
|
||||
f"RawConnError during connection handling for peer "
|
||||
f"{self.peer_id}: {error_msg}"
|
||||
)
|
||||
else:
|
||||
# Log all other errors normally
|
||||
logger.error(
|
||||
f"Error in handle_incoming for peer {self.peer_id}: "
|
||||
+ f"{type(e).__name__}: {str(e)}"
|
||||
)
|
||||
# Don't crash the whole connection for temporary errors
|
||||
if self.event_shutting_down.is_set() or isinstance(
|
||||
e, (RawConnError, OSError)
|
||||
@ -720,9 +740,9 @@ class Yamux(IMuxedConn):
|
||||
# Close the secured connection
|
||||
try:
|
||||
await self.secured_conn.close()
|
||||
logging.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
|
||||
logger.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
|
||||
except Exception as close_error:
|
||||
logging.error(
|
||||
logger.error(
|
||||
f"Error closing secured_conn for peer {self.peer_id}: {close_error}"
|
||||
)
|
||||
|
||||
@ -731,14 +751,14 @@ class Yamux(IMuxedConn):
|
||||
|
||||
# Call on_close callback if provided
|
||||
if self.on_close:
|
||||
logging.debug(f"Calling on_close for peer {self.peer_id}")
|
||||
logger.debug(f"Calling on_close for peer {self.peer_id}")
|
||||
try:
|
||||
if inspect.iscoroutinefunction(self.on_close):
|
||||
await self.on_close()
|
||||
else:
|
||||
self.on_close()
|
||||
except Exception as callback_error:
|
||||
logging.error(f"Error in on_close callback: {callback_error}")
|
||||
logger.error(f"Error in on_close callback: {callback_error}")
|
||||
|
||||
# Cancel nursery tasks
|
||||
if self._nursery:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user