65 Commits

Author SHA1 Message Date
8b2268fcc9 fix: improve async validator handling in Pubsub class
Signed-off-by: varun-r-mallya <varunrmallya@gmail.com>
2025-06-26 19:56:53 +05:30
6a92fa26eb Merge pull request #704 from sukhman-sukh/piggyback-gossipsub
Remove piggyback TODO from gossipsub
2025-06-26 06:53:32 -07:00
e8a484f8e4 Merge branch 'main' into piggyback-gossipsub 2025-06-26 06:44:03 -07:00
74134e9b63 Remove piggyback TODO from gossipsub
Signed-off-by: sukhman <sukhmansinghsaluja@gmail.com>
2025-06-26 15:31:16 +05:30
2277d822f1 Merge pull request #690 from mystical-prog/px-backoff
Peer Exchange and Back Off
2025-06-25 22:37:33 -07:00
b736cfa333 Merge branch 'main' into px-backoff 2025-06-25 22:28:41 -07:00
73ebd27c59 added isolated_topics_test and stress_test 2025-06-26 01:28:21 +05:30
5a3adad093 Merge pull request #631 from Winter-Soren/feat/619-store-pubkey-peerid-peerstore
feat: store pubkey and peerid in peerstore
2025-06-24 14:20:22 -07:00
4e2be87c73 Merge pull request #695 from LVivona/patch-1
chore(kad_dht): centralize shared values in common.py file
2025-06-23 08:55:21 -07:00
fbee0ba2ab added newsfragment 2025-06-23 01:00:46 +05:30
ea6eef6ed5 test px and backoff 2025-06-23 00:41:13 +05:30
fd818d9102 test: added tests to ensure handshake adds pubkey to existing peer ID without one; peerstore unchanged on ID mismatch 2025-06-22 16:01:02 +05:30
2c0a6c0adb Merge branch 'main' into feat/619-store-pubkey-peerid-peerstore 2025-06-22 15:26:51 +05:30
3a4338e1df chore: eliminate self.protocol_id attribute \w in PeerRouting 2025-06-22 00:25:48 -04:00
feb8db6655 style: enforce multiline import style 2025-06-22 00:15:44 -04:00
ebdde7b5aa style: enforce multiline import style for consistency 2025-06-21 15:08:11 -04:00
24e73207d2 fixed failing demo
Co-authored-by: Khwahish Patel <khwahish.p1@ahduni.edu.in>
2025-06-21 18:54:17 +05:30
303bf3060a implemented peer exchange 2025-06-21 18:54:17 +05:30
788b4cf51a added complete back_off implementation 2025-06-21 18:54:17 +05:30
b78468ca32 added params for peer exchange and back off 2025-06-21 18:54:17 +05:30
c48618825d updated protobuf for prune message 2025-06-21 18:54:16 +05:30
811c217ee6 style: isort fix ording of imports 2025-06-20 16:01:11 -04:00
d03ca45bd8 style: fix flake8 linting errors 2025-06-20 11:57:50 -04:00
79ac01308c remove: unused custom_types TProtocol import 2025-06-19 21:38:02 -04:00
dfc0bb4ec8 chore(kad_dht): centralize shared values in common.py 2025-06-19 21:24:39 -04:00
09b4c846a4 feat: add support for sparse connect (#680)
* init

* add newsfragment

* fix
2025-06-19 06:18:45 -06:00
66bd027161 Feat/587-circuit-relay (#611)
* feat: implemented setup of circuit relay and test cases

* chore: remove test files to be rewritten

* added 1 test suite for protocol

* added 1 test suite for discovery

* fixed protocol timeouts and message types to handle reservations and stream operations.

* Resolved merge conflict in libp2p/tools/utils.py by combining timeout approach with retry mechanism

* fix: linting issues

* docs: updated documentation with circuit-relay

* chore: added enums, improved typing, security and examples

* fix: created proper __init__ file to ensure importability

* fix: replace transport_opt with listen_addrs in examples, fixed typing and improved code

* fix type checking issues across relay module and test suite

* regenerated circuit_pb2 file protobuf version 3

* fixed circuit relay example and moved imports to top in test_security_multistream

* chore: moved imports to the top

* chore: fixed linting of test_circuit_v2_transport.py

---------

Co-authored-by: Manu Sheel Gupta <manusheel.edu@gmail.com>
2025-06-18 15:39:39 -06:00
8c16b316ac added newsfragement and tests that would fail without these changes but pass with them 2025-06-18 23:32:48 +05:30
d4ed859b19 Merge branch 'main' into feat/619-store-pubkey-peerid-peerstore 2025-06-18 22:45:33 +05:30
79094d70d3 Optimize pubsub publishing to support multiple topics in single RPC message (#686)
* init

* add newsfragment

* lint

---------

Co-authored-by: Manu Sheel Gupta <manusheel.edu@gmail.com>
2025-06-17 15:23:03 -06:00
2ed2587fc9 fix: removed dummy ID(b) from upgrade_security for inbound connections (#681)
* fix: removed dummy ID(b) from upgrade_security for inbound connections

* added newsfragment

* updated newsfragment
2025-06-17 06:25:50 -06:00
d61bca78ab Kademlia DHT implementation in py-libp2p (#579)
* initialise the module

* added content routing

* added routing module

* added peer routing

* added value store

* added utilities functions

* added main kademlia file

* fixed create_key_from_binary function

* example to test kademlia dht

* added protocol ID and enhanced logging for peer store size in provider and consumer nodes

* refactor: specify stream type in handle_stream method and add peer in routing table

* removed content routing

* added default value of count for finding closest peers

* added functions to find close peers

* refactor: remove content routing and enhance peer discovery

* added put value function

* added get value function

* fix: improve logging and handle key encoding in get_value method

* refactor: remove ContentRouting import from __init__.py

* refactor: improved basic kademlia example

* added protobuf files

* replaced json with protobuf

* refactor: enhance peer discovery and routing logic in KadDHT

* refactor: enhance Kademlia routing table to use PeerInfo objects and improve peer management

* refactor: enhance peer addition logic to utilize PeerInfo objects in routing table

* feat: implement content provider functionality in Kademlia DHT

* refactor: update value store to use datetime for validity management

* refactor: update RoutingTable initialization to include host reference

* refactor: enhance KBucket and RoutingTable for improved peer management and functionality

* refactor: streamline peer discovery and value storage methods in KadDHT

* refactor: update KadDHT and related classes for async peer management and enhanced value storage

* refactor: enhance ProviderStore initialization and improve peer routing integration

* test: add tests for Kademlia DHT functionality

* fix linting issues

* pydocstyle issues fixed

* CICD pipeline issues solved

* fix: update docstring format for find_peer method

* refactor: improve logging and remove unused code in DHT implementation

* refactor: clean up logging and remove unused imports in DHT and test files

* Refactor logging setup and improve DHT stream handling with varint length prefixes

* Update bootstrap peer handling in basic_dht example and refactor peer routing to accept string addresses

* Enhance peer querying in Kademlia DHT by implementing parallel queries using Trio.

* Enhance peer querying by adding deduplication checks

* Refactor DHT implementation to use varint for length prefixes and enhance logging for better traceability

* Add base58 encoding for value storage and enhance logging in basic_dht example

* Refactor Kademlia DHT to support server/client modes

* Added unit tests

* Refactor documentation to fixsome warning

* Add unit tests and remove outdated tests

* Fixed precommit errora

* Refactor error handling test to raise StringParseError for invalid bootstrap addresses

* Add libp2p.kad_dht to the list of subpackages in documentation

* Fix expiration and republish checks to use inclusive comparison

* Add __init__.py file to libp2p.kad_dht.pb package

* Refactor get value and put value to run in parallel with query timeout

* Refactor provider message handling to use parallel processing with timeout

* Add methods for provider store in KadDHT class

* Refactor KadDHT and ProviderStore methods to improve type hints and enhance parallel processing

* Add documentation for libp2p.kad_dht.pb module.

* Update documentation for libp2p.kad_dht package to include subpackages and correct formatting

* Fix formatting in documentation for libp2p.kad_dht package by correcting the subpackage reference

* Fix header formatting in libp2p.kad_dht.pb documentation

* Change log level from info to debug for various logging statements.

* fix CICD issues (post revamp)

* fixed value store unit test

* Refactored kademlia example

* Refactor Kademlia example: enhance logging, improve bootstrap node connection, and streamline server address handling

* removed bootstrap module

* Refactor Kademlia DHT example and core modules: enhance logging, remove unused code, and improve peer handling

* Added docs of kad dht example

* Update server address log file path to use the script's directory

* Refactor: Introduce DHTMode enum for clearer mode management

* moved xor_distance function to utils.py

* Enhance logging in ValueStore and KadDHT: include decoded value in debug logs and update parameter description for validity

* Add handling for closest peers in GET_VALUE response when value is not found

* Handled failure scenario for PUT_VALUE

* Remove kademlia demo from project scripts and contributing documentation

* spelling and logging

---------

Co-authored-by: pacrob <5199899+pacrob@users.noreply.github.com>
2025-06-16 14:46:40 -06:00
a3492cf82f Merge branch 'main' into feat/619-store-pubkey-peerid-peerstore 2025-06-16 22:06:47 +05:30
733ef86e62 refactor(gossipsub.py): Add helper function to fanout and gossipsub (#678)
* fanout and gossibsub helper

* add newsfragment

* remove dub fanout check
2025-06-16 07:23:31 -06:00
0caf8647c5 Merge pull request #684 from guha-rahul/use_decapsulate
fix: replace complex logic with decapsulate
2025-06-16 04:37:14 -07:00
193e8f9cb8 add newsfragment 2025-06-15 19:58:52 +05:30
10b39dad1c replace complex logic with decapsulate 2025-06-15 19:51:55 +05:30
2248108b54 fixed types and improved code 2025-06-11 23:51:57 +05:30
a762df6042 Merge branch 'main' into feat/619-store-pubkey-peerid-peerstore 2025-06-11 23:34:12 +05:30
d2825af045 fix(examples/echo/echo.py): Add max message length to stream.read (#671)
Signed-off-by: varun-r-mallya <varunrmallya@gmail.com>
2025-06-10 11:44:07 -06:00
0f483dd744 Bump version: 0.2.7 → 0.2.8 2025-06-10 11:31:46 -06:00
0197b515c1 Compile release notes for v0.2.8 2025-06-10 11:31:25 -06:00
f27f4ddd85 remove references to removed setup.py (#674) 2025-06-10 11:24:34 -06:00
7e377ede36 Merge branch 'main' into feat/619-store-pubkey-peerid-peerstore 2025-06-10 21:12:28 +05:30
630aac703d add make pr (#672) 2025-06-10 08:34:22 -06:00
286752c517 Merge pull request #658 from AkMo3/main
fix: add connection state for net stream and gracefully handle failure
2025-06-10 01:03:18 +05:30
390ac2eb26 Merge branch 'main' into main 2025-06-10 00:53:59 +05:30
13d730ae5c fix: improve types according to new typecheck 2025-06-09 19:10:15 +00:00
4e9fa87477 Updated examples to automatically use random port (#661)
* updated examples to automatically use random port

* Refactor examples to use shared utils for port selection (#1)

---------

Co-authored-by: acul71 <34693171+acul71@users.noreply.github.com>
2025-06-09 12:59:11 -06:00
47ae20d29c fix: run pytests parallely in CI and makefile 2025-06-09 18:58:51 +00:00
f7757fa726 docs: add documentation and examples for new NetStream state management 2025-06-09 18:58:17 +00:00
5bc4d01eea fix: add connection states for net stream
Other changes:
1. Add operation validation based on states
2. Gracefully handle exceptions and cleanup
2025-06-09 18:58:17 +00:00
c83fc1582d build(deps): bump fastecdsa from 1.7.5 to 2.3.2 (#669)
Bumps [fastecdsa](https://github.com/AntonKueltz/fastecdsa) from 1.7.5 to 2.3.2.
- [Release notes](https://github.com/AntonKueltz/fastecdsa/releases)
- [Changelog](https://github.com/AntonKueltz/fastecdsa/blob/main/CHANGELOG.md)
- [Commits](https://github.com/AntonKueltz/fastecdsa/compare/v1.7.5...v2.3.2)

---
updated-dependencies:
- dependency-name: fastecdsa
  dependency-version: 2.3.2
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-06-09 12:44:45 -06:00
22d93b39ae Add ttl for peer data expiration (#655)
* Add ttl and last_identified to peerdata

* Add test for ttl

Signed-off-by: sukhman <sukhmansinghsaluja@gmail.com>

* Fix lint and add newsfragments

Signed-off-by: sukhman <sukhmansinghsaluja@gmail.com>

* Fix failing ci

Signed-off-by: sukhman <sukhmansinghsaluja@gmail.com>

* fix ttl time from 600 to 120

Signed-off-by: sukhman <sukhmansinghsaluja@gmail.com>

* fix test ttl timeout and lint errors

Signed-off-by: sukhman <sukhmansinghsaluja@gmail.com>

* Fix docstrings

Signed-off-by: sukhman <sukhmansinghsaluja@gmail.com>

* rebase main

* remove print statement

---------

Signed-off-by: sukhman <sukhmansinghsaluja@gmail.com>
Co-authored-by: pacrob <5199899+pacrob@users.noreply.github.com>
2025-06-09 12:42:59 -06:00
bdadec7519 ft. modernise py-libp2p (#618)
* fix pyproject.toml , add ruff

* rm lock

* make progress

* add poetry lock ignore

* fix type issues

* fix tcp type errors

* fix text example - type error - wrong args

* add setuptools to dev

* test ci

* fix docs build

* fix type issues for new_swarm & new_host

* fix types in gossipsub

* fix type issues in noise

* wip: factories

* revert factories

* fix more type issues

* more type fixes

* fix: add null checks for noise protocol initialization and key handling

* corrected argument-errors in peerId and Multiaddr in peer tests

* fix: Noice - remove redundant type casts in BaseNoiseMsgReadWriter

* fix: update test_notify.py to use SwarmFactory.create_batch_and_listen, fix type hints, and comment out ClosedStream assertions

* Fix type checks for pubsub module

Signed-off-by: sukhman <sukhmansinghsaluja@gmail.com>

* Fix type checks for pubsub module-tests

Signed-off-by: sukhman <sukhmansinghsaluja@gmail.com>

* noise: add checks for uninitialized protocol and key states in PatternXX

Signed-off-by: varun-r-mallya <varunrmallya@gmail.com>

* pubsub: add None checks for optional fields in FloodSub and Pubsub

Signed-off-by: varun-r-mallya <varunrmallya@gmail.com>

* Fix type hints and improve testing

Signed-off-by: varun-r-mallya <varunrmallya@gmail.com>

* remove redundant checks

Signed-off-by: varun-r-mallya <varunrmallya@gmail.com>

* fix build issues

* add optional to trio service

* fix types

* fix type errors

* Fix type errors

Signed-off-by: varun-r-mallya <varunrmallya@gmail.com>

* fixed more-type checks in crypto and peer_data files

* wip: factories

* replaced union with optional

* fix: type-error in interp-utils and peerinfo

* replace pyright with pyrefly

* add pyrefly.toml

* wip: fix multiselect issues

* try typecheck

* base check

* mcache test fixes , typecheck ci update

* fix ci

* will this work

* minor fix

* use poetry

* fix wokflow

* use cache,fix err

* fix pyrefly.toml

* fix pyrefly.toml

* fix cache in ci

* deploy commit

* add main baseline

* update to v5

* improve typecheck ci (#14)

* fix typo

* remove holepunching code (#16)

* fix gossipsub typeerrors (#17)

* fix: ensure initiator user includes remote peer id in handshake (#15)

* fix ci (#19)

* typefix: custom_types | core/peerinfo/test_peer_info | io/abc | pubsub/floodsub | protocol_muxer/multiselect (#18)

* fix: Typefixes in PeerInfo  (#21)

* fix minor type issue (#22)

* fix type errors in pubsub (#24)

* fix: Minor typefixes in tests (#23)

* Fix failing tests for type-fixed test/pubsub (#8)

* move pyrefly & ruff to pyproject.toml & rm .project-template (#28)

* move the async_context file to tests/core

* move crypto test to crypto folder

* fix: some typefixes (#25)

* fix type errors

* fix type issues

* fix: update gRPC API usage in autonat_pb2_grpc.py (#31)

* md: typecheck ci

* rm comments

* clean up : from review suggestions

* use | None over Optional as per new python standards

* drop supporto for py3.9

* newsfragments

---------

Signed-off-by: sukhman <sukhmansinghsaluja@gmail.com>
Signed-off-by: varun-r-mallya <varunrmallya@gmail.com>
Co-authored-by: acul71 <luca.pisani@birdo.net>
Co-authored-by: kaneki003 <sakshamchauhan707@gmail.com>
Co-authored-by: sukhman <sukhmansinghsaluja@gmail.com>
Co-authored-by: varun-r-mallya <varunrmallya@gmail.com>
Co-authored-by: varunrmallya <100590632+varun-r-mallya@users.noreply.github.com>
Co-authored-by: lla-dane <abhinavagarwalla6@gmail.com>
Co-authored-by: Collins <ArtemisfowlX@protonmail.com>
Co-authored-by: Abhinav Agarwalla <120122716+lla-dane@users.noreply.github.com>
Co-authored-by: guha-rahul <52607971+guha-rahul@users.noreply.github.com>
Co-authored-by: Sukhman Singh <63765293+sukhman-sukh@users.noreply.github.com>
Co-authored-by: acul71 <34693171+acul71@users.noreply.github.com>
Co-authored-by: pacrob <5199899+pacrob@users.noreply.github.com>
2025-06-09 11:39:59 -06:00
d020bbc066 Add time_since_last_publish (#642)
Added `time_since_last_publish` field to gossipsub. Took reference from
https://github.com/libp2p/go-libp2p-pubsub/blob/master/gossipsub.go#L1224

Issue https://github.com/libp2p/py-libp2p/issues/636

## How was it fixed?
whenever someone publishes message to a topic or set of topics,
`time_since_last_publish` gets updated and whenever we clear fanout
peers or time exceeds ttl, we clear `time_since_last_publish` from dict.

### To-Do

Creating draft PR for now. Tests and type-binding is left for this
issue.
#### Cute Animal Picture

![put a cute animal picture link inside the
parentheses](https://i.etsystatic.com/27171676/r/il/eedb08/5303109239/il_570xN.5303109239_4o61.jpg)
2025-06-09 00:53:36 +05:30
00f10dbec3 Merge branch 'main' into add-last-publish 2025-06-08 19:19:30 +05:30
d75886b180 renamed newsfragment file causing docs ci failure 2025-06-06 17:55:40 +05:30
5ca6f26933 feat: Add blacklisting of peers (#651)
* init

* remove blacklist validation after hello packet

* add docs and newsfragment
2025-06-05 09:10:04 -06:00
a3c9ac61e6 Improve performance of read from daemon test (#646)
Signed-off-by: varun-r-mallya <varunrmallya@gmail.com>
2025-06-05 07:25:59 -06:00
d4785b9e26 Add newsfragments to the PR
Signed-off-by: sukhman <sukhmansinghsaluja@gmail.com>
2025-06-05 14:00:23 +05:30
cef217358f fixed fanout_heartbeat bug and gossipsub join test 2025-06-05 13:39:07 +05:30
338672214c Add test for time_since_last_publish
Signed-off-by: sukhman <sukhmansinghsaluja@gmail.com>
2025-06-04 14:15:07 +05:30
c2046e6aa4 Add time_since_last_publish 2025-06-01 01:47:47 +05:30
30b5811d39 feat: store pubkey and peerid in peerstore 2025-05-29 20:07:48 +05:30
181 changed files with 14080 additions and 1619 deletions

View File

@ -16,10 +16,10 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python: ['3.9', '3.10', '3.11', '3.12', '3.13']
python: ["3.10", "3.11", "3.12", "3.13"]
toxenv: [core, interop, lint, wheel, demos]
include:
- python: '3.10'
- python: "3.10"
toxenv: docs
fail-fast: false
steps:
@ -46,7 +46,7 @@ jobs:
runs-on: windows-latest
strategy:
matrix:
python-version: ['3.11', '3.12', '3.13']
python-version: ["3.11", "3.12", "3.13"]
toxenv: [core, wheel]
fail-fast: false
steps:

7
.gitignore vendored
View File

@ -146,6 +146,9 @@ instance/
# PyBuilder
target/
# PyRight Config
pyrightconfig.json
# Jupyter Notebook
.ipynb_checkpoints
@ -171,3 +174,7 @@ env.bak/
# mkdocs documentation
/site
#lockfiles
uv.lock
poetry.lock

View File

@ -1,59 +1,49 @@
exclude: '.project-template|docs/conf.py|.*pb2\..*'
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: check-yaml
- id: check-toml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/asottile/pyupgrade
rev: v3.15.0
- id: check-yaml
- id: check-toml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/asottile/pyupgrade
rev: v3.20.0
hooks:
- id: pyupgrade
args: [--py39-plus]
- repo: https://github.com/psf/black
rev: 23.9.1
- id: pyupgrade
args: [--py310-plus]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.10
hooks:
- id: black
- repo: https://github.com/PyCQA/flake8
rev: 6.1.0
hooks:
- id: flake8
additional_dependencies:
- flake8-bugbear==23.9.16
exclude: setup.py
- repo: https://github.com/PyCQA/autoflake
rev: v2.2.1
hooks:
- id: autoflake
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
- repo: https://github.com/pycqa/pydocstyle
rev: 6.3.0
hooks:
- id: pydocstyle
additional_dependencies:
- tomli # required until >= python311
- repo: https://github.com/executablebooks/mdformat
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- id: ruff-format
- repo: https://github.com/executablebooks/mdformat
rev: 0.7.22
hooks:
- id: mdformat
- id: mdformat
additional_dependencies:
- mdformat-gfm
- repo: local
- mdformat-gfm
- repo: local
hooks:
- id: mypy-local
- id: mypy-local
name: run mypy with all dev dependencies present
entry: python -m mypy -p libp2p
entry: mypy -p libp2p
language: system
always_run: true
pass_filenames: false
- repo: local
- repo: local
hooks:
- id: check-rst-files
- id: pyrefly-local
name: run pyrefly typecheck locally
entry: pyrefly check
language: system
always_run: true
pass_filenames: false
- repo: local
hooks:
- id: check-rst-files
name: Check for .rst files in the top-level directory
entry: python -c "import glob, sys; rst_files = glob.glob('*.rst'); sys.exit(1) if rst_files else sys.exit(0)"
language: system

View File

@ -1,71 +0,0 @@
#!/usr/bin/env python3
import os
import sys
import re
from pathlib import Path
def _find_files(project_root):
path_exclude_pattern = r"\.git($|\/)|venv|_build"
file_exclude_pattern = r"fill_template_vars\.py|\.swp$"
filepaths = []
for dir_path, _dir_names, file_names in os.walk(project_root):
if not re.search(path_exclude_pattern, dir_path):
for file in file_names:
if not re.search(file_exclude_pattern, file):
filepaths.append(str(Path(dir_path, file)))
return filepaths
def _replace(pattern, replacement, project_root):
print(f"Replacing values: {pattern}")
for file in _find_files(project_root):
try:
with open(file) as f:
content = f.read()
content = re.sub(pattern, replacement, content)
with open(file, "w") as f:
f.write(content)
except UnicodeDecodeError:
pass
def main():
project_root = Path(os.path.realpath(sys.argv[0])).parent.parent
module_name = input("What is your python module name? ")
pypi_input = input(f"What is your pypi package name? (default: {module_name}) ")
pypi_name = pypi_input or module_name
repo_input = input(f"What is your github project name? (default: {pypi_name}) ")
repo_name = repo_input or pypi_name
rtd_input = input(
f"What is your readthedocs.org project name? (default: {pypi_name}) "
)
rtd_name = rtd_input or pypi_name
project_input = input(
f"What is your project name (ex: at the top of the README)? (default: {repo_name}) "
)
project_name = project_input or repo_name
short_description = input("What is a one-liner describing the project? ")
_replace("<MODULE_NAME>", module_name, project_root)
_replace("<PYPI_NAME>", pypi_name, project_root)
_replace("<REPO_NAME>", repo_name, project_root)
_replace("<RTD_NAME>", rtd_name, project_root)
_replace("<PROJECT_NAME>", project_name, project_root)
_replace("<SHORT_DESCRIPTION>", short_description, project_root)
os.makedirs(project_root / module_name, exist_ok=True)
Path(project_root / module_name / "__init__.py").touch()
Path(project_root / module_name / "py.typed").touch()
if __name__ == "__main__":
main()

View File

@ -1,39 +0,0 @@
#!/usr/bin/env python3
import os
import sys
from pathlib import Path
import subprocess
def main():
template_dir = Path(os.path.dirname(sys.argv[0]))
template_vars_file = template_dir / "template_vars.txt"
fill_template_vars_script = template_dir / "fill_template_vars.py"
with open(template_vars_file, "r") as input_file:
content_lines = input_file.readlines()
process = subprocess.Popen(
[sys.executable, str(fill_template_vars_script)],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
for line in content_lines:
process.stdin.write(line)
process.stdin.flush()
stdout, stderr = process.communicate()
if process.returncode != 0:
print(f"Error occurred: {stderr}")
sys.exit(1)
print(stdout)
if __name__ == "__main__":
main()

View File

@ -1,6 +0,0 @@
libp2p
libp2p
py-libp2p
py-libp2p
py-libp2p
The Python implementation of the libp2p networking stack

View File

@ -7,12 +7,14 @@ help:
@echo "clean-pyc - remove Python file artifacts"
@echo "clean - run clean-build and clean-pyc"
@echo "dist - build package and cat contents of the dist directory"
@echo "fix - fix formatting & linting issues with ruff"
@echo "lint - fix linting issues with pre-commit"
@echo "test - run tests quickly with the default Python"
@echo "docs - generate docs and open in browser (linux-docs for version on linux)"
@echo "package-test - build package and install it in a venv for manual testing"
@echo "notes - consume towncrier newsfragments and update release notes in docs - requires bump to be set"
@echo "release - package and upload a release (does not run notes target) - requires bump to be set"
@echo "pr - run clean, fix, lint, typecheck, and test i.e basically everything you need to do before creating a PR"
clean-build:
rm -fr build/
@ -37,8 +39,16 @@ lint:
&& pre-commit run --all-files --show-diff-on-failure \
)
fix:
python -m ruff check --fix
typecheck:
pre-commit run mypy-local --all-files && pre-commit run pyrefly-local --all-files
test:
python -m pytest tests
python -m pytest tests -n auto
pr: clean fix lint typecheck test
# protobufs management
@ -48,7 +58,10 @@ PB = libp2p/crypto/pb/crypto.proto \
libp2p/security/secio/pb/spipe.proto \
libp2p/security/noise/pb/noise.proto \
libp2p/identity/identify/pb/identify.proto \
libp2p/host/autonat/pb/autonat.proto
libp2p/host/autonat/pb/autonat.proto \
libp2p/relay/circuit_v2/pb/circuit.proto \
libp2p/kad_dht/pb/kademlia.proto
PY = $(PB:.proto=_pb2.py)
PYI = $(PB:.proto=_pb2.pyi)
@ -80,7 +93,7 @@ validate-newsfragments:
check-docs: build-docs validate-newsfragments
build-docs:
sphinx-apidoc -o docs/ . setup.py "*conftest*" tests/
sphinx-apidoc -o docs/ . "*conftest*" tests/
$(MAKE) -C docs clean
$(MAKE) -C docs html
$(MAKE) -C docs doctest

View File

@ -15,14 +15,24 @@
# documentation root, use os.path.abspath to make it absolute, like shown here.
# sys.path.insert(0, os.path.abspath('.'))
import doctest
import os
import sys
from unittest.mock import MagicMock
DIR = os.path.dirname(__file__)
with open(os.path.join(DIR, "../setup.py"), "r") as f:
for line in f:
if "version=" in line:
setup_version = line.split('"')[1]
break
try:
import tomllib
except ModuleNotFoundError:
# For Python < 3.11
import tomli as tomllib # type: ignore (In case of >3.11 Pyrefly doesnt find tomli , which is right but a false flag)
# Path to pyproject.toml (assuming conf.py is in a 'docs' subdirectory)
pyproject_path = os.path.join(os.path.dirname(__file__), "..", "pyproject.toml")
with open(pyproject_path, "rb") as f:
pyproject_data = tomllib.load(f)
setup_version = pyproject_data["project"]["version"]
# -- General configuration ------------------------------------------------
@ -302,7 +312,6 @@ intersphinx_mapping = {
# -- Doctest configuration ----------------------------------------
import doctest
doctest_default_flags = (
0
@ -317,10 +326,9 @@ doctest_default_flags = (
# Mock out dependencies that are unbuildable on readthedocs, as recommended here:
# https://docs.readthedocs.io/en/rel/faq.html#i-get-import-errors-on-libraries-that-depend-on-c-modules
import sys
from unittest.mock import MagicMock
# Add new modules to mock here (it should be the same list as those excluded in setup.py)
# Add new modules to mock here (it should be the same list
# as those excluded in pyproject.toml)
MOCK_MODULES = [
"fastecdsa",
"fastecdsa.encoding",
@ -338,4 +346,4 @@ todo_include_todos = True
# Allow duplicate object descriptions
nitpicky = False
nitpick_ignore = [("py:class", "type")]
nitpick_ignore = [("py:class", "type")]

View File

@ -0,0 +1,499 @@
Circuit Relay v2 Example
========================
This example demonstrates how to use Circuit Relay v2 in py-libp2p. It includes three components:
1. A relay node that provides relay services
2. A destination node that accepts relayed connections
3. A source node that connects to the destination through the relay
Prerequisites
-------------
First, ensure you have py-libp2p installed:
.. code-block:: console
$ python -m pip install libp2p
Collecting libp2p
...
Successfully installed libp2p-x.x.x
Relay Node
----------
Create a file named ``relay_node.py`` with the following content:
.. code-block:: python
import trio
import logging
import multiaddr
import traceback
from libp2p import new_host
from libp2p.relay.circuit_v2.protocol import CircuitV2Protocol
from libp2p.relay.circuit_v2.transport import CircuitV2Transport
from libp2p.relay.circuit_v2.config import RelayConfig
from libp2p.tools.async_service import background_trio_service
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger("relay_node")
async def run_relay():
listen_addr = multiaddr.Multiaddr("/ip4/0.0.0.0/tcp/9000")
host = new_host()
config = RelayConfig(
enable_hop=True, # Act as a relay
enable_stop=True, # Accept relayed connections
enable_client=False, # Don't use other relays
max_circuit_duration=3600, # 1 hour
max_circuit_bytes=1024 * 1024 * 10, # 10MB
)
# Initialize the relay protocol with allow_hop=True to act as a relay
protocol = CircuitV2Protocol(host, limits=config.limits, allow_hop=True)
print(f"Created relay protocol with hop enabled: {protocol.allow_hop}")
# Start the protocol service
async with host.run(listen_addrs=[listen_addr]):
peer_id = host.get_id()
print("\n" + "="*50)
print(f"Relay node started with ID: {peer_id}")
print(f"Relay node multiaddr: /ip4/127.0.0.1/tcp/9000/p2p/{peer_id}")
print("="*50 + "\n")
print(f"Listening on: {host.get_addrs()}")
try:
async with background_trio_service(protocol):
print("Protocol service started")
transport = CircuitV2Transport(host, protocol, config)
print("Relay service started successfully")
print(f"Relay limits: {protocol.limits}")
while True:
await trio.sleep(10)
print("Relay node still running...")
print(f"Active connections: {len(host.get_network().connections)}")
except Exception as e:
print(f"Error in relay service: {e}")
traceback.print_exc()
if __name__ == "__main__":
try:
trio.run(run_relay)
except Exception as e:
print(f"Error running relay: {e}")
traceback.print_exc()
Destination Node
----------------
Create a file named ``destination_node.py`` with the following content:
.. code-block:: python
import trio
import logging
import multiaddr
import traceback
import sys
from libp2p import new_host
from libp2p.relay.circuit_v2.protocol import CircuitV2Protocol
from libp2p.relay.circuit_v2.transport import CircuitV2Transport
from libp2p.relay.circuit_v2.config import RelayConfig
from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.tools.async_service import background_trio_service
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger("destination_node")
async def handle_echo_stream(stream):
"""Handle incoming stream by echoing received data."""
try:
print(f"New echo stream from: {stream.get_protocol()}")
while True:
data = await stream.read(1024)
if not data:
print("Stream closed by remote")
break
message = data.decode('utf-8')
print(f"Received: {message}")
response = f"Echo: {message}".encode('utf-8')
await stream.write(response)
print(f"Sent response: Echo: {message}")
except Exception as e:
print(f"Error handling stream: {e}")
traceback.print_exc()
finally:
await stream.close()
print("Stream closed")
async def run_destination(relay_peer_id=None):
"""
Run a simple destination node that accepts connections.
This is a simplified version that doesn't use the relay functionality.
"""
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/9001")
host = new_host()
# Configure as a relay receiver (stop)
config = RelayConfig(
enable_stop=True, # Accept relayed connections
enable_client=True, # Use relays for outbound connections
max_circuit_duration=3600, # 1 hour
max_circuit_bytes=1024 * 1024 * 10, # 10MB
)
# Initialize the relay protocol
protocol = CircuitV2Protocol(host, limits=config.limits, allow_hop=False)
async with host.run(listen_addrs=[listen_addr]):
# Print host information
dest_peer_id = host.get_id()
print("\n" + "="*50)
print(f"Destination node started with ID: {dest_peer_id}")
print(f"Use this ID in the source node: {dest_peer_id}")
print("="*50 + "\n")
print(f"Listening on: {host.get_addrs()}")
# Set stream handler for the echo protocol
host.set_stream_handler("/echo/1.0.0", handle_echo_stream)
print("Registered echo protocol handler")
# Start the protocol service in the background
async with background_trio_service(protocol):
print("Protocol service started")
# Create and register the transport
transport = CircuitV2Transport(host, protocol, config)
print("Transport created")
# Create a listener for relayed connections
listener = transport.create_listener(handle_echo_stream)
print("Created relay listener")
# Start listening for relayed connections
async with trio.open_nursery() as nursery:
await listener.listen("/p2p-circuit", nursery)
print("Destination node ready to accept relayed connections")
if not relay_peer_id:
print("No relay peer ID provided. Please enter the relay's peer ID:")
print("Waiting for relay peer ID input...")
while True:
if sys.stdin.isatty(): # Only try to read from stdin if it's a terminal
try:
relay_peer_id = input("Enter relay peer ID: ").strip()
if relay_peer_id:
break
except EOFError:
await trio.sleep(5)
else:
print("No terminal detected. Waiting for relay peer ID as command line argument.")
await trio.sleep(10)
continue
# Connect to the relay node with the provided relay peer ID
relay_addr_str = f"/ip4/127.0.0.1/tcp/9000/p2p/{relay_peer_id}"
print(f"Connecting to relay at {relay_addr_str}")
try:
# Convert string address to multiaddr, then to peer info
relay_maddr = multiaddr.Multiaddr(relay_addr_str)
relay_peer_info = info_from_p2p_addr(relay_maddr)
await host.connect(relay_peer_info)
print("Connected to relay successfully")
# Add the relay to the transport's discovery
transport.discovery._add_relay(relay_peer_info.peer_id)
print(f"Added relay {relay_peer_info.peer_id} to discovery")
# Keep the node running
while True:
await trio.sleep(10)
print("Destination node still running...")
except Exception as e:
print(f"Failed to connect to relay: {e}")
traceback.print_exc()
if __name__ == "__main__":
print("Starting destination node...")
relay_id = None
if len(sys.argv) > 1:
relay_id = sys.argv[1]
print(f"Using provided relay ID: {relay_id}")
trio.run(run_destination, relay_id)
Source Node
-----------
Create a file named ``source_node.py`` with the following content:
.. code-block:: python
import trio
import logging
import multiaddr
import traceback
import sys
from libp2p import new_host
from libp2p.peer.peerinfo import PeerInfo
from libp2p.peer.id import ID
from libp2p.relay.circuit_v2.protocol import CircuitV2Protocol
from libp2p.relay.circuit_v2.transport import CircuitV2Transport
from libp2p.relay.circuit_v2.config import RelayConfig
from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.tools.async_service import background_trio_service
from libp2p.relay.circuit_v2.discovery import RelayInfo
# Configure logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger("source_node")
async def run_source(relay_peer_id=None, destination_peer_id=None):
# Create a libp2p host
listen_addr = multiaddr.Multiaddr("/ip4/0.0.0.0/tcp/9002")
host = new_host()
# Configure as a relay client
config = RelayConfig(
enable_client=True, # Use relays for outbound connections
max_circuit_duration=3600, # 1 hour
max_circuit_bytes=1024 * 1024 * 10, # 10MB
)
# Initialize the relay protocol
protocol = CircuitV2Protocol(host, limits=config.limits, allow_hop=False)
# Start the protocol service
async with host.run(listen_addrs=[listen_addr]):
# Print host information
print(f"Source node started with ID: {host.get_id()}")
print(f"Listening on: {host.get_addrs()}")
# Start the protocol service in the background
async with background_trio_service(protocol):
print("Protocol service started")
# Create and register the transport
transport = CircuitV2Transport(host, protocol, config)
# Get relay peer ID if not provided
if not relay_peer_id:
print("No relay peer ID provided. Please enter the relay's peer ID:")
while True:
if sys.stdin.isatty(): # Only try to read from stdin if it's a terminal
try:
relay_peer_id = input("Enter relay peer ID: ").strip()
if relay_peer_id:
break
except EOFError:
await trio.sleep(5)
else:
print("No terminal detected. Waiting for relay peer ID as command line argument.")
await trio.sleep(10)
continue
# Connect to the relay node with the provided relay peer ID
relay_addr_str = f"/ip4/127.0.0.1/tcp/9000/p2p/{relay_peer_id}"
print(f"Connecting to relay at {relay_addr_str}")
try:
# Convert string address to multiaddr, then to peer info
relay_maddr = multiaddr.Multiaddr(relay_addr_str)
relay_peer_info = info_from_p2p_addr(relay_maddr)
await host.connect(relay_peer_info)
print("Connected to relay successfully")
# Manually add the relay to the discovery service
relay_id = relay_peer_info.peer_id
now = trio.current_time()
# Create relay info and add it to discovery
relay_info = RelayInfo(
peer_id=relay_id,
discovered_at=now,
last_seen=now
)
transport.discovery._discovered_relays[relay_id] = relay_info
print(f"Added relay {relay_id} to discovery")
# Start relay discovery in the background
async with background_trio_service(transport.discovery):
print("Relay discovery started")
# Wait for relay discovery
await trio.sleep(5)
print("Relay discovery completed")
# Get destination peer ID if not provided
if not destination_peer_id:
print("No destination peer ID provided. Please enter the destination's peer ID:")
while True:
if sys.stdin.isatty(): # Only try to read from stdin if it's a terminal
try:
destination_peer_id = input("Enter destination peer ID: ").strip()
if destination_peer_id:
break
except EOFError:
await trio.sleep(5)
else:
print("No terminal detected. Waiting for destination peer ID as command line argument.")
await trio.sleep(10)
continue
print(f"Attempting to connect to {destination_peer_id} via relay")
# Check if we have any discovered relays
discovered_relays = list(transport.discovery._discovered_relays.keys())
print(f"Discovered relays: {discovered_relays}")
try:
# Create a circuit relay multiaddr for the destination
dest_id = ID.from_base58(destination_peer_id)
# Create a circuit multiaddr that includes the relay
# Format: /ip4/127.0.0.1/tcp/9000/p2p/RELAY_ID/p2p-circuit/p2p/DEST_ID
circuit_addr = multiaddr.Multiaddr(f"{relay_addr_str}/p2p-circuit/p2p/{destination_peer_id}")
print(f"Created circuit address: {circuit_addr}")
# Dial using the circuit address
connection = await transport.dial(circuit_addr)
print("Connection established through relay!")
# Open a stream using the echo protocol
stream = await connection.new_stream("/echo/1.0.0")
# Send messages periodically
for i in range(5):
message = f"Hello from source, message {i+1}"
print(f"Sending: {message}")
await stream.write(message.encode('utf-8'))
response = await stream.read(1024)
print(f"Received: {response.decode('utf-8')}")
await trio.sleep(1)
# Close the stream
await stream.close()
print("Stream closed")
except Exception as e:
print(f"Error connecting through relay: {e}")
print("Detailed error:")
traceback.print_exc()
# Keep the node running for a while
await trio.sleep(30)
print("Source node shutting down")
except Exception as e:
print(f"Error: {e}")
traceback.print_exc()
if __name__ == "__main__":
relay_id = None
dest_id = None
# Parse command line arguments if provided
if len(sys.argv) > 1:
relay_id = sys.argv[1]
print(f"Using provided relay ID: {relay_id}")
if len(sys.argv) > 2:
dest_id = sys.argv[2]
print(f"Using provided destination ID: {dest_id}")
trio.run(run_source, relay_id, dest_id)
Running the Example
-------------------
1. First, start the relay node:
.. code-block:: console
$ python relay_node.py
Created relay protocol with hop enabled: True
==================================================
Relay node started with ID: QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx
Relay node multiaddr: /ip4/127.0.0.1/tcp/9000/p2p/QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx
==================================================
Listening on: [<Multiaddr /ip4/0.0.0.0/tcp/9000/p2p/QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx>]
Protocol service started
Relay service started successfully
Relay limits: RelayLimits(duration=3600, data=10485760, max_circuit_conns=8, max_reservations=4)
Note the relay node\'s peer ID (in this example: `QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx`). You\'ll need this for the other nodes.
2. Next, start the destination node:
.. code-block:: console
$ python destination_node.py
Starting destination node...
==================================================
Destination node started with ID: QmPBr38KeQG2ibyL4fxq6yJWpfoVNCqJMHBdNyn1Qe4h5s
Use this ID in the source node: QmPBr38KeQG2ibyL4fxq6yJWpfoVNCqJMHBdNyn1Qe4h5s
==================================================
Listening on: [<Multiaddr /ip4/0.0.0.0/tcp/9001/p2p/QmPBr38KeQG2ibyL4fxq6yJWpfoVNCqJMHBdNyn1Qe4h5s>]
Registered echo protocol handler
Protocol service started
Transport created
Created relay listener
Destination node ready to accept relayed connections
No relay peer ID provided. Please enter the relay\'s peer ID:
Waiting for relay peer ID input...
Enter relay peer ID: QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx
Connecting to relay at /ip4/127.0.0.1/tcp/9000/p2p/QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx
Connected to relay successfully
Added relay QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx to discovery
Destination node still running...
Note the destination node's peer ID (in this example: `QmPBr38KeQG2ibyL4fxq6yJWpfoVNCqJMHBdNyn1Qe4h5s`). You'll need this for the source node.
3. Finally, start the source node:
.. code-block:: console
$ python source_node.py
Source node started with ID: QmPyM56cgmFoHTgvMgGfDWRdVRQznmxCDDDg2dJ8ygVXj3
Listening on: [<Multiaddr /ip4/0.0.0.0/tcp/9002/p2p/QmPyM56cgmFoHTgvMgGfDWRdVRQznmxCDDDg2dJ8ygVXj3>]
Protocol service started
No relay peer ID provided. Please enter the relay\'s peer ID:
Enter relay peer ID: QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx
Connecting to relay at /ip4/127.0.0.1/tcp/9000/p2p/QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx
Connected to relay successfully
Added relay QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx to discovery
Relay discovery started
Relay discovery completed
No destination peer ID provided. Please enter the destination\'s peer ID:
Enter destination peer ID: QmPBr38KeQG2ibyL4fxq6yJWpfoVNCqJMHBdNyn1Qe4h5s
Attempting to connect to QmPBr38KeQG2ibyL4fxq6yJWpfoVNCqJMHBdNyn1Qe4h5s via relay
Discovered relays: [<libp2p.peer.id.ID (QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx)>]
Created circuit address: /ip4/127.0.0.1/tcp/9000/p2p/QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx/p2p-circuit/p2p/QmPBr38KeQG2ibyL4fxq6yJWpfoVNCqJMHBdNyn1Qe4h5s
At this point, the source node will establish a connection through the relay to the destination node and start sending messages.
4. Alternatively, you can provide the peer IDs as command-line arguments:
.. code-block:: console
# For the destination node (provide relay ID)
$ python destination_node.py QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx
# For the source node (provide both relay and destination IDs)
$ python source_node.py QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx QmPBr38KeQG2ibyL4fxq6yJWpfoVNCqJMHBdNyn1Qe4h5s
This example demonstrates how to use Circuit Relay v2 to establish connections between peers that cannot connect directly. The peer IDs are dynamically generated for each node, and the relay facilitates communication between the source and destination nodes.

124
docs/examples.kademlia.rst Normal file
View File

@ -0,0 +1,124 @@
Kademlia DHT Demo
=================
This example demonstrates a Kademlia Distributed Hash Table (DHT) implementation with both value storage/retrieval and content provider advertisement/discovery functionality.
.. code-block:: console
$ python -m pip install libp2p
Collecting libp2p
...
Successfully installed libp2p-x.x.x
$ cd examples/kademlia
$ python kademlia.py --mode server
2025-06-13 19:51:25,424 - kademlia-example - INFO - Running in server mode on port 0
2025-06-13 19:51:25,426 - kademlia-example - INFO - Connected to bootstrap nodes: []
2025-06-13 19:51:25,426 - kademlia-example - INFO - To connect to this node, use: --bootstrap /ip4/127.0.0.1/tcp/28910/p2p/16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef
2025-06-13 19:51:25,426 - kademlia-example - INFO - Saved server address to log: /ip4/127.0.0.1/tcp/28910/p2p/16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef
2025-06-13 19:51:25,427 - kademlia-example - INFO - DHT service started in SERVER mode
2025-06-13 19:51:25,427 - kademlia-example - INFO - Stored value 'Hello message from Sumanjeet' with key: FVDjasarSFDoLPMdgnp1dHSbW2ZAfN8NU2zNbCQeczgP
2025-06-13 19:51:25,427 - kademlia-example - INFO - Successfully advertised as server for content: 361f2ed1183bca491b8aec11f0b9e5c06724759b0f7480ae7fb4894901993bc8
Copy the line that starts with ``--bootstrap``, open a new terminal in the same folder and run the client:
.. code-block:: console
$ python kademlia.py --mode client --bootstrap /ip4/127.0.0.1/tcp/28910/p2p/16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef
2025-06-13 19:51:37,022 - kademlia-example - INFO - Running in client mode on port 0
2025-06-13 19:51:37,026 - kademlia-example - INFO - Connected to bootstrap nodes: [<libp2p.peer.id.ID (16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef)>]
2025-06-13 19:51:37,027 - kademlia-example - INFO - DHT service started in CLIENT mode
2025-06-13 19:51:37,027 - kademlia-example - INFO - Looking up key: FVDjasarSFDoLPMdgnp1dHSbW2ZAfN8NU2zNbCQeczgP
2025-06-13 19:51:37,031 - kademlia-example - INFO - Retrieved value: Hello message from Sumanjeet
2025-06-13 19:51:37,031 - kademlia-example - INFO - Looking for servers of content: 361f2ed1183bca491b8aec11f0b9e5c06724759b0f7480ae7fb4894901993bc8
2025-06-13 19:51:37,035 - kademlia-example - INFO - Found 1 servers for content: ['16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef']
Alternatively, if you run the server first, the client can automatically extract the bootstrap address from the server log file:
.. code-block:: console
$ python kademlia.py --mode client
2025-06-13 19:51:37,022 - kademlia-example - INFO - Running in client mode on port 0
2025-06-13 19:51:37,026 - kademlia-example - INFO - Connected to bootstrap nodes: [<libp2p.peer.id.ID (16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef)>]
2025-06-13 19:51:37,027 - kademlia-example - INFO - DHT service started in CLIENT mode
2025-06-13 19:51:37,027 - kademlia-example - INFO - Looking up key: FVDjasarSFDoLPMdgnp1dHSbW2ZAfN8NU2zNbCQeczgP
2025-06-13 19:51:37,031 - kademlia-example - INFO - Retrieved value: Hello message from Sumanjeet
2025-06-13 19:51:37,031 - kademlia-example - INFO - Looking for servers of content: 361f2ed1183bca491b8aec11f0b9e5c06724759b0f7480ae7fb4894901993bc8
2025-06-13 19:51:37,035 - kademlia-example - INFO - Found 1 servers for content: ['16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef']
The demo showcases key DHT operations:
- **Value Storage & Retrieval**: The server stores a value, and the client retrieves it
- **Content Provider Discovery**: The server advertises content, and the client finds providers
- **Peer Discovery**: Automatic bootstrap and peer routing using the Kademlia algorithm
- **Network Resilience**: Distributed storage across multiple nodes (when available)
Command Line Options
--------------------
The Kademlia demo supports several command line options for customization:
.. code-block:: console
$ python kademlia.py --help
usage: kademlia.py [-h] [--mode MODE] [--port PORT] [--bootstrap [BOOTSTRAP ...]] [--verbose]
Kademlia DHT example with content server functionality
options:
-h, --help show this help message and exit
--mode MODE Run as a server or client node (default: server)
--port PORT Port to listen on (0 for random) (default: 0)
--bootstrap [BOOTSTRAP ...]
Multiaddrs of bootstrap nodes. Provide a space-separated list of addresses.
This is required for client mode.
--verbose Enable verbose logging
**Examples:**
Start server on a specific port:
.. code-block:: console
$ python kademlia.py --mode server --port 8000
Start client with verbose logging:
.. code-block:: console
$ python kademlia.py --mode client --verbose
Connect to multiple bootstrap nodes:
.. code-block:: console
$ python kademlia.py --mode client --bootstrap /ip4/127.0.0.1/tcp/8000/p2p/... /ip4/127.0.0.1/tcp/8001/p2p/...
How It Works
------------
The Kademlia DHT implementation demonstrates several key concepts:
**Server Mode:**
- Stores key-value pairs in the distributed hash table
- Advertises itself as a content provider for specific content
- Handles incoming DHT requests from other nodes
- Maintains routing table with known peers
**Client Mode:**
- Connects to bootstrap nodes to join the network
- Retrieves values by their keys from the DHT
- Discovers content providers for specific content
- Performs network lookups using the Kademlia algorithm
**Key Components:**
- **Routing Table**: Organizes peers in k-buckets based on XOR distance
- **Value Store**: Manages key-value storage with TTL (time-to-live)
- **Provider Store**: Tracks which peers provide specific content
- **Peer Routing**: Implements iterative lookups to find closest peers
The full source code for this example is below:
.. literalinclude:: ../examples/kademlia/kademlia.py
:language: python
:linenos:

View File

@ -11,3 +11,5 @@ Examples
examples.echo
examples.ping
examples.pubsub
examples.circuit_relay
examples.kademlia

View File

@ -12,10 +12,6 @@ The Python implementation of the libp2p networking stack
getting_started
release_notes
.. toctree::
:maxdepth: 1
:caption: Community
.. toctree::
:maxdepth: 1
:caption: py-libp2p

View File

@ -0,0 +1,22 @@
libp2p.kad\_dht.pb package
==========================
Submodules
----------
libp2p.kad_dht.pb.kademlia_pb2 module
-------------------------------------
.. automodule:: libp2p.kad_dht.pb.kademlia_pb2
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: libp2p.kad_dht.pb
:no-index:
:members:
:undoc-members:
:show-inheritance:

77
docs/libp2p.kad_dht.rst Normal file
View File

@ -0,0 +1,77 @@
libp2p.kad\_dht package
=======================
Subpackages
-----------
.. toctree::
:maxdepth: 4
libp2p.kad_dht.pb
Submodules
----------
libp2p.kad\_dht.kad\_dht module
-------------------------------
.. automodule:: libp2p.kad_dht.kad_dht
:members:
:undoc-members:
:show-inheritance:
libp2p.kad\_dht.peer\_routing module
------------------------------------
.. automodule:: libp2p.kad_dht.peer_routing
:members:
:undoc-members:
:show-inheritance:
libp2p.kad\_dht.provider\_store module
--------------------------------------
.. automodule:: libp2p.kad_dht.provider_store
:members:
:undoc-members:
:show-inheritance:
libp2p.kad\_dht.routing\_table module
-------------------------------------
.. automodule:: libp2p.kad_dht.routing_table
:members:
:undoc-members:
:show-inheritance:
libp2p.kad\_dht.utils module
----------------------------
.. automodule:: libp2p.kad_dht.utils
:members:
:undoc-members:
:show-inheritance:
libp2p.kad\_dht.value\_store module
-----------------------------------
.. automodule:: libp2p.kad_dht.value_store
:members:
:undoc-members:
:show-inheritance:
libp2p.kad\_dht.pb
------------------
.. automodule:: libp2p.kad_dht.pb
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: libp2p.kad_dht
:members:
:undoc-members:
:show-inheritance:

View File

@ -0,0 +1,22 @@
libp2p.relay.circuit_v2.pb package
==================================
Submodules
----------
libp2p.relay.circuit_v2.pb.circuit_pb2 module
---------------------------------------------
.. automodule:: libp2p.relay.circuit_v2.pb.circuit_pb2
:members:
:show-inheritance:
:undoc-members:
Module contents
---------------
.. automodule:: libp2p.relay.circuit_v2.pb
:members:
:show-inheritance:
:undoc-members:
:no-index:

View File

@ -0,0 +1,70 @@
libp2p.relay.circuit_v2 package
===============================
Subpackages
-----------
.. toctree::
:maxdepth: 4
libp2p.relay.circuit_v2.pb
Submodules
----------
libp2p.relay.circuit_v2.protocol module
---------------------------------------
.. automodule:: libp2p.relay.circuit_v2.protocol
:members:
:show-inheritance:
:undoc-members:
libp2p.relay.circuit_v2.transport module
----------------------------------------
.. automodule:: libp2p.relay.circuit_v2.transport
:members:
:show-inheritance:
:undoc-members:
libp2p.relay.circuit_v2.discovery module
----------------------------------------
.. automodule:: libp2p.relay.circuit_v2.discovery
:members:
:show-inheritance:
:undoc-members:
libp2p.relay.circuit_v2.resources module
----------------------------------------
.. automodule:: libp2p.relay.circuit_v2.resources
:members:
:show-inheritance:
:undoc-members:
libp2p.relay.circuit_v2.config module
-------------------------------------
.. automodule:: libp2p.relay.circuit_v2.config
:members:
:show-inheritance:
:undoc-members:
libp2p.relay.circuit_v2.protocol_buffer module
----------------------------------------------
.. automodule:: libp2p.relay.circuit_v2.protocol_buffer
:members:
:show-inheritance:
:undoc-members:
Module contents
---------------
.. automodule:: libp2p.relay.circuit_v2
:members:
:show-inheritance:
:undoc-members:
:no-index:

19
docs/libp2p.relay.rst Normal file
View File

@ -0,0 +1,19 @@
libp2p.relay package
====================
Subpackages
-----------
.. toctree::
:maxdepth: 4
libp2p.relay.circuit_v2
Module contents
---------------
.. automodule:: libp2p.relay
:members:
:show-inheritance:
:undoc-members:
:no-index:

View File

@ -11,10 +11,12 @@ Subpackages
libp2p.host
libp2p.identity
libp2p.io
libp2p.kad_dht
libp2p.network
libp2p.peer
libp2p.protocol_muxer
libp2p.pubsub
libp2p.relay
libp2p.security
libp2p.stream_muxer
libp2p.tools

View File

@ -3,6 +3,51 @@ Release Notes
.. towncrier release notes start
py-libp2p v0.2.8 (2025-06-10)
-----------------------------
Breaking Changes
~~~~~~~~~~~~~~~~
- The `NetStream.state` property is now async and requires `await`. Update any direct state access to use `await stream.state`. (`#300 <https://github.com/libp2p/py-libp2p/issues/300>`__)
Bugfixes
~~~~~~~~
- Added proper state management and resource cleanup to `NetStream`, fixing memory leaks and improved error handling. (`#300 <https://github.com/libp2p/py-libp2p/issues/300>`__)
Improved Documentation
~~~~~~~~~~~~~~~~~~~~~~
- Updated examples to automatically use random port, when `-p` flag is not given (`#661 <https://github.com/libp2p/py-libp2p/issues/661>`__)
Features
~~~~~~~~
- Allow passing `listen_addrs` to `new_swarm` to customize swarm listening behavior. (`#616 <https://github.com/libp2p/py-libp2p/issues/616>`__)
- Feature: Support for sending `ls` command over `multistream-select` to list supported protocols from remote peer.
This allows inspecting which protocol handlers a peer supports at runtime. (`#622 <https://github.com/libp2p/py-libp2p/issues/622>`__)
- implement AsyncContextManager for IMuxedStream to support async with (`#629 <https://github.com/libp2p/py-libp2p/issues/629>`__)
- feat: add method to compute time since last message published by a peer and remove fanout peers based on ttl. (`#636 <https://github.com/libp2p/py-libp2p/issues/636>`__)
- implement blacklist management for `pubsub.Pubsub` with methods to get, add, remove, check, and clear blacklisted peer IDs. (`#641 <https://github.com/libp2p/py-libp2p/issues/641>`__)
- fix: remove expired peers from peerstore based on TTL (`#650 <https://github.com/libp2p/py-libp2p/issues/650>`__)
Internal Changes - for py-libp2p Contributors
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- Modernizes several aspects of the project, notably using ``pyproject.toml`` for project info instead of ``setup.py``, using ``ruff`` to replace several separate linting tools, and ``pyrefly`` in addition to ``mypy`` for typing. Also includes changes across the codebase to conform to new linting and typing rules. (`#618 <https://github.com/libp2p/py-libp2p/issues/618>`__)
Removals
~~~~~~~~
- Removes support for python 3.9 and updates some code conventions, notably using ``|`` operator in typing instead of ``Optional`` or ``Union`` (`#618 <https://github.com/libp2p/py-libp2p/issues/618>`__)
py-libp2p v0.2.7 (2025-05-22)
-----------------------------

View File

@ -40,7 +40,6 @@ async def write_data(stream: INetStream) -> None:
async def run(port: int, destination: str) -> None:
localhost_ip = "127.0.0.1"
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
host = new_host()
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
@ -54,8 +53,8 @@ async def run(port: int, destination: str) -> None:
print(
"Run this from the same folder in another console:\n\n"
f"chat-demo -p {int(port) + 1} "
f"-d /ip4/{localhost_ip}/tcp/{port}/p2p/{host.get_id().pretty()}\n"
f"chat-demo "
f"-d {host.get_addrs()[0]}\n"
)
print("Waiting for incoming connection...")
@ -87,9 +86,7 @@ def main() -> None:
"/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
)
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
"-p", "--port", default=8000, type=int, help="source port number"
)
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
parser.add_argument(
"-d",
"--destination",
@ -98,9 +95,6 @@ def main() -> None:
)
args = parser.parse_args()
if not args.port:
raise RuntimeError("was not able to determine a local port")
try:
trio.run(run, *(args.port, args.destination))
except KeyboardInterrupt:

View File

@ -27,6 +27,9 @@ async def main():
# secure_bytes_provider: Optional function to generate secure random bytes
# (defaults to secrets.token_bytes)
secure_bytes_provider=None, # Use default implementation
# peerstore: Optional peerstore to store peer IDs and public keys
# (defaults to None)
peerstore=None,
)
# Create a security options dictionary mapping protocol ID to transport

View File

@ -9,8 +9,10 @@ from libp2p import (
from libp2p.crypto.secp256k1 import (
create_new_key_pair,
)
from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID
from libp2p.security.noise.transport import Transport as NoiseTransport
from libp2p.security.noise.transport import (
PROTOCOL_ID as NOISE_PROTOCOL_ID,
Transport as NoiseTransport,
)
async def main():

View File

@ -9,8 +9,10 @@ from libp2p import (
from libp2p.crypto.secp256k1 import (
create_new_key_pair,
)
from libp2p.security.secio.transport import ID as SECIO_PROTOCOL_ID
from libp2p.security.secio.transport import Transport as SecioTransport
from libp2p.security.secio.transport import (
ID as SECIO_PROTOCOL_ID,
Transport as SecioTransport,
)
async def main():
@ -22,9 +24,6 @@ async def main():
secio_transport = SecioTransport(
# local_key_pair: The key pair used for libp2p identity and authentication
local_key_pair=key_pair,
# secure_bytes_provider: Optional function to generate secure random bytes
# (defaults to secrets.token_bytes)
secure_bytes_provider=None, # Use default implementation
)
# Create a security options dictionary mapping protocol ID to transport

View File

@ -9,10 +9,9 @@ from libp2p import (
from libp2p.crypto.secp256k1 import (
create_new_key_pair,
)
from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID
from libp2p.security.noise.transport import Transport as NoiseTransport
from libp2p.stream_muxer.mplex.mplex import (
MPLEX_PROTOCOL_ID,
from libp2p.security.noise.transport import (
PROTOCOL_ID as NOISE_PROTOCOL_ID,
Transport as NoiseTransport,
)
@ -37,14 +36,8 @@ async def main():
# Create a security options dictionary mapping protocol ID to transport
security_options = {NOISE_PROTOCOL_ID: noise_transport}
# Create a muxer options dictionary mapping protocol ID to muxer class
# We don't need to instantiate the muxer here, the host will do that for us
muxer_options = {MPLEX_PROTOCOL_ID: None}
# Create a host with the key pair, Noise security, and mplex multiplexer
host = new_host(
key_pair=key_pair, sec_opt=security_options, muxer_opt=muxer_options
)
host = new_host(key_pair=key_pair, sec_opt=security_options)
# Configure the listening address
port = 8000

View File

@ -0,0 +1,263 @@
"""
Enhanced NetStream Example for py-libp2p with State Management
This example demonstrates the new NetStream features including:
- State tracking and transitions
- Proper error handling and validation
- Resource cleanup and event notifications
- Thread-safe operations with Trio locks
Based on the standard echo demo but enhanced to show NetStream state management.
"""
import argparse
import random
import secrets
import multiaddr
import trio
from libp2p import (
new_host,
)
from libp2p.crypto.secp256k1 import (
create_new_key_pair,
)
from libp2p.custom_types import (
TProtocol,
)
from libp2p.network.stream.exceptions import (
StreamClosed,
StreamEOF,
StreamReset,
)
from libp2p.network.stream.net_stream import (
NetStream,
StreamState,
)
from libp2p.peer.peerinfo import (
info_from_p2p_addr,
)
PROTOCOL_ID = TProtocol("/echo/1.0.0")
async def enhanced_echo_handler(stream: NetStream) -> None:
"""
Enhanced echo handler that demonstrates NetStream state management.
"""
print(f"New connection established: {stream}")
print(f"Initial stream state: {await stream.state}")
try:
# Verify stream is in expected initial state
assert await stream.state == StreamState.OPEN
assert await stream.is_readable()
assert await stream.is_writable()
print("✓ Stream initialized in OPEN state")
# Read incoming data with proper state checking
print("Waiting for client data...")
while await stream.is_readable():
try:
# Read data from client
data = await stream.read(1024)
if not data:
print("Received empty data, client may have closed")
break
print(f"Received: {data.decode('utf-8').strip()}")
# Check if we can still write before echoing
if await stream.is_writable():
await stream.write(data)
print(f"Echoed: {data.decode('utf-8').strip()}")
else:
print("Cannot echo - stream not writable")
break
except StreamEOF:
print("Client closed their write side (EOF)")
break
except StreamReset:
print("Stream was reset by client")
return
except StreamClosed as e:
print(f"Stream operation failed: {e}")
break
# Demonstrate graceful closure
current_state = await stream.state
print(f"Current state before close: {current_state}")
if current_state not in [StreamState.CLOSE_BOTH, StreamState.RESET]:
await stream.close()
print("Server closed write side")
final_state = await stream.state
print(f"Final stream state: {final_state}")
except Exception as e:
print(f"Handler error: {e}")
# Reset stream on unexpected errors
if await stream.state not in [StreamState.RESET, StreamState.CLOSE_BOTH]:
await stream.reset()
print("Stream reset due to error")
async def enhanced_client_demo(stream: NetStream) -> None:
"""
Enhanced client that demonstrates various NetStream state scenarios.
"""
print(f"Client stream established: {stream}")
print(f"Initial state: {await stream.state}")
try:
# Verify initial state
assert await stream.state == StreamState.OPEN
print("✓ Client stream in OPEN state")
# Scenario 1: Normal communication
message = b"Hello from enhanced NetStream client!\n"
if await stream.is_writable():
await stream.write(message)
print(f"Sent: {message.decode('utf-8').strip()}")
else:
print("Cannot write - stream not writable")
return
# Close write side to signal EOF to server
await stream.close()
print("Client closed write side")
# Verify state transition
state_after_close = await stream.state
print(f"State after close: {state_after_close}")
assert state_after_close == StreamState.CLOSE_WRITE
assert await stream.is_readable() # Should still be readable
assert not await stream.is_writable() # Should not be writable
# Try to write (should fail)
try:
await stream.write(b"This should fail")
print("ERROR: Write succeeded when it should have failed!")
except StreamClosed as e:
print(f"✓ Expected error when writing to closed stream: {e}")
# Read the echo response
if await stream.is_readable():
try:
response = await stream.read()
print(f"Received echo: {response.decode('utf-8').strip()}")
except StreamEOF:
print("Server closed their write side")
except StreamReset:
print("Stream was reset")
# Check final state
final_state = await stream.state
print(f"Final client state: {final_state}")
except Exception as e:
print(f"Client error: {e}")
# Reset on error
await stream.reset()
print("Client reset stream due to error")
async def run_enhanced_demo(
port: int, destination: str, seed: int | None = None
) -> None:
"""
Run enhanced echo demo with NetStream state management.
"""
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
# Generate or use provided key
if seed:
random.seed(seed)
secret_number = random.getrandbits(32 * 8)
secret = secret_number.to_bytes(length=32, byteorder="big")
else:
secret = secrets.token_bytes(32)
host = new_host(key_pair=create_new_key_pair(secret))
async with host.run(listen_addrs=[listen_addr]):
print(f"Host ID: {host.get_id().to_string()}")
print("=" * 60)
if not destination: # Server mode
print("🖥️ ENHANCED ECHO SERVER MODE")
print("=" * 60)
# type: ignore: Stream is type of NetStream
host.set_stream_handler(PROTOCOL_ID, enhanced_echo_handler)
print(
"Run client from another console:\n"
f"python3 example_net_stream.py "
f"-d {host.get_addrs()[0]}\n"
)
print("Waiting for connections...")
print("Press Ctrl+C to stop server")
await trio.sleep_forever()
else: # Client mode
print("📱 ENHANCED ECHO CLIENT MODE")
print("=" * 60)
# Connect to server
maddr = multiaddr.Multiaddr(destination)
info = info_from_p2p_addr(maddr)
await host.connect(info)
print(f"Connected to server: {info.peer_id.pretty()}")
# Create stream and run enhanced demo
stream = await host.new_stream(info.peer_id, [PROTOCOL_ID])
if isinstance(stream, NetStream):
await enhanced_client_demo(stream)
print("\n" + "=" * 60)
print("CLIENT DEMO COMPLETE")
def main() -> None:
example_maddr = (
"/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
)
parser = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
parser.add_argument(
"-d",
"--destination",
type=str,
help=f"destination multiaddr string, e.g. {example_maddr}",
)
parser.add_argument(
"-s",
"--seed",
type=int,
help="seed for deterministic peer ID generation",
)
parser.add_argument(
"--demo-states", action="store_true", help="run state transition demo only"
)
args = parser.parse_args()
try:
trio.run(run_enhanced_demo, args.port, args.destination, args.seed)
except KeyboardInterrupt:
print("\n👋 Demo interrupted by user")
except Exception as e:
print(f"❌ Demo failed: {e}")
if __name__ == "__main__":
main()

View File

@ -12,10 +12,9 @@ from libp2p.crypto.secp256k1 import (
from libp2p.peer.peerinfo import (
info_from_p2p_addr,
)
from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID
from libp2p.security.noise.transport import Transport as NoiseTransport
from libp2p.stream_muxer.mplex.mplex import (
MPLEX_PROTOCOL_ID,
from libp2p.security.noise.transport import (
PROTOCOL_ID as NOISE_PROTOCOL_ID,
Transport as NoiseTransport,
)
@ -40,14 +39,8 @@ async def main():
# Create a security options dictionary mapping protocol ID to transport
security_options = {NOISE_PROTOCOL_ID: noise_transport}
# Create a muxer options dictionary mapping protocol ID to muxer class
# We don't need to instantiate the muxer here, the host will do that for us
muxer_options = {MPLEX_PROTOCOL_ID: None}
# Create a host with the key pair, Noise security, and mplex multiplexer
host = new_host(
key_pair=key_pair, sec_opt=security_options, muxer_opt=muxer_options
)
host = new_host(key_pair=key_pair, sec_opt=security_options)
# Configure the listening address
port = 8000

View File

@ -9,10 +9,9 @@ from libp2p import (
from libp2p.crypto.secp256k1 import (
create_new_key_pair,
)
from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID
from libp2p.security.noise.transport import Transport as NoiseTransport
from libp2p.stream_muxer.mplex.mplex import (
MPLEX_PROTOCOL_ID,
from libp2p.security.noise.transport import (
PROTOCOL_ID as NOISE_PROTOCOL_ID,
Transport as NoiseTransport,
)
@ -37,14 +36,8 @@ async def main():
# Create a security options dictionary mapping protocol ID to transport
security_options = {NOISE_PROTOCOL_ID: noise_transport}
# Create a muxer options dictionary mapping protocol ID to muxer class
# We don't need to instantiate the muxer here, the host will do that for us
muxer_options = {MPLEX_PROTOCOL_ID: None}
# Create a host with the key pair, Noise security, and mplex multiplexer
host = new_host(
key_pair=key_pair, sec_opt=security_options, muxer_opt=muxer_options
)
host = new_host(key_pair=key_pair, sec_opt=security_options)
# Configure the listening address
port = 8000

View File

@ -20,17 +20,17 @@ from libp2p.peer.peerinfo import (
)
PROTOCOL_ID = TProtocol("/echo/1.0.0")
MAX_READ_LEN = 2**32 - 1
async def _echo_stream_handler(stream: INetStream) -> None:
# Wait until EOF
msg = await stream.read()
msg = await stream.read(MAX_READ_LEN)
await stream.write(msg)
await stream.close()
async def run(port: int, destination: str, seed: int = None) -> None:
localhost_ip = "127.0.0.1"
async def run(port: int, destination: str, seed: int | None = None) -> None:
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
if seed:
@ -53,8 +53,8 @@ async def run(port: int, destination: str, seed: int = None) -> None:
print(
"Run this from the same folder in another console:\n\n"
f"echo-demo -p {int(port) + 1} "
f"-d /ip4/{localhost_ip}/tcp/{port}/p2p/{host.get_id().pretty()}\n"
f"echo-demo "
f"-d {host.get_addrs()[0]}\n"
)
print("Waiting for incoming connections...")
await trio.sleep_forever()
@ -73,9 +73,8 @@ async def run(port: int, destination: str, seed: int = None) -> None:
msg = b"hi, there!\n"
await stream.write(msg)
# Notify the other side about EOF
await stream.close()
response = await stream.read()
await stream.close()
print(f"Sent: {msg.decode('utf-8')}")
print(f"Got: {response.decode('utf-8')}")
@ -94,9 +93,7 @@ def main() -> None:
"/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
)
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
"-p", "--port", default=8000, type=int, help="source port number"
)
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
parser.add_argument(
"-d",
"--destination",
@ -110,10 +107,6 @@ def main() -> None:
help="provide a seed to the random number generator (e.g. to fix peer IDs across runs)", # noqa: E501
)
args = parser.parse_args()
if not args.port:
raise RuntimeError("was not able to determine a local port")
try:
trio.run(run, args.port, args.destination, args.seed)
except KeyboardInterrupt:

View File

@ -61,20 +61,20 @@ async def run(port: int, destination: str) -> None:
async with host_a.run(listen_addrs=[listen_addr]):
print(
"First host listening. Run this from another console:\n\n"
f"identify-demo -p {int(port) + 1} "
f"-d /ip4/{localhost_ip}/tcp/{port}/p2p/{host_a.get_id().pretty()}\n"
f"identify-demo "
f"-d {host_a.get_addrs()[0]}\n"
)
print("Waiting for incoming identify request...")
await trio.sleep_forever()
else:
# Create second host (dialer)
print(f"dialer (host_b) listening on /ip4/{localhost_ip}/tcp/{port}")
listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}")
host_b = new_host()
async with host_b.run(listen_addrs=[listen_addr]):
# Connect to the first host
print(f"dialer (host_b) listening on {host_b.get_addrs()[0]}")
maddr = multiaddr.Multiaddr(destination)
info = info_from_p2p_addr(maddr)
print(f"Second host connecting to peer: {info.peer_id}")
@ -104,13 +104,11 @@ def main() -> None:
"""
example_maddr = (
"/ip4/127.0.0.1/tcp/8888/p2p/QmQn4SwGkDZkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
"/ip4/127.0.0.1/tcp/8888/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
)
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
"-p", "--port", default=8888, type=int, help="source port number"
)
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
parser.add_argument(
"-d",
"--destination",
@ -119,9 +117,6 @@ def main() -> None:
)
args = parser.parse_args()
if not args.port:
raise RuntimeError("failed to determine local port")
try:
trio.run(run, *(args.port, args.destination))
except KeyboardInterrupt:

View File

@ -38,17 +38,17 @@ from libp2p.crypto.secp256k1 import (
create_new_key_pair,
)
from libp2p.identity.identify import (
ID as ID_IDENTIFY,
identify_handler_for,
)
from libp2p.identity.identify import ID as ID_IDENTIFY
from libp2p.identity.identify.pb.identify_pb2 import (
Identify,
)
from libp2p.identity.identify_push import (
ID_PUSH as ID_IDENTIFY_PUSH,
identify_push_handler_for,
push_identify_to_peer,
)
from libp2p.identity.identify_push import ID_PUSH as ID_IDENTIFY_PUSH
from libp2p.peer.peerinfo import (
info_from_p2p_addr,
)
@ -56,9 +56,6 @@ from libp2p.peer.peerinfo import (
# Configure logging
logger = logging.getLogger("libp2p.identity.identify-push-example")
# Default port configuration
DEFAULT_PORT = 8888
def custom_identify_push_handler_for(host):
"""
@ -241,25 +238,16 @@ def main() -> None:
"""Parse arguments and start the appropriate mode."""
description = """
This program demonstrates the libp2p identify/push protocol.
Without arguments, it runs as a listener on port 8888.
With -d parameter, it runs as a dialer on port 8889.
Without arguments, it runs as a listener on random port.
With -d parameter, it runs as a dialer on random port.
"""
example = (
f"/ip4/127.0.0.1/tcp/{DEFAULT_PORT}/p2p/"
"QmQn4SwGkDZkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
"/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
)
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
"-p",
"--port",
type=int,
help=(
f"port to listen on (default: {DEFAULT_PORT} for listener, "
f"{DEFAULT_PORT + 1} for dialer)"
),
)
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
parser.add_argument(
"-d",
"--destination",
@ -270,13 +258,11 @@ def main() -> None:
try:
if args.destination:
# Run in dialer mode with default port DEFAULT_PORT + 1 if not specified
port = args.port if args.port is not None else DEFAULT_PORT + 1
trio.run(run_dialer, port, args.destination)
# Run in dialer mode with random available port if not specified
trio.run(run_dialer, args.port, args.destination)
else:
# Run in listener mode with default port DEFAULT_PORT if not specified
port = args.port if args.port is not None else DEFAULT_PORT
trio.run(run_listener, port)
# Run in listener mode with random available port if not specified
trio.run(run_listener, args.port)
except KeyboardInterrupt:
print("\nInterrupted by user")
logger.info("Interrupted by user")

View File

@ -0,0 +1,300 @@
#!/usr/bin/env python
"""
A basic example of using the Kademlia DHT implementation, with all setup logic inlined.
This example demonstrates both value storage/retrieval and content server
advertisement/discovery.
"""
import argparse
import logging
import os
import random
import secrets
import sys
import base58
from multiaddr import (
Multiaddr,
)
import trio
from libp2p import (
new_host,
)
from libp2p.abc import (
IHost,
)
from libp2p.crypto.secp256k1 import (
create_new_key_pair,
)
from libp2p.kad_dht.kad_dht import (
DHTMode,
KadDHT,
)
from libp2p.kad_dht.utils import (
create_key_from_binary,
)
from libp2p.tools.async_service import (
background_trio_service,
)
from libp2p.tools.utils import (
info_from_p2p_addr,
)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler()],
)
logger = logging.getLogger("kademlia-example")
# Configure DHT module loggers to inherit from the parent logger
# This ensures all kademlia-example.* loggers use the same configuration
# Get the directory where this script is located
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
SERVER_ADDR_LOG = os.path.join(SCRIPT_DIR, "server_node_addr.txt")
# Set the level for all child loggers
for module in [
"kad_dht",
"value_store",
"peer_routing",
"routing_table",
"provider_store",
]:
child_logger = logging.getLogger(f"kademlia-example.{module}")
child_logger.setLevel(logging.INFO)
child_logger.propagate = True # Allow propagation to parent
# File to store node information
bootstrap_nodes = []
# function to take bootstrap_nodes as input and connects to them
async def connect_to_bootstrap_nodes(host: IHost, bootstrap_addrs: list[str]) -> None:
"""
Connect to the bootstrap nodes provided in the list.
params: host: The host instance to connect to
bootstrap_addrs: List of bootstrap node addresses
Returns
-------
None
"""
for addr in bootstrap_addrs:
try:
peerInfo = info_from_p2p_addr(Multiaddr(addr))
host.get_peerstore().add_addrs(peerInfo.peer_id, peerInfo.addrs, 3600)
await host.connect(peerInfo)
except Exception as e:
logger.error(f"Failed to connect to bootstrap node {addr}: {e}")
def save_server_addr(addr: str) -> None:
"""Append the server's multiaddress to the log file."""
try:
with open(SERVER_ADDR_LOG, "w") as f:
f.write(addr + "\n")
logger.info(f"Saved server address to log: {addr}")
except Exception as e:
logger.error(f"Failed to save server address: {e}")
def load_server_addrs() -> list[str]:
"""Load all server multiaddresses from the log file."""
if not os.path.exists(SERVER_ADDR_LOG):
return []
try:
with open(SERVER_ADDR_LOG) as f:
return [line.strip() for line in f if line.strip()]
except Exception as e:
logger.error(f"Failed to load server addresses: {e}")
return []
async def run_node(
port: int, mode: str, bootstrap_addrs: list[str] | None = None
) -> None:
"""Run a node that serves content in the DHT with setup inlined."""
try:
if port <= 0:
port = random.randint(10000, 60000)
logger.debug(f"Using port: {port}")
# Convert string mode to DHTMode enum
if mode is None or mode.upper() == "CLIENT":
dht_mode = DHTMode.CLIENT
elif mode.upper() == "SERVER":
dht_mode = DHTMode.SERVER
else:
logger.error(f"Invalid mode: {mode}. Must be 'client' or 'server'")
sys.exit(1)
# Load server addresses for client mode
if dht_mode == DHTMode.CLIENT:
server_addrs = load_server_addrs()
if server_addrs:
logger.info(f"Loaded {len(server_addrs)} server addresses from log")
bootstrap_nodes.append(server_addrs[0]) # Use the first server address
else:
logger.warning("No server addresses found in log file")
if bootstrap_addrs:
for addr in bootstrap_addrs:
bootstrap_nodes.append(addr)
key_pair = create_new_key_pair(secrets.token_bytes(32))
host = new_host(key_pair=key_pair)
listen_addr = Multiaddr(f"/ip4/127.0.0.1/tcp/{port}")
async with host.run(listen_addrs=[listen_addr]):
peer_id = host.get_id().pretty()
addr_str = f"/ip4/127.0.0.1/tcp/{port}/p2p/{peer_id}"
await connect_to_bootstrap_nodes(host, bootstrap_nodes)
dht = KadDHT(host, dht_mode)
# take all peer ids from the host and add them to the dht
for peer_id in host.get_peerstore().peer_ids():
await dht.routing_table.add_peer(peer_id)
logger.info(f"Connected to bootstrap nodes: {host.get_connected_peers()}")
bootstrap_cmd = f"--bootstrap {addr_str}"
logger.info("To connect to this node, use: %s", bootstrap_cmd)
# Save server address in server mode
if dht_mode == DHTMode.SERVER:
save_server_addr(addr_str)
# Start the DHT service
async with background_trio_service(dht):
logger.info(f"DHT service started in {dht_mode.value} mode")
val_key = create_key_from_binary(b"py-libp2p kademlia example value")
content = b"Hello from python node "
content_key = create_key_from_binary(content)
if dht_mode == DHTMode.SERVER:
# Store a value in the DHT
msg = "Hello message from Sumanjeet"
val_data = msg.encode()
await dht.put_value(val_key, val_data)
logger.info(
f"Stored value '{val_data.decode()}'"
f"with key: {base58.b58encode(val_key).decode()}"
)
# Advertise as content server
success = await dht.provider_store.provide(content_key)
if success:
logger.info(
"Successfully advertised as server"
f"for content: {content_key.hex()}"
)
else:
logger.warning("Failed to advertise as content server")
else:
# retrieve the value
logger.info(
"Looking up key: %s", base58.b58encode(val_key).decode()
)
val_data = await dht.get_value(val_key)
if val_data:
try:
logger.info(f"Retrieved value: {val_data.decode()}")
except UnicodeDecodeError:
logger.info(f"Retrieved value (bytes): {val_data!r}")
else:
logger.warning("Failed to retrieve value")
# Also check if we can find servers for our own content
logger.info("Looking for servers of content: %s", content_key.hex())
providers = await dht.provider_store.find_providers(content_key)
if providers:
logger.info(
"Found %d servers for content: %s",
len(providers),
[p.peer_id.pretty() for p in providers],
)
else:
logger.warning(
"No servers found for content %s", content_key.hex()
)
# Keep the node running
while True:
logger.debug(
"Status - Connected peers: %d,"
"Peers in store: %d, Values in store: %d",
len(dht.host.get_connected_peers()),
len(dht.host.get_peerstore().peer_ids()),
len(dht.value_store.store),
)
await trio.sleep(10)
except Exception as e:
logger.error(f"Server node error: {e}", exc_info=True)
sys.exit(1)
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="Kademlia DHT example with content server functionality"
)
parser.add_argument(
"--mode",
default="server",
help="Run as a server or client node",
)
parser.add_argument(
"--port",
type=int,
default=0,
help="Port to listen on (0 for random)",
)
parser.add_argument(
"--bootstrap",
type=str,
nargs="*",
help=(
"Multiaddrs of bootstrap nodes. "
"Provide a space-separated list of addresses. "
"This is required for client mode."
),
)
# add option to use verbose logging
parser.add_argument(
"--verbose",
action="store_true",
help="Enable verbose logging",
)
args = parser.parse_args()
# Set logging level based on verbosity
if args.verbose:
logging.getLogger().setLevel(logging.DEBUG)
else:
logging.getLogger().setLevel(logging.INFO)
return args
def main():
"""Main entry point for the kademlia demo."""
try:
args = parse_args()
logger.info(
"Running in %s mode on port %d",
args.mode,
args.port,
)
trio.run(run_node, args.port, args.mode, args.bootstrap)
except Exception as e:
logger.critical(f"Script failed: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
main()

View File

@ -55,7 +55,6 @@ async def send_ping(stream: INetStream) -> None:
async def run(port: int, destination: str) -> None:
localhost_ip = "127.0.0.1"
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
host = new_host(listen_addrs=[listen_addr])
@ -65,8 +64,8 @@ async def run(port: int, destination: str) -> None:
print(
"Run this from the same folder in another console:\n\n"
f"ping-demo -p {int(port) + 1} "
f"-d /ip4/{localhost_ip}/tcp/{port}/p2p/{host.get_id().pretty()}\n"
f"ping-demo "
f"-d {host.get_addrs()[0]}\n"
)
print("Waiting for incoming connection...")
@ -96,10 +95,8 @@ def main() -> None:
)
parser = argparse.ArgumentParser(description=description)
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
parser.add_argument(
"-p", "--port", default=8000, type=int, help="source port number"
)
parser.add_argument(
"-d",
"--destination",
@ -108,9 +105,6 @@ def main() -> None:
)
args = parser.parse_args()
if not args.port:
raise RuntimeError("failed to determine local port")
try:
trio.run(run, *(args.port, args.destination))
except KeyboardInterrupt:

View File

@ -1,9 +1,6 @@
import argparse
import logging
import socket
from typing import (
Optional,
)
import base58
import multiaddr
@ -109,7 +106,7 @@ async def monitor_peer_topics(pubsub, nursery, termination_event):
await trio.sleep(2)
async def run(topic: str, destination: Optional[str], port: Optional[int]) -> None:
async def run(topic: str, destination: str | None, port: int | None) -> None:
# Initialize network settings
localhost_ip = "127.0.0.1"

View File

@ -152,12 +152,12 @@ def get_default_muxer_options() -> TMuxerOptions:
def new_swarm(
key_pair: Optional[KeyPair] = None,
muxer_opt: Optional[TMuxerOptions] = None,
sec_opt: Optional[TSecurityOptions] = None,
peerstore_opt: Optional[IPeerStore] = None,
muxer_preference: Optional[Literal["YAMUX", "MPLEX"]] = None,
listen_addrs: Optional[Sequence[multiaddr.Multiaddr]] = None,
key_pair: KeyPair | None = None,
muxer_opt: TMuxerOptions | None = None,
sec_opt: TSecurityOptions | None = None,
peerstore_opt: IPeerStore | None = None,
muxer_preference: Literal["YAMUX", "MPLEX"] | None = None,
listen_addrs: Sequence[multiaddr.Multiaddr] | None = None,
) -> INetworkService:
"""
Create a swarm instance based on the parameters.
@ -200,7 +200,9 @@ def new_swarm(
key_pair, noise_privkey=noise_key_pair.private_key
),
TProtocol(secio.ID): secio.Transport(key_pair),
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair),
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(
key_pair, peerstore=peerstore_opt
),
}
# Use given muxer preference if provided, otherwise use global default
@ -236,13 +238,13 @@ def new_swarm(
def new_host(
key_pair: Optional[KeyPair] = None,
muxer_opt: Optional[TMuxerOptions] = None,
sec_opt: Optional[TSecurityOptions] = None,
peerstore_opt: Optional[IPeerStore] = None,
disc_opt: Optional[IPeerRouting] = None,
muxer_preference: Optional[Literal["YAMUX", "MPLEX"]] = None,
listen_addrs: Sequence[multiaddr.Multiaddr] = None,
key_pair: KeyPair | None = None,
muxer_opt: TMuxerOptions | None = None,
sec_opt: TSecurityOptions | None = None,
peerstore_opt: IPeerStore | None = None,
disc_opt: IPeerRouting | None = None,
muxer_preference: Literal["YAMUX", "MPLEX"] | None = None,
listen_addrs: Sequence[multiaddr.Multiaddr] | None = None,
) -> IHost:
"""
Create a new libp2p host based on the given parameters.

View File

@ -8,6 +8,7 @@ from collections.abc import (
KeysView,
Sequence,
)
from contextlib import AbstractAsyncContextManager
from types import (
TracebackType,
)
@ -15,7 +16,6 @@ from typing import (
TYPE_CHECKING,
Any,
AsyncContextManager,
Optional,
)
from multiaddr import (
@ -160,7 +160,11 @@ class IMuxedConn(ABC):
event_started: trio.Event
@abstractmethod
def __init__(self, conn: ISecureConn, peer_id: ID) -> None:
def __init__(
self,
conn: ISecureConn,
peer_id: ID,
) -> None:
"""
Initialize a new multiplexed connection.
@ -260,9 +264,9 @@ class IMuxedStream(ReadWriteCloser, AsyncContextManager["IMuxedStream"]):
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Exit the async context manager and close the stream."""
await self.close()
@ -287,7 +291,7 @@ class INetStream(ReadWriteCloser):
muxed_conn: IMuxedConn
@abstractmethod
def get_protocol(self) -> TProtocol:
def get_protocol(self) -> TProtocol | None:
"""
Retrieve the protocol identifier for the stream.
@ -916,7 +920,7 @@ class INetwork(ABC):
"""
@abstractmethod
async def listen(self, *multiaddrs: Sequence[Multiaddr]) -> bool:
async def listen(self, *multiaddrs: Multiaddr) -> bool:
"""
Start listening on one or more multiaddresses.
@ -1174,7 +1178,9 @@ class IHost(ABC):
"""
@abstractmethod
def run(self, listen_addrs: Sequence[Multiaddr]) -> AsyncContextManager[None]:
def run(
self, listen_addrs: Sequence[Multiaddr]
) -> AbstractAsyncContextManager[None]:
"""
Run the host and start listening on the specified multiaddresses.
@ -1434,6 +1440,60 @@ class IPeerData(ABC):
"""
@abstractmethod
def update_last_identified(self) -> None:
"""
Updates timestamp to current time.
"""
@abstractmethod
def get_last_identified(self) -> int:
"""
Fetch the last identified timestamp
Returns
-------
last_identified_timestamp
The lastIdentified time of peer.
"""
@abstractmethod
def get_ttl(self) -> int:
"""
Get ttl value for the peer for validity check
Returns
-------
int
The ttl of the peer.
"""
@abstractmethod
def set_ttl(self, ttl: int) -> None:
"""
Set ttl value for the peer for validity check
Parameters
----------
ttl : int
The ttl for the peer.
"""
@abstractmethod
def is_expired(self) -> bool:
"""
Check if the peer is expired based on last_identified and ttl
Returns
-------
bool
True, if last_identified + ttl > current_time
"""
# ------------------ multiselect_communicator interface.py ------------------
@ -1564,7 +1624,7 @@ class IMultiselectMuxer(ABC):
and its corresponding handler for communication.
"""
handlers: dict[TProtocol, StreamHandlerFn]
handlers: dict[TProtocol | None, StreamHandlerFn | None]
@abstractmethod
def add_handler(self, protocol: TProtocol, handler: StreamHandlerFn) -> None:
@ -1580,7 +1640,7 @@ class IMultiselectMuxer(ABC):
"""
def get_protocols(self) -> tuple[TProtocol, ...]:
def get_protocols(self) -> tuple[TProtocol | None, ...]:
"""
Retrieve the protocols for which handlers have been registered.
@ -1595,7 +1655,7 @@ class IMultiselectMuxer(ABC):
@abstractmethod
async def negotiate(
self, communicator: IMultiselectCommunicator
) -> tuple[TProtocol, StreamHandlerFn]:
) -> tuple[TProtocol | None, StreamHandlerFn | None]:
"""
Negotiate a protocol selection with a multiselect client.
@ -1672,7 +1732,7 @@ class IPeerRouting(ABC):
"""
@abstractmethod
async def find_peer(self, peer_id: ID) -> PeerInfo:
async def find_peer(self, peer_id: ID) -> PeerInfo | None:
"""
Search for a peer with the specified peer ID.
@ -1840,6 +1900,11 @@ class IPubsubRouter(ABC):
"""
mesh: dict[str, set[ID]]
fanout: dict[str, set[ID]]
peer_protocol: dict[ID, TProtocol]
degree: int
@abstractmethod
def get_protocols(self) -> list[TProtocol]:
"""
@ -1865,7 +1930,7 @@ class IPubsubRouter(ABC):
"""
@abstractmethod
def add_peer(self, peer_id: ID, protocol_id: TProtocol) -> None:
def add_peer(self, peer_id: ID, protocol_id: TProtocol | None) -> None:
"""
Notify the router that a new peer has connected.
@ -2065,14 +2130,14 @@ class IPubsub(ServiceAPI):
...
@abstractmethod
async def publish(self, topic_id: str, data: bytes) -> None:
async def publish(self, topic_id: str | list[str], data: bytes) -> None:
"""
Publish a message to a topic.
Publish a message to a topic or multiple topics.
Parameters
----------
topic_id : str
The identifier of the topic.
topic_id : str | list[str]
The identifier of the topic (str) or topics (list[str]).
data : bytes
The data to publish.

View File

@ -116,15 +116,15 @@ def initialize_pair(
EncryptionParameters(
cipher_type,
hash_type,
first_half[0:iv_size],
first_half[iv_size + cipher_key_size :],
first_half[iv_size : iv_size + cipher_key_size],
bytes(first_half[0:iv_size]),
bytes(first_half[iv_size + cipher_key_size :]),
bytes(first_half[iv_size : iv_size + cipher_key_size]),
),
EncryptionParameters(
cipher_type,
hash_type,
second_half[0:iv_size],
second_half[iv_size + cipher_key_size :],
second_half[iv_size : iv_size + cipher_key_size],
bytes(second_half[0:iv_size]),
bytes(second_half[iv_size + cipher_key_size :]),
bytes(second_half[iv_size : iv_size + cipher_key_size]),
),
)

View File

@ -9,29 +9,40 @@ from libp2p.crypto.keys import (
if sys.platform != "win32":
from fastecdsa import (
curve as curve_types,
keys,
point,
)
from fastecdsa import curve as curve_types
from fastecdsa.encoding.sec1 import (
SEC1Encoder,
)
else:
from coincurve import PrivateKey as CPrivateKey
from coincurve import PublicKey as CPublicKey
from coincurve import (
PrivateKey as CPrivateKey,
PublicKey as CPublicKey,
)
def infer_local_type(curve: str) -> object:
"""
Convert a str representation of some elliptic curve to a
representation understood by the backend of this module.
"""
if curve != "P-256":
raise NotImplementedError("Only P-256 curve is supported")
if sys.platform != "win32":
if sys.platform != "win32":
def infer_local_type(curve: str) -> curve_types.Curve:
"""
Convert a str representation of some elliptic curve to a
representation understood by the backend of this module.
"""
if curve != "P-256":
raise NotImplementedError("Only P-256 curve is supported")
return curve_types.P256
return "P-256" # coincurve only supports P-256
else:
def infer_local_type(curve: str) -> str:
"""
Convert a str representation of some elliptic curve to a
representation understood by the backend of this module.
"""
if curve != "P-256":
raise NotImplementedError("Only P-256 curve is supported")
return "P-256" # coincurve only supports P-256
if sys.platform != "win32":
@ -68,7 +79,10 @@ if sys.platform != "win32":
return cls(private_key_impl, curve_type)
def to_bytes(self) -> bytes:
return keys.export_key(self.impl, self.curve)
key_str = keys.export_key(self.impl, self.curve)
if key_str is None:
raise Exception("Key not found")
return key_str.encode()
def get_type(self) -> KeyType:
return KeyType.ECC_P256

View File

@ -4,8 +4,10 @@ from Crypto.Hash import (
from nacl.exceptions import (
BadSignatureError,
)
from nacl.public import PrivateKey as PrivateKeyImpl
from nacl.public import PublicKey as PublicKeyImpl
from nacl.public import (
PrivateKey as PrivateKeyImpl,
PublicKey as PublicKeyImpl,
)
from nacl.signing import (
SigningKey,
VerifyKey,
@ -48,7 +50,7 @@ class Ed25519PrivateKey(PrivateKey):
self.impl = impl
@classmethod
def new(cls, seed: bytes = None) -> "Ed25519PrivateKey":
def new(cls, seed: bytes | None = None) -> "Ed25519PrivateKey":
if not seed:
seed = utils.random()
@ -75,7 +77,7 @@ class Ed25519PrivateKey(PrivateKey):
return Ed25519PublicKey(self.impl.public_key)
def create_new_key_pair(seed: bytes = None) -> KeyPair:
def create_new_key_pair(seed: bytes | None = None) -> KeyPair:
private_key = Ed25519PrivateKey.new(seed)
public_key = private_key.get_public_key()
return KeyPair(private_key, public_key)

View File

@ -1,6 +1,6 @@
from collections.abc import Callable
import sys
from typing import (
Callable,
cast,
)

View File

@ -81,12 +81,10 @@ class PrivateKey(Key):
"""A ``PrivateKey`` represents a cryptographic private key."""
@abstractmethod
def sign(self, data: bytes) -> bytes:
...
def sign(self, data: bytes) -> bytes: ...
@abstractmethod
def get_public_key(self) -> PublicKey:
...
def get_public_key(self) -> PublicKey: ...
def _serialize_to_protobuf(self) -> crypto_pb2.PrivateKey:
"""Return the protobuf representation of this ``Key``."""

View File

@ -37,7 +37,7 @@ class Secp256k1PrivateKey(PrivateKey):
self.impl = impl
@classmethod
def new(cls, secret: bytes = None) -> "Secp256k1PrivateKey":
def new(cls, secret: bytes | None = None) -> "Secp256k1PrivateKey":
private_key_impl = coincurve.PrivateKey(secret)
return cls(private_key_impl)
@ -65,7 +65,7 @@ class Secp256k1PrivateKey(PrivateKey):
return Secp256k1PublicKey(public_key_impl)
def create_new_key_pair(secret: bytes = None) -> KeyPair:
def create_new_key_pair(secret: bytes | None = None) -> KeyPair:
"""
Returns a new Secp256k1 keypair derived from the provided ``secret``, a
sequence of bytes corresponding to some integer between 0 and the group

View File

@ -1,13 +1,9 @@
from collections.abc import (
Awaitable,
Callable,
Mapping,
)
from typing import (
TYPE_CHECKING,
Callable,
NewType,
Union,
)
from typing import TYPE_CHECKING, NewType, Union, cast
if TYPE_CHECKING:
from libp2p.abc import (
@ -16,15 +12,9 @@ if TYPE_CHECKING:
ISecureTransport,
)
else:
class INetStream:
pass
class IMuxedConn:
pass
class ISecureTransport:
pass
IMuxedConn = cast(type, object)
INetStream = cast(type, object)
ISecureTransport = cast(type, object)
from libp2p.io.abc import (
@ -38,10 +28,10 @@ from libp2p.pubsub.pb import (
)
TProtocol = NewType("TProtocol", str)
StreamHandlerFn = Callable[["INetStream"], Awaitable[None]]
StreamHandlerFn = Callable[[INetStream], Awaitable[None]]
THandler = Callable[[ReadWriteCloser], Awaitable[None]]
TSecurityOptions = Mapping[TProtocol, "ISecureTransport"]
TMuxerClass = type["IMuxedConn"]
TSecurityOptions = Mapping[TProtocol, ISecureTransport]
TMuxerClass = type[IMuxedConn]
TMuxerOptions = Mapping[TProtocol, TMuxerClass]
SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool]
AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]]

View File

@ -1,7 +1,4 @@
import logging
from typing import (
Union,
)
from libp2p.custom_types import (
TProtocol,
@ -94,7 +91,7 @@ class AutoNATService:
finally:
await stream.close()
async def _handle_request(self, request: Union[bytes, Message]) -> Message:
async def _handle_request(self, request: bytes | Message) -> Message:
"""
Process an AutoNAT protocol request.

View File

@ -84,26 +84,23 @@ class AutoNAT:
request: Any,
target: str,
options: tuple[Any, ...] = (),
channel_credentials: Optional[Any] = None,
call_credentials: Optional[Any] = None,
channel_credentials: Any | None = None,
call_credentials: Any | None = None,
insecure: bool = False,
compression: Optional[Any] = None,
wait_for_ready: Optional[bool] = None,
timeout: Optional[float] = None,
metadata: Optional[list[tuple[str, str]]] = None,
compression: Any | None = None,
wait_for_ready: bool | None = None,
timeout: float | None = None,
metadata: list[tuple[str, str]] | None = None,
) -> Any:
return grpc.experimental.unary_unary(
request,
target,
channel = grpc.secure_channel(target, channel_credentials) if channel_credentials else grpc.insecure_channel(target)
return channel.unary_unary(
"/autonat.pb.AutoNAT/Dial",
autonat__pb2.Message.SerializeToString,
autonat__pb2.Message.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
request_serializer=autonat__pb2.Message.SerializeToString,
response_deserializer=autonat__pb2.Message.FromString,
_registered_method=True,
)(
request,
timeout=timeout,
metadata=metadata,
wait_for_ready=wait_for_ready,
)

View File

@ -3,6 +3,7 @@ from collections.abc import (
Sequence,
)
from contextlib import (
AbstractAsyncContextManager,
asynccontextmanager,
)
import logging
@ -88,14 +89,14 @@ class BasicHost(IHost):
def __init__(
self,
network: INetworkService,
default_protocols: "OrderedDict[TProtocol, StreamHandlerFn]" = None,
default_protocols: Optional["OrderedDict[TProtocol, StreamHandlerFn]"] = None,
) -> None:
self._network = network
self._network.set_stream_handler(self._swarm_stream_handler)
self.peerstore = self._network.peerstore
# Protocol muxing
default_protocols = default_protocols or get_default_protocols(self)
self.multiselect = Multiselect(default_protocols)
self.multiselect = Multiselect(dict(default_protocols.items()))
self.multiselect_client = MultiselectClient()
def get_id(self) -> ID:
@ -147,19 +148,23 @@ class BasicHost(IHost):
"""
return list(self._network.connections.keys())
@asynccontextmanager
async def run(
def run(
self, listen_addrs: Sequence[multiaddr.Multiaddr]
) -> AsyncIterator[None]:
) -> AbstractAsyncContextManager[None]:
"""
Run the host instance and listen to ``listen_addrs``.
:param listen_addrs: a sequence of multiaddrs that we want to listen to
"""
network = self.get_network()
async with background_trio_service(network):
await network.listen(*listen_addrs)
yield
@asynccontextmanager
async def _run() -> AsyncIterator[None]:
network = self.get_network()
async with background_trio_service(network):
await network.listen(*listen_addrs)
yield
return _run()
def set_stream_handler(
self, protocol_id: TProtocol, stream_handler: StreamHandlerFn
@ -229,7 +234,7 @@ class BasicHost(IHost):
:param peer_info: peer_info of the peer we want to connect to
:type peer_info: peer.peerinfo.PeerInfo
"""
self.peerstore.add_addrs(peer_info.peer_id, peer_info.addrs, 10)
self.peerstore.add_addrs(peer_info.peer_id, peer_info.addrs, 120)
# there is already a connection to this peer
if peer_info.peer_id in self._network.connections:
@ -258,6 +263,15 @@ class BasicHost(IHost):
await net_stream.reset()
return
net_stream.set_protocol(protocol)
if handler is None:
logger.debug(
"no handler for protocol %s, closing stream from peer %s",
protocol,
net_stream.muxed_conn.peer_id,
)
await net_stream.reset()
return
await handler(net_stream)
def get_live_peers(self) -> list[ID]:
@ -277,7 +291,7 @@ class BasicHost(IHost):
"""
return peer_id in self._network.connections
def get_peer_connection_info(self, peer_id: ID) -> Optional[INetConn]:
def get_peer_connection_info(self, peer_id: ID) -> INetConn | None:
"""
Get connection information for a specific peer if connected.

View File

@ -9,13 +9,13 @@ from libp2p.abc import (
IHost,
)
from libp2p.host.ping import (
ID as PingID,
handle_ping,
)
from libp2p.host.ping import ID as PingID
from libp2p.identity.identify.identify import (
ID as IdentifyID,
identify_handler_for,
)
from libp2p.identity.identify.identify import ID as IdentifyID
if TYPE_CHECKING:
from libp2p.custom_types import (

View File

@ -40,8 +40,8 @@ class RoutedHost(BasicHost):
found_peer_info = await self._router.find_peer(peer_info.peer_id)
if not found_peer_info:
raise ConnectionFailure("Unable to find Peer address")
self.peerstore.add_addrs(peer_info.peer_id, found_peer_info.addrs, 10)
self.peerstore.add_addrs(peer_info.peer_id, peer_info.addrs, 10)
self.peerstore.add_addrs(peer_info.peer_id, found_peer_info.addrs, 120)
self.peerstore.add_addrs(peer_info.peer_id, peer_info.addrs, 120)
# there is already a connection to this peer
if peer_info.peer_id in self._network.connections:

View File

@ -1,7 +1,4 @@
import logging
from typing import (
Optional,
)
from multiaddr import (
Multiaddr,
@ -40,8 +37,8 @@ def _multiaddr_to_bytes(maddr: Multiaddr) -> bytes:
def _remote_address_to_multiaddr(
remote_address: Optional[tuple[str, int]]
) -> Optional[Multiaddr]:
remote_address: tuple[str, int] | None,
) -> Multiaddr | None:
"""Convert a (host, port) tuple to a Multiaddr."""
if remote_address is None:
return None
@ -58,7 +55,7 @@ def _remote_address_to_multiaddr(
def _mk_identify_protobuf(
host: IHost, observed_multiaddr: Optional[Multiaddr]
host: IHost, observed_multiaddr: Multiaddr | None
) -> Identify:
public_key = host.get_public_key()
laddrs = host.get_addrs()
@ -81,15 +78,14 @@ def identify_handler_for(host: IHost) -> StreamHandlerFn:
peer_id = (
stream.muxed_conn.peer_id
) # remote peer_id is in class Mplex (mplex.py )
observed_multiaddr: Multiaddr | None = None
# Get the remote address
try:
remote_address = stream.get_remote_address()
# Convert to multiaddr
if remote_address:
observed_multiaddr = _remote_address_to_multiaddr(remote_address)
else:
observed_multiaddr = None
logger.debug(
"Connection from remote peer %s, address: %s, multiaddr: %s",
peer_id,

View File

@ -1,7 +1,4 @@
import logging
from typing import (
Optional,
)
from multiaddr import (
Multiaddr,
@ -135,7 +132,7 @@ async def _update_peerstore_from_identify(
async def push_identify_to_peer(
host: IHost, peer_id: ID, observed_multiaddr: Optional[Multiaddr] = None
host: IHost, peer_id: ID, observed_multiaddr: Multiaddr | None = None
) -> bool:
"""
Push an identify message to a specific peer.
@ -172,8 +169,8 @@ async def push_identify_to_peer(
async def push_identify_to_peers(
host: IHost,
peer_ids: Optional[set[ID]] = None,
observed_multiaddr: Optional[Multiaddr] = None,
peer_ids: set[ID] | None = None,
observed_multiaddr: Multiaddr | None = None,
) -> None:
"""
Push an identify message to multiple peers in parallel.

View File

@ -2,27 +2,22 @@ from abc import (
ABC,
abstractmethod,
)
from typing import (
Optional,
)
from typing import Any
class Closer(ABC):
@abstractmethod
async def close(self) -> None:
...
async def close(self) -> None: ...
class Reader(ABC):
@abstractmethod
async def read(self, n: int = None) -> bytes:
...
async def read(self, n: int | None = None) -> bytes: ...
class Writer(ABC):
@abstractmethod
async def write(self, data: bytes) -> None:
...
async def write(self, data: bytes) -> None: ...
class WriteCloser(Writer, Closer):
@ -39,7 +34,7 @@ class ReadWriter(Reader, Writer):
class ReadWriteCloser(Reader, Writer, Closer):
@abstractmethod
def get_remote_address(self) -> Optional[tuple[str, int]]:
def get_remote_address(self) -> tuple[str, int] | None:
"""
Return the remote address of the connected peer.
@ -50,14 +45,12 @@ class ReadWriteCloser(Reader, Writer, Closer):
class MsgReader(ABC):
@abstractmethod
async def read_msg(self) -> bytes:
...
async def read_msg(self) -> bytes: ...
class MsgWriter(ABC):
@abstractmethod
async def write_msg(self, msg: bytes) -> None:
...
async def write_msg(self, msg: bytes) -> None: ...
class MsgReadWriteCloser(MsgReader, MsgWriter, Closer):
@ -66,19 +59,26 @@ class MsgReadWriteCloser(MsgReader, MsgWriter, Closer):
class Encrypter(ABC):
@abstractmethod
def encrypt(self, data: bytes) -> bytes:
...
def encrypt(self, data: bytes) -> bytes: ...
@abstractmethod
def decrypt(self, data: bytes) -> bytes:
...
def decrypt(self, data: bytes) -> bytes: ...
class EncryptedMsgReadWriter(MsgReadWriteCloser, Encrypter):
"""Read/write message with encryption/decryption."""
def get_remote_address(self) -> Optional[tuple[str, int]]:
conn: Any | None
def __init__(self, conn: Any | None = None):
self.conn = conn
def get_remote_address(self) -> tuple[str, int] | None:
"""Get remote address if supported by the underlying connection."""
if hasattr(self, "conn") and hasattr(self.conn, "get_remote_address"):
if (
self.conn is not None
and hasattr(self, "conn")
and hasattr(self.conn, "get_remote_address")
):
return self.conn.get_remote_address()
return None

View File

@ -5,6 +5,7 @@ from that repo: "a simple package to r/w length-delimited slices."
NOTE: currently missing the capability to indicate lengths by "varint" method.
"""
from abc import (
abstractmethod,
)
@ -60,12 +61,10 @@ class BaseMsgReadWriter(MsgReadWriteCloser):
return await read_exactly(self.read_write_closer, length)
@abstractmethod
async def next_msg_len(self) -> int:
...
async def next_msg_len(self) -> int: ...
@abstractmethod
def encode_msg(self, msg: bytes) -> bytes:
...
def encode_msg(self, msg: bytes) -> bytes: ...
async def close(self) -> None:
await self.read_write_closer.close()

View File

@ -1,7 +1,4 @@
import logging
from typing import (
Optional,
)
import trio
@ -34,7 +31,7 @@ class TrioTCPStream(ReadWriteCloser):
except (trio.ClosedResourceError, trio.BrokenResourceError) as error:
raise IOException from error
async def read(self, n: int = None) -> bytes:
async def read(self, n: int | None = None) -> bytes:
async with self.read_lock:
if n is not None and n == 0:
return b""
@ -46,7 +43,7 @@ class TrioTCPStream(ReadWriteCloser):
async def close(self) -> None:
await self.stream.aclose()
def get_remote_address(self) -> Optional[tuple[str, int]]:
def get_remote_address(self) -> tuple[str, int] | None:
"""Return the remote address as (host, port) tuple."""
try:
return self.stream.socket.getpeername()

View File

@ -14,12 +14,14 @@ async def read_exactly(
"""
NOTE: relying on exceptions to break out on erroneous conditions, like EOF
"""
data = await reader.read(n)
buffer = bytearray()
buffer.extend(await reader.read(n))
for _ in range(retry_count):
if len(data) < n:
remaining = n - len(data)
data += await reader.read(remaining)
if len(buffer) < n:
remaining = n - len(buffer)
buffer.extend(await reader.read(remaining))
else:
return data
raise IncompleteReadError({"requested_count": n, "received_count": len(data)})
return bytes(buffer)
raise IncompleteReadError({"requested_count": n, "received_count": len(buffer)})

View File

@ -0,0 +1,30 @@
"""
Kademlia DHT implementation for py-libp2p.
This module provides a Distributed Hash Table (DHT) implementation
based on the Kademlia protocol.
"""
from .kad_dht import (
KadDHT,
)
from .peer_routing import (
PeerRouting,
)
from .routing_table import (
RoutingTable,
)
from .utils import (
create_key_from_binary,
)
from .value_store import (
ValueStore,
)
__all__ = [
"KadDHT",
"RoutingTable",
"PeerRouting",
"ValueStore",
"create_key_from_binary",
]

14
libp2p/kad_dht/common.py Normal file
View File

@ -0,0 +1,14 @@
"""
Shared constants and protocol parameters for the Kademlia DHT.
"""
from libp2p.custom_types import (
TProtocol,
)
# Constants for the Kademlia algorithm
ALPHA = 3 # Concurrency parameter
PROTOCOL_ID = TProtocol("/ipfs/kad/1.0.0")
QUERY_TIMEOUT = 10
TTL = DEFAULT_TTL = 24 * 60 * 60 # 24 hours in seconds

616
libp2p/kad_dht/kad_dht.py Normal file
View File

@ -0,0 +1,616 @@
"""
Kademlia DHT implementation for py-libp2p.
This module provides a complete Distributed Hash Table (DHT)
implementation based on the Kademlia algorithm and protocol.
"""
from enum import (
Enum,
)
import logging
import time
from multiaddr import (
Multiaddr,
)
import trio
import varint
from libp2p.abc import (
IHost,
)
from libp2p.network.stream.net_stream import (
INetStream,
)
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
from libp2p.tools.async_service import (
Service,
)
from .common import (
ALPHA,
PROTOCOL_ID,
QUERY_TIMEOUT,
)
from .pb.kademlia_pb2 import (
Message,
)
from .peer_routing import (
PeerRouting,
)
from .provider_store import (
ProviderStore,
)
from .routing_table import (
RoutingTable,
)
from .value_store import (
ValueStore,
)
logger = logging.getLogger("kademlia-example.kad_dht")
# logger = logging.getLogger("libp2p.kademlia")
# Default parameters
ROUTING_TABLE_REFRESH_INTERVAL = 60 # 1 min in seconds for testing
class DHTMode(Enum):
"""DHT operation modes."""
CLIENT = "CLIENT"
SERVER = "SERVER"
class KadDHT(Service):
"""
Kademlia DHT implementation for libp2p.
This class provides a DHT implementation that combines routing table management,
peer discovery, content routing, and value storage.
"""
def __init__(self, host: IHost, mode: DHTMode):
"""
Initialize a new Kademlia DHT node.
:param host: The libp2p host.
:param mode: The mode of host (Client or Server) - must be DHTMode enum
"""
super().__init__()
self.host = host
self.local_peer_id = host.get_id()
# Validate that mode is a DHTMode enum
if not isinstance(mode, DHTMode):
raise TypeError(f"mode must be DHTMode enum, got {type(mode)}")
self.mode = mode
# Initialize the routing table
self.routing_table = RoutingTable(self.local_peer_id, self.host)
# Initialize peer routing
self.peer_routing = PeerRouting(host, self.routing_table)
# Initialize value store
self.value_store = ValueStore(host=host, local_peer_id=self.local_peer_id)
# Initialize provider store with host and peer_routing references
self.provider_store = ProviderStore(host=host, peer_routing=self.peer_routing)
# Last time we republished provider records
self._last_provider_republish = time.time()
# Set protocol handlers
host.set_stream_handler(PROTOCOL_ID, self.handle_stream)
async def run(self) -> None:
"""Run the DHT service."""
logger.info(f"Starting Kademlia DHT with peer ID {self.local_peer_id}")
# Main service loop
while self.manager.is_running:
# Periodically refresh the routing table
await self.refresh_routing_table()
# Check if it's time to republish provider records
current_time = time.time()
# await self._republish_provider_records()
self._last_provider_republish = current_time
# Clean up expired values and provider records
expired_values = self.value_store.cleanup_expired()
if expired_values > 0:
logger.debug(f"Cleaned up {expired_values} expired values")
self.provider_store.cleanup_expired()
# Wait before next maintenance cycle
await trio.sleep(ROUTING_TABLE_REFRESH_INTERVAL)
async def switch_mode(self, new_mode: DHTMode) -> DHTMode:
"""
Switch the DHT mode.
:param new_mode: The new mode - must be DHTMode enum
:return: The new mode as DHTMode enum
"""
# Validate that new_mode is a DHTMode enum
if not isinstance(new_mode, DHTMode):
raise TypeError(f"new_mode must be DHTMode enum, got {type(new_mode)}")
if new_mode == DHTMode.CLIENT:
self.routing_table.cleanup_routing_table()
self.mode = new_mode
logger.info(f"Switched to {new_mode.value} mode")
return self.mode
async def handle_stream(self, stream: INetStream) -> None:
"""
Handle an incoming DHT stream using varint length prefixes.
"""
if self.mode == DHTMode.CLIENT:
stream.close
return
peer_id = stream.muxed_conn.peer_id
logger.debug(f"Received DHT stream from peer {peer_id}")
await self.add_peer(peer_id)
logger.debug(f"Added peer {peer_id} to routing table")
try:
# Read varint-prefixed length for the message
length_prefix = b""
while True:
byte = await stream.read(1)
if not byte:
logger.warning("Stream closed while reading varint length")
await stream.close()
return
length_prefix += byte
if byte[0] & 0x80 == 0:
break
msg_length = varint.decode_bytes(length_prefix)
# Read the message bytes
msg_bytes = await stream.read(msg_length)
if len(msg_bytes) < msg_length:
logger.warning("Failed to read full message from stream")
await stream.close()
return
try:
# Parse as protobuf
message = Message()
message.ParseFromString(msg_bytes)
logger.debug(
f"Received DHT message from {peer_id}, type: {message.type}"
)
# Handle FIND_NODE message
if message.type == Message.MessageType.FIND_NODE:
# Get target key directly from protobuf
target_key = message.key
# Find closest peers to the target key
closest_peers = self.routing_table.find_local_closest_peers(
target_key, 20
)
logger.debug(f"Found {len(closest_peers)} peers close to target")
# Build response message with protobuf
response = Message()
response.type = Message.MessageType.FIND_NODE
# Add closest peers to response
for peer in closest_peers:
# Skip if the peer is the requester
if peer == peer_id:
continue
# Add peer to closerPeers field
peer_proto = response.closerPeers.add()
peer_proto.id = peer.to_bytes()
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
# Add addresses if available
try:
addrs = self.host.get_peerstore().addrs(peer)
if addrs:
for addr in addrs:
peer_proto.addrs.append(addr.to_bytes())
except Exception:
pass
# Serialize and send response
response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes)))
await stream.write(response_bytes)
logger.debug(
f"Sent FIND_NODE response with{len(response.closerPeers)} peers"
)
# Handle ADD_PROVIDER message
elif message.type == Message.MessageType.ADD_PROVIDER:
# Process ADD_PROVIDER
key = message.key
logger.debug(f"Received ADD_PROVIDER for key {key.hex()}")
# Extract provider information
for provider_proto in message.providerPeers:
try:
# Validate that the provider is the sender
provider_id = ID(provider_proto.id)
if provider_id != peer_id:
logger.warning(
f"Provider ID {provider_id} doesn't"
f"match sender {peer_id}, ignoring"
)
continue
# Convert addresses to Multiaddr
addrs = []
for addr_bytes in provider_proto.addrs:
try:
addrs.append(Multiaddr(addr_bytes))
except Exception as e:
logger.warning(f"Failed to parse address: {e}")
# Add to provider store
provider_info = PeerInfo(provider_id, addrs)
self.provider_store.add_provider(key, provider_info)
logger.debug(
f"Added provider {provider_id} for key {key.hex()}"
)
except Exception as e:
logger.warning(f"Failed to process provider info: {e}")
# Send acknowledgement
response = Message()
response.type = Message.MessageType.ADD_PROVIDER
response.key = key
response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes)))
await stream.write(response_bytes)
logger.debug("Sent ADD_PROVIDER acknowledgement")
# Handle GET_PROVIDERS message
elif message.type == Message.MessageType.GET_PROVIDERS:
# Process GET_PROVIDERS
key = message.key
logger.debug(f"Received GET_PROVIDERS request for key {key.hex()}")
# Find providers for the key
providers = self.provider_store.get_providers(key)
logger.debug(
f"Found {len(providers)} providers for key {key.hex()}"
)
# Create response
response = Message()
response.type = Message.MessageType.GET_PROVIDERS
response.key = key
# Add provider information to response
for provider_info in providers:
provider_proto = response.providerPeers.add()
provider_proto.id = provider_info.peer_id.to_bytes()
provider_proto.connection = Message.ConnectionType.CAN_CONNECT
# Add addresses if available
for addr in provider_info.addrs:
provider_proto.addrs.append(addr.to_bytes())
# Also include closest peers if we don't have providers
if not providers:
closest_peers = self.routing_table.find_local_closest_peers(
key, 20
)
logger.debug(
f"No providers found, including {len(closest_peers)}"
"closest peers"
)
for peer in closest_peers:
# Skip if peer is the requester
if peer == peer_id:
continue
peer_proto = response.closerPeers.add()
peer_proto.id = peer.to_bytes()
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
# Add addresses if available
try:
addrs = self.host.get_peerstore().addrs(peer)
for addr in addrs:
peer_proto.addrs.append(addr.to_bytes())
except Exception:
pass
# Serialize and send response
response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes)))
await stream.write(response_bytes)
logger.debug("Sent GET_PROVIDERS response")
# Handle GET_VALUE message
elif message.type == Message.MessageType.GET_VALUE:
# Process GET_VALUE
key = message.key
logger.debug(f"Received GET_VALUE request for key {key.hex()}")
value = self.value_store.get(key)
if value:
logger.debug(f"Found value for key {key.hex()}")
# Create response using protobuf
response = Message()
response.type = Message.MessageType.GET_VALUE
# Create record
response.key = key
response.record.key = key
response.record.value = value
response.record.timeReceived = str(time.time())
# Serialize and send response
response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes)))
await stream.write(response_bytes)
logger.debug("Sent GET_VALUE response")
else:
logger.debug(f"No value found for key {key.hex()}")
# Create response with closest peers when no value is found
response = Message()
response.type = Message.MessageType.GET_VALUE
response.key = key
# Add closest peers to key
closest_peers = self.routing_table.find_local_closest_peers(
key, 20
)
logger.debug(
"No value found,"
f"including {len(closest_peers)} closest peers"
)
for peer in closest_peers:
# Skip if peer is the requester
if peer == peer_id:
continue
peer_proto = response.closerPeers.add()
peer_proto.id = peer.to_bytes()
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
# Add addresses if available
try:
addrs = self.host.get_peerstore().addrs(peer)
for addr in addrs:
peer_proto.addrs.append(addr.to_bytes())
except Exception:
pass
# Serialize and send response
response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes)))
await stream.write(response_bytes)
logger.debug("Sent GET_VALUE response with closest peers")
# Handle PUT_VALUE message
elif message.type == Message.MessageType.PUT_VALUE and message.HasField(
"record"
):
# Process PUT_VALUE
key = message.record.key
value = message.record.value
success = False
try:
if not (key and value):
raise ValueError(
"Missing key or value in PUT_VALUE message"
)
self.value_store.put(key, value)
logger.debug(f"Stored value {value.hex()} for key {key.hex()}")
success = True
except Exception as e:
logger.warning(
f"Failed to store value {value.hex()} for key "
f"{key.hex()}: {e}"
)
finally:
# Send acknowledgement
response = Message()
response.type = Message.MessageType.PUT_VALUE
if success:
response.key = key
response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes)))
await stream.write(response_bytes)
logger.debug("Sent PUT_VALUE acknowledgement")
except Exception as proto_err:
logger.warning(f"Failed to parse protobuf message: {proto_err}")
await stream.close()
except Exception as e:
logger.error(f"Error handling DHT stream: {e}")
await stream.close()
async def refresh_routing_table(self) -> None:
"""Refresh the routing table."""
logger.debug("Refreshing routing table")
await self.peer_routing.refresh_routing_table()
# Peer routing methods
async def find_peer(self, peer_id: ID) -> PeerInfo | None:
"""
Find a peer with the given ID.
"""
logger.debug(f"Finding peer: {peer_id}")
return await self.peer_routing.find_peer(peer_id)
# Value storage and retrieval methods
async def put_value(self, key: bytes, value: bytes) -> None:
"""
Store a value in the DHT.
"""
logger.debug(f"Storing value for key {key.hex()}")
# 1. Store locally first
self.value_store.put(key, value)
try:
decoded_value = value.decode("utf-8")
except UnicodeDecodeError:
decoded_value = value.hex()
logger.debug(
f"Stored value locally for key {key.hex()} with value {decoded_value}"
)
# 2. Get closest peers, excluding self
closest_peers = [
peer
for peer in self.routing_table.find_local_closest_peers(key)
if peer != self.local_peer_id
]
logger.debug(f"Found {len(closest_peers)} peers to store value at")
# 3. Store at remote peers in batches of ALPHA, in parallel
stored_count = 0
for i in range(0, len(closest_peers), ALPHA):
batch = closest_peers[i : i + ALPHA]
batch_results = [False] * len(batch)
async def store_one(idx: int, peer: ID) -> None:
try:
with trio.move_on_after(QUERY_TIMEOUT):
success = await self.value_store._store_at_peer(
peer, key, value
)
batch_results[idx] = success
if success:
logger.debug(f"Stored value at peer {peer}")
else:
logger.debug(f"Failed to store value at peer {peer}")
except Exception as e:
logger.debug(f"Error storing value at peer {peer}: {e}")
async with trio.open_nursery() as nursery:
for idx, peer in enumerate(batch):
nursery.start_soon(store_one, idx, peer)
stored_count += sum(batch_results)
logger.info(f"Successfully stored value at {stored_count} peers")
async def get_value(self, key: bytes) -> bytes | None:
logger.debug(f"Getting value for key: {key.hex()}")
# 1. Check local store first
value = self.value_store.get(key)
if value:
logger.debug("Found value locally")
return value
# 2. Get closest peers, excluding self
closest_peers = [
peer
for peer in self.routing_table.find_local_closest_peers(key)
if peer != self.local_peer_id
]
logger.debug(f"Searching {len(closest_peers)} peers for value")
# 3. Query ALPHA peers at a time in parallel
for i in range(0, len(closest_peers), ALPHA):
batch = closest_peers[i : i + ALPHA]
found_value = None
async def query_one(peer: ID) -> None:
nonlocal found_value
try:
with trio.move_on_after(QUERY_TIMEOUT):
value = await self.value_store._get_from_peer(peer, key)
if value is not None and found_value is None:
found_value = value
logger.debug(f"Found value at peer {peer}")
except Exception as e:
logger.debug(f"Error querying peer {peer}: {e}")
async with trio.open_nursery() as nursery:
for peer in batch:
nursery.start_soon(query_one, peer)
if found_value is not None:
self.value_store.put(key, found_value)
logger.info("Successfully retrieved value from network")
return found_value
# 4. Not found
logger.warning(f"Value not found for key {key.hex()}")
return None
# Add these methods in the Utility methods section
# Utility methods
async def add_peer(self, peer_id: ID) -> bool:
"""
Add a peer to the routing table.
params: peer_id: The peer ID to add.
Returns
-------
bool
True if peer was added or updated, False otherwise.
"""
return await self.routing_table.add_peer(peer_id)
async def provide(self, key: bytes) -> bool:
"""
Reference to provider_store.provide for convenience.
"""
return await self.provider_store.provide(key)
async def find_providers(self, key: bytes, count: int = 20) -> list[PeerInfo]:
"""
Reference to provider_store.find_providers for convenience.
"""
return await self.provider_store.find_providers(key, count)
def get_routing_table_size(self) -> int:
"""
Get the number of peers in the routing table.
Returns
-------
int
Number of peers.
"""
return self.routing_table.size()
def get_value_store_size(self) -> int:
"""
Get the number of items in the value store.
Returns
-------
int
Number of items.
"""
return self.value_store.size()

View File

View File

@ -0,0 +1,38 @@
syntax = "proto3";
message Record {
bytes key = 1;
bytes value = 2;
string timeReceived = 5;
};
message Message {
enum MessageType {
PUT_VALUE = 0;
GET_VALUE = 1;
ADD_PROVIDER = 2;
GET_PROVIDERS = 3;
FIND_NODE = 4;
PING = 5;
}
enum ConnectionType {
NOT_CONNECTED = 0;
CONNECTED = 1;
CAN_CONNECT = 2;
CANNOT_CONNECT = 3;
}
message Peer {
bytes id = 1;
repeated bytes addrs = 2;
ConnectionType connection = 3;
}
MessageType type = 1;
int32 clusterLevelRaw = 10;
bytes key = 2;
Record record = 3;
repeated Peer closerPeers = 8;
repeated Peer providerPeers = 9;
}

View File

@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: libp2p/kad_dht/pb/kademlia.proto
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/kad_dht/pb/kademlia.proto\":\n\x06Record\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x14\n\x0ctimeReceived\x18\x05 \x01(\t\"\xca\x03\n\x07Message\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.Message.MessageType\x12\x17\n\x0f\x63lusterLevelRaw\x18\n \x01(\x05\x12\x0b\n\x03key\x18\x02 \x01(\x0c\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record\x12\"\n\x0b\x63loserPeers\x18\x08 \x03(\x0b\x32\r.Message.Peer\x12$\n\rproviderPeers\x18\t \x03(\x0b\x32\r.Message.Peer\x1aN\n\x04Peer\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12+\n\nconnection\x18\x03 \x01(\x0e\x32\x17.Message.ConnectionType\"i\n\x0bMessageType\x12\r\n\tPUT_VALUE\x10\x00\x12\r\n\tGET_VALUE\x10\x01\x12\x10\n\x0c\x41\x44\x44_PROVIDER\x10\x02\x12\x11\n\rGET_PROVIDERS\x10\x03\x12\r\n\tFIND_NODE\x10\x04\x12\x08\n\x04PING\x10\x05\"W\n\x0e\x43onnectionType\x12\x11\n\rNOT_CONNECTED\x10\x00\x12\r\n\tCONNECTED\x10\x01\x12\x0f\n\x0b\x43\x41N_CONNECT\x10\x02\x12\x12\n\x0e\x43\x41NNOT_CONNECT\x10\x03\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.kad_dht.pb.kademlia_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_globals['_RECORD']._serialized_start=36
_globals['_RECORD']._serialized_end=94
_globals['_MESSAGE']._serialized_start=97
_globals['_MESSAGE']._serialized_end=555
_globals['_MESSAGE_PEER']._serialized_start=281
_globals['_MESSAGE_PEER']._serialized_end=359
_globals['_MESSAGE_MESSAGETYPE']._serialized_start=361
_globals['_MESSAGE_MESSAGETYPE']._serialized_end=466
_globals['_MESSAGE_CONNECTIONTYPE']._serialized_start=468
_globals['_MESSAGE_CONNECTIONTYPE']._serialized_end=555
# @@protoc_insertion_point(module_scope)

View File

@ -0,0 +1,133 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
"""
import builtins
import collections.abc
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.internal.enum_type_wrapper
import google.protobuf.message
import sys
import typing
if sys.version_info >= (3, 10):
import typing as typing_extensions
else:
import typing_extensions
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@typing.final
class Record(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEY_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
TIMERECEIVED_FIELD_NUMBER: builtins.int
key: builtins.bytes
value: builtins.bytes
timeReceived: builtins.str
def __init__(
self,
*,
key: builtins.bytes = ...,
value: builtins.bytes = ...,
timeReceived: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["key", b"key", "timeReceived", b"timeReceived", "value", b"value"]) -> None: ...
global___Record = Record
@typing.final
class Message(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
class _MessageType:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _MessageTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._MessageType.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
PUT_VALUE: Message._MessageType.ValueType # 0
GET_VALUE: Message._MessageType.ValueType # 1
ADD_PROVIDER: Message._MessageType.ValueType # 2
GET_PROVIDERS: Message._MessageType.ValueType # 3
FIND_NODE: Message._MessageType.ValueType # 4
PING: Message._MessageType.ValueType # 5
class MessageType(_MessageType, metaclass=_MessageTypeEnumTypeWrapper): ...
PUT_VALUE: Message.MessageType.ValueType # 0
GET_VALUE: Message.MessageType.ValueType # 1
ADD_PROVIDER: Message.MessageType.ValueType # 2
GET_PROVIDERS: Message.MessageType.ValueType # 3
FIND_NODE: Message.MessageType.ValueType # 4
PING: Message.MessageType.ValueType # 5
class _ConnectionType:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _ConnectionTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._ConnectionType.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
NOT_CONNECTED: Message._ConnectionType.ValueType # 0
CONNECTED: Message._ConnectionType.ValueType # 1
CAN_CONNECT: Message._ConnectionType.ValueType # 2
CANNOT_CONNECT: Message._ConnectionType.ValueType # 3
class ConnectionType(_ConnectionType, metaclass=_ConnectionTypeEnumTypeWrapper): ...
NOT_CONNECTED: Message.ConnectionType.ValueType # 0
CONNECTED: Message.ConnectionType.ValueType # 1
CAN_CONNECT: Message.ConnectionType.ValueType # 2
CANNOT_CONNECT: Message.ConnectionType.ValueType # 3
@typing.final
class Peer(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
ID_FIELD_NUMBER: builtins.int
ADDRS_FIELD_NUMBER: builtins.int
CONNECTION_FIELD_NUMBER: builtins.int
id: builtins.bytes
connection: global___Message.ConnectionType.ValueType
@property
def addrs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ...
def __init__(
self,
*,
id: builtins.bytes = ...,
addrs: collections.abc.Iterable[builtins.bytes] | None = ...,
connection: global___Message.ConnectionType.ValueType = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["addrs", b"addrs", "connection", b"connection", "id", b"id"]) -> None: ...
TYPE_FIELD_NUMBER: builtins.int
CLUSTERLEVELRAW_FIELD_NUMBER: builtins.int
KEY_FIELD_NUMBER: builtins.int
RECORD_FIELD_NUMBER: builtins.int
CLOSERPEERS_FIELD_NUMBER: builtins.int
PROVIDERPEERS_FIELD_NUMBER: builtins.int
type: global___Message.MessageType.ValueType
clusterLevelRaw: builtins.int
key: builtins.bytes
@property
def record(self) -> global___Record: ...
@property
def closerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ...
@property
def providerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ...
def __init__(
self,
*,
type: global___Message.MessageType.ValueType = ...,
clusterLevelRaw: builtins.int = ...,
key: builtins.bytes = ...,
record: global___Record | None = ...,
closerPeers: collections.abc.Iterable[global___Message.Peer] | None = ...,
providerPeers: collections.abc.Iterable[global___Message.Peer] | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["record", b"record"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["closerPeers", b"closerPeers", "clusterLevelRaw", b"clusterLevelRaw", "key", b"key", "providerPeers", b"providerPeers", "record", b"record", "type", b"type"]) -> None: ...
global___Message = Message

View File

@ -0,0 +1,415 @@
"""
Peer routing implementation for Kademlia DHT.
This module implements the peer routing interface using Kademlia's algorithm
to efficiently locate peers in a distributed network.
"""
import logging
import trio
import varint
from libp2p.abc import (
IHost,
INetStream,
IPeerRouting,
)
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
from .common import (
ALPHA,
PROTOCOL_ID,
)
from .pb.kademlia_pb2 import (
Message,
)
from .routing_table import (
RoutingTable,
)
from .utils import (
sort_peer_ids_by_distance,
)
# logger = logging.getLogger("libp2p.kademlia.peer_routing")
logger = logging.getLogger("kademlia-example.peer_routing")
MAX_PEER_LOOKUP_ROUNDS = 20 # Maximum number of rounds in peer lookup
class PeerRouting(IPeerRouting):
"""
Implementation of peer routing using the Kademlia algorithm.
This class provides methods to find peers in the DHT network
and helps maintain the routing table.
"""
def __init__(self, host: IHost, routing_table: RoutingTable):
"""
Initialize the peer routing service.
:param host: The libp2p host
:param routing_table: The Kademlia routing table
"""
self.host = host
self.routing_table = routing_table
async def find_peer(self, peer_id: ID) -> PeerInfo | None:
"""
Find a peer with the given ID.
:param peer_id: The ID of the peer to find
Returns
-------
Optional[PeerInfo]
The peer information if found, None otherwise
"""
# Check if this is actually our peer ID
if peer_id == self.host.get_id():
try:
# Return our own peer info
return PeerInfo(peer_id, self.host.get_addrs())
except Exception:
logger.exception("Error getting our own peer info")
return None
# First check if the peer is in our routing table
peer_info = self.routing_table.get_peer_info(peer_id)
if peer_info:
logger.debug(f"Found peer {peer_id} in routing table")
return peer_info
# Then check if the peer is in our peerstore
try:
addrs = self.host.get_peerstore().addrs(peer_id)
if addrs:
logger.debug(f"Found peer {peer_id} in peerstore")
return PeerInfo(peer_id, addrs)
except Exception:
pass
# If not found locally, search the network
try:
closest_peers = await self.find_closest_peers_network(peer_id.to_bytes())
logger.info(f"Closest peers found: {closest_peers}")
# Check if we found the peer we're looking for
for found_peer in closest_peers:
if found_peer == peer_id:
try:
addrs = self.host.get_peerstore().addrs(found_peer)
if addrs:
return PeerInfo(found_peer, addrs)
except Exception:
pass
except Exception as e:
logger.error(f"Error searching for peer {peer_id}: {e}")
# Not found
logger.info(f"Peer {peer_id} not found")
return None
async def _query_single_peer_for_closest(
self, peer: ID, target_key: bytes, new_peers: list[ID]
) -> None:
"""
Query a single peer for closest peers and append results to the shared list.
params: peer : ID
The peer to query
params: target_key : bytes
The target key to find closest peers for
params: new_peers : list[ID]
Shared list to append results to
"""
try:
result = await self._query_peer_for_closest(peer, target_key)
# Add deduplication to prevent duplicate peers
for peer_id in result:
if peer_id not in new_peers:
new_peers.append(peer_id)
logger.debug(
"Queried peer %s for closest peers, got %d results (%d unique)",
peer,
len(result),
len([p for p in result if p not in new_peers[: -len(result)]]),
)
except Exception as e:
logger.debug(f"Query to peer {peer} failed: {e}")
async def find_closest_peers_network(
self, target_key: bytes, count: int = 20
) -> list[ID]:
"""
Find the closest peers to a target key in the entire network.
Performs an iterative lookup by querying peers for their closest peers.
Returns
-------
list[ID]
Closest peer IDs
"""
# Start with closest peers from our routing table
closest_peers = self.routing_table.find_local_closest_peers(target_key, count)
logger.debug("Local closest peers: %d found", len(closest_peers))
queried_peers: set[ID] = set()
rounds = 0
# Return early if we have no peers to start with
if not closest_peers:
logger.warning("No local peers available for network lookup")
return []
# Iterative lookup until convergence
while rounds < MAX_PEER_LOOKUP_ROUNDS:
rounds += 1
logger.debug(f"Lookup round {rounds}/{MAX_PEER_LOOKUP_ROUNDS}")
# Find peers we haven't queried yet
peers_to_query = [p for p in closest_peers if p not in queried_peers]
if not peers_to_query:
logger.debug("No more unqueried peers available, ending lookup")
break # No more peers to query
# Query these peers for their closest peers to target
peers_batch = peers_to_query[:ALPHA] # Limit to ALPHA peers at a time
# Mark these peers as queried before we actually query them
for peer in peers_batch:
queried_peers.add(peer)
# Run queries in parallel for this batch using trio nursery
new_peers: list[ID] = [] # Shared array to collect all results
async with trio.open_nursery() as nursery:
for peer in peers_batch:
nursery.start_soon(
self._query_single_peer_for_closest, peer, target_key, new_peers
)
# If we got no new peers, we're done
if not new_peers:
logger.debug("No new peers discovered in this round, ending lookup")
break
# Update our list of closest peers
all_candidates = closest_peers + new_peers
old_closest_peers = closest_peers[:]
closest_peers = sort_peer_ids_by_distance(target_key, all_candidates)[
:count
]
logger.debug(f"Updated closest peers count: {len(closest_peers)}")
# Check if we made any progress (found closer peers)
if closest_peers == old_closest_peers:
logger.debug("No improvement in closest peers, ending lookup")
break
logger.info(
f"Network lookup completed after {rounds} rounds, "
f"found {len(closest_peers)} peers"
)
return closest_peers
async def _query_peer_for_closest(self, peer: ID, target_key: bytes) -> list[ID]:
"""
Query a peer for their closest peers
to the target key using varint length prefix
"""
stream = None
results = []
try:
# Add the peer to our routing table regardless of query outcome
try:
addrs = self.host.get_peerstore().addrs(peer)
if addrs:
peer_info = PeerInfo(peer, addrs)
await self.routing_table.add_peer(peer_info)
except Exception as e:
logger.debug(f"Failed to add peer {peer} to routing table: {e}")
# Open a stream to the peer using the Kademlia protocol
logger.debug(f"Opening stream to {peer} for closest peers query")
try:
stream = await self.host.new_stream(peer, [PROTOCOL_ID])
logger.debug(f"Stream opened to {peer}")
except Exception as e:
logger.warning(f"Failed to open stream to {peer}: {e}")
return []
# Create and send FIND_NODE request using protobuf
find_node_msg = Message()
find_node_msg.type = Message.MessageType.FIND_NODE
find_node_msg.key = target_key # Set target key directly as bytes
# Serialize and send the protobuf message with varint length prefix
proto_bytes = find_node_msg.SerializeToString()
logger.debug(
f"Sending FIND_NODE: {proto_bytes.hex()} (len={len(proto_bytes)})"
)
await stream.write(varint.encode(len(proto_bytes)))
await stream.write(proto_bytes)
# Read varint-prefixed response length
length_bytes = b""
while True:
b = await stream.read(1)
if not b:
logger.warning(
"Error reading varint length from stream: connection closed"
)
return []
length_bytes += b
if b[0] & 0x80 == 0:
break
response_length = varint.decode_bytes(length_bytes)
# Read response data
response_bytes = b""
remaining = response_length
while remaining > 0:
chunk = await stream.read(remaining)
if not chunk:
logger.debug(f"Connection closed by peer {peer} while reading data")
return []
response_bytes += chunk
remaining -= len(chunk)
# Parse the protobuf response
response_msg = Message()
response_msg.ParseFromString(response_bytes)
logger.debug(
"Received response from %s with %d peers",
peer,
len(response_msg.closerPeers),
)
# Process closest peers from response
if response_msg.type == Message.MessageType.FIND_NODE:
for peer_data in response_msg.closerPeers:
new_peer_id = ID(peer_data.id)
if new_peer_id not in results:
results.append(new_peer_id)
if peer_data.addrs:
from multiaddr import (
Multiaddr,
)
addrs = [Multiaddr(addr) for addr in peer_data.addrs]
self.host.get_peerstore().add_addrs(new_peer_id, addrs, 3600)
except Exception as e:
logger.debug(f"Error querying peer {peer} for closest: {e}")
finally:
if stream:
await stream.close()
return results
async def _handle_kad_stream(self, stream: INetStream) -> None:
"""
Handle incoming Kademlia protocol streams.
params: stream: The incoming stream
Returns
-------
None
"""
try:
# Read message length
length_bytes = await stream.read(4)
if not length_bytes:
return
message_length = int.from_bytes(length_bytes, byteorder="big")
# Read message
message_bytes = await stream.read(message_length)
if not message_bytes:
return
# Parse protobuf message
kad_message = Message()
try:
kad_message.ParseFromString(message_bytes)
if kad_message.type == Message.MessageType.FIND_NODE:
# Get target key directly from protobuf message
target_key = kad_message.key
# Find closest peers to target
closest_peers = self.routing_table.find_local_closest_peers(
target_key, 20
)
# Create protobuf response
response = Message()
response.type = Message.MessageType.FIND_NODE
# Add peer information to response
for peer_id in closest_peers:
peer_proto = response.closerPeers.add()
peer_proto.id = peer_id.to_bytes()
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
# Add addresses if available
try:
addrs = self.host.get_peerstore().addrs(peer_id)
if addrs:
for addr in addrs:
peer_proto.addrs.append(addr.to_bytes())
except Exception:
pass
# Send response
response_bytes = response.SerializeToString()
await stream.write(len(response_bytes).to_bytes(4, byteorder="big"))
await stream.write(response_bytes)
except Exception as parse_err:
logger.error(f"Failed to parse protocol buffer message: {parse_err}")
except Exception as e:
logger.debug(f"Error handling Kademlia stream: {e}")
finally:
await stream.close()
async def refresh_routing_table(self) -> None:
"""
Refresh the routing table by performing lookups for random keys.
Returns
-------
None
"""
logger.info("Refreshing routing table")
# Perform a lookup for ourselves to populate the routing table
local_id = self.host.get_id()
closest_peers = await self.find_closest_peers_network(local_id.to_bytes())
# Add discovered peers to routing table
for peer_id in closest_peers:
try:
addrs = self.host.get_peerstore().addrs(peer_id)
if addrs:
peer_info = PeerInfo(peer_id, addrs)
await self.routing_table.add_peer(peer_info)
except Exception as e:
logger.debug(f"Failed to add discovered peer {peer_id}: {e}")

View File

@ -0,0 +1,577 @@
"""
Provider record storage for Kademlia DHT.
This module implements the storage for content provider records in the Kademlia DHT.
"""
import logging
import time
from typing import (
Any,
)
from multiaddr import (
Multiaddr,
)
import trio
import varint
from libp2p.abc import (
IHost,
)
from libp2p.custom_types import (
TProtocol,
)
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
from .common import (
ALPHA,
PROTOCOL_ID,
QUERY_TIMEOUT,
)
from .pb.kademlia_pb2 import (
Message,
)
# logger = logging.getLogger("libp2p.kademlia.provider_store")
logger = logging.getLogger("kademlia-example.provider_store")
# Constants for provider records (based on IPFS standards)
PROVIDER_RECORD_REPUBLISH_INTERVAL = 22 * 60 * 60 # 22 hours in seconds
PROVIDER_RECORD_EXPIRATION_INTERVAL = 48 * 60 * 60 # 48 hours in seconds
PROVIDER_ADDRESS_TTL = 30 * 60 # 30 minutes in seconds
class ProviderRecord:
"""
A record for a content provider in the DHT.
Contains the peer information and timestamp.
"""
def __init__(
self,
provider_info: PeerInfo,
timestamp: float | None = None,
) -> None:
"""
Initialize a new provider record.
:param provider_info: The provider's peer information
:param timestamp: Time this record was created/updated
(defaults to current time)
"""
self.provider_info = provider_info
self.timestamp = timestamp or time.time()
def is_expired(self) -> bool:
"""
Check if this provider record has expired.
Returns
-------
bool
True if the record has expired
"""
current_time = time.time()
return (current_time - self.timestamp) >= PROVIDER_RECORD_EXPIRATION_INTERVAL
def should_republish(self) -> bool:
"""
Check if this provider record should be republished.
Returns
-------
bool
True if the record should be republished
"""
current_time = time.time()
return (current_time - self.timestamp) >= PROVIDER_RECORD_REPUBLISH_INTERVAL
@property
def peer_id(self) -> ID:
"""Get the provider's peer ID."""
return self.provider_info.peer_id
@property
def addresses(self) -> list[Multiaddr]:
"""Get the provider's addresses."""
return self.provider_info.addrs
class ProviderStore:
"""
Store for content provider records in the Kademlia DHT.
Maps content keys to provider records, with support for expiration.
"""
def __init__(self, host: IHost, peer_routing: Any = None) -> None:
"""
Initialize a new provider store.
:param host: The libp2p host instance (optional)
:param peer_routing: The peer routing instance (optional)
"""
# Maps content keys to a dict of provider records (peer_id -> record)
self.providers: dict[bytes, dict[str, ProviderRecord]] = {}
self.host = host
self.peer_routing = peer_routing
self.providing_keys: set[bytes] = set()
self.local_peer_id = host.get_id()
async def _republish_provider_records(self) -> None:
"""Republish all provider records for content this node is providing."""
# First, republish keys we're actively providing
for key in self.providing_keys:
logger.debug(f"Republishing provider record for key {key.hex()}")
await self.provide(key)
# Also check for any records that should be republished
time.time()
for key, providers in self.providers.items():
for peer_id_str, record in providers.items():
# Only republish records for our own peer
if self.local_peer_id and str(self.local_peer_id) == peer_id_str:
if record.should_republish():
logger.debug(
f"Republishing old provider record for key {key.hex()}"
)
await self.provide(key)
async def provide(self, key: bytes) -> bool:
"""
Advertise that this node can provide a piece of content.
Finds the k closest peers to the key and sends them ADD_PROVIDER messages.
:param key: The content key (multihash) to advertise
Returns
-------
bool
True if the advertisement was successful
"""
if not self.host or not self.peer_routing:
logger.error("Host or peer_routing not initialized, cannot provide content")
return False
# Add to local provider store
local_addrs = []
for addr in self.host.get_addrs():
local_addrs.append(addr)
local_peer_info = PeerInfo(self.host.get_id(), local_addrs)
self.add_provider(key, local_peer_info)
# Track that we're providing this key
self.providing_keys.add(key)
# Find the k closest peers to the key
closest_peers = await self.peer_routing.find_closest_peers_network(key)
logger.debug(
"Found %d peers close to key %s for provider advertisement",
len(closest_peers),
key.hex(),
)
# Send ADD_PROVIDER messages to these ALPHA peers in parallel.
success_count = 0
for i in range(0, len(closest_peers), ALPHA):
batch = closest_peers[i : i + ALPHA]
results: list[bool] = [False] * len(batch)
async def send_one(
idx: int, peer_id: ID, results: list[bool] = results
) -> None:
if peer_id == self.local_peer_id:
return
try:
with trio.move_on_after(QUERY_TIMEOUT):
success = await self._send_add_provider(peer_id, key)
results[idx] = success
if not success:
logger.warning(f"Failed to send ADD_PROVIDER to {peer_id}")
except Exception as e:
logger.warning(f"Error sending ADD_PROVIDER to {peer_id}: {e}")
async with trio.open_nursery() as nursery:
for idx, peer_id in enumerate(batch):
nursery.start_soon(send_one, idx, peer_id, results)
success_count += sum(results)
logger.info(f"Successfully advertised to {success_count} peers")
return success_count > 0
async def _send_add_provider(self, peer_id: ID, key: bytes) -> bool:
"""
Send ADD_PROVIDER message to a specific peer.
:param peer_id: The peer to send the message to
:param key: The content key being provided
Returns
-------
bool
True if the message was successfully sent and acknowledged
"""
try:
result = False
# Open a stream to the peer
stream = await self.host.new_stream(peer_id, [TProtocol(PROTOCOL_ID)])
# Get our addresses to include in the message
addrs = []
for addr in self.host.get_addrs():
addrs.append(addr.to_bytes())
# Create the ADD_PROVIDER message
message = Message()
message.type = Message.MessageType.ADD_PROVIDER
message.key = key
# Add our provider info
provider = message.providerPeers.add()
provider.id = self.local_peer_id.to_bytes()
provider.addrs.extend(addrs)
# Serialize and send the message
proto_bytes = message.SerializeToString()
await stream.write(varint.encode(len(proto_bytes)))
await stream.write(proto_bytes)
logger.debug(f"Sent ADD_PROVIDER to {peer_id} for key {key.hex()}")
# Read response length prefix
length_bytes = b""
while True:
logger.debug("Reading response length prefix in add provider")
b = await stream.read(1)
if not b:
return False
length_bytes += b
if b[0] & 0x80 == 0:
break
response_length = varint.decode_bytes(length_bytes)
# Read response data
response_bytes = b""
remaining = response_length
while remaining > 0:
chunk = await stream.read(remaining)
if not chunk:
return False
response_bytes += chunk
remaining -= len(chunk)
# Parse response
response = Message()
response.ParseFromString(response_bytes)
# Check response type
response.type == Message.MessageType.ADD_PROVIDER
if response.type:
result = True
except Exception as e:
logger.warning(f"Error sending ADD_PROVIDER to {peer_id}: {e}")
finally:
await stream.close()
return result
async def find_providers(self, key: bytes, count: int = 20) -> list[PeerInfo]:
"""
Find content providers for a given key.
:param key: The content key to look for
:param count: Maximum number of providers to return
Returns
-------
List[PeerInfo]
List of content providers
"""
if not self.host or not self.peer_routing:
logger.error("Host or peer_routing not initialized, cannot find providers")
return []
# Check local provider store first
local_providers = self.get_providers(key)
if local_providers:
logger.debug(
f"Found {len(local_providers)} providers locally for {key.hex()}"
)
return local_providers[:count]
logger.debug("local providers are %s", local_providers)
# Find the closest peers to the key
closest_peers = await self.peer_routing.find_closest_peers_network(key)
logger.debug(
f"Searching {len(closest_peers)} peers for providers of {key.hex()}"
)
# Query these peers for providers in batches of ALPHA, in parallel, with timeout
all_providers = []
for i in range(0, len(closest_peers), ALPHA):
batch = closest_peers[i : i + ALPHA]
batch_results: list[list[PeerInfo]] = [[] for _ in batch]
async def get_one(
idx: int,
peer_id: ID,
batch_results: list[list[PeerInfo]] = batch_results,
) -> None:
if peer_id == self.local_peer_id:
return
try:
with trio.move_on_after(QUERY_TIMEOUT):
providers = await self._get_providers_from_peer(peer_id, key)
if providers:
for provider in providers:
self.add_provider(key, provider)
batch_results[idx] = providers
else:
logger.debug(f"No providers found at peer {peer_id}")
except Exception as e:
logger.warning(f"Failed to get providers from {peer_id}: {e}")
async with trio.open_nursery() as nursery:
for idx, peer_id in enumerate(batch):
nursery.start_soon(get_one, idx, peer_id, batch_results)
for providers in batch_results:
all_providers.extend(providers)
if len(all_providers) >= count:
return all_providers[:count]
return all_providers[:count]
async def _get_providers_from_peer(self, peer_id: ID, key: bytes) -> list[PeerInfo]:
"""
Get content providers from a specific peer.
:param peer_id: The peer to query
:param key: The content key to look for
Returns
-------
List[PeerInfo]
List of provider information
"""
providers: list[PeerInfo] = []
try:
# Open a stream to the peer
stream = await self.host.new_stream(peer_id, [TProtocol(PROTOCOL_ID)])
try:
# Create the GET_PROVIDERS message
message = Message()
message.type = Message.MessageType.GET_PROVIDERS
message.key = key
# Serialize and send the message
proto_bytes = message.SerializeToString()
await stream.write(varint.encode(len(proto_bytes)))
await stream.write(proto_bytes)
# Read response length prefix
length_bytes = b""
while True:
b = await stream.read(1)
if not b:
return []
length_bytes += b
if b[0] & 0x80 == 0:
break
response_length = varint.decode_bytes(length_bytes)
# Read response data
response_bytes = b""
remaining = response_length
while remaining > 0:
chunk = await stream.read(remaining)
if not chunk:
return []
response_bytes += chunk
remaining -= len(chunk)
# Parse response
response = Message()
response.ParseFromString(response_bytes)
# Check response type
if response.type != Message.MessageType.GET_PROVIDERS:
return []
# Extract provider information
providers = []
for provider_proto in response.providerPeers:
try:
# Create peer ID from bytes
provider_id = ID(provider_proto.id)
# Convert addresses to Multiaddr
addrs = []
for addr_bytes in provider_proto.addrs:
try:
addrs.append(Multiaddr(addr_bytes))
except Exception:
pass # Skip invalid addresses
# Create PeerInfo and add to result
providers.append(PeerInfo(provider_id, addrs))
except Exception as e:
logger.warning(f"Failed to parse provider info: {e}")
finally:
await stream.close()
return providers
except Exception as e:
logger.warning(f"Error getting providers from {peer_id}: {e}")
return []
def add_provider(self, key: bytes, provider: PeerInfo) -> None:
"""
Add a provider for a given content key.
:param key: The content key
:param provider: The provider's peer information
Returns
-------
None
"""
# Initialize providers for this key if needed
if key not in self.providers:
self.providers[key] = {}
# Add or update the provider record
peer_id_str = str(provider.peer_id) # Use string representation as dict key
self.providers[key][peer_id_str] = ProviderRecord(
provider_info=provider, timestamp=time.time()
)
logger.debug(f"Added provider {provider.peer_id} for key {key.hex()}")
def get_providers(self, key: bytes) -> list[PeerInfo]:
"""
Get all providers for a given content key.
:param key: The content key
Returns
-------
List[PeerInfo]
List of providers for the key
"""
if key not in self.providers:
return []
# Collect valid provider records (not expired)
result = []
current_time = time.time()
expired_peers = []
for peer_id_str, record in self.providers[key].items():
# Check if the record has expired
if current_time - record.timestamp > PROVIDER_RECORD_EXPIRATION_INTERVAL:
expired_peers.append(peer_id_str)
continue
# Use addresses only if they haven't expired
addresses = []
if current_time - record.timestamp <= PROVIDER_ADDRESS_TTL:
addresses = record.addresses
# Create PeerInfo and add to results
result.append(PeerInfo(record.peer_id, addresses))
# Clean up expired records
for peer_id in expired_peers:
del self.providers[key][peer_id]
# Remove the key if no providers left
if not self.providers[key]:
del self.providers[key]
return result
def cleanup_expired(self) -> None:
"""Remove expired provider records."""
current_time = time.time()
expired_keys = []
for key, providers in self.providers.items():
expired_providers = []
for peer_id_str, record in providers.items():
if (
current_time - record.timestamp
> PROVIDER_RECORD_EXPIRATION_INTERVAL
):
expired_providers.append(peer_id_str)
logger.debug(
f"Removing expired provider {peer_id_str} for key {key.hex()}"
)
# Remove expired providers
for peer_id in expired_providers:
del providers[peer_id]
# Track empty keys for removal
if not providers:
expired_keys.append(key)
# Remove empty keys
for key in expired_keys:
del self.providers[key]
logger.debug(f"Removed key with no providers: {key.hex()}")
def get_provided_keys(self, peer_id: ID) -> list[bytes]:
"""
Get all content keys provided by a specific peer.
:param peer_id: The peer ID to look for
Returns
-------
List[bytes]
List of content keys provided by the peer
"""
peer_id_str = str(peer_id)
result = []
for key, providers in self.providers.items():
if peer_id_str in providers:
result.append(key)
return result
def size(self) -> int:
"""
Get the total number of provider records in the store.
Returns
-------
int
Total number of provider records across all keys
"""
total = 0
for providers in self.providers.values():
total += len(providers)
return total

View File

@ -0,0 +1,600 @@
"""
Kademlia DHT routing table implementation.
"""
from collections import (
OrderedDict,
)
import logging
import time
import trio
from libp2p.abc import (
IHost,
)
from libp2p.kad_dht.utils import (
xor_distance,
)
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
from .common import (
PROTOCOL_ID,
)
from .pb.kademlia_pb2 import (
Message,
)
# logger = logging.getLogger("libp2p.kademlia.routing_table")
logger = logging.getLogger("kademlia-example.routing_table")
# Default parameters
BUCKET_SIZE = 20 # k in the Kademlia paper
MAXIMUM_BUCKETS = 256 # Maximum number of buckets (for 256-bit keys)
PEER_REFRESH_INTERVAL = 60 # Interval to refresh peers in seconds
STALE_PEER_THRESHOLD = 3600 # Time in seconds after which a peer is considered stale
class KBucket:
"""
A k-bucket implementation for the Kademlia DHT.
Each k-bucket stores up to k (BUCKET_SIZE) peers, sorted by least-recently seen.
"""
def __init__(
self,
host: IHost,
bucket_size: int = BUCKET_SIZE,
min_range: int = 0,
max_range: int = 2**256,
):
"""
Initialize a new k-bucket.
:param host: The host this bucket belongs to
:param bucket_size: Maximum number of peers to store in the bucket
:param min_range: Lower boundary of the bucket's key range (inclusive)
:param max_range: Upper boundary of the bucket's key range (exclusive)
"""
self.bucket_size = bucket_size
self.host = host
self.min_range = min_range
self.max_range = max_range
# Store PeerInfo objects along with last-seen timestamp
self.peers: OrderedDict[ID, tuple[PeerInfo, float]] = OrderedDict()
def peer_ids(self) -> list[ID]:
"""Get all peer IDs in the bucket."""
return list(self.peers.keys())
def peer_infos(self) -> list[PeerInfo]:
"""Get all PeerInfo objects in the bucket."""
return [info for info, _ in self.peers.values()]
def get_oldest_peer(self) -> ID | None:
"""Get the least-recently seen peer."""
if not self.peers:
return None
return next(iter(self.peers.keys()))
async def add_peer(self, peer_info: PeerInfo) -> bool:
"""
Add a peer to the bucket. Returns True if the peer was added or updated,
False if the bucket is full.
"""
current_time = time.time()
peer_id = peer_info.peer_id
# If peer is already in the bucket, move it to the end (most recently seen)
if peer_id in self.peers:
self.refresh_peer_last_seen(peer_id)
return True
# If bucket has space, add the peer
if len(self.peers) < self.bucket_size:
self.peers[peer_id] = (peer_info, current_time)
return True
# If bucket is full, we need to replace the least-recently seen peer
# Get the least-recently seen peer
oldest_peer_id = self.get_oldest_peer()
if oldest_peer_id is None:
logger.warning("No oldest peer found when bucket is full")
return False
# Check if the old peer is responsive to ping request
try:
# Try to ping the oldest peer, not the new peer
response = await self._ping_peer(oldest_peer_id)
if response:
# If the old peer is still alive, we will not add the new peer
logger.debug(
"Old peer %s is still alive, cannot add new peer %s",
oldest_peer_id,
peer_id,
)
return False
except Exception as e:
# If the old peer is unresponsive, we can replace it with the new peer
logger.debug(
"Old peer %s is unresponsive, replacing with new peer %s: %s",
oldest_peer_id,
peer_id,
str(e),
)
self.peers.popitem(last=False) # Remove oldest peer
self.peers[peer_id] = (peer_info, current_time)
return True
# If we got here, the oldest peer responded but we couldn't add the new peer
return False
def remove_peer(self, peer_id: ID) -> bool:
"""
Remove a peer from the bucket.
Returns True if the peer was in the bucket, False otherwise.
"""
if peer_id in self.peers:
del self.peers[peer_id]
return True
return False
def has_peer(self, peer_id: ID) -> bool:
"""Check if the peer is in the bucket."""
return peer_id in self.peers
def get_peer_info(self, peer_id: ID) -> PeerInfo | None:
"""Get the PeerInfo for a given peer ID if it exists in the bucket."""
if peer_id in self.peers:
return self.peers[peer_id][0]
return None
def size(self) -> int:
"""Get the number of peers in the bucket."""
return len(self.peers)
def get_stale_peers(self, stale_threshold_seconds: int = 3600) -> list[ID]:
"""
Get peers that haven't been pinged recently.
params: stale_threshold_seconds: Time in seconds
params: after which a peer is considered stale
Returns
-------
list[ID]
List of peer IDs that need to be refreshed
"""
current_time = time.time()
stale_peers = []
for peer_id, (_, last_seen) in self.peers.items():
if current_time - last_seen > stale_threshold_seconds:
stale_peers.append(peer_id)
return stale_peers
async def _periodic_peer_refresh(self) -> None:
"""Background task to periodically refresh peers"""
try:
while True:
await trio.sleep(PEER_REFRESH_INTERVAL) # Check every minute
# Find stale peers (not pinged in last hour)
stale_peers = self.get_stale_peers(
stale_threshold_seconds=STALE_PEER_THRESHOLD
)
if stale_peers:
logger.debug(f"Found {len(stale_peers)} stale peers to refresh")
for peer_id in stale_peers:
try:
# Try to ping the peer
logger.debug("Pinging stale peer %s", peer_id)
responce = await self._ping_peer(peer_id)
if responce:
# Update the last seen time
self.refresh_peer_last_seen(peer_id)
logger.debug(f"Refreshed peer {peer_id}")
else:
# If ping fails, remove the peer
logger.debug(f"Failed to ping peer {peer_id}")
self.remove_peer(peer_id)
logger.info(f"Removed unresponsive peer {peer_id}")
logger.debug(f"Successfully refreshed peer {peer_id}")
except Exception as e:
# If ping fails, remove the peer
logger.debug(
"Failed to ping peer %s: %s",
peer_id,
e,
)
self.remove_peer(peer_id)
logger.info(f"Removed unresponsive peer {peer_id}")
except trio.Cancelled:
logger.debug("Peer refresh task cancelled")
except Exception as e:
logger.error(f"Error in peer refresh task: {e}", exc_info=True)
async def _ping_peer(self, peer_id: ID) -> bool:
"""
Ping a peer using protobuf message to check
if it's still alive and update last seen time.
params: peer_id: The ID of the peer to ping
Returns
-------
bool
True if ping successful, False otherwise
"""
result = False
# Get peer info directly from the bucket
peer_info = self.get_peer_info(peer_id)
if not peer_info:
raise ValueError(f"Peer {peer_id} not in bucket")
try:
# Open a stream to the peer with the DHT protocol
stream = await self.host.new_stream(peer_id, [PROTOCOL_ID])
try:
# Create ping protobuf message
ping_msg = Message()
ping_msg.type = Message.PING # Use correct enum
# Serialize and send with length prefix (4 bytes big-endian)
msg_bytes = ping_msg.SerializeToString()
logger.debug(
f"Sending PING message to {peer_id}, size: {len(msg_bytes)} bytes"
)
await stream.write(len(msg_bytes).to_bytes(4, byteorder="big"))
await stream.write(msg_bytes)
# Wait for response with timeout
with trio.move_on_after(2): # 2 second timeout
# Read response length (4 bytes)
length_bytes = await stream.read(4)
if not length_bytes or len(length_bytes) < 4:
logger.warning(f"Peer {peer_id} disconnected during ping")
return False
msg_len = int.from_bytes(length_bytes, byteorder="big")
if (
msg_len <= 0 or msg_len > 1024 * 1024
): # Sanity check on message size
logger.warning(
f"Invalid message length from {peer_id}: {msg_len}"
)
return False
logger.debug(
f"Receiving response from {peer_id}, size: {msg_len} bytes"
)
# Read full message
response_bytes = await stream.read(msg_len)
if not response_bytes:
logger.warning(f"Failed to read response from {peer_id}")
return False
# Parse protobuf response
response = Message()
try:
response.ParseFromString(response_bytes)
except Exception as e:
logger.warning(
f"Failed to parse protobuf response from {peer_id}: {e}"
)
return False
if response.type == Message.PING:
# Update the last seen timestamp for this peer
logger.debug(f"Successfully pinged peer {peer_id}")
result = True
return result
else:
logger.warning(
f"Unexpected response type from {peer_id}: {response.type}"
)
return False
# If we get here, the ping timed out
logger.warning(f"Ping to peer {peer_id} timed out")
return False
finally:
await stream.close()
return result
except Exception as e:
logger.error(f"Error pinging peer {peer_id}: {str(e)}")
return False
def refresh_peer_last_seen(self, peer_id: ID) -> bool:
"""
Update the last-seen timestamp for a peer in the bucket.
params: peer_id: The ID of the peer to refresh
Returns
-------
bool
True if the peer was found and refreshed, False otherwise
"""
if peer_id in self.peers:
# Get current peer info and update the timestamp
peer_info, _ = self.peers[peer_id]
current_time = time.time()
self.peers[peer_id] = (peer_info, current_time)
# Move to end of ordered dict to mark as most recently seen
self.peers.move_to_end(peer_id)
return True
return False
def key_in_range(self, key: bytes) -> bool:
"""
Check if a key is in the range of this bucket.
params: key: The key to check (bytes)
Returns
-------
bool
True if the key is in range, False otherwise
"""
key_int = int.from_bytes(key, byteorder="big")
return self.min_range <= key_int < self.max_range
def split(self) -> tuple["KBucket", "KBucket"]:
"""
Split the bucket into two buckets.
Returns
-------
tuple
(lower_bucket, upper_bucket)
"""
midpoint = (self.min_range + self.max_range) // 2
lower_bucket = KBucket(self.host, self.bucket_size, self.min_range, midpoint)
upper_bucket = KBucket(self.host, self.bucket_size, midpoint, self.max_range)
# Redistribute peers
for peer_id, (peer_info, timestamp) in self.peers.items():
peer_key = int.from_bytes(peer_id.to_bytes(), byteorder="big")
if peer_key < midpoint:
lower_bucket.peers[peer_id] = (peer_info, timestamp)
else:
upper_bucket.peers[peer_id] = (peer_info, timestamp)
return lower_bucket, upper_bucket
class RoutingTable:
"""
The Kademlia routing table maintains information on which peers to contact for any
given peer ID in the network.
"""
def __init__(self, local_id: ID, host: IHost) -> None:
"""
Initialize the routing table.
:param local_id: The ID of the local node.
:param host: The host this routing table belongs to.
"""
self.local_id = local_id
self.host = host
self.buckets = [KBucket(host, BUCKET_SIZE)]
async def add_peer(self, peer_obj: PeerInfo | ID) -> bool:
"""
Add a peer to the routing table.
:param peer_obj: Either PeerInfo object or peer ID to add
Returns
-------
bool: True if the peer was added or updated, False otherwise
"""
peer_id = None
peer_info = None
try:
# Handle different types of input
if isinstance(peer_obj, PeerInfo):
# Already have PeerInfo object
peer_info = peer_obj
peer_id = peer_obj.peer_id
else:
# Assume it's a peer ID
peer_id = peer_obj
# Try to get addresses from the peerstore if available
try:
addrs = self.host.get_peerstore().addrs(peer_id)
if addrs:
# Create PeerInfo object
peer_info = PeerInfo(peer_id, addrs)
else:
logger.debug(
"No addresses found for peer %s in peerstore, skipping",
peer_id,
)
return False
except Exception as peerstore_error:
# Handle case where peer is not in peerstore yet
logger.debug(
"Peer %s not found in peerstore: %s, skipping",
peer_id,
str(peerstore_error),
)
return False
# Don't add ourselves
if peer_id == self.local_id:
return False
# Find the right bucket for this peer
bucket = self.find_bucket(peer_id)
# Try to add to the bucket
success = await bucket.add_peer(peer_info)
if success:
logger.debug(f"Successfully added peer {peer_id} to routing table")
return success
except Exception as e:
logger.debug(f"Error adding peer {peer_obj} to routing table: {e}")
return False
def remove_peer(self, peer_id: ID) -> bool:
"""
Remove a peer from the routing table.
:param peer_id: The ID of the peer to remove
Returns
-------
bool: True if the peer was removed, False otherwise
"""
bucket = self.find_bucket(peer_id)
return bucket.remove_peer(peer_id)
def find_bucket(self, peer_id: ID) -> KBucket:
"""
Find the bucket that would contain the given peer ID or PeerInfo.
:param peer_obj: Either a peer ID or a PeerInfo object
Returns
-------
KBucket: The bucket for this peer
"""
for bucket in self.buckets:
if bucket.key_in_range(peer_id.to_bytes()):
return bucket
return self.buckets[0]
def find_local_closest_peers(self, key: bytes, count: int = 20) -> list[ID]:
"""
Find the closest peers to a given key.
:param key: The key to find closest peers to (bytes)
:param count: Maximum number of peers to return
Returns
-------
List[ID]: List of peer IDs closest to the key
"""
# Get all peers from all buckets
all_peers = []
for bucket in self.buckets:
all_peers.extend(bucket.peer_ids())
# Sort by XOR distance to the key
all_peers.sort(key=lambda p: xor_distance(p.to_bytes(), key))
return all_peers[:count]
def get_peer_ids(self) -> list[ID]:
"""
Get all peer IDs in the routing table.
Returns
-------
:param List[ID]: List of all peer IDs
"""
peers = []
for bucket in self.buckets:
peers.extend(bucket.peer_ids())
return peers
def get_peer_info(self, peer_id: ID) -> PeerInfo | None:
"""
Get the peer info for a specific peer.
:param peer_id: The ID of the peer to get info for
Returns
-------
PeerInfo: The peer info, or None if not found
"""
bucket = self.find_bucket(peer_id)
return bucket.get_peer_info(peer_id)
def peer_in_table(self, peer_id: ID) -> bool:
"""
Check if a peer is in the routing table.
:param peer_id: The ID of the peer to check
Returns
-------
bool: True if the peer is in the routing table, False otherwise
"""
bucket = self.find_bucket(peer_id)
return bucket.has_peer(peer_id)
def size(self) -> int:
"""
Get the number of peers in the routing table.
Returns
-------
int: Number of peers
"""
count = 0
for bucket in self.buckets:
count += bucket.size()
return count
def get_stale_peers(self, stale_threshold_seconds: int = 3600) -> list[ID]:
"""
Get all stale peers from all buckets
params: stale_threshold_seconds:
Time in seconds after which a peer is considered stale
Returns
-------
list[ID]
List of stale peer IDs
"""
stale_peers = []
for bucket in self.buckets:
stale_peers.extend(bucket.get_stale_peers(stale_threshold_seconds))
return stale_peers
def cleanup_routing_table(self) -> None:
"""
Cleanup the routing table by removing all data.
This is useful for resetting the routing table during tests or reinitialization.
"""
self.buckets = [KBucket(self.host, BUCKET_SIZE)]
logger.info("Routing table cleaned up, all data removed.")

117
libp2p/kad_dht/utils.py Normal file
View File

@ -0,0 +1,117 @@
"""
Utility functions for Kademlia DHT implementation.
"""
import base58
import multihash
from libp2p.peer.id import (
ID,
)
def create_key_from_binary(binary_data: bytes) -> bytes:
"""
Creates a key for the DHT by hashing binary data with SHA-256.
params: binary_data: The binary data to hash.
Returns
-------
bytes: The resulting key.
"""
return multihash.digest(binary_data, "sha2-256").digest
def xor_distance(key1: bytes, key2: bytes) -> int:
"""
Calculate the XOR distance between two keys.
params: key1: First key (bytes)
params: key2: Second key (bytes)
Returns
-------
int: The XOR distance between the keys
"""
# Ensure the inputs are bytes
if not isinstance(key1, bytes) or not isinstance(key2, bytes):
raise TypeError("Both key1 and key2 must be bytes objects")
# Convert to integers
k1 = int.from_bytes(key1, byteorder="big")
k2 = int.from_bytes(key2, byteorder="big")
# Calculate XOR distance
return k1 ^ k2
def bytes_to_base58(data: bytes) -> str:
"""
Convert bytes to base58 encoded string.
params: data: Input bytes
Returns
-------
str: Base58 encoded string
"""
return base58.b58encode(data).decode("utf-8")
def sort_peer_ids_by_distance(target_key: bytes, peer_ids: list[ID]) -> list[ID]:
"""
Sort a list of peer IDs by their distance to the target key.
params: target_key: The target key to measure distance from
params: peer_ids: List of peer IDs to sort
Returns
-------
List[ID]: Sorted list of peer IDs from closest to furthest
"""
def get_distance(peer_id: ID) -> int:
# Hash the peer ID bytes to get a key for distance calculation
peer_hash = multihash.digest(peer_id.to_bytes(), "sha2-256").digest
return xor_distance(target_key, peer_hash)
return sorted(peer_ids, key=get_distance)
def shared_prefix_len(first: bytes, second: bytes) -> int:
"""
Calculate the number of prefix bits shared by two byte sequences.
params: first: First byte sequence
params: second: Second byte sequence
Returns
-------
int: Number of shared prefix bits
"""
# Compare each byte to find the first bit difference
common_length = 0
for i in range(min(len(first), len(second))):
byte_first = first[i]
byte_second = second[i]
if byte_first == byte_second:
common_length += 8
else:
# Find specific bit where they differ
xor = byte_first ^ byte_second
# Count leading zeros in the xor result
for j in range(7, -1, -1):
if (xor >> j) & 1 == 1:
return common_length + (7 - j)
# This shouldn't be reached if xor != 0
return common_length + 8
return common_length

View File

@ -0,0 +1,393 @@
"""
Value store implementation for Kademlia DHT.
Provides a way to store and retrieve key-value pairs with optional expiration.
"""
import logging
import time
import varint
from libp2p.abc import (
IHost,
)
from libp2p.custom_types import (
TProtocol,
)
from libp2p.peer.id import (
ID,
)
from .common import (
DEFAULT_TTL,
PROTOCOL_ID,
)
from .pb.kademlia_pb2 import (
Message,
)
# logger = logging.getLogger("libp2p.kademlia.value_store")
logger = logging.getLogger("kademlia-example.value_store")
class ValueStore:
"""
Store for key-value pairs in a Kademlia DHT.
Values are stored with a timestamp and optional expiration time.
"""
def __init__(self, host: IHost, local_peer_id: ID):
"""
Initialize an empty value store.
:param host: The libp2p host instance.
:param local_peer_id: The local peer ID to ignore in peer requests.
"""
# Store format: {key: (value, validity)}
self.store: dict[bytes, tuple[bytes, float]] = {}
# Store references to the host and local peer ID for making requests
self.host = host
self.local_peer_id = local_peer_id
def put(self, key: bytes, value: bytes, validity: float = 0.0) -> None:
"""
Store a value in the DHT.
:param key: The key to store the value under
:param value: The value to store
:param validity: validity in seconds before the value expires.
Defaults to `DEFAULT_TTL` if set to 0.0.
Returns
-------
None
"""
if validity == 0.0:
validity = time.time() + DEFAULT_TTL
logger.debug(
"Storing value for key %s... with validity %s", key.hex(), validity
)
self.store[key] = (value, validity)
logger.debug(f"Stored value for key {key.hex()}")
async def _store_at_peer(self, peer_id: ID, key: bytes, value: bytes) -> bool:
"""
Store a value at a specific peer.
params: peer_id: The ID of the peer to store the value at
params: key: The key to store
params: value: The value to store
Returns
-------
bool
True if the value was successfully stored, False otherwise
"""
result = False
stream = None
try:
# Don't try to store at ourselves
if self.local_peer_id and peer_id == self.local_peer_id:
result = True
return result
if not self.host:
logger.error("Host not initialized, cannot store value at peer")
return False
logger.debug(f"Storing value for key {key.hex()} at peer {peer_id}")
# Open a stream to the peer
stream = await self.host.new_stream(peer_id, [PROTOCOL_ID])
logger.debug(f"Opened stream to peer {peer_id}")
# Create the PUT_VALUE message with protobuf
message = Message()
message.type = Message.MessageType.PUT_VALUE
# Set message fields
message.key = key
message.record.key = key
message.record.value = value
message.record.timeReceived = str(time.time())
# Serialize and send the protobuf message with length prefix
proto_bytes = message.SerializeToString()
await stream.write(varint.encode(len(proto_bytes)))
await stream.write(proto_bytes)
logger.debug("Sent PUT_VALUE protobuf message with varint length")
# Read varint-prefixed response length
length_bytes = b""
while True:
logger.debug("Reading varint length prefix for response...")
b = await stream.read(1)
if not b:
logger.warning("Connection closed while reading varint length")
return False
length_bytes += b
if b[0] & 0x80 == 0:
break
logger.debug(f"Received varint length bytes: {length_bytes.hex()}")
response_length = varint.decode_bytes(length_bytes)
logger.debug("Response length: %d bytes", response_length)
# Read response data
response_bytes = b""
remaining = response_length
while remaining > 0:
chunk = await stream.read(remaining)
if not chunk:
logger.debug(
f"Connection closed by peer {peer_id} while reading data"
)
return False
response_bytes += chunk
remaining -= len(chunk)
# Parse protobuf response
response = Message()
response.ParseFromString(response_bytes)
# Check if response is valid
if response.type == Message.MessageType.PUT_VALUE:
if response.key:
result = True
return result
except Exception as e:
logger.warning(f"Failed to store value at peer {peer_id}: {e}")
return False
finally:
if stream:
await stream.close()
return result
def get(self, key: bytes) -> bytes | None:
"""
Retrieve a value from the DHT.
params: key: The key to look up
Returns
-------
Optional[bytes]
The stored value, or None if not found or expired
"""
logger.debug("Retrieving value for key %s...", key.hex()[:8])
if key not in self.store:
return None
value, validity = self.store[key]
logger.debug(
"Found value for key %s... with validity %s",
key.hex(),
validity,
)
# Check if the value has expired
if validity is not None and validity < time.time():
logger.debug(
"Value for key %s... has expired, removing it",
key.hex()[:8],
)
self.remove(key)
return None
return value
async def _get_from_peer(self, peer_id: ID, key: bytes) -> bytes | None:
"""
Retrieve a value from a specific peer.
params: peer_id: The ID of the peer to retrieve the value from
params: key: The key to retrieve
Returns
-------
Optional[bytes]
The value if found, None otherwise
"""
stream = None
try:
# Don't try to get from ourselves
if peer_id == self.local_peer_id:
return None
logger.debug(f"Getting value for key {key.hex()} from peer {peer_id}")
# Open a stream to the peer
stream = await self.host.new_stream(peer_id, [TProtocol(PROTOCOL_ID)])
logger.debug(f"Opened stream to peer {peer_id} for GET_VALUE")
# Create the GET_VALUE message using protobuf
message = Message()
message.type = Message.MessageType.GET_VALUE
message.key = key
# Serialize and send the protobuf message
proto_bytes = message.SerializeToString()
await stream.write(varint.encode(len(proto_bytes)))
await stream.write(proto_bytes)
# Read response length
length_bytes = b""
while True:
b = await stream.read(1)
if not b:
logger.warning("Connection closed while reading length")
return None
length_bytes += b
if b[0] & 0x80 == 0:
break
response_length = varint.decode_bytes(length_bytes)
# Read response data
response_bytes = b""
remaining = response_length
while remaining > 0:
chunk = await stream.read(remaining)
if not chunk:
logger.debug(
f"Connection closed by peer {peer_id} while reading data"
)
return None
response_bytes += chunk
remaining -= len(chunk)
# Parse protobuf response
try:
response = Message()
response.ParseFromString(response_bytes)
logger.debug(
f"Received protobuf response from peer"
f" {peer_id}, type: {response.type}"
)
# Process protobuf response
if (
response.type == Message.MessageType.GET_VALUE
and response.HasField("record")
and response.record.value
):
logger.debug(
f"Received value for key {key.hex()} from peer {peer_id}"
)
return response.record.value
# Handle case where value is not found but peer infos are returned
else:
logger.debug(
f"Value not found for key {key.hex()} from peer {peer_id},"
f" received {len(response.closerPeers)} closer peers"
)
return None
except Exception as proto_err:
logger.warning(f"Failed to parse as protobuf: {proto_err}")
return None
except Exception as e:
logger.warning(f"Failed to get value from peer {peer_id}: {e}")
return None
finally:
if stream:
await stream.close()
def remove(self, key: bytes) -> bool:
"""
Remove a value from the DHT.
params: key: The key to remove
Returns
-------
bool
True if the key was found and removed, False otherwise
"""
if key in self.store:
del self.store[key]
logger.debug(f"Removed value for key {key.hex()[:8]}...")
return True
return False
def has(self, key: bytes) -> bool:
"""
Check if a key exists in the store and hasn't expired.
params: key: The key to check
Returns
-------
bool
True if the key exists and hasn't expired, False otherwise
"""
if key not in self.store:
return False
_, validity = self.store[key]
if validity is not None and time.time() > validity:
self.remove(key)
return False
return True
def cleanup_expired(self) -> int:
"""
Remove all expired values from the store.
Returns
-------
int
The number of expired values that were removed
"""
current_time = time.time()
expired_keys = [
key for key, (_, validity) in self.store.items() if current_time > validity
]
for key in expired_keys:
del self.store[key]
if expired_keys:
logger.debug(f"Cleaned up {len(expired_keys)} expired values")
return len(expired_keys)
def get_keys(self) -> list[bytes]:
"""
Get all non-expired keys in the store.
Returns
-------
list[bytes]
List of keys
"""
# Clean up expired values first
self.cleanup_expired()
return list(self.store.keys())
def size(self) -> int:
"""
Get the number of items in the store (after removing expired entries).
Returns
-------
int
Number of items
"""
self.cleanup_expired()
return len(self.store)

View File

@ -1,7 +1,3 @@
from typing import (
Optional,
)
from libp2p.abc import (
IRawConnection,
)
@ -32,7 +28,7 @@ class RawConnection(IRawConnection):
except IOException as error:
raise RawConnError from error
async def read(self, n: int = None) -> bytes:
async def read(self, n: int | None = None) -> bytes:
"""
Read up to ``n`` bytes from the underlying stream. This call is
delegated directly to the underlying ``self.reader``.
@ -47,6 +43,6 @@ class RawConnection(IRawConnection):
async def close(self) -> None:
await self.stream.close()
def get_remote_address(self) -> Optional[tuple[str, int]]:
def get_remote_address(self) -> tuple[str, int] | None:
"""Delegate to the underlying stream's get_remote_address method."""
return self.stream.get_remote_address()

View File

@ -22,7 +22,7 @@ if TYPE_CHECKING:
"""
Reference: https://github.com/libp2p/go-libp2p-swarm/blob/04c86bbdafd390651cb2ee14e334f7caeedad722/swarm_conn.go # noqa: E501
Reference: https://github.com/libp2p/go-libp2p-swarm/blob/04c86bbdafd390651cb2ee14e334f7caeedad722/swarm_conn.go
"""
@ -32,7 +32,11 @@ class SwarmConn(INetConn):
streams: set[NetStream]
event_closed: trio.Event
def __init__(self, muxed_conn: IMuxedConn, swarm: "Swarm") -> None:
def __init__(
self,
muxed_conn: IMuxedConn,
swarm: "Swarm",
) -> None:
self.muxed_conn = muxed_conn
self.swarm = swarm
self.streams = set()
@ -40,7 +44,7 @@ class SwarmConn(INetConn):
self.event_started = trio.Event()
if hasattr(muxed_conn, "on_close"):
logging.debug(f"Setting on_close for peer {muxed_conn.peer_id}")
muxed_conn.on_close = self._on_muxed_conn_closed
setattr(muxed_conn, "on_close", self._on_muxed_conn_closed)
else:
logging.error(
f"muxed_conn for peer {muxed_conn.peer_id} has no on_close attribute"

View File

@ -1,7 +1,9 @@
from typing import (
Optional,
from enum import (
Enum,
)
import trio
from libp2p.abc import (
IMuxedStream,
INetStream,
@ -23,19 +25,103 @@ from .exceptions import (
)
# TODO: Handle exceptions from `muxed_stream`
# TODO: Add stream state
# - Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501
class NetStream(INetStream):
muxed_stream: IMuxedStream
protocol_id: Optional[TProtocol]
class StreamState(Enum):
"""NetStream States"""
OPEN = "open"
CLOSE_READ = "close_read"
CLOSE_WRITE = "close_write"
CLOSE_BOTH = "close_both"
RESET = "reset"
class NetStream(INetStream):
"""
Summary
_______
A Network stream implementation.
NetStream wraps a muxed stream and provides proper state tracking, resource cleanup,
and event notification capabilities.
State Machine
_____________
.. code:: markdown
[CREATED] → OPEN → CLOSE_READ → CLOSE_BOTH → [CLEANUP]
↓ ↗ ↗
CLOSE_WRITE → ← ↗
↓ ↗
RESET → → → → → → → →
State Transitions
_________________
- OPEN → CLOSE_READ: EOF encountered during read()
- OPEN → CLOSE_WRITE: Explicit close() call
- OPEN → RESET: reset() call or critical stream error
- CLOSE_READ → CLOSE_BOTH: Explicit close() call
- CLOSE_WRITE → CLOSE_BOTH: EOF encountered during read()
- Any state → RESET: reset() call
Terminal States (trigger cleanup)
_________________________________
- CLOSE_BOTH: Stream fully closed, triggers resource cleanup
- RESET: Stream reset/terminated, triggers resource cleanup
Operation Validity by State
___________________________
OPEN: read() ✓ write() ✓ close() ✓ reset() ✓
CLOSE_READ: read() ✗ write() ✓ close() ✓ reset() ✓
CLOSE_WRITE: read() ✓ write() ✗ close() ✓ reset() ✓
CLOSE_BOTH: read() ✗ write() ✗ close() ✓ reset() ✓
RESET: read() ✗ write() ✗ close() ✓ reset() ✓
Cleanup Process (triggered by CLOSE_BOTH or RESET)
__________________________________________________
1. Remove stream from SwarmConn
2. Notify all listeners with ClosedStream event
3. Decrement reference counter
4. Background cleanup via nursery (if provided)
Thread Safety
_____________
All state operations are protected by trio.Lock() for safe concurrent access.
State checks and modifications are atomic operations.
Example: See :file:`examples/doc-examples/example_net_stream.py`
:param muxed_stream (IMuxedStream): The underlying muxed stream
:param nursery (Optional[trio.Nursery]): Nursery for background cleanup tasks
:raises StreamClosed: When attempting invalid operations on closed streams
:raises StreamEOF: When EOF is encountered during read operations
:raises StreamReset: When the underlying stream has been reset
"""
muxed_stream: IMuxedStream
protocol_id: TProtocol | None
__stream_state: StreamState
def __init__(
self, muxed_stream: IMuxedStream, nursery: trio.Nursery | None = None
) -> None:
super().__init__()
def __init__(self, muxed_stream: IMuxedStream) -> None:
self.muxed_stream = muxed_stream
self.muxed_conn = muxed_stream.muxed_conn
self.protocol_id = None
def get_protocol(self) -> TProtocol:
# For background tasks
self._nursery = nursery
# State management
self.__stream_state = StreamState.OPEN
self._state_lock = trio.Lock()
# For notification handling
self._notify_lock = trio.Lock()
def get_protocol(self) -> TProtocol | None:
"""
:return: protocol id that stream runs on
"""
@ -47,42 +133,176 @@ class NetStream(INetStream):
"""
self.protocol_id = protocol_id
async def read(self, n: int = None) -> bytes:
@property
async def state(self) -> StreamState:
"""Get current stream state."""
async with self._state_lock:
return self.__stream_state
async def read(self, n: int | None = None) -> bytes:
"""
Read from stream.
:param n: number of bytes to read
:return: bytes of input
:raises StreamClosed: If `NetStream` is closed for reading
:raises StreamReset: If `NetStream` is reset
:raises StreamEOF: If trying to read after reaching end of file
:return: Bytes read from the stream
"""
async with self._state_lock:
if self.__stream_state in [
StreamState.CLOSE_READ,
StreamState.CLOSE_BOTH,
]:
raise StreamClosed("Stream is closed for reading")
if self.__stream_state == StreamState.RESET:
raise StreamReset("Stream is reset, cannot be used to read")
try:
return await self.muxed_stream.read(n)
data = await self.muxed_stream.read(n)
return data
except MuxedStreamEOF as error:
async with self._state_lock:
if self.__stream_state == StreamState.CLOSE_WRITE:
self.__stream_state = StreamState.CLOSE_BOTH
await self._remove()
elif self.__stream_state == StreamState.OPEN:
self.__stream_state = StreamState.CLOSE_READ
raise StreamEOF() from error
except MuxedStreamReset as error:
async with self._state_lock:
if self.__stream_state in [
StreamState.OPEN,
StreamState.CLOSE_READ,
StreamState.CLOSE_WRITE,
]:
self.__stream_state = StreamState.RESET
await self._remove()
raise StreamReset() from error
async def write(self, data: bytes) -> None:
"""
Write to stream.
:return: number of bytes written
:param data: bytes to write
:raises StreamClosed: If `NetStream` is closed for writing or reset
:raises StreamClosed: If `StreamError` occurred while writing
"""
async with self._state_lock:
if self.__stream_state in [
StreamState.CLOSE_WRITE,
StreamState.CLOSE_BOTH,
StreamState.RESET,
]:
raise StreamClosed("Stream is closed for writing")
try:
await self.muxed_stream.write(data)
except (MuxedStreamClosed, MuxedStreamError) as error:
async with self._state_lock:
if self.__stream_state == StreamState.OPEN:
self.__stream_state = StreamState.CLOSE_WRITE
elif self.__stream_state == StreamState.CLOSE_READ:
self.__stream_state = StreamState.CLOSE_BOTH
await self._remove()
raise StreamClosed() from error
async def close(self) -> None:
"""Close stream."""
"""Close stream for writing."""
async with self._state_lock:
if self.__stream_state in [
StreamState.CLOSE_BOTH,
StreamState.RESET,
StreamState.CLOSE_WRITE,
]:
return
await self.muxed_stream.close()
async with self._state_lock:
if self.__stream_state == StreamState.CLOSE_READ:
self.__stream_state = StreamState.CLOSE_BOTH
await self._remove()
elif self.__stream_state == StreamState.OPEN:
self.__stream_state = StreamState.CLOSE_WRITE
async def reset(self) -> None:
"""Reset stream, closing both ends."""
async with self._state_lock:
if self.__stream_state == StreamState.RESET:
return
await self.muxed_stream.reset()
def get_remote_address(self) -> Optional[tuple[str, int]]:
async with self._state_lock:
if self.__stream_state in [
StreamState.OPEN,
StreamState.CLOSE_READ,
StreamState.CLOSE_WRITE,
]:
self.__stream_state = StreamState.RESET
await self._remove()
async def _remove(self) -> None:
"""
Remove stream from connection and notify listeners.
This is called when the stream is fully closed or reset.
"""
if hasattr(self.muxed_conn, "remove_stream"):
remove_stream = getattr(self.muxed_conn, "remove_stream")
await remove_stream(self)
# Notify in background using Trio nursery if available
if self._nursery:
self._nursery.start_soon(self._notify_closed)
else:
await self._notify_closed()
async def _notify_closed(self) -> None:
"""
Notify all listeners that the stream has been closed.
This runs in a separate task to avoid blocking the main flow.
"""
async with self._notify_lock:
if hasattr(self.muxed_conn, "swarm"):
swarm = getattr(self.muxed_conn, "swarm")
if hasattr(swarm, "notify_all"):
await swarm.notify_all(
lambda notifiee: notifiee.closed_stream(swarm, self)
)
if hasattr(swarm, "refs") and hasattr(swarm.refs, "done"):
swarm.refs.done()
def get_remote_address(self) -> tuple[str, int] | None:
"""Delegate to the underlying muxed stream."""
return self.muxed_stream.get_remote_address()
# TODO: `remove`: Called by close and write when the stream is in specific states.
# It notifies `ClosedStream` after `SwarmConn.remove_stream` is called.
# Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501
async def is_closed(self) -> bool:
"""Check if stream is closed."""
current_state = await self.state
return current_state in [StreamState.CLOSE_BOTH, StreamState.RESET]
async def is_readable(self) -> bool:
"""Check if stream is readable."""
current_state = await self.state
return current_state not in [
StreamState.CLOSE_READ,
StreamState.CLOSE_BOTH,
StreamState.RESET,
]
async def is_writable(self) -> bool:
"""Check if stream is writable."""
current_state = await self.state
return current_state not in [
StreamState.CLOSE_WRITE,
StreamState.CLOSE_BOTH,
StreamState.RESET,
]
def __str__(self) -> str:
"""String representation of the stream."""
return f"<NetStream[{self.__stream_state.value}] protocol={self.protocol_id}>"

View File

@ -1,7 +1,4 @@
import logging
from typing import (
Optional,
)
from multiaddr import (
Multiaddr,
@ -75,7 +72,7 @@ class Swarm(Service, INetworkService):
connections: dict[ID, INetConn]
listeners: dict[str, IListener]
common_stream_handler: StreamHandlerFn
listener_nursery: Optional[trio.Nursery]
listener_nursery: trio.Nursery | None
event_listener_nursery_created: trio.Event
notifees: list[INotifee]
@ -190,7 +187,7 @@ class Swarm(Service, INetworkService):
# Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure
# the conn and then mux the conn
try:
secured_conn = await self.upgrader.upgrade_security(raw_conn, peer_id, True)
secured_conn = await self.upgrader.upgrade_security(raw_conn, True, peer_id)
except SecurityUpgradeFailure as error:
logger.debug("failed to upgrade security for peer %s", peer_id)
await raw_conn.close()
@ -260,10 +257,7 @@ class Swarm(Service, INetworkService):
# Per, https://discuss.libp2p.io/t/multistream-security/130, we first
# secure the conn and then mux the conn
try:
# FIXME: This dummy `ID(b"")` for the remote peer is useless.
secured_conn = await self.upgrader.upgrade_security(
raw_conn, ID(b""), False
)
secured_conn = await self.upgrader.upgrade_security(raw_conn, False)
except SecurityUpgradeFailure as error:
logger.debug("failed to upgrade security for peer at %s", maddr)
await raw_conn.close()
@ -340,7 +334,9 @@ class Swarm(Service, INetworkService):
if hasattr(self, "transport") and self.transport is not None:
# Check if transport has close method before calling it
if hasattr(self.transport, "close"):
await self.transport.close()
await self.transport.close() # type: ignore
# Ignoring the type above since `transport` may not have a close method
# and we have already checked it with hasattr
logger.debug("swarm successfully closed")
@ -360,7 +356,11 @@ class Swarm(Service, INetworkService):
and start to monitor the connection for its new streams and
disconnection.
"""
swarm_conn = SwarmConn(muxed_conn, self)
swarm_conn = SwarmConn(
muxed_conn,
self,
)
self.manager.run_task(muxed_conn.start)
await muxed_conn.event_started.wait()
self.manager.run_task(swarm_conn.start)

View File

@ -1,7 +1,4 @@
import hashlib
from typing import (
Union,
)
import base58
import multihash
@ -24,7 +21,7 @@ if ENABLE_INLINING:
_digest: bytes
def __init__(self) -> None:
self._digest = bytearray()
self._digest = b""
def update(self, input: bytes) -> None:
self._digest += input
@ -39,8 +36,8 @@ if ENABLE_INLINING:
class ID:
_bytes: bytes
_xor_id: int = None
_b58_str: str = None
_xor_id: int | None = None
_b58_str: str | None = None
def __init__(self, peer_id_bytes: bytes) -> None:
self._bytes = peer_id_bytes
@ -93,7 +90,7 @@ class ID:
return cls(mh_digest.encode())
def sha256_digest(data: Union[str, bytes]) -> bytes:
def sha256_digest(data: str | bytes) -> bytes:
if isinstance(data, str):
data = data.encode("utf8")
return hashlib.sha256(data).digest()

View File

@ -1,6 +1,7 @@
from collections.abc import (
Sequence,
)
import time
from typing import (
Any,
)
@ -19,11 +20,13 @@ from libp2p.crypto.keys import (
class PeerData(IPeerData):
pubkey: PublicKey
privkey: PrivateKey
pubkey: PublicKey | None
privkey: PrivateKey | None
metadata: dict[Any, Any]
protocols: list[str]
addrs: list[Multiaddr]
last_identified: int
ttl: int # Keep ttl=0 by default for always valid
def __init__(self) -> None:
self.pubkey = None
@ -31,6 +34,8 @@ class PeerData(IPeerData):
self.metadata = {}
self.protocols = []
self.addrs = []
self.last_identified = int(time.time())
self.ttl = 0
def get_protocols(self) -> list[str]:
"""
@ -115,6 +120,36 @@ class PeerData(IPeerData):
raise PeerDataError("private key not found")
return self.privkey
def update_last_identified(self) -> None:
self.last_identified = int(time.time())
def get_last_identified(self) -> int:
"""
:return: last identified timestamp
"""
return self.last_identified
def get_ttl(self) -> int:
"""
:return: ttl for current peer
"""
return self.ttl
def set_ttl(self, ttl: int) -> None:
"""
:param ttl: ttl to set
"""
self.ttl = ttl
def is_expired(self) -> bool:
"""
:return: true, if last_identified+ttl > current_time
"""
# for ttl = 0; peer_data is always valid
if self.ttl > 0 and self.last_identified + self.ttl < int(time.time()):
return True
return False
class PeerDataError(KeyError):
"""Raised when a key is not found in peer metadata."""

View File

@ -32,21 +32,31 @@ def info_from_p2p_addr(addr: multiaddr.Multiaddr) -> PeerInfo:
if not addr:
raise InvalidAddrError("`addr` should not be `None`")
parts = addr.split()
parts: list[multiaddr.Multiaddr] = addr.split()
if not parts:
raise InvalidAddrError(
f"`parts`={parts} should at least have a protocol `P_P2P`"
)
p2p_part = parts[-1]
last_protocol_code = p2p_part.protocols()[0].code
if last_protocol_code != multiaddr.protocols.P_P2P:
p2p_protocols = p2p_part.protocols()
if not p2p_protocols:
raise InvalidAddrError("The last part of the address has no protocols")
last_protocol = p2p_protocols[0]
if last_protocol is None:
raise InvalidAddrError("The last protocol is None")
last_protocol_code = last_protocol.code
if last_protocol_code != multiaddr.multiaddr.protocols.P_P2P:
raise InvalidAddrError(
f"The last protocol should be `P_P2P` instead of `{last_protocol_code}`"
)
# make sure the /p2p value parses as a peer.ID
peer_id_str: str = p2p_part.value_for_protocol(multiaddr.protocols.P_P2P)
peer_id_str = p2p_part.value_for_protocol(multiaddr.multiaddr.protocols.P_P2P)
if peer_id_str is None:
raise InvalidAddrError("Missing value for /p2p protocol in multiaddr")
peer_id: ID = ID.from_base58(peer_id_str)
# we might have received just an / p2p part, which means there's no addr.
@ -56,5 +66,23 @@ def info_from_p2p_addr(addr: multiaddr.Multiaddr) -> PeerInfo:
return PeerInfo(peer_id, [addr])
def peer_info_to_bytes(peer_info: PeerInfo) -> bytes:
lines = [str(peer_info.peer_id)] + [str(addr) for addr in peer_info.addrs]
return "\n".join(lines).encode("utf-8")
def peer_info_from_bytes(data: bytes) -> PeerInfo:
try:
lines = data.decode("utf-8").splitlines()
if not lines:
raise InvalidAddrError("no data to decode PeerInfo")
peer_id = ID.from_base58(lines[0])
addrs = [multiaddr.Multiaddr(addr_str) for addr_str in lines[1:]]
return PeerInfo(peer_id, addrs)
except Exception as e:
raise InvalidAddrError(f"failed to decode PeerInfo: {e}")
class InvalidAddrError(ValueError):
pass

View File

@ -4,7 +4,6 @@ from collections import (
from collections.abc import (
Sequence,
)
import sys
from typing import (
Any,
)
@ -33,7 +32,7 @@ from .peerinfo import (
PeerInfo,
)
PERMANENT_ADDR_TTL = sys.maxsize
PERMANENT_ADDR_TTL = 0
class PeerStore(IPeerStore):
@ -49,6 +48,8 @@ class PeerStore(IPeerStore):
"""
if peer_id in self.peer_data_map:
peer_data = self.peer_data_map[peer_id]
if peer_data.is_expired():
peer_data.clear_addrs()
return PeerInfo(peer_id, peer_data.get_addrs())
raise PeerStoreError("peer ID not found")
@ -84,6 +85,18 @@ class PeerStore(IPeerStore):
"""
return list(self.peer_data_map.keys())
def valid_peer_ids(self) -> list[ID]:
"""
:return: all of the valid peer IDs stored in peer store
"""
valid_peer_ids: list[ID] = []
for peer_id, peer_data in self.peer_data_map.items():
if not peer_data.is_expired():
valid_peer_ids.append(peer_id)
else:
peer_data.clear_addrs()
return valid_peer_ids
def get(self, peer_id: ID, key: str) -> Any:
"""
:param peer_id: peer ID to get peer data for
@ -108,7 +121,7 @@ class PeerStore(IPeerStore):
peer_data = self.peer_data_map[peer_id]
peer_data.put_metadata(key, val)
def add_addr(self, peer_id: ID, addr: Multiaddr, ttl: int) -> None:
def add_addr(self, peer_id: ID, addr: Multiaddr, ttl: int = 0) -> None:
"""
:param peer_id: peer ID to add address for
:param addr:
@ -116,24 +129,30 @@ class PeerStore(IPeerStore):
"""
self.add_addrs(peer_id, [addr], ttl)
def add_addrs(self, peer_id: ID, addrs: Sequence[Multiaddr], ttl: int) -> None:
def add_addrs(self, peer_id: ID, addrs: Sequence[Multiaddr], ttl: int = 0) -> None:
"""
:param peer_id: peer ID to add address for
:param addrs:
:param ttl: time-to-live for the this record
"""
# Ignore ttl for now
peer_data = self.peer_data_map[peer_id]
peer_data.add_addrs(list(addrs))
peer_data.set_ttl(ttl)
peer_data.update_last_identified()
def addrs(self, peer_id: ID) -> list[Multiaddr]:
"""
:param peer_id: peer ID to get addrs for
:return: list of addrs
:return: list of addrs of a valid peer.
:raise PeerStoreError: if peer ID not found
"""
if peer_id in self.peer_data_map:
return self.peer_data_map[peer_id].get_addrs()
peer_data = self.peer_data_map[peer_id]
if not peer_data.is_expired():
return peer_data.get_addrs()
else:
peer_data.clear_addrs()
raise PeerStoreError("peer ID is expired")
raise PeerStoreError("peer ID not found")
def clear_addrs(self, peer_id: ID) -> None:
@ -153,7 +172,11 @@ class PeerStore(IPeerStore):
for peer_id in self.peer_data_map:
if len(self.peer_data_map[peer_id].get_addrs()) >= 1:
output.append(peer_id)
peer_data = self.peer_data_map[peer_id]
if not peer_data.is_expired():
output.append(peer_id)
else:
peer_data.clear_addrs()
return output
def add_pubkey(self, peer_id: ID, pubkey: PublicKey) -> None:

View File

@ -23,16 +23,20 @@ class Multiselect(IMultiselectMuxer):
communication.
"""
handlers: dict[TProtocol, StreamHandlerFn]
handlers: dict[TProtocol | None, StreamHandlerFn | None]
def __init__(
self, default_handlers: dict[TProtocol, StreamHandlerFn] = None
self,
default_handlers: None
| (dict[TProtocol | None, StreamHandlerFn | None]) = None,
) -> None:
if not default_handlers:
default_handlers = {}
self.handlers = default_handlers
def add_handler(self, protocol: TProtocol, handler: StreamHandlerFn) -> None:
def add_handler(
self, protocol: TProtocol | None, handler: StreamHandlerFn | None
) -> None:
"""
Store the handler with the given protocol.
@ -41,9 +45,10 @@ class Multiselect(IMultiselectMuxer):
"""
self.handlers[protocol] = handler
# FIXME: Make TProtocol Optional[TProtocol] to keep types consistent
async def negotiate(
self, communicator: IMultiselectCommunicator
) -> tuple[TProtocol, StreamHandlerFn]:
) -> tuple[TProtocol, StreamHandlerFn | None]:
"""
Negotiate performs protocol selection.
@ -60,7 +65,7 @@ class Multiselect(IMultiselectMuxer):
raise MultiselectError() from error
if command == "ls":
supported_protocols = list(self.handlers.keys())
supported_protocols = [p for p in self.handlers.keys() if p is not None]
response = "\n".join(supported_protocols) + "\n"
try:
@ -82,6 +87,8 @@ class Multiselect(IMultiselectMuxer):
except MultiselectCommunicatorError as error:
raise MultiselectError() from error
raise MultiselectError("Negotiation failed: no matching protocol")
async def handshake(self, communicator: IMultiselectCommunicator) -> None:
"""
Perform handshake to agree on multiselect protocol.

View File

@ -22,6 +22,9 @@ from libp2p.utils import (
encode_varint_prefixed,
)
from .exceptions import (
PubsubRouterError,
)
from .pb import (
rpc_pb2,
)
@ -37,7 +40,7 @@ logger = logging.getLogger("libp2p.pubsub.floodsub")
class FloodSub(IPubsubRouter):
protocols: list[TProtocol]
pubsub: Pubsub
pubsub: Pubsub | None
def __init__(self, protocols: Sequence[TProtocol]) -> None:
self.protocols = list(protocols)
@ -58,7 +61,7 @@ class FloodSub(IPubsubRouter):
"""
self.pubsub = pubsub
def add_peer(self, peer_id: ID, protocol_id: TProtocol) -> None:
def add_peer(self, peer_id: ID, protocol_id: TProtocol | None) -> None:
"""
Notifies the router that a new peer has been connected.
@ -108,17 +111,22 @@ class FloodSub(IPubsubRouter):
logger.debug("publishing message %s", pubsub_msg)
if self.pubsub is None:
raise PubsubRouterError("pubsub not attached to this instance")
else:
pubsub = self.pubsub
for peer_id in peers_gen:
if peer_id not in self.pubsub.peers:
if peer_id not in pubsub.peers:
continue
stream = self.pubsub.peers[peer_id]
stream = pubsub.peers[peer_id]
# FIXME: We should add a `WriteMsg` similar to write delimited messages.
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107
try:
await stream.write(encode_varint_prefixed(rpc_msg.SerializeToString()))
except StreamClosed:
logger.debug("Fail to publish message to %s: stream closed", peer_id)
self.pubsub._handle_dead_peer(peer_id)
pubsub._handle_dead_peer(peer_id)
async def join(self, topic: str) -> None:
"""
@ -150,12 +158,16 @@ class FloodSub(IPubsubRouter):
:param origin: peer id of the peer the message originate from.
:return: a generator of the peer ids who we send data to.
"""
if self.pubsub is None:
raise PubsubRouterError("pubsub not attached to this instance")
else:
pubsub = self.pubsub
for topic in topic_ids:
if topic not in self.pubsub.peer_topics:
if topic not in pubsub.peer_topics:
continue
for peer_id in self.pubsub.peer_topics[topic]:
for peer_id in pubsub.peer_topics[topic]:
if peer_id in (msg_forwarder, origin):
continue
if peer_id not in self.pubsub.peers:
if peer_id not in pubsub.peers:
continue
yield peer_id

View File

@ -10,6 +10,7 @@ from collections.abc import (
)
import logging
import random
import time
from typing import (
Any,
DefaultDict,
@ -31,6 +32,8 @@ from libp2p.peer.id import (
)
from libp2p.peer.peerinfo import (
PeerInfo,
peer_info_from_bytes,
peer_info_to_bytes,
)
from libp2p.peer.peerstore import (
PERMANENT_ADDR_TTL,
@ -66,7 +69,7 @@ logger = logging.getLogger("libp2p.pubsub.gossipsub")
class GossipSub(IPubsubRouter, Service):
protocols: list[TProtocol]
pubsub: Pubsub
pubsub: Pubsub | None
degree: int
degree_high: int
@ -80,8 +83,7 @@ class GossipSub(IPubsubRouter, Service):
# The protocol peer supports
peer_protocol: dict[ID, TProtocol]
# TODO: Add `time_since_last_publish`
# Create topic --> time since last publish map.
time_since_last_publish: dict[str, int]
mcache: MessageCache
@ -92,13 +94,19 @@ class GossipSub(IPubsubRouter, Service):
direct_connect_initial_delay: float
direct_connect_interval: int
do_px: bool
px_peers_count: int
back_off: dict[str, dict[ID, int]]
prune_back_off: int
unsubscribe_back_off: int
def __init__(
self,
protocols: Sequence[TProtocol],
degree: int,
degree_low: int,
degree_high: int,
direct_peers: Sequence[PeerInfo] = None,
direct_peers: Sequence[PeerInfo] | None = None,
time_to_live: int = 60,
gossip_window: int = 3,
gossip_history: int = 5,
@ -106,6 +114,10 @@ class GossipSub(IPubsubRouter, Service):
heartbeat_interval: int = 120,
direct_connect_initial_delay: float = 0.1,
direct_connect_interval: int = 300,
do_px: bool = False,
px_peers_count: int = 16,
prune_back_off: int = 60,
unsubscribe_back_off: int = 10,
) -> None:
self.protocols = list(protocols)
self.pubsub = None
@ -138,10 +150,15 @@ class GossipSub(IPubsubRouter, Service):
self.direct_peers[direct_peer.peer_id] = direct_peer
self.direct_connect_interval = direct_connect_interval
self.direct_connect_initial_delay = direct_connect_initial_delay
self.time_since_last_publish = {}
self.do_px = do_px
self.px_peers_count = px_peers_count
self.back_off = dict()
self.prune_back_off = prune_back_off
self.unsubscribe_back_off = unsubscribe_back_off
async def run(self) -> None:
if self.pubsub is None:
raise NoPubsubAttached
self.manager.run_daemon_task(self.heartbeat)
if len(self.direct_peers) > 0:
self.manager.run_daemon_task(self.direct_connect_heartbeat)
@ -172,7 +189,7 @@ class GossipSub(IPubsubRouter, Service):
logger.debug("attached to pusub")
def add_peer(self, peer_id: ID, protocol_id: TProtocol) -> None:
def add_peer(self, peer_id: ID, protocol_id: TProtocol | None) -> None:
"""
Notifies the router that a new peer has been connected.
@ -181,6 +198,9 @@ class GossipSub(IPubsubRouter, Service):
"""
logger.debug("adding peer %s with protocol %s", peer_id, protocol_id)
if protocol_id is None:
raise ValueError("Protocol cannot be None")
if protocol_id not in (PROTOCOL_ID, floodsub.PROTOCOL_ID):
# We should never enter here. Becuase the `protocol_id` is registered by
# your pubsub instance in multistream-select, but it is not the protocol
@ -242,17 +262,20 @@ class GossipSub(IPubsubRouter, Service):
logger.debug("publishing message %s", pubsub_msg)
for peer_id in peers_gen:
if self.pubsub is None:
raise NoPubsubAttached
if peer_id not in self.pubsub.peers:
continue
stream = self.pubsub.peers[peer_id]
# FIXME: We should add a `WriteMsg` similar to write delimited messages.
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107
# TODO: Go use `sendRPC`, which possibly piggybacks gossip/control messages.
try:
await stream.write(encode_varint_prefixed(rpc_msg.SerializeToString()))
except StreamClosed:
logger.debug("Fail to publish message to %s: stream closed", peer_id)
self.pubsub._handle_dead_peer(peer_id)
for topic in pubsub_msg.topicIDs:
self.time_since_last_publish[topic] = int(time.time())
def _get_peers_to_send(
self, topic_ids: Iterable[str], msg_forwarder: ID, origin: ID
@ -266,6 +289,8 @@ class GossipSub(IPubsubRouter, Service):
"""
send_to: set[ID] = set()
for topic in topic_ids:
if self.pubsub is None:
raise NoPubsubAttached
if topic not in self.pubsub.peer_topics:
continue
@ -315,6 +340,9 @@ class GossipSub(IPubsubRouter, Service):
:param topic: topic to join
"""
if self.pubsub is None:
raise NoPubsubAttached
logger.debug("joining topic %s", topic)
if topic in self.mesh:
@ -323,15 +351,22 @@ class GossipSub(IPubsubRouter, Service):
self.mesh[topic] = set()
topic_in_fanout: bool = topic in self.fanout
fanout_peers: set[ID] = self.fanout[topic] if topic_in_fanout else set()
fanout_peers: set[ID] = set()
if topic_in_fanout:
for peer in self.fanout[topic]:
if self._check_back_off(peer, topic):
continue
fanout_peers.add(peer)
fanout_size = len(fanout_peers)
if not topic_in_fanout or (topic_in_fanout and fanout_size < self.degree):
# There are less than D peers (let this number be x)
# in the fanout for a topic (or the topic is not in the fanout).
# Selects the remaining number of peers (D-x) from peers.gossipsub[topic].
if topic in self.pubsub.peer_topics:
if self.pubsub is not None and topic in self.pubsub.peer_topics:
selected_peers = self._get_in_topic_gossipsub_peers_from_minus(
topic, self.degree - fanout_size, fanout_peers
topic, self.degree - fanout_size, fanout_peers, True
)
# Combine fanout peers with selected peers
fanout_peers.update(selected_peers)
@ -342,6 +377,7 @@ class GossipSub(IPubsubRouter, Service):
await self.emit_graft(topic, peer)
self.fanout.pop(topic, None)
self.time_since_last_publish.pop(topic, None)
async def leave(self, topic: str) -> None:
# Note: the comments here are the near-exact algorithm description from the spec
@ -357,7 +393,8 @@ class GossipSub(IPubsubRouter, Service):
return
# Notify the peers in mesh[topic] with a PRUNE(topic) message
for peer in self.mesh[topic]:
await self.emit_prune(topic, peer)
await self.emit_prune(topic, peer, self.do_px, True)
self._add_back_off(peer, topic, True)
# Forget mesh[topic]
self.mesh.pop(topic, None)
@ -447,8 +484,8 @@ class GossipSub(IPubsubRouter, Service):
self.fanout_heartbeat()
# Get the peers to send IHAVE to
peers_to_gossip = self.gossip_heartbeat()
# Pack GRAFT, PRUNE and IHAVE for the same peer into one control message and
# send it
# Pack(piggyback) GRAFT, PRUNE and IHAVE for the same peer into
# one control message and send it
await self._emit_control_msgs(
peers_to_graft, peers_to_prune, peers_to_gossip
)
@ -464,6 +501,8 @@ class GossipSub(IPubsubRouter, Service):
await trio.sleep(self.direct_connect_initial_delay)
while True:
for direct_peer in self.direct_peers:
if self.pubsub is None:
raise NoPubsubAttached
if direct_peer not in self.pubsub.peers:
try:
await self.pubsub.host.connect(self.direct_peers[direct_peer])
@ -481,6 +520,8 @@ class GossipSub(IPubsubRouter, Service):
peers_to_graft: DefaultDict[ID, list[str]] = defaultdict(list)
peers_to_prune: DefaultDict[ID, list[str]] = defaultdict(list)
for topic in self.mesh:
if self.pubsub is None:
raise NoPubsubAttached
# Skip if no peers have subscribed to the topic
if topic not in self.pubsub.peer_topics:
continue
@ -489,7 +530,7 @@ class GossipSub(IPubsubRouter, Service):
if num_mesh_peers_in_topic < self.degree_low:
# Select D - |mesh[topic]| peers from peers.gossipsub[topic] - mesh[topic] # noqa: E501
selected_peers = self._get_in_topic_gossipsub_peers_from_minus(
topic, self.degree - num_mesh_peers_in_topic, self.mesh[topic]
topic, self.degree - num_mesh_peers_in_topic, self.mesh[topic], True
)
for peer in selected_peers:
@ -512,72 +553,97 @@ class GossipSub(IPubsubRouter, Service):
peers_to_prune[peer].append(topic)
return peers_to_graft, peers_to_prune
def _handle_topic_heartbeat(
self,
topic: str,
current_peers: set[ID],
is_fanout: bool = False,
peers_to_gossip: DefaultDict[ID, dict[str, list[str]]] | None = None,
) -> tuple[set[ID], bool]:
"""
Helper method to handle heartbeat for a single topic,
supporting both fanout and gossip.
:param topic: The topic to handle
:param current_peers: Current set of peers in the topic
:param is_fanout: Whether this is a fanout topic (affects expiration check)
:param peers_to_gossip: Optional dictionary to store peers to gossip to
:return: Tuple of (updated_peers, should_remove_topic)
"""
if self.pubsub is None:
raise NoPubsubAttached
# Skip if no peers have subscribed to the topic
if topic not in self.pubsub.peer_topics:
return current_peers, False
# For fanout topics, check if we should remove the topic
if is_fanout:
if self.time_since_last_publish.get(topic, 0) + self.time_to_live < int(
time.time()
):
return set(), True
# Check if peers are still in the topic and remove the ones that are not
in_topic_peers: set[ID] = {
peer for peer in current_peers if peer in self.pubsub.peer_topics[topic]
}
# If we need more peers to reach target degree
if len(in_topic_peers) < self.degree:
# Select additional peers from peers.gossipsub[topic]
selected_peers = self._get_in_topic_gossipsub_peers_from_minus(
topic, self.degree - len(in_topic_peers), in_topic_peers, True
)
# Add the selected peers
in_topic_peers.update(selected_peers)
# Handle gossip if requested
if peers_to_gossip is not None:
msg_ids = self.mcache.window(topic)
if msg_ids:
# Select D peers from peers.gossipsub[topic] excluding current peers
peers_to_emit_ihave_to = self._get_in_topic_gossipsub_peers_from_minus(
topic, self.degree, current_peers, True
)
msg_id_strs = [str(msg_id) for msg_id in msg_ids]
for peer in peers_to_emit_ihave_to:
peers_to_gossip[peer][topic] = msg_id_strs
return in_topic_peers, False
def fanout_heartbeat(self) -> None:
# Note: the comments here are the exact pseudocode from the spec
for topic in self.fanout:
# Delete topic entry if it's not in `pubsub.peer_topics`
# or (TODO) if it's time-since-last-published > ttl
if topic not in self.pubsub.peer_topics:
# Remove topic from fanout
"""
Maintain fanout topics by:
1. Removing expired topics
2. Removing peers that are no longer in the topic
3. Adding new peers if needed to maintain the target degree
"""
for topic in list(self.fanout):
updated_peers, should_remove = self._handle_topic_heartbeat(
topic, self.fanout[topic], is_fanout=True
)
if should_remove:
del self.fanout[topic]
else:
# Check if fanout peers are still in the topic and remove the ones that are not # noqa: E501
# ref: https://github.com/libp2p/go-libp2p-pubsub/blob/01b9825fbee1848751d90a8469e3f5f43bac8466/gossipsub.go#L498-L504 # noqa: E501
in_topic_fanout_peers = [
peer
for peer in self.fanout[topic]
if peer in self.pubsub.peer_topics[topic]
]
self.fanout[topic] = set(in_topic_fanout_peers)
num_fanout_peers_in_topic = len(self.fanout[topic])
# If |fanout[topic]| < D
if num_fanout_peers_in_topic < self.degree:
# Select D - |fanout[topic]| peers from peers.gossipsub[topic] - fanout[topic] # noqa: E501
selected_peers = self._get_in_topic_gossipsub_peers_from_minus(
topic,
self.degree - num_fanout_peers_in_topic,
self.fanout[topic],
)
# Add the peers to fanout[topic]
self.fanout[topic].update(selected_peers)
self.fanout[topic] = updated_peers
def gossip_heartbeat(self) -> DefaultDict[ID, dict[str, list[str]]]:
peers_to_gossip: DefaultDict[ID, dict[str, list[str]]] = defaultdict(dict)
# Handle mesh topics
for topic in self.mesh:
msg_ids = self.mcache.window(topic)
if msg_ids:
# Get all pubsub peers in a topic and only add them if they are
# gossipsub peers too
if topic in self.pubsub.peer_topics:
# Select D peers from peers.gossipsub[topic]
peers_to_emit_ihave_to = (
self._get_in_topic_gossipsub_peers_from_minus(
topic, self.degree, self.mesh[topic]
)
)
self._handle_topic_heartbeat(
topic, self.mesh[topic], peers_to_gossip=peers_to_gossip
)
msg_id_strs = [str(msg_id) for msg_id in msg_ids]
for peer in peers_to_emit_ihave_to:
peers_to_gossip[peer][topic] = msg_id_strs
# TODO: Refactor and Dedup. This section is the roughly the same as the above.
# Do the same for fanout, for all topics not already hit in mesh
# Handle fanout topics that aren't in mesh
for topic in self.fanout:
msg_ids = self.mcache.window(topic)
if msg_ids:
# Get all pubsub peers in topic and only add if they are
# gossipsub peers also
if topic in self.pubsub.peer_topics:
# Select D peers from peers.gossipsub[topic]
peers_to_emit_ihave_to = (
self._get_in_topic_gossipsub_peers_from_minus(
topic, self.degree, self.fanout[topic]
)
)
msg_id_strs = [str(msg) for msg in msg_ids]
for peer in peers_to_emit_ihave_to:
peers_to_gossip[peer][topic] = msg_id_strs
if topic not in self.mesh:
self._handle_topic_heartbeat(
topic, self.fanout[topic], peers_to_gossip=peers_to_gossip
)
return peers_to_gossip
@staticmethod
@ -612,21 +678,109 @@ class GossipSub(IPubsubRouter, Service):
return selection
def _get_in_topic_gossipsub_peers_from_minus(
self, topic: str, num_to_select: int, minus: Iterable[ID]
self,
topic: str,
num_to_select: int,
minus: Iterable[ID],
backoff_check: bool = False,
) -> list[ID]:
if self.pubsub is None:
raise NoPubsubAttached
gossipsub_peers_in_topic = {
peer_id
for peer_id in self.pubsub.peer_topics[topic]
if self.peer_protocol[peer_id] == PROTOCOL_ID
}
if backoff_check:
# filter out peers that are in back off for this topic
gossipsub_peers_in_topic = {
peer_id
for peer_id in gossipsub_peers_in_topic
if self._check_back_off(peer_id, topic) is False
}
return self.select_from_minus(num_to_select, gossipsub_peers_in_topic, minus)
def _add_back_off(
self, peer: ID, topic: str, is_unsubscribe: bool, backoff_duration: int = 0
) -> None:
"""
Add back off for a peer in a topic.
:param peer: peer to add back off for
:param topic: topic to add back off for
:param is_unsubscribe: whether this is an unsubscribe operation
:param backoff_duration: duration of back off in seconds, if 0, use default
"""
if topic not in self.back_off:
self.back_off[topic] = dict()
backoff_till = int(time.time())
if backoff_duration > 0:
backoff_till += backoff_duration
else:
if is_unsubscribe:
backoff_till += self.unsubscribe_back_off
else:
backoff_till += self.prune_back_off
if peer not in self.back_off[topic]:
self.back_off[topic][peer] = backoff_till
else:
self.back_off[topic][peer] = max(self.back_off[topic][peer], backoff_till)
def _check_back_off(self, peer: ID, topic: str) -> bool:
"""
Check if a peer is in back off for a topic and cleanup expired back off entries.
:param peer: peer to check
:param topic: topic to check
:return: True if the peer is in back off, False otherwise
"""
if topic not in self.back_off or peer not in self.back_off[topic]:
return False
if self.back_off[topic].get(peer, 0) > int(time.time()):
return True
else:
del self.back_off[topic][peer]
return False
async def _do_px(self, px_peers: list[rpc_pb2.PeerInfo]) -> None:
if len(px_peers) > self.px_peers_count:
px_peers = px_peers[: self.px_peers_count]
for peer in px_peers:
peer_id: ID = ID(peer.peerID)
if self.pubsub and peer_id in self.pubsub.peers:
continue
try:
peer_info = peer_info_from_bytes(peer.signedPeerRecord)
try:
if self.pubsub is None:
raise NoPubsubAttached
await self.pubsub.host.connect(peer_info)
except Exception as e:
logger.warning(
"failed to connect to px peer %s: %s",
peer_id,
e,
)
continue
except Exception as e:
logger.warning(
"failed to parse peer info from px peer %s: %s",
peer_id,
e,
)
continue
# RPC handlers
async def handle_ihave(
self, ihave_msg: rpc_pb2.ControlIHave, sender_peer_id: ID
) -> None:
"""Checks the seen set and requests unknown messages with an IWANT message."""
if self.pubsub is None:
raise NoPubsubAttached
# Get list of all seen (seqnos, from) from the (seqno, from) tuples in
# seen_messages cache
seen_seqnos_and_peers = [
@ -659,7 +813,7 @@ class GossipSub(IPubsubRouter, Service):
msgs_to_forward: list[rpc_pb2.Message] = []
for msg_id_iwant in msg_ids:
# Check if the wanted message ID is present in mcache
msg: rpc_pb2.Message = self.mcache.get(msg_id_iwant)
msg: rpc_pb2.Message | None = self.mcache.get(msg_id_iwant)
# Cache hit
if msg:
@ -677,6 +831,8 @@ class GossipSub(IPubsubRouter, Service):
# 2) Serialize that packet
rpc_msg: bytes = packet.SerializeToString()
if self.pubsub is None:
raise NoPubsubAttached
# 3) Get the stream to this peer
if sender_peer_id not in self.pubsub.peers:
@ -709,31 +865,53 @@ class GossipSub(IPubsubRouter, Service):
logger.warning(
"GRAFT: ignoring request from direct peer %s", sender_peer_id
)
await self.emit_prune(topic, sender_peer_id)
await self.emit_prune(topic, sender_peer_id, False, False)
return
if self._check_back_off(sender_peer_id, topic):
logger.warning(
"GRAFT: ignoring request from %s, back off until %d",
sender_peer_id,
self.back_off[topic][sender_peer_id],
)
self._add_back_off(sender_peer_id, topic, False)
await self.emit_prune(topic, sender_peer_id, False, False)
return
if sender_peer_id not in self.mesh[topic]:
self.mesh[topic].add(sender_peer_id)
else:
# Respond with PRUNE if not subscribed to the topic
await self.emit_prune(topic, sender_peer_id)
await self.emit_prune(topic, sender_peer_id, self.do_px, False)
async def handle_prune(
self, prune_msg: rpc_pb2.ControlPrune, sender_peer_id: ID
) -> None:
topic: str = prune_msg.topicID
backoff_till: int = prune_msg.backoff
px_peers: list[rpc_pb2.PeerInfo] = []
for peer in prune_msg.peers:
px_peers.append(peer)
# Remove peer from mesh for topic
if topic in self.mesh:
if backoff_till > 0:
self._add_back_off(sender_peer_id, topic, False, backoff_till)
else:
self._add_back_off(sender_peer_id, topic, False)
self.mesh[topic].discard(sender_peer_id)
if px_peers:
await self._do_px(px_peers)
# RPC emitters
def pack_control_msgs(
self,
ihave_msgs: list[rpc_pb2.ControlIHave],
graft_msgs: list[rpc_pb2.ControlGraft],
prune_msgs: list[rpc_pb2.ControlPrune],
ihave_msgs: list[rpc_pb2.ControlIHave] | None,
graft_msgs: list[rpc_pb2.ControlGraft] | None,
prune_msgs: list[rpc_pb2.ControlPrune] | None,
) -> rpc_pb2.ControlMessage:
control_msg: rpc_pb2.ControlMessage = rpc_pb2.ControlMessage()
if ihave_msgs:
@ -765,7 +943,7 @@ class GossipSub(IPubsubRouter, Service):
await self.emit_control_message(control_msg, to_peer)
async def emit_graft(self, topic: str, to_peer: ID) -> None:
async def emit_graft(self, topic: str, id: ID) -> None:
"""Emit graft message, sent to to_peer, for topic."""
graft_msg: rpc_pb2.ControlGraft = rpc_pb2.ControlGraft()
graft_msg.topicID = topic
@ -773,13 +951,34 @@ class GossipSub(IPubsubRouter, Service):
control_msg: rpc_pb2.ControlMessage = rpc_pb2.ControlMessage()
control_msg.graft.extend([graft_msg])
await self.emit_control_message(control_msg, to_peer)
await self.emit_control_message(control_msg, id)
async def emit_prune(self, topic: str, to_peer: ID) -> None:
async def emit_prune(
self, topic: str, to_peer: ID, do_px: bool, is_unsubscribe: bool
) -> None:
"""Emit graft message, sent to to_peer, for topic."""
prune_msg: rpc_pb2.ControlPrune = rpc_pb2.ControlPrune()
prune_msg.topicID = topic
back_off_duration = self.prune_back_off
if is_unsubscribe:
back_off_duration = self.unsubscribe_back_off
prune_msg.backoff = back_off_duration
if do_px:
exchange_peers = self._get_in_topic_gossipsub_peers_from_minus(
topic, self.px_peers_count, [to_peer]
)
for peer in exchange_peers:
if self.pubsub is None:
raise NoPubsubAttached
peer_info = self.pubsub.host.get_peerstore().peer_info(peer)
signed_peer_record: rpc_pb2.PeerInfo = rpc_pb2.PeerInfo()
signed_peer_record.peerID = peer.to_bytes()
signed_peer_record.signedPeerRecord = peer_info_to_bytes(peer_info)
prune_msg.peers.append(signed_peer_record)
control_msg: rpc_pb2.ControlMessage = rpc_pb2.ControlMessage()
control_msg.prune.extend([prune_msg])
@ -788,6 +987,8 @@ class GossipSub(IPubsubRouter, Service):
async def emit_control_message(
self, control_msg: rpc_pb2.ControlMessage, to_peer: ID
) -> None:
if self.pubsub is None:
raise NoPubsubAttached
# Add control message to packet
packet: rpc_pb2.RPC = rpc_pb2.RPC()
packet.control.CopyFrom(control_msg)

View File

@ -1,9 +1,6 @@
from collections.abc import (
Sequence,
)
from typing import (
Optional,
)
from .pb import (
rpc_pb2,
@ -66,7 +63,7 @@ class MessageCache:
self.history[0].append(CacheEntry(mid, msg.topicIDs))
def get(self, mid: tuple[bytes, bytes]) -> Optional[rpc_pb2.Message]:
def get(self, mid: tuple[bytes, bytes]) -> rpc_pb2.Message | None:
"""
Get a message from the mcache.

View File

@ -47,6 +47,13 @@ message ControlGraft {
message ControlPrune {
optional string topicID = 1;
repeated PeerInfo peers = 2;
optional uint64 backoff = 3;
}
message PeerInfo {
optional bytes peerID = 1;
optional bytes signedPeerRecord = 2;
}
message TopicDescriptor {

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: libp2p/pubsub/pb/rpc.proto
# source: rpc.proto
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
@ -13,37 +13,39 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1alibp2p/pubsub/pb/rpc.proto\x12\tpubsub.pb\"\xb4\x01\n\x03RPC\x12-\n\rsubscriptions\x18\x01 \x03(\x0b\x32\x16.pubsub.pb.RPC.SubOpts\x12#\n\x07publish\x18\x02 \x03(\x0b\x32\x12.pubsub.pb.Message\x12*\n\x07\x63ontrol\x18\x03 \x01(\x0b\x32\x19.pubsub.pb.ControlMessage\x1a-\n\x07SubOpts\x12\x11\n\tsubscribe\x18\x01 \x01(\x08\x12\x0f\n\x07topicid\x18\x02 \x01(\t\"i\n\x07Message\x12\x0f\n\x07\x66rom_id\x18\x01 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\x12\r\n\x05seqno\x18\x03 \x01(\x0c\x12\x10\n\x08topicIDs\x18\x04 \x03(\t\x12\x11\n\tsignature\x18\x05 \x01(\x0c\x12\x0b\n\x03key\x18\x06 \x01(\x0c\"\xb0\x01\n\x0e\x43ontrolMessage\x12&\n\x05ihave\x18\x01 \x03(\x0b\x32\x17.pubsub.pb.ControlIHave\x12&\n\x05iwant\x18\x02 \x03(\x0b\x32\x17.pubsub.pb.ControlIWant\x12&\n\x05graft\x18\x03 \x03(\x0b\x32\x17.pubsub.pb.ControlGraft\x12&\n\x05prune\x18\x04 \x03(\x0b\x32\x17.pubsub.pb.ControlPrune\"3\n\x0c\x43ontrolIHave\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\x12\n\nmessageIDs\x18\x02 \x03(\t\"\"\n\x0c\x43ontrolIWant\x12\x12\n\nmessageIDs\x18\x01 \x03(\t\"\x1f\n\x0c\x43ontrolGraft\x12\x0f\n\x07topicID\x18\x01 \x01(\t\"\x1f\n\x0c\x43ontrolPrune\x12\x0f\n\x07topicID\x18\x01 \x01(\t\"\x87\x03\n\x0fTopicDescriptor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x31\n\x04\x61uth\x18\x02 \x01(\x0b\x32#.pubsub.pb.TopicDescriptor.AuthOpts\x12/\n\x03\x65nc\x18\x03 \x01(\x0b\x32\".pubsub.pb.TopicDescriptor.EncOpts\x1a|\n\x08\x41uthOpts\x12:\n\x04mode\x18\x01 \x01(\x0e\x32,.pubsub.pb.TopicDescriptor.AuthOpts.AuthMode\x12\x0c\n\x04keys\x18\x02 \x03(\x0c\"&\n\x08\x41uthMode\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03KEY\x10\x01\x12\x07\n\x03WOT\x10\x02\x1a\x83\x01\n\x07\x45ncOpts\x12\x38\n\x04mode\x18\x01 \x01(\x0e\x32*.pubsub.pb.TopicDescriptor.EncOpts.EncMode\x12\x11\n\tkeyHashes\x18\x02 \x03(\x0c\"+\n\x07\x45ncMode\x12\x08\n\x04NONE\x10\x00\x12\r\n\tSHAREDKEY\x10\x01\x12\x07\n\x03WOT\x10\x02')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\trpc.proto\x12\tpubsub.pb\"\xb4\x01\n\x03RPC\x12-\n\rsubscriptions\x18\x01 \x03(\x0b\x32\x16.pubsub.pb.RPC.SubOpts\x12#\n\x07publish\x18\x02 \x03(\x0b\x32\x12.pubsub.pb.Message\x12*\n\x07\x63ontrol\x18\x03 \x01(\x0b\x32\x19.pubsub.pb.ControlMessage\x1a-\n\x07SubOpts\x12\x11\n\tsubscribe\x18\x01 \x01(\x08\x12\x0f\n\x07topicid\x18\x02 \x01(\t\"i\n\x07Message\x12\x0f\n\x07\x66rom_id\x18\x01 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\x12\r\n\x05seqno\x18\x03 \x01(\x0c\x12\x10\n\x08topicIDs\x18\x04 \x03(\t\x12\x11\n\tsignature\x18\x05 \x01(\x0c\x12\x0b\n\x03key\x18\x06 \x01(\x0c\"\xb0\x01\n\x0e\x43ontrolMessage\x12&\n\x05ihave\x18\x01 \x03(\x0b\x32\x17.pubsub.pb.ControlIHave\x12&\n\x05iwant\x18\x02 \x03(\x0b\x32\x17.pubsub.pb.ControlIWant\x12&\n\x05graft\x18\x03 \x03(\x0b\x32\x17.pubsub.pb.ControlGraft\x12&\n\x05prune\x18\x04 \x03(\x0b\x32\x17.pubsub.pb.ControlPrune\"3\n\x0c\x43ontrolIHave\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\x12\n\nmessageIDs\x18\x02 \x03(\t\"\"\n\x0c\x43ontrolIWant\x12\x12\n\nmessageIDs\x18\x01 \x03(\t\"\x1f\n\x0c\x43ontrolGraft\x12\x0f\n\x07topicID\x18\x01 \x01(\t\"T\n\x0c\x43ontrolPrune\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\"\n\x05peers\x18\x02 \x03(\x0b\x32\x13.pubsub.pb.PeerInfo\x12\x0f\n\x07\x62\x61\x63koff\x18\x03 \x01(\x04\"4\n\x08PeerInfo\x12\x0e\n\x06peerID\x18\x01 \x01(\x0c\x12\x18\n\x10signedPeerRecord\x18\x02 \x01(\x0c\"\x87\x03\n\x0fTopicDescriptor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x31\n\x04\x61uth\x18\x02 \x01(\x0b\x32#.pubsub.pb.TopicDescriptor.AuthOpts\x12/\n\x03\x65nc\x18\x03 \x01(\x0b\x32\".pubsub.pb.TopicDescriptor.EncOpts\x1a|\n\x08\x41uthOpts\x12:\n\x04mode\x18\x01 \x01(\x0e\x32,.pubsub.pb.TopicDescriptor.AuthOpts.AuthMode\x12\x0c\n\x04keys\x18\x02 \x03(\x0c\"&\n\x08\x41uthMode\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03KEY\x10\x01\x12\x07\n\x03WOT\x10\x02\x1a\x83\x01\n\x07\x45ncOpts\x12\x38\n\x04mode\x18\x01 \x01(\x0e\x32*.pubsub.pb.TopicDescriptor.EncOpts.EncMode\x12\x11\n\tkeyHashes\x18\x02 \x03(\x0c\"+\n\x07\x45ncMode\x12\x08\n\x04NONE\x10\x00\x12\r\n\tSHAREDKEY\x10\x01\x12\x07\n\x03WOT\x10\x02')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.pubsub.pb.rpc_pb2', globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'rpc_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_RPC._serialized_start=42
_RPC._serialized_end=222
_RPC_SUBOPTS._serialized_start=177
_RPC_SUBOPTS._serialized_end=222
_MESSAGE._serialized_start=224
_MESSAGE._serialized_end=329
_CONTROLMESSAGE._serialized_start=332
_CONTROLMESSAGE._serialized_end=508
_CONTROLIHAVE._serialized_start=510
_CONTROLIHAVE._serialized_end=561
_CONTROLIWANT._serialized_start=563
_CONTROLIWANT._serialized_end=597
_CONTROLGRAFT._serialized_start=599
_CONTROLGRAFT._serialized_end=630
_CONTROLPRUNE._serialized_start=632
_CONTROLPRUNE._serialized_end=663
_TOPICDESCRIPTOR._serialized_start=666
_TOPICDESCRIPTOR._serialized_end=1057
_TOPICDESCRIPTOR_AUTHOPTS._serialized_start=799
_TOPICDESCRIPTOR_AUTHOPTS._serialized_end=923
_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE._serialized_start=885
_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE._serialized_end=923
_TOPICDESCRIPTOR_ENCOPTS._serialized_start=926
_TOPICDESCRIPTOR_ENCOPTS._serialized_end=1057
_TOPICDESCRIPTOR_ENCOPTS_ENCMODE._serialized_start=1014
_TOPICDESCRIPTOR_ENCOPTS_ENCMODE._serialized_end=1057
_RPC._serialized_start=25
_RPC._serialized_end=205
_RPC_SUBOPTS._serialized_start=160
_RPC_SUBOPTS._serialized_end=205
_MESSAGE._serialized_start=207
_MESSAGE._serialized_end=312
_CONTROLMESSAGE._serialized_start=315
_CONTROLMESSAGE._serialized_end=491
_CONTROLIHAVE._serialized_start=493
_CONTROLIHAVE._serialized_end=544
_CONTROLIWANT._serialized_start=546
_CONTROLIWANT._serialized_end=580
_CONTROLGRAFT._serialized_start=582
_CONTROLGRAFT._serialized_end=613
_CONTROLPRUNE._serialized_start=615
_CONTROLPRUNE._serialized_end=699
_PEERINFO._serialized_start=701
_PEERINFO._serialized_end=753
_TOPICDESCRIPTOR._serialized_start=756
_TOPICDESCRIPTOR._serialized_end=1147
_TOPICDESCRIPTOR_AUTHOPTS._serialized_start=889
_TOPICDESCRIPTOR_AUTHOPTS._serialized_end=1013
_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE._serialized_start=975
_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE._serialized_end=1013
_TOPICDESCRIPTOR_ENCOPTS._serialized_start=1016
_TOPICDESCRIPTOR_ENCOPTS._serialized_end=1147
_TOPICDESCRIPTOR_ENCOPTS_ENCMODE._serialized_start=1104
_TOPICDESCRIPTOR_ENCOPTS_ENCMODE._serialized_end=1147
# @@protoc_insertion_point(module_scope)

View File

@ -179,17 +179,43 @@ class ControlPrune(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
TOPICID_FIELD_NUMBER: builtins.int
PEERS_FIELD_NUMBER: builtins.int
BACKOFF_FIELD_NUMBER: builtins.int
topicID: builtins.str
backoff: builtins.int
@property
def peers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___PeerInfo]: ...
def __init__(
self,
*,
topicID: builtins.str | None = ...,
peers: collections.abc.Iterable[global___PeerInfo] | None = ...,
backoff: builtins.int | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["topicID", b"topicID"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["topicID", b"topicID"]) -> None: ...
def HasField(self, field_name: typing.Literal["backoff", b"backoff", "topicID", b"topicID"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["backoff", b"backoff", "peers", b"peers", "topicID", b"topicID"]) -> None: ...
global___ControlPrune = ControlPrune
@typing.final
class PeerInfo(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
PEERID_FIELD_NUMBER: builtins.int
SIGNEDPEERRECORD_FIELD_NUMBER: builtins.int
peerID: builtins.bytes
signedPeerRecord: builtins.bytes
def __init__(
self,
*,
peerID: builtins.bytes | None = ...,
signedPeerRecord: builtins.bytes | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["peerID", b"peerID", "signedPeerRecord", b"signedPeerRecord"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["peerID", b"peerID", "signedPeerRecord", b"signedPeerRecord"]) -> None: ...
global___PeerInfo = PeerInfo
@typing.final
class TopicDescriptor(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

View File

@ -4,6 +4,7 @@ from __future__ import (
import base64
from collections.abc import (
Callable,
KeysView,
)
import functools
@ -11,7 +12,6 @@ import hashlib
import logging
import time
from typing import (
Callable,
NamedTuple,
cast,
)
@ -53,6 +53,9 @@ from libp2p.network.stream.exceptions import (
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerdata import (
PeerDataError,
)
from libp2p.tools.async_service import (
Service,
)
@ -120,7 +123,10 @@ class Pubsub(Service, IPubsub):
# Indicate if we should enforce signature verification
strict_signing: bool
sign_key: PrivateKey
sign_key: PrivateKey | None
# Set of blacklisted peer IDs
blacklisted_peers: set[ID]
event_handle_peer_queue_started: trio.Event
event_handle_dead_peer_queue_started: trio.Event
@ -129,7 +135,7 @@ class Pubsub(Service, IPubsub):
self,
host: IHost,
router: IPubsubRouter,
cache_size: int = None,
cache_size: int | None = None,
seen_ttl: int = 120,
sweep_interval: int = 60,
strict_signing: bool = True,
@ -201,6 +207,9 @@ class Pubsub(Service, IPubsub):
self.counter = int(time.time())
# Set of blacklisted peer IDs
self.blacklisted_peers = set()
self.event_handle_peer_queue_started = trio.Event()
self.event_handle_dead_peer_queue_started = trio.Event()
@ -320,6 +329,82 @@ class Pubsub(Service, IPubsub):
if topic in self.topic_validators
)
def add_to_blacklist(self, peer_id: ID) -> None:
"""
Add a peer to the blacklist.
When a peer is blacklisted:
- Any existing connection to that peer is immediately closed and removed
- The peer is removed from all topic subscription mappings
- Future connection attempts from this peer will be rejected
- Messages forwarded by or originating from this peer will be dropped
- The peer will not be able to participate in pubsub communication
:param peer_id: the peer ID to blacklist
"""
self.blacklisted_peers.add(peer_id)
logger.debug("Added peer %s to blacklist", peer_id)
self.manager.run_task(self._teardown_if_connected, peer_id)
async def _teardown_if_connected(self, peer_id: ID) -> None:
"""Close their stream and remove them if connected"""
stream = self.peers.get(peer_id)
if stream is not None:
try:
await stream.reset()
except Exception:
pass
del self.peers[peer_id]
# Also remove from any subscription maps:
for _topic, peerset in self.peer_topics.items():
if peer_id in peerset:
peerset.discard(peer_id)
def remove_from_blacklist(self, peer_id: ID) -> None:
"""
Remove a peer from the blacklist.
Once removed from the blacklist:
- The peer can establish new connections to this node
- Messages from this peer will be processed normally
- The peer can participate in topic subscriptions and message forwarding
:param peer_id: the peer ID to remove from blacklist
"""
self.blacklisted_peers.discard(peer_id)
logger.debug("Removed peer %s from blacklist", peer_id)
def is_peer_blacklisted(self, peer_id: ID) -> bool:
"""
Check if a peer is blacklisted.
:param peer_id: the peer ID to check
:return: True if peer is blacklisted, False otherwise
"""
return peer_id in self.blacklisted_peers
def clear_blacklist(self) -> None:
"""
Clear all peers from the blacklist.
This removes all blacklist restrictions, allowing previously blacklisted
peers to:
- Establish new connections
- Send and forward messages
- Participate in topic subscriptions
"""
self.blacklisted_peers.clear()
logger.debug("Cleared all peers from blacklist")
def get_blacklisted_peers(self) -> set[ID]:
"""
Get a copy of the current blacklisted peers.
Returns a snapshot of all currently blacklisted peer IDs. These peers
are completely isolated from pubsub communication - their connections
are rejected and their messages are dropped.
:return: a set containing all blacklisted peer IDs
"""
return self.blacklisted_peers.copy()
async def stream_handler(self, stream: INetStream) -> None:
"""
Stream handler for pubsub. Gets invoked whenever a new stream is
@ -346,6 +431,10 @@ class Pubsub(Service, IPubsub):
await self.event_handle_dead_peer_queue_started.wait()
async def _handle_new_peer(self, peer_id: ID) -> None:
if self.is_peer_blacklisted(peer_id):
logger.debug("Rejecting blacklisted peer %s", peer_id)
return
try:
stream: INetStream = await self.host.new_stream(peer_id, self.protocols)
except SwarmException as error:
@ -359,7 +448,6 @@ class Pubsub(Service, IPubsub):
except StreamClosed:
logger.debug("Fail to add new peer %s: stream closed", peer_id)
return
# TODO: Check if the peer in black list.
try:
self.router.add_peer(peer_id, stream.get_protocol())
except Exception as error:
@ -532,16 +620,22 @@ class Pubsub(Service, IPubsub):
logger.debug("Fail to message peer %s: stream closed", peer_id)
self._handle_dead_peer(peer_id)
async def publish(self, topic_id: str, data: bytes) -> None:
async def publish(self, topic_id: str | list[str], data: bytes) -> None:
"""
Publish data to a topic.
Publish data to a topic or multiple topics.
:param topic_id: topic which we are going to publish the data to
:param topic_id: topic (str) or topics (list[str]) to publish the data to
:param data: data which we are publishing
"""
# Handle both single topic (str) and multiple topics (list[str])
if isinstance(topic_id, str):
topic_ids = [topic_id]
else:
topic_ids = topic_id
msg = rpc_pb2.Message(
data=data,
topicIDs=[topic_id],
topicIDs=topic_ids,
# Origin is ourself.
from_id=self.my_id.to_bytes(),
seqno=self._next_seqno(),
@ -549,6 +643,9 @@ class Pubsub(Service, IPubsub):
if self.strict_signing:
priv_key = self.sign_key
if priv_key is None:
raise PeerDataError("private key not found")
signature = priv_key.sign(
PUBSUB_SIGNING_PREFIX.encode() + msg.SerializeToString()
)
@ -585,19 +682,18 @@ class Pubsub(Service, IPubsub):
# TODO: Implement throttle on async validators
if len(async_topic_validators) > 0:
# TODO: Use a better pattern
final_result: bool = True
# Appends to lists are thread safe in CPython
results = []
async def run_async_validator(func: AsyncValidatorFn) -> None:
nonlocal final_result
result = await func(msg_forwarder, msg)
final_result = final_result and result
results.append(result)
async with trio.open_nursery() as nursery:
for async_validator in async_topic_validators:
nursery.start_soon(run_async_validator, async_validator)
if not final_result:
if not all(results):
raise ValidationError(f"Validation failed for msg={msg}")
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
@ -609,9 +705,20 @@ class Pubsub(Service, IPubsub):
"""
logger.debug("attempting to publish message %s", msg)
# TODO: Check if the `source` is in the blacklist. If yes, reject.
# Check if the message forwarder (source) is in the blacklist. If yes, reject.
if self.is_peer_blacklisted(msg_forwarder):
logger.debug(
"Rejecting message from blacklisted source peer %s", msg_forwarder
)
return
# TODO: Check if the `from` is in the blacklist. If yes, reject.
# Check if the message originator (from) is in the blacklist. If yes, reject.
msg_from_peer = ID(msg.from_id)
if self.is_peer_blacklisted(msg_from_peer):
logger.debug(
"Rejecting message from blacklisted originator peer %s", msg_from_peer
)
return
# If the message is processed before, return(i.e., don't further process the message) # noqa: E501
if self._is_msg_seen(msg):

28
libp2p/relay/__init__.py Normal file
View File

@ -0,0 +1,28 @@
"""
Relay module for libp2p.
This package includes implementations of circuit relay protocols
for enabling connectivity between peers behind NATs or firewalls.
"""
# Import the circuit_v2 module to make it accessible
# through the relay package
from libp2p.relay.circuit_v2 import (
PROTOCOL_ID,
CircuitV2Protocol,
CircuitV2Transport,
RelayDiscovery,
RelayLimits,
RelayResourceManager,
Reservation,
)
__all__ = [
"CircuitV2Protocol",
"CircuitV2Transport",
"PROTOCOL_ID",
"RelayDiscovery",
"RelayLimits",
"RelayResourceManager",
"Reservation",
]

View File

@ -0,0 +1,32 @@
"""
Circuit Relay v2 implementation for libp2p.
This package implements the Circuit Relay v2 protocol as specified in:
https://github.com/libp2p/specs/blob/master/relay/circuit-v2.md
"""
from .discovery import (
RelayDiscovery,
)
from .protocol import (
PROTOCOL_ID,
CircuitV2Protocol,
)
from .resources import (
RelayLimits,
RelayResourceManager,
Reservation,
)
from .transport import (
CircuitV2Transport,
)
__all__ = [
"CircuitV2Protocol",
"PROTOCOL_ID",
"RelayLimits",
"Reservation",
"RelayResourceManager",
"CircuitV2Transport",
"RelayDiscovery",
]

View File

@ -0,0 +1,92 @@
"""
Configuration management for Circuit Relay v2.
This module handles configuration for relay roles, resource limits,
and discovery settings.
"""
from dataclasses import (
dataclass,
field,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
from .resources import (
RelayLimits,
)
@dataclass
class RelayConfig:
"""Configuration for Circuit Relay v2."""
# Role configuration
enable_hop: bool = False # Whether to act as a relay (hop)
enable_stop: bool = True # Whether to accept relayed connections (stop)
enable_client: bool = True # Whether to use relays for dialing
# Resource limits
limits: RelayLimits | None = None
# Discovery configuration
bootstrap_relays: list[PeerInfo] = field(default_factory=list)
min_relays: int = 3
max_relays: int = 20
discovery_interval: int = 300 # seconds
# Connection configuration
reservation_ttl: int = 3600 # seconds
max_circuit_duration: int = 3600 # seconds
max_circuit_bytes: int = 1024 * 1024 * 1024 # 1GB
def __post_init__(self) -> None:
"""Initialize default values."""
if self.limits is None:
self.limits = RelayLimits(
duration=self.max_circuit_duration,
data=self.max_circuit_bytes,
max_circuit_conns=8,
max_reservations=4,
)
@dataclass
class HopConfig:
"""Configuration specific to relay (hop) nodes."""
# Resource limits per IP
max_reservations_per_ip: int = 8
max_circuits_per_ip: int = 16
# Rate limiting
reservation_rate_per_ip: int = 4 # per minute
circuit_rate_per_ip: int = 8 # per minute
# Resource quotas
max_circuits_total: int = 64
max_reservations_total: int = 32
# Bandwidth limits
max_bandwidth_per_circuit: int = 1024 * 1024 # 1MB/s
max_bandwidth_total: int = 10 * 1024 * 1024 # 10MB/s
@dataclass
class ClientConfig:
"""Configuration specific to relay clients."""
# Relay selection
min_relay_score: float = 0.5
max_relay_latency: float = 1.0 # seconds
# Auto-relay settings
enable_auto_relay: bool = True
auto_relay_timeout: int = 30 # seconds
max_auto_relay_attempts: int = 3
# Reservation management
reservation_refresh_threshold: float = 0.8 # Refresh at 80% of TTL
max_concurrent_reservations: int = 2

View File

@ -0,0 +1,537 @@
"""
Discovery module for Circuit Relay v2.
This module handles discovering and tracking relay nodes in the network.
"""
from dataclasses import (
dataclass,
)
import logging
import time
from typing import (
Any,
Protocol as TypingProtocol,
cast,
runtime_checkable,
)
import trio
from libp2p.abc import (
IHost,
)
from libp2p.custom_types import (
TProtocol,
)
from libp2p.peer.id import (
ID,
)
from libp2p.tools.async_service import (
Service,
)
from .pb.circuit_pb2 import (
HopMessage,
)
from .protocol import (
PROTOCOL_ID,
)
from .protocol_buffer import (
StatusCode,
)
logger = logging.getLogger("libp2p.relay.circuit_v2.discovery")
# Constants
MAX_RELAYS_TO_TRACK = 10
DEFAULT_DISCOVERY_INTERVAL = 60 # seconds
STREAM_TIMEOUT = 10 # seconds
# Extended interfaces for type checking
@runtime_checkable
class IHostWithMultiselect(TypingProtocol):
"""Extended host interface with multiselect attribute."""
@property
def multiselect(self) -> Any:
"""Get the multiselect component."""
...
@dataclass
class RelayInfo:
"""Information about a discovered relay."""
peer_id: ID
discovered_at: float
last_seen: float
has_reservation: bool = False
reservation_expires_at: float | None = None
reservation_data_limit: int | None = None
class RelayDiscovery(Service):
"""
Discovery service for Circuit Relay v2 nodes.
This service discovers and keeps track of available relay nodes, and optionally
makes reservations with them.
"""
def __init__(
self,
host: IHost,
auto_reserve: bool = False,
discovery_interval: int = DEFAULT_DISCOVERY_INTERVAL,
max_relays: int = MAX_RELAYS_TO_TRACK,
) -> None:
"""
Initialize the discovery service.
Parameters
----------
host : IHost
The libp2p host this discovery service is running on
auto_reserve : bool
Whether to automatically make reservations with discovered relays
discovery_interval : int
How often to run discovery, in seconds
max_relays : int
Maximum number of relays to track
"""
super().__init__()
self.host = host
self.auto_reserve = auto_reserve
self.discovery_interval = discovery_interval
self.max_relays = max_relays
self._discovered_relays: dict[ID, RelayInfo] = {}
self._protocol_cache: dict[
ID, set[str]
] = {} # Cache protocol info to reduce queries
self.event_started = trio.Event()
self.is_running = False
async def run(self, *, task_status: Any = trio.TASK_STATUS_IGNORED) -> None:
"""Run the discovery service."""
try:
self.is_running = True
self.event_started.set()
task_status.started()
# Main discovery loop
async with trio.open_nursery() as nursery:
# Run initial discovery
nursery.start_soon(self.discover_relays)
# Set up periodic discovery
while True:
await trio.sleep(self.discovery_interval)
if not self.manager.is_running:
break
nursery.start_soon(self.discover_relays)
# Cleanup expired relays and reservations
await self._cleanup_expired()
finally:
self.is_running = False
async def discover_relays(self) -> None:
r"""
Discover relay nodes in the network.
This method queries the network for peers that support the
Circuit Relay v2 protocol.
"""
logger.debug("Starting relay discovery")
try:
# Get connected peers
connected_peers = self.host.get_connected_peers()
logger.debug(
"Checking %d connected peers for relay support", len(connected_peers)
)
# Check each peer if they support the relay protocol
for peer_id in connected_peers:
if peer_id == self.host.get_id():
continue # Skip ourselves
if peer_id in self._discovered_relays:
# Update last seen time for existing relay
self._discovered_relays[peer_id].last_seen = time.time()
continue
# Check if peer supports the relay protocol
with trio.move_on_after(5): # Don't wait too long for protocol info
if await self._supports_relay_protocol(peer_id):
await self._add_relay(peer_id)
# Limit number of relays we track
if len(self._discovered_relays) > self.max_relays:
# Sort by last seen time and keep only the most recent ones
sorted_relays = sorted(
self._discovered_relays.items(),
key=lambda x: x[1].last_seen,
reverse=True,
)
to_remove = sorted_relays[self.max_relays :]
for peer_id, _ in to_remove:
del self._discovered_relays[peer_id]
logger.debug(
"Discovery completed, tracking %d relays", len(self._discovered_relays)
)
except Exception as e:
logger.error("Error during relay discovery: %s", str(e))
async def _supports_relay_protocol(self, peer_id: ID) -> bool:
"""
Check if a peer supports the relay protocol.
Parameters
----------
peer_id : ID
The ID of the peer to check
Returns
-------
bool
True if the peer supports the relay protocol, False otherwise
"""
# Check cache first
if peer_id in self._protocol_cache:
return PROTOCOL_ID in self._protocol_cache[peer_id]
# Method 1: Try peerstore
result = await self._check_via_peerstore(peer_id)
if result is not None:
return result
# Method 2: Try direct stream connection
result = await self._check_via_direct_connection(peer_id)
if result is not None:
return result
# Method 3: Try protocols from mux
result = await self._check_via_mux(peer_id)
if result is not None:
return result
# Default: Cannot determine, assume false
return False
async def _check_via_peerstore(self, peer_id: ID) -> bool | None:
"""Check protocol support via peerstore."""
try:
peerstore = self.host.get_peerstore()
proto_getter = peerstore.get_protocols
if not callable(proto_getter):
return None
try:
# Try to get protocols
proto_result = proto_getter(peer_id)
# Get protocols list
protocols_list = []
if hasattr(proto_result, "__await__"):
protocols_list = await cast(Any, proto_result)
else:
protocols_list = proto_result
# Check result
if protocols_list is not None:
protocols = set(protocols_list)
self._protocol_cache[peer_id] = protocols
return PROTOCOL_ID in protocols
return False
except Exception as e:
logger.debug("Error getting protocols: %s", str(e))
return None
except Exception as e:
logger.debug("Error accessing peerstore: %s", str(e))
return None
async def _check_via_direct_connection(self, peer_id: ID) -> bool | None:
"""Check protocol support via direct connection."""
try:
with trio.fail_after(STREAM_TIMEOUT):
stream = await self.host.new_stream(peer_id, [PROTOCOL_ID])
if stream:
await stream.close()
self._protocol_cache[peer_id] = {PROTOCOL_ID}
return True
return False
except Exception as e:
logger.debug(
"Failed to open relay protocol stream to %s: %s", peer_id, str(e)
)
return None
async def _check_via_mux(self, peer_id: ID) -> bool | None:
"""Check protocol support via mux protocols."""
try:
if not (hasattr(self.host, "get_mux") and self.host.get_mux() is not None):
return None
mux = self.host.get_mux()
if not hasattr(mux, "protocols"):
return None
peer_protocols = set()
# Get protocols from mux with proper type safety
available_protocols = []
if hasattr(mux, "get_protocols"):
# Get protocols with proper typing
mux_protocols = mux.get_protocols()
if isinstance(mux_protocols, (list, tuple)):
available_protocols = list(mux_protocols)
for protocol in available_protocols:
try:
with trio.fail_after(2): # Quick check
# Ensure we have a proper protocol object
# Use string representation since we can't use isinstance
is_tprotocol = str(type(protocol)) == str(type(TProtocol))
protocol_obj = (
protocol if is_tprotocol else TProtocol(str(protocol))
)
stream = await self.host.new_stream(peer_id, [protocol_obj])
if stream:
peer_protocols.add(str(protocol_obj))
await stream.close()
except Exception:
pass # Ignore errors when closing the stream
self._protocol_cache[peer_id] = peer_protocols
protocol_str = str(PROTOCOL_ID)
for protocol in peer_protocols:
if protocol == protocol_str:
return True
return False
except Exception as e:
logger.debug("Error checking protocols via mux: %s", str(e))
return None
async def _add_relay(self, peer_id: ID) -> None:
"""
Add a peer as a relay and optionally make a reservation.
Parameters
----------
peer_id : ID
The ID of the peer to add as a relay
"""
now = time.time()
relay_info = RelayInfo(
peer_id=peer_id,
discovered_at=now,
last_seen=now,
)
self._discovered_relays[peer_id] = relay_info
logger.debug("Added relay %s to discovered relays", peer_id)
# If auto-reserve is enabled, make a reservation with this relay
if self.auto_reserve:
await self.make_reservation(peer_id)
async def make_reservation(self, peer_id: ID) -> bool:
"""
Make a reservation with a relay.
Parameters
----------
peer_id : ID
The ID of the relay to make a reservation with
Returns
-------
bool
True if reservation succeeded, False otherwise
"""
if peer_id not in self._discovered_relays:
logger.error("Cannot make reservation with unknown relay %s", peer_id)
return False
stream = None
try:
logger.debug("Making reservation with relay %s", peer_id)
# Open a stream to the relay with timeout
try:
with trio.fail_after(STREAM_TIMEOUT):
stream = await self.host.new_stream(peer_id, [PROTOCOL_ID])
if not stream:
logger.error("Failed to open stream to relay %s", peer_id)
return False
except trio.TooSlowError:
logger.error("Timeout opening stream to relay %s", peer_id)
return False
try:
# Create and send reservation request
request = HopMessage(
type=HopMessage.RESERVE,
peer=self.host.get_id().to_bytes(),
)
with trio.fail_after(STREAM_TIMEOUT):
await stream.write(request.SerializeToString())
# Wait for response
response_bytes = await stream.read()
if not response_bytes:
logger.error("No response received from relay %s", peer_id)
return False
# Parse response
response = HopMessage()
response.ParseFromString(response_bytes)
# Check if reservation was successful
if response.type == HopMessage.RESERVE and response.HasField(
"status"
):
# Access status code directly from protobuf object
status_code = getattr(response.status, "code", StatusCode.OK)
if status_code == StatusCode.OK:
# Update relay info with reservation details
relay_info = self._discovered_relays[peer_id]
relay_info.has_reservation = True
if response.HasField("reservation") and response.HasField(
"limit"
):
relay_info.reservation_expires_at = (
response.reservation.expire
)
relay_info.reservation_data_limit = response.limit.data
logger.debug(
"Successfully made reservation with relay %s", peer_id
)
return True
# Reservation failed
error_message = "Unknown error"
if response.HasField("status"):
# Access message directly from protobuf object
error_message = getattr(response.status, "message", "")
logger.warning(
"Reservation request rejected by relay %s: %s",
peer_id,
error_message,
)
return False
except trio.TooSlowError:
logger.error(
"Timeout during reservation process with relay %s", peer_id
)
return False
except Exception as e:
logger.error("Error making reservation with relay %s: %s", peer_id, str(e))
return False
finally:
# Always close the stream
if stream:
try:
await stream.close()
except Exception:
pass # Ignore errors when closing the stream
return False
async def _cleanup_expired(self) -> None:
"""Clean up expired relays and reservations."""
now = time.time()
to_remove = []
for peer_id, relay_info in self._discovered_relays.items():
# Check if relay hasn't been seen in a while (3x discovery interval)
if now - relay_info.last_seen > self.discovery_interval * 3:
to_remove.append(peer_id)
continue
# Check if reservation has expired
if (
relay_info.has_reservation
and relay_info.reservation_expires_at
and now > relay_info.reservation_expires_at
):
relay_info.has_reservation = False
relay_info.reservation_expires_at = None
relay_info.reservation_data_limit = None
# If auto-reserve is enabled, try to renew
if self.auto_reserve:
await self.make_reservation(peer_id)
# Remove expired relays
for peer_id in to_remove:
del self._discovered_relays[peer_id]
if peer_id in self._protocol_cache:
del self._protocol_cache[peer_id]
def get_relays(self) -> list[ID]:
"""
Get a list of discovered relay peer IDs.
Returns
-------
list[ID]
List of discovered relay peer IDs
"""
return list(self._discovered_relays.keys())
def get_relay_info(self, peer_id: ID) -> RelayInfo | None:
"""
Get information about a specific relay.
Parameters
----------
peer_id : ID
The ID of the relay to get information about
Returns
-------
Optional[RelayInfo]
Information about the relay, or None if not found
"""
return self._discovered_relays.get(peer_id)
def get_relay(self) -> ID | None:
"""
Get a single relay peer ID for connection purposes.
Prioritizes relays with active reservations.
Returns
-------
Optional[ID]
ID of a discovered relay, or None if no relays found
"""
if not self._discovered_relays:
return None
# First try to find a relay with an active reservation
for peer_id, relay_info in self._discovered_relays.items():
if relay_info and relay_info.has_reservation:
return peer_id
return next(iter(self._discovered_relays.keys()), None)

View File

@ -0,0 +1,16 @@
"""
Protocol buffer package for circuit_v2.
Contains generated protobuf code for circuit_v2 relay protocol.
"""
# Import the classes to be accessible directly from the package
from .circuit_pb2 import (
HopMessage,
Limit,
Reservation,
Status,
StopMessage,
)
__all__ = ["HopMessage", "Limit", "Reservation", "Status", "StopMessage"]

View File

@ -0,0 +1,55 @@
syntax = "proto3";
package circuit.pb.v2;
// Circuit v2 message types
message HopMessage {
enum Type {
RESERVE = 0;
CONNECT = 1;
STATUS = 2;
}
Type type = 1;
bytes peer = 2;
Reservation reservation = 3;
Limit limit = 4;
Status status = 5;
}
message StopMessage {
enum Type {
CONNECT = 0;
STATUS = 1;
}
Type type = 1;
bytes peer = 2;
Status status = 3;
}
message Reservation {
bytes voucher = 1;
bytes signature = 2;
int64 expire = 3;
}
message Limit {
int64 duration = 1;
int64 data = 2;
}
message Status {
enum Code {
OK = 0;
RESERVATION_REFUSED = 100;
RESOURCE_LIMIT_EXCEEDED = 101;
PERMISSION_DENIED = 102;
CONNECTION_FAILED = 200;
DIAL_REFUSED = 201;
STOP_FAILED = 300;
MALFORMED_MESSAGE = 400;
}
Code code = 1;
string message = 2;
}

View File

@ -0,0 +1,37 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: libp2p/relay/circuit_v2/pb/circuit.proto
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n(libp2p/relay/circuit_v2/pb/circuit.proto\x12\rcircuit.pb.v2\"\xf3\x01\n\nHopMessage\x12,\n\x04type\x18\x01 \x01(\x0e\x32\x1e.circuit.pb.v2.HopMessage.Type\x12\x0c\n\x04peer\x18\x02 \x01(\x0c\x12/\n\x0breservation\x18\x03 \x01(\x0b\x32\x1a.circuit.pb.v2.Reservation\x12#\n\x05limit\x18\x04 \x01(\x0b\x32\x14.circuit.pb.v2.Limit\x12%\n\x06status\x18\x05 \x01(\x0b\x32\x15.circuit.pb.v2.Status\",\n\x04Type\x12\x0b\n\x07RESERVE\x10\x00\x12\x0b\n\x07\x43ONNECT\x10\x01\x12\n\n\x06STATUS\x10\x02\"\x92\x01\n\x0bStopMessage\x12-\n\x04type\x18\x01 \x01(\x0e\x32\x1f.circuit.pb.v2.StopMessage.Type\x12\x0c\n\x04peer\x18\x02 \x01(\x0c\x12%\n\x06status\x18\x03 \x01(\x0b\x32\x15.circuit.pb.v2.Status\"\x1f\n\x04Type\x12\x0b\n\x07\x43ONNECT\x10\x00\x12\n\n\x06STATUS\x10\x01\"A\n\x0bReservation\x12\x0f\n\x07voucher\x18\x01 \x01(\x0c\x12\x11\n\tsignature\x18\x02 \x01(\x0c\x12\x0e\n\x06\x65xpire\x18\x03 \x01(\x03\"\'\n\x05Limit\x12\x10\n\x08\x64uration\x18\x01 \x01(\x03\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x03\"\xf6\x01\n\x06Status\x12(\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x1a.circuit.pb.v2.Status.Code\x12\x0f\n\x07message\x18\x02 \x01(\t\"\xb0\x01\n\x04\x43ode\x12\x06\n\x02OK\x10\x00\x12\x17\n\x13RESERVATION_REFUSED\x10\x64\x12\x1b\n\x17RESOURCE_LIMIT_EXCEEDED\x10\x65\x12\x15\n\x11PERMISSION_DENIED\x10\x66\x12\x16\n\x11\x43ONNECTION_FAILED\x10\xc8\x01\x12\x11\n\x0c\x44IAL_REFUSED\x10\xc9\x01\x12\x10\n\x0bSTOP_FAILED\x10\xac\x02\x12\x16\n\x11MALFORMED_MESSAGE\x10\x90\x03\x62\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.relay.circuit_v2.pb.circuit_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_HOPMESSAGE._serialized_start=60
_HOPMESSAGE._serialized_end=303
_HOPMESSAGE_TYPE._serialized_start=259
_HOPMESSAGE_TYPE._serialized_end=303
_STOPMESSAGE._serialized_start=306
_STOPMESSAGE._serialized_end=452
_STOPMESSAGE_TYPE._serialized_start=421
_STOPMESSAGE_TYPE._serialized_end=452
_RESERVATION._serialized_start=454
_RESERVATION._serialized_end=519
_LIMIT._serialized_start=521
_LIMIT._serialized_end=560
_STATUS._serialized_start=563
_STATUS._serialized_end=809
_STATUS_CODE._serialized_start=633
_STATUS_CODE._serialized_end=809
# @@protoc_insertion_point(module_scope)

View File

@ -0,0 +1,184 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
"""
import builtins
import google.protobuf.descriptor
import google.protobuf.internal.enum_type_wrapper
import google.protobuf.message
import sys
import typing
if sys.version_info >= (3, 10):
import typing as typing_extensions
else:
import typing_extensions
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@typing.final
class HopMessage(google.protobuf.message.Message):
"""Circuit v2 message types"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
class _Type:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _TypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[HopMessage._Type.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
RESERVE: HopMessage._Type.ValueType # 0
CONNECT: HopMessage._Type.ValueType # 1
STATUS: HopMessage._Type.ValueType # 2
class Type(_Type, metaclass=_TypeEnumTypeWrapper): ...
RESERVE: HopMessage.Type.ValueType # 0
CONNECT: HopMessage.Type.ValueType # 1
STATUS: HopMessage.Type.ValueType # 2
TYPE_FIELD_NUMBER: builtins.int
PEER_FIELD_NUMBER: builtins.int
RESERVATION_FIELD_NUMBER: builtins.int
LIMIT_FIELD_NUMBER: builtins.int
STATUS_FIELD_NUMBER: builtins.int
type: global___HopMessage.Type.ValueType
peer: builtins.bytes
@property
def reservation(self) -> global___Reservation: ...
@property
def limit(self) -> global___Limit: ...
@property
def status(self) -> global___Status: ...
def __init__(
self,
*,
type: global___HopMessage.Type.ValueType = ...,
peer: builtins.bytes = ...,
reservation: global___Reservation | None = ...,
limit: global___Limit | None = ...,
status: global___Status | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["limit", b"limit", "reservation", b"reservation", "status", b"status"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["limit", b"limit", "peer", b"peer", "reservation", b"reservation", "status", b"status", "type", b"type"]) -> None: ...
global___HopMessage = HopMessage
@typing.final
class StopMessage(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
class _Type:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _TypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[StopMessage._Type.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
CONNECT: StopMessage._Type.ValueType # 0
STATUS: StopMessage._Type.ValueType # 1
class Type(_Type, metaclass=_TypeEnumTypeWrapper): ...
CONNECT: StopMessage.Type.ValueType # 0
STATUS: StopMessage.Type.ValueType # 1
TYPE_FIELD_NUMBER: builtins.int
PEER_FIELD_NUMBER: builtins.int
STATUS_FIELD_NUMBER: builtins.int
type: global___StopMessage.Type.ValueType
peer: builtins.bytes
@property
def status(self) -> global___Status: ...
def __init__(
self,
*,
type: global___StopMessage.Type.ValueType = ...,
peer: builtins.bytes = ...,
status: global___Status | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["status", b"status"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["peer", b"peer", "status", b"status", "type", b"type"]) -> None: ...
global___StopMessage = StopMessage
@typing.final
class Reservation(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
VOUCHER_FIELD_NUMBER: builtins.int
SIGNATURE_FIELD_NUMBER: builtins.int
EXPIRE_FIELD_NUMBER: builtins.int
voucher: builtins.bytes
signature: builtins.bytes
expire: builtins.int
def __init__(
self,
*,
voucher: builtins.bytes = ...,
signature: builtins.bytes = ...,
expire: builtins.int = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["expire", b"expire", "signature", b"signature", "voucher", b"voucher"]) -> None: ...
global___Reservation = Reservation
@typing.final
class Limit(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
DURATION_FIELD_NUMBER: builtins.int
DATA_FIELD_NUMBER: builtins.int
duration: builtins.int
data: builtins.int
def __init__(
self,
*,
duration: builtins.int = ...,
data: builtins.int = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["data", b"data", "duration", b"duration"]) -> None: ...
global___Limit = Limit
@typing.final
class Status(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
class _Code:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _CodeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Status._Code.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
OK: Status._Code.ValueType # 0
RESERVATION_REFUSED: Status._Code.ValueType # 100
RESOURCE_LIMIT_EXCEEDED: Status._Code.ValueType # 101
PERMISSION_DENIED: Status._Code.ValueType # 102
CONNECTION_FAILED: Status._Code.ValueType # 200
DIAL_REFUSED: Status._Code.ValueType # 201
STOP_FAILED: Status._Code.ValueType # 300
MALFORMED_MESSAGE: Status._Code.ValueType # 400
class Code(_Code, metaclass=_CodeEnumTypeWrapper): ...
OK: Status.Code.ValueType # 0
RESERVATION_REFUSED: Status.Code.ValueType # 100
RESOURCE_LIMIT_EXCEEDED: Status.Code.ValueType # 101
PERMISSION_DENIED: Status.Code.ValueType # 102
CONNECTION_FAILED: Status.Code.ValueType # 200
DIAL_REFUSED: Status.Code.ValueType # 201
STOP_FAILED: Status.Code.ValueType # 300
MALFORMED_MESSAGE: Status.Code.ValueType # 400
CODE_FIELD_NUMBER: builtins.int
MESSAGE_FIELD_NUMBER: builtins.int
code: global___Status.Code.ValueType
message: builtins.str
def __init__(
self,
*,
code: global___Status.Code.ValueType = ...,
message: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["code", b"code", "message", b"message"]) -> None: ...
global___Status = Status

View File

@ -0,0 +1,800 @@
"""
Circuit Relay v2 protocol implementation.
This module implements the Circuit Relay v2 protocol as specified in:
https://github.com/libp2p/specs/blob/master/relay/circuit-v2.md
"""
import logging
import time
from typing import (
Any,
Protocol as TypingProtocol,
cast,
runtime_checkable,
)
import trio
from libp2p.abc import (
IHost,
INetStream,
)
from libp2p.custom_types import (
TProtocol,
)
from libp2p.io.abc import (
ReadWriteCloser,
)
from libp2p.peer.id import (
ID,
)
from libp2p.stream_muxer.mplex.exceptions import (
MplexStreamEOF,
MplexStreamReset,
)
from libp2p.tools.async_service import (
Service,
)
from .pb.circuit_pb2 import (
HopMessage,
Limit,
Reservation,
Status as PbStatus,
StopMessage,
)
from .protocol_buffer import (
StatusCode,
create_status,
)
from .resources import (
RelayLimits,
RelayResourceManager,
)
logger = logging.getLogger("libp2p.relay.circuit_v2")
PROTOCOL_ID = TProtocol("/libp2p/circuit/relay/2.0.0")
STOP_PROTOCOL_ID = TProtocol("/libp2p/circuit/relay/2.0.0/stop")
# Default limits for relay resources
DEFAULT_RELAY_LIMITS = RelayLimits(
duration=60 * 60, # 1 hour
data=1024 * 1024 * 1024, # 1GB
max_circuit_conns=8,
max_reservations=4,
)
# Stream operation timeouts
STREAM_READ_TIMEOUT = 15 # seconds
STREAM_WRITE_TIMEOUT = 15 # seconds
STREAM_CLOSE_TIMEOUT = 10 # seconds
MAX_READ_RETRIES = 5 # Maximum number of read retries
# Extended interfaces for type checking
@runtime_checkable
class IHostWithStreamHandlers(TypingProtocol):
"""Extended host interface with stream handler methods."""
def remove_stream_handler(self, protocol_id: TProtocol) -> None:
"""Remove a stream handler for a protocol."""
...
@runtime_checkable
class INetStreamWithExtras(TypingProtocol):
"""Extended net stream interface with additional methods."""
def get_remote_peer_id(self) -> ID:
"""Get the remote peer ID."""
...
def is_open(self) -> bool:
"""Check if the stream is open."""
...
def is_closed(self) -> bool:
"""Check if the stream is closed."""
...
class CircuitV2Protocol(Service):
"""
CircuitV2Protocol implements the Circuit Relay v2 protocol.
This protocol allows peers to establish connections through relay nodes
when direct connections are not possible (e.g., due to NAT).
"""
def __init__(
self,
host: IHost,
limits: RelayLimits | None = None,
allow_hop: bool = False,
) -> None:
"""
Initialize a Circuit Relay v2 protocol instance.
Parameters
----------
host : IHost
The libp2p host instance
limits : RelayLimits | None
Resource limits for the relay
allow_hop : bool
Whether to allow this node to act as a relay
"""
self.host = host
self.limits = limits or DEFAULT_RELAY_LIMITS
self.allow_hop = allow_hop
self.resource_manager = RelayResourceManager(self.limits)
self._active_relays: dict[ID, tuple[INetStream, INetStream | None]] = {}
self.event_started = trio.Event()
async def run(self, *, task_status: Any = trio.TASK_STATUS_IGNORED) -> None:
"""Run the protocol service."""
try:
# Register protocol handlers
if self.allow_hop:
logger.debug("Registering stream handlers for relay protocol")
self.host.set_stream_handler(PROTOCOL_ID, self._handle_hop_stream)
self.host.set_stream_handler(STOP_PROTOCOL_ID, self._handle_stop_stream)
logger.debug("Stream handlers registered successfully")
# Signal that we're ready
self.event_started.set()
task_status.started()
logger.debug("Protocol service started")
# Wait for service to be stopped
await self.manager.wait_finished()
finally:
# Clean up any active relay connections
for src_stream, dst_stream in self._active_relays.values():
await self._close_stream(src_stream)
await self._close_stream(dst_stream)
self._active_relays.clear()
# Unregister protocol handlers
if self.allow_hop:
try:
# Cast host to extended interface with remove_stream_handler
host_with_handlers = cast(IHostWithStreamHandlers, self.host)
host_with_handlers.remove_stream_handler(PROTOCOL_ID)
host_with_handlers.remove_stream_handler(STOP_PROTOCOL_ID)
except Exception as e:
logger.error("Error unregistering stream handlers: %s", str(e))
async def _close_stream(self, stream: INetStream | None) -> None:
"""Helper function to safely close a stream."""
if stream is None:
return
try:
with trio.fail_after(STREAM_CLOSE_TIMEOUT):
await stream.close()
except Exception:
try:
await stream.reset()
except Exception:
pass
async def _read_stream_with_retry(
self,
stream: INetStream,
max_retries: int = MAX_READ_RETRIES,
) -> bytes | None:
"""
Helper function to read from a stream with retries.
Parameters
----------
stream : INetStream
The stream to read from
max_retries : int
Maximum number of read retries
Returns
-------
Optional[bytes]
The data read from the stream, or None if the stream is closed/reset
Raises
------
trio.TooSlowError
If read timeout occurs after all retries
Exception
For other unexpected errors
"""
retries = 0
last_error: Any = None
backoff_time = 0.2 # Base backoff time in seconds
while retries < max_retries:
try:
with trio.fail_after(STREAM_READ_TIMEOUT):
# Try reading with timeout
logger.debug(
"Attempting to read from stream (attempt %d/%d)",
retries + 1,
max_retries,
)
data = await stream.read()
if not data: # EOF
logger.debug("Stream EOF detected")
return None
logger.debug("Successfully read %d bytes from stream", len(data))
return data
except trio.WouldBlock:
# Just retry immediately if we would block
retries += 1
logger.debug(
"Stream would block (attempt %d/%d), retrying...",
retries,
max_retries,
)
await trio.sleep(backoff_time * retries) # Increased backoff time
continue
except (MplexStreamEOF, MplexStreamReset):
# Stream closed/reset - no point retrying
logger.debug("Stream closed/reset during read")
return None
except trio.TooSlowError as e:
last_error = e
retries += 1
logger.debug(
"Read timeout (attempt %d/%d), retrying...", retries, max_retries
)
if retries < max_retries:
# Wait longer before retry with increasing backoff
await trio.sleep(backoff_time * retries) # Increased backoff
continue
except Exception as e:
logger.error("Unexpected error reading from stream: %s", str(e))
last_error = e
retries += 1
if retries < max_retries:
await trio.sleep(backoff_time * retries) # Increased backoff
continue
raise
if last_error:
if isinstance(last_error, trio.TooSlowError):
logger.error("Read timed out after %d retries", max_retries)
raise last_error
return None
async def _handle_hop_stream(self, stream: INetStream) -> None:
"""
Handle incoming HOP streams.
This handler processes relay requests from other peers.
"""
try:
# Try to get peer ID first
try:
# Cast to extended interface with get_remote_peer_id
stream_with_peer_id = cast(INetStreamWithExtras, stream)
remote_peer_id = stream_with_peer_id.get_remote_peer_id()
remote_id = str(remote_peer_id)
except Exception:
# Fall back to address if peer ID not available
remote_addr = stream.get_remote_address()
remote_id = f"peer at {remote_addr}" if remote_addr else "unknown peer"
logger.debug("Handling hop stream from %s", remote_id)
# First, handle the read timeout gracefully
try:
with trio.fail_after(
STREAM_READ_TIMEOUT * 2
): # Double the timeout for reading
msg_bytes = await stream.read()
if not msg_bytes:
logger.error(
"Empty read from stream from %s",
remote_id,
)
# Create a proto Status directly
pb_status = PbStatus()
pb_status.code = cast(Any, int(StatusCode.MALFORMED_MESSAGE))
pb_status.message = "Empty message received"
response = HopMessage(
type=HopMessage.STATUS,
status=pb_status,
)
await stream.write(response.SerializeToString())
await trio.sleep(0.5) # Longer wait to ensure message is sent
return
except trio.TooSlowError:
logger.error(
"Timeout reading from hop stream from %s",
remote_id,
)
# Create a proto Status directly
pb_status = PbStatus()
pb_status.code = cast(Any, int(StatusCode.CONNECTION_FAILED))
pb_status.message = "Stream read timeout"
response = HopMessage(
type=HopMessage.STATUS,
status=pb_status,
)
await stream.write(response.SerializeToString())
await trio.sleep(0.5) # Longer wait to ensure the message is sent
return
except Exception as e:
logger.error(
"Error reading from hop stream from %s: %s",
remote_id,
str(e),
)
# Create a proto Status directly
pb_status = PbStatus()
pb_status.code = cast(Any, int(StatusCode.MALFORMED_MESSAGE))
pb_status.message = f"Read error: {str(e)}"
response = HopMessage(
type=HopMessage.STATUS,
status=pb_status,
)
await stream.write(response.SerializeToString())
await trio.sleep(0.5) # Longer wait to ensure the message is sent
return
# Parse the message
try:
hop_msg = HopMessage()
hop_msg.ParseFromString(msg_bytes)
except Exception as e:
logger.error(
"Error parsing hop message from %s: %s",
remote_id,
str(e),
)
# Create a proto Status directly
pb_status = PbStatus()
pb_status.code = cast(Any, int(StatusCode.MALFORMED_MESSAGE))
pb_status.message = f"Parse error: {str(e)}"
response = HopMessage(
type=HopMessage.STATUS,
status=pb_status,
)
await stream.write(response.SerializeToString())
await trio.sleep(0.5) # Longer wait to ensure the message is sent
return
# Process based on message type
if hop_msg.type == HopMessage.RESERVE:
logger.debug("Handling RESERVE message from %s", remote_id)
await self._handle_reserve(stream, hop_msg)
# For RESERVE requests, let the client close the stream
return
elif hop_msg.type == HopMessage.CONNECT:
logger.debug("Handling CONNECT message from %s", remote_id)
await self._handle_connect(stream, hop_msg)
else:
logger.error("Invalid message type %d from %s", hop_msg.type, remote_id)
# Send a nice error response using _send_status method
await self._send_status(
stream,
StatusCode.MALFORMED_MESSAGE,
f"Invalid message type: {hop_msg.type}",
)
except Exception as e:
logger.error(
"Unexpected error handling hop stream from %s: %s", remote_id, str(e)
)
try:
# Send a nice error response using _send_status method
await self._send_status(
stream,
StatusCode.MALFORMED_MESSAGE,
f"Internal error: {str(e)}",
)
except Exception as e2:
logger.error(
"Failed to send error response to %s: %s", remote_id, str(e2)
)
async def _handle_stop_stream(self, stream: INetStream) -> None:
"""
Handle incoming STOP streams.
This handler processes incoming relay connections from the destination side.
"""
try:
# Read the incoming message with timeout
with trio.fail_after(STREAM_READ_TIMEOUT):
msg_bytes = await stream.read()
stop_msg = StopMessage()
stop_msg.ParseFromString(msg_bytes)
if stop_msg.type != StopMessage.CONNECT:
# Use direct attribute access to create status object for error response
await self._send_stop_status(
stream,
StatusCode.MALFORMED_MESSAGE,
"Invalid message type",
)
await self._close_stream(stream)
return
# Get the source stream from active relays
peer_id = ID(stop_msg.peer)
if peer_id not in self._active_relays:
# Use direct attribute access to create status object for error response
await self._send_stop_status(
stream,
StatusCode.CONNECTION_FAILED,
"No pending relay connection",
)
await self._close_stream(stream)
return
src_stream, _ = self._active_relays[peer_id]
self._active_relays[peer_id] = (src_stream, stream)
# Send success status to both sides
await self._send_status(
src_stream,
StatusCode.OK,
"Connection established",
)
await self._send_stop_status(
stream,
StatusCode.OK,
"Connection established",
)
# Start relaying data
async with trio.open_nursery() as nursery:
nursery.start_soon(self._relay_data, src_stream, stream, peer_id)
nursery.start_soon(self._relay_data, stream, src_stream, peer_id)
except trio.TooSlowError:
logger.error("Timeout reading from stop stream")
await self._send_stop_status(
stream,
StatusCode.CONNECTION_FAILED,
"Stream read timeout",
)
await self._close_stream(stream)
except Exception as e:
logger.error("Error handling stop stream: %s", str(e))
try:
await self._send_stop_status(
stream,
StatusCode.MALFORMED_MESSAGE,
str(e),
)
await self._close_stream(stream)
except Exception:
pass
async def _handle_reserve(self, stream: INetStream, msg: Any) -> None:
"""Handle a reservation request."""
peer_id = None
try:
peer_id = ID(msg.peer)
logger.debug("Handling reservation request from peer %s", peer_id)
# Check if we can accept more reservations
if not self.resource_manager.can_accept_reservation(peer_id):
logger.debug("Reservation limit exceeded for peer %s", peer_id)
# Send status message with STATUS type
status = create_status(
code=StatusCode.RESOURCE_LIMIT_EXCEEDED,
message="Reservation limit exceeded",
)
status_msg = HopMessage(
type=HopMessage.STATUS,
status=status.to_pb(),
)
await stream.write(status_msg.SerializeToString())
return
# Accept reservation
logger.debug("Accepting reservation from peer %s", peer_id)
ttl = self.resource_manager.reserve(peer_id)
# Send reservation success response
with trio.fail_after(STREAM_WRITE_TIMEOUT):
status = create_status(
code=StatusCode.OK, message="Reservation accepted"
)
response = HopMessage(
type=HopMessage.STATUS,
status=status.to_pb(),
reservation=Reservation(
expire=int(time.time() + ttl),
voucher=b"", # We don't use vouchers yet
signature=b"", # We don't use signatures yet
),
limit=Limit(
duration=self.limits.duration,
data=self.limits.data,
),
)
# Log the response message details for debugging
logger.debug(
"Sending reservation response: type=%s, status=%s, ttl=%d",
response.type,
getattr(response.status, "code", "unknown"),
ttl,
)
# Send the response with increased timeout
await stream.write(response.SerializeToString())
# Add a small wait to ensure the message is fully sent
await trio.sleep(0.1)
logger.debug("Reservation response sent successfully")
except Exception as e:
logger.error("Error handling reservation request: %s", str(e))
if cast(INetStreamWithExtras, stream).is_open():
try:
# Send error response
await self._send_status(
stream,
StatusCode.INTERNAL_ERROR,
f"Failed to process reservation: {str(e)}",
)
except Exception as send_err:
logger.error("Failed to send error response: %s", str(send_err))
finally:
# Always close the stream when done with reservation
if cast(INetStreamWithExtras, stream).is_open():
try:
with trio.fail_after(STREAM_CLOSE_TIMEOUT):
await stream.close()
except Exception as close_err:
logger.error("Error closing stream: %s", str(close_err))
async def _handle_connect(self, stream: INetStream, msg: Any) -> None:
"""Handle a connect request."""
peer_id = ID(msg.peer)
dst_stream: INetStream | None = None
# Verify reservation if provided
if msg.HasField("reservation"):
if not self.resource_manager.verify_reservation(peer_id, msg.reservation):
await self._send_status(
stream,
StatusCode.PERMISSION_DENIED,
"Invalid reservation",
)
await stream.reset()
return
# Check resource limits
if not self.resource_manager.can_accept_connection(peer_id):
await self._send_status(
stream,
StatusCode.RESOURCE_LIMIT_EXCEEDED,
"Connection limit exceeded",
)
await stream.reset()
return
try:
# Store the source stream with properly typed None
self._active_relays[peer_id] = (stream, None)
# Try to connect to the destination with timeout
with trio.fail_after(STREAM_READ_TIMEOUT):
dst_stream = await self.host.new_stream(peer_id, [STOP_PROTOCOL_ID])
if not dst_stream:
raise ConnectionError("Could not connect to destination")
# Send STOP CONNECT message
stop_msg = StopMessage(
type=StopMessage.CONNECT,
# Cast to extended interface with get_remote_peer_id
peer=cast(INetStreamWithExtras, stream)
.get_remote_peer_id()
.to_bytes(),
)
await dst_stream.write(stop_msg.SerializeToString())
# Wait for response from destination
resp_bytes = await dst_stream.read()
resp = StopMessage()
resp.ParseFromString(resp_bytes)
# Handle status attributes from the response
if resp.HasField("status"):
# Get code and message attributes with defaults
status_code = getattr(resp.status, "code", StatusCode.OK)
# Get message with default
status_msg = getattr(resp.status, "message", "Unknown error")
else:
status_code = StatusCode.OK
status_msg = "No status provided"
if status_code != StatusCode.OK:
raise ConnectionError(
f"Destination rejected connection: {status_msg}"
)
# Update active relays with destination stream
self._active_relays[peer_id] = (stream, dst_stream)
# Update reservation connection count
reservation = self.resource_manager._reservations.get(peer_id)
if reservation:
reservation.active_connections += 1
# Send success status
await self._send_status(
stream,
StatusCode.OK,
"Connection established",
)
# Start relaying data
async with trio.open_nursery() as nursery:
nursery.start_soon(self._relay_data, stream, dst_stream, peer_id)
nursery.start_soon(self._relay_data, dst_stream, stream, peer_id)
except (trio.TooSlowError, ConnectionError) as e:
logger.error("Error establishing relay connection: %s", str(e))
await self._send_status(
stream,
StatusCode.CONNECTION_FAILED,
str(e),
)
if peer_id in self._active_relays:
del self._active_relays[peer_id]
# Clean up reservation connection count on failure
reservation = self.resource_manager._reservations.get(peer_id)
if reservation:
reservation.active_connections -= 1
await stream.reset()
if dst_stream and not cast(INetStreamWithExtras, dst_stream).is_closed():
await dst_stream.reset()
except Exception as e:
logger.error("Unexpected error in connect handler: %s", str(e))
await self._send_status(
stream,
StatusCode.CONNECTION_FAILED,
"Internal error",
)
if peer_id in self._active_relays:
del self._active_relays[peer_id]
await stream.reset()
if dst_stream and not cast(INetStreamWithExtras, dst_stream).is_closed():
await dst_stream.reset()
async def _relay_data(
self,
src_stream: INetStream,
dst_stream: INetStream,
peer_id: ID,
) -> None:
"""
Relay data between two streams.
Parameters
----------
src_stream : INetStream
Source stream to read from
dst_stream : INetStream
Destination stream to write to
peer_id : ID
ID of the peer being relayed
"""
try:
while True:
# Read data with retries
data = await self._read_stream_with_retry(src_stream)
if not data:
logger.info("Source stream closed/reset")
break
# Write data with timeout
try:
with trio.fail_after(STREAM_WRITE_TIMEOUT):
await dst_stream.write(data)
except trio.TooSlowError:
logger.error("Timeout writing to destination stream")
break
except Exception as e:
logger.error("Error writing to destination stream: %s", str(e))
break
# Update resource usage
reservation = self.resource_manager._reservations.get(peer_id)
if reservation:
reservation.data_used += len(data)
if reservation.data_used >= reservation.limits.data:
logger.warning("Data limit exceeded for peer %s", peer_id)
break
except Exception as e:
logger.error("Error relaying data: %s", str(e))
finally:
# Clean up streams and remove from active relays
await src_stream.reset()
await dst_stream.reset()
if peer_id in self._active_relays:
del self._active_relays[peer_id]
async def _send_status(
self,
stream: ReadWriteCloser,
code: int,
message: str,
) -> None:
"""Send a status message."""
try:
logger.debug("Sending status message with code %s: %s", code, message)
with trio.fail_after(STREAM_WRITE_TIMEOUT * 2): # Double the timeout
# Create a proto Status directly
pb_status = PbStatus()
pb_status.code = cast(
Any, int(code)
) # Cast to Any to avoid type errors
pb_status.message = message
status_msg = HopMessage(
type=HopMessage.STATUS,
status=pb_status,
)
msg_bytes = status_msg.SerializeToString()
logger.debug("Status message serialized (%d bytes)", len(msg_bytes))
await stream.write(msg_bytes)
logger.debug("Status message sent, waiting for processing")
# Wait longer to ensure the message is sent
await trio.sleep(1.5)
logger.debug("Status message sending completed")
except trio.TooSlowError:
logger.error(
"Timeout sending status message: code=%s, message=%s", code, message
)
except Exception as e:
logger.error("Error sending status message: %s", str(e))
async def _send_stop_status(
self,
stream: ReadWriteCloser,
code: int,
message: str,
) -> None:
"""Send a status message on a STOP stream."""
try:
logger.debug("Sending stop status message with code %s: %s", code, message)
with trio.fail_after(STREAM_WRITE_TIMEOUT * 2): # Double the timeout
# Create a proto Status directly
pb_status = PbStatus()
pb_status.code = cast(
Any, int(code)
) # Cast to Any to avoid type errors
pb_status.message = message
status_msg = StopMessage(
type=StopMessage.STATUS,
status=pb_status,
)
await stream.write(status_msg.SerializeToString())
await trio.sleep(0.5) # Ensure message is sent
except Exception as e:
logger.error("Error sending stop status message: %s", str(e))

View File

@ -0,0 +1,55 @@
"""
Protocol buffer wrapper classes for Circuit Relay v2.
This module provides wrapper classes for protocol buffer generated objects
to make them easier to work with in type-checked code.
"""
from enum import (
IntEnum,
)
from typing import (
Any,
)
from .pb.circuit_pb2 import Status as PbStatus
# Define Status codes as an Enum for better type safety and organization
class StatusCode(IntEnum):
OK = 0
RESERVATION_REFUSED = 100
RESOURCE_LIMIT_EXCEEDED = 101
PERMISSION_DENIED = 102
CONNECTION_FAILED = 200
DIAL_REFUSED = 201
STOP_FAILED = 300
MALFORMED_MESSAGE = 400
INTERNAL_ERROR = 500
def create_status(code: int = StatusCode.OK, message: str = "") -> Any:
"""
Create a protocol buffer Status object.
Parameters
----------
code : int
The status code
message : str
The status message
Returns
-------
Any
The protocol buffer Status object
"""
# Create status object
pb_obj = PbStatus()
# Convert the integer status code to the protobuf enum value type
pb_obj.code = PbStatus.Code.ValueType(code)
pb_obj.message = message
return pb_obj

View File

@ -0,0 +1,254 @@
"""
Resource management for Circuit Relay v2.
This module handles managing resources for relay operations,
including reservations and connection limits.
"""
from dataclasses import (
dataclass,
)
import hashlib
import os
import time
from libp2p.peer.id import (
ID,
)
# Import the protobuf definitions
from .pb.circuit_pb2 import Reservation as PbReservation
@dataclass
class RelayLimits:
"""Configuration for relay resource limits."""
duration: int # Maximum duration of a relay connection in seconds
data: int # Maximum data transfer allowed in bytes
max_circuit_conns: int # Maximum number of concurrent circuit connections
max_reservations: int # Maximum number of active reservations
class Reservation:
"""Represents a relay reservation."""
def __init__(self, peer_id: ID, limits: RelayLimits):
"""
Initialize a new reservation.
Parameters
----------
peer_id : ID
The peer ID this reservation is for
limits : RelayLimits
The resource limits for this reservation
"""
self.peer_id = peer_id
self.limits = limits
self.created_at = time.time()
self.expires_at = self.created_at + limits.duration
self.data_used = 0
self.active_connections = 0
self.voucher = self._generate_voucher()
def _generate_voucher(self) -> bytes:
"""
Generate a unique cryptographically secure voucher for this reservation.
Returns
-------
bytes
A secure voucher token
"""
# Create a random token using a combination of:
# - Random bytes for unpredictability
# - Peer ID to bind it to the specific peer
# - Timestamp for uniqueness
# - Hash everything for a fixed size output
random_bytes = os.urandom(16) # 128 bits of randomness
timestamp = str(int(self.created_at * 1000000)).encode()
peer_bytes = self.peer_id.to_bytes()
# Combine all elements and hash them
h = hashlib.sha256()
h.update(random_bytes)
h.update(timestamp)
h.update(peer_bytes)
return h.digest()
def is_expired(self) -> bool:
"""Check if the reservation has expired."""
return time.time() > self.expires_at
def can_accept_connection(self) -> bool:
"""Check if a new connection can be accepted."""
return (
not self.is_expired()
and self.active_connections < self.limits.max_circuit_conns
and self.data_used < self.limits.data
)
def to_proto(self) -> PbReservation:
"""Convert the reservation to its protobuf representation."""
# TODO: For production use, implement proper signature generation
# The signature should be created by signing the voucher with the
# peer's private key. The current implementation with an empty signature
# is intended for development and testing only.
return PbReservation(
expire=int(self.expires_at),
voucher=self.voucher,
signature=b"",
)
class RelayResourceManager:
"""
Manages resources and reservations for relay operations.
This class handles:
- Tracking active reservations
- Enforcing resource limits
- Managing connection quotas
"""
def __init__(self, limits: RelayLimits):
"""
Initialize the resource manager.
Parameters
----------
limits : RelayLimits
The resource limits to enforce
"""
self.limits = limits
self._reservations: dict[ID, Reservation] = {}
def can_accept_reservation(self, peer_id: ID) -> bool:
"""
Check if a new reservation can be accepted for the given peer.
Parameters
----------
peer_id : ID
The peer ID requesting the reservation
Returns
-------
bool
True if the reservation can be accepted
"""
# Clean expired reservations
self._clean_expired()
# Check if peer already has a valid reservation
existing = self._reservations.get(peer_id)
if existing and not existing.is_expired():
return True
# Check if we're at the reservation limit
return len(self._reservations) < self.limits.max_reservations
def create_reservation(self, peer_id: ID) -> Reservation:
"""
Create a new reservation for the given peer.
Parameters
----------
peer_id : ID
The peer ID to create the reservation for
Returns
-------
Reservation
The newly created reservation
"""
reservation = Reservation(peer_id, self.limits)
self._reservations[peer_id] = reservation
return reservation
def verify_reservation(self, peer_id: ID, proto_res: PbReservation) -> bool:
"""
Verify a reservation from a protobuf message.
Parameters
----------
peer_id : ID
The peer ID the reservation is for
proto_res : PbReservation
The protobuf reservation message
Returns
-------
bool
True if the reservation is valid
"""
# TODO: Implement voucher and signature verification
reservation = self._reservations.get(peer_id)
return (
reservation is not None
and not reservation.is_expired()
and reservation.expires_at == proto_res.expire
)
def can_accept_connection(self, peer_id: ID) -> bool:
"""
Check if a new connection can be accepted for the given peer.
Parameters
----------
peer_id : ID
The peer ID requesting the connection
Returns
-------
bool
True if the connection can be accepted
"""
reservation = self._reservations.get(peer_id)
return reservation is not None and reservation.can_accept_connection()
def _clean_expired(self) -> None:
"""Remove expired reservations."""
now = time.time()
expired = [
peer_id
for peer_id, res in self._reservations.items()
if now > res.expires_at
]
for peer_id in expired:
del self._reservations[peer_id]
def reserve(self, peer_id: ID) -> int:
"""
Create or update a reservation for a peer and return the TTL.
Parameters
----------
peer_id : ID
The peer ID to reserve for
Returns
-------
int
The TTL of the reservation in seconds
"""
# Check for existing reservation
existing = self._reservations.get(peer_id)
if existing and not existing.is_expired():
# Return remaining time for existing reservation
remaining = max(0, int(existing.expires_at - time.time()))
return remaining
# Create new reservation
self.create_reservation(peer_id)
return self.limits.duration

View File

@ -0,0 +1,427 @@
"""
Transport implementation for Circuit Relay v2.
This module implements the transport layer for Circuit Relay v2,
allowing peers to establish connections through relay nodes.
"""
from collections.abc import Awaitable, Callable
import logging
import multiaddr
import trio
from libp2p.abc import (
IHost,
IListener,
INetStream,
ITransport,
ReadWriteCloser,
)
from libp2p.network.connection.raw_connection import (
RawConnection,
)
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
from libp2p.tools.async_service import (
Service,
)
from .config import (
ClientConfig,
RelayConfig,
)
from .discovery import (
RelayDiscovery,
)
from .pb.circuit_pb2 import (
HopMessage,
StopMessage,
)
from .protocol import (
PROTOCOL_ID,
CircuitV2Protocol,
)
from .protocol_buffer import (
StatusCode,
)
logger = logging.getLogger("libp2p.relay.circuit_v2.transport")
class CircuitV2Transport(ITransport):
"""
CircuitV2Transport implements the transport interface for Circuit Relay v2.
This transport allows peers to establish connections through relay nodes
when direct connections are not possible.
"""
def __init__(
self,
host: IHost,
protocol: CircuitV2Protocol,
config: RelayConfig,
) -> None:
"""
Initialize the Circuit v2 transport.
Parameters
----------
host : IHost
The libp2p host this transport is running on
protocol : CircuitV2Protocol
The Circuit v2 protocol instance
config : RelayConfig
Relay configuration
"""
self.host = host
self.protocol = protocol
self.config = config
self.client_config = ClientConfig()
self.discovery = RelayDiscovery(
host=host,
auto_reserve=config.enable_client,
discovery_interval=config.discovery_interval,
max_relays=config.max_relays,
)
async def dial(
self,
maddr: multiaddr.Multiaddr,
) -> RawConnection:
"""
Dial a peer using the multiaddr.
Parameters
----------
maddr : multiaddr.Multiaddr
The multiaddr to dial
Returns
-------
RawConnection
The established connection
Raises
------
ConnectionError
If the connection cannot be established
"""
# Extract peer ID from multiaddr - P_P2P code is 0x01A5 (421)
peer_id_str = maddr.value_for_protocol("p2p")
if not peer_id_str:
raise ConnectionError("Multiaddr does not contain peer ID")
peer_id = ID.from_base58(peer_id_str)
peer_info = PeerInfo(peer_id, [maddr])
# Use the internal dial_peer_info method
return await self.dial_peer_info(peer_info)
async def dial_peer_info(
self,
peer_info: PeerInfo,
*,
relay_peer_id: ID | None = None,
) -> RawConnection:
"""
Dial a peer through a relay.
Parameters
----------
peer_info : PeerInfo
The peer to dial
relay_peer_id : Optional[ID], optional
Optional specific relay peer to use
Returns
-------
RawConnection
The established connection
Raises
------
ConnectionError
If the connection cannot be established
"""
# If no specific relay is provided, try to find one
if relay_peer_id is None:
relay_peer_id = await self._select_relay(peer_info)
if not relay_peer_id:
raise ConnectionError("No suitable relay found")
# Get a stream to the relay
relay_stream = await self.host.new_stream(relay_peer_id, [PROTOCOL_ID])
if not relay_stream:
raise ConnectionError(f"Could not open stream to relay {relay_peer_id}")
try:
# First try to make a reservation if enabled
if self.config.enable_client:
success = await self._make_reservation(relay_stream, relay_peer_id)
if not success:
logger.warning(
"Failed to make reservation with relay %s", relay_peer_id
)
# Send HOP CONNECT message
hop_msg = HopMessage(
type=HopMessage.CONNECT,
peer=peer_info.peer_id.to_bytes(),
)
await relay_stream.write(hop_msg.SerializeToString())
# Read response
resp_bytes = await relay_stream.read()
resp = HopMessage()
resp.ParseFromString(resp_bytes)
# Access status attributes directly
status_code = getattr(resp.status, "code", StatusCode.OK)
status_msg = getattr(resp.status, "message", "Unknown error")
if status_code != StatusCode.OK:
raise ConnectionError(f"Relay connection failed: {status_msg}")
# Create raw connection from stream
return RawConnection(stream=relay_stream, initiator=True)
except Exception as e:
await relay_stream.close()
raise ConnectionError(f"Failed to establish relay connection: {str(e)}")
async def _select_relay(self, peer_info: PeerInfo) -> ID | None:
"""
Select an appropriate relay for the given peer.
Parameters
----------
peer_info : PeerInfo
The peer to connect to
Returns
-------
Optional[ID]
Selected relay peer ID, or None if no suitable relay found
"""
# Try to find a relay
attempts = 0
while attempts < self.client_config.max_auto_relay_attempts:
# Get a relay from the list of discovered relays
relays = self.discovery.get_relays()
if relays:
# TODO: Implement more sophisticated relay selection
# For now, just return the first available relay
return relays[0]
# Wait and try discovery
await trio.sleep(1)
attempts += 1
return None
async def _make_reservation(
self,
stream: INetStream,
relay_peer_id: ID,
) -> bool:
"""
Make a reservation with a relay.
Parameters
----------
stream : INetStream
Stream to the relay
relay_peer_id : ID
The relay's peer ID
Returns
-------
bool
True if reservation was successful
"""
try:
# Send reservation request
reserve_msg = HopMessage(
type=HopMessage.RESERVE,
peer=self.host.get_id().to_bytes(),
)
await stream.write(reserve_msg.SerializeToString())
# Read response
resp_bytes = await stream.read()
resp = HopMessage()
resp.ParseFromString(resp_bytes)
# Access status attributes directly
status_code = getattr(resp.status, "code", StatusCode.OK)
status_msg = getattr(resp.status, "message", "Unknown error")
if status_code != StatusCode.OK:
logger.warning(
"Reservation failed with relay %s: %s",
relay_peer_id,
status_msg,
)
return False
# Store reservation info
# TODO: Implement reservation storage and refresh mechanism
return True
except Exception as e:
logger.error("Error making reservation: %s", str(e))
return False
def create_listener(
self,
handler_function: Callable[[ReadWriteCloser], Awaitable[None]],
) -> IListener:
"""
Create a listener for incoming relay connections.
Parameters
----------
handler_function : Callable[[ReadWriteCloser], Awaitable[None]]
The handler function for new connections
Returns
-------
IListener
The created listener
"""
return CircuitV2Listener(self.host, self.protocol, self.config)
class CircuitV2Listener(Service, IListener):
"""Listener for incoming relay connections."""
def __init__(
self,
host: IHost,
protocol: CircuitV2Protocol,
config: RelayConfig,
) -> None:
"""
Initialize the Circuit v2 listener.
Parameters
----------
host : IHost
The libp2p host this listener is running on
protocol : CircuitV2Protocol
The Circuit v2 protocol instance
config : RelayConfig
Relay configuration
"""
super().__init__()
self.host = host
self.protocol = protocol
self.config = config
self.multiaddrs: list[
multiaddr.Multiaddr
] = [] # Store multiaddrs as Multiaddr objects
async def handle_incoming_connection(
self,
stream: INetStream,
remote_peer_id: ID,
) -> RawConnection:
"""
Handle an incoming relay connection.
Parameters
----------
stream : INetStream
The incoming stream
remote_peer_id : ID
The remote peer's ID
Returns
-------
RawConnection
The established connection
Raises
------
ConnectionError
If the connection cannot be established
"""
if not self.config.enable_stop:
raise ConnectionError("Stop role is not enabled")
try:
# Read STOP message
msg_bytes = await stream.read()
stop_msg = StopMessage()
stop_msg.ParseFromString(msg_bytes)
if stop_msg.type != StopMessage.CONNECT:
raise ConnectionError("Invalid STOP message type")
# Create raw connection
return RawConnection(stream=stream, initiator=False)
except Exception as e:
await stream.close()
raise ConnectionError(f"Failed to handle incoming connection: {str(e)}")
async def run(self) -> None:
"""Run the listener service."""
# Implementation would go here
async def listen(self, maddr: multiaddr.Multiaddr, nursery: trio.Nursery) -> bool:
"""
Start listening on the given multiaddr.
Parameters
----------
maddr : multiaddr.Multiaddr
The multiaddr to listen on
nursery : trio.Nursery
The nursery to run tasks in
Returns
-------
bool
True if listening successfully started
"""
# Convert string to Multiaddr if needed
addr = (
maddr
if isinstance(maddr, multiaddr.Multiaddr)
else multiaddr.Multiaddr(maddr)
)
self.multiaddrs.append(addr)
return True
def get_addrs(self) -> tuple[multiaddr.Multiaddr, ...]:
"""
Get the listening addresses.
Returns
-------
tuple[multiaddr.Multiaddr, ...]
Tuple of listening multiaddresses
"""
return tuple(self.multiaddrs)
async def close(self) -> None:
"""Close the listener."""
self.multiaddrs.clear()
await self.manager.stop()

View File

@ -1,7 +1,3 @@
from typing import (
Optional,
)
from libp2p.abc import (
ISecureConn,
)
@ -49,5 +45,5 @@ class BaseSession(ISecureConn):
def get_remote_peer(self) -> ID:
return self.remote_peer
def get_remote_public_key(self) -> Optional[PublicKey]:
def get_remote_public_key(self) -> PublicKey:
return self.remote_permanent_pubkey

View File

@ -1,7 +1,7 @@
import secrets
from typing import (
from collections.abc import (
Callable,
)
import secrets
from libp2p.abc import (
ISecureTransport,

View File

@ -1,8 +1,7 @@
from typing import (
Optional,
)
from collections.abc import Callable
from libp2p.abc import (
IPeerStore,
IRawConnection,
ISecureConn,
)
@ -10,6 +9,7 @@ from libp2p.crypto.exceptions import (
MissingDeserializerError,
)
from libp2p.crypto.keys import (
KeyPair,
PrivateKey,
PublicKey,
)
@ -34,11 +34,15 @@ from libp2p.network.connection.exceptions import (
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerstore import (
PeerStoreError,
)
from libp2p.security.base_session import (
BaseSession,
)
from libp2p.security.base_transport import (
BaseSecureTransport,
default_secure_bytes_provider,
)
from libp2p.security.exceptions import (
HandshakeFailure,
@ -87,13 +91,13 @@ class InsecureSession(BaseSession):
async def write(self, data: bytes) -> None:
await self.conn.write(data)
async def read(self, n: int = None) -> bytes:
async def read(self, n: int | None = None) -> bytes:
return await self.conn.read(n)
async def close(self) -> None:
await self.conn.close()
def get_remote_address(self) -> Optional[tuple[str, int]]:
def get_remote_address(self) -> tuple[str, int] | None:
"""
Delegate to the underlying connection's get_remote_address method.
"""
@ -105,7 +109,8 @@ async def run_handshake(
local_private_key: PrivateKey,
conn: IRawConnection,
is_initiator: bool,
remote_peer_id: ID,
remote_peer_id: ID | None,
peerstore: IPeerStore | None = None,
) -> ISecureConn:
"""Raise `HandshakeFailure` when handshake failed."""
msg = make_exchange_message(local_private_key.get_public_key())
@ -124,6 +129,15 @@ async def run_handshake(
remote_msg.ParseFromString(remote_msg_bytes)
received_peer_id = ID(remote_msg.id)
# Verify that `remote_peer_id` isn't `None`
# That is the only condition that `remote_peer_id` would not need to be checked
# against the `recieved_peer_id` gotten from the outbound/recieved `msg`.
# The check against `received_peer_id` happens in the next if-block
if is_initiator and remote_peer_id is None:
raise HandshakeFailure(
"remote peer ID cannot be None if `is_initiator` is set to `True`"
)
# Verify if the receive `ID` matches the one we originally initialize the session.
# We only need to check it when we are the initiator, because only in that condition
# we possibly knows the `ID` of the remote.
@ -159,7 +173,14 @@ async def run_handshake(
conn=conn,
)
# TODO: Store `pubkey` and `peer_id` to `PeerStore`
# Store `pubkey` and `peer_id` to `PeerStore`
if peerstore is not None:
try:
peerstore.add_pubkey(received_peer_id, received_pubkey)
except PeerStoreError:
# If peer ID and pubkey don't match, it would have already been caught above
# This might happen if the peer is already in the store
pass
return secure_conn
@ -170,6 +191,18 @@ class InsecureTransport(BaseSecureTransport):
transport does not add any additional security.
"""
def __init__(
self,
local_key_pair: KeyPair,
secure_bytes_provider: Callable[[int], bytes] | None = None,
peerstore: IPeerStore | None = None,
) -> None:
# If secure_bytes_provider is None, use the default one
if secure_bytes_provider is None:
secure_bytes_provider = default_secure_bytes_provider
super().__init__(local_key_pair, secure_bytes_provider)
self.peerstore = peerstore
async def secure_inbound(self, conn: IRawConnection) -> ISecureConn:
"""
Secure the connection, either locally or by communicating with opposing
@ -178,8 +211,9 @@ class InsecureTransport(BaseSecureTransport):
:return: secure connection object (that implements secure_conn_interface)
"""
# For inbound connections, we don't know the remote peer ID yet
return await run_handshake(
self.local_peer, self.local_private_key, conn, False, None
self.local_peer, self.local_private_key, conn, False, None, self.peerstore
)
async def secure_outbound(self, conn: IRawConnection, peer_id: ID) -> ISecureConn:
@ -190,7 +224,7 @@ class InsecureTransport(BaseSecureTransport):
:return: secure connection object (that implements secure_conn_interface)
"""
return await run_handshake(
self.local_peer, self.local_private_key, conn, True, peer_id
self.local_peer, self.local_private_key, conn, True, peer_id, self.peerstore
)

View File

@ -1,5 +1,4 @@
from typing import (
Optional,
cast,
)
@ -10,7 +9,6 @@ from libp2p.abc import (
)
from libp2p.io.abc import (
EncryptedMsgReadWriter,
MsgReadWriteCloser,
ReadWriteCloser,
)
from libp2p.io.msgio import (
@ -40,7 +38,7 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter):
implemented by the subclasses.
"""
read_writer: MsgReadWriteCloser
read_writer: NoisePacketReadWriter
noise_state: NoiseState
# FIXME: This prefix is added in msg#3 in Go. Check whether it's a desired behavior.
@ -50,12 +48,12 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter):
self.read_writer = NoisePacketReadWriter(cast(ReadWriteCloser, conn))
self.noise_state = noise_state
async def write_msg(self, data: bytes, prefix_encoded: bool = False) -> None:
data_encrypted = self.encrypt(data)
async def write_msg(self, msg: bytes, prefix_encoded: bool = False) -> None:
data_encrypted = self.encrypt(msg)
if prefix_encoded:
await self.read_writer.write_msg(self.prefix + data_encrypted)
else:
await self.read_writer.write_msg(data_encrypted)
# Manually add the prefix if needed
data_encrypted = self.prefix + data_encrypted
await self.read_writer.write_msg(data_encrypted)
async def read_msg(self, prefix_encoded: bool = False) -> bytes:
noise_msg_encrypted = await self.read_writer.read_msg()
@ -67,10 +65,11 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter):
async def close(self) -> None:
await self.read_writer.close()
def get_remote_address(self) -> Optional[tuple[str, int]]:
def get_remote_address(self) -> tuple[str, int] | None:
# Delegate to the underlying connection if possible
if hasattr(self.read_writer, "read_write_closer") and hasattr(
self.read_writer.read_write_closer, "get_remote_address"
self.read_writer.read_write_closer,
"get_remote_address",
):
return self.read_writer.read_write_closer.get_remote_address()
return None
@ -78,7 +77,7 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter):
class NoiseHandshakeReadWriter(BaseNoiseMsgReadWriter):
def encrypt(self, data: bytes) -> bytes:
return self.noise_state.write_message(data)
return bytes(self.noise_state.write_message(data))
def decrypt(self, data: bytes) -> bytes:
return bytes(self.noise_state.read_message(data))

View File

@ -19,7 +19,7 @@ SIGNED_DATA_PREFIX = "noise-libp2p-static-key:"
class NoiseHandshakePayload:
id_pubkey: PublicKey
id_sig: bytes
early_data: bytes = None
early_data: bytes | None = None
def serialize(self) -> bytes:
msg = noise_pb.NoiseHandshakePayload(

View File

@ -7,8 +7,10 @@ from cryptography.hazmat.primitives import (
serialization,
)
from noise.backends.default.keypairs import KeyPair as NoiseKeyPair
from noise.connection import Keypair as NoiseKeypairEnum
from noise.connection import NoiseConnection as NoiseState
from noise.connection import (
Keypair as NoiseKeypairEnum,
NoiseConnection as NoiseState,
)
from libp2p.abc import (
IRawConnection,
@ -47,14 +49,12 @@ from .messages import (
class IPattern(ABC):
@abstractmethod
async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn:
...
async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn: ...
@abstractmethod
async def handshake_outbound(
self, conn: IRawConnection, remote_peer: ID
) -> ISecureConn:
...
) -> ISecureConn: ...
class BasePattern(IPattern):
@ -62,13 +62,15 @@ class BasePattern(IPattern):
noise_static_key: PrivateKey
local_peer: ID
libp2p_privkey: PrivateKey
early_data: bytes
early_data: bytes | None
def create_noise_state(self) -> NoiseState:
noise_state = NoiseState.from_name(self.protocol_name)
noise_state.set_keypair_from_private_bytes(
NoiseKeypairEnum.STATIC, self.noise_static_key.to_bytes()
)
if noise_state.noise_protocol is None:
raise NoiseStateError("noise_protocol is not initialized")
return noise_state
def make_handshake_payload(self) -> NoiseHandshakePayload:
@ -84,7 +86,7 @@ class PatternXX(BasePattern):
local_peer: ID,
libp2p_privkey: PrivateKey,
noise_static_key: PrivateKey,
early_data: bytes = None,
early_data: bytes | None = None,
) -> None:
self.protocol_name = b"Noise_XX_25519_ChaChaPoly_SHA256"
self.local_peer = local_peer
@ -96,7 +98,12 @@ class PatternXX(BasePattern):
noise_state = self.create_noise_state()
noise_state.set_as_responder()
noise_state.start_handshake()
if noise_state.noise_protocol is None:
raise NoiseStateError("noise_protocol is not initialized")
handshake_state = noise_state.noise_protocol.handshake_state
if handshake_state is None:
raise NoiseStateError("Handshake state is not initialized")
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
# Consume msg#1.
@ -145,7 +152,11 @@ class PatternXX(BasePattern):
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
noise_state.set_as_initiator()
noise_state.start_handshake()
if noise_state.noise_protocol is None:
raise NoiseStateError("noise_protocol is not initialized")
handshake_state = noise_state.noise_protocol.handshake_state
if handshake_state is None:
raise NoiseStateError("Handshake state is not initialized")
# Send msg#1, which is *not* encrypted.
msg_1 = b""
@ -195,6 +206,8 @@ class PatternXX(BasePattern):
@staticmethod
def _get_pubkey_from_noise_keypair(key_pair: NoiseKeyPair) -> PublicKey:
# Use `Ed25519PublicKey` since 25519 is used in our pattern.
if key_pair.public is None:
raise NoiseStateError("public key is not initialized")
raw_bytes = key_pair.public.public_bytes(
serialization.Encoding.Raw, serialization.PublicFormat.Raw
)

View File

@ -26,7 +26,7 @@ class Transport(ISecureTransport):
libp2p_privkey: PrivateKey
noise_privkey: PrivateKey
local_peer: ID
early_data: bytes
early_data: bytes | None
with_noise_pipes: bool
# NOTE: Implementations that support Noise Pipes must decide whether to use
@ -37,8 +37,8 @@ class Transport(ISecureTransport):
def __init__(
self,
libp2p_keypair: KeyPair,
noise_privkey: PrivateKey = None,
early_data: bytes = None,
noise_privkey: PrivateKey,
early_data: bytes | None = None,
with_noise_pipes: bool = False,
) -> None:
self.libp2p_privkey = libp2p_keypair.private_key

Some files were not shown because too many files have changed in this diff Show More