mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Compare commits
1 Commits
async-vali
...
varun-r-ma
| Author | SHA1 | Date | |
|---|---|---|---|
| 5983c08379 |
6
.github/workflows/tox.yml
vendored
6
.github/workflows/tox.yml
vendored
@ -16,10 +16,10 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python: ["3.10", "3.11", "3.12", "3.13"]
|
||||
python: ['3.9', '3.10', '3.11', '3.12', '3.13']
|
||||
toxenv: [core, interop, lint, wheel, demos]
|
||||
include:
|
||||
- python: "3.10"
|
||||
- python: '3.10'
|
||||
toxenv: docs
|
||||
fail-fast: false
|
||||
steps:
|
||||
@ -46,7 +46,7 @@ jobs:
|
||||
runs-on: windows-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.11", "3.12", "3.13"]
|
||||
python-version: ['3.11', '3.12', '3.13']
|
||||
toxenv: [core, wheel]
|
||||
fail-fast: false
|
||||
steps:
|
||||
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@ -146,9 +146,6 @@ instance/
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# PyRight Config
|
||||
pyrightconfig.json
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
@ -174,7 +171,3 @@ env.bak/
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
#lockfiles
|
||||
uv.lock
|
||||
poetry.lock
|
||||
|
||||
@ -1,49 +1,59 @@
|
||||
exclude: '.project-template|docs/conf.py|.*pb2\..*'
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
- id: check-toml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.20.0
|
||||
- id: check-yaml
|
||||
- id: check-toml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.15.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args: [--py310-plus]
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.11.10
|
||||
- id: pyupgrade
|
||||
args: [--py39-plus]
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.9.1
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix, --exit-non-zero-on-fix]
|
||||
- id: ruff-format
|
||||
- repo: https://github.com/executablebooks/mdformat
|
||||
- id: black
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 6.1.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
additional_dependencies:
|
||||
- flake8-bugbear==23.9.16
|
||||
exclude: setup.py
|
||||
- repo: https://github.com/PyCQA/autoflake
|
||||
rev: v2.2.1
|
||||
hooks:
|
||||
- id: autoflake
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.12.0
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/pycqa/pydocstyle
|
||||
rev: 6.3.0
|
||||
hooks:
|
||||
- id: pydocstyle
|
||||
additional_dependencies:
|
||||
- tomli # required until >= python311
|
||||
- repo: https://github.com/executablebooks/mdformat
|
||||
rev: 0.7.22
|
||||
hooks:
|
||||
- id: mdformat
|
||||
- id: mdformat
|
||||
additional_dependencies:
|
||||
- mdformat-gfm
|
||||
- repo: local
|
||||
- mdformat-gfm
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: mypy-local
|
||||
- id: mypy-local
|
||||
name: run mypy with all dev dependencies present
|
||||
entry: mypy -p libp2p
|
||||
entry: python -m mypy -p libp2p
|
||||
language: system
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
- repo: local
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: pyrefly-local
|
||||
name: run pyrefly typecheck locally
|
||||
entry: pyrefly check
|
||||
language: system
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: check-rst-files
|
||||
- id: check-rst-files
|
||||
name: Check for .rst files in the top-level directory
|
||||
entry: python -c "import glob, sys; rst_files = glob.glob('*.rst'); sys.exit(1) if rst_files else sys.exit(0)"
|
||||
language: system
|
||||
|
||||
71
.project-template/fill_template_vars.py
Normal file
71
.project-template/fill_template_vars.py
Normal file
@ -0,0 +1,71 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _find_files(project_root):
|
||||
path_exclude_pattern = r"\.git($|\/)|venv|_build"
|
||||
file_exclude_pattern = r"fill_template_vars\.py|\.swp$"
|
||||
filepaths = []
|
||||
for dir_path, _dir_names, file_names in os.walk(project_root):
|
||||
if not re.search(path_exclude_pattern, dir_path):
|
||||
for file in file_names:
|
||||
if not re.search(file_exclude_pattern, file):
|
||||
filepaths.append(str(Path(dir_path, file)))
|
||||
|
||||
return filepaths
|
||||
|
||||
|
||||
def _replace(pattern, replacement, project_root):
|
||||
print(f"Replacing values: {pattern}")
|
||||
for file in _find_files(project_root):
|
||||
try:
|
||||
with open(file) as f:
|
||||
content = f.read()
|
||||
content = re.sub(pattern, replacement, content)
|
||||
with open(file, "w") as f:
|
||||
f.write(content)
|
||||
except UnicodeDecodeError:
|
||||
pass
|
||||
|
||||
|
||||
def main():
|
||||
project_root = Path(os.path.realpath(sys.argv[0])).parent.parent
|
||||
|
||||
module_name = input("What is your python module name? ")
|
||||
|
||||
pypi_input = input(f"What is your pypi package name? (default: {module_name}) ")
|
||||
pypi_name = pypi_input or module_name
|
||||
|
||||
repo_input = input(f"What is your github project name? (default: {pypi_name}) ")
|
||||
repo_name = repo_input or pypi_name
|
||||
|
||||
rtd_input = input(
|
||||
f"What is your readthedocs.org project name? (default: {pypi_name}) "
|
||||
)
|
||||
rtd_name = rtd_input or pypi_name
|
||||
|
||||
project_input = input(
|
||||
f"What is your project name (ex: at the top of the README)? (default: {repo_name}) "
|
||||
)
|
||||
project_name = project_input or repo_name
|
||||
|
||||
short_description = input("What is a one-liner describing the project? ")
|
||||
|
||||
_replace("<MODULE_NAME>", module_name, project_root)
|
||||
_replace("<PYPI_NAME>", pypi_name, project_root)
|
||||
_replace("<REPO_NAME>", repo_name, project_root)
|
||||
_replace("<RTD_NAME>", rtd_name, project_root)
|
||||
_replace("<PROJECT_NAME>", project_name, project_root)
|
||||
_replace("<SHORT_DESCRIPTION>", short_description, project_root)
|
||||
|
||||
os.makedirs(project_root / module_name, exist_ok=True)
|
||||
Path(project_root / module_name / "__init__.py").touch()
|
||||
Path(project_root / module_name / "py.typed").touch()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
39
.project-template/refill_template_vars.py
Normal file
39
.project-template/refill_template_vars.py
Normal file
@ -0,0 +1,39 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
|
||||
|
||||
def main():
|
||||
template_dir = Path(os.path.dirname(sys.argv[0]))
|
||||
template_vars_file = template_dir / "template_vars.txt"
|
||||
fill_template_vars_script = template_dir / "fill_template_vars.py"
|
||||
|
||||
with open(template_vars_file, "r") as input_file:
|
||||
content_lines = input_file.readlines()
|
||||
|
||||
process = subprocess.Popen(
|
||||
[sys.executable, str(fill_template_vars_script)],
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
)
|
||||
|
||||
for line in content_lines:
|
||||
process.stdin.write(line)
|
||||
process.stdin.flush()
|
||||
|
||||
stdout, stderr = process.communicate()
|
||||
|
||||
if process.returncode != 0:
|
||||
print(f"Error occurred: {stderr}")
|
||||
sys.exit(1)
|
||||
|
||||
print(stdout)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
6
.project-template/template_vars.txt
Normal file
6
.project-template/template_vars.txt
Normal file
@ -0,0 +1,6 @@
|
||||
libp2p
|
||||
libp2p
|
||||
py-libp2p
|
||||
py-libp2p
|
||||
py-libp2p
|
||||
The Python implementation of the libp2p networking stack
|
||||
19
Makefile
19
Makefile
@ -7,14 +7,12 @@ help:
|
||||
@echo "clean-pyc - remove Python file artifacts"
|
||||
@echo "clean - run clean-build and clean-pyc"
|
||||
@echo "dist - build package and cat contents of the dist directory"
|
||||
@echo "fix - fix formatting & linting issues with ruff"
|
||||
@echo "lint - fix linting issues with pre-commit"
|
||||
@echo "test - run tests quickly with the default Python"
|
||||
@echo "docs - generate docs and open in browser (linux-docs for version on linux)"
|
||||
@echo "package-test - build package and install it in a venv for manual testing"
|
||||
@echo "notes - consume towncrier newsfragments and update release notes in docs - requires bump to be set"
|
||||
@echo "release - package and upload a release (does not run notes target) - requires bump to be set"
|
||||
@echo "pr - run clean, fix, lint, typecheck, and test i.e basically everything you need to do before creating a PR"
|
||||
|
||||
clean-build:
|
||||
rm -fr build/
|
||||
@ -39,16 +37,8 @@ lint:
|
||||
&& pre-commit run --all-files --show-diff-on-failure \
|
||||
)
|
||||
|
||||
fix:
|
||||
python -m ruff check --fix
|
||||
|
||||
typecheck:
|
||||
pre-commit run mypy-local --all-files && pre-commit run pyrefly-local --all-files
|
||||
|
||||
test:
|
||||
python -m pytest tests -n auto
|
||||
|
||||
pr: clean fix lint typecheck test
|
||||
python -m pytest tests
|
||||
|
||||
# protobufs management
|
||||
|
||||
@ -58,10 +48,7 @@ PB = libp2p/crypto/pb/crypto.proto \
|
||||
libp2p/security/secio/pb/spipe.proto \
|
||||
libp2p/security/noise/pb/noise.proto \
|
||||
libp2p/identity/identify/pb/identify.proto \
|
||||
libp2p/host/autonat/pb/autonat.proto \
|
||||
libp2p/relay/circuit_v2/pb/circuit.proto \
|
||||
libp2p/kad_dht/pb/kademlia.proto
|
||||
|
||||
libp2p/host/autonat/pb/autonat.proto
|
||||
PY = $(PB:.proto=_pb2.py)
|
||||
PYI = $(PB:.proto=_pb2.pyi)
|
||||
|
||||
@ -93,7 +80,7 @@ validate-newsfragments:
|
||||
check-docs: build-docs validate-newsfragments
|
||||
|
||||
build-docs:
|
||||
sphinx-apidoc -o docs/ . "*conftest*" tests/
|
||||
sphinx-apidoc -o docs/ . setup.py "*conftest*" tests/
|
||||
$(MAKE) -C docs clean
|
||||
$(MAKE) -C docs html
|
||||
$(MAKE) -C docs doctest
|
||||
|
||||
30
docs/conf.py
30
docs/conf.py
@ -15,24 +15,14 @@
|
||||
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||
# sys.path.insert(0, os.path.abspath('.'))
|
||||
|
||||
import doctest
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
try:
|
||||
import tomllib
|
||||
except ModuleNotFoundError:
|
||||
# For Python < 3.11
|
||||
import tomli as tomllib # type: ignore (In case of >3.11 Pyrefly doesnt find tomli , which is right but a false flag)
|
||||
|
||||
# Path to pyproject.toml (assuming conf.py is in a 'docs' subdirectory)
|
||||
pyproject_path = os.path.join(os.path.dirname(__file__), "..", "pyproject.toml")
|
||||
|
||||
with open(pyproject_path, "rb") as f:
|
||||
pyproject_data = tomllib.load(f)
|
||||
|
||||
setup_version = pyproject_data["project"]["version"]
|
||||
DIR = os.path.dirname(__file__)
|
||||
with open(os.path.join(DIR, "../setup.py"), "r") as f:
|
||||
for line in f:
|
||||
if "version=" in line:
|
||||
setup_version = line.split('"')[1]
|
||||
break
|
||||
|
||||
# -- General configuration ------------------------------------------------
|
||||
|
||||
@ -312,6 +302,7 @@ intersphinx_mapping = {
|
||||
|
||||
# -- Doctest configuration ----------------------------------------
|
||||
|
||||
import doctest
|
||||
|
||||
doctest_default_flags = (
|
||||
0
|
||||
@ -326,9 +317,10 @@ doctest_default_flags = (
|
||||
# Mock out dependencies that are unbuildable on readthedocs, as recommended here:
|
||||
# https://docs.readthedocs.io/en/rel/faq.html#i-get-import-errors-on-libraries-that-depend-on-c-modules
|
||||
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Add new modules to mock here (it should be the same list
|
||||
# as those excluded in pyproject.toml)
|
||||
# Add new modules to mock here (it should be the same list as those excluded in setup.py)
|
||||
MOCK_MODULES = [
|
||||
"fastecdsa",
|
||||
"fastecdsa.encoding",
|
||||
@ -346,4 +338,4 @@ todo_include_todos = True
|
||||
|
||||
# Allow duplicate object descriptions
|
||||
nitpicky = False
|
||||
nitpick_ignore = [("py:class", "type")]
|
||||
nitpick_ignore = [("py:class", "type")]
|
||||
@ -1,499 +0,0 @@
|
||||
Circuit Relay v2 Example
|
||||
========================
|
||||
|
||||
This example demonstrates how to use Circuit Relay v2 in py-libp2p. It includes three components:
|
||||
|
||||
1. A relay node that provides relay services
|
||||
2. A destination node that accepts relayed connections
|
||||
3. A source node that connects to the destination through the relay
|
||||
|
||||
Prerequisites
|
||||
-------------
|
||||
|
||||
First, ensure you have py-libp2p installed:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python -m pip install libp2p
|
||||
Collecting libp2p
|
||||
...
|
||||
Successfully installed libp2p-x.x.x
|
||||
|
||||
Relay Node
|
||||
----------
|
||||
|
||||
Create a file named ``relay_node.py`` with the following content:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import trio
|
||||
import logging
|
||||
import multiaddr
|
||||
import traceback
|
||||
|
||||
from libp2p import new_host
|
||||
from libp2p.relay.circuit_v2.protocol import CircuitV2Protocol
|
||||
from libp2p.relay.circuit_v2.transport import CircuitV2Transport
|
||||
from libp2p.relay.circuit_v2.config import RelayConfig
|
||||
from libp2p.tools.async_service import background_trio_service
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger("relay_node")
|
||||
|
||||
async def run_relay():
|
||||
listen_addr = multiaddr.Multiaddr("/ip4/0.0.0.0/tcp/9000")
|
||||
host = new_host()
|
||||
|
||||
config = RelayConfig(
|
||||
enable_hop=True, # Act as a relay
|
||||
enable_stop=True, # Accept relayed connections
|
||||
enable_client=False, # Don't use other relays
|
||||
max_circuit_duration=3600, # 1 hour
|
||||
max_circuit_bytes=1024 * 1024 * 10, # 10MB
|
||||
)
|
||||
|
||||
# Initialize the relay protocol with allow_hop=True to act as a relay
|
||||
protocol = CircuitV2Protocol(host, limits=config.limits, allow_hop=True)
|
||||
print(f"Created relay protocol with hop enabled: {protocol.allow_hop}")
|
||||
|
||||
# Start the protocol service
|
||||
async with host.run(listen_addrs=[listen_addr]):
|
||||
peer_id = host.get_id()
|
||||
print("\n" + "="*50)
|
||||
print(f"Relay node started with ID: {peer_id}")
|
||||
print(f"Relay node multiaddr: /ip4/127.0.0.1/tcp/9000/p2p/{peer_id}")
|
||||
print("="*50 + "\n")
|
||||
print(f"Listening on: {host.get_addrs()}")
|
||||
|
||||
try:
|
||||
async with background_trio_service(protocol):
|
||||
print("Protocol service started")
|
||||
|
||||
transport = CircuitV2Transport(host, protocol, config)
|
||||
print("Relay service started successfully")
|
||||
print(f"Relay limits: {protocol.limits}")
|
||||
|
||||
while True:
|
||||
await trio.sleep(10)
|
||||
print("Relay node still running...")
|
||||
print(f"Active connections: {len(host.get_network().connections)}")
|
||||
except Exception as e:
|
||||
print(f"Error in relay service: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
trio.run(run_relay)
|
||||
except Exception as e:
|
||||
print(f"Error running relay: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
Destination Node
|
||||
----------------
|
||||
|
||||
Create a file named ``destination_node.py`` with the following content:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import trio
|
||||
import logging
|
||||
import multiaddr
|
||||
import traceback
|
||||
import sys
|
||||
|
||||
from libp2p import new_host
|
||||
from libp2p.relay.circuit_v2.protocol import CircuitV2Protocol
|
||||
from libp2p.relay.circuit_v2.transport import CircuitV2Transport
|
||||
from libp2p.relay.circuit_v2.config import RelayConfig
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
from libp2p.tools.async_service import background_trio_service
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger("destination_node")
|
||||
|
||||
async def handle_echo_stream(stream):
|
||||
"""Handle incoming stream by echoing received data."""
|
||||
try:
|
||||
print(f"New echo stream from: {stream.get_protocol()}")
|
||||
while True:
|
||||
data = await stream.read(1024)
|
||||
if not data:
|
||||
print("Stream closed by remote")
|
||||
break
|
||||
|
||||
message = data.decode('utf-8')
|
||||
print(f"Received: {message}")
|
||||
|
||||
response = f"Echo: {message}".encode('utf-8')
|
||||
await stream.write(response)
|
||||
print(f"Sent response: Echo: {message}")
|
||||
except Exception as e:
|
||||
print(f"Error handling stream: {e}")
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
await stream.close()
|
||||
print("Stream closed")
|
||||
|
||||
async def run_destination(relay_peer_id=None):
|
||||
"""
|
||||
Run a simple destination node that accepts connections.
|
||||
This is a simplified version that doesn't use the relay functionality.
|
||||
"""
|
||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/9001")
|
||||
host = new_host()
|
||||
|
||||
# Configure as a relay receiver (stop)
|
||||
config = RelayConfig(
|
||||
enable_stop=True, # Accept relayed connections
|
||||
enable_client=True, # Use relays for outbound connections
|
||||
max_circuit_duration=3600, # 1 hour
|
||||
max_circuit_bytes=1024 * 1024 * 10, # 10MB
|
||||
)
|
||||
|
||||
# Initialize the relay protocol
|
||||
protocol = CircuitV2Protocol(host, limits=config.limits, allow_hop=False)
|
||||
|
||||
async with host.run(listen_addrs=[listen_addr]):
|
||||
# Print host information
|
||||
dest_peer_id = host.get_id()
|
||||
print("\n" + "="*50)
|
||||
print(f"Destination node started with ID: {dest_peer_id}")
|
||||
print(f"Use this ID in the source node: {dest_peer_id}")
|
||||
print("="*50 + "\n")
|
||||
print(f"Listening on: {host.get_addrs()}")
|
||||
|
||||
# Set stream handler for the echo protocol
|
||||
host.set_stream_handler("/echo/1.0.0", handle_echo_stream)
|
||||
print("Registered echo protocol handler")
|
||||
|
||||
# Start the protocol service in the background
|
||||
async with background_trio_service(protocol):
|
||||
print("Protocol service started")
|
||||
|
||||
# Create and register the transport
|
||||
transport = CircuitV2Transport(host, protocol, config)
|
||||
print("Transport created")
|
||||
|
||||
# Create a listener for relayed connections
|
||||
listener = transport.create_listener(handle_echo_stream)
|
||||
print("Created relay listener")
|
||||
|
||||
# Start listening for relayed connections
|
||||
async with trio.open_nursery() as nursery:
|
||||
await listener.listen("/p2p-circuit", nursery)
|
||||
print("Destination node ready to accept relayed connections")
|
||||
|
||||
if not relay_peer_id:
|
||||
print("No relay peer ID provided. Please enter the relay's peer ID:")
|
||||
print("Waiting for relay peer ID input...")
|
||||
while True:
|
||||
if sys.stdin.isatty(): # Only try to read from stdin if it's a terminal
|
||||
try:
|
||||
relay_peer_id = input("Enter relay peer ID: ").strip()
|
||||
if relay_peer_id:
|
||||
break
|
||||
except EOFError:
|
||||
await trio.sleep(5)
|
||||
else:
|
||||
print("No terminal detected. Waiting for relay peer ID as command line argument.")
|
||||
await trio.sleep(10)
|
||||
continue
|
||||
|
||||
# Connect to the relay node with the provided relay peer ID
|
||||
relay_addr_str = f"/ip4/127.0.0.1/tcp/9000/p2p/{relay_peer_id}"
|
||||
print(f"Connecting to relay at {relay_addr_str}")
|
||||
|
||||
try:
|
||||
# Convert string address to multiaddr, then to peer info
|
||||
relay_maddr = multiaddr.Multiaddr(relay_addr_str)
|
||||
relay_peer_info = info_from_p2p_addr(relay_maddr)
|
||||
await host.connect(relay_peer_info)
|
||||
print("Connected to relay successfully")
|
||||
|
||||
# Add the relay to the transport's discovery
|
||||
transport.discovery._add_relay(relay_peer_info.peer_id)
|
||||
print(f"Added relay {relay_peer_info.peer_id} to discovery")
|
||||
|
||||
# Keep the node running
|
||||
while True:
|
||||
await trio.sleep(10)
|
||||
print("Destination node still running...")
|
||||
except Exception as e:
|
||||
print(f"Failed to connect to relay: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Starting destination node...")
|
||||
relay_id = None
|
||||
if len(sys.argv) > 1:
|
||||
relay_id = sys.argv[1]
|
||||
print(f"Using provided relay ID: {relay_id}")
|
||||
trio.run(run_destination, relay_id)
|
||||
|
||||
Source Node
|
||||
-----------
|
||||
|
||||
Create a file named ``source_node.py`` with the following content:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import trio
|
||||
import logging
|
||||
import multiaddr
|
||||
import traceback
|
||||
import sys
|
||||
|
||||
from libp2p import new_host
|
||||
from libp2p.peer.peerinfo import PeerInfo
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.relay.circuit_v2.protocol import CircuitV2Protocol
|
||||
from libp2p.relay.circuit_v2.transport import CircuitV2Transport
|
||||
from libp2p.relay.circuit_v2.config import RelayConfig
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
from libp2p.tools.async_service import background_trio_service
|
||||
from libp2p.relay.circuit_v2.discovery import RelayInfo
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger("source_node")
|
||||
|
||||
async def run_source(relay_peer_id=None, destination_peer_id=None):
|
||||
# Create a libp2p host
|
||||
listen_addr = multiaddr.Multiaddr("/ip4/0.0.0.0/tcp/9002")
|
||||
host = new_host()
|
||||
|
||||
# Configure as a relay client
|
||||
config = RelayConfig(
|
||||
enable_client=True, # Use relays for outbound connections
|
||||
max_circuit_duration=3600, # 1 hour
|
||||
max_circuit_bytes=1024 * 1024 * 10, # 10MB
|
||||
)
|
||||
|
||||
# Initialize the relay protocol
|
||||
protocol = CircuitV2Protocol(host, limits=config.limits, allow_hop=False)
|
||||
|
||||
# Start the protocol service
|
||||
async with host.run(listen_addrs=[listen_addr]):
|
||||
# Print host information
|
||||
print(f"Source node started with ID: {host.get_id()}")
|
||||
print(f"Listening on: {host.get_addrs()}")
|
||||
|
||||
# Start the protocol service in the background
|
||||
async with background_trio_service(protocol):
|
||||
print("Protocol service started")
|
||||
|
||||
# Create and register the transport
|
||||
transport = CircuitV2Transport(host, protocol, config)
|
||||
|
||||
# Get relay peer ID if not provided
|
||||
if not relay_peer_id:
|
||||
print("No relay peer ID provided. Please enter the relay's peer ID:")
|
||||
while True:
|
||||
if sys.stdin.isatty(): # Only try to read from stdin if it's a terminal
|
||||
try:
|
||||
relay_peer_id = input("Enter relay peer ID: ").strip()
|
||||
if relay_peer_id:
|
||||
break
|
||||
except EOFError:
|
||||
await trio.sleep(5)
|
||||
else:
|
||||
print("No terminal detected. Waiting for relay peer ID as command line argument.")
|
||||
await trio.sleep(10)
|
||||
continue
|
||||
|
||||
# Connect to the relay node with the provided relay peer ID
|
||||
relay_addr_str = f"/ip4/127.0.0.1/tcp/9000/p2p/{relay_peer_id}"
|
||||
print(f"Connecting to relay at {relay_addr_str}")
|
||||
|
||||
try:
|
||||
# Convert string address to multiaddr, then to peer info
|
||||
relay_maddr = multiaddr.Multiaddr(relay_addr_str)
|
||||
relay_peer_info = info_from_p2p_addr(relay_maddr)
|
||||
await host.connect(relay_peer_info)
|
||||
print("Connected to relay successfully")
|
||||
|
||||
# Manually add the relay to the discovery service
|
||||
relay_id = relay_peer_info.peer_id
|
||||
now = trio.current_time()
|
||||
|
||||
# Create relay info and add it to discovery
|
||||
relay_info = RelayInfo(
|
||||
peer_id=relay_id,
|
||||
discovered_at=now,
|
||||
last_seen=now
|
||||
)
|
||||
transport.discovery._discovered_relays[relay_id] = relay_info
|
||||
print(f"Added relay {relay_id} to discovery")
|
||||
|
||||
# Start relay discovery in the background
|
||||
async with background_trio_service(transport.discovery):
|
||||
print("Relay discovery started")
|
||||
|
||||
# Wait for relay discovery
|
||||
await trio.sleep(5)
|
||||
print("Relay discovery completed")
|
||||
|
||||
# Get destination peer ID if not provided
|
||||
if not destination_peer_id:
|
||||
print("No destination peer ID provided. Please enter the destination's peer ID:")
|
||||
while True:
|
||||
if sys.stdin.isatty(): # Only try to read from stdin if it's a terminal
|
||||
try:
|
||||
destination_peer_id = input("Enter destination peer ID: ").strip()
|
||||
if destination_peer_id:
|
||||
break
|
||||
except EOFError:
|
||||
await trio.sleep(5)
|
||||
else:
|
||||
print("No terminal detected. Waiting for destination peer ID as command line argument.")
|
||||
await trio.sleep(10)
|
||||
continue
|
||||
|
||||
print(f"Attempting to connect to {destination_peer_id} via relay")
|
||||
|
||||
# Check if we have any discovered relays
|
||||
discovered_relays = list(transport.discovery._discovered_relays.keys())
|
||||
print(f"Discovered relays: {discovered_relays}")
|
||||
|
||||
try:
|
||||
# Create a circuit relay multiaddr for the destination
|
||||
dest_id = ID.from_base58(destination_peer_id)
|
||||
|
||||
# Create a circuit multiaddr that includes the relay
|
||||
# Format: /ip4/127.0.0.1/tcp/9000/p2p/RELAY_ID/p2p-circuit/p2p/DEST_ID
|
||||
circuit_addr = multiaddr.Multiaddr(f"{relay_addr_str}/p2p-circuit/p2p/{destination_peer_id}")
|
||||
print(f"Created circuit address: {circuit_addr}")
|
||||
|
||||
# Dial using the circuit address
|
||||
connection = await transport.dial(circuit_addr)
|
||||
print("Connection established through relay!")
|
||||
|
||||
# Open a stream using the echo protocol
|
||||
stream = await connection.new_stream("/echo/1.0.0")
|
||||
|
||||
# Send messages periodically
|
||||
for i in range(5):
|
||||
message = f"Hello from source, message {i+1}"
|
||||
print(f"Sending: {message}")
|
||||
|
||||
await stream.write(message.encode('utf-8'))
|
||||
response = await stream.read(1024)
|
||||
|
||||
print(f"Received: {response.decode('utf-8')}")
|
||||
await trio.sleep(1)
|
||||
|
||||
# Close the stream
|
||||
await stream.close()
|
||||
print("Stream closed")
|
||||
except Exception as e:
|
||||
print(f"Error connecting through relay: {e}")
|
||||
print("Detailed error:")
|
||||
traceback.print_exc()
|
||||
|
||||
# Keep the node running for a while
|
||||
await trio.sleep(30)
|
||||
print("Source node shutting down")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
relay_id = None
|
||||
dest_id = None
|
||||
|
||||
# Parse command line arguments if provided
|
||||
if len(sys.argv) > 1:
|
||||
relay_id = sys.argv[1]
|
||||
print(f"Using provided relay ID: {relay_id}")
|
||||
|
||||
if len(sys.argv) > 2:
|
||||
dest_id = sys.argv[2]
|
||||
print(f"Using provided destination ID: {dest_id}")
|
||||
|
||||
trio.run(run_source, relay_id, dest_id)
|
||||
|
||||
Running the Example
|
||||
-------------------
|
||||
|
||||
1. First, start the relay node:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python relay_node.py
|
||||
Created relay protocol with hop enabled: True
|
||||
|
||||
==================================================
|
||||
Relay node started with ID: QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx
|
||||
Relay node multiaddr: /ip4/127.0.0.1/tcp/9000/p2p/QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx
|
||||
==================================================
|
||||
|
||||
Listening on: [<Multiaddr /ip4/0.0.0.0/tcp/9000/p2p/QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx>]
|
||||
Protocol service started
|
||||
Relay service started successfully
|
||||
Relay limits: RelayLimits(duration=3600, data=10485760, max_circuit_conns=8, max_reservations=4)
|
||||
|
||||
Note the relay node\'s peer ID (in this example: `QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx`). You\'ll need this for the other nodes.
|
||||
|
||||
2. Next, start the destination node:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python destination_node.py
|
||||
Starting destination node...
|
||||
|
||||
==================================================
|
||||
Destination node started with ID: QmPBr38KeQG2ibyL4fxq6yJWpfoVNCqJMHBdNyn1Qe4h5s
|
||||
Use this ID in the source node: QmPBr38KeQG2ibyL4fxq6yJWpfoVNCqJMHBdNyn1Qe4h5s
|
||||
==================================================
|
||||
|
||||
Listening on: [<Multiaddr /ip4/0.0.0.0/tcp/9001/p2p/QmPBr38KeQG2ibyL4fxq6yJWpfoVNCqJMHBdNyn1Qe4h5s>]
|
||||
Registered echo protocol handler
|
||||
Protocol service started
|
||||
Transport created
|
||||
Created relay listener
|
||||
Destination node ready to accept relayed connections
|
||||
No relay peer ID provided. Please enter the relay\'s peer ID:
|
||||
Waiting for relay peer ID input...
|
||||
Enter relay peer ID: QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx
|
||||
Connecting to relay at /ip4/127.0.0.1/tcp/9000/p2p/QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx
|
||||
Connected to relay successfully
|
||||
Added relay QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx to discovery
|
||||
Destination node still running...
|
||||
|
||||
Note the destination node's peer ID (in this example: `QmPBr38KeQG2ibyL4fxq6yJWpfoVNCqJMHBdNyn1Qe4h5s`). You'll need this for the source node.
|
||||
|
||||
3. Finally, start the source node:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python source_node.py
|
||||
Source node started with ID: QmPyM56cgmFoHTgvMgGfDWRdVRQznmxCDDDg2dJ8ygVXj3
|
||||
Listening on: [<Multiaddr /ip4/0.0.0.0/tcp/9002/p2p/QmPyM56cgmFoHTgvMgGfDWRdVRQznmxCDDDg2dJ8ygVXj3>]
|
||||
Protocol service started
|
||||
No relay peer ID provided. Please enter the relay\'s peer ID:
|
||||
Enter relay peer ID: QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx
|
||||
Connecting to relay at /ip4/127.0.0.1/tcp/9000/p2p/QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx
|
||||
Connected to relay successfully
|
||||
Added relay QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx to discovery
|
||||
Relay discovery started
|
||||
Relay discovery completed
|
||||
No destination peer ID provided. Please enter the destination\'s peer ID:
|
||||
Enter destination peer ID: QmPBr38KeQG2ibyL4fxq6yJWpfoVNCqJMHBdNyn1Qe4h5s
|
||||
Attempting to connect to QmPBr38KeQG2ibyL4fxq6yJWpfoVNCqJMHBdNyn1Qe4h5s via relay
|
||||
Discovered relays: [<libp2p.peer.id.ID (QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx)>]
|
||||
Created circuit address: /ip4/127.0.0.1/tcp/9000/p2p/QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx/p2p-circuit/p2p/QmPBr38KeQG2ibyL4fxq6yJWpfoVNCqJMHBdNyn1Qe4h5s
|
||||
|
||||
At this point, the source node will establish a connection through the relay to the destination node and start sending messages.
|
||||
|
||||
4. Alternatively, you can provide the peer IDs as command-line arguments:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
# For the destination node (provide relay ID)
|
||||
$ python destination_node.py QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx
|
||||
|
||||
# For the source node (provide both relay and destination IDs)
|
||||
$ python source_node.py QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx QmPBr38KeQG2ibyL4fxq6yJWpfoVNCqJMHBdNyn1Qe4h5s
|
||||
|
||||
This example demonstrates how to use Circuit Relay v2 to establish connections between peers that cannot connect directly. The peer IDs are dynamically generated for each node, and the relay facilitates communication between the source and destination nodes.
|
||||
@ -1,124 +0,0 @@
|
||||
Kademlia DHT Demo
|
||||
=================
|
||||
|
||||
This example demonstrates a Kademlia Distributed Hash Table (DHT) implementation with both value storage/retrieval and content provider advertisement/discovery functionality.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python -m pip install libp2p
|
||||
Collecting libp2p
|
||||
...
|
||||
Successfully installed libp2p-x.x.x
|
||||
$ cd examples/kademlia
|
||||
$ python kademlia.py --mode server
|
||||
2025-06-13 19:51:25,424 - kademlia-example - INFO - Running in server mode on port 0
|
||||
2025-06-13 19:51:25,426 - kademlia-example - INFO - Connected to bootstrap nodes: []
|
||||
2025-06-13 19:51:25,426 - kademlia-example - INFO - To connect to this node, use: --bootstrap /ip4/127.0.0.1/tcp/28910/p2p/16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef
|
||||
2025-06-13 19:51:25,426 - kademlia-example - INFO - Saved server address to log: /ip4/127.0.0.1/tcp/28910/p2p/16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef
|
||||
2025-06-13 19:51:25,427 - kademlia-example - INFO - DHT service started in SERVER mode
|
||||
2025-06-13 19:51:25,427 - kademlia-example - INFO - Stored value 'Hello message from Sumanjeet' with key: FVDjasarSFDoLPMdgnp1dHSbW2ZAfN8NU2zNbCQeczgP
|
||||
2025-06-13 19:51:25,427 - kademlia-example - INFO - Successfully advertised as server for content: 361f2ed1183bca491b8aec11f0b9e5c06724759b0f7480ae7fb4894901993bc8
|
||||
|
||||
|
||||
Copy the line that starts with ``--bootstrap``, open a new terminal in the same folder and run the client:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python kademlia.py --mode client --bootstrap /ip4/127.0.0.1/tcp/28910/p2p/16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef
|
||||
2025-06-13 19:51:37,022 - kademlia-example - INFO - Running in client mode on port 0
|
||||
2025-06-13 19:51:37,026 - kademlia-example - INFO - Connected to bootstrap nodes: [<libp2p.peer.id.ID (16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef)>]
|
||||
2025-06-13 19:51:37,027 - kademlia-example - INFO - DHT service started in CLIENT mode
|
||||
2025-06-13 19:51:37,027 - kademlia-example - INFO - Looking up key: FVDjasarSFDoLPMdgnp1dHSbW2ZAfN8NU2zNbCQeczgP
|
||||
2025-06-13 19:51:37,031 - kademlia-example - INFO - Retrieved value: Hello message from Sumanjeet
|
||||
2025-06-13 19:51:37,031 - kademlia-example - INFO - Looking for servers of content: 361f2ed1183bca491b8aec11f0b9e5c06724759b0f7480ae7fb4894901993bc8
|
||||
2025-06-13 19:51:37,035 - kademlia-example - INFO - Found 1 servers for content: ['16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef']
|
||||
|
||||
Alternatively, if you run the server first, the client can automatically extract the bootstrap address from the server log file:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python kademlia.py --mode client
|
||||
2025-06-13 19:51:37,022 - kademlia-example - INFO - Running in client mode on port 0
|
||||
2025-06-13 19:51:37,026 - kademlia-example - INFO - Connected to bootstrap nodes: [<libp2p.peer.id.ID (16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef)>]
|
||||
2025-06-13 19:51:37,027 - kademlia-example - INFO - DHT service started in CLIENT mode
|
||||
2025-06-13 19:51:37,027 - kademlia-example - INFO - Looking up key: FVDjasarSFDoLPMdgnp1dHSbW2ZAfN8NU2zNbCQeczgP
|
||||
2025-06-13 19:51:37,031 - kademlia-example - INFO - Retrieved value: Hello message from Sumanjeet
|
||||
2025-06-13 19:51:37,031 - kademlia-example - INFO - Looking for servers of content: 361f2ed1183bca491b8aec11f0b9e5c06724759b0f7480ae7fb4894901993bc8
|
||||
2025-06-13 19:51:37,035 - kademlia-example - INFO - Found 1 servers for content: ['16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef']
|
||||
|
||||
The demo showcases key DHT operations:
|
||||
|
||||
- **Value Storage & Retrieval**: The server stores a value, and the client retrieves it
|
||||
- **Content Provider Discovery**: The server advertises content, and the client finds providers
|
||||
- **Peer Discovery**: Automatic bootstrap and peer routing using the Kademlia algorithm
|
||||
- **Network Resilience**: Distributed storage across multiple nodes (when available)
|
||||
|
||||
Command Line Options
|
||||
--------------------
|
||||
|
||||
The Kademlia demo supports several command line options for customization:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python kademlia.py --help
|
||||
usage: kademlia.py [-h] [--mode MODE] [--port PORT] [--bootstrap [BOOTSTRAP ...]] [--verbose]
|
||||
|
||||
Kademlia DHT example with content server functionality
|
||||
|
||||
options:
|
||||
-h, --help show this help message and exit
|
||||
--mode MODE Run as a server or client node (default: server)
|
||||
--port PORT Port to listen on (0 for random) (default: 0)
|
||||
--bootstrap [BOOTSTRAP ...]
|
||||
Multiaddrs of bootstrap nodes. Provide a space-separated list of addresses.
|
||||
This is required for client mode.
|
||||
--verbose Enable verbose logging
|
||||
|
||||
**Examples:**
|
||||
|
||||
Start server on a specific port:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python kademlia.py --mode server --port 8000
|
||||
|
||||
Start client with verbose logging:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python kademlia.py --mode client --verbose
|
||||
|
||||
Connect to multiple bootstrap nodes:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python kademlia.py --mode client --bootstrap /ip4/127.0.0.1/tcp/8000/p2p/... /ip4/127.0.0.1/tcp/8001/p2p/...
|
||||
|
||||
How It Works
|
||||
------------
|
||||
|
||||
The Kademlia DHT implementation demonstrates several key concepts:
|
||||
|
||||
**Server Mode:**
|
||||
- Stores key-value pairs in the distributed hash table
|
||||
- Advertises itself as a content provider for specific content
|
||||
- Handles incoming DHT requests from other nodes
|
||||
- Maintains routing table with known peers
|
||||
|
||||
**Client Mode:**
|
||||
- Connects to bootstrap nodes to join the network
|
||||
- Retrieves values by their keys from the DHT
|
||||
- Discovers content providers for specific content
|
||||
- Performs network lookups using the Kademlia algorithm
|
||||
|
||||
**Key Components:**
|
||||
- **Routing Table**: Organizes peers in k-buckets based on XOR distance
|
||||
- **Value Store**: Manages key-value storage with TTL (time-to-live)
|
||||
- **Provider Store**: Tracks which peers provide specific content
|
||||
- **Peer Routing**: Implements iterative lookups to find closest peers
|
||||
|
||||
The full source code for this example is below:
|
||||
|
||||
.. literalinclude:: ../examples/kademlia/kademlia.py
|
||||
:language: python
|
||||
:linenos:
|
||||
@ -1,64 +0,0 @@
|
||||
mDNS Peer Discovery Example
|
||||
===========================
|
||||
|
||||
This example demonstrates how to use mDNS (Multicast DNS) for peer discovery in py-libp2p.
|
||||
|
||||
Prerequisites
|
||||
-------------
|
||||
|
||||
First, ensure you have py-libp2p installed and your environment is activated:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python -m pip install libp2p
|
||||
|
||||
Running the Example
|
||||
-------------------
|
||||
|
||||
The mDNS demo script allows you to discover peers on your local network using mDNS. To start a peer, run:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ mdns-demo
|
||||
|
||||
You should see output similar to:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
Run this from another console to start another peer on a different port:
|
||||
|
||||
python mdns-demo -p <ANOTHER_PORT>
|
||||
|
||||
Waiting for mDNS peer discovery events...
|
||||
|
||||
2025-06-20 23:28:12,052 - libp2p.example.discovery.mdns - INFO - Starting peer Discovery
|
||||
|
||||
To discover peers, open another terminal and run the same command with a different port:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python mdns-demo -p 9001
|
||||
|
||||
You should see output indicating that a new peer has been discovered:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
Run this from the same folder in another console to start another peer on a different port:
|
||||
|
||||
python mdns-demo -p <ANOTHER_PORT>
|
||||
|
||||
Waiting for mDNS peer discovery events...
|
||||
|
||||
2025-06-20 23:43:43,786 - libp2p.example.discovery.mdns - INFO - Starting peer Discovery
|
||||
2025-06-20 23:43:43,790 - libp2p.example.discovery.mdns - INFO - Discovered: 16Uiu2HAmGxy5NdQEjZWtrYUMrzdp3Syvg7MB2E5Lx8weA9DanYxj
|
||||
|
||||
When a new peer is discovered, its peer ID will be printed in the console output.
|
||||
|
||||
How it Works
|
||||
------------
|
||||
|
||||
- Each node advertises itself on the local network using mDNS.
|
||||
- When a new peer is discovered, the handler prints its peer ID.
|
||||
- This is useful for local peer discovery without requiring a DHT or bootstrap nodes.
|
||||
|
||||
You can modify the script to perform additional actions when peers are discovered, such as opening streams or exchanging messages.
|
||||
@ -11,6 +11,3 @@ Examples
|
||||
examples.echo
|
||||
examples.ping
|
||||
examples.pubsub
|
||||
examples.circuit_relay
|
||||
examples.kademlia
|
||||
examples.mDNS
|
||||
|
||||
@ -12,6 +12,10 @@ The Python implementation of the libp2p networking stack
|
||||
getting_started
|
||||
release_notes
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Community
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: py-libp2p
|
||||
|
||||
@ -1,21 +0,0 @@
|
||||
libp2p.discovery.events package
|
||||
===============================
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
libp2p.discovery.events.peerDiscovery module
|
||||
--------------------------------------------
|
||||
|
||||
.. automodule:: libp2p.discovery.events.peerDiscovery
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Module contents
|
||||
---------------
|
||||
|
||||
.. automodule:: libp2p.discovery.events
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
@ -1,45 +0,0 @@
|
||||
libp2p.discovery.mdns package
|
||||
=============================
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
libp2p.discovery.mdns.broadcaster module
|
||||
----------------------------------------
|
||||
|
||||
.. automodule:: libp2p.discovery.mdns.broadcaster
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.discovery.mdns.listener module
|
||||
-------------------------------------
|
||||
|
||||
.. automodule:: libp2p.discovery.mdns.listener
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.discovery.mdns.mdns module
|
||||
---------------------------------
|
||||
|
||||
.. automodule:: libp2p.discovery.mdns.mdns
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.discovery.mdns.utils module
|
||||
----------------------------------
|
||||
|
||||
.. automodule:: libp2p.discovery.mdns.utils
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Module contents
|
||||
---------------
|
||||
|
||||
.. automodule:: libp2p.discovery.mdns
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
@ -1,22 +0,0 @@
|
||||
libp2p.discovery package
|
||||
========================
|
||||
|
||||
Subpackages
|
||||
-----------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 4
|
||||
|
||||
libp2p.discovery.events
|
||||
libp2p.discovery.mdns
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
Module contents
|
||||
---------------
|
||||
|
||||
.. automodule:: libp2p.discovery
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
@ -1,22 +0,0 @@
|
||||
libp2p.kad\_dht.pb package
|
||||
==========================
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
libp2p.kad_dht.pb.kademlia_pb2 module
|
||||
-------------------------------------
|
||||
|
||||
.. automodule:: libp2p.kad_dht.pb.kademlia_pb2
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Module contents
|
||||
---------------
|
||||
|
||||
.. automodule:: libp2p.kad_dht.pb
|
||||
:no-index:
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
@ -1,77 +0,0 @@
|
||||
libp2p.kad\_dht package
|
||||
=======================
|
||||
|
||||
Subpackages
|
||||
-----------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 4
|
||||
|
||||
libp2p.kad_dht.pb
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
libp2p.kad\_dht.kad\_dht module
|
||||
-------------------------------
|
||||
|
||||
.. automodule:: libp2p.kad_dht.kad_dht
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.kad\_dht.peer\_routing module
|
||||
------------------------------------
|
||||
|
||||
.. automodule:: libp2p.kad_dht.peer_routing
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.kad\_dht.provider\_store module
|
||||
--------------------------------------
|
||||
|
||||
.. automodule:: libp2p.kad_dht.provider_store
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.kad\_dht.routing\_table module
|
||||
-------------------------------------
|
||||
|
||||
.. automodule:: libp2p.kad_dht.routing_table
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.kad\_dht.utils module
|
||||
----------------------------
|
||||
|
||||
.. automodule:: libp2p.kad_dht.utils
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.kad\_dht.value\_store module
|
||||
-----------------------------------
|
||||
|
||||
.. automodule:: libp2p.kad_dht.value_store
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.kad\_dht.pb
|
||||
------------------
|
||||
|
||||
.. automodule:: libp2p.kad_dht.pb
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Module contents
|
||||
---------------
|
||||
|
||||
.. automodule:: libp2p.kad_dht
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
@ -1,22 +0,0 @@
|
||||
libp2p.relay.circuit_v2.pb package
|
||||
==================================
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
libp2p.relay.circuit_v2.pb.circuit_pb2 module
|
||||
---------------------------------------------
|
||||
|
||||
.. automodule:: libp2p.relay.circuit_v2.pb.circuit_pb2
|
||||
:members:
|
||||
:show-inheritance:
|
||||
:undoc-members:
|
||||
|
||||
Module contents
|
||||
---------------
|
||||
|
||||
.. automodule:: libp2p.relay.circuit_v2.pb
|
||||
:members:
|
||||
:show-inheritance:
|
||||
:undoc-members:
|
||||
:no-index:
|
||||
@ -1,70 +0,0 @@
|
||||
libp2p.relay.circuit_v2 package
|
||||
===============================
|
||||
|
||||
Subpackages
|
||||
-----------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 4
|
||||
|
||||
libp2p.relay.circuit_v2.pb
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
libp2p.relay.circuit_v2.protocol module
|
||||
---------------------------------------
|
||||
|
||||
.. automodule:: libp2p.relay.circuit_v2.protocol
|
||||
:members:
|
||||
:show-inheritance:
|
||||
:undoc-members:
|
||||
|
||||
libp2p.relay.circuit_v2.transport module
|
||||
----------------------------------------
|
||||
|
||||
.. automodule:: libp2p.relay.circuit_v2.transport
|
||||
:members:
|
||||
:show-inheritance:
|
||||
:undoc-members:
|
||||
|
||||
libp2p.relay.circuit_v2.discovery module
|
||||
----------------------------------------
|
||||
|
||||
.. automodule:: libp2p.relay.circuit_v2.discovery
|
||||
:members:
|
||||
:show-inheritance:
|
||||
:undoc-members:
|
||||
|
||||
libp2p.relay.circuit_v2.resources module
|
||||
----------------------------------------
|
||||
|
||||
.. automodule:: libp2p.relay.circuit_v2.resources
|
||||
:members:
|
||||
:show-inheritance:
|
||||
:undoc-members:
|
||||
|
||||
libp2p.relay.circuit_v2.config module
|
||||
-------------------------------------
|
||||
|
||||
.. automodule:: libp2p.relay.circuit_v2.config
|
||||
:members:
|
||||
:show-inheritance:
|
||||
:undoc-members:
|
||||
|
||||
libp2p.relay.circuit_v2.protocol_buffer module
|
||||
----------------------------------------------
|
||||
|
||||
.. automodule:: libp2p.relay.circuit_v2.protocol_buffer
|
||||
:members:
|
||||
:show-inheritance:
|
||||
:undoc-members:
|
||||
|
||||
Module contents
|
||||
---------------
|
||||
|
||||
.. automodule:: libp2p.relay.circuit_v2
|
||||
:members:
|
||||
:show-inheritance:
|
||||
:undoc-members:
|
||||
:no-index:
|
||||
@ -1,19 +0,0 @@
|
||||
libp2p.relay package
|
||||
====================
|
||||
|
||||
Subpackages
|
||||
-----------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 4
|
||||
|
||||
libp2p.relay.circuit_v2
|
||||
|
||||
Module contents
|
||||
---------------
|
||||
|
||||
.. automodule:: libp2p.relay
|
||||
:members:
|
||||
:show-inheritance:
|
||||
:undoc-members:
|
||||
:no-index:
|
||||
@ -8,16 +8,13 @@ Subpackages
|
||||
:maxdepth: 4
|
||||
|
||||
libp2p.crypto
|
||||
libp2p.discovery
|
||||
libp2p.host
|
||||
libp2p.identity
|
||||
libp2p.io
|
||||
libp2p.kad_dht
|
||||
libp2p.network
|
||||
libp2p.peer
|
||||
libp2p.protocol_muxer
|
||||
libp2p.pubsub
|
||||
libp2p.relay
|
||||
libp2p.security
|
||||
libp2p.stream_muxer
|
||||
libp2p.tools
|
||||
|
||||
@ -3,110 +3,6 @@ Release Notes
|
||||
|
||||
.. towncrier release notes start
|
||||
|
||||
py-libp2p v0.2.9 (2025-07-09)
|
||||
-----------------------------
|
||||
|
||||
Breaking Changes
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
- Reordered the arguments to ``upgrade_security`` to place ``is_initiator`` before ``peer_id``, and made ``peer_id`` optional.
|
||||
This allows the method to reflect the fact that peer identity is not required for inbound connections. (`#681 <https://github.com/libp2p/py-libp2p/issues/681>`__)
|
||||
|
||||
|
||||
Bugfixes
|
||||
~~~~~~~~
|
||||
|
||||
- Add timeout wrappers in:
|
||||
1. ``multiselect.py``: ``negotiate`` function
|
||||
2. ``multiselect_client.py``: ``select_one_of`` , ``query_multistream_command`` functions
|
||||
to prevent indefinite hangs when a remote peer does not respond. (`#696 <https://github.com/libp2p/py-libp2p/issues/696>`__)
|
||||
- Align stream creation logic with yamux specification (`#701 <https://github.com/libp2p/py-libp2p/issues/701>`__)
|
||||
- Fixed an issue in ``Pubsub`` where async validators were not handled reliably under concurrency. Now uses a safe aggregator list for consistent behavior. (`#702 <https://github.com/libp2p/py-libp2p/issues/702>`__)
|
||||
|
||||
|
||||
Features
|
||||
~~~~~~~~
|
||||
|
||||
- Added support for ``Kademlia DHT`` in py-libp2p. (`#579 <https://github.com/libp2p/py-libp2p/issues/579>`__)
|
||||
- Limit concurrency in ``push_identify_to_peers`` to prevent resource congestion under high peer counts. (`#621 <https://github.com/libp2p/py-libp2p/issues/621>`__)
|
||||
- Store public key and peer ID in peerstore during handshake
|
||||
|
||||
Modified the InsecureTransport class to accept an optional peerstore parameter and updated the handshake process to store the received public key and peer ID in the peerstore when available.
|
||||
|
||||
Added test cases to verify:
|
||||
1. The peerstore remains unchanged when handshake fails due to peer ID mismatch
|
||||
2. The handshake correctly adds a public key to a peer ID that already exists in the peerstore but doesn't have a public key yet (`#631 <https://github.com/libp2p/py-libp2p/issues/631>`__)
|
||||
- Fixed several flow-control and concurrency issues in the ``YamuxStream`` class. Previously, stress-testing revealed that transferring data over ``DEFAULT_WINDOW_SIZE`` would break the stream due to inconsistent window update handling and lock management. The fixes include:
|
||||
|
||||
- Removed sending of window updates during writes to maintain correct flow-control.
|
||||
- Added proper timeout handling when releasing and acquiring locks to prevent concurrency errors.
|
||||
- Corrected the ``read`` function to properly handle window updates for both ``read_until_EOF`` and ``read_n_bytes``.
|
||||
- Added event logging at ``send_window_updates`` and ``waiting_for_window_updates`` for better observability. (`#639 <https://github.com/libp2p/py-libp2p/issues/639>`__)
|
||||
- Added support for ``Multicast DNS`` in py-libp2p (`#649 <https://github.com/libp2p/py-libp2p/issues/649>`__)
|
||||
- Optimized pubsub publishing to send multiple topics in a single message instead of separate messages per topic. (`#685 <https://github.com/libp2p/py-libp2p/issues/685>`__)
|
||||
- Optimized pubsub message writing by implementing a write_msg() method that uses pre-allocated buffers and single write operations, improving performance by eliminating separate varint prefix encoding and write operations in FloodSub and GossipSub. (`#687 <https://github.com/libp2p/py-libp2p/issues/687>`__)
|
||||
- Added peer exchange and backoff logic as part of Gossipsub v1.1 upgrade (`#690 <https://github.com/libp2p/py-libp2p/issues/690>`__)
|
||||
|
||||
|
||||
Internal Changes - for py-libp2p Contributors
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
- Added sparse connect utility function to pubsub test utilities for creating test networks with configurable connectivity. (`#679 <https://github.com/libp2p/py-libp2p/issues/679>`__)
|
||||
- Added comprehensive tests for pubsub connection utility functions to verify degree limits are enforced, excess peers are handled correctly, and edge cases (degree=0, negative values, empty lists) are managed gracefully. (`#707 <https://github.com/libp2p/py-libp2p/issues/707>`__)
|
||||
- Added extra tests for identify push concurrency cap under high peer load (`#708 <https://github.com/libp2p/py-libp2p/issues/708>`__)
|
||||
|
||||
|
||||
Miscellaneous Changes
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
- `#678 <https://github.com/libp2p/py-libp2p/issues/678>`__, `#684 <https://github.com/libp2p/py-libp2p/issues/684>`__
|
||||
|
||||
|
||||
py-libp2p v0.2.8 (2025-06-10)
|
||||
-----------------------------
|
||||
|
||||
Breaking Changes
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
- The `NetStream.state` property is now async and requires `await`. Update any direct state access to use `await stream.state`. (`#300 <https://github.com/libp2p/py-libp2p/issues/300>`__)
|
||||
|
||||
|
||||
Bugfixes
|
||||
~~~~~~~~
|
||||
|
||||
- Added proper state management and resource cleanup to `NetStream`, fixing memory leaks and improved error handling. (`#300 <https://github.com/libp2p/py-libp2p/issues/300>`__)
|
||||
|
||||
|
||||
Improved Documentation
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
- Updated examples to automatically use random port, when `-p` flag is not given (`#661 <https://github.com/libp2p/py-libp2p/issues/661>`__)
|
||||
|
||||
|
||||
Features
|
||||
~~~~~~~~
|
||||
|
||||
- Allow passing `listen_addrs` to `new_swarm` to customize swarm listening behavior. (`#616 <https://github.com/libp2p/py-libp2p/issues/616>`__)
|
||||
- Feature: Support for sending `ls` command over `multistream-select` to list supported protocols from remote peer.
|
||||
This allows inspecting which protocol handlers a peer supports at runtime. (`#622 <https://github.com/libp2p/py-libp2p/issues/622>`__)
|
||||
- implement AsyncContextManager for IMuxedStream to support async with (`#629 <https://github.com/libp2p/py-libp2p/issues/629>`__)
|
||||
- feat: add method to compute time since last message published by a peer and remove fanout peers based on ttl. (`#636 <https://github.com/libp2p/py-libp2p/issues/636>`__)
|
||||
- implement blacklist management for `pubsub.Pubsub` with methods to get, add, remove, check, and clear blacklisted peer IDs. (`#641 <https://github.com/libp2p/py-libp2p/issues/641>`__)
|
||||
- fix: remove expired peers from peerstore based on TTL (`#650 <https://github.com/libp2p/py-libp2p/issues/650>`__)
|
||||
|
||||
|
||||
Internal Changes - for py-libp2p Contributors
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
- Modernizes several aspects of the project, notably using ``pyproject.toml`` for project info instead of ``setup.py``, using ``ruff`` to replace several separate linting tools, and ``pyrefly`` in addition to ``mypy`` for typing. Also includes changes across the codebase to conform to new linting and typing rules. (`#618 <https://github.com/libp2p/py-libp2p/issues/618>`__)
|
||||
|
||||
|
||||
Removals
|
||||
~~~~~~~~
|
||||
|
||||
- Removes support for python 3.9 and updates some code conventions, notably using ``|`` operator in typing instead of ``Optional`` or ``Union`` (`#618 <https://github.com/libp2p/py-libp2p/issues/618>`__)
|
||||
|
||||
|
||||
py-libp2p v0.2.7 (2025-05-22)
|
||||
-----------------------------
|
||||
|
||||
|
||||
@ -40,6 +40,7 @@ async def write_data(stream: INetStream) -> None:
|
||||
|
||||
|
||||
async def run(port: int, destination: str) -> None:
|
||||
localhost_ip = "127.0.0.1"
|
||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
||||
host = new_host()
|
||||
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
|
||||
@ -53,8 +54,8 @@ async def run(port: int, destination: str) -> None:
|
||||
|
||||
print(
|
||||
"Run this from the same folder in another console:\n\n"
|
||||
f"chat-demo "
|
||||
f"-d {host.get_addrs()[0]}\n"
|
||||
f"chat-demo -p {int(port) + 1} "
|
||||
f"-d /ip4/{localhost_ip}/tcp/{port}/p2p/{host.get_id().pretty()}\n"
|
||||
)
|
||||
print("Waiting for incoming connection...")
|
||||
|
||||
@ -86,7 +87,9 @@ def main() -> None:
|
||||
"/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
|
||||
)
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
|
||||
parser.add_argument(
|
||||
"-p", "--port", default=8000, type=int, help="source port number"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--destination",
|
||||
@ -95,6 +98,9 @@ def main() -> None:
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.port:
|
||||
raise RuntimeError("was not able to determine a local port")
|
||||
|
||||
try:
|
||||
trio.run(run, *(args.port, args.destination))
|
||||
except KeyboardInterrupt:
|
||||
|
||||
@ -27,9 +27,6 @@ async def main():
|
||||
# secure_bytes_provider: Optional function to generate secure random bytes
|
||||
# (defaults to secrets.token_bytes)
|
||||
secure_bytes_provider=None, # Use default implementation
|
||||
# peerstore: Optional peerstore to store peer IDs and public keys
|
||||
# (defaults to None)
|
||||
peerstore=None,
|
||||
)
|
||||
|
||||
# Create a security options dictionary mapping protocol ID to transport
|
||||
|
||||
@ -9,10 +9,8 @@ from libp2p import (
|
||||
from libp2p.crypto.secp256k1 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.security.noise.transport import (
|
||||
PROTOCOL_ID as NOISE_PROTOCOL_ID,
|
||||
Transport as NoiseTransport,
|
||||
)
|
||||
from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID
|
||||
from libp2p.security.noise.transport import Transport as NoiseTransport
|
||||
|
||||
|
||||
async def main():
|
||||
|
||||
@ -9,10 +9,8 @@ from libp2p import (
|
||||
from libp2p.crypto.secp256k1 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.security.secio.transport import (
|
||||
ID as SECIO_PROTOCOL_ID,
|
||||
Transport as SecioTransport,
|
||||
)
|
||||
from libp2p.security.secio.transport import ID as SECIO_PROTOCOL_ID
|
||||
from libp2p.security.secio.transport import Transport as SecioTransport
|
||||
|
||||
|
||||
async def main():
|
||||
@ -24,6 +22,9 @@ async def main():
|
||||
secio_transport = SecioTransport(
|
||||
# local_key_pair: The key pair used for libp2p identity and authentication
|
||||
local_key_pair=key_pair,
|
||||
# secure_bytes_provider: Optional function to generate secure random bytes
|
||||
# (defaults to secrets.token_bytes)
|
||||
secure_bytes_provider=None, # Use default implementation
|
||||
)
|
||||
|
||||
# Create a security options dictionary mapping protocol ID to transport
|
||||
|
||||
@ -9,9 +9,10 @@ from libp2p import (
|
||||
from libp2p.crypto.secp256k1 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.security.noise.transport import (
|
||||
PROTOCOL_ID as NOISE_PROTOCOL_ID,
|
||||
Transport as NoiseTransport,
|
||||
from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID
|
||||
from libp2p.security.noise.transport import Transport as NoiseTransport
|
||||
from libp2p.stream_muxer.mplex.mplex import (
|
||||
MPLEX_PROTOCOL_ID,
|
||||
)
|
||||
|
||||
|
||||
@ -36,8 +37,14 @@ async def main():
|
||||
# Create a security options dictionary mapping protocol ID to transport
|
||||
security_options = {NOISE_PROTOCOL_ID: noise_transport}
|
||||
|
||||
# Create a muxer options dictionary mapping protocol ID to muxer class
|
||||
# We don't need to instantiate the muxer here, the host will do that for us
|
||||
muxer_options = {MPLEX_PROTOCOL_ID: None}
|
||||
|
||||
# Create a host with the key pair, Noise security, and mplex multiplexer
|
||||
host = new_host(key_pair=key_pair, sec_opt=security_options)
|
||||
host = new_host(
|
||||
key_pair=key_pair, sec_opt=security_options, muxer_opt=muxer_options
|
||||
)
|
||||
|
||||
# Configure the listening address
|
||||
port = 8000
|
||||
|
||||
@ -1,263 +0,0 @@
|
||||
"""
|
||||
Enhanced NetStream Example for py-libp2p with State Management
|
||||
|
||||
This example demonstrates the new NetStream features including:
|
||||
- State tracking and transitions
|
||||
- Proper error handling and validation
|
||||
- Resource cleanup and event notifications
|
||||
- Thread-safe operations with Trio locks
|
||||
|
||||
Based on the standard echo demo but enhanced to show NetStream state management.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import random
|
||||
import secrets
|
||||
|
||||
import multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p import (
|
||||
new_host,
|
||||
)
|
||||
from libp2p.crypto.secp256k1 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.network.stream.exceptions import (
|
||||
StreamClosed,
|
||||
StreamEOF,
|
||||
StreamReset,
|
||||
)
|
||||
from libp2p.network.stream.net_stream import (
|
||||
NetStream,
|
||||
StreamState,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
info_from_p2p_addr,
|
||||
)
|
||||
|
||||
PROTOCOL_ID = TProtocol("/echo/1.0.0")
|
||||
|
||||
|
||||
async def enhanced_echo_handler(stream: NetStream) -> None:
|
||||
"""
|
||||
Enhanced echo handler that demonstrates NetStream state management.
|
||||
"""
|
||||
print(f"New connection established: {stream}")
|
||||
print(f"Initial stream state: {await stream.state}")
|
||||
|
||||
try:
|
||||
# Verify stream is in expected initial state
|
||||
assert await stream.state == StreamState.OPEN
|
||||
assert await stream.is_readable()
|
||||
assert await stream.is_writable()
|
||||
print("✓ Stream initialized in OPEN state")
|
||||
|
||||
# Read incoming data with proper state checking
|
||||
print("Waiting for client data...")
|
||||
|
||||
while await stream.is_readable():
|
||||
try:
|
||||
# Read data from client
|
||||
data = await stream.read(1024)
|
||||
if not data:
|
||||
print("Received empty data, client may have closed")
|
||||
break
|
||||
|
||||
print(f"Received: {data.decode('utf-8').strip()}")
|
||||
|
||||
# Check if we can still write before echoing
|
||||
if await stream.is_writable():
|
||||
await stream.write(data)
|
||||
print(f"Echoed: {data.decode('utf-8').strip()}")
|
||||
else:
|
||||
print("Cannot echo - stream not writable")
|
||||
break
|
||||
|
||||
except StreamEOF:
|
||||
print("Client closed their write side (EOF)")
|
||||
break
|
||||
except StreamReset:
|
||||
print("Stream was reset by client")
|
||||
return
|
||||
except StreamClosed as e:
|
||||
print(f"Stream operation failed: {e}")
|
||||
break
|
||||
|
||||
# Demonstrate graceful closure
|
||||
current_state = await stream.state
|
||||
print(f"Current state before close: {current_state}")
|
||||
|
||||
if current_state not in [StreamState.CLOSE_BOTH, StreamState.RESET]:
|
||||
await stream.close()
|
||||
print("Server closed write side")
|
||||
|
||||
final_state = await stream.state
|
||||
print(f"Final stream state: {final_state}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Handler error: {e}")
|
||||
# Reset stream on unexpected errors
|
||||
if await stream.state not in [StreamState.RESET, StreamState.CLOSE_BOTH]:
|
||||
await stream.reset()
|
||||
print("Stream reset due to error")
|
||||
|
||||
|
||||
async def enhanced_client_demo(stream: NetStream) -> None:
|
||||
"""
|
||||
Enhanced client that demonstrates various NetStream state scenarios.
|
||||
"""
|
||||
print(f"Client stream established: {stream}")
|
||||
print(f"Initial state: {await stream.state}")
|
||||
|
||||
try:
|
||||
# Verify initial state
|
||||
assert await stream.state == StreamState.OPEN
|
||||
print("✓ Client stream in OPEN state")
|
||||
|
||||
# Scenario 1: Normal communication
|
||||
message = b"Hello from enhanced NetStream client!\n"
|
||||
|
||||
if await stream.is_writable():
|
||||
await stream.write(message)
|
||||
print(f"Sent: {message.decode('utf-8').strip()}")
|
||||
else:
|
||||
print("Cannot write - stream not writable")
|
||||
return
|
||||
|
||||
# Close write side to signal EOF to server
|
||||
await stream.close()
|
||||
print("Client closed write side")
|
||||
|
||||
# Verify state transition
|
||||
state_after_close = await stream.state
|
||||
print(f"State after close: {state_after_close}")
|
||||
assert state_after_close == StreamState.CLOSE_WRITE
|
||||
assert await stream.is_readable() # Should still be readable
|
||||
assert not await stream.is_writable() # Should not be writable
|
||||
|
||||
# Try to write (should fail)
|
||||
try:
|
||||
await stream.write(b"This should fail")
|
||||
print("ERROR: Write succeeded when it should have failed!")
|
||||
except StreamClosed as e:
|
||||
print(f"✓ Expected error when writing to closed stream: {e}")
|
||||
|
||||
# Read the echo response
|
||||
if await stream.is_readable():
|
||||
try:
|
||||
response = await stream.read()
|
||||
print(f"Received echo: {response.decode('utf-8').strip()}")
|
||||
except StreamEOF:
|
||||
print("Server closed their write side")
|
||||
except StreamReset:
|
||||
print("Stream was reset")
|
||||
|
||||
# Check final state
|
||||
final_state = await stream.state
|
||||
print(f"Final client state: {final_state}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Client error: {e}")
|
||||
# Reset on error
|
||||
await stream.reset()
|
||||
print("Client reset stream due to error")
|
||||
|
||||
|
||||
async def run_enhanced_demo(
|
||||
port: int, destination: str, seed: int | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Run enhanced echo demo with NetStream state management.
|
||||
"""
|
||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
||||
|
||||
# Generate or use provided key
|
||||
if seed:
|
||||
random.seed(seed)
|
||||
secret_number = random.getrandbits(32 * 8)
|
||||
secret = secret_number.to_bytes(length=32, byteorder="big")
|
||||
else:
|
||||
secret = secrets.token_bytes(32)
|
||||
|
||||
host = new_host(key_pair=create_new_key_pair(secret))
|
||||
|
||||
async with host.run(listen_addrs=[listen_addr]):
|
||||
print(f"Host ID: {host.get_id().to_string()}")
|
||||
print("=" * 60)
|
||||
|
||||
if not destination: # Server mode
|
||||
print("🖥️ ENHANCED ECHO SERVER MODE")
|
||||
print("=" * 60)
|
||||
|
||||
# type: ignore: Stream is type of NetStream
|
||||
host.set_stream_handler(PROTOCOL_ID, enhanced_echo_handler)
|
||||
|
||||
print(
|
||||
"Run client from another console:\n"
|
||||
f"python3 example_net_stream.py "
|
||||
f"-d {host.get_addrs()[0]}\n"
|
||||
)
|
||||
print("Waiting for connections...")
|
||||
print("Press Ctrl+C to stop server")
|
||||
await trio.sleep_forever()
|
||||
|
||||
else: # Client mode
|
||||
print("📱 ENHANCED ECHO CLIENT MODE")
|
||||
print("=" * 60)
|
||||
|
||||
# Connect to server
|
||||
maddr = multiaddr.Multiaddr(destination)
|
||||
info = info_from_p2p_addr(maddr)
|
||||
await host.connect(info)
|
||||
print(f"Connected to server: {info.peer_id.pretty()}")
|
||||
|
||||
# Create stream and run enhanced demo
|
||||
stream = await host.new_stream(info.peer_id, [PROTOCOL_ID])
|
||||
if isinstance(stream, NetStream):
|
||||
await enhanced_client_demo(stream)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("CLIENT DEMO COMPLETE")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
example_maddr = (
|
||||
"/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--destination",
|
||||
type=str,
|
||||
help=f"destination multiaddr string, e.g. {example_maddr}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--seed",
|
||||
type=int,
|
||||
help="seed for deterministic peer ID generation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--demo-states", action="store_true", help="run state transition demo only"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
trio.run(run_enhanced_demo, args.port, args.destination, args.seed)
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 Demo interrupted by user")
|
||||
except Exception as e:
|
||||
print(f"❌ Demo failed: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -12,9 +12,10 @@ from libp2p.crypto.secp256k1 import (
|
||||
from libp2p.peer.peerinfo import (
|
||||
info_from_p2p_addr,
|
||||
)
|
||||
from libp2p.security.noise.transport import (
|
||||
PROTOCOL_ID as NOISE_PROTOCOL_ID,
|
||||
Transport as NoiseTransport,
|
||||
from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID
|
||||
from libp2p.security.noise.transport import Transport as NoiseTransport
|
||||
from libp2p.stream_muxer.mplex.mplex import (
|
||||
MPLEX_PROTOCOL_ID,
|
||||
)
|
||||
|
||||
|
||||
@ -39,8 +40,14 @@ async def main():
|
||||
# Create a security options dictionary mapping protocol ID to transport
|
||||
security_options = {NOISE_PROTOCOL_ID: noise_transport}
|
||||
|
||||
# Create a muxer options dictionary mapping protocol ID to muxer class
|
||||
# We don't need to instantiate the muxer here, the host will do that for us
|
||||
muxer_options = {MPLEX_PROTOCOL_ID: None}
|
||||
|
||||
# Create a host with the key pair, Noise security, and mplex multiplexer
|
||||
host = new_host(key_pair=key_pair, sec_opt=security_options)
|
||||
host = new_host(
|
||||
key_pair=key_pair, sec_opt=security_options, muxer_opt=muxer_options
|
||||
)
|
||||
|
||||
# Configure the listening address
|
||||
port = 8000
|
||||
|
||||
@ -9,9 +9,10 @@ from libp2p import (
|
||||
from libp2p.crypto.secp256k1 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.security.noise.transport import (
|
||||
PROTOCOL_ID as NOISE_PROTOCOL_ID,
|
||||
Transport as NoiseTransport,
|
||||
from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID
|
||||
from libp2p.security.noise.transport import Transport as NoiseTransport
|
||||
from libp2p.stream_muxer.mplex.mplex import (
|
||||
MPLEX_PROTOCOL_ID,
|
||||
)
|
||||
|
||||
|
||||
@ -36,8 +37,14 @@ async def main():
|
||||
# Create a security options dictionary mapping protocol ID to transport
|
||||
security_options = {NOISE_PROTOCOL_ID: noise_transport}
|
||||
|
||||
# Create a muxer options dictionary mapping protocol ID to muxer class
|
||||
# We don't need to instantiate the muxer here, the host will do that for us
|
||||
muxer_options = {MPLEX_PROTOCOL_ID: None}
|
||||
|
||||
# Create a host with the key pair, Noise security, and mplex multiplexer
|
||||
host = new_host(key_pair=key_pair, sec_opt=security_options)
|
||||
host = new_host(
|
||||
key_pair=key_pair, sec_opt=security_options, muxer_opt=muxer_options
|
||||
)
|
||||
|
||||
# Configure the listening address
|
||||
port = 8000
|
||||
|
||||
@ -20,17 +20,17 @@ from libp2p.peer.peerinfo import (
|
||||
)
|
||||
|
||||
PROTOCOL_ID = TProtocol("/echo/1.0.0")
|
||||
MAX_READ_LEN = 2**32 - 1
|
||||
|
||||
|
||||
async def _echo_stream_handler(stream: INetStream) -> None:
|
||||
# Wait until EOF
|
||||
msg = await stream.read(MAX_READ_LEN)
|
||||
msg = await stream.read()
|
||||
await stream.write(msg)
|
||||
await stream.close()
|
||||
|
||||
|
||||
async def run(port: int, destination: str, seed: int | None = None) -> None:
|
||||
async def run(port: int, destination: str, seed: int = None) -> None:
|
||||
localhost_ip = "127.0.0.1"
|
||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
||||
|
||||
if seed:
|
||||
@ -53,8 +53,8 @@ async def run(port: int, destination: str, seed: int | None = None) -> None:
|
||||
|
||||
print(
|
||||
"Run this from the same folder in another console:\n\n"
|
||||
f"echo-demo "
|
||||
f"-d {host.get_addrs()[0]}\n"
|
||||
f"echo-demo -p {int(port) + 1} "
|
||||
f"-d /ip4/{localhost_ip}/tcp/{port}/p2p/{host.get_id().pretty()}\n"
|
||||
)
|
||||
print("Waiting for incoming connections...")
|
||||
await trio.sleep_forever()
|
||||
@ -73,8 +73,9 @@ async def run(port: int, destination: str, seed: int | None = None) -> None:
|
||||
msg = b"hi, there!\n"
|
||||
|
||||
await stream.write(msg)
|
||||
response = await stream.read()
|
||||
# Notify the other side about EOF
|
||||
await stream.close()
|
||||
response = await stream.read()
|
||||
|
||||
print(f"Sent: {msg.decode('utf-8')}")
|
||||
print(f"Got: {response.decode('utf-8')}")
|
||||
@ -93,7 +94,9 @@ def main() -> None:
|
||||
"/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
|
||||
)
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
|
||||
parser.add_argument(
|
||||
"-p", "--port", default=8000, type=int, help="source port number"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--destination",
|
||||
@ -107,6 +110,10 @@ def main() -> None:
|
||||
help="provide a seed to the random number generator (e.g. to fix peer IDs across runs)", # noqa: E501
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.port:
|
||||
raise RuntimeError("was not able to determine a local port")
|
||||
|
||||
try:
|
||||
trio.run(run, args.port, args.destination, args.seed)
|
||||
except KeyboardInterrupt:
|
||||
|
||||
@ -8,10 +8,9 @@ import trio
|
||||
from libp2p import (
|
||||
new_host,
|
||||
)
|
||||
from libp2p.identity.identify.identify import (
|
||||
ID as IDENTIFY_PROTOCOL_ID,
|
||||
identify_handler_for,
|
||||
parse_identify_response,
|
||||
from libp2p.identity.identify.identify import ID as IDENTIFY_PROTOCOL_ID
|
||||
from libp2p.identity.identify.pb.identify_pb2 import (
|
||||
Identify,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
info_from_p2p_addr,
|
||||
@ -51,7 +50,7 @@ def print_identify_response(identify_response):
|
||||
)
|
||||
|
||||
|
||||
async def run(port: int, destination: str, use_varint_format: bool = True) -> None:
|
||||
async def run(port: int, destination: str) -> None:
|
||||
localhost_ip = "0.0.0.0"
|
||||
|
||||
if not destination:
|
||||
@ -59,36 +58,23 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No
|
||||
listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}")
|
||||
host_a = new_host()
|
||||
|
||||
# Set up identify handler with specified format
|
||||
identify_handler = identify_handler_for(
|
||||
host_a, use_varint_format=use_varint_format
|
||||
)
|
||||
host_a.set_stream_handler(IDENTIFY_PROTOCOL_ID, identify_handler)
|
||||
|
||||
async with host_a.run(listen_addrs=[listen_addr]):
|
||||
# Get the actual address and replace 0.0.0.0 with 127.0.0.1 for client
|
||||
# connections
|
||||
server_addr = str(host_a.get_addrs()[0])
|
||||
client_addr = server_addr.replace("/ip4/0.0.0.0/", "/ip4/127.0.0.1/")
|
||||
|
||||
format_name = "length-prefixed" if use_varint_format else "raw protobuf"
|
||||
print(
|
||||
f"First host listening (using {format_name} format). "
|
||||
f"Run this from another console:\n\n"
|
||||
f"identify-demo "
|
||||
f"-d {client_addr}\n"
|
||||
"First host listening. Run this from another console:\n\n"
|
||||
f"identify-demo -p {int(port) + 1} "
|
||||
f"-d /ip4/{localhost_ip}/tcp/{port}/p2p/{host_a.get_id().pretty()}\n"
|
||||
)
|
||||
print("Waiting for incoming identify request...")
|
||||
await trio.sleep_forever()
|
||||
|
||||
else:
|
||||
# Create second host (dialer)
|
||||
print(f"dialer (host_b) listening on /ip4/{localhost_ip}/tcp/{port}")
|
||||
listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}")
|
||||
host_b = new_host()
|
||||
|
||||
async with host_b.run(listen_addrs=[listen_addr]):
|
||||
# Connect to the first host
|
||||
print(f"dialer (host_b) listening on {host_b.get_addrs()[0]}")
|
||||
maddr = multiaddr.Multiaddr(destination)
|
||||
info = info_from_p2p_addr(maddr)
|
||||
print(f"Second host connecting to peer: {info.peer_id}")
|
||||
@ -98,18 +84,11 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No
|
||||
|
||||
try:
|
||||
print("Starting identify protocol...")
|
||||
|
||||
# Read the complete response (could be either format)
|
||||
# Read a larger chunk to get all the data before stream closes
|
||||
response = await stream.read(8192) # Read enough data in one go
|
||||
|
||||
response = await stream.read()
|
||||
await stream.close()
|
||||
|
||||
# Parse the response using the robust protocol-level function
|
||||
# This handles both old and new formats automatically
|
||||
identify_msg = parse_identify_response(response)
|
||||
identify_msg = Identify()
|
||||
identify_msg.ParseFromString(response)
|
||||
print_identify_response(identify_msg)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Identify protocol error: {e}")
|
||||
|
||||
@ -119,42 +98,32 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No
|
||||
def main() -> None:
|
||||
description = """
|
||||
This program demonstrates the libp2p identify protocol.
|
||||
First run 'identify-demo -p <PORT> [--raw-format]' to start a listener.
|
||||
First run identify-demo -p <PORT>' to start a listener.
|
||||
Then run 'identify-demo <ANOTHER_PORT> -d <DESTINATION>'
|
||||
where <DESTINATION> is the multiaddress shown by the listener.
|
||||
|
||||
Use --raw-format to send raw protobuf messages (old format) instead of
|
||||
length-prefixed protobuf messages (new format, default).
|
||||
"""
|
||||
|
||||
example_maddr = (
|
||||
"/ip4/127.0.0.1/tcp/8888/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
|
||||
"/ip4/127.0.0.1/tcp/8888/p2p/QmQn4SwGkDZkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
|
||||
parser.add_argument(
|
||||
"-p", "--port", default=8888, type=int, help="source port number"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--destination",
|
||||
type=str,
|
||||
help=f"destination multiaddr string, e.g. {example_maddr}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--raw-format",
|
||||
action="store_true",
|
||||
help=(
|
||||
"use raw protobuf format (old format) instead of "
|
||||
"length-prefixed (new format)"
|
||||
),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Determine format: raw format if --raw-format is specified, otherwise
|
||||
# length-prefixed
|
||||
use_varint_format = not args.raw_format
|
||||
if not args.port:
|
||||
raise RuntimeError("failed to determine local port")
|
||||
|
||||
try:
|
||||
trio.run(run, *(args.port, args.destination, use_varint_format))
|
||||
trio.run(run, *(args.port, args.destination))
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
|
||||
@ -38,17 +38,17 @@ from libp2p.crypto.secp256k1 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.identity.identify import (
|
||||
ID as ID_IDENTIFY,
|
||||
identify_handler_for,
|
||||
)
|
||||
from libp2p.identity.identify import ID as ID_IDENTIFY
|
||||
from libp2p.identity.identify.pb.identify_pb2 import (
|
||||
Identify,
|
||||
)
|
||||
from libp2p.identity.identify_push import (
|
||||
ID_PUSH as ID_IDENTIFY_PUSH,
|
||||
identify_push_handler_for,
|
||||
push_identify_to_peer,
|
||||
)
|
||||
from libp2p.identity.identify_push import ID_PUSH as ID_IDENTIFY_PUSH
|
||||
from libp2p.peer.peerinfo import (
|
||||
info_from_p2p_addr,
|
||||
)
|
||||
@ -56,57 +56,22 @@ from libp2p.peer.peerinfo import (
|
||||
# Configure logging
|
||||
logger = logging.getLogger("libp2p.identity.identify-push-example")
|
||||
|
||||
# Default port configuration
|
||||
DEFAULT_PORT = 8888
|
||||
|
||||
def custom_identify_push_handler_for(host, use_varint_format: bool = True):
|
||||
|
||||
def custom_identify_push_handler_for(host):
|
||||
"""
|
||||
Create a custom handler for the identify/push protocol that logs and prints
|
||||
the identity information received from the dialer.
|
||||
|
||||
Args:
|
||||
host: The libp2p host
|
||||
use_varint_format: If True, expect length-prefixed format; if False, expect
|
||||
raw protobuf
|
||||
|
||||
"""
|
||||
|
||||
async def handle_identify_push(stream: INetStream) -> None:
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
|
||||
try:
|
||||
if use_varint_format:
|
||||
# Read length-prefixed identify message from the stream
|
||||
from libp2p.utils.varint import decode_varint_from_bytes
|
||||
|
||||
# First read the varint length prefix
|
||||
length_bytes = b""
|
||||
while True:
|
||||
b = await stream.read(1)
|
||||
if not b:
|
||||
break
|
||||
length_bytes += b
|
||||
if b[0] & 0x80 == 0:
|
||||
break
|
||||
|
||||
if not length_bytes:
|
||||
logger.warning("No length prefix received from peer %s", peer_id)
|
||||
return
|
||||
|
||||
msg_length = decode_varint_from_bytes(length_bytes)
|
||||
|
||||
# Read the protobuf message
|
||||
data = await stream.read(msg_length)
|
||||
if len(data) != msg_length:
|
||||
logger.warning("Incomplete message received from peer %s", peer_id)
|
||||
return
|
||||
else:
|
||||
# Read raw protobuf message from the stream
|
||||
data = b""
|
||||
while True:
|
||||
chunk = await stream.read(4096)
|
||||
if not chunk:
|
||||
break
|
||||
data += chunk
|
||||
|
||||
# Read the identify message from the stream
|
||||
data = await stream.read()
|
||||
identify_msg = Identify()
|
||||
identify_msg.ParseFromString(data)
|
||||
|
||||
@ -167,13 +132,9 @@ def custom_identify_push_handler_for(host, use_varint_format: bool = True):
|
||||
return handle_identify_push
|
||||
|
||||
|
||||
async def run_listener(port: int, use_varint_format: bool = True) -> None:
|
||||
async def run_listener(port: int) -> None:
|
||||
"""Run a host in listener mode."""
|
||||
format_name = "length-prefixed" if use_varint_format else "raw protobuf"
|
||||
print(
|
||||
f"\n==== Starting Identify-Push Listener on port {port} "
|
||||
f"(using {format_name} format) ====\n"
|
||||
)
|
||||
print(f"\n==== Starting Identify-Push Listener on port {port} ====\n")
|
||||
|
||||
# Create key pair for the listener
|
||||
key_pair = create_new_key_pair()
|
||||
@ -181,14 +142,9 @@ async def run_listener(port: int, use_varint_format: bool = True) -> None:
|
||||
# Create the listener host
|
||||
host = new_host(key_pair=key_pair)
|
||||
|
||||
# Set up the identify and identify/push handlers with specified format
|
||||
host.set_stream_handler(
|
||||
ID_IDENTIFY, identify_handler_for(host, use_varint_format=use_varint_format)
|
||||
)
|
||||
host.set_stream_handler(
|
||||
ID_IDENTIFY_PUSH,
|
||||
identify_push_handler_for(host, use_varint_format=use_varint_format),
|
||||
)
|
||||
# Set up the identify and identify/push handlers
|
||||
host.set_stream_handler(ID_IDENTIFY, identify_handler_for(host))
|
||||
host.set_stream_handler(ID_IDENTIFY_PUSH, custom_identify_push_handler_for(host))
|
||||
|
||||
# Start listening
|
||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
||||
@ -212,15 +168,9 @@ async def run_listener(port: int, use_varint_format: bool = True) -> None:
|
||||
await trio.sleep_forever()
|
||||
|
||||
|
||||
async def run_dialer(
|
||||
port: int, destination: str, use_varint_format: bool = True
|
||||
) -> None:
|
||||
async def run_dialer(port: int, destination: str) -> None:
|
||||
"""Run a host in dialer mode that connects to a listener."""
|
||||
format_name = "length-prefixed" if use_varint_format else "raw protobuf"
|
||||
print(
|
||||
f"\n==== Starting Identify-Push Dialer on port {port} "
|
||||
f"(using {format_name} format) ====\n"
|
||||
)
|
||||
print(f"\n==== Starting Identify-Push Dialer on port {port} ====\n")
|
||||
|
||||
# Create key pair for the dialer
|
||||
key_pair = create_new_key_pair()
|
||||
@ -228,14 +178,9 @@ async def run_dialer(
|
||||
# Create the dialer host
|
||||
host = new_host(key_pair=key_pair)
|
||||
|
||||
# Set up the identify and identify/push handlers with specified format
|
||||
host.set_stream_handler(
|
||||
ID_IDENTIFY, identify_handler_for(host, use_varint_format=use_varint_format)
|
||||
)
|
||||
host.set_stream_handler(
|
||||
ID_IDENTIFY_PUSH,
|
||||
identify_push_handler_for(host, use_varint_format=use_varint_format),
|
||||
)
|
||||
# Set up the identify and identify/push handlers
|
||||
host.set_stream_handler(ID_IDENTIFY, identify_handler_for(host))
|
||||
host.set_stream_handler(ID_IDENTIFY_PUSH, identify_push_handler_for(host))
|
||||
|
||||
# Start listening on a different port
|
||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
||||
@ -264,9 +209,7 @@ async def run_dialer(
|
||||
|
||||
try:
|
||||
# Call push_identify_to_peer which returns a boolean
|
||||
success = await push_identify_to_peer(
|
||||
host, peer_info.peer_id, use_varint_format=use_varint_format
|
||||
)
|
||||
success = await push_identify_to_peer(host, peer_info.peer_id)
|
||||
|
||||
if success:
|
||||
logger.info("Identify push completed successfully!")
|
||||
@ -298,42 +241,42 @@ def main() -> None:
|
||||
"""Parse arguments and start the appropriate mode."""
|
||||
description = """
|
||||
This program demonstrates the libp2p identify/push protocol.
|
||||
Without arguments, it runs as a listener on random port.
|
||||
With -d parameter, it runs as a dialer on random port.
|
||||
|
||||
Use --raw-format to send raw protobuf messages (old format) instead of
|
||||
length-prefixed protobuf messages (new format, default).
|
||||
Without arguments, it runs as a listener on port 8888.
|
||||
With -d parameter, it runs as a dialer on port 8889.
|
||||
"""
|
||||
|
||||
example = (
|
||||
f"/ip4/127.0.0.1/tcp/{DEFAULT_PORT}/p2p/"
|
||||
"QmQn4SwGkDZkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--port",
|
||||
type=int,
|
||||
help=(
|
||||
f"port to listen on (default: {DEFAULT_PORT} for listener, "
|
||||
f"{DEFAULT_PORT + 1} for dialer)"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--destination",
|
||||
type=str,
|
||||
help="destination multiaddr string",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--raw-format",
|
||||
action="store_true",
|
||||
help=(
|
||||
"use raw protobuf format (old format) instead of "
|
||||
"length-prefixed (new format)"
|
||||
),
|
||||
help=f"destination multiaddr string, e.g. {example}",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Determine format: raw format if --raw-format is specified, otherwise
|
||||
# length-prefixed
|
||||
use_varint_format = not args.raw_format
|
||||
|
||||
try:
|
||||
if args.destination:
|
||||
# Run in dialer mode with random available port if not specified
|
||||
trio.run(run_dialer, args.port, args.destination, use_varint_format)
|
||||
# Run in dialer mode with default port DEFAULT_PORT + 1 if not specified
|
||||
port = args.port if args.port is not None else DEFAULT_PORT + 1
|
||||
trio.run(run_dialer, port, args.destination)
|
||||
else:
|
||||
# Run in listener mode with random available port if not specified
|
||||
trio.run(run_listener, args.port, use_varint_format)
|
||||
# Run in listener mode with default port DEFAULT_PORT if not specified
|
||||
port = args.port if args.port is not None else DEFAULT_PORT
|
||||
trio.run(run_listener, port)
|
||||
except KeyboardInterrupt:
|
||||
print("\nInterrupted by user")
|
||||
logger.info("Interrupted by user")
|
||||
|
||||
@ -1,300 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
A basic example of using the Kademlia DHT implementation, with all setup logic inlined.
|
||||
This example demonstrates both value storage/retrieval and content server
|
||||
advertisement/discovery.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import secrets
|
||||
import sys
|
||||
|
||||
import base58
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
)
|
||||
import trio
|
||||
|
||||
from libp2p import (
|
||||
new_host,
|
||||
)
|
||||
from libp2p.abc import (
|
||||
IHost,
|
||||
)
|
||||
from libp2p.crypto.secp256k1 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.kad_dht.kad_dht import (
|
||||
DHTMode,
|
||||
KadDHT,
|
||||
)
|
||||
from libp2p.kad_dht.utils import (
|
||||
create_key_from_binary,
|
||||
)
|
||||
from libp2p.tools.async_service import (
|
||||
background_trio_service,
|
||||
)
|
||||
from libp2p.tools.utils import (
|
||||
info_from_p2p_addr,
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[logging.StreamHandler()],
|
||||
)
|
||||
logger = logging.getLogger("kademlia-example")
|
||||
|
||||
# Configure DHT module loggers to inherit from the parent logger
|
||||
# This ensures all kademlia-example.* loggers use the same configuration
|
||||
# Get the directory where this script is located
|
||||
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
SERVER_ADDR_LOG = os.path.join(SCRIPT_DIR, "server_node_addr.txt")
|
||||
|
||||
# Set the level for all child loggers
|
||||
for module in [
|
||||
"kad_dht",
|
||||
"value_store",
|
||||
"peer_routing",
|
||||
"routing_table",
|
||||
"provider_store",
|
||||
]:
|
||||
child_logger = logging.getLogger(f"kademlia-example.{module}")
|
||||
child_logger.setLevel(logging.INFO)
|
||||
child_logger.propagate = True # Allow propagation to parent
|
||||
|
||||
# File to store node information
|
||||
bootstrap_nodes = []
|
||||
|
||||
|
||||
# function to take bootstrap_nodes as input and connects to them
|
||||
async def connect_to_bootstrap_nodes(host: IHost, bootstrap_addrs: list[str]) -> None:
|
||||
"""
|
||||
Connect to the bootstrap nodes provided in the list.
|
||||
|
||||
params: host: The host instance to connect to
|
||||
bootstrap_addrs: List of bootstrap node addresses
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
|
||||
"""
|
||||
for addr in bootstrap_addrs:
|
||||
try:
|
||||
peerInfo = info_from_p2p_addr(Multiaddr(addr))
|
||||
host.get_peerstore().add_addrs(peerInfo.peer_id, peerInfo.addrs, 3600)
|
||||
await host.connect(peerInfo)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to bootstrap node {addr}: {e}")
|
||||
|
||||
|
||||
def save_server_addr(addr: str) -> None:
|
||||
"""Append the server's multiaddress to the log file."""
|
||||
try:
|
||||
with open(SERVER_ADDR_LOG, "w") as f:
|
||||
f.write(addr + "\n")
|
||||
logger.info(f"Saved server address to log: {addr}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save server address: {e}")
|
||||
|
||||
|
||||
def load_server_addrs() -> list[str]:
|
||||
"""Load all server multiaddresses from the log file."""
|
||||
if not os.path.exists(SERVER_ADDR_LOG):
|
||||
return []
|
||||
try:
|
||||
with open(SERVER_ADDR_LOG) as f:
|
||||
return [line.strip() for line in f if line.strip()]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load server addresses: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def run_node(
|
||||
port: int, mode: str, bootstrap_addrs: list[str] | None = None
|
||||
) -> None:
|
||||
"""Run a node that serves content in the DHT with setup inlined."""
|
||||
try:
|
||||
if port <= 0:
|
||||
port = random.randint(10000, 60000)
|
||||
logger.debug(f"Using port: {port}")
|
||||
|
||||
# Convert string mode to DHTMode enum
|
||||
if mode is None or mode.upper() == "CLIENT":
|
||||
dht_mode = DHTMode.CLIENT
|
||||
elif mode.upper() == "SERVER":
|
||||
dht_mode = DHTMode.SERVER
|
||||
else:
|
||||
logger.error(f"Invalid mode: {mode}. Must be 'client' or 'server'")
|
||||
sys.exit(1)
|
||||
|
||||
# Load server addresses for client mode
|
||||
if dht_mode == DHTMode.CLIENT:
|
||||
server_addrs = load_server_addrs()
|
||||
if server_addrs:
|
||||
logger.info(f"Loaded {len(server_addrs)} server addresses from log")
|
||||
bootstrap_nodes.append(server_addrs[0]) # Use the first server address
|
||||
else:
|
||||
logger.warning("No server addresses found in log file")
|
||||
|
||||
if bootstrap_addrs:
|
||||
for addr in bootstrap_addrs:
|
||||
bootstrap_nodes.append(addr)
|
||||
|
||||
key_pair = create_new_key_pair(secrets.token_bytes(32))
|
||||
host = new_host(key_pair=key_pair)
|
||||
listen_addr = Multiaddr(f"/ip4/127.0.0.1/tcp/{port}")
|
||||
|
||||
async with host.run(listen_addrs=[listen_addr]):
|
||||
peer_id = host.get_id().pretty()
|
||||
addr_str = f"/ip4/127.0.0.1/tcp/{port}/p2p/{peer_id}"
|
||||
await connect_to_bootstrap_nodes(host, bootstrap_nodes)
|
||||
dht = KadDHT(host, dht_mode)
|
||||
# take all peer ids from the host and add them to the dht
|
||||
for peer_id in host.get_peerstore().peer_ids():
|
||||
await dht.routing_table.add_peer(peer_id)
|
||||
logger.info(f"Connected to bootstrap nodes: {host.get_connected_peers()}")
|
||||
bootstrap_cmd = f"--bootstrap {addr_str}"
|
||||
logger.info("To connect to this node, use: %s", bootstrap_cmd)
|
||||
|
||||
# Save server address in server mode
|
||||
if dht_mode == DHTMode.SERVER:
|
||||
save_server_addr(addr_str)
|
||||
|
||||
# Start the DHT service
|
||||
async with background_trio_service(dht):
|
||||
logger.info(f"DHT service started in {dht_mode.value} mode")
|
||||
val_key = create_key_from_binary(b"py-libp2p kademlia example value")
|
||||
content = b"Hello from python node "
|
||||
content_key = create_key_from_binary(content)
|
||||
|
||||
if dht_mode == DHTMode.SERVER:
|
||||
# Store a value in the DHT
|
||||
msg = "Hello message from Sumanjeet"
|
||||
val_data = msg.encode()
|
||||
await dht.put_value(val_key, val_data)
|
||||
logger.info(
|
||||
f"Stored value '{val_data.decode()}'"
|
||||
f"with key: {base58.b58encode(val_key).decode()}"
|
||||
)
|
||||
|
||||
# Advertise as content server
|
||||
success = await dht.provider_store.provide(content_key)
|
||||
if success:
|
||||
logger.info(
|
||||
"Successfully advertised as server"
|
||||
f"for content: {content_key.hex()}"
|
||||
)
|
||||
else:
|
||||
logger.warning("Failed to advertise as content server")
|
||||
|
||||
else:
|
||||
# retrieve the value
|
||||
logger.info(
|
||||
"Looking up key: %s", base58.b58encode(val_key).decode()
|
||||
)
|
||||
val_data = await dht.get_value(val_key)
|
||||
if val_data:
|
||||
try:
|
||||
logger.info(f"Retrieved value: {val_data.decode()}")
|
||||
except UnicodeDecodeError:
|
||||
logger.info(f"Retrieved value (bytes): {val_data!r}")
|
||||
else:
|
||||
logger.warning("Failed to retrieve value")
|
||||
|
||||
# Also check if we can find servers for our own content
|
||||
logger.info("Looking for servers of content: %s", content_key.hex())
|
||||
providers = await dht.provider_store.find_providers(content_key)
|
||||
if providers:
|
||||
logger.info(
|
||||
"Found %d servers for content: %s",
|
||||
len(providers),
|
||||
[p.peer_id.pretty() for p in providers],
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"No servers found for content %s", content_key.hex()
|
||||
)
|
||||
|
||||
# Keep the node running
|
||||
while True:
|
||||
logger.debug(
|
||||
"Status - Connected peers: %d,"
|
||||
"Peers in store: %d, Values in store: %d",
|
||||
len(dht.host.get_connected_peers()),
|
||||
len(dht.host.get_peerstore().peer_ids()),
|
||||
len(dht.value_store.store),
|
||||
)
|
||||
await trio.sleep(10)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Server node error: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Kademlia DHT example with content server functionality"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
default="server",
|
||||
help="Run as a server or client node",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Port to listen on (0 for random)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bootstrap",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help=(
|
||||
"Multiaddrs of bootstrap nodes. "
|
||||
"Provide a space-separated list of addresses. "
|
||||
"This is required for client mode."
|
||||
),
|
||||
)
|
||||
# add option to use verbose logging
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="Enable verbose logging",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
# Set logging level based on verbosity
|
||||
if args.verbose:
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
else:
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the kademlia demo."""
|
||||
try:
|
||||
args = parse_args()
|
||||
logger.info(
|
||||
"Running in %s mode on port %d",
|
||||
args.mode,
|
||||
args.port,
|
||||
)
|
||||
trio.run(run_node, args.port, args.mode, args.bootstrap)
|
||||
except Exception as e:
|
||||
logger.critical(f"Script failed: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,74 +0,0 @@
|
||||
import argparse
|
||||
import logging
|
||||
import secrets
|
||||
|
||||
import multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p import (
|
||||
new_host,
|
||||
)
|
||||
from libp2p.abc import PeerInfo
|
||||
from libp2p.crypto.secp256k1 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.discovery.events.peerDiscovery import peerDiscovery
|
||||
|
||||
logger = logging.getLogger("libp2p.discovery.mdns")
|
||||
logger.setLevel(logging.INFO)
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(
|
||||
logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
)
|
||||
logger.addHandler(handler)
|
||||
|
||||
# Set root logger to DEBUG to capture all logs from dependencies
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
def onPeerDiscovery(peerinfo: PeerInfo):
|
||||
logger.info(f"Discovered: {peerinfo.peer_id}")
|
||||
|
||||
|
||||
async def run(port: int) -> None:
|
||||
secret = secrets.token_bytes(32)
|
||||
key_pair = create_new_key_pair(secret)
|
||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
||||
|
||||
peerDiscovery.register_peer_discovered_handler(onPeerDiscovery)
|
||||
|
||||
print(
|
||||
"Run this from the same folder in another console to "
|
||||
"start another peer on a different port:\n\n"
|
||||
"mdns-demo -p <ANOTHER_PORT>\n"
|
||||
)
|
||||
print("Waiting for mDNS peer discovery events...\n")
|
||||
|
||||
logger.info("Starting peer Discovery")
|
||||
host = new_host(key_pair=key_pair, enable_mDNS=True)
|
||||
async with host.run(listen_addrs=[listen_addr]):
|
||||
await trio.sleep_forever()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
description = """
|
||||
This program demonstrates mDNS peer discovery using libp2p.
|
||||
To use it, run 'mdns-demo -p <PORT>', where <PORT> is the port number.
|
||||
Start multiple peers on different ports to see discovery in action.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
|
||||
parser.add_argument(
|
||||
"-v", "--verbose", action="store_true", help="Enable verbose output"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.verbose:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
try:
|
||||
trio.run(run, args.port)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Exiting...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -55,6 +55,7 @@ async def send_ping(stream: INetStream) -> None:
|
||||
|
||||
|
||||
async def run(port: int, destination: str) -> None:
|
||||
localhost_ip = "127.0.0.1"
|
||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
||||
host = new_host(listen_addrs=[listen_addr])
|
||||
|
||||
@ -64,8 +65,8 @@ async def run(port: int, destination: str) -> None:
|
||||
|
||||
print(
|
||||
"Run this from the same folder in another console:\n\n"
|
||||
f"ping-demo "
|
||||
f"-d {host.get_addrs()[0]}\n"
|
||||
f"ping-demo -p {int(port) + 1} "
|
||||
f"-d /ip4/{localhost_ip}/tcp/{port}/p2p/{host.get_id().pretty()}\n"
|
||||
)
|
||||
print("Waiting for incoming connection...")
|
||||
|
||||
@ -95,8 +96,10 @@ def main() -> None:
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
|
||||
|
||||
parser.add_argument(
|
||||
"-p", "--port", default=8000, type=int, help="source port number"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--destination",
|
||||
@ -105,6 +108,9 @@ def main() -> None:
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.port:
|
||||
raise RuntimeError("failed to determine local port")
|
||||
|
||||
try:
|
||||
trio.run(run, *(args.port, args.destination))
|
||||
except KeyboardInterrupt:
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
import argparse
|
||||
import logging
|
||||
import socket
|
||||
from typing import (
|
||||
Optional,
|
||||
)
|
||||
|
||||
import base58
|
||||
import multiaddr
|
||||
@ -106,7 +109,7 @@ async def monitor_peer_topics(pubsub, nursery, termination_event):
|
||||
await trio.sleep(2)
|
||||
|
||||
|
||||
async def run(topic: str, destination: str | None, port: int | None) -> None:
|
||||
async def run(topic: str, destination: Optional[str], port: Optional[int]) -> None:
|
||||
# Initialize network settings
|
||||
localhost_ip = "127.0.0.1"
|
||||
|
||||
|
||||
@ -32,9 +32,6 @@ from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
TSecurityOptions,
|
||||
)
|
||||
from libp2p.discovery.mdns.mdns import (
|
||||
MDNSDiscovery,
|
||||
)
|
||||
from libp2p.host.basic_host import (
|
||||
BasicHost,
|
||||
)
|
||||
@ -84,8 +81,6 @@ DEFAULT_MUXER = "YAMUX"
|
||||
# Multiplexer options
|
||||
MUXER_YAMUX = "YAMUX"
|
||||
MUXER_MPLEX = "MPLEX"
|
||||
DEFAULT_NEGOTIATE_TIMEOUT = 5
|
||||
|
||||
|
||||
|
||||
def set_default_muxer(muxer_name: Literal["YAMUX", "MPLEX"]) -> None:
|
||||
@ -157,12 +152,12 @@ def get_default_muxer_options() -> TMuxerOptions:
|
||||
|
||||
|
||||
def new_swarm(
|
||||
key_pair: KeyPair | None = None,
|
||||
muxer_opt: TMuxerOptions | None = None,
|
||||
sec_opt: TSecurityOptions | None = None,
|
||||
peerstore_opt: IPeerStore | None = None,
|
||||
muxer_preference: Literal["YAMUX", "MPLEX"] | None = None,
|
||||
listen_addrs: Sequence[multiaddr.Multiaddr] | None = None,
|
||||
key_pair: Optional[KeyPair] = None,
|
||||
muxer_opt: Optional[TMuxerOptions] = None,
|
||||
sec_opt: Optional[TSecurityOptions] = None,
|
||||
peerstore_opt: Optional[IPeerStore] = None,
|
||||
muxer_preference: Optional[Literal["YAMUX", "MPLEX"]] = None,
|
||||
listen_addrs: Optional[Sequence[multiaddr.Multiaddr]] = None,
|
||||
) -> INetworkService:
|
||||
"""
|
||||
Create a swarm instance based on the parameters.
|
||||
@ -205,9 +200,7 @@ def new_swarm(
|
||||
key_pair, noise_privkey=noise_key_pair.private_key
|
||||
),
|
||||
TProtocol(secio.ID): secio.Transport(key_pair),
|
||||
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(
|
||||
key_pair, peerstore=peerstore_opt
|
||||
),
|
||||
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair),
|
||||
}
|
||||
|
||||
# Use given muxer preference if provided, otherwise use global default
|
||||
@ -243,15 +236,13 @@ def new_swarm(
|
||||
|
||||
|
||||
def new_host(
|
||||
key_pair: KeyPair | None = None,
|
||||
muxer_opt: TMuxerOptions | None = None,
|
||||
sec_opt: TSecurityOptions | None = None,
|
||||
peerstore_opt: IPeerStore | None = None,
|
||||
disc_opt: IPeerRouting | None = None,
|
||||
muxer_preference: Literal["YAMUX", "MPLEX"] | None = None,
|
||||
listen_addrs: Sequence[multiaddr.Multiaddr] | None = None,
|
||||
enable_mDNS: bool = False,
|
||||
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
key_pair: Optional[KeyPair] = None,
|
||||
muxer_opt: Optional[TMuxerOptions] = None,
|
||||
sec_opt: Optional[TSecurityOptions] = None,
|
||||
peerstore_opt: Optional[IPeerStore] = None,
|
||||
disc_opt: Optional[IPeerRouting] = None,
|
||||
muxer_preference: Optional[Literal["YAMUX", "MPLEX"]] = None,
|
||||
listen_addrs: Sequence[multiaddr.Multiaddr] = None,
|
||||
) -> IHost:
|
||||
"""
|
||||
Create a new libp2p host based on the given parameters.
|
||||
@ -263,7 +254,6 @@ def new_host(
|
||||
:param disc_opt: optional discovery
|
||||
:param muxer_preference: optional explicit muxer preference
|
||||
:param listen_addrs: optional list of multiaddrs to listen on
|
||||
:param enable_mDNS: whether to enable mDNS discovery
|
||||
:return: return a host instance
|
||||
"""
|
||||
swarm = new_swarm(
|
||||
@ -276,7 +266,8 @@ def new_host(
|
||||
)
|
||||
|
||||
if disc_opt is not None:
|
||||
return RoutedHost(swarm, disc_opt, enable_mDNS)
|
||||
return BasicHost(network=swarm,enable_mDNS=enable_mDNS , negotitate_timeout=negotiate_timeout)
|
||||
return RoutedHost(swarm, disc_opt)
|
||||
return BasicHost(swarm)
|
||||
|
||||
|
||||
__version__ = __version("libp2p")
|
||||
|
||||
863
libp2p/abc.py
863
libp2p/abc.py
File diff suppressed because it is too large
Load Diff
@ -116,15 +116,15 @@ def initialize_pair(
|
||||
EncryptionParameters(
|
||||
cipher_type,
|
||||
hash_type,
|
||||
bytes(first_half[0:iv_size]),
|
||||
bytes(first_half[iv_size + cipher_key_size :]),
|
||||
bytes(first_half[iv_size : iv_size + cipher_key_size]),
|
||||
first_half[0:iv_size],
|
||||
first_half[iv_size + cipher_key_size :],
|
||||
first_half[iv_size : iv_size + cipher_key_size],
|
||||
),
|
||||
EncryptionParameters(
|
||||
cipher_type,
|
||||
hash_type,
|
||||
bytes(second_half[0:iv_size]),
|
||||
bytes(second_half[iv_size + cipher_key_size :]),
|
||||
bytes(second_half[iv_size : iv_size + cipher_key_size]),
|
||||
second_half[0:iv_size],
|
||||
second_half[iv_size + cipher_key_size :],
|
||||
second_half[iv_size : iv_size + cipher_key_size],
|
||||
),
|
||||
)
|
||||
|
||||
@ -9,40 +9,29 @@ from libp2p.crypto.keys import (
|
||||
|
||||
if sys.platform != "win32":
|
||||
from fastecdsa import (
|
||||
curve as curve_types,
|
||||
keys,
|
||||
point,
|
||||
)
|
||||
from fastecdsa import curve as curve_types
|
||||
from fastecdsa.encoding.sec1 import (
|
||||
SEC1Encoder,
|
||||
)
|
||||
else:
|
||||
from coincurve import (
|
||||
PrivateKey as CPrivateKey,
|
||||
PublicKey as CPublicKey,
|
||||
)
|
||||
from coincurve import PrivateKey as CPrivateKey
|
||||
from coincurve import PublicKey as CPublicKey
|
||||
|
||||
|
||||
if sys.platform != "win32":
|
||||
def infer_local_type(curve: str) -> object:
|
||||
"""
|
||||
Convert a str representation of some elliptic curve to a
|
||||
representation understood by the backend of this module.
|
||||
"""
|
||||
if curve != "P-256":
|
||||
raise NotImplementedError("Only P-256 curve is supported")
|
||||
|
||||
def infer_local_type(curve: str) -> curve_types.Curve:
|
||||
"""
|
||||
Convert a str representation of some elliptic curve to a
|
||||
representation understood by the backend of this module.
|
||||
"""
|
||||
if curve != "P-256":
|
||||
raise NotImplementedError("Only P-256 curve is supported")
|
||||
if sys.platform != "win32":
|
||||
return curve_types.P256
|
||||
else:
|
||||
|
||||
def infer_local_type(curve: str) -> str:
|
||||
"""
|
||||
Convert a str representation of some elliptic curve to a
|
||||
representation understood by the backend of this module.
|
||||
"""
|
||||
if curve != "P-256":
|
||||
raise NotImplementedError("Only P-256 curve is supported")
|
||||
return "P-256" # coincurve only supports P-256
|
||||
return "P-256" # coincurve only supports P-256
|
||||
|
||||
|
||||
if sys.platform != "win32":
|
||||
@ -79,10 +68,7 @@ if sys.platform != "win32":
|
||||
return cls(private_key_impl, curve_type)
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
key_str = keys.export_key(self.impl, self.curve)
|
||||
if key_str is None:
|
||||
raise Exception("Key not found")
|
||||
return key_str.encode()
|
||||
return keys.export_key(self.impl, self.curve)
|
||||
|
||||
def get_type(self) -> KeyType:
|
||||
return KeyType.ECC_P256
|
||||
|
||||
@ -4,10 +4,8 @@ from Crypto.Hash import (
|
||||
from nacl.exceptions import (
|
||||
BadSignatureError,
|
||||
)
|
||||
from nacl.public import (
|
||||
PrivateKey as PrivateKeyImpl,
|
||||
PublicKey as PublicKeyImpl,
|
||||
)
|
||||
from nacl.public import PrivateKey as PrivateKeyImpl
|
||||
from nacl.public import PublicKey as PublicKeyImpl
|
||||
from nacl.signing import (
|
||||
SigningKey,
|
||||
VerifyKey,
|
||||
@ -50,7 +48,7 @@ class Ed25519PrivateKey(PrivateKey):
|
||||
self.impl = impl
|
||||
|
||||
@classmethod
|
||||
def new(cls, seed: bytes | None = None) -> "Ed25519PrivateKey":
|
||||
def new(cls, seed: bytes = None) -> "Ed25519PrivateKey":
|
||||
if not seed:
|
||||
seed = utils.random()
|
||||
|
||||
@ -77,7 +75,7 @@ class Ed25519PrivateKey(PrivateKey):
|
||||
return Ed25519PublicKey(self.impl.public_key)
|
||||
|
||||
|
||||
def create_new_key_pair(seed: bytes | None = None) -> KeyPair:
|
||||
def create_new_key_pair(seed: bytes = None) -> KeyPair:
|
||||
private_key = Ed25519PrivateKey.new(seed)
|
||||
public_key = private_key.get_public_key()
|
||||
return KeyPair(private_key, public_key)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
import sys
|
||||
from typing import (
|
||||
Callable,
|
||||
cast,
|
||||
)
|
||||
|
||||
|
||||
@ -81,10 +81,12 @@ class PrivateKey(Key):
|
||||
"""A ``PrivateKey`` represents a cryptographic private key."""
|
||||
|
||||
@abstractmethod
|
||||
def sign(self, data: bytes) -> bytes: ...
|
||||
def sign(self, data: bytes) -> bytes:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_public_key(self) -> PublicKey: ...
|
||||
def get_public_key(self) -> PublicKey:
|
||||
...
|
||||
|
||||
def _serialize_to_protobuf(self) -> crypto_pb2.PrivateKey:
|
||||
"""Return the protobuf representation of this ``Key``."""
|
||||
|
||||
@ -37,7 +37,7 @@ class Secp256k1PrivateKey(PrivateKey):
|
||||
self.impl = impl
|
||||
|
||||
@classmethod
|
||||
def new(cls, secret: bytes | None = None) -> "Secp256k1PrivateKey":
|
||||
def new(cls, secret: bytes = None) -> "Secp256k1PrivateKey":
|
||||
private_key_impl = coincurve.PrivateKey(secret)
|
||||
return cls(private_key_impl)
|
||||
|
||||
@ -65,7 +65,7 @@ class Secp256k1PrivateKey(PrivateKey):
|
||||
return Secp256k1PublicKey(public_key_impl)
|
||||
|
||||
|
||||
def create_new_key_pair(secret: bytes | None = None) -> KeyPair:
|
||||
def create_new_key_pair(secret: bytes = None) -> KeyPair:
|
||||
"""
|
||||
Returns a new Secp256k1 keypair derived from the provided ``secret``, a
|
||||
sequence of bytes corresponding to some integer between 0 and the group
|
||||
|
||||
@ -1,9 +1,13 @@
|
||||
from collections.abc import (
|
||||
Awaitable,
|
||||
Callable,
|
||||
Mapping,
|
||||
)
|
||||
from typing import TYPE_CHECKING, NewType, Union, cast
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Callable,
|
||||
NewType,
|
||||
Union,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from libp2p.abc import (
|
||||
@ -12,9 +16,15 @@ if TYPE_CHECKING:
|
||||
ISecureTransport,
|
||||
)
|
||||
else:
|
||||
IMuxedConn = cast(type, object)
|
||||
INetStream = cast(type, object)
|
||||
ISecureTransport = cast(type, object)
|
||||
|
||||
class INetStream:
|
||||
pass
|
||||
|
||||
class IMuxedConn:
|
||||
pass
|
||||
|
||||
class ISecureTransport:
|
||||
pass
|
||||
|
||||
|
||||
from libp2p.io.abc import (
|
||||
@ -28,10 +38,10 @@ from libp2p.pubsub.pb import (
|
||||
)
|
||||
|
||||
TProtocol = NewType("TProtocol", str)
|
||||
StreamHandlerFn = Callable[[INetStream], Awaitable[None]]
|
||||
StreamHandlerFn = Callable[["INetStream"], Awaitable[None]]
|
||||
THandler = Callable[[ReadWriteCloser], Awaitable[None]]
|
||||
TSecurityOptions = Mapping[TProtocol, ISecureTransport]
|
||||
TMuxerClass = type[IMuxedConn]
|
||||
TSecurityOptions = Mapping[TProtocol, "ISecureTransport"]
|
||||
TMuxerClass = type["IMuxedConn"]
|
||||
TMuxerOptions = Mapping[TProtocol, TMuxerClass]
|
||||
SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool]
|
||||
AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]]
|
||||
|
||||
@ -1,26 +0,0 @@
|
||||
from collections.abc import (
|
||||
Callable,
|
||||
)
|
||||
|
||||
from libp2p.abc import (
|
||||
PeerInfo,
|
||||
)
|
||||
|
||||
TTL: int = 60 * 60 # Time-to-live for discovered peers in seconds
|
||||
|
||||
|
||||
class PeerDiscovery:
|
||||
def __init__(self) -> None:
|
||||
self._peer_discovered_handlers: list[Callable[[PeerInfo], None]] = []
|
||||
|
||||
def register_peer_discovered_handler(
|
||||
self, handler: Callable[[PeerInfo], None]
|
||||
) -> None:
|
||||
self._peer_discovered_handlers.append(handler)
|
||||
|
||||
def emit_peer_discovered(self, peer_info: PeerInfo) -> None:
|
||||
for handler in self._peer_discovered_handlers:
|
||||
handler(peer_info)
|
||||
|
||||
|
||||
peerDiscovery = PeerDiscovery()
|
||||
@ -1,91 +0,0 @@
|
||||
import logging
|
||||
import socket
|
||||
|
||||
from zeroconf import (
|
||||
EventLoopBlocked,
|
||||
ServiceInfo,
|
||||
Zeroconf,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("libp2p.discovery.mdns.broadcaster")
|
||||
|
||||
|
||||
class PeerBroadcaster:
|
||||
"""
|
||||
Broadcasts this peer's presence on the local network using mDNS/zeroconf.
|
||||
Registers a service with the peer's ID in the TXT record as per libp2p spec.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
zeroconf: Zeroconf,
|
||||
service_type: str,
|
||||
service_name: str,
|
||||
peer_id: str,
|
||||
port: int,
|
||||
):
|
||||
self.zeroconf = zeroconf
|
||||
self.service_type = service_type
|
||||
self.peer_id = peer_id
|
||||
self.port = port
|
||||
self.service_name = service_name
|
||||
|
||||
# Get the local IP address
|
||||
local_ip = self._get_local_ip()
|
||||
hostname = socket.gethostname()
|
||||
|
||||
self.service_info = ServiceInfo(
|
||||
type_=self.service_type,
|
||||
name=self.service_name,
|
||||
port=self.port,
|
||||
properties={b"id": self.peer_id.encode()},
|
||||
server=f"{hostname}.local.",
|
||||
addresses=[socket.inet_aton(local_ip)],
|
||||
)
|
||||
|
||||
def _get_local_ip(self) -> str:
|
||||
"""Get the local IP address of this machine"""
|
||||
try:
|
||||
# Connect to a remote address to determine the local IP
|
||||
# This doesn't actually send data
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
||||
s.connect(("8.8.8.8", 80))
|
||||
local_ip = s.getsockname()[0]
|
||||
return local_ip
|
||||
except Exception:
|
||||
# Fallback to localhost if we can't determine the IP
|
||||
return "127.0.0.1"
|
||||
|
||||
def register(self) -> None:
|
||||
"""Register the peer's mDNS service on the network."""
|
||||
try:
|
||||
self.zeroconf.register_service(self.service_info)
|
||||
logger.debug(f"mDNS service registered: {self.service_name}")
|
||||
except EventLoopBlocked as e:
|
||||
logger.warning(
|
||||
"EventLoopBlocked while registering mDNS '%s': %s", self.service_name, e
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Unexpected error during mDNS registration for '%s': %r",
|
||||
self.service_name,
|
||||
e,
|
||||
)
|
||||
|
||||
def unregister(self) -> None:
|
||||
"""Unregister the peer's mDNS service from the network."""
|
||||
try:
|
||||
self.zeroconf.unregister_service(self.service_info)
|
||||
logger.debug(f"mDNS service unregistered: {self.service_name}")
|
||||
except EventLoopBlocked as e:
|
||||
logger.warning(
|
||||
"EventLoopBlocked while unregistering mDNS '%s': %s",
|
||||
self.service_name,
|
||||
e,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Unexpected error during mDNS unregistration for '%s': %r",
|
||||
self.service_name,
|
||||
e,
|
||||
)
|
||||
@ -1,83 +0,0 @@
|
||||
import logging
|
||||
import socket
|
||||
|
||||
from zeroconf import (
|
||||
ServiceBrowser,
|
||||
ServiceInfo,
|
||||
ServiceListener,
|
||||
Zeroconf,
|
||||
)
|
||||
|
||||
from libp2p.abc import IPeerStore, Multiaddr
|
||||
from libp2p.discovery.events.peerDiscovery import peerDiscovery
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import PeerInfo
|
||||
|
||||
logger = logging.getLogger("libp2p.discovery.mdns.listener")
|
||||
|
||||
|
||||
class PeerListener(ServiceListener):
|
||||
"""mDNS listener — now a true ServiceListener subclass."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
peerstore: IPeerStore,
|
||||
zeroconf: Zeroconf,
|
||||
service_type: str,
|
||||
service_name: str,
|
||||
) -> None:
|
||||
self.peerstore = peerstore
|
||||
self.zeroconf = zeroconf
|
||||
self.service_type = service_type
|
||||
self.service_name = service_name
|
||||
self.discovered_services: dict[str, ID] = {}
|
||||
self.browser = ServiceBrowser(self.zeroconf, self.service_type, listener=self)
|
||||
|
||||
def add_service(self, zc: Zeroconf, type_: str, name: str) -> None:
|
||||
if name == self.service_name:
|
||||
return
|
||||
logger.debug(f"Adding service: {name}")
|
||||
info = zc.get_service_info(type_, name, timeout=5000)
|
||||
if not info:
|
||||
return
|
||||
peer_info = self._extract_peer_info(info)
|
||||
if peer_info:
|
||||
self.discovered_services[name] = peer_info.peer_id
|
||||
self.peerstore.add_addrs(peer_info.peer_id, peer_info.addrs, 10)
|
||||
peerDiscovery.emit_peer_discovered(peer_info)
|
||||
logger.debug(f"Discovered Peer: {peer_info.peer_id}")
|
||||
|
||||
def remove_service(self, zc: Zeroconf, type_: str, name: str) -> None:
|
||||
if name == self.service_name:
|
||||
return
|
||||
logger.debug(f"Removing service: {name}")
|
||||
peer_id = self.discovered_services.pop(name)
|
||||
self.peerstore.clear_addrs(peer_id)
|
||||
logger.debug(f"Removed Peer: {peer_id}")
|
||||
|
||||
def update_service(self, zc: Zeroconf, type_: str, name: str) -> None:
|
||||
info = zc.get_service_info(type_, name, timeout=5000)
|
||||
if not info:
|
||||
return
|
||||
peer_info = self._extract_peer_info(info)
|
||||
if peer_info:
|
||||
self.peerstore.clear_addrs(peer_info.peer_id)
|
||||
self.peerstore.add_addrs(peer_info.peer_id, peer_info.addrs, 10)
|
||||
logger.debug(f"Updated Peer {peer_info.peer_id}")
|
||||
|
||||
def _extract_peer_info(self, info: ServiceInfo) -> PeerInfo | None:
|
||||
try:
|
||||
addrs = [
|
||||
Multiaddr(f"/ip4/{socket.inet_ntoa(addr)}/tcp/{info.port}")
|
||||
for addr in info.addresses
|
||||
]
|
||||
pid_bytes = info.properties.get(b"id")
|
||||
if not pid_bytes:
|
||||
return None
|
||||
pid = ID.from_base58(pid_bytes.decode())
|
||||
return PeerInfo(peer_id=pid, addrs=addrs)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def stop(self) -> None:
|
||||
self.browser.cancel()
|
||||
@ -1,73 +0,0 @@
|
||||
"""
|
||||
mDNS-based peer discovery for py-libp2p.
|
||||
Conforms to https://github.com/libp2p/specs/blob/master/discovery/mdns.md
|
||||
Uses zeroconf for mDNS broadcast/listen. Async operations use trio.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from zeroconf import (
|
||||
Zeroconf,
|
||||
)
|
||||
|
||||
from libp2p.abc import (
|
||||
INetworkService,
|
||||
)
|
||||
|
||||
from .broadcaster import (
|
||||
PeerBroadcaster,
|
||||
)
|
||||
from .listener import (
|
||||
PeerListener,
|
||||
)
|
||||
from .utils import (
|
||||
stringGen,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("libp2p.discovery.mdns")
|
||||
|
||||
SERVICE_TYPE = "_p2p._udp.local."
|
||||
MCAST_PORT = 5353
|
||||
MCAST_ADDR = "224.0.0.251"
|
||||
|
||||
|
||||
class MDNSDiscovery:
|
||||
"""
|
||||
mDNS-based peer discovery for py-libp2p, using zeroconf.
|
||||
Conforms to the libp2p mDNS discovery spec.
|
||||
"""
|
||||
|
||||
def __init__(self, swarm: INetworkService, port: int = 8000):
|
||||
self.peer_id = str(swarm.get_peer_id())
|
||||
self.port = port
|
||||
self.zeroconf = Zeroconf()
|
||||
self.serviceName = f"{stringGen()}.{SERVICE_TYPE}"
|
||||
self.peerstore = swarm.peerstore
|
||||
self.swarm = swarm
|
||||
self.broadcaster = PeerBroadcaster(
|
||||
zeroconf=self.zeroconf,
|
||||
service_type=SERVICE_TYPE,
|
||||
service_name=self.serviceName,
|
||||
peer_id=self.peer_id,
|
||||
port=self.port,
|
||||
)
|
||||
self.listener = PeerListener(
|
||||
zeroconf=self.zeroconf,
|
||||
peerstore=self.peerstore,
|
||||
service_type=SERVICE_TYPE,
|
||||
service_name=self.serviceName,
|
||||
)
|
||||
|
||||
def start(self) -> None:
|
||||
"""Register this peer and start listening for others."""
|
||||
logger.debug(
|
||||
f"Starting mDNS discovery for peer {self.peer_id} on port {self.port}"
|
||||
)
|
||||
self.broadcaster.register()
|
||||
# Listener is started in constructor
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Unregister this peer and clean up zeroconf resources."""
|
||||
logger.debug("Stopping mDNS discovery")
|
||||
self.broadcaster.unregister()
|
||||
self.zeroconf.close()
|
||||
@ -1,11 +0,0 @@
|
||||
import random
|
||||
import string
|
||||
|
||||
|
||||
def stringGen(len: int = 63) -> str:
|
||||
"""Generate a random string of lowercase letters and digits."""
|
||||
charset = string.ascii_lowercase + string.digits
|
||||
result = []
|
||||
for _ in range(len):
|
||||
result.append(random.choice(charset))
|
||||
return "".join(result)
|
||||
@ -1,4 +1,7 @@
|
||||
import logging
|
||||
from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
@ -91,7 +94,7 @@ class AutoNATService:
|
||||
finally:
|
||||
await stream.close()
|
||||
|
||||
async def _handle_request(self, request: bytes | Message) -> Message:
|
||||
async def _handle_request(self, request: Union[bytes, Message]) -> Message:
|
||||
"""
|
||||
Process an AutoNAT protocol request.
|
||||
|
||||
|
||||
@ -84,23 +84,26 @@ class AutoNAT:
|
||||
request: Any,
|
||||
target: str,
|
||||
options: tuple[Any, ...] = (),
|
||||
channel_credentials: Any | None = None,
|
||||
call_credentials: Any | None = None,
|
||||
channel_credentials: Optional[Any] = None,
|
||||
call_credentials: Optional[Any] = None,
|
||||
insecure: bool = False,
|
||||
compression: Any | None = None,
|
||||
wait_for_ready: bool | None = None,
|
||||
timeout: float | None = None,
|
||||
metadata: list[tuple[str, str]] | None = None,
|
||||
compression: Optional[Any] = None,
|
||||
wait_for_ready: Optional[bool] = None,
|
||||
timeout: Optional[float] = None,
|
||||
metadata: Optional[list[tuple[str, str]]] = None,
|
||||
) -> Any:
|
||||
channel = grpc.secure_channel(target, channel_credentials) if channel_credentials else grpc.insecure_channel(target)
|
||||
return channel.unary_unary(
|
||||
"/autonat.pb.AutoNAT/Dial",
|
||||
request_serializer=autonat__pb2.Message.SerializeToString,
|
||||
response_deserializer=autonat__pb2.Message.FromString,
|
||||
_registered_method=True,
|
||||
)(
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
timeout=timeout,
|
||||
metadata=metadata,
|
||||
wait_for_ready=wait_for_ready,
|
||||
target,
|
||||
"/autonat.pb.AutoNAT/Dial",
|
||||
autonat__pb2.Message.SerializeToString,
|
||||
autonat__pb2.Message.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
)
|
||||
|
||||
@ -3,7 +3,6 @@ from collections.abc import (
|
||||
Sequence,
|
||||
)
|
||||
from contextlib import (
|
||||
AbstractAsyncContextManager,
|
||||
asynccontextmanager,
|
||||
)
|
||||
import logging
|
||||
@ -29,7 +28,6 @@ from libp2p.custom_types import (
|
||||
StreamHandlerFn,
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.discovery.mdns.mdns import MDNSDiscovery
|
||||
from libp2p.host.defaults import (
|
||||
get_default_protocols,
|
||||
)
|
||||
@ -71,7 +69,6 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
logger = logging.getLogger("libp2p.network.basic_host")
|
||||
DEFAULT_NEGOTIATE_TIMEOUT = 5
|
||||
|
||||
|
||||
class BasicHost(IHost):
|
||||
@ -91,20 +88,15 @@ class BasicHost(IHost):
|
||||
def __init__(
|
||||
self,
|
||||
network: INetworkService,
|
||||
enable_mDNS: bool = False,
|
||||
default_protocols: Optional["OrderedDict[TProtocol, StreamHandlerFn]"] = None,
|
||||
negotitate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
default_protocols: "OrderedDict[TProtocol, StreamHandlerFn]" = None,
|
||||
) -> None:
|
||||
self._network = network
|
||||
self._network.set_stream_handler(self._swarm_stream_handler)
|
||||
self.peerstore = self._network.peerstore
|
||||
self.negotiate_timeout = negotitate_timeout
|
||||
# Protocol muxing
|
||||
default_protocols = default_protocols or get_default_protocols(self)
|
||||
self.multiselect = Multiselect(dict(default_protocols.items()))
|
||||
self.multiselect = Multiselect(default_protocols)
|
||||
self.multiselect_client = MultiselectClient()
|
||||
if enable_mDNS:
|
||||
self.mDNS = MDNSDiscovery(network)
|
||||
|
||||
def get_id(self) -> ID:
|
||||
"""
|
||||
@ -155,30 +147,19 @@ class BasicHost(IHost):
|
||||
"""
|
||||
return list(self._network.connections.keys())
|
||||
|
||||
def run(
|
||||
@asynccontextmanager
|
||||
async def run(
|
||||
self, listen_addrs: Sequence[multiaddr.Multiaddr]
|
||||
) -> AbstractAsyncContextManager[None]:
|
||||
) -> AsyncIterator[None]:
|
||||
"""
|
||||
Run the host instance and listen to ``listen_addrs``.
|
||||
|
||||
:param listen_addrs: a sequence of multiaddrs that we want to listen to
|
||||
"""
|
||||
|
||||
@asynccontextmanager
|
||||
async def _run() -> AsyncIterator[None]:
|
||||
network = self.get_network()
|
||||
async with background_trio_service(network):
|
||||
await network.listen(*listen_addrs)
|
||||
if hasattr(self, "mDNS") and self.mDNS is not None:
|
||||
logger.debug("Starting mDNS Discovery")
|
||||
self.mDNS.start()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if hasattr(self, "mDNS") and self.mDNS is not None:
|
||||
self.mDNS.stop()
|
||||
|
||||
return _run()
|
||||
network = self.get_network()
|
||||
async with background_trio_service(network):
|
||||
await network.listen(*listen_addrs)
|
||||
yield
|
||||
|
||||
def set_stream_handler(
|
||||
self, protocol_id: TProtocol, stream_handler: StreamHandlerFn
|
||||
@ -192,10 +173,7 @@ class BasicHost(IHost):
|
||||
self.multiselect.add_handler(protocol_id, stream_handler)
|
||||
|
||||
async def new_stream(
|
||||
self,
|
||||
peer_id: ID,
|
||||
protocol_ids: Sequence[TProtocol],
|
||||
negotitate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
self, peer_id: ID, protocol_ids: Sequence[TProtocol]
|
||||
) -> INetStream:
|
||||
"""
|
||||
:param peer_id: peer_id that host is connecting
|
||||
@ -207,9 +185,7 @@ class BasicHost(IHost):
|
||||
# Perform protocol muxing to determine protocol to use
|
||||
try:
|
||||
selected_protocol = await self.multiselect_client.select_one_of(
|
||||
list(protocol_ids),
|
||||
MultiselectCommunicator(net_stream),
|
||||
negotitate_timeout,
|
||||
list(protocol_ids), MultiselectCommunicator(net_stream)
|
||||
)
|
||||
except MultiselectClientError as error:
|
||||
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)
|
||||
@ -219,12 +195,7 @@ class BasicHost(IHost):
|
||||
net_stream.set_protocol(selected_protocol)
|
||||
return net_stream
|
||||
|
||||
async def send_command(
|
||||
self,
|
||||
peer_id: ID,
|
||||
command: str,
|
||||
response_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
) -> list[str]:
|
||||
async def send_command(self, peer_id: ID, command: str) -> list[str]:
|
||||
"""
|
||||
Send a multistream-select command to the specified peer and return
|
||||
the response.
|
||||
@ -238,7 +209,7 @@ class BasicHost(IHost):
|
||||
|
||||
try:
|
||||
response = await self.multiselect_client.query_multistream_command(
|
||||
MultiselectCommunicator(new_stream), command, response_timeout
|
||||
MultiselectCommunicator(new_stream), command
|
||||
)
|
||||
except MultiselectClientError as error:
|
||||
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)
|
||||
@ -258,7 +229,7 @@ class BasicHost(IHost):
|
||||
:param peer_info: peer_info of the peer we want to connect to
|
||||
:type peer_info: peer.peerinfo.PeerInfo
|
||||
"""
|
||||
self.peerstore.add_addrs(peer_info.peer_id, peer_info.addrs, 120)
|
||||
self.peerstore.add_addrs(peer_info.peer_id, peer_info.addrs, 10)
|
||||
|
||||
# there is already a connection to this peer
|
||||
if peer_info.peer_id in self._network.connections:
|
||||
@ -277,7 +248,7 @@ class BasicHost(IHost):
|
||||
# Perform protocol muxing to determine protocol to use
|
||||
try:
|
||||
protocol, handler = await self.multiselect.negotiate(
|
||||
MultiselectCommunicator(net_stream), self.negotiate_timeout
|
||||
MultiselectCommunicator(net_stream)
|
||||
)
|
||||
except MultiselectError as error:
|
||||
peer_id = net_stream.muxed_conn.peer_id
|
||||
@ -287,15 +258,6 @@ class BasicHost(IHost):
|
||||
await net_stream.reset()
|
||||
return
|
||||
net_stream.set_protocol(protocol)
|
||||
if handler is None:
|
||||
logger.debug(
|
||||
"no handler for protocol %s, closing stream from peer %s",
|
||||
protocol,
|
||||
net_stream.muxed_conn.peer_id,
|
||||
)
|
||||
await net_stream.reset()
|
||||
return
|
||||
|
||||
await handler(net_stream)
|
||||
|
||||
def get_live_peers(self) -> list[ID]:
|
||||
@ -315,7 +277,7 @@ class BasicHost(IHost):
|
||||
"""
|
||||
return peer_id in self._network.connections
|
||||
|
||||
def get_peer_connection_info(self, peer_id: ID) -> INetConn | None:
|
||||
def get_peer_connection_info(self, peer_id: ID) -> Optional[INetConn]:
|
||||
"""
|
||||
Get connection information for a specific peer if connected.
|
||||
|
||||
|
||||
@ -9,13 +9,13 @@ from libp2p.abc import (
|
||||
IHost,
|
||||
)
|
||||
from libp2p.host.ping import (
|
||||
ID as PingID,
|
||||
handle_ping,
|
||||
)
|
||||
from libp2p.host.ping import ID as PingID
|
||||
from libp2p.identity.identify.identify import (
|
||||
ID as IdentifyID,
|
||||
identify_handler_for,
|
||||
)
|
||||
from libp2p.identity.identify.identify import ID as IdentifyID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from libp2p.custom_types import (
|
||||
@ -26,8 +26,5 @@ if TYPE_CHECKING:
|
||||
|
||||
def get_default_protocols(host: IHost) -> "OrderedDict[TProtocol, StreamHandlerFn]":
|
||||
return OrderedDict(
|
||||
(
|
||||
(IdentifyID, identify_handler_for(host, use_varint_format=True)),
|
||||
(PingID, handle_ping),
|
||||
)
|
||||
((IdentifyID, identify_handler_for(host)), (PingID, handle_ping))
|
||||
)
|
||||
|
||||
@ -18,10 +18,8 @@ from libp2p.peer.peerinfo import (
|
||||
class RoutedHost(BasicHost):
|
||||
_router: IPeerRouting
|
||||
|
||||
def __init__(
|
||||
self, network: INetworkService, router: IPeerRouting, enable_mDNS: bool = False
|
||||
):
|
||||
super().__init__(network, enable_mDNS)
|
||||
def __init__(self, network: INetworkService, router: IPeerRouting):
|
||||
super().__init__(network)
|
||||
self._router = router
|
||||
|
||||
async def connect(self, peer_info: PeerInfo) -> None:
|
||||
@ -42,8 +40,8 @@ class RoutedHost(BasicHost):
|
||||
found_peer_info = await self._router.find_peer(peer_info.peer_id)
|
||||
if not found_peer_info:
|
||||
raise ConnectionFailure("Unable to find Peer address")
|
||||
self.peerstore.add_addrs(peer_info.peer_id, found_peer_info.addrs, 120)
|
||||
self.peerstore.add_addrs(peer_info.peer_id, peer_info.addrs, 120)
|
||||
self.peerstore.add_addrs(peer_info.peer_id, found_peer_info.addrs, 10)
|
||||
self.peerstore.add_addrs(peer_info.peer_id, peer_info.addrs, 10)
|
||||
|
||||
# there is already a connection to this peer
|
||||
if peer_info.peer_id in self._network.connections:
|
||||
|
||||
@ -1,4 +1,7 @@
|
||||
import logging
|
||||
from typing import (
|
||||
Optional,
|
||||
)
|
||||
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
@ -16,9 +19,7 @@ from libp2p.network.stream.exceptions import (
|
||||
StreamClosed,
|
||||
)
|
||||
from libp2p.utils import (
|
||||
decode_varint_with_size,
|
||||
get_agent_version,
|
||||
varint,
|
||||
)
|
||||
|
||||
from .pb.identify_pb2 import (
|
||||
@ -39,8 +40,8 @@ def _multiaddr_to_bytes(maddr: Multiaddr) -> bytes:
|
||||
|
||||
|
||||
def _remote_address_to_multiaddr(
|
||||
remote_address: tuple[str, int] | None,
|
||||
) -> Multiaddr | None:
|
||||
remote_address: Optional[tuple[str, int]]
|
||||
) -> Optional[Multiaddr]:
|
||||
"""Convert a (host, port) tuple to a Multiaddr."""
|
||||
if remote_address is None:
|
||||
return None
|
||||
@ -57,11 +58,11 @@ def _remote_address_to_multiaddr(
|
||||
|
||||
|
||||
def _mk_identify_protobuf(
|
||||
host: IHost, observed_multiaddr: Multiaddr | None
|
||||
host: IHost, observed_multiaddr: Optional[Multiaddr]
|
||||
) -> Identify:
|
||||
public_key = host.get_public_key()
|
||||
laddrs = host.get_addrs()
|
||||
protocols = tuple(str(p) for p in host.get_mux().get_protocols() if p is not None)
|
||||
protocols = host.get_mux().get_protocols()
|
||||
|
||||
observed_addr = observed_multiaddr.to_bytes() if observed_multiaddr else b""
|
||||
return Identify(
|
||||
@ -74,60 +75,21 @@ def _mk_identify_protobuf(
|
||||
)
|
||||
|
||||
|
||||
def parse_identify_response(response: bytes) -> Identify:
|
||||
"""
|
||||
Parse identify response that could be either:
|
||||
- Old format: raw protobuf
|
||||
- New format: length-prefixed protobuf
|
||||
|
||||
This function provides backward and forward compatibility.
|
||||
"""
|
||||
# Try new format first: length-prefixed protobuf
|
||||
if len(response) >= 1:
|
||||
length, varint_size = decode_varint_with_size(response)
|
||||
if varint_size > 0 and length > 0 and varint_size + length <= len(response):
|
||||
protobuf_data = response[varint_size : varint_size + length]
|
||||
try:
|
||||
identify_response = Identify()
|
||||
identify_response.ParseFromString(protobuf_data)
|
||||
# Sanity check: must have agent_version (protocol_version is optional)
|
||||
if identify_response.agent_version:
|
||||
logger.debug(
|
||||
"Parsed length-prefixed identify response (new format)"
|
||||
)
|
||||
return identify_response
|
||||
except Exception:
|
||||
pass # Fall through to old format
|
||||
|
||||
# Fall back to old format: raw protobuf
|
||||
try:
|
||||
identify_response = Identify()
|
||||
identify_response.ParseFromString(response)
|
||||
logger.debug("Parsed raw protobuf identify response (old format)")
|
||||
return identify_response
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse identify response: {e}")
|
||||
logger.error(f"Response length: {len(response)}")
|
||||
logger.error(f"Response hex: {response.hex()}")
|
||||
raise
|
||||
|
||||
|
||||
def identify_handler_for(
|
||||
host: IHost, use_varint_format: bool = False
|
||||
) -> StreamHandlerFn:
|
||||
def identify_handler_for(host: IHost) -> StreamHandlerFn:
|
||||
async def handle_identify(stream: INetStream) -> None:
|
||||
# get observed address from ``stream``
|
||||
peer_id = (
|
||||
stream.muxed_conn.peer_id
|
||||
) # remote peer_id is in class Mplex (mplex.py )
|
||||
observed_multiaddr: Multiaddr | None = None
|
||||
|
||||
# Get the remote address
|
||||
try:
|
||||
remote_address = stream.get_remote_address()
|
||||
# Convert to multiaddr
|
||||
if remote_address:
|
||||
observed_multiaddr = _remote_address_to_multiaddr(remote_address)
|
||||
|
||||
else:
|
||||
observed_multiaddr = None
|
||||
logger.debug(
|
||||
"Connection from remote peer %s, address: %s, multiaddr: %s",
|
||||
peer_id,
|
||||
@ -142,21 +104,7 @@ def identify_handler_for(
|
||||
response = protobuf.SerializeToString()
|
||||
|
||||
try:
|
||||
if use_varint_format:
|
||||
# Send length-prefixed protobuf message (new format)
|
||||
await stream.write(varint.encode_uvarint(len(response)))
|
||||
await stream.write(response)
|
||||
logger.debug(
|
||||
"Sent new format (length-prefixed) identify response to %s",
|
||||
peer_id,
|
||||
)
|
||||
else:
|
||||
# Send raw protobuf message (old format for backward compatibility)
|
||||
await stream.write(response)
|
||||
logger.debug(
|
||||
"Sent old format (raw protobuf) identify response to %s",
|
||||
peer_id,
|
||||
)
|
||||
await stream.write(response)
|
||||
except StreamClosed:
|
||||
logger.debug("Fail to respond to %s request: stream closed", ID)
|
||||
else:
|
||||
|
||||
@ -1,4 +1,7 @@
|
||||
import logging
|
||||
from typing import (
|
||||
Optional,
|
||||
)
|
||||
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
@ -25,10 +28,6 @@ from libp2p.peer.id import (
|
||||
)
|
||||
from libp2p.utils import (
|
||||
get_agent_version,
|
||||
varint,
|
||||
)
|
||||
from libp2p.utils.varint import (
|
||||
decode_varint_from_bytes,
|
||||
)
|
||||
|
||||
from ..identify.identify import (
|
||||
@ -44,72 +43,22 @@ logger = logging.getLogger(__name__)
|
||||
ID_PUSH = TProtocol("/ipfs/id/push/1.0.0")
|
||||
PROTOCOL_VERSION = "ipfs/0.1.0"
|
||||
AGENT_VERSION = get_agent_version()
|
||||
CONCURRENCY_LIMIT = 10
|
||||
|
||||
|
||||
def identify_push_handler_for(
|
||||
host: IHost, use_varint_format: bool = True
|
||||
) -> StreamHandlerFn:
|
||||
def identify_push_handler_for(host: IHost) -> StreamHandlerFn:
|
||||
"""
|
||||
Create a handler for the identify/push protocol.
|
||||
|
||||
This handler receives pushed identify messages from remote peers and updates
|
||||
the local peerstore with the new information.
|
||||
|
||||
Args:
|
||||
host: The libp2p host.
|
||||
use_varint_format: True=length-prefixed, False=raw protobuf.
|
||||
|
||||
"""
|
||||
|
||||
async def handle_identify_push(stream: INetStream) -> None:
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
|
||||
try:
|
||||
if use_varint_format:
|
||||
# Read length-prefixed identify message from the stream
|
||||
# First read the varint length prefix
|
||||
length_bytes = b""
|
||||
while True:
|
||||
b = await stream.read(1)
|
||||
if not b:
|
||||
break
|
||||
length_bytes += b
|
||||
if b[0] & 0x80 == 0:
|
||||
break
|
||||
|
||||
if not length_bytes:
|
||||
logger.warning("No length prefix received from peer %s", peer_id)
|
||||
return
|
||||
|
||||
msg_length = decode_varint_from_bytes(length_bytes)
|
||||
|
||||
# Read the protobuf message
|
||||
data = await stream.read(msg_length)
|
||||
if len(data) != msg_length:
|
||||
logger.warning("Incomplete message received from peer %s", peer_id)
|
||||
return
|
||||
else:
|
||||
# Read raw protobuf message from the stream
|
||||
# For raw format, we need to read all data before the stream is closed
|
||||
data = b""
|
||||
try:
|
||||
# Read all available data in a single operation
|
||||
data = await stream.read()
|
||||
except StreamClosed:
|
||||
# Try to read any remaining data
|
||||
try:
|
||||
data = await stream.read()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# If we got no data, log a warning and return
|
||||
if not data:
|
||||
logger.warning(
|
||||
"No data received in raw format from peer %s", peer_id
|
||||
)
|
||||
return
|
||||
|
||||
# Read the identify message from the stream
|
||||
data = await stream.read()
|
||||
identify_msg = Identify()
|
||||
identify_msg.ParseFromString(data)
|
||||
|
||||
@ -186,11 +135,7 @@ async def _update_peerstore_from_identify(
|
||||
|
||||
|
||||
async def push_identify_to_peer(
|
||||
host: IHost,
|
||||
peer_id: ID,
|
||||
observed_multiaddr: Multiaddr | None = None,
|
||||
limit: trio.Semaphore = trio.Semaphore(CONCURRENCY_LIMIT),
|
||||
use_varint_format: bool = True,
|
||||
host: IHost, peer_id: ID, observed_multiaddr: Optional[Multiaddr] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Push an identify message to a specific peer.
|
||||
@ -198,78 +143,52 @@ async def push_identify_to_peer(
|
||||
This function opens a stream to the peer using the identify/push protocol,
|
||||
sends the identify message, and closes the stream.
|
||||
|
||||
Args:
|
||||
host: The libp2p host.
|
||||
peer_id: The peer ID to push to.
|
||||
observed_multiaddr: The observed multiaddress (optional).
|
||||
limit: Semaphore for concurrency control.
|
||||
use_varint_format: True=length-prefixed, False=raw protobuf.
|
||||
|
||||
Returns:
|
||||
bool: True if the push was successful, False otherwise.
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the push was successful, False otherwise.
|
||||
|
||||
"""
|
||||
async with limit:
|
||||
try:
|
||||
# Create a new stream to the peer using the identify/push protocol
|
||||
stream = await host.new_stream(peer_id, [ID_PUSH])
|
||||
try:
|
||||
# Create a new stream to the peer using the identify/push protocol
|
||||
stream = await host.new_stream(peer_id, [ID_PUSH])
|
||||
|
||||
# Create the identify message
|
||||
identify_msg = _mk_identify_protobuf(host, observed_multiaddr)
|
||||
response = identify_msg.SerializeToString()
|
||||
# Create the identify message
|
||||
identify_msg = _mk_identify_protobuf(host, observed_multiaddr)
|
||||
response = identify_msg.SerializeToString()
|
||||
|
||||
if use_varint_format:
|
||||
# Send length-prefixed identify message
|
||||
await stream.write(varint.encode_uvarint(len(response)))
|
||||
await stream.write(response)
|
||||
else:
|
||||
# Send raw protobuf message
|
||||
await stream.write(response)
|
||||
# Send the identify message
|
||||
await stream.write(response)
|
||||
|
||||
# Close the stream
|
||||
await stream.close()
|
||||
# Close the stream
|
||||
await stream.close()
|
||||
|
||||
logger.debug("Successfully pushed identify to peer %s", peer_id)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Error pushing identify to peer %s: %s", peer_id, e)
|
||||
return False
|
||||
logger.debug("Successfully pushed identify to peer %s", peer_id)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Error pushing identify to peer %s: %s", peer_id, e)
|
||||
return False
|
||||
|
||||
|
||||
async def push_identify_to_peers(
|
||||
host: IHost,
|
||||
peer_ids: set[ID] | None = None,
|
||||
observed_multiaddr: Multiaddr | None = None,
|
||||
use_varint_format: bool = True,
|
||||
peer_ids: Optional[set[ID]] = None,
|
||||
observed_multiaddr: Optional[Multiaddr] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Push an identify message to multiple peers in parallel.
|
||||
|
||||
If peer_ids is None, push to all connected peers.
|
||||
|
||||
Args:
|
||||
host: The libp2p host.
|
||||
peer_ids: Set of peer IDs to push to (if None, push to all connected peers).
|
||||
observed_multiaddr: The observed multiaddress (optional).
|
||||
use_varint_format: True=length-prefixed, False=raw protobuf.
|
||||
|
||||
"""
|
||||
if peer_ids is None:
|
||||
# Get all connected peers
|
||||
peer_ids = set(host.get_connected_peers())
|
||||
|
||||
# Create a single shared semaphore for concurrency control
|
||||
limit = trio.Semaphore(CONCURRENCY_LIMIT)
|
||||
peer_ids = set(host.get_peerstore().peer_ids())
|
||||
|
||||
# Push to each peer in parallel using a trio.Nursery
|
||||
# limiting concurrent connections to CONCURRENCY_LIMIT
|
||||
# TODO: Consider using a bounded nursery to limit concurrency
|
||||
# and avoid overwhelming the network. This can be done by using
|
||||
# trio.open_nursery(max_concurrent=10) or similar.
|
||||
# For now, we will use an unbounded nursery for simplicity.
|
||||
async with trio.open_nursery() as nursery:
|
||||
for peer_id in peer_ids:
|
||||
nursery.start_soon(
|
||||
push_identify_to_peer,
|
||||
host,
|
||||
peer_id,
|
||||
observed_multiaddr,
|
||||
limit,
|
||||
use_varint_format,
|
||||
)
|
||||
nursery.start_soon(push_identify_to_peer, host, peer_id, observed_multiaddr)
|
||||
|
||||
@ -2,22 +2,27 @@ from abc import (
|
||||
ABC,
|
||||
abstractmethod,
|
||||
)
|
||||
from typing import Any
|
||||
from typing import (
|
||||
Optional,
|
||||
)
|
||||
|
||||
|
||||
class Closer(ABC):
|
||||
@abstractmethod
|
||||
async def close(self) -> None: ...
|
||||
async def close(self) -> None:
|
||||
...
|
||||
|
||||
|
||||
class Reader(ABC):
|
||||
@abstractmethod
|
||||
async def read(self, n: int | None = None) -> bytes: ...
|
||||
async def read(self, n: int = None) -> bytes:
|
||||
...
|
||||
|
||||
|
||||
class Writer(ABC):
|
||||
@abstractmethod
|
||||
async def write(self, data: bytes) -> None: ...
|
||||
async def write(self, data: bytes) -> None:
|
||||
...
|
||||
|
||||
|
||||
class WriteCloser(Writer, Closer):
|
||||
@ -34,7 +39,7 @@ class ReadWriter(Reader, Writer):
|
||||
|
||||
class ReadWriteCloser(Reader, Writer, Closer):
|
||||
@abstractmethod
|
||||
def get_remote_address(self) -> tuple[str, int] | None:
|
||||
def get_remote_address(self) -> Optional[tuple[str, int]]:
|
||||
"""
|
||||
Return the remote address of the connected peer.
|
||||
|
||||
@ -45,12 +50,14 @@ class ReadWriteCloser(Reader, Writer, Closer):
|
||||
|
||||
class MsgReader(ABC):
|
||||
@abstractmethod
|
||||
async def read_msg(self) -> bytes: ...
|
||||
async def read_msg(self) -> bytes:
|
||||
...
|
||||
|
||||
|
||||
class MsgWriter(ABC):
|
||||
@abstractmethod
|
||||
async def write_msg(self, msg: bytes) -> None: ...
|
||||
async def write_msg(self, msg: bytes) -> None:
|
||||
...
|
||||
|
||||
|
||||
class MsgReadWriteCloser(MsgReader, MsgWriter, Closer):
|
||||
@ -59,26 +66,19 @@ class MsgReadWriteCloser(MsgReader, MsgWriter, Closer):
|
||||
|
||||
class Encrypter(ABC):
|
||||
@abstractmethod
|
||||
def encrypt(self, data: bytes) -> bytes: ...
|
||||
def encrypt(self, data: bytes) -> bytes:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def decrypt(self, data: bytes) -> bytes: ...
|
||||
def decrypt(self, data: bytes) -> bytes:
|
||||
...
|
||||
|
||||
|
||||
class EncryptedMsgReadWriter(MsgReadWriteCloser, Encrypter):
|
||||
"""Read/write message with encryption/decryption."""
|
||||
|
||||
conn: Any | None
|
||||
|
||||
def __init__(self, conn: Any | None = None):
|
||||
self.conn = conn
|
||||
|
||||
def get_remote_address(self) -> tuple[str, int] | None:
|
||||
def get_remote_address(self) -> Optional[tuple[str, int]]:
|
||||
"""Get remote address if supported by the underlying connection."""
|
||||
if (
|
||||
self.conn is not None
|
||||
and hasattr(self, "conn")
|
||||
and hasattr(self.conn, "get_remote_address")
|
||||
):
|
||||
if hasattr(self, "conn") and hasattr(self.conn, "get_remote_address"):
|
||||
return self.conn.get_remote_address()
|
||||
return None
|
||||
|
||||
@ -5,7 +5,6 @@ from that repo: "a simple package to r/w length-delimited slices."
|
||||
|
||||
NOTE: currently missing the capability to indicate lengths by "varint" method.
|
||||
"""
|
||||
|
||||
from abc import (
|
||||
abstractmethod,
|
||||
)
|
||||
@ -61,10 +60,12 @@ class BaseMsgReadWriter(MsgReadWriteCloser):
|
||||
return await read_exactly(self.read_write_closer, length)
|
||||
|
||||
@abstractmethod
|
||||
async def next_msg_len(self) -> int: ...
|
||||
async def next_msg_len(self) -> int:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def encode_msg(self, msg: bytes) -> bytes: ...
|
||||
def encode_msg(self, msg: bytes) -> bytes:
|
||||
...
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.read_write_closer.close()
|
||||
|
||||
@ -1,4 +1,7 @@
|
||||
import logging
|
||||
from typing import (
|
||||
Optional,
|
||||
)
|
||||
|
||||
import trio
|
||||
|
||||
@ -31,7 +34,7 @@ class TrioTCPStream(ReadWriteCloser):
|
||||
except (trio.ClosedResourceError, trio.BrokenResourceError) as error:
|
||||
raise IOException from error
|
||||
|
||||
async def read(self, n: int | None = None) -> bytes:
|
||||
async def read(self, n: int = None) -> bytes:
|
||||
async with self.read_lock:
|
||||
if n is not None and n == 0:
|
||||
return b""
|
||||
@ -43,7 +46,7 @@ class TrioTCPStream(ReadWriteCloser):
|
||||
async def close(self) -> None:
|
||||
await self.stream.aclose()
|
||||
|
||||
def get_remote_address(self) -> tuple[str, int] | None:
|
||||
def get_remote_address(self) -> Optional[tuple[str, int]]:
|
||||
"""Return the remote address as (host, port) tuple."""
|
||||
try:
|
||||
return self.stream.socket.getpeername()
|
||||
|
||||
@ -14,14 +14,12 @@ async def read_exactly(
|
||||
"""
|
||||
NOTE: relying on exceptions to break out on erroneous conditions, like EOF
|
||||
"""
|
||||
buffer = bytearray()
|
||||
buffer.extend(await reader.read(n))
|
||||
data = await reader.read(n)
|
||||
|
||||
for _ in range(retry_count):
|
||||
if len(buffer) < n:
|
||||
remaining = n - len(buffer)
|
||||
buffer.extend(await reader.read(remaining))
|
||||
|
||||
if len(data) < n:
|
||||
remaining = n - len(data)
|
||||
data += await reader.read(remaining)
|
||||
else:
|
||||
return bytes(buffer)
|
||||
raise IncompleteReadError({"requested_count": n, "received_count": len(buffer)})
|
||||
return data
|
||||
raise IncompleteReadError({"requested_count": n, "received_count": len(data)})
|
||||
|
||||
@ -1,30 +0,0 @@
|
||||
"""
|
||||
Kademlia DHT implementation for py-libp2p.
|
||||
|
||||
This module provides a Distributed Hash Table (DHT) implementation
|
||||
based on the Kademlia protocol.
|
||||
"""
|
||||
|
||||
from .kad_dht import (
|
||||
KadDHT,
|
||||
)
|
||||
from .peer_routing import (
|
||||
PeerRouting,
|
||||
)
|
||||
from .routing_table import (
|
||||
RoutingTable,
|
||||
)
|
||||
from .utils import (
|
||||
create_key_from_binary,
|
||||
)
|
||||
from .value_store import (
|
||||
ValueStore,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"KadDHT",
|
||||
"RoutingTable",
|
||||
"PeerRouting",
|
||||
"ValueStore",
|
||||
"create_key_from_binary",
|
||||
]
|
||||
@ -1,14 +0,0 @@
|
||||
"""
|
||||
Shared constants and protocol parameters for the Kademlia DHT.
|
||||
"""
|
||||
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
|
||||
# Constants for the Kademlia algorithm
|
||||
ALPHA = 3 # Concurrency parameter
|
||||
PROTOCOL_ID = TProtocol("/ipfs/kad/1.0.0")
|
||||
QUERY_TIMEOUT = 10
|
||||
|
||||
TTL = DEFAULT_TTL = 24 * 60 * 60 # 24 hours in seconds
|
||||
@ -1,616 +0,0 @@
|
||||
"""
|
||||
Kademlia DHT implementation for py-libp2p.
|
||||
|
||||
This module provides a complete Distributed Hash Table (DHT)
|
||||
implementation based on the Kademlia algorithm and protocol.
|
||||
"""
|
||||
|
||||
from enum import (
|
||||
Enum,
|
||||
)
|
||||
import logging
|
||||
import time
|
||||
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
)
|
||||
import trio
|
||||
import varint
|
||||
|
||||
from libp2p.abc import (
|
||||
IHost,
|
||||
)
|
||||
from libp2p.network.stream.net_stream import (
|
||||
INetStream,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
from libp2p.tools.async_service import (
|
||||
Service,
|
||||
)
|
||||
|
||||
from .common import (
|
||||
ALPHA,
|
||||
PROTOCOL_ID,
|
||||
QUERY_TIMEOUT,
|
||||
)
|
||||
from .pb.kademlia_pb2 import (
|
||||
Message,
|
||||
)
|
||||
from .peer_routing import (
|
||||
PeerRouting,
|
||||
)
|
||||
from .provider_store import (
|
||||
ProviderStore,
|
||||
)
|
||||
from .routing_table import (
|
||||
RoutingTable,
|
||||
)
|
||||
from .value_store import (
|
||||
ValueStore,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("kademlia-example.kad_dht")
|
||||
# logger = logging.getLogger("libp2p.kademlia")
|
||||
# Default parameters
|
||||
ROUTING_TABLE_REFRESH_INTERVAL = 60 # 1 min in seconds for testing
|
||||
|
||||
|
||||
class DHTMode(Enum):
|
||||
"""DHT operation modes."""
|
||||
|
||||
CLIENT = "CLIENT"
|
||||
SERVER = "SERVER"
|
||||
|
||||
|
||||
class KadDHT(Service):
|
||||
"""
|
||||
Kademlia DHT implementation for libp2p.
|
||||
|
||||
This class provides a DHT implementation that combines routing table management,
|
||||
peer discovery, content routing, and value storage.
|
||||
"""
|
||||
|
||||
def __init__(self, host: IHost, mode: DHTMode):
|
||||
"""
|
||||
Initialize a new Kademlia DHT node.
|
||||
|
||||
:param host: The libp2p host.
|
||||
:param mode: The mode of host (Client or Server) - must be DHTMode enum
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.host = host
|
||||
self.local_peer_id = host.get_id()
|
||||
|
||||
# Validate that mode is a DHTMode enum
|
||||
if not isinstance(mode, DHTMode):
|
||||
raise TypeError(f"mode must be DHTMode enum, got {type(mode)}")
|
||||
|
||||
self.mode = mode
|
||||
|
||||
# Initialize the routing table
|
||||
self.routing_table = RoutingTable(self.local_peer_id, self.host)
|
||||
|
||||
# Initialize peer routing
|
||||
self.peer_routing = PeerRouting(host, self.routing_table)
|
||||
|
||||
# Initialize value store
|
||||
self.value_store = ValueStore(host=host, local_peer_id=self.local_peer_id)
|
||||
|
||||
# Initialize provider store with host and peer_routing references
|
||||
self.provider_store = ProviderStore(host=host, peer_routing=self.peer_routing)
|
||||
|
||||
# Last time we republished provider records
|
||||
self._last_provider_republish = time.time()
|
||||
|
||||
# Set protocol handlers
|
||||
host.set_stream_handler(PROTOCOL_ID, self.handle_stream)
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run the DHT service."""
|
||||
logger.info(f"Starting Kademlia DHT with peer ID {self.local_peer_id}")
|
||||
|
||||
# Main service loop
|
||||
while self.manager.is_running:
|
||||
# Periodically refresh the routing table
|
||||
await self.refresh_routing_table()
|
||||
|
||||
# Check if it's time to republish provider records
|
||||
current_time = time.time()
|
||||
# await self._republish_provider_records()
|
||||
self._last_provider_republish = current_time
|
||||
|
||||
# Clean up expired values and provider records
|
||||
expired_values = self.value_store.cleanup_expired()
|
||||
if expired_values > 0:
|
||||
logger.debug(f"Cleaned up {expired_values} expired values")
|
||||
|
||||
self.provider_store.cleanup_expired()
|
||||
|
||||
# Wait before next maintenance cycle
|
||||
await trio.sleep(ROUTING_TABLE_REFRESH_INTERVAL)
|
||||
|
||||
async def switch_mode(self, new_mode: DHTMode) -> DHTMode:
|
||||
"""
|
||||
Switch the DHT mode.
|
||||
|
||||
:param new_mode: The new mode - must be DHTMode enum
|
||||
:return: The new mode as DHTMode enum
|
||||
"""
|
||||
# Validate that new_mode is a DHTMode enum
|
||||
if not isinstance(new_mode, DHTMode):
|
||||
raise TypeError(f"new_mode must be DHTMode enum, got {type(new_mode)}")
|
||||
|
||||
if new_mode == DHTMode.CLIENT:
|
||||
self.routing_table.cleanup_routing_table()
|
||||
self.mode = new_mode
|
||||
logger.info(f"Switched to {new_mode.value} mode")
|
||||
return self.mode
|
||||
|
||||
async def handle_stream(self, stream: INetStream) -> None:
|
||||
"""
|
||||
Handle an incoming DHT stream using varint length prefixes.
|
||||
"""
|
||||
if self.mode == DHTMode.CLIENT:
|
||||
stream.close
|
||||
return
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
logger.debug(f"Received DHT stream from peer {peer_id}")
|
||||
await self.add_peer(peer_id)
|
||||
logger.debug(f"Added peer {peer_id} to routing table")
|
||||
|
||||
try:
|
||||
# Read varint-prefixed length for the message
|
||||
length_prefix = b""
|
||||
while True:
|
||||
byte = await stream.read(1)
|
||||
if not byte:
|
||||
logger.warning("Stream closed while reading varint length")
|
||||
await stream.close()
|
||||
return
|
||||
length_prefix += byte
|
||||
if byte[0] & 0x80 == 0:
|
||||
break
|
||||
msg_length = varint.decode_bytes(length_prefix)
|
||||
|
||||
# Read the message bytes
|
||||
msg_bytes = await stream.read(msg_length)
|
||||
if len(msg_bytes) < msg_length:
|
||||
logger.warning("Failed to read full message from stream")
|
||||
await stream.close()
|
||||
return
|
||||
|
||||
try:
|
||||
# Parse as protobuf
|
||||
message = Message()
|
||||
message.ParseFromString(msg_bytes)
|
||||
logger.debug(
|
||||
f"Received DHT message from {peer_id}, type: {message.type}"
|
||||
)
|
||||
|
||||
# Handle FIND_NODE message
|
||||
if message.type == Message.MessageType.FIND_NODE:
|
||||
# Get target key directly from protobuf
|
||||
target_key = message.key
|
||||
|
||||
# Find closest peers to the target key
|
||||
closest_peers = self.routing_table.find_local_closest_peers(
|
||||
target_key, 20
|
||||
)
|
||||
logger.debug(f"Found {len(closest_peers)} peers close to target")
|
||||
|
||||
# Build response message with protobuf
|
||||
response = Message()
|
||||
response.type = Message.MessageType.FIND_NODE
|
||||
|
||||
# Add closest peers to response
|
||||
for peer in closest_peers:
|
||||
# Skip if the peer is the requester
|
||||
if peer == peer_id:
|
||||
continue
|
||||
|
||||
# Add peer to closerPeers field
|
||||
peer_proto = response.closerPeers.add()
|
||||
peer_proto.id = peer.to_bytes()
|
||||
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
|
||||
|
||||
# Add addresses if available
|
||||
try:
|
||||
addrs = self.host.get_peerstore().addrs(peer)
|
||||
if addrs:
|
||||
for addr in addrs:
|
||||
peer_proto.addrs.append(addr.to_bytes())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Serialize and send response
|
||||
response_bytes = response.SerializeToString()
|
||||
await stream.write(varint.encode(len(response_bytes)))
|
||||
await stream.write(response_bytes)
|
||||
logger.debug(
|
||||
f"Sent FIND_NODE response with{len(response.closerPeers)} peers"
|
||||
)
|
||||
|
||||
# Handle ADD_PROVIDER message
|
||||
elif message.type == Message.MessageType.ADD_PROVIDER:
|
||||
# Process ADD_PROVIDER
|
||||
key = message.key
|
||||
logger.debug(f"Received ADD_PROVIDER for key {key.hex()}")
|
||||
|
||||
# Extract provider information
|
||||
for provider_proto in message.providerPeers:
|
||||
try:
|
||||
# Validate that the provider is the sender
|
||||
provider_id = ID(provider_proto.id)
|
||||
if provider_id != peer_id:
|
||||
logger.warning(
|
||||
f"Provider ID {provider_id} doesn't"
|
||||
f"match sender {peer_id}, ignoring"
|
||||
)
|
||||
continue
|
||||
|
||||
# Convert addresses to Multiaddr
|
||||
addrs = []
|
||||
for addr_bytes in provider_proto.addrs:
|
||||
try:
|
||||
addrs.append(Multiaddr(addr_bytes))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse address: {e}")
|
||||
|
||||
# Add to provider store
|
||||
provider_info = PeerInfo(provider_id, addrs)
|
||||
self.provider_store.add_provider(key, provider_info)
|
||||
logger.debug(
|
||||
f"Added provider {provider_id} for key {key.hex()}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process provider info: {e}")
|
||||
|
||||
# Send acknowledgement
|
||||
response = Message()
|
||||
response.type = Message.MessageType.ADD_PROVIDER
|
||||
response.key = key
|
||||
|
||||
response_bytes = response.SerializeToString()
|
||||
await stream.write(varint.encode(len(response_bytes)))
|
||||
await stream.write(response_bytes)
|
||||
logger.debug("Sent ADD_PROVIDER acknowledgement")
|
||||
|
||||
# Handle GET_PROVIDERS message
|
||||
elif message.type == Message.MessageType.GET_PROVIDERS:
|
||||
# Process GET_PROVIDERS
|
||||
key = message.key
|
||||
logger.debug(f"Received GET_PROVIDERS request for key {key.hex()}")
|
||||
|
||||
# Find providers for the key
|
||||
providers = self.provider_store.get_providers(key)
|
||||
logger.debug(
|
||||
f"Found {len(providers)} providers for key {key.hex()}"
|
||||
)
|
||||
|
||||
# Create response
|
||||
response = Message()
|
||||
response.type = Message.MessageType.GET_PROVIDERS
|
||||
response.key = key
|
||||
|
||||
# Add provider information to response
|
||||
for provider_info in providers:
|
||||
provider_proto = response.providerPeers.add()
|
||||
provider_proto.id = provider_info.peer_id.to_bytes()
|
||||
provider_proto.connection = Message.ConnectionType.CAN_CONNECT
|
||||
|
||||
# Add addresses if available
|
||||
for addr in provider_info.addrs:
|
||||
provider_proto.addrs.append(addr.to_bytes())
|
||||
|
||||
# Also include closest peers if we don't have providers
|
||||
if not providers:
|
||||
closest_peers = self.routing_table.find_local_closest_peers(
|
||||
key, 20
|
||||
)
|
||||
logger.debug(
|
||||
f"No providers found, including {len(closest_peers)}"
|
||||
"closest peers"
|
||||
)
|
||||
|
||||
for peer in closest_peers:
|
||||
# Skip if peer is the requester
|
||||
if peer == peer_id:
|
||||
continue
|
||||
|
||||
peer_proto = response.closerPeers.add()
|
||||
peer_proto.id = peer.to_bytes()
|
||||
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
|
||||
|
||||
# Add addresses if available
|
||||
try:
|
||||
addrs = self.host.get_peerstore().addrs(peer)
|
||||
for addr in addrs:
|
||||
peer_proto.addrs.append(addr.to_bytes())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Serialize and send response
|
||||
response_bytes = response.SerializeToString()
|
||||
await stream.write(varint.encode(len(response_bytes)))
|
||||
await stream.write(response_bytes)
|
||||
logger.debug("Sent GET_PROVIDERS response")
|
||||
|
||||
# Handle GET_VALUE message
|
||||
elif message.type == Message.MessageType.GET_VALUE:
|
||||
# Process GET_VALUE
|
||||
key = message.key
|
||||
logger.debug(f"Received GET_VALUE request for key {key.hex()}")
|
||||
|
||||
value = self.value_store.get(key)
|
||||
if value:
|
||||
logger.debug(f"Found value for key {key.hex()}")
|
||||
|
||||
# Create response using protobuf
|
||||
response = Message()
|
||||
response.type = Message.MessageType.GET_VALUE
|
||||
|
||||
# Create record
|
||||
response.key = key
|
||||
response.record.key = key
|
||||
response.record.value = value
|
||||
response.record.timeReceived = str(time.time())
|
||||
|
||||
# Serialize and send response
|
||||
response_bytes = response.SerializeToString()
|
||||
await stream.write(varint.encode(len(response_bytes)))
|
||||
await stream.write(response_bytes)
|
||||
logger.debug("Sent GET_VALUE response")
|
||||
else:
|
||||
logger.debug(f"No value found for key {key.hex()}")
|
||||
|
||||
# Create response with closest peers when no value is found
|
||||
response = Message()
|
||||
response.type = Message.MessageType.GET_VALUE
|
||||
response.key = key
|
||||
|
||||
# Add closest peers to key
|
||||
closest_peers = self.routing_table.find_local_closest_peers(
|
||||
key, 20
|
||||
)
|
||||
logger.debug(
|
||||
"No value found,"
|
||||
f"including {len(closest_peers)} closest peers"
|
||||
)
|
||||
|
||||
for peer in closest_peers:
|
||||
# Skip if peer is the requester
|
||||
if peer == peer_id:
|
||||
continue
|
||||
|
||||
peer_proto = response.closerPeers.add()
|
||||
peer_proto.id = peer.to_bytes()
|
||||
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
|
||||
|
||||
# Add addresses if available
|
||||
try:
|
||||
addrs = self.host.get_peerstore().addrs(peer)
|
||||
for addr in addrs:
|
||||
peer_proto.addrs.append(addr.to_bytes())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Serialize and send response
|
||||
response_bytes = response.SerializeToString()
|
||||
await stream.write(varint.encode(len(response_bytes)))
|
||||
await stream.write(response_bytes)
|
||||
logger.debug("Sent GET_VALUE response with closest peers")
|
||||
|
||||
# Handle PUT_VALUE message
|
||||
elif message.type == Message.MessageType.PUT_VALUE and message.HasField(
|
||||
"record"
|
||||
):
|
||||
# Process PUT_VALUE
|
||||
key = message.record.key
|
||||
value = message.record.value
|
||||
success = False
|
||||
try:
|
||||
if not (key and value):
|
||||
raise ValueError(
|
||||
"Missing key or value in PUT_VALUE message"
|
||||
)
|
||||
|
||||
self.value_store.put(key, value)
|
||||
logger.debug(f"Stored value {value.hex()} for key {key.hex()}")
|
||||
success = True
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to store value {value.hex()} for key "
|
||||
f"{key.hex()}: {e}"
|
||||
)
|
||||
finally:
|
||||
# Send acknowledgement
|
||||
response = Message()
|
||||
response.type = Message.MessageType.PUT_VALUE
|
||||
if success:
|
||||
response.key = key
|
||||
response_bytes = response.SerializeToString()
|
||||
await stream.write(varint.encode(len(response_bytes)))
|
||||
await stream.write(response_bytes)
|
||||
logger.debug("Sent PUT_VALUE acknowledgement")
|
||||
|
||||
except Exception as proto_err:
|
||||
logger.warning(f"Failed to parse protobuf message: {proto_err}")
|
||||
|
||||
await stream.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling DHT stream: {e}")
|
||||
await stream.close()
|
||||
|
||||
async def refresh_routing_table(self) -> None:
|
||||
"""Refresh the routing table."""
|
||||
logger.debug("Refreshing routing table")
|
||||
await self.peer_routing.refresh_routing_table()
|
||||
|
||||
# Peer routing methods
|
||||
|
||||
async def find_peer(self, peer_id: ID) -> PeerInfo | None:
|
||||
"""
|
||||
Find a peer with the given ID.
|
||||
"""
|
||||
logger.debug(f"Finding peer: {peer_id}")
|
||||
return await self.peer_routing.find_peer(peer_id)
|
||||
|
||||
# Value storage and retrieval methods
|
||||
|
||||
async def put_value(self, key: bytes, value: bytes) -> None:
|
||||
"""
|
||||
Store a value in the DHT.
|
||||
"""
|
||||
logger.debug(f"Storing value for key {key.hex()}")
|
||||
|
||||
# 1. Store locally first
|
||||
self.value_store.put(key, value)
|
||||
try:
|
||||
decoded_value = value.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
decoded_value = value.hex()
|
||||
logger.debug(
|
||||
f"Stored value locally for key {key.hex()} with value {decoded_value}"
|
||||
)
|
||||
|
||||
# 2. Get closest peers, excluding self
|
||||
closest_peers = [
|
||||
peer
|
||||
for peer in self.routing_table.find_local_closest_peers(key)
|
||||
if peer != self.local_peer_id
|
||||
]
|
||||
logger.debug(f"Found {len(closest_peers)} peers to store value at")
|
||||
|
||||
# 3. Store at remote peers in batches of ALPHA, in parallel
|
||||
stored_count = 0
|
||||
for i in range(0, len(closest_peers), ALPHA):
|
||||
batch = closest_peers[i : i + ALPHA]
|
||||
batch_results = [False] * len(batch)
|
||||
|
||||
async def store_one(idx: int, peer: ID) -> None:
|
||||
try:
|
||||
with trio.move_on_after(QUERY_TIMEOUT):
|
||||
success = await self.value_store._store_at_peer(
|
||||
peer, key, value
|
||||
)
|
||||
batch_results[idx] = success
|
||||
if success:
|
||||
logger.debug(f"Stored value at peer {peer}")
|
||||
else:
|
||||
logger.debug(f"Failed to store value at peer {peer}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Error storing value at peer {peer}: {e}")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for idx, peer in enumerate(batch):
|
||||
nursery.start_soon(store_one, idx, peer)
|
||||
|
||||
stored_count += sum(batch_results)
|
||||
|
||||
logger.info(f"Successfully stored value at {stored_count} peers")
|
||||
|
||||
async def get_value(self, key: bytes) -> bytes | None:
|
||||
logger.debug(f"Getting value for key: {key.hex()}")
|
||||
|
||||
# 1. Check local store first
|
||||
value = self.value_store.get(key)
|
||||
if value:
|
||||
logger.debug("Found value locally")
|
||||
return value
|
||||
|
||||
# 2. Get closest peers, excluding self
|
||||
closest_peers = [
|
||||
peer
|
||||
for peer in self.routing_table.find_local_closest_peers(key)
|
||||
if peer != self.local_peer_id
|
||||
]
|
||||
logger.debug(f"Searching {len(closest_peers)} peers for value")
|
||||
|
||||
# 3. Query ALPHA peers at a time in parallel
|
||||
for i in range(0, len(closest_peers), ALPHA):
|
||||
batch = closest_peers[i : i + ALPHA]
|
||||
found_value = None
|
||||
|
||||
async def query_one(peer: ID) -> None:
|
||||
nonlocal found_value
|
||||
try:
|
||||
with trio.move_on_after(QUERY_TIMEOUT):
|
||||
value = await self.value_store._get_from_peer(peer, key)
|
||||
if value is not None and found_value is None:
|
||||
found_value = value
|
||||
logger.debug(f"Found value at peer {peer}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Error querying peer {peer}: {e}")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for peer in batch:
|
||||
nursery.start_soon(query_one, peer)
|
||||
|
||||
if found_value is not None:
|
||||
self.value_store.put(key, found_value)
|
||||
logger.info("Successfully retrieved value from network")
|
||||
return found_value
|
||||
|
||||
# 4. Not found
|
||||
logger.warning(f"Value not found for key {key.hex()}")
|
||||
return None
|
||||
|
||||
# Add these methods in the Utility methods section
|
||||
|
||||
# Utility methods
|
||||
|
||||
async def add_peer(self, peer_id: ID) -> bool:
|
||||
"""
|
||||
Add a peer to the routing table.
|
||||
|
||||
params: peer_id: The peer ID to add.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if peer was added or updated, False otherwise.
|
||||
|
||||
"""
|
||||
return await self.routing_table.add_peer(peer_id)
|
||||
|
||||
async def provide(self, key: bytes) -> bool:
|
||||
"""
|
||||
Reference to provider_store.provide for convenience.
|
||||
"""
|
||||
return await self.provider_store.provide(key)
|
||||
|
||||
async def find_providers(self, key: bytes, count: int = 20) -> list[PeerInfo]:
|
||||
"""
|
||||
Reference to provider_store.find_providers for convenience.
|
||||
"""
|
||||
return await self.provider_store.find_providers(key, count)
|
||||
|
||||
def get_routing_table_size(self) -> int:
|
||||
"""
|
||||
Get the number of peers in the routing table.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Number of peers.
|
||||
|
||||
"""
|
||||
return self.routing_table.size()
|
||||
|
||||
def get_value_store_size(self) -> int:
|
||||
"""
|
||||
Get the number of items in the value store.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Number of items.
|
||||
|
||||
"""
|
||||
return self.value_store.size()
|
||||
@ -1,38 +0,0 @@
|
||||
syntax = "proto3";
|
||||
|
||||
message Record {
|
||||
bytes key = 1;
|
||||
bytes value = 2;
|
||||
string timeReceived = 5;
|
||||
};
|
||||
|
||||
message Message {
|
||||
enum MessageType {
|
||||
PUT_VALUE = 0;
|
||||
GET_VALUE = 1;
|
||||
ADD_PROVIDER = 2;
|
||||
GET_PROVIDERS = 3;
|
||||
FIND_NODE = 4;
|
||||
PING = 5;
|
||||
}
|
||||
|
||||
enum ConnectionType {
|
||||
NOT_CONNECTED = 0;
|
||||
CONNECTED = 1;
|
||||
CAN_CONNECT = 2;
|
||||
CANNOT_CONNECT = 3;
|
||||
}
|
||||
|
||||
message Peer {
|
||||
bytes id = 1;
|
||||
repeated bytes addrs = 2;
|
||||
ConnectionType connection = 3;
|
||||
}
|
||||
|
||||
MessageType type = 1;
|
||||
int32 clusterLevelRaw = 10;
|
||||
bytes key = 2;
|
||||
Record record = 3;
|
||||
repeated Peer closerPeers = 8;
|
||||
repeated Peer providerPeers = 9;
|
||||
}
|
||||
@ -1,33 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: libp2p/kad_dht/pb/kademlia.proto
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/kad_dht/pb/kademlia.proto\":\n\x06Record\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x14\n\x0ctimeReceived\x18\x05 \x01(\t\"\xca\x03\n\x07Message\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.Message.MessageType\x12\x17\n\x0f\x63lusterLevelRaw\x18\n \x01(\x05\x12\x0b\n\x03key\x18\x02 \x01(\x0c\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record\x12\"\n\x0b\x63loserPeers\x18\x08 \x03(\x0b\x32\r.Message.Peer\x12$\n\rproviderPeers\x18\t \x03(\x0b\x32\r.Message.Peer\x1aN\n\x04Peer\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12+\n\nconnection\x18\x03 \x01(\x0e\x32\x17.Message.ConnectionType\"i\n\x0bMessageType\x12\r\n\tPUT_VALUE\x10\x00\x12\r\n\tGET_VALUE\x10\x01\x12\x10\n\x0c\x41\x44\x44_PROVIDER\x10\x02\x12\x11\n\rGET_PROVIDERS\x10\x03\x12\r\n\tFIND_NODE\x10\x04\x12\x08\n\x04PING\x10\x05\"W\n\x0e\x43onnectionType\x12\x11\n\rNOT_CONNECTED\x10\x00\x12\r\n\tCONNECTED\x10\x01\x12\x0f\n\x0b\x43\x41N_CONNECT\x10\x02\x12\x12\n\x0e\x43\x41NNOT_CONNECT\x10\x03\x62\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.kad_dht.pb.kademlia_pb2', _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
DESCRIPTOR._options = None
|
||||
_globals['_RECORD']._serialized_start=36
|
||||
_globals['_RECORD']._serialized_end=94
|
||||
_globals['_MESSAGE']._serialized_start=97
|
||||
_globals['_MESSAGE']._serialized_end=555
|
||||
_globals['_MESSAGE_PEER']._serialized_start=281
|
||||
_globals['_MESSAGE_PEER']._serialized_end=359
|
||||
_globals['_MESSAGE_MESSAGETYPE']._serialized_start=361
|
||||
_globals['_MESSAGE_MESSAGETYPE']._serialized_end=466
|
||||
_globals['_MESSAGE_CONNECTIONTYPE']._serialized_start=468
|
||||
_globals['_MESSAGE_CONNECTIONTYPE']._serialized_end=555
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
@ -1,133 +0,0 @@
|
||||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import collections.abc
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.internal.containers
|
||||
import google.protobuf.internal.enum_type_wrapper
|
||||
import google.protobuf.message
|
||||
import sys
|
||||
import typing
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
import typing as typing_extensions
|
||||
else:
|
||||
import typing_extensions
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
|
||||
@typing.final
|
||||
class Record(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
KEY_FIELD_NUMBER: builtins.int
|
||||
VALUE_FIELD_NUMBER: builtins.int
|
||||
TIMERECEIVED_FIELD_NUMBER: builtins.int
|
||||
key: builtins.bytes
|
||||
value: builtins.bytes
|
||||
timeReceived: builtins.str
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
key: builtins.bytes = ...,
|
||||
value: builtins.bytes = ...,
|
||||
timeReceived: builtins.str = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["key", b"key", "timeReceived", b"timeReceived", "value", b"value"]) -> None: ...
|
||||
|
||||
global___Record = Record
|
||||
|
||||
@typing.final
|
||||
class Message(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
class _MessageType:
|
||||
ValueType = typing.NewType("ValueType", builtins.int)
|
||||
V: typing_extensions.TypeAlias = ValueType
|
||||
|
||||
class _MessageTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._MessageType.ValueType], builtins.type):
|
||||
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
||||
PUT_VALUE: Message._MessageType.ValueType # 0
|
||||
GET_VALUE: Message._MessageType.ValueType # 1
|
||||
ADD_PROVIDER: Message._MessageType.ValueType # 2
|
||||
GET_PROVIDERS: Message._MessageType.ValueType # 3
|
||||
FIND_NODE: Message._MessageType.ValueType # 4
|
||||
PING: Message._MessageType.ValueType # 5
|
||||
|
||||
class MessageType(_MessageType, metaclass=_MessageTypeEnumTypeWrapper): ...
|
||||
PUT_VALUE: Message.MessageType.ValueType # 0
|
||||
GET_VALUE: Message.MessageType.ValueType # 1
|
||||
ADD_PROVIDER: Message.MessageType.ValueType # 2
|
||||
GET_PROVIDERS: Message.MessageType.ValueType # 3
|
||||
FIND_NODE: Message.MessageType.ValueType # 4
|
||||
PING: Message.MessageType.ValueType # 5
|
||||
|
||||
class _ConnectionType:
|
||||
ValueType = typing.NewType("ValueType", builtins.int)
|
||||
V: typing_extensions.TypeAlias = ValueType
|
||||
|
||||
class _ConnectionTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._ConnectionType.ValueType], builtins.type):
|
||||
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
||||
NOT_CONNECTED: Message._ConnectionType.ValueType # 0
|
||||
CONNECTED: Message._ConnectionType.ValueType # 1
|
||||
CAN_CONNECT: Message._ConnectionType.ValueType # 2
|
||||
CANNOT_CONNECT: Message._ConnectionType.ValueType # 3
|
||||
|
||||
class ConnectionType(_ConnectionType, metaclass=_ConnectionTypeEnumTypeWrapper): ...
|
||||
NOT_CONNECTED: Message.ConnectionType.ValueType # 0
|
||||
CONNECTED: Message.ConnectionType.ValueType # 1
|
||||
CAN_CONNECT: Message.ConnectionType.ValueType # 2
|
||||
CANNOT_CONNECT: Message.ConnectionType.ValueType # 3
|
||||
|
||||
@typing.final
|
||||
class Peer(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
ID_FIELD_NUMBER: builtins.int
|
||||
ADDRS_FIELD_NUMBER: builtins.int
|
||||
CONNECTION_FIELD_NUMBER: builtins.int
|
||||
id: builtins.bytes
|
||||
connection: global___Message.ConnectionType.ValueType
|
||||
@property
|
||||
def addrs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
id: builtins.bytes = ...,
|
||||
addrs: collections.abc.Iterable[builtins.bytes] | None = ...,
|
||||
connection: global___Message.ConnectionType.ValueType = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["addrs", b"addrs", "connection", b"connection", "id", b"id"]) -> None: ...
|
||||
|
||||
TYPE_FIELD_NUMBER: builtins.int
|
||||
CLUSTERLEVELRAW_FIELD_NUMBER: builtins.int
|
||||
KEY_FIELD_NUMBER: builtins.int
|
||||
RECORD_FIELD_NUMBER: builtins.int
|
||||
CLOSERPEERS_FIELD_NUMBER: builtins.int
|
||||
PROVIDERPEERS_FIELD_NUMBER: builtins.int
|
||||
type: global___Message.MessageType.ValueType
|
||||
clusterLevelRaw: builtins.int
|
||||
key: builtins.bytes
|
||||
@property
|
||||
def record(self) -> global___Record: ...
|
||||
@property
|
||||
def closerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ...
|
||||
@property
|
||||
def providerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
type: global___Message.MessageType.ValueType = ...,
|
||||
clusterLevelRaw: builtins.int = ...,
|
||||
key: builtins.bytes = ...,
|
||||
record: global___Record | None = ...,
|
||||
closerPeers: collections.abc.Iterable[global___Message.Peer] | None = ...,
|
||||
providerPeers: collections.abc.Iterable[global___Message.Peer] | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["record", b"record"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["closerPeers", b"closerPeers", "clusterLevelRaw", b"clusterLevelRaw", "key", b"key", "providerPeers", b"providerPeers", "record", b"record", "type", b"type"]) -> None: ...
|
||||
|
||||
global___Message = Message
|
||||
@ -1,415 +0,0 @@
|
||||
"""
|
||||
Peer routing implementation for Kademlia DHT.
|
||||
|
||||
This module implements the peer routing interface using Kademlia's algorithm
|
||||
to efficiently locate peers in a distributed network.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import trio
|
||||
import varint
|
||||
|
||||
from libp2p.abc import (
|
||||
IHost,
|
||||
INetStream,
|
||||
IPeerRouting,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
|
||||
from .common import (
|
||||
ALPHA,
|
||||
PROTOCOL_ID,
|
||||
)
|
||||
from .pb.kademlia_pb2 import (
|
||||
Message,
|
||||
)
|
||||
from .routing_table import (
|
||||
RoutingTable,
|
||||
)
|
||||
from .utils import (
|
||||
sort_peer_ids_by_distance,
|
||||
)
|
||||
|
||||
# logger = logging.getLogger("libp2p.kademlia.peer_routing")
|
||||
logger = logging.getLogger("kademlia-example.peer_routing")
|
||||
|
||||
MAX_PEER_LOOKUP_ROUNDS = 20 # Maximum number of rounds in peer lookup
|
||||
|
||||
|
||||
class PeerRouting(IPeerRouting):
|
||||
"""
|
||||
Implementation of peer routing using the Kademlia algorithm.
|
||||
|
||||
This class provides methods to find peers in the DHT network
|
||||
and helps maintain the routing table.
|
||||
"""
|
||||
|
||||
def __init__(self, host: IHost, routing_table: RoutingTable):
|
||||
"""
|
||||
Initialize the peer routing service.
|
||||
|
||||
:param host: The libp2p host
|
||||
:param routing_table: The Kademlia routing table
|
||||
|
||||
"""
|
||||
self.host = host
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def find_peer(self, peer_id: ID) -> PeerInfo | None:
|
||||
"""
|
||||
Find a peer with the given ID.
|
||||
|
||||
:param peer_id: The ID of the peer to find
|
||||
|
||||
Returns
|
||||
-------
|
||||
Optional[PeerInfo]
|
||||
The peer information if found, None otherwise
|
||||
|
||||
"""
|
||||
# Check if this is actually our peer ID
|
||||
if peer_id == self.host.get_id():
|
||||
try:
|
||||
# Return our own peer info
|
||||
return PeerInfo(peer_id, self.host.get_addrs())
|
||||
except Exception:
|
||||
logger.exception("Error getting our own peer info")
|
||||
return None
|
||||
|
||||
# First check if the peer is in our routing table
|
||||
peer_info = self.routing_table.get_peer_info(peer_id)
|
||||
if peer_info:
|
||||
logger.debug(f"Found peer {peer_id} in routing table")
|
||||
return peer_info
|
||||
|
||||
# Then check if the peer is in our peerstore
|
||||
try:
|
||||
addrs = self.host.get_peerstore().addrs(peer_id)
|
||||
if addrs:
|
||||
logger.debug(f"Found peer {peer_id} in peerstore")
|
||||
return PeerInfo(peer_id, addrs)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# If not found locally, search the network
|
||||
try:
|
||||
closest_peers = await self.find_closest_peers_network(peer_id.to_bytes())
|
||||
logger.info(f"Closest peers found: {closest_peers}")
|
||||
|
||||
# Check if we found the peer we're looking for
|
||||
for found_peer in closest_peers:
|
||||
if found_peer == peer_id:
|
||||
try:
|
||||
addrs = self.host.get_peerstore().addrs(found_peer)
|
||||
if addrs:
|
||||
return PeerInfo(found_peer, addrs)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching for peer {peer_id}: {e}")
|
||||
|
||||
# Not found
|
||||
logger.info(f"Peer {peer_id} not found")
|
||||
return None
|
||||
|
||||
async def _query_single_peer_for_closest(
|
||||
self, peer: ID, target_key: bytes, new_peers: list[ID]
|
||||
) -> None:
|
||||
"""
|
||||
Query a single peer for closest peers and append results to the shared list.
|
||||
|
||||
params: peer : ID
|
||||
The peer to query
|
||||
params: target_key : bytes
|
||||
The target key to find closest peers for
|
||||
params: new_peers : list[ID]
|
||||
Shared list to append results to
|
||||
|
||||
"""
|
||||
try:
|
||||
result = await self._query_peer_for_closest(peer, target_key)
|
||||
# Add deduplication to prevent duplicate peers
|
||||
for peer_id in result:
|
||||
if peer_id not in new_peers:
|
||||
new_peers.append(peer_id)
|
||||
logger.debug(
|
||||
"Queried peer %s for closest peers, got %d results (%d unique)",
|
||||
peer,
|
||||
len(result),
|
||||
len([p for p in result if p not in new_peers[: -len(result)]]),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Query to peer {peer} failed: {e}")
|
||||
|
||||
async def find_closest_peers_network(
|
||||
self, target_key: bytes, count: int = 20
|
||||
) -> list[ID]:
|
||||
"""
|
||||
Find the closest peers to a target key in the entire network.
|
||||
|
||||
Performs an iterative lookup by querying peers for their closest peers.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[ID]
|
||||
Closest peer IDs
|
||||
|
||||
"""
|
||||
# Start with closest peers from our routing table
|
||||
closest_peers = self.routing_table.find_local_closest_peers(target_key, count)
|
||||
logger.debug("Local closest peers: %d found", len(closest_peers))
|
||||
queried_peers: set[ID] = set()
|
||||
rounds = 0
|
||||
|
||||
# Return early if we have no peers to start with
|
||||
if not closest_peers:
|
||||
logger.warning("No local peers available for network lookup")
|
||||
return []
|
||||
|
||||
# Iterative lookup until convergence
|
||||
while rounds < MAX_PEER_LOOKUP_ROUNDS:
|
||||
rounds += 1
|
||||
logger.debug(f"Lookup round {rounds}/{MAX_PEER_LOOKUP_ROUNDS}")
|
||||
|
||||
# Find peers we haven't queried yet
|
||||
peers_to_query = [p for p in closest_peers if p not in queried_peers]
|
||||
if not peers_to_query:
|
||||
logger.debug("No more unqueried peers available, ending lookup")
|
||||
break # No more peers to query
|
||||
|
||||
# Query these peers for their closest peers to target
|
||||
peers_batch = peers_to_query[:ALPHA] # Limit to ALPHA peers at a time
|
||||
|
||||
# Mark these peers as queried before we actually query them
|
||||
for peer in peers_batch:
|
||||
queried_peers.add(peer)
|
||||
|
||||
# Run queries in parallel for this batch using trio nursery
|
||||
new_peers: list[ID] = [] # Shared array to collect all results
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for peer in peers_batch:
|
||||
nursery.start_soon(
|
||||
self._query_single_peer_for_closest, peer, target_key, new_peers
|
||||
)
|
||||
|
||||
# If we got no new peers, we're done
|
||||
if not new_peers:
|
||||
logger.debug("No new peers discovered in this round, ending lookup")
|
||||
break
|
||||
|
||||
# Update our list of closest peers
|
||||
all_candidates = closest_peers + new_peers
|
||||
old_closest_peers = closest_peers[:]
|
||||
closest_peers = sort_peer_ids_by_distance(target_key, all_candidates)[
|
||||
:count
|
||||
]
|
||||
logger.debug(f"Updated closest peers count: {len(closest_peers)}")
|
||||
|
||||
# Check if we made any progress (found closer peers)
|
||||
if closest_peers == old_closest_peers:
|
||||
logger.debug("No improvement in closest peers, ending lookup")
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"Network lookup completed after {rounds} rounds, "
|
||||
f"found {len(closest_peers)} peers"
|
||||
)
|
||||
return closest_peers
|
||||
|
||||
async def _query_peer_for_closest(self, peer: ID, target_key: bytes) -> list[ID]:
|
||||
"""
|
||||
Query a peer for their closest peers
|
||||
to the target key using varint length prefix
|
||||
"""
|
||||
stream = None
|
||||
results = []
|
||||
try:
|
||||
# Add the peer to our routing table regardless of query outcome
|
||||
try:
|
||||
addrs = self.host.get_peerstore().addrs(peer)
|
||||
if addrs:
|
||||
peer_info = PeerInfo(peer, addrs)
|
||||
await self.routing_table.add_peer(peer_info)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to add peer {peer} to routing table: {e}")
|
||||
|
||||
# Open a stream to the peer using the Kademlia protocol
|
||||
logger.debug(f"Opening stream to {peer} for closest peers query")
|
||||
try:
|
||||
stream = await self.host.new_stream(peer, [PROTOCOL_ID])
|
||||
logger.debug(f"Stream opened to {peer}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to open stream to {peer}: {e}")
|
||||
return []
|
||||
|
||||
# Create and send FIND_NODE request using protobuf
|
||||
find_node_msg = Message()
|
||||
find_node_msg.type = Message.MessageType.FIND_NODE
|
||||
find_node_msg.key = target_key # Set target key directly as bytes
|
||||
|
||||
# Serialize and send the protobuf message with varint length prefix
|
||||
proto_bytes = find_node_msg.SerializeToString()
|
||||
logger.debug(
|
||||
f"Sending FIND_NODE: {proto_bytes.hex()} (len={len(proto_bytes)})"
|
||||
)
|
||||
await stream.write(varint.encode(len(proto_bytes)))
|
||||
await stream.write(proto_bytes)
|
||||
|
||||
# Read varint-prefixed response length
|
||||
length_bytes = b""
|
||||
while True:
|
||||
b = await stream.read(1)
|
||||
if not b:
|
||||
logger.warning(
|
||||
"Error reading varint length from stream: connection closed"
|
||||
)
|
||||
return []
|
||||
length_bytes += b
|
||||
if b[0] & 0x80 == 0:
|
||||
break
|
||||
response_length = varint.decode_bytes(length_bytes)
|
||||
|
||||
# Read response data
|
||||
response_bytes = b""
|
||||
remaining = response_length
|
||||
while remaining > 0:
|
||||
chunk = await stream.read(remaining)
|
||||
if not chunk:
|
||||
logger.debug(f"Connection closed by peer {peer} while reading data")
|
||||
return []
|
||||
response_bytes += chunk
|
||||
remaining -= len(chunk)
|
||||
|
||||
# Parse the protobuf response
|
||||
response_msg = Message()
|
||||
response_msg.ParseFromString(response_bytes)
|
||||
logger.debug(
|
||||
"Received response from %s with %d peers",
|
||||
peer,
|
||||
len(response_msg.closerPeers),
|
||||
)
|
||||
|
||||
# Process closest peers from response
|
||||
if response_msg.type == Message.MessageType.FIND_NODE:
|
||||
for peer_data in response_msg.closerPeers:
|
||||
new_peer_id = ID(peer_data.id)
|
||||
if new_peer_id not in results:
|
||||
results.append(new_peer_id)
|
||||
if peer_data.addrs:
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
)
|
||||
|
||||
addrs = [Multiaddr(addr) for addr in peer_data.addrs]
|
||||
self.host.get_peerstore().add_addrs(new_peer_id, addrs, 3600)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error querying peer {peer} for closest: {e}")
|
||||
|
||||
finally:
|
||||
if stream:
|
||||
await stream.close()
|
||||
return results
|
||||
|
||||
async def _handle_kad_stream(self, stream: INetStream) -> None:
|
||||
"""
|
||||
Handle incoming Kademlia protocol streams.
|
||||
|
||||
params: stream: The incoming stream
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
|
||||
"""
|
||||
try:
|
||||
# Read message length
|
||||
length_bytes = await stream.read(4)
|
||||
if not length_bytes:
|
||||
return
|
||||
|
||||
message_length = int.from_bytes(length_bytes, byteorder="big")
|
||||
|
||||
# Read message
|
||||
message_bytes = await stream.read(message_length)
|
||||
if not message_bytes:
|
||||
return
|
||||
|
||||
# Parse protobuf message
|
||||
kad_message = Message()
|
||||
try:
|
||||
kad_message.ParseFromString(message_bytes)
|
||||
|
||||
if kad_message.type == Message.MessageType.FIND_NODE:
|
||||
# Get target key directly from protobuf message
|
||||
target_key = kad_message.key
|
||||
|
||||
# Find closest peers to target
|
||||
closest_peers = self.routing_table.find_local_closest_peers(
|
||||
target_key, 20
|
||||
)
|
||||
|
||||
# Create protobuf response
|
||||
response = Message()
|
||||
response.type = Message.MessageType.FIND_NODE
|
||||
|
||||
# Add peer information to response
|
||||
for peer_id in closest_peers:
|
||||
peer_proto = response.closerPeers.add()
|
||||
peer_proto.id = peer_id.to_bytes()
|
||||
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
|
||||
|
||||
# Add addresses if available
|
||||
try:
|
||||
addrs = self.host.get_peerstore().addrs(peer_id)
|
||||
if addrs:
|
||||
for addr in addrs:
|
||||
peer_proto.addrs.append(addr.to_bytes())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Send response
|
||||
response_bytes = response.SerializeToString()
|
||||
await stream.write(len(response_bytes).to_bytes(4, byteorder="big"))
|
||||
await stream.write(response_bytes)
|
||||
|
||||
except Exception as parse_err:
|
||||
logger.error(f"Failed to parse protocol buffer message: {parse_err}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error handling Kademlia stream: {e}")
|
||||
finally:
|
||||
await stream.close()
|
||||
|
||||
async def refresh_routing_table(self) -> None:
|
||||
"""
|
||||
Refresh the routing table by performing lookups for random keys.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
|
||||
"""
|
||||
logger.info("Refreshing routing table")
|
||||
|
||||
# Perform a lookup for ourselves to populate the routing table
|
||||
local_id = self.host.get_id()
|
||||
closest_peers = await self.find_closest_peers_network(local_id.to_bytes())
|
||||
|
||||
# Add discovered peers to routing table
|
||||
for peer_id in closest_peers:
|
||||
try:
|
||||
addrs = self.host.get_peerstore().addrs(peer_id)
|
||||
if addrs:
|
||||
peer_info = PeerInfo(peer_id, addrs)
|
||||
await self.routing_table.add_peer(peer_info)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to add discovered peer {peer_id}: {e}")
|
||||
@ -1,577 +0,0 @@
|
||||
"""
|
||||
Provider record storage for Kademlia DHT.
|
||||
|
||||
This module implements the storage for content provider records in the Kademlia DHT.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import (
|
||||
Any,
|
||||
)
|
||||
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
)
|
||||
import trio
|
||||
import varint
|
||||
|
||||
from libp2p.abc import (
|
||||
IHost,
|
||||
)
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
|
||||
from .common import (
|
||||
ALPHA,
|
||||
PROTOCOL_ID,
|
||||
QUERY_TIMEOUT,
|
||||
)
|
||||
from .pb.kademlia_pb2 import (
|
||||
Message,
|
||||
)
|
||||
|
||||
# logger = logging.getLogger("libp2p.kademlia.provider_store")
|
||||
logger = logging.getLogger("kademlia-example.provider_store")
|
||||
|
||||
# Constants for provider records (based on IPFS standards)
|
||||
PROVIDER_RECORD_REPUBLISH_INTERVAL = 22 * 60 * 60 # 22 hours in seconds
|
||||
PROVIDER_RECORD_EXPIRATION_INTERVAL = 48 * 60 * 60 # 48 hours in seconds
|
||||
PROVIDER_ADDRESS_TTL = 30 * 60 # 30 minutes in seconds
|
||||
|
||||
|
||||
class ProviderRecord:
|
||||
"""
|
||||
A record for a content provider in the DHT.
|
||||
|
||||
Contains the peer information and timestamp.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider_info: PeerInfo,
|
||||
timestamp: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a new provider record.
|
||||
|
||||
:param provider_info: The provider's peer information
|
||||
:param timestamp: Time this record was created/updated
|
||||
(defaults to current time)
|
||||
|
||||
"""
|
||||
self.provider_info = provider_info
|
||||
self.timestamp = timestamp or time.time()
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""
|
||||
Check if this provider record has expired.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the record has expired
|
||||
|
||||
"""
|
||||
current_time = time.time()
|
||||
return (current_time - self.timestamp) >= PROVIDER_RECORD_EXPIRATION_INTERVAL
|
||||
|
||||
def should_republish(self) -> bool:
|
||||
"""
|
||||
Check if this provider record should be republished.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the record should be republished
|
||||
|
||||
"""
|
||||
current_time = time.time()
|
||||
return (current_time - self.timestamp) >= PROVIDER_RECORD_REPUBLISH_INTERVAL
|
||||
|
||||
@property
|
||||
def peer_id(self) -> ID:
|
||||
"""Get the provider's peer ID."""
|
||||
return self.provider_info.peer_id
|
||||
|
||||
@property
|
||||
def addresses(self) -> list[Multiaddr]:
|
||||
"""Get the provider's addresses."""
|
||||
return self.provider_info.addrs
|
||||
|
||||
|
||||
class ProviderStore:
|
||||
"""
|
||||
Store for content provider records in the Kademlia DHT.
|
||||
|
||||
Maps content keys to provider records, with support for expiration.
|
||||
"""
|
||||
|
||||
def __init__(self, host: IHost, peer_routing: Any = None) -> None:
|
||||
"""
|
||||
Initialize a new provider store.
|
||||
|
||||
:param host: The libp2p host instance (optional)
|
||||
:param peer_routing: The peer routing instance (optional)
|
||||
"""
|
||||
# Maps content keys to a dict of provider records (peer_id -> record)
|
||||
self.providers: dict[bytes, dict[str, ProviderRecord]] = {}
|
||||
self.host = host
|
||||
self.peer_routing = peer_routing
|
||||
self.providing_keys: set[bytes] = set()
|
||||
self.local_peer_id = host.get_id()
|
||||
|
||||
async def _republish_provider_records(self) -> None:
|
||||
"""Republish all provider records for content this node is providing."""
|
||||
# First, republish keys we're actively providing
|
||||
for key in self.providing_keys:
|
||||
logger.debug(f"Republishing provider record for key {key.hex()}")
|
||||
await self.provide(key)
|
||||
|
||||
# Also check for any records that should be republished
|
||||
time.time()
|
||||
for key, providers in self.providers.items():
|
||||
for peer_id_str, record in providers.items():
|
||||
# Only republish records for our own peer
|
||||
if self.local_peer_id and str(self.local_peer_id) == peer_id_str:
|
||||
if record.should_republish():
|
||||
logger.debug(
|
||||
f"Republishing old provider record for key {key.hex()}"
|
||||
)
|
||||
await self.provide(key)
|
||||
|
||||
async def provide(self, key: bytes) -> bool:
|
||||
"""
|
||||
Advertise that this node can provide a piece of content.
|
||||
|
||||
Finds the k closest peers to the key and sends them ADD_PROVIDER messages.
|
||||
|
||||
:param key: The content key (multihash) to advertise
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the advertisement was successful
|
||||
|
||||
"""
|
||||
if not self.host or not self.peer_routing:
|
||||
logger.error("Host or peer_routing not initialized, cannot provide content")
|
||||
return False
|
||||
|
||||
# Add to local provider store
|
||||
local_addrs = []
|
||||
for addr in self.host.get_addrs():
|
||||
local_addrs.append(addr)
|
||||
|
||||
local_peer_info = PeerInfo(self.host.get_id(), local_addrs)
|
||||
self.add_provider(key, local_peer_info)
|
||||
|
||||
# Track that we're providing this key
|
||||
self.providing_keys.add(key)
|
||||
|
||||
# Find the k closest peers to the key
|
||||
closest_peers = await self.peer_routing.find_closest_peers_network(key)
|
||||
logger.debug(
|
||||
"Found %d peers close to key %s for provider advertisement",
|
||||
len(closest_peers),
|
||||
key.hex(),
|
||||
)
|
||||
|
||||
# Send ADD_PROVIDER messages to these ALPHA peers in parallel.
|
||||
success_count = 0
|
||||
for i in range(0, len(closest_peers), ALPHA):
|
||||
batch = closest_peers[i : i + ALPHA]
|
||||
results: list[bool] = [False] * len(batch)
|
||||
|
||||
async def send_one(
|
||||
idx: int, peer_id: ID, results: list[bool] = results
|
||||
) -> None:
|
||||
if peer_id == self.local_peer_id:
|
||||
return
|
||||
try:
|
||||
with trio.move_on_after(QUERY_TIMEOUT):
|
||||
success = await self._send_add_provider(peer_id, key)
|
||||
results[idx] = success
|
||||
if not success:
|
||||
logger.warning(f"Failed to send ADD_PROVIDER to {peer_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error sending ADD_PROVIDER to {peer_id}: {e}")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for idx, peer_id in enumerate(batch):
|
||||
nursery.start_soon(send_one, idx, peer_id, results)
|
||||
success_count += sum(results)
|
||||
|
||||
logger.info(f"Successfully advertised to {success_count} peers")
|
||||
return success_count > 0
|
||||
|
||||
async def _send_add_provider(self, peer_id: ID, key: bytes) -> bool:
|
||||
"""
|
||||
Send ADD_PROVIDER message to a specific peer.
|
||||
|
||||
:param peer_id: The peer to send the message to
|
||||
:param key: The content key being provided
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the message was successfully sent and acknowledged
|
||||
|
||||
"""
|
||||
try:
|
||||
result = False
|
||||
# Open a stream to the peer
|
||||
stream = await self.host.new_stream(peer_id, [TProtocol(PROTOCOL_ID)])
|
||||
|
||||
# Get our addresses to include in the message
|
||||
addrs = []
|
||||
for addr in self.host.get_addrs():
|
||||
addrs.append(addr.to_bytes())
|
||||
|
||||
# Create the ADD_PROVIDER message
|
||||
message = Message()
|
||||
message.type = Message.MessageType.ADD_PROVIDER
|
||||
message.key = key
|
||||
|
||||
# Add our provider info
|
||||
provider = message.providerPeers.add()
|
||||
provider.id = self.local_peer_id.to_bytes()
|
||||
provider.addrs.extend(addrs)
|
||||
|
||||
# Serialize and send the message
|
||||
proto_bytes = message.SerializeToString()
|
||||
await stream.write(varint.encode(len(proto_bytes)))
|
||||
await stream.write(proto_bytes)
|
||||
logger.debug(f"Sent ADD_PROVIDER to {peer_id} for key {key.hex()}")
|
||||
# Read response length prefix
|
||||
length_bytes = b""
|
||||
while True:
|
||||
logger.debug("Reading response length prefix in add provider")
|
||||
b = await stream.read(1)
|
||||
if not b:
|
||||
return False
|
||||
length_bytes += b
|
||||
if b[0] & 0x80 == 0:
|
||||
break
|
||||
|
||||
response_length = varint.decode_bytes(length_bytes)
|
||||
# Read response data
|
||||
response_bytes = b""
|
||||
remaining = response_length
|
||||
while remaining > 0:
|
||||
chunk = await stream.read(remaining)
|
||||
if not chunk:
|
||||
return False
|
||||
response_bytes += chunk
|
||||
remaining -= len(chunk)
|
||||
|
||||
# Parse response
|
||||
response = Message()
|
||||
response.ParseFromString(response_bytes)
|
||||
|
||||
# Check response type
|
||||
response.type == Message.MessageType.ADD_PROVIDER
|
||||
if response.type:
|
||||
result = True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error sending ADD_PROVIDER to {peer_id}: {e}")
|
||||
|
||||
finally:
|
||||
await stream.close()
|
||||
return result
|
||||
|
||||
async def find_providers(self, key: bytes, count: int = 20) -> list[PeerInfo]:
|
||||
"""
|
||||
Find content providers for a given key.
|
||||
|
||||
:param key: The content key to look for
|
||||
:param count: Maximum number of providers to return
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[PeerInfo]
|
||||
List of content providers
|
||||
|
||||
"""
|
||||
if not self.host or not self.peer_routing:
|
||||
logger.error("Host or peer_routing not initialized, cannot find providers")
|
||||
return []
|
||||
|
||||
# Check local provider store first
|
||||
local_providers = self.get_providers(key)
|
||||
if local_providers:
|
||||
logger.debug(
|
||||
f"Found {len(local_providers)} providers locally for {key.hex()}"
|
||||
)
|
||||
return local_providers[:count]
|
||||
logger.debug("local providers are %s", local_providers)
|
||||
|
||||
# Find the closest peers to the key
|
||||
closest_peers = await self.peer_routing.find_closest_peers_network(key)
|
||||
logger.debug(
|
||||
f"Searching {len(closest_peers)} peers for providers of {key.hex()}"
|
||||
)
|
||||
|
||||
# Query these peers for providers in batches of ALPHA, in parallel, with timeout
|
||||
all_providers = []
|
||||
for i in range(0, len(closest_peers), ALPHA):
|
||||
batch = closest_peers[i : i + ALPHA]
|
||||
batch_results: list[list[PeerInfo]] = [[] for _ in batch]
|
||||
|
||||
async def get_one(
|
||||
idx: int,
|
||||
peer_id: ID,
|
||||
batch_results: list[list[PeerInfo]] = batch_results,
|
||||
) -> None:
|
||||
if peer_id == self.local_peer_id:
|
||||
return
|
||||
try:
|
||||
with trio.move_on_after(QUERY_TIMEOUT):
|
||||
providers = await self._get_providers_from_peer(peer_id, key)
|
||||
if providers:
|
||||
for provider in providers:
|
||||
self.add_provider(key, provider)
|
||||
batch_results[idx] = providers
|
||||
else:
|
||||
logger.debug(f"No providers found at peer {peer_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get providers from {peer_id}: {e}")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for idx, peer_id in enumerate(batch):
|
||||
nursery.start_soon(get_one, idx, peer_id, batch_results)
|
||||
|
||||
for providers in batch_results:
|
||||
all_providers.extend(providers)
|
||||
if len(all_providers) >= count:
|
||||
return all_providers[:count]
|
||||
|
||||
return all_providers[:count]
|
||||
|
||||
async def _get_providers_from_peer(self, peer_id: ID, key: bytes) -> list[PeerInfo]:
|
||||
"""
|
||||
Get content providers from a specific peer.
|
||||
|
||||
:param peer_id: The peer to query
|
||||
:param key: The content key to look for
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[PeerInfo]
|
||||
List of provider information
|
||||
|
||||
"""
|
||||
providers: list[PeerInfo] = []
|
||||
try:
|
||||
# Open a stream to the peer
|
||||
stream = await self.host.new_stream(peer_id, [TProtocol(PROTOCOL_ID)])
|
||||
|
||||
try:
|
||||
# Create the GET_PROVIDERS message
|
||||
message = Message()
|
||||
message.type = Message.MessageType.GET_PROVIDERS
|
||||
message.key = key
|
||||
|
||||
# Serialize and send the message
|
||||
proto_bytes = message.SerializeToString()
|
||||
await stream.write(varint.encode(len(proto_bytes)))
|
||||
await stream.write(proto_bytes)
|
||||
|
||||
# Read response length prefix
|
||||
length_bytes = b""
|
||||
while True:
|
||||
b = await stream.read(1)
|
||||
if not b:
|
||||
return []
|
||||
length_bytes += b
|
||||
if b[0] & 0x80 == 0:
|
||||
break
|
||||
|
||||
response_length = varint.decode_bytes(length_bytes)
|
||||
# Read response data
|
||||
response_bytes = b""
|
||||
remaining = response_length
|
||||
while remaining > 0:
|
||||
chunk = await stream.read(remaining)
|
||||
if not chunk:
|
||||
return []
|
||||
response_bytes += chunk
|
||||
remaining -= len(chunk)
|
||||
|
||||
# Parse response
|
||||
response = Message()
|
||||
response.ParseFromString(response_bytes)
|
||||
|
||||
# Check response type
|
||||
if response.type != Message.MessageType.GET_PROVIDERS:
|
||||
return []
|
||||
|
||||
# Extract provider information
|
||||
providers = []
|
||||
for provider_proto in response.providerPeers:
|
||||
try:
|
||||
# Create peer ID from bytes
|
||||
provider_id = ID(provider_proto.id)
|
||||
|
||||
# Convert addresses to Multiaddr
|
||||
addrs = []
|
||||
for addr_bytes in provider_proto.addrs:
|
||||
try:
|
||||
addrs.append(Multiaddr(addr_bytes))
|
||||
except Exception:
|
||||
pass # Skip invalid addresses
|
||||
|
||||
# Create PeerInfo and add to result
|
||||
providers.append(PeerInfo(provider_id, addrs))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse provider info: {e}")
|
||||
|
||||
finally:
|
||||
await stream.close()
|
||||
return providers
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting providers from {peer_id}: {e}")
|
||||
return []
|
||||
|
||||
def add_provider(self, key: bytes, provider: PeerInfo) -> None:
|
||||
"""
|
||||
Add a provider for a given content key.
|
||||
|
||||
:param key: The content key
|
||||
:param provider: The provider's peer information
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
|
||||
"""
|
||||
# Initialize providers for this key if needed
|
||||
if key not in self.providers:
|
||||
self.providers[key] = {}
|
||||
|
||||
# Add or update the provider record
|
||||
peer_id_str = str(provider.peer_id) # Use string representation as dict key
|
||||
self.providers[key][peer_id_str] = ProviderRecord(
|
||||
provider_info=provider, timestamp=time.time()
|
||||
)
|
||||
logger.debug(f"Added provider {provider.peer_id} for key {key.hex()}")
|
||||
|
||||
def get_providers(self, key: bytes) -> list[PeerInfo]:
|
||||
"""
|
||||
Get all providers for a given content key.
|
||||
|
||||
:param key: The content key
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[PeerInfo]
|
||||
List of providers for the key
|
||||
|
||||
"""
|
||||
if key not in self.providers:
|
||||
return []
|
||||
|
||||
# Collect valid provider records (not expired)
|
||||
result = []
|
||||
current_time = time.time()
|
||||
expired_peers = []
|
||||
|
||||
for peer_id_str, record in self.providers[key].items():
|
||||
# Check if the record has expired
|
||||
if current_time - record.timestamp > PROVIDER_RECORD_EXPIRATION_INTERVAL:
|
||||
expired_peers.append(peer_id_str)
|
||||
continue
|
||||
|
||||
# Use addresses only if they haven't expired
|
||||
addresses = []
|
||||
if current_time - record.timestamp <= PROVIDER_ADDRESS_TTL:
|
||||
addresses = record.addresses
|
||||
|
||||
# Create PeerInfo and add to results
|
||||
result.append(PeerInfo(record.peer_id, addresses))
|
||||
|
||||
# Clean up expired records
|
||||
for peer_id in expired_peers:
|
||||
del self.providers[key][peer_id]
|
||||
|
||||
# Remove the key if no providers left
|
||||
if not self.providers[key]:
|
||||
del self.providers[key]
|
||||
|
||||
return result
|
||||
|
||||
def cleanup_expired(self) -> None:
|
||||
"""Remove expired provider records."""
|
||||
current_time = time.time()
|
||||
expired_keys = []
|
||||
|
||||
for key, providers in self.providers.items():
|
||||
expired_providers = []
|
||||
|
||||
for peer_id_str, record in providers.items():
|
||||
if (
|
||||
current_time - record.timestamp
|
||||
> PROVIDER_RECORD_EXPIRATION_INTERVAL
|
||||
):
|
||||
expired_providers.append(peer_id_str)
|
||||
logger.debug(
|
||||
f"Removing expired provider {peer_id_str} for key {key.hex()}"
|
||||
)
|
||||
|
||||
# Remove expired providers
|
||||
for peer_id in expired_providers:
|
||||
del providers[peer_id]
|
||||
|
||||
# Track empty keys for removal
|
||||
if not providers:
|
||||
expired_keys.append(key)
|
||||
|
||||
# Remove empty keys
|
||||
for key in expired_keys:
|
||||
del self.providers[key]
|
||||
logger.debug(f"Removed key with no providers: {key.hex()}")
|
||||
|
||||
def get_provided_keys(self, peer_id: ID) -> list[bytes]:
|
||||
"""
|
||||
Get all content keys provided by a specific peer.
|
||||
|
||||
:param peer_id: The peer ID to look for
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[bytes]
|
||||
List of content keys provided by the peer
|
||||
|
||||
"""
|
||||
peer_id_str = str(peer_id)
|
||||
result = []
|
||||
|
||||
for key, providers in self.providers.items():
|
||||
if peer_id_str in providers:
|
||||
result.append(key)
|
||||
|
||||
return result
|
||||
|
||||
def size(self) -> int:
|
||||
"""
|
||||
Get the total number of provider records in the store.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Total number of provider records across all keys
|
||||
|
||||
"""
|
||||
total = 0
|
||||
for providers in self.providers.values():
|
||||
total += len(providers)
|
||||
return total
|
||||
@ -1,600 +0,0 @@
|
||||
"""
|
||||
Kademlia DHT routing table implementation.
|
||||
"""
|
||||
|
||||
from collections import (
|
||||
OrderedDict,
|
||||
)
|
||||
import logging
|
||||
import time
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
IHost,
|
||||
)
|
||||
from libp2p.kad_dht.utils import (
|
||||
xor_distance,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
|
||||
from .common import (
|
||||
PROTOCOL_ID,
|
||||
)
|
||||
from .pb.kademlia_pb2 import (
|
||||
Message,
|
||||
)
|
||||
|
||||
# logger = logging.getLogger("libp2p.kademlia.routing_table")
|
||||
logger = logging.getLogger("kademlia-example.routing_table")
|
||||
|
||||
# Default parameters
|
||||
BUCKET_SIZE = 20 # k in the Kademlia paper
|
||||
MAXIMUM_BUCKETS = 256 # Maximum number of buckets (for 256-bit keys)
|
||||
PEER_REFRESH_INTERVAL = 60 # Interval to refresh peers in seconds
|
||||
STALE_PEER_THRESHOLD = 3600 # Time in seconds after which a peer is considered stale
|
||||
|
||||
|
||||
class KBucket:
|
||||
"""
|
||||
A k-bucket implementation for the Kademlia DHT.
|
||||
|
||||
Each k-bucket stores up to k (BUCKET_SIZE) peers, sorted by least-recently seen.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: IHost,
|
||||
bucket_size: int = BUCKET_SIZE,
|
||||
min_range: int = 0,
|
||||
max_range: int = 2**256,
|
||||
):
|
||||
"""
|
||||
Initialize a new k-bucket.
|
||||
|
||||
:param host: The host this bucket belongs to
|
||||
:param bucket_size: Maximum number of peers to store in the bucket
|
||||
:param min_range: Lower boundary of the bucket's key range (inclusive)
|
||||
:param max_range: Upper boundary of the bucket's key range (exclusive)
|
||||
|
||||
"""
|
||||
self.bucket_size = bucket_size
|
||||
self.host = host
|
||||
self.min_range = min_range
|
||||
self.max_range = max_range
|
||||
# Store PeerInfo objects along with last-seen timestamp
|
||||
self.peers: OrderedDict[ID, tuple[PeerInfo, float]] = OrderedDict()
|
||||
|
||||
def peer_ids(self) -> list[ID]:
|
||||
"""Get all peer IDs in the bucket."""
|
||||
return list(self.peers.keys())
|
||||
|
||||
def peer_infos(self) -> list[PeerInfo]:
|
||||
"""Get all PeerInfo objects in the bucket."""
|
||||
return [info for info, _ in self.peers.values()]
|
||||
|
||||
def get_oldest_peer(self) -> ID | None:
|
||||
"""Get the least-recently seen peer."""
|
||||
if not self.peers:
|
||||
return None
|
||||
return next(iter(self.peers.keys()))
|
||||
|
||||
async def add_peer(self, peer_info: PeerInfo) -> bool:
|
||||
"""
|
||||
Add a peer to the bucket. Returns True if the peer was added or updated,
|
||||
False if the bucket is full.
|
||||
"""
|
||||
current_time = time.time()
|
||||
peer_id = peer_info.peer_id
|
||||
|
||||
# If peer is already in the bucket, move it to the end (most recently seen)
|
||||
if peer_id in self.peers:
|
||||
self.refresh_peer_last_seen(peer_id)
|
||||
return True
|
||||
|
||||
# If bucket has space, add the peer
|
||||
if len(self.peers) < self.bucket_size:
|
||||
self.peers[peer_id] = (peer_info, current_time)
|
||||
return True
|
||||
|
||||
# If bucket is full, we need to replace the least-recently seen peer
|
||||
# Get the least-recently seen peer
|
||||
oldest_peer_id = self.get_oldest_peer()
|
||||
if oldest_peer_id is None:
|
||||
logger.warning("No oldest peer found when bucket is full")
|
||||
return False
|
||||
|
||||
# Check if the old peer is responsive to ping request
|
||||
try:
|
||||
# Try to ping the oldest peer, not the new peer
|
||||
response = await self._ping_peer(oldest_peer_id)
|
||||
if response:
|
||||
# If the old peer is still alive, we will not add the new peer
|
||||
logger.debug(
|
||||
"Old peer %s is still alive, cannot add new peer %s",
|
||||
oldest_peer_id,
|
||||
peer_id,
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
# If the old peer is unresponsive, we can replace it with the new peer
|
||||
logger.debug(
|
||||
"Old peer %s is unresponsive, replacing with new peer %s: %s",
|
||||
oldest_peer_id,
|
||||
peer_id,
|
||||
str(e),
|
||||
)
|
||||
self.peers.popitem(last=False) # Remove oldest peer
|
||||
self.peers[peer_id] = (peer_info, current_time)
|
||||
return True
|
||||
|
||||
# If we got here, the oldest peer responded but we couldn't add the new peer
|
||||
return False
|
||||
|
||||
def remove_peer(self, peer_id: ID) -> bool:
|
||||
"""
|
||||
Remove a peer from the bucket.
|
||||
Returns True if the peer was in the bucket, False otherwise.
|
||||
"""
|
||||
if peer_id in self.peers:
|
||||
del self.peers[peer_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
def has_peer(self, peer_id: ID) -> bool:
|
||||
"""Check if the peer is in the bucket."""
|
||||
return peer_id in self.peers
|
||||
|
||||
def get_peer_info(self, peer_id: ID) -> PeerInfo | None:
|
||||
"""Get the PeerInfo for a given peer ID if it exists in the bucket."""
|
||||
if peer_id in self.peers:
|
||||
return self.peers[peer_id][0]
|
||||
return None
|
||||
|
||||
def size(self) -> int:
|
||||
"""Get the number of peers in the bucket."""
|
||||
return len(self.peers)
|
||||
|
||||
def get_stale_peers(self, stale_threshold_seconds: int = 3600) -> list[ID]:
|
||||
"""
|
||||
Get peers that haven't been pinged recently.
|
||||
|
||||
params: stale_threshold_seconds: Time in seconds
|
||||
params: after which a peer is considered stale
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[ID]
|
||||
List of peer IDs that need to be refreshed
|
||||
|
||||
"""
|
||||
current_time = time.time()
|
||||
stale_peers = []
|
||||
|
||||
for peer_id, (_, last_seen) in self.peers.items():
|
||||
if current_time - last_seen > stale_threshold_seconds:
|
||||
stale_peers.append(peer_id)
|
||||
|
||||
return stale_peers
|
||||
|
||||
async def _periodic_peer_refresh(self) -> None:
|
||||
"""Background task to periodically refresh peers"""
|
||||
try:
|
||||
while True:
|
||||
await trio.sleep(PEER_REFRESH_INTERVAL) # Check every minute
|
||||
|
||||
# Find stale peers (not pinged in last hour)
|
||||
stale_peers = self.get_stale_peers(
|
||||
stale_threshold_seconds=STALE_PEER_THRESHOLD
|
||||
)
|
||||
if stale_peers:
|
||||
logger.debug(f"Found {len(stale_peers)} stale peers to refresh")
|
||||
|
||||
for peer_id in stale_peers:
|
||||
try:
|
||||
# Try to ping the peer
|
||||
logger.debug("Pinging stale peer %s", peer_id)
|
||||
responce = await self._ping_peer(peer_id)
|
||||
if responce:
|
||||
# Update the last seen time
|
||||
self.refresh_peer_last_seen(peer_id)
|
||||
logger.debug(f"Refreshed peer {peer_id}")
|
||||
else:
|
||||
# If ping fails, remove the peer
|
||||
logger.debug(f"Failed to ping peer {peer_id}")
|
||||
self.remove_peer(peer_id)
|
||||
logger.info(f"Removed unresponsive peer {peer_id}")
|
||||
|
||||
logger.debug(f"Successfully refreshed peer {peer_id}")
|
||||
except Exception as e:
|
||||
# If ping fails, remove the peer
|
||||
logger.debug(
|
||||
"Failed to ping peer %s: %s",
|
||||
peer_id,
|
||||
e,
|
||||
)
|
||||
self.remove_peer(peer_id)
|
||||
logger.info(f"Removed unresponsive peer {peer_id}")
|
||||
except trio.Cancelled:
|
||||
logger.debug("Peer refresh task cancelled")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in peer refresh task: {e}", exc_info=True)
|
||||
|
||||
async def _ping_peer(self, peer_id: ID) -> bool:
|
||||
"""
|
||||
Ping a peer using protobuf message to check
|
||||
if it's still alive and update last seen time.
|
||||
|
||||
params: peer_id: The ID of the peer to ping
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if ping successful, False otherwise
|
||||
|
||||
"""
|
||||
result = False
|
||||
# Get peer info directly from the bucket
|
||||
peer_info = self.get_peer_info(peer_id)
|
||||
if not peer_info:
|
||||
raise ValueError(f"Peer {peer_id} not in bucket")
|
||||
|
||||
try:
|
||||
# Open a stream to the peer with the DHT protocol
|
||||
stream = await self.host.new_stream(peer_id, [PROTOCOL_ID])
|
||||
|
||||
try:
|
||||
# Create ping protobuf message
|
||||
ping_msg = Message()
|
||||
ping_msg.type = Message.PING # Use correct enum
|
||||
|
||||
# Serialize and send with length prefix (4 bytes big-endian)
|
||||
msg_bytes = ping_msg.SerializeToString()
|
||||
logger.debug(
|
||||
f"Sending PING message to {peer_id}, size: {len(msg_bytes)} bytes"
|
||||
)
|
||||
await stream.write(len(msg_bytes).to_bytes(4, byteorder="big"))
|
||||
await stream.write(msg_bytes)
|
||||
|
||||
# Wait for response with timeout
|
||||
with trio.move_on_after(2): # 2 second timeout
|
||||
# Read response length (4 bytes)
|
||||
length_bytes = await stream.read(4)
|
||||
if not length_bytes or len(length_bytes) < 4:
|
||||
logger.warning(f"Peer {peer_id} disconnected during ping")
|
||||
return False
|
||||
|
||||
msg_len = int.from_bytes(length_bytes, byteorder="big")
|
||||
if (
|
||||
msg_len <= 0 or msg_len > 1024 * 1024
|
||||
): # Sanity check on message size
|
||||
logger.warning(
|
||||
f"Invalid message length from {peer_id}: {msg_len}"
|
||||
)
|
||||
return False
|
||||
|
||||
logger.debug(
|
||||
f"Receiving response from {peer_id}, size: {msg_len} bytes"
|
||||
)
|
||||
|
||||
# Read full message
|
||||
response_bytes = await stream.read(msg_len)
|
||||
if not response_bytes:
|
||||
logger.warning(f"Failed to read response from {peer_id}")
|
||||
return False
|
||||
|
||||
# Parse protobuf response
|
||||
response = Message()
|
||||
try:
|
||||
response.ParseFromString(response_bytes)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to parse protobuf response from {peer_id}: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
if response.type == Message.PING:
|
||||
# Update the last seen timestamp for this peer
|
||||
logger.debug(f"Successfully pinged peer {peer_id}")
|
||||
result = True
|
||||
return result
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
f"Unexpected response type from {peer_id}: {response.type}"
|
||||
)
|
||||
return False
|
||||
|
||||
# If we get here, the ping timed out
|
||||
logger.warning(f"Ping to peer {peer_id} timed out")
|
||||
return False
|
||||
|
||||
finally:
|
||||
await stream.close()
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error pinging peer {peer_id}: {str(e)}")
|
||||
return False
|
||||
|
||||
def refresh_peer_last_seen(self, peer_id: ID) -> bool:
|
||||
"""
|
||||
Update the last-seen timestamp for a peer in the bucket.
|
||||
|
||||
params: peer_id: The ID of the peer to refresh
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the peer was found and refreshed, False otherwise
|
||||
|
||||
"""
|
||||
if peer_id in self.peers:
|
||||
# Get current peer info and update the timestamp
|
||||
peer_info, _ = self.peers[peer_id]
|
||||
current_time = time.time()
|
||||
self.peers[peer_id] = (peer_info, current_time)
|
||||
# Move to end of ordered dict to mark as most recently seen
|
||||
self.peers.move_to_end(peer_id)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def key_in_range(self, key: bytes) -> bool:
|
||||
"""
|
||||
Check if a key is in the range of this bucket.
|
||||
|
||||
params: key: The key to check (bytes)
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the key is in range, False otherwise
|
||||
|
||||
"""
|
||||
key_int = int.from_bytes(key, byteorder="big")
|
||||
return self.min_range <= key_int < self.max_range
|
||||
|
||||
def split(self) -> tuple["KBucket", "KBucket"]:
|
||||
"""
|
||||
Split the bucket into two buckets.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple
|
||||
(lower_bucket, upper_bucket)
|
||||
|
||||
"""
|
||||
midpoint = (self.min_range + self.max_range) // 2
|
||||
lower_bucket = KBucket(self.host, self.bucket_size, self.min_range, midpoint)
|
||||
upper_bucket = KBucket(self.host, self.bucket_size, midpoint, self.max_range)
|
||||
|
||||
# Redistribute peers
|
||||
for peer_id, (peer_info, timestamp) in self.peers.items():
|
||||
peer_key = int.from_bytes(peer_id.to_bytes(), byteorder="big")
|
||||
if peer_key < midpoint:
|
||||
lower_bucket.peers[peer_id] = (peer_info, timestamp)
|
||||
else:
|
||||
upper_bucket.peers[peer_id] = (peer_info, timestamp)
|
||||
|
||||
return lower_bucket, upper_bucket
|
||||
|
||||
|
||||
class RoutingTable:
|
||||
"""
|
||||
The Kademlia routing table maintains information on which peers to contact for any
|
||||
given peer ID in the network.
|
||||
"""
|
||||
|
||||
def __init__(self, local_id: ID, host: IHost) -> None:
|
||||
"""
|
||||
Initialize the routing table.
|
||||
|
||||
:param local_id: The ID of the local node.
|
||||
:param host: The host this routing table belongs to.
|
||||
|
||||
"""
|
||||
self.local_id = local_id
|
||||
self.host = host
|
||||
self.buckets = [KBucket(host, BUCKET_SIZE)]
|
||||
|
||||
async def add_peer(self, peer_obj: PeerInfo | ID) -> bool:
|
||||
"""
|
||||
Add a peer to the routing table.
|
||||
|
||||
:param peer_obj: Either PeerInfo object or peer ID to add
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool: True if the peer was added or updated, False otherwise
|
||||
|
||||
"""
|
||||
peer_id = None
|
||||
peer_info = None
|
||||
|
||||
try:
|
||||
# Handle different types of input
|
||||
if isinstance(peer_obj, PeerInfo):
|
||||
# Already have PeerInfo object
|
||||
peer_info = peer_obj
|
||||
peer_id = peer_obj.peer_id
|
||||
else:
|
||||
# Assume it's a peer ID
|
||||
peer_id = peer_obj
|
||||
# Try to get addresses from the peerstore if available
|
||||
try:
|
||||
addrs = self.host.get_peerstore().addrs(peer_id)
|
||||
if addrs:
|
||||
# Create PeerInfo object
|
||||
peer_info = PeerInfo(peer_id, addrs)
|
||||
else:
|
||||
logger.debug(
|
||||
"No addresses found for peer %s in peerstore, skipping",
|
||||
peer_id,
|
||||
)
|
||||
return False
|
||||
except Exception as peerstore_error:
|
||||
# Handle case where peer is not in peerstore yet
|
||||
logger.debug(
|
||||
"Peer %s not found in peerstore: %s, skipping",
|
||||
peer_id,
|
||||
str(peerstore_error),
|
||||
)
|
||||
return False
|
||||
|
||||
# Don't add ourselves
|
||||
if peer_id == self.local_id:
|
||||
return False
|
||||
|
||||
# Find the right bucket for this peer
|
||||
bucket = self.find_bucket(peer_id)
|
||||
|
||||
# Try to add to the bucket
|
||||
success = await bucket.add_peer(peer_info)
|
||||
if success:
|
||||
logger.debug(f"Successfully added peer {peer_id} to routing table")
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error adding peer {peer_obj} to routing table: {e}")
|
||||
return False
|
||||
|
||||
def remove_peer(self, peer_id: ID) -> bool:
|
||||
"""
|
||||
Remove a peer from the routing table.
|
||||
|
||||
:param peer_id: The ID of the peer to remove
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool: True if the peer was removed, False otherwise
|
||||
|
||||
"""
|
||||
bucket = self.find_bucket(peer_id)
|
||||
return bucket.remove_peer(peer_id)
|
||||
|
||||
def find_bucket(self, peer_id: ID) -> KBucket:
|
||||
"""
|
||||
Find the bucket that would contain the given peer ID or PeerInfo.
|
||||
|
||||
:param peer_obj: Either a peer ID or a PeerInfo object
|
||||
|
||||
Returns
|
||||
-------
|
||||
KBucket: The bucket for this peer
|
||||
|
||||
"""
|
||||
for bucket in self.buckets:
|
||||
if bucket.key_in_range(peer_id.to_bytes()):
|
||||
return bucket
|
||||
|
||||
return self.buckets[0]
|
||||
|
||||
def find_local_closest_peers(self, key: bytes, count: int = 20) -> list[ID]:
|
||||
"""
|
||||
Find the closest peers to a given key.
|
||||
|
||||
:param key: The key to find closest peers to (bytes)
|
||||
:param count: Maximum number of peers to return
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[ID]: List of peer IDs closest to the key
|
||||
|
||||
"""
|
||||
# Get all peers from all buckets
|
||||
all_peers = []
|
||||
for bucket in self.buckets:
|
||||
all_peers.extend(bucket.peer_ids())
|
||||
|
||||
# Sort by XOR distance to the key
|
||||
all_peers.sort(key=lambda p: xor_distance(p.to_bytes(), key))
|
||||
|
||||
return all_peers[:count]
|
||||
|
||||
def get_peer_ids(self) -> list[ID]:
|
||||
"""
|
||||
Get all peer IDs in the routing table.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:param List[ID]: List of all peer IDs
|
||||
|
||||
"""
|
||||
peers = []
|
||||
for bucket in self.buckets:
|
||||
peers.extend(bucket.peer_ids())
|
||||
return peers
|
||||
|
||||
def get_peer_info(self, peer_id: ID) -> PeerInfo | None:
|
||||
"""
|
||||
Get the peer info for a specific peer.
|
||||
|
||||
:param peer_id: The ID of the peer to get info for
|
||||
|
||||
Returns
|
||||
-------
|
||||
PeerInfo: The peer info, or None if not found
|
||||
|
||||
"""
|
||||
bucket = self.find_bucket(peer_id)
|
||||
return bucket.get_peer_info(peer_id)
|
||||
|
||||
def peer_in_table(self, peer_id: ID) -> bool:
|
||||
"""
|
||||
Check if a peer is in the routing table.
|
||||
|
||||
:param peer_id: The ID of the peer to check
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool: True if the peer is in the routing table, False otherwise
|
||||
|
||||
"""
|
||||
bucket = self.find_bucket(peer_id)
|
||||
return bucket.has_peer(peer_id)
|
||||
|
||||
def size(self) -> int:
|
||||
"""
|
||||
Get the number of peers in the routing table.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int: Number of peers
|
||||
|
||||
"""
|
||||
count = 0
|
||||
for bucket in self.buckets:
|
||||
count += bucket.size()
|
||||
return count
|
||||
|
||||
def get_stale_peers(self, stale_threshold_seconds: int = 3600) -> list[ID]:
|
||||
"""
|
||||
Get all stale peers from all buckets
|
||||
|
||||
params: stale_threshold_seconds:
|
||||
Time in seconds after which a peer is considered stale
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[ID]
|
||||
List of stale peer IDs
|
||||
|
||||
"""
|
||||
stale_peers = []
|
||||
for bucket in self.buckets:
|
||||
stale_peers.extend(bucket.get_stale_peers(stale_threshold_seconds))
|
||||
return stale_peers
|
||||
|
||||
def cleanup_routing_table(self) -> None:
|
||||
"""
|
||||
Cleanup the routing table by removing all data.
|
||||
This is useful for resetting the routing table during tests or reinitialization.
|
||||
"""
|
||||
self.buckets = [KBucket(self.host, BUCKET_SIZE)]
|
||||
logger.info("Routing table cleaned up, all data removed.")
|
||||
@ -1,117 +0,0 @@
|
||||
"""
|
||||
Utility functions for Kademlia DHT implementation.
|
||||
"""
|
||||
|
||||
import base58
|
||||
import multihash
|
||||
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
|
||||
|
||||
def create_key_from_binary(binary_data: bytes) -> bytes:
|
||||
"""
|
||||
Creates a key for the DHT by hashing binary data with SHA-256.
|
||||
|
||||
params: binary_data: The binary data to hash.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bytes: The resulting key.
|
||||
|
||||
"""
|
||||
return multihash.digest(binary_data, "sha2-256").digest
|
||||
|
||||
|
||||
def xor_distance(key1: bytes, key2: bytes) -> int:
|
||||
"""
|
||||
Calculate the XOR distance between two keys.
|
||||
|
||||
params: key1: First key (bytes)
|
||||
params: key2: Second key (bytes)
|
||||
|
||||
Returns
|
||||
-------
|
||||
int: The XOR distance between the keys
|
||||
|
||||
"""
|
||||
# Ensure the inputs are bytes
|
||||
if not isinstance(key1, bytes) or not isinstance(key2, bytes):
|
||||
raise TypeError("Both key1 and key2 must be bytes objects")
|
||||
|
||||
# Convert to integers
|
||||
k1 = int.from_bytes(key1, byteorder="big")
|
||||
k2 = int.from_bytes(key2, byteorder="big")
|
||||
|
||||
# Calculate XOR distance
|
||||
return k1 ^ k2
|
||||
|
||||
|
||||
def bytes_to_base58(data: bytes) -> str:
|
||||
"""
|
||||
Convert bytes to base58 encoded string.
|
||||
|
||||
params: data: Input bytes
|
||||
|
||||
Returns
|
||||
-------
|
||||
str: Base58 encoded string
|
||||
|
||||
"""
|
||||
return base58.b58encode(data).decode("utf-8")
|
||||
|
||||
|
||||
def sort_peer_ids_by_distance(target_key: bytes, peer_ids: list[ID]) -> list[ID]:
|
||||
"""
|
||||
Sort a list of peer IDs by their distance to the target key.
|
||||
|
||||
params: target_key: The target key to measure distance from
|
||||
params: peer_ids: List of peer IDs to sort
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[ID]: Sorted list of peer IDs from closest to furthest
|
||||
|
||||
"""
|
||||
|
||||
def get_distance(peer_id: ID) -> int:
|
||||
# Hash the peer ID bytes to get a key for distance calculation
|
||||
peer_hash = multihash.digest(peer_id.to_bytes(), "sha2-256").digest
|
||||
return xor_distance(target_key, peer_hash)
|
||||
|
||||
return sorted(peer_ids, key=get_distance)
|
||||
|
||||
|
||||
def shared_prefix_len(first: bytes, second: bytes) -> int:
|
||||
"""
|
||||
Calculate the number of prefix bits shared by two byte sequences.
|
||||
|
||||
params: first: First byte sequence
|
||||
params: second: Second byte sequence
|
||||
|
||||
Returns
|
||||
-------
|
||||
int: Number of shared prefix bits
|
||||
|
||||
"""
|
||||
# Compare each byte to find the first bit difference
|
||||
common_length = 0
|
||||
for i in range(min(len(first), len(second))):
|
||||
byte_first = first[i]
|
||||
byte_second = second[i]
|
||||
|
||||
if byte_first == byte_second:
|
||||
common_length += 8
|
||||
else:
|
||||
# Find specific bit where they differ
|
||||
xor = byte_first ^ byte_second
|
||||
# Count leading zeros in the xor result
|
||||
for j in range(7, -1, -1):
|
||||
if (xor >> j) & 1 == 1:
|
||||
return common_length + (7 - j)
|
||||
|
||||
# This shouldn't be reached if xor != 0
|
||||
return common_length + 8
|
||||
|
||||
return common_length
|
||||
@ -1,393 +0,0 @@
|
||||
"""
|
||||
Value store implementation for Kademlia DHT.
|
||||
|
||||
Provides a way to store and retrieve key-value pairs with optional expiration.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
import varint
|
||||
|
||||
from libp2p.abc import (
|
||||
IHost,
|
||||
)
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
|
||||
from .common import (
|
||||
DEFAULT_TTL,
|
||||
PROTOCOL_ID,
|
||||
)
|
||||
from .pb.kademlia_pb2 import (
|
||||
Message,
|
||||
)
|
||||
|
||||
# logger = logging.getLogger("libp2p.kademlia.value_store")
|
||||
logger = logging.getLogger("kademlia-example.value_store")
|
||||
|
||||
|
||||
class ValueStore:
|
||||
"""
|
||||
Store for key-value pairs in a Kademlia DHT.
|
||||
|
||||
Values are stored with a timestamp and optional expiration time.
|
||||
"""
|
||||
|
||||
def __init__(self, host: IHost, local_peer_id: ID):
|
||||
"""
|
||||
Initialize an empty value store.
|
||||
|
||||
:param host: The libp2p host instance.
|
||||
:param local_peer_id: The local peer ID to ignore in peer requests.
|
||||
|
||||
"""
|
||||
# Store format: {key: (value, validity)}
|
||||
self.store: dict[bytes, tuple[bytes, float]] = {}
|
||||
# Store references to the host and local peer ID for making requests
|
||||
self.host = host
|
||||
self.local_peer_id = local_peer_id
|
||||
|
||||
def put(self, key: bytes, value: bytes, validity: float = 0.0) -> None:
|
||||
"""
|
||||
Store a value in the DHT.
|
||||
|
||||
:param key: The key to store the value under
|
||||
:param value: The value to store
|
||||
:param validity: validity in seconds before the value expires.
|
||||
Defaults to `DEFAULT_TTL` if set to 0.0.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
|
||||
"""
|
||||
if validity == 0.0:
|
||||
validity = time.time() + DEFAULT_TTL
|
||||
logger.debug(
|
||||
"Storing value for key %s... with validity %s", key.hex(), validity
|
||||
)
|
||||
self.store[key] = (value, validity)
|
||||
logger.debug(f"Stored value for key {key.hex()}")
|
||||
|
||||
async def _store_at_peer(self, peer_id: ID, key: bytes, value: bytes) -> bool:
|
||||
"""
|
||||
Store a value at a specific peer.
|
||||
|
||||
params: peer_id: The ID of the peer to store the value at
|
||||
params: key: The key to store
|
||||
params: value: The value to store
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the value was successfully stored, False otherwise
|
||||
|
||||
"""
|
||||
result = False
|
||||
stream = None
|
||||
try:
|
||||
# Don't try to store at ourselves
|
||||
if self.local_peer_id and peer_id == self.local_peer_id:
|
||||
result = True
|
||||
return result
|
||||
|
||||
if not self.host:
|
||||
logger.error("Host not initialized, cannot store value at peer")
|
||||
return False
|
||||
|
||||
logger.debug(f"Storing value for key {key.hex()} at peer {peer_id}")
|
||||
|
||||
# Open a stream to the peer
|
||||
stream = await self.host.new_stream(peer_id, [PROTOCOL_ID])
|
||||
logger.debug(f"Opened stream to peer {peer_id}")
|
||||
|
||||
# Create the PUT_VALUE message with protobuf
|
||||
message = Message()
|
||||
message.type = Message.MessageType.PUT_VALUE
|
||||
|
||||
# Set message fields
|
||||
message.key = key
|
||||
message.record.key = key
|
||||
message.record.value = value
|
||||
message.record.timeReceived = str(time.time())
|
||||
|
||||
# Serialize and send the protobuf message with length prefix
|
||||
proto_bytes = message.SerializeToString()
|
||||
await stream.write(varint.encode(len(proto_bytes)))
|
||||
await stream.write(proto_bytes)
|
||||
logger.debug("Sent PUT_VALUE protobuf message with varint length")
|
||||
# Read varint-prefixed response length
|
||||
|
||||
length_bytes = b""
|
||||
while True:
|
||||
logger.debug("Reading varint length prefix for response...")
|
||||
b = await stream.read(1)
|
||||
if not b:
|
||||
logger.warning("Connection closed while reading varint length")
|
||||
return False
|
||||
length_bytes += b
|
||||
if b[0] & 0x80 == 0:
|
||||
break
|
||||
logger.debug(f"Received varint length bytes: {length_bytes.hex()}")
|
||||
response_length = varint.decode_bytes(length_bytes)
|
||||
logger.debug("Response length: %d bytes", response_length)
|
||||
# Read response data
|
||||
response_bytes = b""
|
||||
remaining = response_length
|
||||
while remaining > 0:
|
||||
chunk = await stream.read(remaining)
|
||||
if not chunk:
|
||||
logger.debug(
|
||||
f"Connection closed by peer {peer_id} while reading data"
|
||||
)
|
||||
return False
|
||||
response_bytes += chunk
|
||||
remaining -= len(chunk)
|
||||
|
||||
# Parse protobuf response
|
||||
response = Message()
|
||||
response.ParseFromString(response_bytes)
|
||||
|
||||
# Check if response is valid
|
||||
if response.type == Message.MessageType.PUT_VALUE:
|
||||
if response.key:
|
||||
result = True
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to store value at peer {peer_id}: {e}")
|
||||
return False
|
||||
|
||||
finally:
|
||||
if stream:
|
||||
await stream.close()
|
||||
return result
|
||||
|
||||
def get(self, key: bytes) -> bytes | None:
|
||||
"""
|
||||
Retrieve a value from the DHT.
|
||||
|
||||
params: key: The key to look up
|
||||
|
||||
Returns
|
||||
-------
|
||||
Optional[bytes]
|
||||
The stored value, or None if not found or expired
|
||||
|
||||
"""
|
||||
logger.debug("Retrieving value for key %s...", key.hex()[:8])
|
||||
if key not in self.store:
|
||||
return None
|
||||
|
||||
value, validity = self.store[key]
|
||||
logger.debug(
|
||||
"Found value for key %s... with validity %s",
|
||||
key.hex(),
|
||||
validity,
|
||||
)
|
||||
# Check if the value has expired
|
||||
if validity is not None and validity < time.time():
|
||||
logger.debug(
|
||||
"Value for key %s... has expired, removing it",
|
||||
key.hex()[:8],
|
||||
)
|
||||
self.remove(key)
|
||||
return None
|
||||
|
||||
return value
|
||||
|
||||
async def _get_from_peer(self, peer_id: ID, key: bytes) -> bytes | None:
|
||||
"""
|
||||
Retrieve a value from a specific peer.
|
||||
|
||||
params: peer_id: The ID of the peer to retrieve the value from
|
||||
params: key: The key to retrieve
|
||||
|
||||
Returns
|
||||
-------
|
||||
Optional[bytes]
|
||||
The value if found, None otherwise
|
||||
|
||||
"""
|
||||
stream = None
|
||||
try:
|
||||
# Don't try to get from ourselves
|
||||
if peer_id == self.local_peer_id:
|
||||
return None
|
||||
|
||||
logger.debug(f"Getting value for key {key.hex()} from peer {peer_id}")
|
||||
|
||||
# Open a stream to the peer
|
||||
stream = await self.host.new_stream(peer_id, [TProtocol(PROTOCOL_ID)])
|
||||
logger.debug(f"Opened stream to peer {peer_id} for GET_VALUE")
|
||||
|
||||
# Create the GET_VALUE message using protobuf
|
||||
message = Message()
|
||||
message.type = Message.MessageType.GET_VALUE
|
||||
message.key = key
|
||||
|
||||
# Serialize and send the protobuf message
|
||||
proto_bytes = message.SerializeToString()
|
||||
await stream.write(varint.encode(len(proto_bytes)))
|
||||
await stream.write(proto_bytes)
|
||||
|
||||
# Read response length
|
||||
length_bytes = b""
|
||||
while True:
|
||||
b = await stream.read(1)
|
||||
if not b:
|
||||
logger.warning("Connection closed while reading length")
|
||||
return None
|
||||
length_bytes += b
|
||||
if b[0] & 0x80 == 0:
|
||||
break
|
||||
response_length = varint.decode_bytes(length_bytes)
|
||||
# Read response data
|
||||
response_bytes = b""
|
||||
remaining = response_length
|
||||
while remaining > 0:
|
||||
chunk = await stream.read(remaining)
|
||||
if not chunk:
|
||||
logger.debug(
|
||||
f"Connection closed by peer {peer_id} while reading data"
|
||||
)
|
||||
return None
|
||||
response_bytes += chunk
|
||||
remaining -= len(chunk)
|
||||
|
||||
# Parse protobuf response
|
||||
try:
|
||||
response = Message()
|
||||
response.ParseFromString(response_bytes)
|
||||
logger.debug(
|
||||
f"Received protobuf response from peer"
|
||||
f" {peer_id}, type: {response.type}"
|
||||
)
|
||||
|
||||
# Process protobuf response
|
||||
if (
|
||||
response.type == Message.MessageType.GET_VALUE
|
||||
and response.HasField("record")
|
||||
and response.record.value
|
||||
):
|
||||
logger.debug(
|
||||
f"Received value for key {key.hex()} from peer {peer_id}"
|
||||
)
|
||||
return response.record.value
|
||||
|
||||
# Handle case where value is not found but peer infos are returned
|
||||
else:
|
||||
logger.debug(
|
||||
f"Value not found for key {key.hex()} from peer {peer_id},"
|
||||
f" received {len(response.closerPeers)} closer peers"
|
||||
)
|
||||
return None
|
||||
|
||||
except Exception as proto_err:
|
||||
logger.warning(f"Failed to parse as protobuf: {proto_err}")
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get value from peer {peer_id}: {e}")
|
||||
return None
|
||||
|
||||
finally:
|
||||
if stream:
|
||||
await stream.close()
|
||||
|
||||
def remove(self, key: bytes) -> bool:
|
||||
"""
|
||||
Remove a value from the DHT.
|
||||
|
||||
|
||||
params: key: The key to remove
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the key was found and removed, False otherwise
|
||||
|
||||
"""
|
||||
if key in self.store:
|
||||
del self.store[key]
|
||||
logger.debug(f"Removed value for key {key.hex()[:8]}...")
|
||||
return True
|
||||
return False
|
||||
|
||||
def has(self, key: bytes) -> bool:
|
||||
"""
|
||||
Check if a key exists in the store and hasn't expired.
|
||||
|
||||
params: key: The key to check
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the key exists and hasn't expired, False otherwise
|
||||
|
||||
"""
|
||||
if key not in self.store:
|
||||
return False
|
||||
|
||||
_, validity = self.store[key]
|
||||
if validity is not None and time.time() > validity:
|
||||
self.remove(key)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def cleanup_expired(self) -> int:
|
||||
"""
|
||||
Remove all expired values from the store.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
The number of expired values that were removed
|
||||
|
||||
"""
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
key for key, (_, validity) in self.store.items() if current_time > validity
|
||||
]
|
||||
|
||||
for key in expired_keys:
|
||||
del self.store[key]
|
||||
|
||||
if expired_keys:
|
||||
logger.debug(f"Cleaned up {len(expired_keys)} expired values")
|
||||
|
||||
return len(expired_keys)
|
||||
|
||||
def get_keys(self) -> list[bytes]:
|
||||
"""
|
||||
Get all non-expired keys in the store.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[bytes]
|
||||
List of keys
|
||||
|
||||
"""
|
||||
# Clean up expired values first
|
||||
self.cleanup_expired()
|
||||
return list(self.store.keys())
|
||||
|
||||
def size(self) -> int:
|
||||
"""
|
||||
Get the number of items in the store (after removing expired entries).
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Number of items
|
||||
|
||||
"""
|
||||
self.cleanup_expired()
|
||||
return len(self.store)
|
||||
@ -1,3 +1,7 @@
|
||||
from typing import (
|
||||
Optional,
|
||||
)
|
||||
|
||||
from libp2p.abc import (
|
||||
IRawConnection,
|
||||
)
|
||||
@ -28,7 +32,7 @@ class RawConnection(IRawConnection):
|
||||
except IOException as error:
|
||||
raise RawConnError from error
|
||||
|
||||
async def read(self, n: int | None = None) -> bytes:
|
||||
async def read(self, n: int = None) -> bytes:
|
||||
"""
|
||||
Read up to ``n`` bytes from the underlying stream. This call is
|
||||
delegated directly to the underlying ``self.reader``.
|
||||
@ -43,6 +47,6 @@ class RawConnection(IRawConnection):
|
||||
async def close(self) -> None:
|
||||
await self.stream.close()
|
||||
|
||||
def get_remote_address(self) -> tuple[str, int] | None:
|
||||
def get_remote_address(self) -> Optional[tuple[str, int]]:
|
||||
"""Delegate to the underlying stream's get_remote_address method."""
|
||||
return self.stream.get_remote_address()
|
||||
|
||||
@ -22,7 +22,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
"""
|
||||
Reference: https://github.com/libp2p/go-libp2p-swarm/blob/04c86bbdafd390651cb2ee14e334f7caeedad722/swarm_conn.go
|
||||
Reference: https://github.com/libp2p/go-libp2p-swarm/blob/04c86bbdafd390651cb2ee14e334f7caeedad722/swarm_conn.go # noqa: E501
|
||||
"""
|
||||
|
||||
|
||||
@ -32,11 +32,7 @@ class SwarmConn(INetConn):
|
||||
streams: set[NetStream]
|
||||
event_closed: trio.Event
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
muxed_conn: IMuxedConn,
|
||||
swarm: "Swarm",
|
||||
) -> None:
|
||||
def __init__(self, muxed_conn: IMuxedConn, swarm: "Swarm") -> None:
|
||||
self.muxed_conn = muxed_conn
|
||||
self.swarm = swarm
|
||||
self.streams = set()
|
||||
@ -44,7 +40,7 @@ class SwarmConn(INetConn):
|
||||
self.event_started = trio.Event()
|
||||
if hasattr(muxed_conn, "on_close"):
|
||||
logging.debug(f"Setting on_close for peer {muxed_conn.peer_id}")
|
||||
setattr(muxed_conn, "on_close", self._on_muxed_conn_closed)
|
||||
muxed_conn.on_close = self._on_muxed_conn_closed
|
||||
else:
|
||||
logging.error(
|
||||
f"muxed_conn for peer {muxed_conn.peer_id} has no on_close attribute"
|
||||
|
||||
@ -1,9 +1,7 @@
|
||||
from enum import (
|
||||
Enum,
|
||||
from typing import (
|
||||
Optional,
|
||||
)
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
IMuxedStream,
|
||||
INetStream,
|
||||
@ -25,103 +23,19 @@ from .exceptions import (
|
||||
)
|
||||
|
||||
|
||||
class StreamState(Enum):
|
||||
"""NetStream States"""
|
||||
|
||||
OPEN = "open"
|
||||
CLOSE_READ = "close_read"
|
||||
CLOSE_WRITE = "close_write"
|
||||
CLOSE_BOTH = "close_both"
|
||||
RESET = "reset"
|
||||
|
||||
|
||||
# TODO: Handle exceptions from `muxed_stream`
|
||||
# TODO: Add stream state
|
||||
# - Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501
|
||||
class NetStream(INetStream):
|
||||
"""
|
||||
Summary
|
||||
_______
|
||||
A Network stream implementation.
|
||||
|
||||
NetStream wraps a muxed stream and provides proper state tracking, resource cleanup,
|
||||
and event notification capabilities.
|
||||
|
||||
State Machine
|
||||
_____________
|
||||
|
||||
.. code:: markdown
|
||||
|
||||
[CREATED] → OPEN → CLOSE_READ → CLOSE_BOTH → [CLEANUP]
|
||||
↓ ↗ ↗
|
||||
CLOSE_WRITE → ← ↗
|
||||
↓ ↗
|
||||
RESET → → → → → → → →
|
||||
|
||||
State Transitions
|
||||
_________________
|
||||
- OPEN → CLOSE_READ: EOF encountered during read()
|
||||
- OPEN → CLOSE_WRITE: Explicit close() call
|
||||
- OPEN → RESET: reset() call or critical stream error
|
||||
- CLOSE_READ → CLOSE_BOTH: Explicit close() call
|
||||
- CLOSE_WRITE → CLOSE_BOTH: EOF encountered during read()
|
||||
- Any state → RESET: reset() call
|
||||
|
||||
Terminal States (trigger cleanup)
|
||||
_________________________________
|
||||
- CLOSE_BOTH: Stream fully closed, triggers resource cleanup
|
||||
- RESET: Stream reset/terminated, triggers resource cleanup
|
||||
|
||||
Operation Validity by State
|
||||
___________________________
|
||||
OPEN: read() ✓ write() ✓ close() ✓ reset() ✓
|
||||
CLOSE_READ: read() ✗ write() ✓ close() ✓ reset() ✓
|
||||
CLOSE_WRITE: read() ✓ write() ✗ close() ✓ reset() ✓
|
||||
CLOSE_BOTH: read() ✗ write() ✗ close() ✓ reset() ✓
|
||||
RESET: read() ✗ write() ✗ close() ✓ reset() ✓
|
||||
|
||||
Cleanup Process (triggered by CLOSE_BOTH or RESET)
|
||||
__________________________________________________
|
||||
1. Remove stream from SwarmConn
|
||||
2. Notify all listeners with ClosedStream event
|
||||
3. Decrement reference counter
|
||||
4. Background cleanup via nursery (if provided)
|
||||
|
||||
Thread Safety
|
||||
_____________
|
||||
All state operations are protected by trio.Lock() for safe concurrent access.
|
||||
State checks and modifications are atomic operations.
|
||||
|
||||
Example: See :file:`examples/doc-examples/example_net_stream.py`
|
||||
|
||||
:param muxed_stream (IMuxedStream): The underlying muxed stream
|
||||
:param nursery (Optional[trio.Nursery]): Nursery for background cleanup tasks
|
||||
:raises StreamClosed: When attempting invalid operations on closed streams
|
||||
:raises StreamEOF: When EOF is encountered during read operations
|
||||
:raises StreamReset: When the underlying stream has been reset
|
||||
"""
|
||||
|
||||
muxed_stream: IMuxedStream
|
||||
protocol_id: TProtocol | None
|
||||
__stream_state: StreamState
|
||||
|
||||
def __init__(
|
||||
self, muxed_stream: IMuxedStream, nursery: trio.Nursery | None = None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
protocol_id: Optional[TProtocol]
|
||||
|
||||
def __init__(self, muxed_stream: IMuxedStream) -> None:
|
||||
self.muxed_stream = muxed_stream
|
||||
self.muxed_conn = muxed_stream.muxed_conn
|
||||
self.protocol_id = None
|
||||
|
||||
# For background tasks
|
||||
self._nursery = nursery
|
||||
|
||||
# State management
|
||||
self.__stream_state = StreamState.OPEN
|
||||
self._state_lock = trio.Lock()
|
||||
|
||||
# For notification handling
|
||||
self._notify_lock = trio.Lock()
|
||||
|
||||
def get_protocol(self) -> TProtocol | None:
|
||||
def get_protocol(self) -> TProtocol:
|
||||
"""
|
||||
:return: protocol id that stream runs on
|
||||
"""
|
||||
@ -133,176 +47,42 @@ class NetStream(INetStream):
|
||||
"""
|
||||
self.protocol_id = protocol_id
|
||||
|
||||
@property
|
||||
async def state(self) -> StreamState:
|
||||
"""Get current stream state."""
|
||||
async with self._state_lock:
|
||||
return self.__stream_state
|
||||
|
||||
async def read(self, n: int | None = None) -> bytes:
|
||||
async def read(self, n: int = None) -> bytes:
|
||||
"""
|
||||
Read from stream.
|
||||
|
||||
:param n: number of bytes to read
|
||||
:raises StreamClosed: If `NetStream` is closed for reading
|
||||
:raises StreamReset: If `NetStream` is reset
|
||||
:raises StreamEOF: If trying to read after reaching end of file
|
||||
:return: Bytes read from the stream
|
||||
:return: bytes of input
|
||||
"""
|
||||
async with self._state_lock:
|
||||
if self.__stream_state in [
|
||||
StreamState.CLOSE_READ,
|
||||
StreamState.CLOSE_BOTH,
|
||||
]:
|
||||
raise StreamClosed("Stream is closed for reading")
|
||||
|
||||
if self.__stream_state == StreamState.RESET:
|
||||
raise StreamReset("Stream is reset, cannot be used to read")
|
||||
|
||||
try:
|
||||
data = await self.muxed_stream.read(n)
|
||||
return data
|
||||
return await self.muxed_stream.read(n)
|
||||
except MuxedStreamEOF as error:
|
||||
async with self._state_lock:
|
||||
if self.__stream_state == StreamState.CLOSE_WRITE:
|
||||
self.__stream_state = StreamState.CLOSE_BOTH
|
||||
await self._remove()
|
||||
elif self.__stream_state == StreamState.OPEN:
|
||||
self.__stream_state = StreamState.CLOSE_READ
|
||||
raise StreamEOF() from error
|
||||
except MuxedStreamReset as error:
|
||||
async with self._state_lock:
|
||||
if self.__stream_state in [
|
||||
StreamState.OPEN,
|
||||
StreamState.CLOSE_READ,
|
||||
StreamState.CLOSE_WRITE,
|
||||
]:
|
||||
self.__stream_state = StreamState.RESET
|
||||
await self._remove()
|
||||
raise StreamReset() from error
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
"""
|
||||
Write to stream.
|
||||
|
||||
:param data: bytes to write
|
||||
:raises StreamClosed: If `NetStream` is closed for writing or reset
|
||||
:raises StreamClosed: If `StreamError` occurred while writing
|
||||
:return: number of bytes written
|
||||
"""
|
||||
async with self._state_lock:
|
||||
if self.__stream_state in [
|
||||
StreamState.CLOSE_WRITE,
|
||||
StreamState.CLOSE_BOTH,
|
||||
StreamState.RESET,
|
||||
]:
|
||||
raise StreamClosed("Stream is closed for writing")
|
||||
|
||||
try:
|
||||
await self.muxed_stream.write(data)
|
||||
except (MuxedStreamClosed, MuxedStreamError) as error:
|
||||
async with self._state_lock:
|
||||
if self.__stream_state == StreamState.OPEN:
|
||||
self.__stream_state = StreamState.CLOSE_WRITE
|
||||
elif self.__stream_state == StreamState.CLOSE_READ:
|
||||
self.__stream_state = StreamState.CLOSE_BOTH
|
||||
await self._remove()
|
||||
raise StreamClosed() from error
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close stream for writing."""
|
||||
async with self._state_lock:
|
||||
if self.__stream_state in [
|
||||
StreamState.CLOSE_BOTH,
|
||||
StreamState.RESET,
|
||||
StreamState.CLOSE_WRITE,
|
||||
]:
|
||||
return
|
||||
|
||||
"""Close stream."""
|
||||
await self.muxed_stream.close()
|
||||
|
||||
async with self._state_lock:
|
||||
if self.__stream_state == StreamState.CLOSE_READ:
|
||||
self.__stream_state = StreamState.CLOSE_BOTH
|
||||
await self._remove()
|
||||
elif self.__stream_state == StreamState.OPEN:
|
||||
self.__stream_state = StreamState.CLOSE_WRITE
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Reset stream, closing both ends."""
|
||||
async with self._state_lock:
|
||||
if self.__stream_state == StreamState.RESET:
|
||||
return
|
||||
|
||||
await self.muxed_stream.reset()
|
||||
|
||||
async with self._state_lock:
|
||||
if self.__stream_state in [
|
||||
StreamState.OPEN,
|
||||
StreamState.CLOSE_READ,
|
||||
StreamState.CLOSE_WRITE,
|
||||
]:
|
||||
self.__stream_state = StreamState.RESET
|
||||
await self._remove()
|
||||
|
||||
async def _remove(self) -> None:
|
||||
"""
|
||||
Remove stream from connection and notify listeners.
|
||||
This is called when the stream is fully closed or reset.
|
||||
"""
|
||||
if hasattr(self.muxed_conn, "remove_stream"):
|
||||
remove_stream = getattr(self.muxed_conn, "remove_stream")
|
||||
await remove_stream(self)
|
||||
|
||||
# Notify in background using Trio nursery if available
|
||||
if self._nursery:
|
||||
self._nursery.start_soon(self._notify_closed)
|
||||
else:
|
||||
await self._notify_closed()
|
||||
|
||||
async def _notify_closed(self) -> None:
|
||||
"""
|
||||
Notify all listeners that the stream has been closed.
|
||||
This runs in a separate task to avoid blocking the main flow.
|
||||
"""
|
||||
async with self._notify_lock:
|
||||
if hasattr(self.muxed_conn, "swarm"):
|
||||
swarm = getattr(self.muxed_conn, "swarm")
|
||||
|
||||
if hasattr(swarm, "notify_all"):
|
||||
await swarm.notify_all(
|
||||
lambda notifiee: notifiee.closed_stream(swarm, self)
|
||||
)
|
||||
|
||||
if hasattr(swarm, "refs") and hasattr(swarm.refs, "done"):
|
||||
swarm.refs.done()
|
||||
|
||||
def get_remote_address(self) -> tuple[str, int] | None:
|
||||
def get_remote_address(self) -> Optional[tuple[str, int]]:
|
||||
"""Delegate to the underlying muxed stream."""
|
||||
return self.muxed_stream.get_remote_address()
|
||||
|
||||
async def is_closed(self) -> bool:
|
||||
"""Check if stream is closed."""
|
||||
current_state = await self.state
|
||||
return current_state in [StreamState.CLOSE_BOTH, StreamState.RESET]
|
||||
|
||||
async def is_readable(self) -> bool:
|
||||
"""Check if stream is readable."""
|
||||
current_state = await self.state
|
||||
return current_state not in [
|
||||
StreamState.CLOSE_READ,
|
||||
StreamState.CLOSE_BOTH,
|
||||
StreamState.RESET,
|
||||
]
|
||||
|
||||
async def is_writable(self) -> bool:
|
||||
"""Check if stream is writable."""
|
||||
current_state = await self.state
|
||||
return current_state not in [
|
||||
StreamState.CLOSE_WRITE,
|
||||
StreamState.CLOSE_BOTH,
|
||||
StreamState.RESET,
|
||||
]
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the stream."""
|
||||
return f"<NetStream[{self.__stream_state.value}] protocol={self.protocol_id}>"
|
||||
# TODO: `remove`: Called by close and write when the stream is in specific states.
|
||||
# It notifies `ClosedStream` after `SwarmConn.remove_stream` is called.
|
||||
# Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501
|
||||
|
||||
@ -1,4 +1,7 @@
|
||||
import logging
|
||||
from typing import (
|
||||
Optional,
|
||||
)
|
||||
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
@ -72,7 +75,7 @@ class Swarm(Service, INetworkService):
|
||||
connections: dict[ID, INetConn]
|
||||
listeners: dict[str, IListener]
|
||||
common_stream_handler: StreamHandlerFn
|
||||
listener_nursery: trio.Nursery | None
|
||||
listener_nursery: Optional[trio.Nursery]
|
||||
event_listener_nursery_created: trio.Event
|
||||
|
||||
notifees: list[INotifee]
|
||||
@ -187,7 +190,7 @@ class Swarm(Service, INetworkService):
|
||||
# Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure
|
||||
# the conn and then mux the conn
|
||||
try:
|
||||
secured_conn = await self.upgrader.upgrade_security(raw_conn, True, peer_id)
|
||||
secured_conn = await self.upgrader.upgrade_security(raw_conn, peer_id, True)
|
||||
except SecurityUpgradeFailure as error:
|
||||
logger.debug("failed to upgrade security for peer %s", peer_id)
|
||||
await raw_conn.close()
|
||||
@ -257,7 +260,10 @@ class Swarm(Service, INetworkService):
|
||||
# Per, https://discuss.libp2p.io/t/multistream-security/130, we first
|
||||
# secure the conn and then mux the conn
|
||||
try:
|
||||
secured_conn = await self.upgrader.upgrade_security(raw_conn, False)
|
||||
# FIXME: This dummy `ID(b"")` for the remote peer is useless.
|
||||
secured_conn = await self.upgrader.upgrade_security(
|
||||
raw_conn, ID(b""), False
|
||||
)
|
||||
except SecurityUpgradeFailure as error:
|
||||
logger.debug("failed to upgrade security for peer at %s", maddr)
|
||||
await raw_conn.close()
|
||||
@ -334,9 +340,7 @@ class Swarm(Service, INetworkService):
|
||||
if hasattr(self, "transport") and self.transport is not None:
|
||||
# Check if transport has close method before calling it
|
||||
if hasattr(self.transport, "close"):
|
||||
await self.transport.close() # type: ignore
|
||||
# Ignoring the type above since `transport` may not have a close method
|
||||
# and we have already checked it with hasattr
|
||||
await self.transport.close()
|
||||
|
||||
logger.debug("swarm successfully closed")
|
||||
|
||||
@ -356,11 +360,7 @@ class Swarm(Service, INetworkService):
|
||||
and start to monitor the connection for its new streams and
|
||||
disconnection.
|
||||
"""
|
||||
swarm_conn = SwarmConn(
|
||||
muxed_conn,
|
||||
self,
|
||||
)
|
||||
|
||||
swarm_conn = SwarmConn(muxed_conn, self)
|
||||
self.manager.run_task(muxed_conn.start)
|
||||
await muxed_conn.event_started.wait()
|
||||
self.manager.run_task(swarm_conn.start)
|
||||
|
||||
@ -1,4 +1,7 @@
|
||||
import hashlib
|
||||
from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
import base58
|
||||
import multihash
|
||||
@ -21,7 +24,7 @@ if ENABLE_INLINING:
|
||||
_digest: bytes
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._digest = b""
|
||||
self._digest = bytearray()
|
||||
|
||||
def update(self, input: bytes) -> None:
|
||||
self._digest += input
|
||||
@ -36,8 +39,8 @@ if ENABLE_INLINING:
|
||||
|
||||
class ID:
|
||||
_bytes: bytes
|
||||
_xor_id: int | None = None
|
||||
_b58_str: str | None = None
|
||||
_xor_id: int = None
|
||||
_b58_str: str = None
|
||||
|
||||
def __init__(self, peer_id_bytes: bytes) -> None:
|
||||
self._bytes = peer_id_bytes
|
||||
@ -90,7 +93,7 @@ class ID:
|
||||
return cls(mh_digest.encode())
|
||||
|
||||
|
||||
def sha256_digest(data: str | bytes) -> bytes:
|
||||
def sha256_digest(data: Union[str, bytes]) -> bytes:
|
||||
if isinstance(data, str):
|
||||
data = data.encode("utf8")
|
||||
return hashlib.sha256(data).digest()
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from collections.abc import (
|
||||
Sequence,
|
||||
)
|
||||
import time
|
||||
from typing import (
|
||||
Any,
|
||||
)
|
||||
@ -18,23 +17,13 @@ from libp2p.crypto.keys import (
|
||||
PublicKey,
|
||||
)
|
||||
|
||||
"""
|
||||
Latency EWMA Smoothing governs the deacy of the EWMA (the speed at which
|
||||
is changes). This must be a normalized (0-1) value.
|
||||
1 is 100% change, 0 is no change.
|
||||
"""
|
||||
LATENCY_EWMA_SMOOTHING = 0.1
|
||||
|
||||
|
||||
class PeerData(IPeerData):
|
||||
pubkey: PublicKey | None
|
||||
privkey: PrivateKey | None
|
||||
pubkey: PublicKey
|
||||
privkey: PrivateKey
|
||||
metadata: dict[Any, Any]
|
||||
protocols: list[str]
|
||||
addrs: list[Multiaddr]
|
||||
last_identified: int
|
||||
ttl: int # Keep ttl=0 by default for always valid
|
||||
latmap: float
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.pubkey = None
|
||||
@ -42,11 +31,6 @@ class PeerData(IPeerData):
|
||||
self.metadata = {}
|
||||
self.protocols = []
|
||||
self.addrs = []
|
||||
self.last_identified = int(time.time())
|
||||
self.ttl = 0
|
||||
self.latmap = 0
|
||||
|
||||
# --------PROTO-BOOK--------
|
||||
|
||||
def get_protocols(self) -> list[str]:
|
||||
"""
|
||||
@ -66,37 +50,6 @@ class PeerData(IPeerData):
|
||||
"""
|
||||
self.protocols = list(protocols)
|
||||
|
||||
def remove_protocols(self, protocols: Sequence[str]) -> None:
|
||||
"""
|
||||
:param protocols: protocols to remove
|
||||
"""
|
||||
for protocol in protocols:
|
||||
if protocol in self.protocols:
|
||||
self.protocols.remove(protocol)
|
||||
|
||||
def supports_protocols(self, protocols: Sequence[str]) -> list[str]:
|
||||
"""
|
||||
:param protocols: protocols to check from
|
||||
:return: all supported protocols in the given list
|
||||
"""
|
||||
return [proto for proto in protocols if proto in self.protocols]
|
||||
|
||||
def first_supported_protocol(self, protocols: Sequence[str]) -> str:
|
||||
"""
|
||||
:param protocols: protocols to check from
|
||||
:return: first supported protocol in the given list
|
||||
"""
|
||||
for protocol in protocols:
|
||||
if protocol in self.protocols:
|
||||
return protocol
|
||||
|
||||
return "None supported"
|
||||
|
||||
def clear_protocol_data(self) -> None:
|
||||
"""Clear all protocols"""
|
||||
self.protocols = []
|
||||
|
||||
# -------ADDR-BOOK---------
|
||||
def add_addrs(self, addrs: Sequence[Multiaddr]) -> None:
|
||||
"""
|
||||
:param addrs: multiaddresses to add
|
||||
@ -115,7 +68,6 @@ class PeerData(IPeerData):
|
||||
"""Clear all addresses."""
|
||||
self.addrs = []
|
||||
|
||||
# -------METADATA-----------
|
||||
def put_metadata(self, key: str, val: Any) -> None:
|
||||
"""
|
||||
:param key: key in KV pair
|
||||
@ -133,11 +85,6 @@ class PeerData(IPeerData):
|
||||
return self.metadata[key]
|
||||
raise PeerDataError("key not found")
|
||||
|
||||
def clear_metadata(self) -> None:
|
||||
"""Clears metadata."""
|
||||
self.metadata = {}
|
||||
|
||||
# -------KEY-BOOK---------------
|
||||
def add_pubkey(self, pubkey: PublicKey) -> None:
|
||||
"""
|
||||
:param pubkey:
|
||||
@ -168,68 +115,6 @@ class PeerData(IPeerData):
|
||||
raise PeerDataError("private key not found")
|
||||
return self.privkey
|
||||
|
||||
def clear_keydata(self) -> None:
|
||||
"""Clears keydata"""
|
||||
self.pubkey = None
|
||||
self.privkey = None
|
||||
|
||||
# ----------METRICS--------------
|
||||
def record_latency(self, new_latency: float) -> None:
|
||||
"""
|
||||
Records a new latency measurement for the given peer
|
||||
using Exponentially Weighted Moving Average (EWMA)
|
||||
:param new_latency: the new latency value
|
||||
"""
|
||||
s = LATENCY_EWMA_SMOOTHING
|
||||
if s > 1 or s < 0:
|
||||
s = 0.1
|
||||
|
||||
if self.latmap == 0:
|
||||
self.latmap = new_latency
|
||||
else:
|
||||
prev = self.latmap
|
||||
updated = ((1.0 - s) * prev) + (s * new_latency)
|
||||
self.latmap = updated
|
||||
|
||||
def latency_EWMA(self) -> float:
|
||||
"""Returns the latency EWMA value"""
|
||||
return self.latmap
|
||||
|
||||
def clear_metrics(self) -> None:
|
||||
"""Clear the latency metrics"""
|
||||
self.latmap = 0
|
||||
|
||||
def update_last_identified(self) -> None:
|
||||
self.last_identified = int(time.time())
|
||||
|
||||
# ----------TTL------------------
|
||||
def get_last_identified(self) -> int:
|
||||
"""
|
||||
:return: last identified timestamp
|
||||
"""
|
||||
return self.last_identified
|
||||
|
||||
def get_ttl(self) -> int:
|
||||
"""
|
||||
:return: ttl for current peer
|
||||
"""
|
||||
return self.ttl
|
||||
|
||||
def set_ttl(self, ttl: int) -> None:
|
||||
"""
|
||||
:param ttl: ttl to set
|
||||
"""
|
||||
self.ttl = ttl
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""
|
||||
:return: true, if last_identified+ttl > current_time
|
||||
"""
|
||||
# for ttl = 0; peer_data is always valid
|
||||
if self.ttl > 0 and self.last_identified + self.ttl < int(time.time()):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class PeerDataError(KeyError):
|
||||
"""Raised when a key is not found in peer metadata."""
|
||||
|
||||
@ -3,11 +3,9 @@ from collections.abc import (
|
||||
)
|
||||
from typing import (
|
||||
Any,
|
||||
cast,
|
||||
)
|
||||
|
||||
import multiaddr
|
||||
from multiaddr.protocols import Protocol
|
||||
|
||||
from .id import (
|
||||
ID,
|
||||
@ -34,32 +32,21 @@ def info_from_p2p_addr(addr: multiaddr.Multiaddr) -> PeerInfo:
|
||||
if not addr:
|
||||
raise InvalidAddrError("`addr` should not be `None`")
|
||||
|
||||
parts: list[multiaddr.Multiaddr] = addr.split()
|
||||
parts = addr.split()
|
||||
if not parts:
|
||||
raise InvalidAddrError(
|
||||
f"`parts`={parts} should at least have a protocol `P_P2P`"
|
||||
)
|
||||
|
||||
p2p_part = parts[-1]
|
||||
p2p_protocols = p2p_part.protocols()
|
||||
if not p2p_protocols:
|
||||
raise InvalidAddrError("The last part of the address has no protocols")
|
||||
last_protocol = cast(Protocol, p2p_part.protocols()[0])
|
||||
|
||||
if last_protocol is None:
|
||||
raise InvalidAddrError("The last protocol is None")
|
||||
|
||||
last_protocol_code = last_protocol.code
|
||||
if last_protocol_code != multiaddr.multiaddr.protocols.P_P2P:
|
||||
last_protocol_code = p2p_part.protocols()[0].code
|
||||
if last_protocol_code != multiaddr.protocols.P_P2P:
|
||||
raise InvalidAddrError(
|
||||
f"The last protocol should be `P_P2P` instead of `{last_protocol_code}`"
|
||||
)
|
||||
|
||||
# make sure the /p2p value parses as a peer.ID
|
||||
peer_id_str = p2p_part.value_for_protocol(multiaddr.multiaddr.protocols.P_P2P)
|
||||
if peer_id_str is None:
|
||||
raise InvalidAddrError("Missing value for /p2p protocol in multiaddr")
|
||||
|
||||
peer_id_str: str = p2p_part.value_for_protocol(multiaddr.protocols.P_P2P)
|
||||
peer_id: ID = ID.from_base58(peer_id_str)
|
||||
|
||||
# we might have received just an / p2p part, which means there's no addr.
|
||||
@ -69,23 +56,5 @@ def info_from_p2p_addr(addr: multiaddr.Multiaddr) -> PeerInfo:
|
||||
return PeerInfo(peer_id, [addr])
|
||||
|
||||
|
||||
def peer_info_to_bytes(peer_info: PeerInfo) -> bytes:
|
||||
lines = [str(peer_info.peer_id)] + [str(addr) for addr in peer_info.addrs]
|
||||
return "\n".join(lines).encode("utf-8")
|
||||
|
||||
|
||||
def peer_info_from_bytes(data: bytes) -> PeerInfo:
|
||||
try:
|
||||
lines = data.decode("utf-8").splitlines()
|
||||
if not lines:
|
||||
raise InvalidAddrError("no data to decode PeerInfo")
|
||||
|
||||
peer_id = ID.from_base58(lines[0])
|
||||
addrs = [multiaddr.Multiaddr(addr_str) for addr_str in lines[1:]]
|
||||
return PeerInfo(peer_id, addrs)
|
||||
except Exception as e:
|
||||
raise InvalidAddrError(f"failed to decode PeerInfo: {e}")
|
||||
|
||||
|
||||
class InvalidAddrError(ValueError):
|
||||
pass
|
||||
|
||||
@ -2,9 +2,9 @@ from collections import (
|
||||
defaultdict,
|
||||
)
|
||||
from collections.abc import (
|
||||
AsyncIterable,
|
||||
Sequence,
|
||||
)
|
||||
import sys
|
||||
from typing import (
|
||||
Any,
|
||||
)
|
||||
@ -12,8 +12,6 @@ from typing import (
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
)
|
||||
import trio
|
||||
from trio import MemoryReceiveChannel, MemorySendChannel
|
||||
|
||||
from libp2p.abc import (
|
||||
IPeerStore,
|
||||
@ -35,7 +33,7 @@ from .peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
|
||||
PERMANENT_ADDR_TTL = 0
|
||||
PERMANENT_ADDR_TTL = sys.maxsize
|
||||
|
||||
|
||||
class PeerStore(IPeerStore):
|
||||
@ -43,7 +41,6 @@ class PeerStore(IPeerStore):
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.peer_data_map = defaultdict(PeerData)
|
||||
self.addr_update_channels: dict[ID, MemorySendChannel[Multiaddr]] = {}
|
||||
|
||||
def peer_info(self, peer_id: ID) -> PeerInfo:
|
||||
"""
|
||||
@ -52,38 +49,9 @@ class PeerStore(IPeerStore):
|
||||
"""
|
||||
if peer_id in self.peer_data_map:
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
if peer_data.is_expired():
|
||||
peer_data.clear_addrs()
|
||||
return PeerInfo(peer_id, peer_data.get_addrs())
|
||||
raise PeerStoreError("peer ID not found")
|
||||
|
||||
def peer_ids(self) -> list[ID]:
|
||||
"""
|
||||
:return: all of the peer IDs stored in peer store
|
||||
"""
|
||||
return list(self.peer_data_map.keys())
|
||||
|
||||
def clear_peerdata(self, peer_id: ID) -> None:
|
||||
"""Clears all data associated with the given peer_id."""
|
||||
if peer_id in self.peer_data_map:
|
||||
del self.peer_data_map[peer_id]
|
||||
else:
|
||||
raise PeerStoreError("peer ID not found")
|
||||
|
||||
def valid_peer_ids(self) -> list[ID]:
|
||||
"""
|
||||
:return: all of the valid peer IDs stored in peer store
|
||||
"""
|
||||
valid_peer_ids: list[ID] = []
|
||||
for peer_id, peer_data in self.peer_data_map.items():
|
||||
if not peer_data.is_expired():
|
||||
valid_peer_ids.append(peer_id)
|
||||
else:
|
||||
peer_data.clear_addrs()
|
||||
return valid_peer_ids
|
||||
|
||||
# --------PROTO-BOOK--------
|
||||
|
||||
def get_protocols(self, peer_id: ID) -> list[str]:
|
||||
"""
|
||||
:param peer_id: peer ID to get protocols for
|
||||
@ -110,31 +78,11 @@ class PeerStore(IPeerStore):
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
peer_data.set_protocols(list(protocols))
|
||||
|
||||
def remove_protocols(self, peer_id: ID, protocols: Sequence[str]) -> None:
|
||||
"""
|
||||
:param peer_id: peer ID to get info for
|
||||
:param protocols: unsupported protocols to remove
|
||||
"""
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
peer_data.remove_protocols(protocols)
|
||||
|
||||
def supports_protocols(self, peer_id: ID, protocols: Sequence[str]) -> list[str]:
|
||||
def peer_ids(self) -> list[ID]:
|
||||
"""
|
||||
:return: all of the peer IDs stored in peer store
|
||||
"""
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
return peer_data.supports_protocols(protocols)
|
||||
|
||||
def first_supported_protocol(self, peer_id: ID, protocols: Sequence[str]) -> str:
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
return peer_data.first_supported_protocol(protocols)
|
||||
|
||||
def clear_protocol_data(self, peer_id: ID) -> None:
|
||||
"""Clears prtocoldata"""
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
peer_data.clear_protocol_data()
|
||||
|
||||
# ------METADATA---------
|
||||
return list(self.peer_data_map.keys())
|
||||
|
||||
def get(self, peer_id: ID, key: str) -> Any:
|
||||
"""
|
||||
@ -160,14 +108,7 @@ class PeerStore(IPeerStore):
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
peer_data.put_metadata(key, val)
|
||||
|
||||
def clear_metadata(self, peer_id: ID) -> None:
|
||||
"""Clears metadata"""
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
peer_data.clear_metadata()
|
||||
|
||||
# -------ADDR-BOOK--------
|
||||
|
||||
def add_addr(self, peer_id: ID, addr: Multiaddr, ttl: int = 0) -> None:
|
||||
def add_addr(self, peer_id: ID, addr: Multiaddr, ttl: int) -> None:
|
||||
"""
|
||||
:param peer_id: peer ID to add address for
|
||||
:param addr:
|
||||
@ -175,37 +116,24 @@ class PeerStore(IPeerStore):
|
||||
"""
|
||||
self.add_addrs(peer_id, [addr], ttl)
|
||||
|
||||
def add_addrs(self, peer_id: ID, addrs: Sequence[Multiaddr], ttl: int = 0) -> None:
|
||||
def add_addrs(self, peer_id: ID, addrs: Sequence[Multiaddr], ttl: int) -> None:
|
||||
"""
|
||||
:param peer_id: peer ID to add address for
|
||||
:param addrs:
|
||||
:param ttl: time-to-live for the this record
|
||||
"""
|
||||
# Ignore ttl for now
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
peer_data.add_addrs(list(addrs))
|
||||
peer_data.set_ttl(ttl)
|
||||
peer_data.update_last_identified()
|
||||
|
||||
if peer_id in self.addr_update_channels:
|
||||
for addr in addrs:
|
||||
try:
|
||||
self.addr_update_channels[peer_id].send_nowait(addr)
|
||||
except trio.WouldBlock:
|
||||
pass # Or consider logging / dropping / replacing stream
|
||||
|
||||
def addrs(self, peer_id: ID) -> list[Multiaddr]:
|
||||
"""
|
||||
:param peer_id: peer ID to get addrs for
|
||||
:return: list of addrs of a valid peer.
|
||||
:return: list of addrs
|
||||
:raise PeerStoreError: if peer ID not found
|
||||
"""
|
||||
if peer_id in self.peer_data_map:
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
if not peer_data.is_expired():
|
||||
return peer_data.get_addrs()
|
||||
else:
|
||||
peer_data.clear_addrs()
|
||||
raise PeerStoreError("peer ID is expired")
|
||||
return self.peer_data_map[peer_id].get_addrs()
|
||||
raise PeerStoreError("peer ID not found")
|
||||
|
||||
def clear_addrs(self, peer_id: ID) -> None:
|
||||
@ -218,41 +146,16 @@ class PeerStore(IPeerStore):
|
||||
|
||||
def peers_with_addrs(self) -> list[ID]:
|
||||
"""
|
||||
:return: all of the peer IDs which has addrsfloat stored in peer store
|
||||
:return: all of the peer IDs which has addrs stored in peer store
|
||||
"""
|
||||
# Add all peers with addrs at least 1 to output
|
||||
output: list[ID] = []
|
||||
|
||||
for peer_id in self.peer_data_map:
|
||||
if len(self.peer_data_map[peer_id].get_addrs()) >= 1:
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
if not peer_data.is_expired():
|
||||
output.append(peer_id)
|
||||
else:
|
||||
peer_data.clear_addrs()
|
||||
output.append(peer_id)
|
||||
return output
|
||||
|
||||
async def addr_stream(self, peer_id: ID) -> AsyncIterable[Multiaddr]:
|
||||
"""
|
||||
Returns an async stream of newly added addresses for the given peer.
|
||||
|
||||
This function allows consumers to subscribe to address updates for a peer
|
||||
and receive each new address as it is added via `add_addr` or `add_addrs`.
|
||||
|
||||
:param peer_id: The ID of the peer to monitor address updates for.
|
||||
:return: An async iterator yielding Multiaddr instances as they are added.
|
||||
"""
|
||||
send: MemorySendChannel[Multiaddr]
|
||||
receive: MemoryReceiveChannel[Multiaddr]
|
||||
|
||||
send, receive = trio.open_memory_channel(0)
|
||||
self.addr_update_channels[peer_id] = send
|
||||
|
||||
async for addr in receive:
|
||||
yield addr
|
||||
|
||||
# -------KEY-BOOK---------
|
||||
|
||||
def add_pubkey(self, peer_id: ID, pubkey: PublicKey) -> None:
|
||||
"""
|
||||
:param peer_id: peer ID to add public key for
|
||||
@ -313,45 +216,6 @@ class PeerStore(IPeerStore):
|
||||
self.add_pubkey(peer_id, key_pair.public_key)
|
||||
self.add_privkey(peer_id, key_pair.private_key)
|
||||
|
||||
def peer_with_keys(self) -> list[ID]:
|
||||
"""Returns the peer_ids for which keys are stored"""
|
||||
return [
|
||||
peer_id
|
||||
for peer_id, pdata in self.peer_data_map.items()
|
||||
if pdata.pubkey is not None
|
||||
]
|
||||
|
||||
def clear_keydata(self, peer_id: ID) -> None:
|
||||
"""Clears the keys of the peer"""
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
peer_data.clear_keydata()
|
||||
|
||||
# --------METRICS--------
|
||||
|
||||
def record_latency(self, peer_id: ID, RTT: float) -> None:
|
||||
"""
|
||||
Records a new latency measurement for the given peer
|
||||
using Exponentially Weighted Moving Average (EWMA)
|
||||
|
||||
:param peer_id: peer ID to get private key for
|
||||
:param RTT: the new latency value (round trip time)
|
||||
"""
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
peer_data.record_latency(RTT)
|
||||
|
||||
def latency_EWMA(self, peer_id: ID) -> float:
|
||||
"""
|
||||
:param peer_id: peer ID to get private key for
|
||||
:return: The latency EWMA value for that peer
|
||||
"""
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
return peer_data.latency_EWMA()
|
||||
|
||||
def clear_metrics(self, peer_id: ID) -> None:
|
||||
"""Clear the latency metrics"""
|
||||
peer_data = self.peer_data_map[peer_id]
|
||||
peer_data.clear_metrics()
|
||||
|
||||
|
||||
class PeerStoreError(KeyError):
|
||||
"""Raised when peer ID is not found in peer store."""
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
IMultiselectCommunicator,
|
||||
IMultiselectMuxer,
|
||||
@ -16,7 +14,6 @@ from .exceptions import (
|
||||
|
||||
MULTISELECT_PROTOCOL_ID = "/multistream/1.0.0"
|
||||
PROTOCOL_NOT_FOUND_MSG = "na"
|
||||
DEFAULT_NEGOTIATE_TIMEOUT = 5
|
||||
|
||||
|
||||
class Multiselect(IMultiselectMuxer):
|
||||
@ -26,20 +23,16 @@ class Multiselect(IMultiselectMuxer):
|
||||
communication.
|
||||
"""
|
||||
|
||||
handlers: dict[TProtocol | None, StreamHandlerFn | None]
|
||||
handlers: dict[TProtocol, StreamHandlerFn]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default_handlers: None
|
||||
| (dict[TProtocol | None, StreamHandlerFn | None]) = None,
|
||||
self, default_handlers: dict[TProtocol, StreamHandlerFn] = None
|
||||
) -> None:
|
||||
if not default_handlers:
|
||||
default_handlers = {}
|
||||
self.handlers = default_handlers
|
||||
|
||||
def add_handler(
|
||||
self, protocol: TProtocol | None, handler: StreamHandlerFn | None
|
||||
) -> None:
|
||||
def add_handler(self, protocol: TProtocol, handler: StreamHandlerFn) -> None:
|
||||
"""
|
||||
Store the handler with the given protocol.
|
||||
|
||||
@ -48,70 +41,46 @@ class Multiselect(IMultiselectMuxer):
|
||||
"""
|
||||
self.handlers[protocol] = handler
|
||||
|
||||
# FIXME: Make TProtocol Optional[TProtocol] to keep types consistent
|
||||
async def negotiate(
|
||||
self,
|
||||
communicator: IMultiselectCommunicator,
|
||||
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
) -> tuple[TProtocol, StreamHandlerFn | None]:
|
||||
self, communicator: IMultiselectCommunicator
|
||||
) -> tuple[TProtocol, StreamHandlerFn]:
|
||||
"""
|
||||
Negotiate performs protocol selection.
|
||||
|
||||
:param stream: stream to negotiate on
|
||||
:param negotiate_timeout: timeout for negotiation
|
||||
:return: selected protocol name, handler function
|
||||
:raise MultiselectError: raised when negotiation failed
|
||||
"""
|
||||
try:
|
||||
with trio.fail_after(negotiate_timeout):
|
||||
await self.handshake(communicator)
|
||||
await self.handshake(communicator)
|
||||
|
||||
while True:
|
||||
while True:
|
||||
try:
|
||||
command = await communicator.read()
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectError() from error
|
||||
|
||||
if command == "ls":
|
||||
supported_protocols = list(self.handlers.keys())
|
||||
response = "\n".join(supported_protocols) + "\n"
|
||||
|
||||
try:
|
||||
await communicator.write(response)
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectError() from error
|
||||
|
||||
else:
|
||||
protocol = TProtocol(command)
|
||||
if protocol in self.handlers:
|
||||
try:
|
||||
command = await communicator.read()
|
||||
await communicator.write(protocol)
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectError() from error
|
||||
|
||||
if command == "ls":
|
||||
supported_protocols = [
|
||||
p for p in self.handlers.keys() if p is not None
|
||||
]
|
||||
response = "\n".join(supported_protocols) + "\n"
|
||||
|
||||
try:
|
||||
await communicator.write(response)
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectError() from error
|
||||
|
||||
else:
|
||||
protocol = TProtocol(command)
|
||||
if protocol in self.handlers:
|
||||
try:
|
||||
await communicator.write(protocol)
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectError() from error
|
||||
|
||||
return protocol, self.handlers[protocol]
|
||||
try:
|
||||
await communicator.write(PROTOCOL_NOT_FOUND_MSG)
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectError() from error
|
||||
|
||||
raise MultiselectError("Negotiation failed: no matching protocol")
|
||||
except trio.TooSlowError:
|
||||
raise MultiselectError("handshake read timeout")
|
||||
|
||||
def get_protocols(self) -> tuple[TProtocol | None, ...]:
|
||||
"""
|
||||
Retrieve the protocols for which handlers have been registered.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[TProtocol, ...]
|
||||
A tuple of registered protocol names.
|
||||
|
||||
"""
|
||||
return tuple(self.handlers.keys())
|
||||
return protocol, self.handlers[protocol]
|
||||
try:
|
||||
await communicator.write(PROTOCOL_NOT_FOUND_MSG)
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectError() from error
|
||||
|
||||
async def handshake(self, communicator: IMultiselectCommunicator) -> None:
|
||||
"""
|
||||
|
||||
@ -2,8 +2,6 @@ from collections.abc import (
|
||||
Sequence,
|
||||
)
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
IMultiselectClient,
|
||||
IMultiselectCommunicator,
|
||||
@ -19,7 +17,6 @@ from .exceptions import (
|
||||
|
||||
MULTISELECT_PROTOCOL_ID = "/multistream/1.0.0"
|
||||
PROTOCOL_NOT_FOUND_MSG = "na"
|
||||
DEFAULT_NEGOTIATE_TIMEOUT = 5
|
||||
|
||||
|
||||
class MultiselectClient(IMultiselectClient):
|
||||
@ -43,7 +40,6 @@ class MultiselectClient(IMultiselectClient):
|
||||
|
||||
try:
|
||||
handshake_contents = await communicator.read()
|
||||
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectClientError() from error
|
||||
|
||||
@ -51,10 +47,7 @@ class MultiselectClient(IMultiselectClient):
|
||||
raise MultiselectClientError("multiselect protocol ID mismatch")
|
||||
|
||||
async def select_one_of(
|
||||
self,
|
||||
protocols: Sequence[TProtocol],
|
||||
communicator: IMultiselectCommunicator,
|
||||
negotitate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
self, protocols: Sequence[TProtocol], communicator: IMultiselectCommunicator
|
||||
) -> TProtocol:
|
||||
"""
|
||||
For each protocol, send message to multiselect selecting protocol and
|
||||
@ -63,32 +56,22 @@ class MultiselectClient(IMultiselectClient):
|
||||
|
||||
:param protocol: protocol to select
|
||||
:param communicator: communicator to use to communicate with counterparty
|
||||
:param negotiate_timeout: timeout for negotiation
|
||||
:return: selected protocol
|
||||
:raise MultiselectClientError: raised when protocol negotiation failed
|
||||
"""
|
||||
try:
|
||||
with trio.fail_after(negotitate_timeout):
|
||||
await self.handshake(communicator)
|
||||
await self.handshake(communicator)
|
||||
|
||||
for protocol in protocols:
|
||||
try:
|
||||
selected_protocol = await self.try_select(
|
||||
communicator, protocol
|
||||
)
|
||||
return selected_protocol
|
||||
except MultiselectClientError:
|
||||
pass
|
||||
for protocol in protocols:
|
||||
try:
|
||||
selected_protocol = await self.try_select(communicator, protocol)
|
||||
return selected_protocol
|
||||
except MultiselectClientError:
|
||||
pass
|
||||
|
||||
raise MultiselectClientError("protocols not supported")
|
||||
except trio.TooSlowError:
|
||||
raise MultiselectClientError("response timed out")
|
||||
raise MultiselectClientError("protocols not supported")
|
||||
|
||||
async def query_multistream_command(
|
||||
self,
|
||||
communicator: IMultiselectCommunicator,
|
||||
command: str,
|
||||
response_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
self, communicator: IMultiselectCommunicator, command: str
|
||||
) -> list[str]:
|
||||
"""
|
||||
Send a multistream-select command over the given communicator and return
|
||||
@ -96,32 +79,26 @@ class MultiselectClient(IMultiselectClient):
|
||||
|
||||
:param communicator: communicator to use to communicate with counterparty
|
||||
:param command: supported multistream-select command(e.g., ls)
|
||||
:param negotiate_timeout: timeout for negotiation
|
||||
:raise MultiselectClientError: If the communicator fails to process data.
|
||||
:return: list of strings representing the response from peer.
|
||||
"""
|
||||
await self.handshake(communicator)
|
||||
|
||||
if command == "ls":
|
||||
try:
|
||||
await communicator.write("ls")
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectClientError() from error
|
||||
else:
|
||||
raise ValueError("Command not supported")
|
||||
|
||||
try:
|
||||
with trio.fail_after(response_timeout):
|
||||
await self.handshake(communicator)
|
||||
response = await communicator.read()
|
||||
response_list = response.strip().splitlines()
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectClientError() from error
|
||||
|
||||
if command == "ls":
|
||||
try:
|
||||
await communicator.write("ls")
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectClientError() from error
|
||||
else:
|
||||
raise ValueError("Command not supported")
|
||||
|
||||
try:
|
||||
response = await communicator.read()
|
||||
response_list = response.strip().splitlines()
|
||||
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectClientError() from error
|
||||
|
||||
return response_list
|
||||
except trio.TooSlowError:
|
||||
raise MultiselectClientError("command response timed out")
|
||||
return response_list
|
||||
|
||||
async def try_select(
|
||||
self, communicator: IMultiselectCommunicator, protocol: TProtocol
|
||||
@ -141,7 +118,6 @@ class MultiselectClient(IMultiselectClient):
|
||||
|
||||
try:
|
||||
response = await communicator.read()
|
||||
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectClientError() from error
|
||||
|
||||
|
||||
@ -12,13 +12,16 @@ from libp2p.abc import (
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.network.stream.exceptions import (
|
||||
StreamClosed,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
|
||||
from .exceptions import (
|
||||
PubsubRouterError,
|
||||
from libp2p.utils import (
|
||||
encode_varint_prefixed,
|
||||
)
|
||||
|
||||
from .pb import (
|
||||
rpc_pb2,
|
||||
)
|
||||
@ -34,7 +37,7 @@ logger = logging.getLogger("libp2p.pubsub.floodsub")
|
||||
class FloodSub(IPubsubRouter):
|
||||
protocols: list[TProtocol]
|
||||
|
||||
pubsub: Pubsub | None
|
||||
pubsub: Pubsub
|
||||
|
||||
def __init__(self, protocols: Sequence[TProtocol]) -> None:
|
||||
self.protocols = list(protocols)
|
||||
@ -55,7 +58,7 @@ class FloodSub(IPubsubRouter):
|
||||
"""
|
||||
self.pubsub = pubsub
|
||||
|
||||
def add_peer(self, peer_id: ID, protocol_id: TProtocol | None) -> None:
|
||||
def add_peer(self, peer_id: ID, protocol_id: TProtocol) -> None:
|
||||
"""
|
||||
Notifies the router that a new peer has been connected.
|
||||
|
||||
@ -105,16 +108,17 @@ class FloodSub(IPubsubRouter):
|
||||
|
||||
logger.debug("publishing message %s", pubsub_msg)
|
||||
|
||||
if self.pubsub is None:
|
||||
raise PubsubRouterError("pubsub not attached to this instance")
|
||||
else:
|
||||
pubsub = self.pubsub
|
||||
|
||||
for peer_id in peers_gen:
|
||||
if peer_id not in pubsub.peers:
|
||||
if peer_id not in self.pubsub.peers:
|
||||
continue
|
||||
stream = pubsub.peers[peer_id]
|
||||
await pubsub.write_msg(stream, rpc_msg)
|
||||
stream = self.pubsub.peers[peer_id]
|
||||
# FIXME: We should add a `WriteMsg` similar to write delimited messages.
|
||||
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107
|
||||
try:
|
||||
await stream.write(encode_varint_prefixed(rpc_msg.SerializeToString()))
|
||||
except StreamClosed:
|
||||
logger.debug("Fail to publish message to %s: stream closed", peer_id)
|
||||
self.pubsub._handle_dead_peer(peer_id)
|
||||
|
||||
async def join(self, topic: str) -> None:
|
||||
"""
|
||||
@ -146,16 +150,12 @@ class FloodSub(IPubsubRouter):
|
||||
:param origin: peer id of the peer the message originate from.
|
||||
:return: a generator of the peer ids who we send data to.
|
||||
"""
|
||||
if self.pubsub is None:
|
||||
raise PubsubRouterError("pubsub not attached to this instance")
|
||||
else:
|
||||
pubsub = self.pubsub
|
||||
for topic in topic_ids:
|
||||
if topic not in pubsub.peer_topics:
|
||||
if topic not in self.pubsub.peer_topics:
|
||||
continue
|
||||
for peer_id in pubsub.peer_topics[topic]:
|
||||
for peer_id in self.pubsub.peer_topics[topic]:
|
||||
if peer_id in (msg_forwarder, origin):
|
||||
continue
|
||||
if peer_id not in pubsub.peers:
|
||||
if peer_id not in self.pubsub.peers:
|
||||
continue
|
||||
yield peer_id
|
||||
|
||||
@ -24,13 +24,14 @@ from libp2p.abc import (
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.network.stream.exceptions import (
|
||||
StreamClosed,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
peer_info_from_bytes,
|
||||
peer_info_to_bytes,
|
||||
)
|
||||
from libp2p.peer.peerstore import (
|
||||
PERMANENT_ADDR_TTL,
|
||||
@ -41,6 +42,9 @@ from libp2p.pubsub import (
|
||||
from libp2p.tools.async_service import (
|
||||
Service,
|
||||
)
|
||||
from libp2p.utils import (
|
||||
encode_varint_prefixed,
|
||||
)
|
||||
|
||||
from .exceptions import (
|
||||
NoPubsubAttached,
|
||||
@ -63,7 +67,7 @@ logger = logging.getLogger("libp2p.pubsub.gossipsub")
|
||||
|
||||
class GossipSub(IPubsubRouter, Service):
|
||||
protocols: list[TProtocol]
|
||||
pubsub: Pubsub | None
|
||||
pubsub: Pubsub
|
||||
|
||||
degree: int
|
||||
degree_high: int
|
||||
@ -88,19 +92,13 @@ class GossipSub(IPubsubRouter, Service):
|
||||
direct_connect_initial_delay: float
|
||||
direct_connect_interval: int
|
||||
|
||||
do_px: bool
|
||||
px_peers_count: int
|
||||
back_off: dict[str, dict[ID, int]]
|
||||
prune_back_off: int
|
||||
unsubscribe_back_off: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
protocols: Sequence[TProtocol],
|
||||
degree: int,
|
||||
degree_low: int,
|
||||
degree_high: int,
|
||||
direct_peers: Sequence[PeerInfo] | None = None,
|
||||
direct_peers: Sequence[PeerInfo] = None,
|
||||
time_to_live: int = 60,
|
||||
gossip_window: int = 3,
|
||||
gossip_history: int = 5,
|
||||
@ -108,10 +106,6 @@ class GossipSub(IPubsubRouter, Service):
|
||||
heartbeat_interval: int = 120,
|
||||
direct_connect_initial_delay: float = 0.1,
|
||||
direct_connect_interval: int = 300,
|
||||
do_px: bool = False,
|
||||
px_peers_count: int = 16,
|
||||
prune_back_off: int = 60,
|
||||
unsubscribe_back_off: int = 10,
|
||||
) -> None:
|
||||
self.protocols = list(protocols)
|
||||
self.pubsub = None
|
||||
@ -146,13 +140,9 @@ class GossipSub(IPubsubRouter, Service):
|
||||
self.direct_connect_initial_delay = direct_connect_initial_delay
|
||||
self.time_since_last_publish = {}
|
||||
|
||||
self.do_px = do_px
|
||||
self.px_peers_count = px_peers_count
|
||||
self.back_off = dict()
|
||||
self.prune_back_off = prune_back_off
|
||||
self.unsubscribe_back_off = unsubscribe_back_off
|
||||
|
||||
async def run(self) -> None:
|
||||
if self.pubsub is None:
|
||||
raise NoPubsubAttached
|
||||
self.manager.run_daemon_task(self.heartbeat)
|
||||
if len(self.direct_peers) > 0:
|
||||
self.manager.run_daemon_task(self.direct_connect_heartbeat)
|
||||
@ -183,7 +173,7 @@ class GossipSub(IPubsubRouter, Service):
|
||||
|
||||
logger.debug("attached to pusub")
|
||||
|
||||
def add_peer(self, peer_id: ID, protocol_id: TProtocol | None) -> None:
|
||||
def add_peer(self, peer_id: ID, protocol_id: TProtocol) -> None:
|
||||
"""
|
||||
Notifies the router that a new peer has been connected.
|
||||
|
||||
@ -192,9 +182,6 @@ class GossipSub(IPubsubRouter, Service):
|
||||
"""
|
||||
logger.debug("adding peer %s with protocol %s", peer_id, protocol_id)
|
||||
|
||||
if protocol_id is None:
|
||||
raise ValueError("Protocol cannot be None")
|
||||
|
||||
if protocol_id not in (PROTOCOL_ID, floodsub.PROTOCOL_ID):
|
||||
# We should never enter here. Becuase the `protocol_id` is registered by
|
||||
# your pubsub instance in multistream-select, but it is not the protocol
|
||||
@ -256,15 +243,17 @@ class GossipSub(IPubsubRouter, Service):
|
||||
logger.debug("publishing message %s", pubsub_msg)
|
||||
|
||||
for peer_id in peers_gen:
|
||||
if self.pubsub is None:
|
||||
raise NoPubsubAttached
|
||||
if peer_id not in self.pubsub.peers:
|
||||
continue
|
||||
stream = self.pubsub.peers[peer_id]
|
||||
|
||||
# FIXME: We should add a `WriteMsg` similar to write delimited messages.
|
||||
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107
|
||||
# TODO: Go use `sendRPC`, which possibly piggybacks gossip/control messages.
|
||||
await self.pubsub.write_msg(stream, rpc_msg)
|
||||
|
||||
try:
|
||||
await stream.write(encode_varint_prefixed(rpc_msg.SerializeToString()))
|
||||
except StreamClosed:
|
||||
logger.debug("Fail to publish message to %s: stream closed", peer_id)
|
||||
self.pubsub._handle_dead_peer(peer_id)
|
||||
for topic in pubsub_msg.topicIDs:
|
||||
self.time_since_last_publish[topic] = int(time.time())
|
||||
|
||||
@ -280,8 +269,6 @@ class GossipSub(IPubsubRouter, Service):
|
||||
"""
|
||||
send_to: set[ID] = set()
|
||||
for topic in topic_ids:
|
||||
if self.pubsub is None:
|
||||
raise NoPubsubAttached
|
||||
if topic not in self.pubsub.peer_topics:
|
||||
continue
|
||||
|
||||
@ -331,9 +318,6 @@ class GossipSub(IPubsubRouter, Service):
|
||||
|
||||
:param topic: topic to join
|
||||
"""
|
||||
if self.pubsub is None:
|
||||
raise NoPubsubAttached
|
||||
|
||||
logger.debug("joining topic %s", topic)
|
||||
|
||||
if topic in self.mesh:
|
||||
@ -342,22 +326,15 @@ class GossipSub(IPubsubRouter, Service):
|
||||
self.mesh[topic] = set()
|
||||
|
||||
topic_in_fanout: bool = topic in self.fanout
|
||||
fanout_peers: set[ID] = set()
|
||||
|
||||
if topic_in_fanout:
|
||||
for peer in self.fanout[topic]:
|
||||
if self._check_back_off(peer, topic):
|
||||
continue
|
||||
fanout_peers.add(peer)
|
||||
|
||||
fanout_peers: set[ID] = self.fanout[topic] if topic_in_fanout else set()
|
||||
fanout_size = len(fanout_peers)
|
||||
if not topic_in_fanout or (topic_in_fanout and fanout_size < self.degree):
|
||||
# There are less than D peers (let this number be x)
|
||||
# in the fanout for a topic (or the topic is not in the fanout).
|
||||
# Selects the remaining number of peers (D-x) from peers.gossipsub[topic].
|
||||
if self.pubsub is not None and topic in self.pubsub.peer_topics:
|
||||
if topic in self.pubsub.peer_topics:
|
||||
selected_peers = self._get_in_topic_gossipsub_peers_from_minus(
|
||||
topic, self.degree - fanout_size, fanout_peers, True
|
||||
topic, self.degree - fanout_size, fanout_peers
|
||||
)
|
||||
# Combine fanout peers with selected peers
|
||||
fanout_peers.update(selected_peers)
|
||||
@ -384,8 +361,7 @@ class GossipSub(IPubsubRouter, Service):
|
||||
return
|
||||
# Notify the peers in mesh[topic] with a PRUNE(topic) message
|
||||
for peer in self.mesh[topic]:
|
||||
await self.emit_prune(topic, peer, self.do_px, True)
|
||||
self._add_back_off(peer, topic, True)
|
||||
await self.emit_prune(topic, peer)
|
||||
|
||||
# Forget mesh[topic]
|
||||
self.mesh.pop(topic, None)
|
||||
@ -475,8 +451,8 @@ class GossipSub(IPubsubRouter, Service):
|
||||
self.fanout_heartbeat()
|
||||
# Get the peers to send IHAVE to
|
||||
peers_to_gossip = self.gossip_heartbeat()
|
||||
# Pack(piggyback) GRAFT, PRUNE and IHAVE for the same peer into
|
||||
# one control message and send it
|
||||
# Pack GRAFT, PRUNE and IHAVE for the same peer into one control message and
|
||||
# send it
|
||||
await self._emit_control_msgs(
|
||||
peers_to_graft, peers_to_prune, peers_to_gossip
|
||||
)
|
||||
@ -492,8 +468,6 @@ class GossipSub(IPubsubRouter, Service):
|
||||
await trio.sleep(self.direct_connect_initial_delay)
|
||||
while True:
|
||||
for direct_peer in self.direct_peers:
|
||||
if self.pubsub is None:
|
||||
raise NoPubsubAttached
|
||||
if direct_peer not in self.pubsub.peers:
|
||||
try:
|
||||
await self.pubsub.host.connect(self.direct_peers[direct_peer])
|
||||
@ -511,8 +485,6 @@ class GossipSub(IPubsubRouter, Service):
|
||||
peers_to_graft: DefaultDict[ID, list[str]] = defaultdict(list)
|
||||
peers_to_prune: DefaultDict[ID, list[str]] = defaultdict(list)
|
||||
for topic in self.mesh:
|
||||
if self.pubsub is None:
|
||||
raise NoPubsubAttached
|
||||
# Skip if no peers have subscribed to the topic
|
||||
if topic not in self.pubsub.peer_topics:
|
||||
continue
|
||||
@ -521,7 +493,7 @@ class GossipSub(IPubsubRouter, Service):
|
||||
if num_mesh_peers_in_topic < self.degree_low:
|
||||
# Select D - |mesh[topic]| peers from peers.gossipsub[topic] - mesh[topic] # noqa: E501
|
||||
selected_peers = self._get_in_topic_gossipsub_peers_from_minus(
|
||||
topic, self.degree - num_mesh_peers_in_topic, self.mesh[topic], True
|
||||
topic, self.degree - num_mesh_peers_in_topic, self.mesh[topic]
|
||||
)
|
||||
|
||||
for peer in selected_peers:
|
||||
@ -544,97 +516,74 @@ class GossipSub(IPubsubRouter, Service):
|
||||
peers_to_prune[peer].append(topic)
|
||||
return peers_to_graft, peers_to_prune
|
||||
|
||||
def _handle_topic_heartbeat(
|
||||
self,
|
||||
topic: str,
|
||||
current_peers: set[ID],
|
||||
is_fanout: bool = False,
|
||||
peers_to_gossip: DefaultDict[ID, dict[str, list[str]]] | None = None,
|
||||
) -> tuple[set[ID], bool]:
|
||||
"""
|
||||
Helper method to handle heartbeat for a single topic,
|
||||
supporting both fanout and gossip.
|
||||
|
||||
:param topic: The topic to handle
|
||||
:param current_peers: Current set of peers in the topic
|
||||
:param is_fanout: Whether this is a fanout topic (affects expiration check)
|
||||
:param peers_to_gossip: Optional dictionary to store peers to gossip to
|
||||
:return: Tuple of (updated_peers, should_remove_topic)
|
||||
"""
|
||||
if self.pubsub is None:
|
||||
raise NoPubsubAttached
|
||||
|
||||
# Skip if no peers have subscribed to the topic
|
||||
if topic not in self.pubsub.peer_topics:
|
||||
return current_peers, False
|
||||
|
||||
# For fanout topics, check if we should remove the topic
|
||||
if is_fanout:
|
||||
if self.time_since_last_publish.get(topic, 0) + self.time_to_live < int(
|
||||
time.time()
|
||||
):
|
||||
return set(), True
|
||||
|
||||
# Check if peers are still in the topic and remove the ones that are not
|
||||
in_topic_peers: set[ID] = {
|
||||
peer for peer in current_peers if peer in self.pubsub.peer_topics[topic]
|
||||
}
|
||||
|
||||
# If we need more peers to reach target degree
|
||||
if len(in_topic_peers) < self.degree:
|
||||
# Select additional peers from peers.gossipsub[topic]
|
||||
selected_peers = self._get_in_topic_gossipsub_peers_from_minus(
|
||||
topic, self.degree - len(in_topic_peers), in_topic_peers, True
|
||||
)
|
||||
# Add the selected peers
|
||||
in_topic_peers.update(selected_peers)
|
||||
|
||||
# Handle gossip if requested
|
||||
if peers_to_gossip is not None:
|
||||
msg_ids = self.mcache.window(topic)
|
||||
if msg_ids:
|
||||
# Select D peers from peers.gossipsub[topic] excluding current peers
|
||||
peers_to_emit_ihave_to = self._get_in_topic_gossipsub_peers_from_minus(
|
||||
topic, self.degree, current_peers, True
|
||||
)
|
||||
msg_id_strs = [str(msg_id) for msg_id in msg_ids]
|
||||
for peer in peers_to_emit_ihave_to:
|
||||
peers_to_gossip[peer][topic] = msg_id_strs
|
||||
|
||||
return in_topic_peers, False
|
||||
|
||||
def fanout_heartbeat(self) -> None:
|
||||
"""
|
||||
Maintain fanout topics by:
|
||||
1. Removing expired topics
|
||||
2. Removing peers that are no longer in the topic
|
||||
3. Adding new peers if needed to maintain the target degree
|
||||
"""
|
||||
# Note: the comments here are the exact pseudocode from the spec
|
||||
for topic in list(self.fanout):
|
||||
updated_peers, should_remove = self._handle_topic_heartbeat(
|
||||
topic, self.fanout[topic], is_fanout=True
|
||||
)
|
||||
if should_remove:
|
||||
if (
|
||||
topic not in self.pubsub.peer_topics
|
||||
and self.time_since_last_publish.get(topic, 0) + self.time_to_live
|
||||
< int(time.time())
|
||||
):
|
||||
# Remove topic from fanout
|
||||
del self.fanout[topic]
|
||||
else:
|
||||
self.fanout[topic] = updated_peers
|
||||
# Check if fanout peers are still in the topic and remove the ones that are not # noqa: E501
|
||||
# ref: https://github.com/libp2p/go-libp2p-pubsub/blob/01b9825fbee1848751d90a8469e3f5f43bac8466/gossipsub.go#L498-L504 # noqa: E501
|
||||
in_topic_fanout_peers = [
|
||||
peer
|
||||
for peer in self.fanout[topic]
|
||||
if peer in self.pubsub.peer_topics[topic]
|
||||
]
|
||||
self.fanout[topic] = set(in_topic_fanout_peers)
|
||||
num_fanout_peers_in_topic = len(self.fanout[topic])
|
||||
|
||||
# If |fanout[topic]| < D
|
||||
if num_fanout_peers_in_topic < self.degree:
|
||||
# Select D - |fanout[topic]| peers from peers.gossipsub[topic] - fanout[topic] # noqa: E501
|
||||
selected_peers = self._get_in_topic_gossipsub_peers_from_minus(
|
||||
topic,
|
||||
self.degree - num_fanout_peers_in_topic,
|
||||
self.fanout[topic],
|
||||
)
|
||||
# Add the peers to fanout[topic]
|
||||
self.fanout[topic].update(selected_peers)
|
||||
|
||||
def gossip_heartbeat(self) -> DefaultDict[ID, dict[str, list[str]]]:
|
||||
peers_to_gossip: DefaultDict[ID, dict[str, list[str]]] = defaultdict(dict)
|
||||
|
||||
# Handle mesh topics
|
||||
for topic in self.mesh:
|
||||
self._handle_topic_heartbeat(
|
||||
topic, self.mesh[topic], peers_to_gossip=peers_to_gossip
|
||||
)
|
||||
msg_ids = self.mcache.window(topic)
|
||||
if msg_ids:
|
||||
# Get all pubsub peers in a topic and only add them if they are
|
||||
# gossipsub peers too
|
||||
if topic in self.pubsub.peer_topics:
|
||||
# Select D peers from peers.gossipsub[topic]
|
||||
peers_to_emit_ihave_to = (
|
||||
self._get_in_topic_gossipsub_peers_from_minus(
|
||||
topic, self.degree, self.mesh[topic]
|
||||
)
|
||||
)
|
||||
|
||||
# Handle fanout topics that aren't in mesh
|
||||
msg_id_strs = [str(msg_id) for msg_id in msg_ids]
|
||||
for peer in peers_to_emit_ihave_to:
|
||||
peers_to_gossip[peer][topic] = msg_id_strs
|
||||
|
||||
# TODO: Refactor and Dedup. This section is the roughly the same as the above.
|
||||
# Do the same for fanout, for all topics not already hit in mesh
|
||||
for topic in self.fanout:
|
||||
if topic not in self.mesh:
|
||||
self._handle_topic_heartbeat(
|
||||
topic, self.fanout[topic], peers_to_gossip=peers_to_gossip
|
||||
)
|
||||
|
||||
msg_ids = self.mcache.window(topic)
|
||||
if msg_ids:
|
||||
# Get all pubsub peers in topic and only add if they are
|
||||
# gossipsub peers also
|
||||
if topic in self.pubsub.peer_topics:
|
||||
# Select D peers from peers.gossipsub[topic]
|
||||
peers_to_emit_ihave_to = (
|
||||
self._get_in_topic_gossipsub_peers_from_minus(
|
||||
topic, self.degree, self.fanout[topic]
|
||||
)
|
||||
)
|
||||
msg_id_strs = [str(msg) for msg in msg_ids]
|
||||
for peer in peers_to_emit_ihave_to:
|
||||
peers_to_gossip[peer][topic] = msg_id_strs
|
||||
return peers_to_gossip
|
||||
|
||||
@staticmethod
|
||||
@ -669,109 +618,21 @@ class GossipSub(IPubsubRouter, Service):
|
||||
return selection
|
||||
|
||||
def _get_in_topic_gossipsub_peers_from_minus(
|
||||
self,
|
||||
topic: str,
|
||||
num_to_select: int,
|
||||
minus: Iterable[ID],
|
||||
backoff_check: bool = False,
|
||||
self, topic: str, num_to_select: int, minus: Iterable[ID]
|
||||
) -> list[ID]:
|
||||
if self.pubsub is None:
|
||||
raise NoPubsubAttached
|
||||
gossipsub_peers_in_topic = {
|
||||
peer_id
|
||||
for peer_id in self.pubsub.peer_topics[topic]
|
||||
if self.peer_protocol[peer_id] == PROTOCOL_ID
|
||||
}
|
||||
if backoff_check:
|
||||
# filter out peers that are in back off for this topic
|
||||
gossipsub_peers_in_topic = {
|
||||
peer_id
|
||||
for peer_id in gossipsub_peers_in_topic
|
||||
if self._check_back_off(peer_id, topic) is False
|
||||
}
|
||||
return self.select_from_minus(num_to_select, gossipsub_peers_in_topic, minus)
|
||||
|
||||
def _add_back_off(
|
||||
self, peer: ID, topic: str, is_unsubscribe: bool, backoff_duration: int = 0
|
||||
) -> None:
|
||||
"""
|
||||
Add back off for a peer in a topic.
|
||||
:param peer: peer to add back off for
|
||||
:param topic: topic to add back off for
|
||||
:param is_unsubscribe: whether this is an unsubscribe operation
|
||||
:param backoff_duration: duration of back off in seconds, if 0, use default
|
||||
"""
|
||||
if topic not in self.back_off:
|
||||
self.back_off[topic] = dict()
|
||||
|
||||
backoff_till = int(time.time())
|
||||
if backoff_duration > 0:
|
||||
backoff_till += backoff_duration
|
||||
else:
|
||||
if is_unsubscribe:
|
||||
backoff_till += self.unsubscribe_back_off
|
||||
else:
|
||||
backoff_till += self.prune_back_off
|
||||
|
||||
if peer not in self.back_off[topic]:
|
||||
self.back_off[topic][peer] = backoff_till
|
||||
else:
|
||||
self.back_off[topic][peer] = max(self.back_off[topic][peer], backoff_till)
|
||||
|
||||
def _check_back_off(self, peer: ID, topic: str) -> bool:
|
||||
"""
|
||||
Check if a peer is in back off for a topic and cleanup expired back off entries.
|
||||
:param peer: peer to check
|
||||
:param topic: topic to check
|
||||
:return: True if the peer is in back off, False otherwise
|
||||
"""
|
||||
if topic not in self.back_off or peer not in self.back_off[topic]:
|
||||
return False
|
||||
if self.back_off[topic].get(peer, 0) > int(time.time()):
|
||||
return True
|
||||
else:
|
||||
del self.back_off[topic][peer]
|
||||
return False
|
||||
|
||||
async def _do_px(self, px_peers: list[rpc_pb2.PeerInfo]) -> None:
|
||||
if len(px_peers) > self.px_peers_count:
|
||||
px_peers = px_peers[: self.px_peers_count]
|
||||
|
||||
for peer in px_peers:
|
||||
peer_id: ID = ID(peer.peerID)
|
||||
|
||||
if self.pubsub and peer_id in self.pubsub.peers:
|
||||
continue
|
||||
|
||||
try:
|
||||
peer_info = peer_info_from_bytes(peer.signedPeerRecord)
|
||||
try:
|
||||
if self.pubsub is None:
|
||||
raise NoPubsubAttached
|
||||
await self.pubsub.host.connect(peer_info)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"failed to connect to px peer %s: %s",
|
||||
peer_id,
|
||||
e,
|
||||
)
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"failed to parse peer info from px peer %s: %s",
|
||||
peer_id,
|
||||
e,
|
||||
)
|
||||
continue
|
||||
|
||||
# RPC handlers
|
||||
|
||||
async def handle_ihave(
|
||||
self, ihave_msg: rpc_pb2.ControlIHave, sender_peer_id: ID
|
||||
) -> None:
|
||||
"""Checks the seen set and requests unknown messages with an IWANT message."""
|
||||
if self.pubsub is None:
|
||||
raise NoPubsubAttached
|
||||
# Get list of all seen (seqnos, from) from the (seqno, from) tuples in
|
||||
# seen_messages cache
|
||||
seen_seqnos_and_peers = [
|
||||
@ -804,7 +665,7 @@ class GossipSub(IPubsubRouter, Service):
|
||||
msgs_to_forward: list[rpc_pb2.Message] = []
|
||||
for msg_id_iwant in msg_ids:
|
||||
# Check if the wanted message ID is present in mcache
|
||||
msg: rpc_pb2.Message | None = self.mcache.get(msg_id_iwant)
|
||||
msg: rpc_pb2.Message = self.mcache.get(msg_id_iwant)
|
||||
|
||||
# Cache hit
|
||||
if msg:
|
||||
@ -820,8 +681,8 @@ class GossipSub(IPubsubRouter, Service):
|
||||
|
||||
packet.publish.extend(msgs_to_forward)
|
||||
|
||||
if self.pubsub is None:
|
||||
raise NoPubsubAttached
|
||||
# 2) Serialize that packet
|
||||
rpc_msg: bytes = packet.SerializeToString()
|
||||
|
||||
# 3) Get the stream to this peer
|
||||
if sender_peer_id not in self.pubsub.peers:
|
||||
@ -833,7 +694,14 @@ class GossipSub(IPubsubRouter, Service):
|
||||
peer_stream = self.pubsub.peers[sender_peer_id]
|
||||
|
||||
# 4) And write the packet to the stream
|
||||
await self.pubsub.write_msg(peer_stream, packet)
|
||||
try:
|
||||
await peer_stream.write(encode_varint_prefixed(rpc_msg))
|
||||
except StreamClosed:
|
||||
logger.debug(
|
||||
"Fail to responed to iwant request from %s: stream closed",
|
||||
sender_peer_id,
|
||||
)
|
||||
self.pubsub._handle_dead_peer(sender_peer_id)
|
||||
|
||||
async def handle_graft(
|
||||
self, graft_msg: rpc_pb2.ControlGraft, sender_peer_id: ID
|
||||
@ -847,53 +715,31 @@ class GossipSub(IPubsubRouter, Service):
|
||||
logger.warning(
|
||||
"GRAFT: ignoring request from direct peer %s", sender_peer_id
|
||||
)
|
||||
await self.emit_prune(topic, sender_peer_id, False, False)
|
||||
await self.emit_prune(topic, sender_peer_id)
|
||||
return
|
||||
|
||||
if self._check_back_off(sender_peer_id, topic):
|
||||
logger.warning(
|
||||
"GRAFT: ignoring request from %s, back off until %d",
|
||||
sender_peer_id,
|
||||
self.back_off[topic][sender_peer_id],
|
||||
)
|
||||
self._add_back_off(sender_peer_id, topic, False)
|
||||
await self.emit_prune(topic, sender_peer_id, False, False)
|
||||
return
|
||||
|
||||
if sender_peer_id not in self.mesh[topic]:
|
||||
self.mesh[topic].add(sender_peer_id)
|
||||
else:
|
||||
# Respond with PRUNE if not subscribed to the topic
|
||||
await self.emit_prune(topic, sender_peer_id, self.do_px, False)
|
||||
await self.emit_prune(topic, sender_peer_id)
|
||||
|
||||
async def handle_prune(
|
||||
self, prune_msg: rpc_pb2.ControlPrune, sender_peer_id: ID
|
||||
) -> None:
|
||||
topic: str = prune_msg.topicID
|
||||
backoff_till: int = prune_msg.backoff
|
||||
px_peers: list[rpc_pb2.PeerInfo] = []
|
||||
for peer in prune_msg.peers:
|
||||
px_peers.append(peer)
|
||||
|
||||
# Remove peer from mesh for topic
|
||||
if topic in self.mesh:
|
||||
if backoff_till > 0:
|
||||
self._add_back_off(sender_peer_id, topic, False, backoff_till)
|
||||
else:
|
||||
self._add_back_off(sender_peer_id, topic, False)
|
||||
|
||||
self.mesh[topic].discard(sender_peer_id)
|
||||
|
||||
if px_peers:
|
||||
await self._do_px(px_peers)
|
||||
|
||||
# RPC emitters
|
||||
|
||||
def pack_control_msgs(
|
||||
self,
|
||||
ihave_msgs: list[rpc_pb2.ControlIHave] | None,
|
||||
graft_msgs: list[rpc_pb2.ControlGraft] | None,
|
||||
prune_msgs: list[rpc_pb2.ControlPrune] | None,
|
||||
ihave_msgs: list[rpc_pb2.ControlIHave],
|
||||
graft_msgs: list[rpc_pb2.ControlGraft],
|
||||
prune_msgs: list[rpc_pb2.ControlPrune],
|
||||
) -> rpc_pb2.ControlMessage:
|
||||
control_msg: rpc_pb2.ControlMessage = rpc_pb2.ControlMessage()
|
||||
if ihave_msgs:
|
||||
@ -925,7 +771,7 @@ class GossipSub(IPubsubRouter, Service):
|
||||
|
||||
await self.emit_control_message(control_msg, to_peer)
|
||||
|
||||
async def emit_graft(self, topic: str, id: ID) -> None:
|
||||
async def emit_graft(self, topic: str, to_peer: ID) -> None:
|
||||
"""Emit graft message, sent to to_peer, for topic."""
|
||||
graft_msg: rpc_pb2.ControlGraft = rpc_pb2.ControlGraft()
|
||||
graft_msg.topicID = topic
|
||||
@ -933,34 +779,13 @@ class GossipSub(IPubsubRouter, Service):
|
||||
control_msg: rpc_pb2.ControlMessage = rpc_pb2.ControlMessage()
|
||||
control_msg.graft.extend([graft_msg])
|
||||
|
||||
await self.emit_control_message(control_msg, id)
|
||||
await self.emit_control_message(control_msg, to_peer)
|
||||
|
||||
async def emit_prune(
|
||||
self, topic: str, to_peer: ID, do_px: bool, is_unsubscribe: bool
|
||||
) -> None:
|
||||
async def emit_prune(self, topic: str, to_peer: ID) -> None:
|
||||
"""Emit graft message, sent to to_peer, for topic."""
|
||||
prune_msg: rpc_pb2.ControlPrune = rpc_pb2.ControlPrune()
|
||||
prune_msg.topicID = topic
|
||||
|
||||
back_off_duration = self.prune_back_off
|
||||
if is_unsubscribe:
|
||||
back_off_duration = self.unsubscribe_back_off
|
||||
|
||||
prune_msg.backoff = back_off_duration
|
||||
|
||||
if do_px:
|
||||
exchange_peers = self._get_in_topic_gossipsub_peers_from_minus(
|
||||
topic, self.px_peers_count, [to_peer]
|
||||
)
|
||||
for peer in exchange_peers:
|
||||
if self.pubsub is None:
|
||||
raise NoPubsubAttached
|
||||
peer_info = self.pubsub.host.get_peerstore().peer_info(peer)
|
||||
signed_peer_record: rpc_pb2.PeerInfo = rpc_pb2.PeerInfo()
|
||||
signed_peer_record.peerID = peer.to_bytes()
|
||||
signed_peer_record.signedPeerRecord = peer_info_to_bytes(peer_info)
|
||||
prune_msg.peers.append(signed_peer_record)
|
||||
|
||||
control_msg: rpc_pb2.ControlMessage = rpc_pb2.ControlMessage()
|
||||
control_msg.prune.extend([prune_msg])
|
||||
|
||||
@ -969,12 +794,12 @@ class GossipSub(IPubsubRouter, Service):
|
||||
async def emit_control_message(
|
||||
self, control_msg: rpc_pb2.ControlMessage, to_peer: ID
|
||||
) -> None:
|
||||
if self.pubsub is None:
|
||||
raise NoPubsubAttached
|
||||
# Add control message to packet
|
||||
packet: rpc_pb2.RPC = rpc_pb2.RPC()
|
||||
packet.control.CopyFrom(control_msg)
|
||||
|
||||
rpc_msg: bytes = packet.SerializeToString()
|
||||
|
||||
# Get stream for peer from pubsub
|
||||
if to_peer not in self.pubsub.peers:
|
||||
logger.debug(
|
||||
@ -984,4 +809,8 @@ class GossipSub(IPubsubRouter, Service):
|
||||
peer_stream = self.pubsub.peers[to_peer]
|
||||
|
||||
# Write rpc to stream
|
||||
await self.pubsub.write_msg(peer_stream, packet)
|
||||
try:
|
||||
await peer_stream.write(encode_varint_prefixed(rpc_msg))
|
||||
except StreamClosed:
|
||||
logger.debug("Fail to emit control message to %s: stream closed", to_peer)
|
||||
self.pubsub._handle_dead_peer(to_peer)
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
from collections.abc import (
|
||||
Sequence,
|
||||
)
|
||||
from typing import (
|
||||
Optional,
|
||||
)
|
||||
|
||||
from .pb import (
|
||||
rpc_pb2,
|
||||
@ -63,7 +66,7 @@ class MessageCache:
|
||||
|
||||
self.history[0].append(CacheEntry(mid, msg.topicIDs))
|
||||
|
||||
def get(self, mid: tuple[bytes, bytes]) -> rpc_pb2.Message | None:
|
||||
def get(self, mid: tuple[bytes, bytes]) -> Optional[rpc_pb2.Message]:
|
||||
"""
|
||||
Get a message from the mcache.
|
||||
|
||||
|
||||
@ -47,13 +47,6 @@ message ControlGraft {
|
||||
|
||||
message ControlPrune {
|
||||
optional string topicID = 1;
|
||||
repeated PeerInfo peers = 2;
|
||||
optional uint64 backoff = 3;
|
||||
}
|
||||
|
||||
message PeerInfo {
|
||||
optional bytes peerID = 1;
|
||||
optional bytes signedPeerRecord = 2;
|
||||
}
|
||||
|
||||
message TopicDescriptor {
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: rpc.proto
|
||||
# source: libp2p/pubsub/pb/rpc.proto
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf.internal import builder as _builder
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
@ -13,39 +13,37 @@ _sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\trpc.proto\x12\tpubsub.pb\"\xb4\x01\n\x03RPC\x12-\n\rsubscriptions\x18\x01 \x03(\x0b\x32\x16.pubsub.pb.RPC.SubOpts\x12#\n\x07publish\x18\x02 \x03(\x0b\x32\x12.pubsub.pb.Message\x12*\n\x07\x63ontrol\x18\x03 \x01(\x0b\x32\x19.pubsub.pb.ControlMessage\x1a-\n\x07SubOpts\x12\x11\n\tsubscribe\x18\x01 \x01(\x08\x12\x0f\n\x07topicid\x18\x02 \x01(\t\"i\n\x07Message\x12\x0f\n\x07\x66rom_id\x18\x01 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\x12\r\n\x05seqno\x18\x03 \x01(\x0c\x12\x10\n\x08topicIDs\x18\x04 \x03(\t\x12\x11\n\tsignature\x18\x05 \x01(\x0c\x12\x0b\n\x03key\x18\x06 \x01(\x0c\"\xb0\x01\n\x0e\x43ontrolMessage\x12&\n\x05ihave\x18\x01 \x03(\x0b\x32\x17.pubsub.pb.ControlIHave\x12&\n\x05iwant\x18\x02 \x03(\x0b\x32\x17.pubsub.pb.ControlIWant\x12&\n\x05graft\x18\x03 \x03(\x0b\x32\x17.pubsub.pb.ControlGraft\x12&\n\x05prune\x18\x04 \x03(\x0b\x32\x17.pubsub.pb.ControlPrune\"3\n\x0c\x43ontrolIHave\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\x12\n\nmessageIDs\x18\x02 \x03(\t\"\"\n\x0c\x43ontrolIWant\x12\x12\n\nmessageIDs\x18\x01 \x03(\t\"\x1f\n\x0c\x43ontrolGraft\x12\x0f\n\x07topicID\x18\x01 \x01(\t\"T\n\x0c\x43ontrolPrune\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\"\n\x05peers\x18\x02 \x03(\x0b\x32\x13.pubsub.pb.PeerInfo\x12\x0f\n\x07\x62\x61\x63koff\x18\x03 \x01(\x04\"4\n\x08PeerInfo\x12\x0e\n\x06peerID\x18\x01 \x01(\x0c\x12\x18\n\x10signedPeerRecord\x18\x02 \x01(\x0c\"\x87\x03\n\x0fTopicDescriptor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x31\n\x04\x61uth\x18\x02 \x01(\x0b\x32#.pubsub.pb.TopicDescriptor.AuthOpts\x12/\n\x03\x65nc\x18\x03 \x01(\x0b\x32\".pubsub.pb.TopicDescriptor.EncOpts\x1a|\n\x08\x41uthOpts\x12:\n\x04mode\x18\x01 \x01(\x0e\x32,.pubsub.pb.TopicDescriptor.AuthOpts.AuthMode\x12\x0c\n\x04keys\x18\x02 \x03(\x0c\"&\n\x08\x41uthMode\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03KEY\x10\x01\x12\x07\n\x03WOT\x10\x02\x1a\x83\x01\n\x07\x45ncOpts\x12\x38\n\x04mode\x18\x01 \x01(\x0e\x32*.pubsub.pb.TopicDescriptor.EncOpts.EncMode\x12\x11\n\tkeyHashes\x18\x02 \x03(\x0c\"+\n\x07\x45ncMode\x12\x08\n\x04NONE\x10\x00\x12\r\n\tSHAREDKEY\x10\x01\x12\x07\n\x03WOT\x10\x02')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1alibp2p/pubsub/pb/rpc.proto\x12\tpubsub.pb\"\xb4\x01\n\x03RPC\x12-\n\rsubscriptions\x18\x01 \x03(\x0b\x32\x16.pubsub.pb.RPC.SubOpts\x12#\n\x07publish\x18\x02 \x03(\x0b\x32\x12.pubsub.pb.Message\x12*\n\x07\x63ontrol\x18\x03 \x01(\x0b\x32\x19.pubsub.pb.ControlMessage\x1a-\n\x07SubOpts\x12\x11\n\tsubscribe\x18\x01 \x01(\x08\x12\x0f\n\x07topicid\x18\x02 \x01(\t\"i\n\x07Message\x12\x0f\n\x07\x66rom_id\x18\x01 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\x12\r\n\x05seqno\x18\x03 \x01(\x0c\x12\x10\n\x08topicIDs\x18\x04 \x03(\t\x12\x11\n\tsignature\x18\x05 \x01(\x0c\x12\x0b\n\x03key\x18\x06 \x01(\x0c\"\xb0\x01\n\x0e\x43ontrolMessage\x12&\n\x05ihave\x18\x01 \x03(\x0b\x32\x17.pubsub.pb.ControlIHave\x12&\n\x05iwant\x18\x02 \x03(\x0b\x32\x17.pubsub.pb.ControlIWant\x12&\n\x05graft\x18\x03 \x03(\x0b\x32\x17.pubsub.pb.ControlGraft\x12&\n\x05prune\x18\x04 \x03(\x0b\x32\x17.pubsub.pb.ControlPrune\"3\n\x0c\x43ontrolIHave\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\x12\n\nmessageIDs\x18\x02 \x03(\t\"\"\n\x0c\x43ontrolIWant\x12\x12\n\nmessageIDs\x18\x01 \x03(\t\"\x1f\n\x0c\x43ontrolGraft\x12\x0f\n\x07topicID\x18\x01 \x01(\t\"\x1f\n\x0c\x43ontrolPrune\x12\x0f\n\x07topicID\x18\x01 \x01(\t\"\x87\x03\n\x0fTopicDescriptor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x31\n\x04\x61uth\x18\x02 \x01(\x0b\x32#.pubsub.pb.TopicDescriptor.AuthOpts\x12/\n\x03\x65nc\x18\x03 \x01(\x0b\x32\".pubsub.pb.TopicDescriptor.EncOpts\x1a|\n\x08\x41uthOpts\x12:\n\x04mode\x18\x01 \x01(\x0e\x32,.pubsub.pb.TopicDescriptor.AuthOpts.AuthMode\x12\x0c\n\x04keys\x18\x02 \x03(\x0c\"&\n\x08\x41uthMode\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03KEY\x10\x01\x12\x07\n\x03WOT\x10\x02\x1a\x83\x01\n\x07\x45ncOpts\x12\x38\n\x04mode\x18\x01 \x01(\x0e\x32*.pubsub.pb.TopicDescriptor.EncOpts.EncMode\x12\x11\n\tkeyHashes\x18\x02 \x03(\x0c\"+\n\x07\x45ncMode\x12\x08\n\x04NONE\x10\x00\x12\r\n\tSHAREDKEY\x10\x01\x12\x07\n\x03WOT\x10\x02')
|
||||
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'rpc_pb2', globals())
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.pubsub.pb.rpc_pb2', globals())
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
|
||||
DESCRIPTOR._options = None
|
||||
_RPC._serialized_start=25
|
||||
_RPC._serialized_end=205
|
||||
_RPC_SUBOPTS._serialized_start=160
|
||||
_RPC_SUBOPTS._serialized_end=205
|
||||
_MESSAGE._serialized_start=207
|
||||
_MESSAGE._serialized_end=312
|
||||
_CONTROLMESSAGE._serialized_start=315
|
||||
_CONTROLMESSAGE._serialized_end=491
|
||||
_CONTROLIHAVE._serialized_start=493
|
||||
_CONTROLIHAVE._serialized_end=544
|
||||
_CONTROLIWANT._serialized_start=546
|
||||
_CONTROLIWANT._serialized_end=580
|
||||
_CONTROLGRAFT._serialized_start=582
|
||||
_CONTROLGRAFT._serialized_end=613
|
||||
_CONTROLPRUNE._serialized_start=615
|
||||
_CONTROLPRUNE._serialized_end=699
|
||||
_PEERINFO._serialized_start=701
|
||||
_PEERINFO._serialized_end=753
|
||||
_TOPICDESCRIPTOR._serialized_start=756
|
||||
_TOPICDESCRIPTOR._serialized_end=1147
|
||||
_TOPICDESCRIPTOR_AUTHOPTS._serialized_start=889
|
||||
_TOPICDESCRIPTOR_AUTHOPTS._serialized_end=1013
|
||||
_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE._serialized_start=975
|
||||
_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE._serialized_end=1013
|
||||
_TOPICDESCRIPTOR_ENCOPTS._serialized_start=1016
|
||||
_TOPICDESCRIPTOR_ENCOPTS._serialized_end=1147
|
||||
_TOPICDESCRIPTOR_ENCOPTS_ENCMODE._serialized_start=1104
|
||||
_TOPICDESCRIPTOR_ENCOPTS_ENCMODE._serialized_end=1147
|
||||
_RPC._serialized_start=42
|
||||
_RPC._serialized_end=222
|
||||
_RPC_SUBOPTS._serialized_start=177
|
||||
_RPC_SUBOPTS._serialized_end=222
|
||||
_MESSAGE._serialized_start=224
|
||||
_MESSAGE._serialized_end=329
|
||||
_CONTROLMESSAGE._serialized_start=332
|
||||
_CONTROLMESSAGE._serialized_end=508
|
||||
_CONTROLIHAVE._serialized_start=510
|
||||
_CONTROLIHAVE._serialized_end=561
|
||||
_CONTROLIWANT._serialized_start=563
|
||||
_CONTROLIWANT._serialized_end=597
|
||||
_CONTROLGRAFT._serialized_start=599
|
||||
_CONTROLGRAFT._serialized_end=630
|
||||
_CONTROLPRUNE._serialized_start=632
|
||||
_CONTROLPRUNE._serialized_end=663
|
||||
_TOPICDESCRIPTOR._serialized_start=666
|
||||
_TOPICDESCRIPTOR._serialized_end=1057
|
||||
_TOPICDESCRIPTOR_AUTHOPTS._serialized_start=799
|
||||
_TOPICDESCRIPTOR_AUTHOPTS._serialized_end=923
|
||||
_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE._serialized_start=885
|
||||
_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE._serialized_end=923
|
||||
_TOPICDESCRIPTOR_ENCOPTS._serialized_start=926
|
||||
_TOPICDESCRIPTOR_ENCOPTS._serialized_end=1057
|
||||
_TOPICDESCRIPTOR_ENCOPTS_ENCMODE._serialized_start=1014
|
||||
_TOPICDESCRIPTOR_ENCOPTS_ENCMODE._serialized_end=1057
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
||||
@ -179,43 +179,17 @@ class ControlPrune(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
TOPICID_FIELD_NUMBER: builtins.int
|
||||
PEERS_FIELD_NUMBER: builtins.int
|
||||
BACKOFF_FIELD_NUMBER: builtins.int
|
||||
topicID: builtins.str
|
||||
backoff: builtins.int
|
||||
@property
|
||||
def peers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___PeerInfo]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
topicID: builtins.str | None = ...,
|
||||
peers: collections.abc.Iterable[global___PeerInfo] | None = ...,
|
||||
backoff: builtins.int | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["backoff", b"backoff", "topicID", b"topicID"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["backoff", b"backoff", "peers", b"peers", "topicID", b"topicID"]) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["topicID", b"topicID"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["topicID", b"topicID"]) -> None: ...
|
||||
|
||||
global___ControlPrune = ControlPrune
|
||||
|
||||
@typing.final
|
||||
class PeerInfo(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
PEERID_FIELD_NUMBER: builtins.int
|
||||
SIGNEDPEERRECORD_FIELD_NUMBER: builtins.int
|
||||
peerID: builtins.bytes
|
||||
signedPeerRecord: builtins.bytes
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
peerID: builtins.bytes | None = ...,
|
||||
signedPeerRecord: builtins.bytes | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["peerID", b"peerID", "signedPeerRecord", b"signedPeerRecord"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["peerID", b"peerID", "signedPeerRecord", b"signedPeerRecord"]) -> None: ...
|
||||
|
||||
global___PeerInfo = PeerInfo
|
||||
|
||||
@typing.final
|
||||
class TopicDescriptor(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
@ -4,13 +4,17 @@ from __future__ import (
|
||||
|
||||
import base64
|
||||
from collections.abc import (
|
||||
Callable,
|
||||
KeysView,
|
||||
)
|
||||
import functools
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from typing import (
|
||||
Callable,
|
||||
NamedTuple,
|
||||
cast,
|
||||
)
|
||||
|
||||
import base58
|
||||
import trio
|
||||
@ -26,6 +30,8 @@ from libp2p.crypto.keys import (
|
||||
PrivateKey,
|
||||
)
|
||||
from libp2p.custom_types import (
|
||||
AsyncValidatorFn,
|
||||
SyncValidatorFn,
|
||||
TProtocol,
|
||||
ValidatorFn,
|
||||
)
|
||||
@ -47,9 +53,6 @@ from libp2p.network.stream.exceptions import (
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peerdata import (
|
||||
PeerDataError,
|
||||
)
|
||||
from libp2p.tools.async_service import (
|
||||
Service,
|
||||
)
|
||||
@ -60,7 +63,6 @@ from libp2p.utils import (
|
||||
encode_varint_prefixed,
|
||||
read_varint_prefixed_bytes,
|
||||
)
|
||||
from libp2p.utils.varint import encode_uvarint
|
||||
|
||||
from .pb import (
|
||||
rpc_pb2,
|
||||
@ -71,11 +73,6 @@ from .pubsub_notifee import (
|
||||
from .subscription import (
|
||||
TrioSubscriptionAPI,
|
||||
)
|
||||
from .validation_throttler import (
|
||||
TopicValidator,
|
||||
ValidationResult,
|
||||
ValidationThrottler,
|
||||
)
|
||||
from .validators import (
|
||||
PUBSUB_SIGNING_PREFIX,
|
||||
signature_validator,
|
||||
@ -96,6 +93,11 @@ def get_content_addressed_msg_id(msg: rpc_pb2.Message) -> bytes:
|
||||
return base64.b64encode(hashlib.sha256(msg.data).digest())
|
||||
|
||||
|
||||
class TopicValidator(NamedTuple):
|
||||
validator: ValidatorFn
|
||||
is_async: bool
|
||||
|
||||
|
||||
class Pubsub(Service, IPubsub):
|
||||
host: IHost
|
||||
|
||||
@ -118,7 +120,7 @@ class Pubsub(Service, IPubsub):
|
||||
|
||||
# Indicate if we should enforce signature verification
|
||||
strict_signing: bool
|
||||
sign_key: PrivateKey | None
|
||||
sign_key: PrivateKey
|
||||
|
||||
# Set of blacklisted peer IDs
|
||||
blacklisted_peers: set[ID]
|
||||
@ -130,18 +132,13 @@ class Pubsub(Service, IPubsub):
|
||||
self,
|
||||
host: IHost,
|
||||
router: IPubsubRouter,
|
||||
cache_size: int | None = None,
|
||||
cache_size: int = None,
|
||||
seen_ttl: int = 120,
|
||||
sweep_interval: int = 60,
|
||||
strict_signing: bool = True,
|
||||
msg_id_constructor: Callable[
|
||||
[rpc_pb2.Message], bytes
|
||||
] = get_peer_and_seqno_msg_id,
|
||||
# TODO: these values have been copied from Go, but try to tune these dynamically
|
||||
validation_queue_size: int = 32,
|
||||
global_throttle_limit: int = 8192,
|
||||
default_topic_throttle_limit: int = 1024,
|
||||
validation_worker_count: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Construct a new Pubsub object, which is responsible for handling all
|
||||
@ -202,15 +199,7 @@ class Pubsub(Service, IPubsub):
|
||||
# Create peers map, which maps peer_id (as string) to stream (to a given peer)
|
||||
self.peers = {}
|
||||
|
||||
# Validation Throttler
|
||||
self.validation_throttler = ValidationThrottler(
|
||||
queue_size=validation_queue_size,
|
||||
global_throttle_limit=global_throttle_limit,
|
||||
default_topic_throttle_limit=default_topic_throttle_limit,
|
||||
worker_count=validation_worker_count or 4,
|
||||
)
|
||||
|
||||
# Keep a mapping of topic -> TopicValidator for easier lookup
|
||||
# Map of topic to topic validator
|
||||
self.topic_validators = {}
|
||||
|
||||
self.counter = int(time.time())
|
||||
@ -222,19 +211,10 @@ class Pubsub(Service, IPubsub):
|
||||
self.event_handle_dead_peer_queue_started = trio.Event()
|
||||
|
||||
async def run(self) -> None:
|
||||
self.manager.run_daemon_task(self._start_validation_throttler)
|
||||
self.manager.run_daemon_task(self.handle_peer_queue)
|
||||
self.manager.run_daemon_task(self.handle_dead_peer_queue)
|
||||
await self.manager.wait_finished()
|
||||
|
||||
async def _start_validation_throttler(self) -> None:
|
||||
"""Start validation throttler in current nursery context"""
|
||||
async with trio.open_nursery() as nursery:
|
||||
await self.validation_throttler.start(nursery)
|
||||
# Keep nursery alive until service stops
|
||||
while self.manager.is_running:
|
||||
await self.manager.wait_finished()
|
||||
|
||||
@property
|
||||
def my_id(self) -> ID:
|
||||
return self.host.get_id()
|
||||
@ -314,12 +294,7 @@ class Pubsub(Service, IPubsub):
|
||||
)
|
||||
|
||||
def set_topic_validator(
|
||||
self,
|
||||
topic: str,
|
||||
validator: ValidatorFn,
|
||||
is_async_validator: bool,
|
||||
timeout: float | None = None,
|
||||
throttle_limit: int | None = None,
|
||||
self, topic: str, validator: ValidatorFn, is_async_validator: bool
|
||||
) -> None:
|
||||
"""
|
||||
Register a validator under the given topic. One topic can only have one
|
||||
@ -328,18 +303,8 @@ class Pubsub(Service, IPubsub):
|
||||
:param topic: the topic to register validator under
|
||||
:param validator: the validator used to validate messages published to the topic
|
||||
:param is_async_validator: indicate if the validator is an asynchronous validator
|
||||
:param timeout: optional timeout for the validator
|
||||
:param throttle_limit: optional throttle limit for the validator
|
||||
""" # noqa: E501
|
||||
# Create throttled topic validator
|
||||
topic_validator = self.validation_throttler.create_topic_validator(
|
||||
topic=topic,
|
||||
validator=validator,
|
||||
is_async=is_async_validator,
|
||||
timeout=timeout,
|
||||
throttle_limit=throttle_limit,
|
||||
)
|
||||
self.topic_validators[topic] = topic_validator
|
||||
self.topic_validators[topic] = TopicValidator(validator, is_async_validator)
|
||||
|
||||
def remove_topic_validator(self, topic: str) -> None:
|
||||
"""
|
||||
@ -349,18 +314,17 @@ class Pubsub(Service, IPubsub):
|
||||
"""
|
||||
self.topic_validators.pop(topic, None)
|
||||
|
||||
def get_msg_validators(self, msg: rpc_pb2.Message) -> list[TopicValidator]:
|
||||
def get_msg_validators(self, msg: rpc_pb2.Message) -> tuple[TopicValidator, ...]:
|
||||
"""
|
||||
Get all validators corresponding to the topics in the message.
|
||||
|
||||
:param msg: the message published to the topic
|
||||
:return: list of topic validators for the message's topics
|
||||
"""
|
||||
return [
|
||||
return tuple(
|
||||
self.topic_validators[topic]
|
||||
for topic in msg.topicIDs
|
||||
if topic in self.topic_validators
|
||||
]
|
||||
)
|
||||
|
||||
def add_to_blacklist(self, peer_id: ID) -> None:
|
||||
"""
|
||||
@ -653,22 +617,16 @@ class Pubsub(Service, IPubsub):
|
||||
logger.debug("Fail to message peer %s: stream closed", peer_id)
|
||||
self._handle_dead_peer(peer_id)
|
||||
|
||||
async def publish(self, topic_id: str | list[str], data: bytes) -> None:
|
||||
async def publish(self, topic_id: str, data: bytes) -> None:
|
||||
"""
|
||||
Publish data to a topic or multiple topics.
|
||||
Publish data to a topic.
|
||||
|
||||
:param topic_id: topic (str) or topics (list[str]) to publish the data to
|
||||
:param topic_id: topic which we are going to publish the data to
|
||||
:param data: data which we are publishing
|
||||
"""
|
||||
# Handle both single topic (str) and multiple topics (list[str])
|
||||
if isinstance(topic_id, str):
|
||||
topic_ids = [topic_id]
|
||||
else:
|
||||
topic_ids = topic_id
|
||||
|
||||
msg = rpc_pb2.Message(
|
||||
data=data,
|
||||
topicIDs=topic_ids,
|
||||
topicIDs=[topic_id],
|
||||
# Origin is ourself.
|
||||
from_id=self.my_id.to_bytes(),
|
||||
seqno=self._next_seqno(),
|
||||
@ -676,9 +634,6 @@ class Pubsub(Service, IPubsub):
|
||||
|
||||
if self.strict_signing:
|
||||
priv_key = self.sign_key
|
||||
if priv_key is None:
|
||||
raise PeerDataError("private key not found")
|
||||
|
||||
signature = priv_key.sign(
|
||||
PUBSUB_SIGNING_PREFIX.encode() + msg.SerializeToString()
|
||||
)
|
||||
@ -696,56 +651,39 @@ class Pubsub(Service, IPubsub):
|
||||
:param msg_forwarder: the peer who forward us the message.
|
||||
:param msg: the message.
|
||||
"""
|
||||
# Get applicable validators for this message
|
||||
validators = self.get_msg_validators(msg)
|
||||
sync_topic_validators: list[SyncValidatorFn] = []
|
||||
async_topic_validators: list[AsyncValidatorFn] = []
|
||||
for topic_validator in self.get_msg_validators(msg):
|
||||
if topic_validator.is_async:
|
||||
async_topic_validators.append(
|
||||
cast(AsyncValidatorFn, topic_validator.validator)
|
||||
)
|
||||
else:
|
||||
sync_topic_validators.append(
|
||||
cast(SyncValidatorFn, topic_validator.validator)
|
||||
)
|
||||
|
||||
if not validators:
|
||||
# No validators, accept immediately
|
||||
return
|
||||
for validator in sync_topic_validators:
|
||||
if not validator(msg_forwarder, msg):
|
||||
raise ValidationError(f"Validation failed for msg={msg}")
|
||||
|
||||
# Use trio.Event for async coordination
|
||||
validation_event = trio.Event()
|
||||
result_container: dict[str, ValidationResult | None | Exception] = {
|
||||
"result": None,
|
||||
"error": None,
|
||||
}
|
||||
# TODO: Implement throttle on async validators
|
||||
|
||||
def handle_validation_result(
|
||||
result: ValidationResult, error: Exception | None
|
||||
) -> None:
|
||||
result_container["result"] = result
|
||||
result_container["error"] = error
|
||||
validation_event.set()
|
||||
if len(async_topic_validators) > 0:
|
||||
# TODO: Use a better pattern
|
||||
final_result: bool = True
|
||||
|
||||
# Submit for throttled validation
|
||||
success = await self.validation_throttler.submit_validation(
|
||||
validators=validators,
|
||||
msg_forwarder=msg_forwarder,
|
||||
msg=msg,
|
||||
result_callback=handle_validation_result,
|
||||
)
|
||||
async def run_async_validator(func: AsyncValidatorFn) -> None:
|
||||
nonlocal final_result
|
||||
result = await func(msg_forwarder, msg)
|
||||
final_result = final_result and result
|
||||
|
||||
if not success:
|
||||
# Validation was throttled at queue level
|
||||
raise ValidationError("Validation throttled at queue level")
|
||||
async with trio.open_nursery() as nursery:
|
||||
for async_validator in async_topic_validators:
|
||||
nursery.start_soon(run_async_validator, async_validator)
|
||||
|
||||
# Wait for validation result
|
||||
await validation_event.wait()
|
||||
|
||||
result = result_container["result"]
|
||||
error = result_container["error"]
|
||||
|
||||
if error:
|
||||
raise ValidationError(f"Validation error: {error}")
|
||||
|
||||
if result == ValidationResult.REJECT:
|
||||
raise ValidationError("Message validation rejected")
|
||||
elif result == ValidationResult.THROTTLED:
|
||||
raise ValidationError("Message validation throttled")
|
||||
elif result == ValidationResult.IGNORE:
|
||||
# Treat IGNORE as rejection for now, or you could silently drop
|
||||
raise ValidationError("Message validation ignored")
|
||||
# ACCEPT case - just return normally
|
||||
if not final_result:
|
||||
raise ValidationError(f"Validation failed for msg={msg}")
|
||||
|
||||
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
|
||||
"""
|
||||
@ -829,43 +767,3 @@ class Pubsub(Service, IPubsub):
|
||||
|
||||
def _is_subscribed_to_msg(self, msg: rpc_pb2.Message) -> bool:
|
||||
return any(topic in self.topic_ids for topic in msg.topicIDs)
|
||||
|
||||
async def write_msg(self, stream: INetStream, rpc_msg: rpc_pb2.RPC) -> bool:
|
||||
"""
|
||||
Write an RPC message to a stream with proper error handling.
|
||||
|
||||
Implements WriteMsg similar to go-msgio which is used in go-libp2p
|
||||
Ref: https://github.com/libp2p/go-msgio/blob/master/protoio/uvarint_writer.go#L56
|
||||
|
||||
|
||||
:param stream: stream to write the message to
|
||||
:param rpc_msg: RPC message to write
|
||||
:return: True if successful, False if stream was closed
|
||||
"""
|
||||
try:
|
||||
# Calculate message size first
|
||||
msg_bytes = rpc_msg.SerializeToString()
|
||||
msg_size = len(msg_bytes)
|
||||
|
||||
# Calculate varint size and allocate exact buffer size needed
|
||||
|
||||
varint_bytes = encode_uvarint(msg_size)
|
||||
varint_size = len(varint_bytes)
|
||||
|
||||
# Allocate buffer with exact size (like Go's pool.Get())
|
||||
buf = bytearray(varint_size + msg_size)
|
||||
|
||||
# Write varint length prefix to buffer (like Go's binary.PutUvarint())
|
||||
buf[:varint_size] = varint_bytes
|
||||
|
||||
# Write serialized message after varint (like Go's rpc.MarshalTo())
|
||||
buf[varint_size:] = msg_bytes
|
||||
|
||||
# Single write operation (like Go's s.Write(buf))
|
||||
await stream.write(bytes(buf))
|
||||
return True
|
||||
except StreamClosed:
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
logger.debug("Fail to write message to %s: stream closed", peer_id)
|
||||
self._handle_dead_peer(peer_id)
|
||||
return False
|
||||
|
||||
@ -1,314 +0,0 @@
|
||||
from collections.abc import (
|
||||
Callable,
|
||||
)
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import logging
|
||||
from typing import (
|
||||
NamedTuple,
|
||||
cast,
|
||||
)
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.custom_types import AsyncValidatorFn, ValidatorFn
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
|
||||
from .pb import (
|
||||
rpc_pb2,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("libp2p.pubsub.validation")
|
||||
|
||||
|
||||
class ValidationResult(Enum):
|
||||
ACCEPT = "accept"
|
||||
REJECT = "reject"
|
||||
IGNORE = "ignore"
|
||||
THROTTLED = "throttled"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationRequest:
|
||||
"""Request for message validation"""
|
||||
|
||||
validators: list["TopicValidator"]
|
||||
msg_forwarder: ID # peer ID
|
||||
msg: rpc_pb2.Message # message object
|
||||
result_callback: Callable[[ValidationResult, Exception | None], None]
|
||||
|
||||
|
||||
class TopicValidator(NamedTuple):
|
||||
topic: str
|
||||
validator: ValidatorFn
|
||||
is_async: bool
|
||||
timeout: float | None = None
|
||||
# Per-topic throttle semaphore
|
||||
throttle_semaphore: trio.Semaphore | None = None
|
||||
|
||||
|
||||
class ValidationThrottler:
|
||||
"""Manages all validation throttling mechanisms"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
queue_size: int = 32,
|
||||
global_throttle_limit: int = 8192,
|
||||
default_topic_throttle_limit: int = 1024,
|
||||
worker_count: int | None = None,
|
||||
):
|
||||
# 1. Queue-level throttling - bounded memory channel
|
||||
self._validation_send, self._validation_receive = trio.open_memory_channel[
|
||||
ValidationRequest
|
||||
](queue_size)
|
||||
|
||||
# 2. Global validation throttling - limits total concurrent async validations
|
||||
self._global_throttle = trio.Semaphore(global_throttle_limit)
|
||||
|
||||
# 3. Per-topic throttling - each validator gets its own semaphore
|
||||
self._default_topic_throttle_limit = default_topic_throttle_limit
|
||||
|
||||
# Worker management
|
||||
# TODO: Find a better way to manage worker count
|
||||
self._worker_count = worker_count or 4
|
||||
self._running = False
|
||||
|
||||
async def start(self, nursery: trio.Nursery) -> None:
|
||||
"""Start the validation workers"""
|
||||
self._running = True
|
||||
|
||||
# Start validation worker tasks
|
||||
for i in range(self._worker_count):
|
||||
nursery.start_soon(self._validation_worker, f"worker-{i}")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the validation system"""
|
||||
self._running = False
|
||||
await self._validation_send.aclose()
|
||||
|
||||
def create_topic_validator(
|
||||
self,
|
||||
topic: str,
|
||||
validator: ValidatorFn,
|
||||
is_async: bool,
|
||||
timeout: float | None = None,
|
||||
throttle_limit: int | None = None,
|
||||
) -> TopicValidator:
|
||||
"""Create a new topic validator with its own throttle"""
|
||||
limit = throttle_limit or self._default_topic_throttle_limit
|
||||
throttle_sem = trio.Semaphore(limit)
|
||||
|
||||
return TopicValidator(
|
||||
topic=topic,
|
||||
validator=validator,
|
||||
is_async=is_async,
|
||||
timeout=timeout,
|
||||
throttle_semaphore=throttle_sem,
|
||||
)
|
||||
|
||||
async def submit_validation(
|
||||
self,
|
||||
validators: list[TopicValidator],
|
||||
msg_forwarder: ID,
|
||||
msg: rpc_pb2.Message,
|
||||
result_callback: Callable[[ValidationResult, Exception | None], None],
|
||||
) -> bool:
|
||||
"""
|
||||
Submit a message for validation.
|
||||
Returns True if queued successfully, False if queue is full (throttled).
|
||||
"""
|
||||
if not self._running:
|
||||
result_callback(
|
||||
ValidationResult.REJECT, Exception("Validation system not running")
|
||||
)
|
||||
return False
|
||||
|
||||
request = ValidationRequest(
|
||||
validators=validators,
|
||||
msg_forwarder=msg_forwarder,
|
||||
msg=msg,
|
||||
result_callback=result_callback,
|
||||
)
|
||||
|
||||
try:
|
||||
# This will raise trio.WouldBlock if queue is full
|
||||
self._validation_send.send_nowait(request)
|
||||
return True
|
||||
except trio.WouldBlock:
|
||||
# Queue-level throttling: drop the message
|
||||
logger.debug(
|
||||
"Validation queue full, dropping message from %s", msg_forwarder
|
||||
)
|
||||
result_callback(
|
||||
ValidationResult.THROTTLED, Exception("Validation queue full")
|
||||
)
|
||||
return False
|
||||
|
||||
async def _validation_worker(self, worker_id: str) -> None:
|
||||
"""Worker that processes validation requests"""
|
||||
logger.debug("Validation worker %s started", worker_id)
|
||||
|
||||
async with self._validation_receive:
|
||||
async for request in self._validation_receive:
|
||||
if not self._running:
|
||||
break
|
||||
|
||||
try:
|
||||
# Process the validation request
|
||||
result = await self._validate_message(request)
|
||||
request.result_callback(result, None)
|
||||
except Exception as e:
|
||||
logger.exception("Error in validation worker %s", worker_id)
|
||||
request.result_callback(ValidationResult.REJECT, e)
|
||||
|
||||
logger.debug("Validation worker %s stopped", worker_id)
|
||||
|
||||
async def _validate_message(self, request: ValidationRequest) -> ValidationResult:
|
||||
"""Core validation logic with throttling"""
|
||||
validators = request.validators
|
||||
msg_forwarder = request.msg_forwarder
|
||||
msg = request.msg
|
||||
|
||||
if not validators:
|
||||
return ValidationResult.ACCEPT
|
||||
|
||||
# Separate sync and async validators
|
||||
sync_validators = [v for v in validators if not v.is_async]
|
||||
async_validators = [v for v in validators if v.is_async]
|
||||
|
||||
# Run synchronous validators first
|
||||
for validator in sync_validators:
|
||||
try:
|
||||
# Apply per-topic throttling even for sync validators
|
||||
if validator.throttle_semaphore:
|
||||
validator.throttle_semaphore.acquire_nowait()
|
||||
try:
|
||||
result = validator.validator(msg_forwarder, msg)
|
||||
if not result:
|
||||
return ValidationResult.REJECT
|
||||
finally:
|
||||
validator.throttle_semaphore.release()
|
||||
else:
|
||||
result = validator.validator(msg_forwarder, msg)
|
||||
if not result:
|
||||
return ValidationResult.REJECT
|
||||
except trio.WouldBlock:
|
||||
# Per-topic throttling for sync validator
|
||||
logger.debug("Sync validation throttled for topic %s", validator.topic)
|
||||
return ValidationResult.THROTTLED
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Sync validator failed for topic %s: %s", validator.topic, e
|
||||
)
|
||||
return ValidationResult.REJECT
|
||||
|
||||
# Handle async validators with global + per-topic throttling
|
||||
if async_validators:
|
||||
return await self._validate_async_validators(
|
||||
async_validators, msg_forwarder, msg
|
||||
)
|
||||
|
||||
return ValidationResult.ACCEPT
|
||||
|
||||
async def _validate_async_validators(
|
||||
self, validators: list[TopicValidator], msg_forwarder: ID, msg: rpc_pb2.Message
|
||||
) -> ValidationResult:
|
||||
"""Handle async validators with proper throttling"""
|
||||
if len(validators) == 1:
|
||||
# Fast path for single validator
|
||||
return await self._validate_single_async_validator(
|
||||
validators[0], msg_forwarder, msg
|
||||
)
|
||||
|
||||
# Multiple async validators - run them concurrently
|
||||
try:
|
||||
# Try to acquire global throttle slot
|
||||
self._global_throttle.acquire_nowait()
|
||||
except trio.WouldBlock:
|
||||
logger.debug(
|
||||
"Global validation throttle exceeded, dropping message from %s",
|
||||
msg_forwarder,
|
||||
)
|
||||
return ValidationResult.THROTTLED
|
||||
|
||||
try:
|
||||
async with trio.open_nursery() as nursery:
|
||||
results = {}
|
||||
|
||||
async def run_validator(validator: TopicValidator, index: int) -> None:
|
||||
"""Run a single async validator and store the result"""
|
||||
nonlocal results
|
||||
result = await self._validate_single_async_validator(
|
||||
validator, msg_forwarder, msg
|
||||
)
|
||||
results[index] = result
|
||||
|
||||
# Start all validators concurrently
|
||||
for i, validator in enumerate(validators):
|
||||
nursery.start_soon(run_validator, validator, i)
|
||||
|
||||
# Process results - any reject or throttle causes overall failure
|
||||
final_result = ValidationResult.ACCEPT
|
||||
for result in results.values():
|
||||
if result == ValidationResult.REJECT:
|
||||
return ValidationResult.REJECT
|
||||
elif result == ValidationResult.THROTTLED:
|
||||
final_result = ValidationResult.THROTTLED
|
||||
elif (
|
||||
result == ValidationResult.IGNORE
|
||||
and final_result == ValidationResult.ACCEPT
|
||||
):
|
||||
final_result = ValidationResult.IGNORE
|
||||
|
||||
return final_result
|
||||
|
||||
finally:
|
||||
self._global_throttle.release()
|
||||
|
||||
return ValidationResult.IGNORE
|
||||
|
||||
async def _validate_single_async_validator(
|
||||
self, validator: TopicValidator, msg_forwarder: ID, msg: rpc_pb2.Message
|
||||
) -> ValidationResult:
|
||||
"""Validate with a single async validator"""
|
||||
# Apply per-topic throttling
|
||||
if validator.throttle_semaphore:
|
||||
try:
|
||||
validator.throttle_semaphore.acquire_nowait()
|
||||
except trio.WouldBlock:
|
||||
logger.debug(
|
||||
"Per-topic validation throttled for topic %s", validator.topic
|
||||
)
|
||||
return ValidationResult.THROTTLED
|
||||
else:
|
||||
# Fallback if no throttle semaphore configured
|
||||
pass
|
||||
|
||||
try:
|
||||
# Apply timeout if configured
|
||||
result: bool
|
||||
if validator.timeout:
|
||||
with trio.fail_after(validator.timeout):
|
||||
func = cast(AsyncValidatorFn, validator.validator)
|
||||
result = await func(msg_forwarder, msg)
|
||||
else:
|
||||
func = cast(AsyncValidatorFn, validator.validator)
|
||||
result = await func(msg_forwarder, msg)
|
||||
|
||||
return ValidationResult.ACCEPT if result else ValidationResult.REJECT
|
||||
|
||||
except trio.TooSlowError:
|
||||
logger.debug("Validation timeout for topic %s", validator.topic)
|
||||
return ValidationResult.IGNORE
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Async validator failed for topic %s: %s", validator.topic, e
|
||||
)
|
||||
return ValidationResult.REJECT
|
||||
finally:
|
||||
if validator.throttle_semaphore:
|
||||
validator.throttle_semaphore.release()
|
||||
|
||||
return ValidationResult.IGNORE
|
||||
@ -1,28 +0,0 @@
|
||||
"""
|
||||
Relay module for libp2p.
|
||||
|
||||
This package includes implementations of circuit relay protocols
|
||||
for enabling connectivity between peers behind NATs or firewalls.
|
||||
"""
|
||||
|
||||
# Import the circuit_v2 module to make it accessible
|
||||
# through the relay package
|
||||
from libp2p.relay.circuit_v2 import (
|
||||
PROTOCOL_ID,
|
||||
CircuitV2Protocol,
|
||||
CircuitV2Transport,
|
||||
RelayDiscovery,
|
||||
RelayLimits,
|
||||
RelayResourceManager,
|
||||
Reservation,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CircuitV2Protocol",
|
||||
"CircuitV2Transport",
|
||||
"PROTOCOL_ID",
|
||||
"RelayDiscovery",
|
||||
"RelayLimits",
|
||||
"RelayResourceManager",
|
||||
"Reservation",
|
||||
]
|
||||
@ -1,32 +0,0 @@
|
||||
"""
|
||||
Circuit Relay v2 implementation for libp2p.
|
||||
|
||||
This package implements the Circuit Relay v2 protocol as specified in:
|
||||
https://github.com/libp2p/specs/blob/master/relay/circuit-v2.md
|
||||
"""
|
||||
|
||||
from .discovery import (
|
||||
RelayDiscovery,
|
||||
)
|
||||
from .protocol import (
|
||||
PROTOCOL_ID,
|
||||
CircuitV2Protocol,
|
||||
)
|
||||
from .resources import (
|
||||
RelayLimits,
|
||||
RelayResourceManager,
|
||||
Reservation,
|
||||
)
|
||||
from .transport import (
|
||||
CircuitV2Transport,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CircuitV2Protocol",
|
||||
"PROTOCOL_ID",
|
||||
"RelayLimits",
|
||||
"Reservation",
|
||||
"RelayResourceManager",
|
||||
"CircuitV2Transport",
|
||||
"RelayDiscovery",
|
||||
]
|
||||
@ -1,92 +0,0 @@
|
||||
"""
|
||||
Configuration management for Circuit Relay v2.
|
||||
|
||||
This module handles configuration for relay roles, resource limits,
|
||||
and discovery settings.
|
||||
"""
|
||||
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
field,
|
||||
)
|
||||
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
|
||||
from .resources import (
|
||||
RelayLimits,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RelayConfig:
|
||||
"""Configuration for Circuit Relay v2."""
|
||||
|
||||
# Role configuration
|
||||
enable_hop: bool = False # Whether to act as a relay (hop)
|
||||
enable_stop: bool = True # Whether to accept relayed connections (stop)
|
||||
enable_client: bool = True # Whether to use relays for dialing
|
||||
|
||||
# Resource limits
|
||||
limits: RelayLimits | None = None
|
||||
|
||||
# Discovery configuration
|
||||
bootstrap_relays: list[PeerInfo] = field(default_factory=list)
|
||||
min_relays: int = 3
|
||||
max_relays: int = 20
|
||||
discovery_interval: int = 300 # seconds
|
||||
|
||||
# Connection configuration
|
||||
reservation_ttl: int = 3600 # seconds
|
||||
max_circuit_duration: int = 3600 # seconds
|
||||
max_circuit_bytes: int = 1024 * 1024 * 1024 # 1GB
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Initialize default values."""
|
||||
if self.limits is None:
|
||||
self.limits = RelayLimits(
|
||||
duration=self.max_circuit_duration,
|
||||
data=self.max_circuit_bytes,
|
||||
max_circuit_conns=8,
|
||||
max_reservations=4,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HopConfig:
|
||||
"""Configuration specific to relay (hop) nodes."""
|
||||
|
||||
# Resource limits per IP
|
||||
max_reservations_per_ip: int = 8
|
||||
max_circuits_per_ip: int = 16
|
||||
|
||||
# Rate limiting
|
||||
reservation_rate_per_ip: int = 4 # per minute
|
||||
circuit_rate_per_ip: int = 8 # per minute
|
||||
|
||||
# Resource quotas
|
||||
max_circuits_total: int = 64
|
||||
max_reservations_total: int = 32
|
||||
|
||||
# Bandwidth limits
|
||||
max_bandwidth_per_circuit: int = 1024 * 1024 # 1MB/s
|
||||
max_bandwidth_total: int = 10 * 1024 * 1024 # 10MB/s
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClientConfig:
|
||||
"""Configuration specific to relay clients."""
|
||||
|
||||
# Relay selection
|
||||
min_relay_score: float = 0.5
|
||||
max_relay_latency: float = 1.0 # seconds
|
||||
|
||||
# Auto-relay settings
|
||||
enable_auto_relay: bool = True
|
||||
auto_relay_timeout: int = 30 # seconds
|
||||
max_auto_relay_attempts: int = 3
|
||||
|
||||
# Reservation management
|
||||
reservation_refresh_threshold: float = 0.8 # Refresh at 80% of TTL
|
||||
max_concurrent_reservations: int = 2
|
||||
@ -1,538 +0,0 @@
|
||||
"""
|
||||
Discovery module for Circuit Relay v2.
|
||||
|
||||
This module handles discovering and tracking relay nodes in the network.
|
||||
"""
|
||||
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
)
|
||||
import logging
|
||||
import time
|
||||
from typing import (
|
||||
Any,
|
||||
Protocol as TypingProtocol,
|
||||
cast,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
IHost,
|
||||
)
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.tools.async_service import (
|
||||
Service,
|
||||
)
|
||||
|
||||
from .pb.circuit_pb2 import (
|
||||
HopMessage,
|
||||
)
|
||||
from .protocol import (
|
||||
PROTOCOL_ID,
|
||||
)
|
||||
from .protocol_buffer import (
|
||||
StatusCode,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("libp2p.relay.circuit_v2.discovery")
|
||||
|
||||
# Constants
|
||||
MAX_RELAYS_TO_TRACK = 10
|
||||
DEFAULT_DISCOVERY_INTERVAL = 60 # seconds
|
||||
STREAM_TIMEOUT = 10 # seconds
|
||||
|
||||
|
||||
# Extended interfaces for type checking
|
||||
@runtime_checkable
|
||||
class IHostWithMultiselect(TypingProtocol):
|
||||
"""Extended host interface with multiselect attribute."""
|
||||
|
||||
@property
|
||||
def multiselect(self) -> Any:
|
||||
"""Get the multiselect component."""
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class RelayInfo:
|
||||
"""Information about a discovered relay."""
|
||||
|
||||
peer_id: ID
|
||||
discovered_at: float
|
||||
last_seen: float
|
||||
has_reservation: bool = False
|
||||
reservation_expires_at: float | None = None
|
||||
reservation_data_limit: int | None = None
|
||||
|
||||
|
||||
class RelayDiscovery(Service):
|
||||
"""
|
||||
Discovery service for Circuit Relay v2 nodes.
|
||||
|
||||
This service discovers and keeps track of available relay nodes, and optionally
|
||||
makes reservations with them.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: IHost,
|
||||
auto_reserve: bool = False,
|
||||
discovery_interval: int = DEFAULT_DISCOVERY_INTERVAL,
|
||||
max_relays: int = MAX_RELAYS_TO_TRACK,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the discovery service.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
host : IHost
|
||||
The libp2p host this discovery service is running on
|
||||
auto_reserve : bool
|
||||
Whether to automatically make reservations with discovered relays
|
||||
discovery_interval : int
|
||||
How often to run discovery, in seconds
|
||||
max_relays : int
|
||||
Maximum number of relays to track
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.host = host
|
||||
self.auto_reserve = auto_reserve
|
||||
self.discovery_interval = discovery_interval
|
||||
self.max_relays = max_relays
|
||||
self._discovered_relays: dict[ID, RelayInfo] = {}
|
||||
self._protocol_cache: dict[
|
||||
ID, set[str]
|
||||
] = {} # Cache protocol info to reduce queries
|
||||
self.event_started = trio.Event()
|
||||
self.is_running = False
|
||||
|
||||
async def run(self, *, task_status: Any = trio.TASK_STATUS_IGNORED) -> None:
|
||||
"""Run the discovery service."""
|
||||
try:
|
||||
self.is_running = True
|
||||
self.event_started.set()
|
||||
task_status.started()
|
||||
|
||||
# Main discovery loop
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Run initial discovery
|
||||
nursery.start_soon(self.discover_relays)
|
||||
|
||||
# Set up periodic discovery
|
||||
while True:
|
||||
await trio.sleep(self.discovery_interval)
|
||||
if not self.manager.is_running:
|
||||
break
|
||||
nursery.start_soon(self.discover_relays)
|
||||
|
||||
# Cleanup expired relays and reservations
|
||||
await self._cleanup_expired()
|
||||
|
||||
finally:
|
||||
self.is_running = False
|
||||
|
||||
async def discover_relays(self) -> None:
|
||||
r"""
|
||||
Discover relay nodes in the network.
|
||||
|
||||
This method queries the network for peers that support the
|
||||
Circuit Relay v2 protocol.
|
||||
"""
|
||||
logger.debug("Starting relay discovery")
|
||||
|
||||
try:
|
||||
# Get connected peers
|
||||
connected_peers = self.host.get_connected_peers()
|
||||
logger.debug(
|
||||
"Checking %d connected peers for relay support", len(connected_peers)
|
||||
)
|
||||
|
||||
# Check each peer if they support the relay protocol
|
||||
for peer_id in connected_peers:
|
||||
if peer_id == self.host.get_id():
|
||||
continue # Skip ourselves
|
||||
|
||||
if peer_id in self._discovered_relays:
|
||||
# Update last seen time for existing relay
|
||||
self._discovered_relays[peer_id].last_seen = time.time()
|
||||
continue
|
||||
|
||||
# Check if peer supports the relay protocol
|
||||
with trio.move_on_after(5): # Don't wait too long for protocol info
|
||||
if await self._supports_relay_protocol(peer_id):
|
||||
await self._add_relay(peer_id)
|
||||
|
||||
# Limit number of relays we track
|
||||
if len(self._discovered_relays) > self.max_relays:
|
||||
# Sort by last seen time and keep only the most recent ones
|
||||
sorted_relays = sorted(
|
||||
self._discovered_relays.items(),
|
||||
key=lambda x: x[1].last_seen,
|
||||
reverse=True,
|
||||
)
|
||||
to_remove = sorted_relays[self.max_relays :]
|
||||
for peer_id, _ in to_remove:
|
||||
del self._discovered_relays[peer_id]
|
||||
|
||||
logger.debug(
|
||||
"Discovery completed, tracking %d relays", len(self._discovered_relays)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error during relay discovery: %s", str(e))
|
||||
|
||||
async def _supports_relay_protocol(self, peer_id: ID) -> bool:
|
||||
"""
|
||||
Check if a peer supports the relay protocol.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The ID of the peer to check
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the peer supports the relay protocol, False otherwise
|
||||
|
||||
"""
|
||||
# Check cache first
|
||||
if peer_id in self._protocol_cache:
|
||||
return PROTOCOL_ID in self._protocol_cache[peer_id]
|
||||
|
||||
# Method 1: Try peerstore
|
||||
result = await self._check_via_peerstore(peer_id)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Method 2: Try direct stream connection
|
||||
result = await self._check_via_direct_connection(peer_id)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Method 3: Try protocols from mux
|
||||
result = await self._check_via_mux(peer_id)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Default: Cannot determine, assume false
|
||||
return False
|
||||
|
||||
async def _check_via_peerstore(self, peer_id: ID) -> bool | None:
|
||||
"""Check protocol support via peerstore."""
|
||||
try:
|
||||
peerstore = self.host.get_peerstore()
|
||||
proto_getter = peerstore.get_protocols
|
||||
|
||||
if not callable(proto_getter):
|
||||
return None
|
||||
if peer_id not in peerstore.peer_ids():
|
||||
return None
|
||||
try:
|
||||
# Try to get protocols
|
||||
proto_result = proto_getter(peer_id)
|
||||
|
||||
# Get protocols list
|
||||
protocols_list = []
|
||||
if hasattr(proto_result, "__await__"):
|
||||
protocols_list = await cast(Any, proto_result)
|
||||
else:
|
||||
protocols_list = proto_result
|
||||
|
||||
# Check result
|
||||
if protocols_list is not None:
|
||||
protocols = set(protocols_list)
|
||||
self._protocol_cache[peer_id] = protocols
|
||||
return PROTOCOL_ID in protocols
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.debug("Error getting protocols: %s", str(e))
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.debug("Error accessing peerstore: %s", str(e))
|
||||
return None
|
||||
|
||||
async def _check_via_direct_connection(self, peer_id: ID) -> bool | None:
|
||||
"""Check protocol support via direct connection."""
|
||||
try:
|
||||
with trio.fail_after(STREAM_TIMEOUT):
|
||||
stream = await self.host.new_stream(peer_id, [PROTOCOL_ID])
|
||||
if stream:
|
||||
await stream.close()
|
||||
self._protocol_cache[peer_id] = {PROTOCOL_ID}
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Failed to open relay protocol stream to %s: %s", peer_id, str(e)
|
||||
)
|
||||
return None
|
||||
|
||||
async def _check_via_mux(self, peer_id: ID) -> bool | None:
|
||||
"""Check protocol support via mux protocols."""
|
||||
try:
|
||||
if not (hasattr(self.host, "get_mux") and self.host.get_mux() is not None):
|
||||
return None
|
||||
|
||||
mux = self.host.get_mux()
|
||||
|
||||
peer_protocols = set()
|
||||
# Get protocols from mux with proper type safety
|
||||
available_protocols = []
|
||||
if hasattr(mux, "get_protocols"):
|
||||
# Get protocols with proper typing
|
||||
mux_protocols = mux.get_protocols()
|
||||
if isinstance(mux_protocols, (list, tuple)):
|
||||
available_protocols = [
|
||||
p for p in mux.get_protocols() if p is not None
|
||||
]
|
||||
|
||||
for protocol in available_protocols:
|
||||
try:
|
||||
with trio.fail_after(2): # Quick check
|
||||
# Ensure we have a proper protocol object
|
||||
# Use string representation since we can't use isinstance
|
||||
is_tprotocol = str(type(protocol)) == str(type(TProtocol))
|
||||
protocol_obj = (
|
||||
protocol if is_tprotocol else TProtocol(str(protocol))
|
||||
)
|
||||
stream = await self.host.new_stream(peer_id, [protocol_obj])
|
||||
if stream:
|
||||
peer_protocols.add(str(protocol_obj))
|
||||
await stream.close()
|
||||
except Exception:
|
||||
pass # Ignore errors when closing the stream
|
||||
|
||||
self._protocol_cache[peer_id] = peer_protocols
|
||||
protocol_str = str(PROTOCOL_ID)
|
||||
for protocol in map(TProtocol, peer_protocols):
|
||||
if protocol == protocol_str:
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.debug("Error checking protocols via mux: %s", str(e))
|
||||
return None
|
||||
|
||||
async def _add_relay(self, peer_id: ID) -> None:
|
||||
"""
|
||||
Add a peer as a relay and optionally make a reservation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The ID of the peer to add as a relay
|
||||
|
||||
"""
|
||||
now = time.time()
|
||||
relay_info = RelayInfo(
|
||||
peer_id=peer_id,
|
||||
discovered_at=now,
|
||||
last_seen=now,
|
||||
)
|
||||
self._discovered_relays[peer_id] = relay_info
|
||||
logger.debug("Added relay %s to discovered relays", peer_id)
|
||||
|
||||
# If auto-reserve is enabled, make a reservation with this relay
|
||||
if self.auto_reserve:
|
||||
await self.make_reservation(peer_id)
|
||||
|
||||
async def make_reservation(self, peer_id: ID) -> bool:
|
||||
"""
|
||||
Make a reservation with a relay.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The ID of the relay to make a reservation with
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if reservation succeeded, False otherwise
|
||||
|
||||
"""
|
||||
if peer_id not in self._discovered_relays:
|
||||
logger.error("Cannot make reservation with unknown relay %s", peer_id)
|
||||
return False
|
||||
|
||||
stream = None
|
||||
try:
|
||||
logger.debug("Making reservation with relay %s", peer_id)
|
||||
|
||||
# Open a stream to the relay with timeout
|
||||
try:
|
||||
with trio.fail_after(STREAM_TIMEOUT):
|
||||
stream = await self.host.new_stream(peer_id, [PROTOCOL_ID])
|
||||
if not stream:
|
||||
logger.error("Failed to open stream to relay %s", peer_id)
|
||||
return False
|
||||
except trio.TooSlowError:
|
||||
logger.error("Timeout opening stream to relay %s", peer_id)
|
||||
return False
|
||||
|
||||
try:
|
||||
# Create and send reservation request
|
||||
request = HopMessage(
|
||||
type=HopMessage.RESERVE,
|
||||
peer=self.host.get_id().to_bytes(),
|
||||
)
|
||||
|
||||
with trio.fail_after(STREAM_TIMEOUT):
|
||||
await stream.write(request.SerializeToString())
|
||||
|
||||
# Wait for response
|
||||
response_bytes = await stream.read()
|
||||
if not response_bytes:
|
||||
logger.error("No response received from relay %s", peer_id)
|
||||
return False
|
||||
|
||||
# Parse response
|
||||
response = HopMessage()
|
||||
response.ParseFromString(response_bytes)
|
||||
|
||||
# Check if reservation was successful
|
||||
if response.type == HopMessage.RESERVE and response.HasField(
|
||||
"status"
|
||||
):
|
||||
# Access status code directly from protobuf object
|
||||
status_code = getattr(response.status, "code", StatusCode.OK)
|
||||
|
||||
if status_code == StatusCode.OK:
|
||||
# Update relay info with reservation details
|
||||
relay_info = self._discovered_relays[peer_id]
|
||||
relay_info.has_reservation = True
|
||||
|
||||
if response.HasField("reservation") and response.HasField(
|
||||
"limit"
|
||||
):
|
||||
relay_info.reservation_expires_at = (
|
||||
response.reservation.expire
|
||||
)
|
||||
relay_info.reservation_data_limit = response.limit.data
|
||||
|
||||
logger.debug(
|
||||
"Successfully made reservation with relay %s", peer_id
|
||||
)
|
||||
return True
|
||||
|
||||
# Reservation failed
|
||||
error_message = "Unknown error"
|
||||
if response.HasField("status"):
|
||||
# Access message directly from protobuf object
|
||||
error_message = getattr(response.status, "message", "")
|
||||
|
||||
logger.warning(
|
||||
"Reservation request rejected by relay %s: %s",
|
||||
peer_id,
|
||||
error_message,
|
||||
)
|
||||
return False
|
||||
|
||||
except trio.TooSlowError:
|
||||
logger.error(
|
||||
"Timeout during reservation process with relay %s", peer_id
|
||||
)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error making reservation with relay %s: %s", peer_id, str(e))
|
||||
return False
|
||||
finally:
|
||||
# Always close the stream
|
||||
if stream:
|
||||
try:
|
||||
await stream.close()
|
||||
except Exception:
|
||||
pass # Ignore errors when closing the stream
|
||||
|
||||
return False
|
||||
|
||||
async def _cleanup_expired(self) -> None:
|
||||
"""Clean up expired relays and reservations."""
|
||||
now = time.time()
|
||||
to_remove = []
|
||||
|
||||
for peer_id, relay_info in self._discovered_relays.items():
|
||||
# Check if relay hasn't been seen in a while (3x discovery interval)
|
||||
if now - relay_info.last_seen > self.discovery_interval * 3:
|
||||
to_remove.append(peer_id)
|
||||
continue
|
||||
|
||||
# Check if reservation has expired
|
||||
if (
|
||||
relay_info.has_reservation
|
||||
and relay_info.reservation_expires_at
|
||||
and now > relay_info.reservation_expires_at
|
||||
):
|
||||
relay_info.has_reservation = False
|
||||
relay_info.reservation_expires_at = None
|
||||
relay_info.reservation_data_limit = None
|
||||
|
||||
# If auto-reserve is enabled, try to renew
|
||||
if self.auto_reserve:
|
||||
await self.make_reservation(peer_id)
|
||||
|
||||
# Remove expired relays
|
||||
for peer_id in to_remove:
|
||||
del self._discovered_relays[peer_id]
|
||||
if peer_id in self._protocol_cache:
|
||||
del self._protocol_cache[peer_id]
|
||||
|
||||
def get_relays(self) -> list[ID]:
|
||||
"""
|
||||
Get a list of discovered relay peer IDs.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[ID]
|
||||
List of discovered relay peer IDs
|
||||
|
||||
"""
|
||||
return list(self._discovered_relays.keys())
|
||||
|
||||
def get_relay_info(self, peer_id: ID) -> RelayInfo | None:
|
||||
"""
|
||||
Get information about a specific relay.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The ID of the relay to get information about
|
||||
|
||||
Returns
|
||||
-------
|
||||
Optional[RelayInfo]
|
||||
Information about the relay, or None if not found
|
||||
|
||||
"""
|
||||
return self._discovered_relays.get(peer_id)
|
||||
|
||||
def get_relay(self) -> ID | None:
|
||||
"""
|
||||
Get a single relay peer ID for connection purposes.
|
||||
Prioritizes relays with active reservations.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Optional[ID]
|
||||
ID of a discovered relay, or None if no relays found
|
||||
|
||||
"""
|
||||
if not self._discovered_relays:
|
||||
return None
|
||||
|
||||
# First try to find a relay with an active reservation
|
||||
for peer_id, relay_info in self._discovered_relays.items():
|
||||
if relay_info and relay_info.has_reservation:
|
||||
return peer_id
|
||||
|
||||
return next(iter(self._discovered_relays.keys()), None)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user