Merge branch 'main' into feat/619-store-pubkey-peerid-peerstore

This commit is contained in:
Soham Bhoir
2025-06-10 21:12:28 +05:30
committed by GitHub
123 changed files with 2849 additions and 1444 deletions

View File

@ -16,10 +16,10 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python: ['3.9', '3.10', '3.11', '3.12', '3.13']
python: ["3.10", "3.11", "3.12", "3.13"]
toxenv: [core, interop, lint, wheel, demos]
include:
- python: '3.10'
- python: "3.10"
toxenv: docs
fail-fast: false
steps:
@ -46,7 +46,7 @@ jobs:
runs-on: windows-latest
strategy:
matrix:
python-version: ['3.11', '3.12', '3.13']
python-version: ["3.11", "3.12", "3.13"]
toxenv: [core, wheel]
fail-fast: false
steps:

7
.gitignore vendored
View File

@ -146,6 +146,9 @@ instance/
# PyBuilder
target/
# PyRight Config
pyrightconfig.json
# Jupyter Notebook
.ipynb_checkpoints
@ -171,3 +174,7 @@ env.bak/
# mkdocs documentation
/site
#lockfiles
uv.lock
poetry.lock

View File

@ -1,59 +1,49 @@
exclude: '.project-template|docs/conf.py|.*pb2\..*'
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: check-yaml
- id: check-toml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/asottile/pyupgrade
rev: v3.15.0
- id: check-yaml
- id: check-toml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/asottile/pyupgrade
rev: v3.20.0
hooks:
- id: pyupgrade
args: [--py39-plus]
- repo: https://github.com/psf/black
rev: 23.9.1
- id: pyupgrade
args: [--py310-plus]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.10
hooks:
- id: black
- repo: https://github.com/PyCQA/flake8
rev: 6.1.0
hooks:
- id: flake8
additional_dependencies:
- flake8-bugbear==23.9.16
exclude: setup.py
- repo: https://github.com/PyCQA/autoflake
rev: v2.2.1
hooks:
- id: autoflake
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
- repo: https://github.com/pycqa/pydocstyle
rev: 6.3.0
hooks:
- id: pydocstyle
additional_dependencies:
- tomli # required until >= python311
- repo: https://github.com/executablebooks/mdformat
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- id: ruff-format
- repo: https://github.com/executablebooks/mdformat
rev: 0.7.22
hooks:
- id: mdformat
- id: mdformat
additional_dependencies:
- mdformat-gfm
- repo: local
- mdformat-gfm
- repo: local
hooks:
- id: mypy-local
- id: mypy-local
name: run mypy with all dev dependencies present
entry: python -m mypy -p libp2p
entry: mypy -p libp2p
language: system
always_run: true
pass_filenames: false
- repo: local
- repo: local
hooks:
- id: check-rst-files
- id: pyrefly-local
name: run pyrefly typecheck locally
entry: pyrefly check
language: system
always_run: true
pass_filenames: false
- repo: local
hooks:
- id: check-rst-files
name: Check for .rst files in the top-level directory
entry: python -c "import glob, sys; rst_files = glob.glob('*.rst'); sys.exit(1) if rst_files else sys.exit(0)"
language: system

View File

@ -1,71 +0,0 @@
#!/usr/bin/env python3
import os
import sys
import re
from pathlib import Path
def _find_files(project_root):
path_exclude_pattern = r"\.git($|\/)|venv|_build"
file_exclude_pattern = r"fill_template_vars\.py|\.swp$"
filepaths = []
for dir_path, _dir_names, file_names in os.walk(project_root):
if not re.search(path_exclude_pattern, dir_path):
for file in file_names:
if not re.search(file_exclude_pattern, file):
filepaths.append(str(Path(dir_path, file)))
return filepaths
def _replace(pattern, replacement, project_root):
print(f"Replacing values: {pattern}")
for file in _find_files(project_root):
try:
with open(file) as f:
content = f.read()
content = re.sub(pattern, replacement, content)
with open(file, "w") as f:
f.write(content)
except UnicodeDecodeError:
pass
def main():
project_root = Path(os.path.realpath(sys.argv[0])).parent.parent
module_name = input("What is your python module name? ")
pypi_input = input(f"What is your pypi package name? (default: {module_name}) ")
pypi_name = pypi_input or module_name
repo_input = input(f"What is your github project name? (default: {pypi_name}) ")
repo_name = repo_input or pypi_name
rtd_input = input(
f"What is your readthedocs.org project name? (default: {pypi_name}) "
)
rtd_name = rtd_input or pypi_name
project_input = input(
f"What is your project name (ex: at the top of the README)? (default: {repo_name}) "
)
project_name = project_input or repo_name
short_description = input("What is a one-liner describing the project? ")
_replace("<MODULE_NAME>", module_name, project_root)
_replace("<PYPI_NAME>", pypi_name, project_root)
_replace("<REPO_NAME>", repo_name, project_root)
_replace("<RTD_NAME>", rtd_name, project_root)
_replace("<PROJECT_NAME>", project_name, project_root)
_replace("<SHORT_DESCRIPTION>", short_description, project_root)
os.makedirs(project_root / module_name, exist_ok=True)
Path(project_root / module_name / "__init__.py").touch()
Path(project_root / module_name / "py.typed").touch()
if __name__ == "__main__":
main()

View File

@ -1,39 +0,0 @@
#!/usr/bin/env python3
import os
import sys
from pathlib import Path
import subprocess
def main():
template_dir = Path(os.path.dirname(sys.argv[0]))
template_vars_file = template_dir / "template_vars.txt"
fill_template_vars_script = template_dir / "fill_template_vars.py"
with open(template_vars_file, "r") as input_file:
content_lines = input_file.readlines()
process = subprocess.Popen(
[sys.executable, str(fill_template_vars_script)],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
for line in content_lines:
process.stdin.write(line)
process.stdin.flush()
stdout, stderr = process.communicate()
if process.returncode != 0:
print(f"Error occurred: {stderr}")
sys.exit(1)
print(stdout)
if __name__ == "__main__":
main()

View File

@ -1,6 +0,0 @@
libp2p
libp2p
py-libp2p
py-libp2p
py-libp2p
The Python implementation of the libp2p networking stack

View File

@ -7,12 +7,14 @@ help:
@echo "clean-pyc - remove Python file artifacts"
@echo "clean - run clean-build and clean-pyc"
@echo "dist - build package and cat contents of the dist directory"
@echo "fix - fix formatting & linting issues with ruff"
@echo "lint - fix linting issues with pre-commit"
@echo "test - run tests quickly with the default Python"
@echo "docs - generate docs and open in browser (linux-docs for version on linux)"
@echo "package-test - build package and install it in a venv for manual testing"
@echo "notes - consume towncrier newsfragments and update release notes in docs - requires bump to be set"
@echo "release - package and upload a release (does not run notes target) - requires bump to be set"
@echo "pr - run clean, fix, lint, typecheck, and test i.e basically everything you need to do before creating a PR"
clean-build:
rm -fr build/
@ -37,8 +39,16 @@ lint:
&& pre-commit run --all-files --show-diff-on-failure \
)
fix:
python -m ruff check --fix
typecheck:
pre-commit run mypy-local --all-files && pre-commit run pyrefly-local --all-files
test:
python -m pytest tests
python -m pytest tests -n auto
pr: clean fix lint typecheck test
# protobufs management

View File

@ -15,14 +15,24 @@
# documentation root, use os.path.abspath to make it absolute, like shown here.
# sys.path.insert(0, os.path.abspath('.'))
import doctest
import os
import sys
from unittest.mock import MagicMock
DIR = os.path.dirname(__file__)
with open(os.path.join(DIR, "../setup.py"), "r") as f:
for line in f:
if "version=" in line:
setup_version = line.split('"')[1]
break
try:
import tomllib
except ModuleNotFoundError:
# For Python < 3.11
import tomli as tomllib # type: ignore (In case of >3.11 Pyrefly doesnt find tomli , which is right but a false flag)
# Path to pyproject.toml (assuming conf.py is in a 'docs' subdirectory)
pyproject_path = os.path.join(os.path.dirname(__file__), "..", "pyproject.toml")
with open(pyproject_path, "rb") as f:
pyproject_data = tomllib.load(f)
setup_version = pyproject_data["project"]["version"]
# -- General configuration ------------------------------------------------
@ -302,7 +312,6 @@ intersphinx_mapping = {
# -- Doctest configuration ----------------------------------------
import doctest
doctest_default_flags = (
0
@ -317,10 +326,9 @@ doctest_default_flags = (
# Mock out dependencies that are unbuildable on readthedocs, as recommended here:
# https://docs.readthedocs.io/en/rel/faq.html#i-get-import-errors-on-libraries-that-depend-on-c-modules
import sys
from unittest.mock import MagicMock
# Add new modules to mock here (it should be the same list as those excluded in setup.py)
# Add new modules to mock here (it should be the same list
# as those excluded in pyproject.toml)
MOCK_MODULES = [
"fastecdsa",
"fastecdsa.encoding",
@ -338,4 +346,4 @@ todo_include_todos = True
# Allow duplicate object descriptions
nitpicky = False
nitpick_ignore = [("py:class", "type")]
nitpick_ignore = [("py:class", "type")]

View File

@ -40,7 +40,6 @@ async def write_data(stream: INetStream) -> None:
async def run(port: int, destination: str) -> None:
localhost_ip = "127.0.0.1"
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
host = new_host()
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
@ -54,8 +53,8 @@ async def run(port: int, destination: str) -> None:
print(
"Run this from the same folder in another console:\n\n"
f"chat-demo -p {int(port) + 1} "
f"-d /ip4/{localhost_ip}/tcp/{port}/p2p/{host.get_id().pretty()}\n"
f"chat-demo "
f"-d {host.get_addrs()[0]}\n"
)
print("Waiting for incoming connection...")
@ -87,9 +86,7 @@ def main() -> None:
"/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
)
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
"-p", "--port", default=8000, type=int, help="source port number"
)
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
parser.add_argument(
"-d",
"--destination",
@ -98,9 +95,6 @@ def main() -> None:
)
args = parser.parse_args()
if not args.port:
raise RuntimeError("was not able to determine a local port")
try:
trio.run(run, *(args.port, args.destination))
except KeyboardInterrupt:

View File

@ -9,8 +9,10 @@ from libp2p import (
from libp2p.crypto.secp256k1 import (
create_new_key_pair,
)
from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID
from libp2p.security.noise.transport import Transport as NoiseTransport
from libp2p.security.noise.transport import (
PROTOCOL_ID as NOISE_PROTOCOL_ID,
Transport as NoiseTransport,
)
async def main():

View File

@ -9,8 +9,10 @@ from libp2p import (
from libp2p.crypto.secp256k1 import (
create_new_key_pair,
)
from libp2p.security.secio.transport import ID as SECIO_PROTOCOL_ID
from libp2p.security.secio.transport import Transport as SecioTransport
from libp2p.security.secio.transport import (
ID as SECIO_PROTOCOL_ID,
Transport as SecioTransport,
)
async def main():
@ -22,9 +24,6 @@ async def main():
secio_transport = SecioTransport(
# local_key_pair: The key pair used for libp2p identity and authentication
local_key_pair=key_pair,
# secure_bytes_provider: Optional function to generate secure random bytes
# (defaults to secrets.token_bytes)
secure_bytes_provider=None, # Use default implementation
)
# Create a security options dictionary mapping protocol ID to transport

View File

@ -9,10 +9,9 @@ from libp2p import (
from libp2p.crypto.secp256k1 import (
create_new_key_pair,
)
from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID
from libp2p.security.noise.transport import Transport as NoiseTransport
from libp2p.stream_muxer.mplex.mplex import (
MPLEX_PROTOCOL_ID,
from libp2p.security.noise.transport import (
PROTOCOL_ID as NOISE_PROTOCOL_ID,
Transport as NoiseTransport,
)
@ -37,14 +36,8 @@ async def main():
# Create a security options dictionary mapping protocol ID to transport
security_options = {NOISE_PROTOCOL_ID: noise_transport}
# Create a muxer options dictionary mapping protocol ID to muxer class
# We don't need to instantiate the muxer here, the host will do that for us
muxer_options = {MPLEX_PROTOCOL_ID: None}
# Create a host with the key pair, Noise security, and mplex multiplexer
host = new_host(
key_pair=key_pair, sec_opt=security_options, muxer_opt=muxer_options
)
host = new_host(key_pair=key_pair, sec_opt=security_options)
# Configure the listening address
port = 8000

View File

@ -0,0 +1,263 @@
"""
Enhanced NetStream Example for py-libp2p with State Management
This example demonstrates the new NetStream features including:
- State tracking and transitions
- Proper error handling and validation
- Resource cleanup and event notifications
- Thread-safe operations with Trio locks
Based on the standard echo demo but enhanced to show NetStream state management.
"""
import argparse
import random
import secrets
import multiaddr
import trio
from libp2p import (
new_host,
)
from libp2p.crypto.secp256k1 import (
create_new_key_pair,
)
from libp2p.custom_types import (
TProtocol,
)
from libp2p.network.stream.exceptions import (
StreamClosed,
StreamEOF,
StreamReset,
)
from libp2p.network.stream.net_stream import (
NetStream,
StreamState,
)
from libp2p.peer.peerinfo import (
info_from_p2p_addr,
)
PROTOCOL_ID = TProtocol("/echo/1.0.0")
async def enhanced_echo_handler(stream: NetStream) -> None:
"""
Enhanced echo handler that demonstrates NetStream state management.
"""
print(f"New connection established: {stream}")
print(f"Initial stream state: {await stream.state}")
try:
# Verify stream is in expected initial state
assert await stream.state == StreamState.OPEN
assert await stream.is_readable()
assert await stream.is_writable()
print("✓ Stream initialized in OPEN state")
# Read incoming data with proper state checking
print("Waiting for client data...")
while await stream.is_readable():
try:
# Read data from client
data = await stream.read(1024)
if not data:
print("Received empty data, client may have closed")
break
print(f"Received: {data.decode('utf-8').strip()}")
# Check if we can still write before echoing
if await stream.is_writable():
await stream.write(data)
print(f"Echoed: {data.decode('utf-8').strip()}")
else:
print("Cannot echo - stream not writable")
break
except StreamEOF:
print("Client closed their write side (EOF)")
break
except StreamReset:
print("Stream was reset by client")
return
except StreamClosed as e:
print(f"Stream operation failed: {e}")
break
# Demonstrate graceful closure
current_state = await stream.state
print(f"Current state before close: {current_state}")
if current_state not in [StreamState.CLOSE_BOTH, StreamState.RESET]:
await stream.close()
print("Server closed write side")
final_state = await stream.state
print(f"Final stream state: {final_state}")
except Exception as e:
print(f"Handler error: {e}")
# Reset stream on unexpected errors
if await stream.state not in [StreamState.RESET, StreamState.CLOSE_BOTH]:
await stream.reset()
print("Stream reset due to error")
async def enhanced_client_demo(stream: NetStream) -> None:
"""
Enhanced client that demonstrates various NetStream state scenarios.
"""
print(f"Client stream established: {stream}")
print(f"Initial state: {await stream.state}")
try:
# Verify initial state
assert await stream.state == StreamState.OPEN
print("✓ Client stream in OPEN state")
# Scenario 1: Normal communication
message = b"Hello from enhanced NetStream client!\n"
if await stream.is_writable():
await stream.write(message)
print(f"Sent: {message.decode('utf-8').strip()}")
else:
print("Cannot write - stream not writable")
return
# Close write side to signal EOF to server
await stream.close()
print("Client closed write side")
# Verify state transition
state_after_close = await stream.state
print(f"State after close: {state_after_close}")
assert state_after_close == StreamState.CLOSE_WRITE
assert await stream.is_readable() # Should still be readable
assert not await stream.is_writable() # Should not be writable
# Try to write (should fail)
try:
await stream.write(b"This should fail")
print("ERROR: Write succeeded when it should have failed!")
except StreamClosed as e:
print(f"✓ Expected error when writing to closed stream: {e}")
# Read the echo response
if await stream.is_readable():
try:
response = await stream.read()
print(f"Received echo: {response.decode('utf-8').strip()}")
except StreamEOF:
print("Server closed their write side")
except StreamReset:
print("Stream was reset")
# Check final state
final_state = await stream.state
print(f"Final client state: {final_state}")
except Exception as e:
print(f"Client error: {e}")
# Reset on error
await stream.reset()
print("Client reset stream due to error")
async def run_enhanced_demo(
port: int, destination: str, seed: int | None = None
) -> None:
"""
Run enhanced echo demo with NetStream state management.
"""
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
# Generate or use provided key
if seed:
random.seed(seed)
secret_number = random.getrandbits(32 * 8)
secret = secret_number.to_bytes(length=32, byteorder="big")
else:
secret = secrets.token_bytes(32)
host = new_host(key_pair=create_new_key_pair(secret))
async with host.run(listen_addrs=[listen_addr]):
print(f"Host ID: {host.get_id().to_string()}")
print("=" * 60)
if not destination: # Server mode
print("🖥️ ENHANCED ECHO SERVER MODE")
print("=" * 60)
# type: ignore: Stream is type of NetStream
host.set_stream_handler(PROTOCOL_ID, enhanced_echo_handler)
print(
"Run client from another console:\n"
f"python3 example_net_stream.py "
f"-d {host.get_addrs()[0]}\n"
)
print("Waiting for connections...")
print("Press Ctrl+C to stop server")
await trio.sleep_forever()
else: # Client mode
print("📱 ENHANCED ECHO CLIENT MODE")
print("=" * 60)
# Connect to server
maddr = multiaddr.Multiaddr(destination)
info = info_from_p2p_addr(maddr)
await host.connect(info)
print(f"Connected to server: {info.peer_id.pretty()}")
# Create stream and run enhanced demo
stream = await host.new_stream(info.peer_id, [PROTOCOL_ID])
if isinstance(stream, NetStream):
await enhanced_client_demo(stream)
print("\n" + "=" * 60)
print("CLIENT DEMO COMPLETE")
def main() -> None:
example_maddr = (
"/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
)
parser = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
parser.add_argument(
"-d",
"--destination",
type=str,
help=f"destination multiaddr string, e.g. {example_maddr}",
)
parser.add_argument(
"-s",
"--seed",
type=int,
help="seed for deterministic peer ID generation",
)
parser.add_argument(
"--demo-states", action="store_true", help="run state transition demo only"
)
args = parser.parse_args()
try:
trio.run(run_enhanced_demo, args.port, args.destination, args.seed)
except KeyboardInterrupt:
print("\n👋 Demo interrupted by user")
except Exception as e:
print(f"❌ Demo failed: {e}")
if __name__ == "__main__":
main()

View File

@ -12,10 +12,9 @@ from libp2p.crypto.secp256k1 import (
from libp2p.peer.peerinfo import (
info_from_p2p_addr,
)
from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID
from libp2p.security.noise.transport import Transport as NoiseTransport
from libp2p.stream_muxer.mplex.mplex import (
MPLEX_PROTOCOL_ID,
from libp2p.security.noise.transport import (
PROTOCOL_ID as NOISE_PROTOCOL_ID,
Transport as NoiseTransport,
)
@ -40,14 +39,8 @@ async def main():
# Create a security options dictionary mapping protocol ID to transport
security_options = {NOISE_PROTOCOL_ID: noise_transport}
# Create a muxer options dictionary mapping protocol ID to muxer class
# We don't need to instantiate the muxer here, the host will do that for us
muxer_options = {MPLEX_PROTOCOL_ID: None}
# Create a host with the key pair, Noise security, and mplex multiplexer
host = new_host(
key_pair=key_pair, sec_opt=security_options, muxer_opt=muxer_options
)
host = new_host(key_pair=key_pair, sec_opt=security_options)
# Configure the listening address
port = 8000

View File

@ -9,10 +9,9 @@ from libp2p import (
from libp2p.crypto.secp256k1 import (
create_new_key_pair,
)
from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID
from libp2p.security.noise.transport import Transport as NoiseTransport
from libp2p.stream_muxer.mplex.mplex import (
MPLEX_PROTOCOL_ID,
from libp2p.security.noise.transport import (
PROTOCOL_ID as NOISE_PROTOCOL_ID,
Transport as NoiseTransport,
)
@ -37,14 +36,8 @@ async def main():
# Create a security options dictionary mapping protocol ID to transport
security_options = {NOISE_PROTOCOL_ID: noise_transport}
# Create a muxer options dictionary mapping protocol ID to muxer class
# We don't need to instantiate the muxer here, the host will do that for us
muxer_options = {MPLEX_PROTOCOL_ID: None}
# Create a host with the key pair, Noise security, and mplex multiplexer
host = new_host(
key_pair=key_pair, sec_opt=security_options, muxer_opt=muxer_options
)
host = new_host(key_pair=key_pair, sec_opt=security_options)
# Configure the listening address
port = 8000

View File

@ -29,8 +29,7 @@ async def _echo_stream_handler(stream: INetStream) -> None:
await stream.close()
async def run(port: int, destination: str, seed: int = None) -> None:
localhost_ip = "127.0.0.1"
async def run(port: int, destination: str, seed: int | None = None) -> None:
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
if seed:
@ -53,8 +52,8 @@ async def run(port: int, destination: str, seed: int = None) -> None:
print(
"Run this from the same folder in another console:\n\n"
f"echo-demo -p {int(port) + 1} "
f"-d /ip4/{localhost_ip}/tcp/{port}/p2p/{host.get_id().pretty()}\n"
f"echo-demo "
f"-d {host.get_addrs()[0]}\n"
)
print("Waiting for incoming connections...")
await trio.sleep_forever()
@ -73,6 +72,7 @@ async def run(port: int, destination: str, seed: int = None) -> None:
msg = b"hi, there!\n"
await stream.write(msg)
# TODO: check why the stream is closed after the first write ???
# Notify the other side about EOF
await stream.close()
response = await stream.read()
@ -94,9 +94,7 @@ def main() -> None:
"/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
)
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
"-p", "--port", default=8000, type=int, help="source port number"
)
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
parser.add_argument(
"-d",
"--destination",
@ -110,10 +108,6 @@ def main() -> None:
help="provide a seed to the random number generator (e.g. to fix peer IDs across runs)", # noqa: E501
)
args = parser.parse_args()
if not args.port:
raise RuntimeError("was not able to determine a local port")
try:
trio.run(run, args.port, args.destination, args.seed)
except KeyboardInterrupt:

View File

@ -61,20 +61,20 @@ async def run(port: int, destination: str) -> None:
async with host_a.run(listen_addrs=[listen_addr]):
print(
"First host listening. Run this from another console:\n\n"
f"identify-demo -p {int(port) + 1} "
f"-d /ip4/{localhost_ip}/tcp/{port}/p2p/{host_a.get_id().pretty()}\n"
f"identify-demo "
f"-d {host_a.get_addrs()[0]}\n"
)
print("Waiting for incoming identify request...")
await trio.sleep_forever()
else:
# Create second host (dialer)
print(f"dialer (host_b) listening on /ip4/{localhost_ip}/tcp/{port}")
listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}")
host_b = new_host()
async with host_b.run(listen_addrs=[listen_addr]):
# Connect to the first host
print(f"dialer (host_b) listening on {host_b.get_addrs()[0]}")
maddr = multiaddr.Multiaddr(destination)
info = info_from_p2p_addr(maddr)
print(f"Second host connecting to peer: {info.peer_id}")
@ -104,13 +104,11 @@ def main() -> None:
"""
example_maddr = (
"/ip4/127.0.0.1/tcp/8888/p2p/QmQn4SwGkDZkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
"/ip4/127.0.0.1/tcp/8888/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
)
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
"-p", "--port", default=8888, type=int, help="source port number"
)
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
parser.add_argument(
"-d",
"--destination",
@ -119,9 +117,6 @@ def main() -> None:
)
args = parser.parse_args()
if not args.port:
raise RuntimeError("failed to determine local port")
try:
trio.run(run, *(args.port, args.destination))
except KeyboardInterrupt:

View File

@ -38,17 +38,17 @@ from libp2p.crypto.secp256k1 import (
create_new_key_pair,
)
from libp2p.identity.identify import (
ID as ID_IDENTIFY,
identify_handler_for,
)
from libp2p.identity.identify import ID as ID_IDENTIFY
from libp2p.identity.identify.pb.identify_pb2 import (
Identify,
)
from libp2p.identity.identify_push import (
ID_PUSH as ID_IDENTIFY_PUSH,
identify_push_handler_for,
push_identify_to_peer,
)
from libp2p.identity.identify_push import ID_PUSH as ID_IDENTIFY_PUSH
from libp2p.peer.peerinfo import (
info_from_p2p_addr,
)
@ -56,9 +56,6 @@ from libp2p.peer.peerinfo import (
# Configure logging
logger = logging.getLogger("libp2p.identity.identify-push-example")
# Default port configuration
DEFAULT_PORT = 8888
def custom_identify_push_handler_for(host):
"""
@ -241,25 +238,16 @@ def main() -> None:
"""Parse arguments and start the appropriate mode."""
description = """
This program demonstrates the libp2p identify/push protocol.
Without arguments, it runs as a listener on port 8888.
With -d parameter, it runs as a dialer on port 8889.
Without arguments, it runs as a listener on random port.
With -d parameter, it runs as a dialer on random port.
"""
example = (
f"/ip4/127.0.0.1/tcp/{DEFAULT_PORT}/p2p/"
"QmQn4SwGkDZkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
"/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
)
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
"-p",
"--port",
type=int,
help=(
f"port to listen on (default: {DEFAULT_PORT} for listener, "
f"{DEFAULT_PORT + 1} for dialer)"
),
)
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
parser.add_argument(
"-d",
"--destination",
@ -270,13 +258,11 @@ def main() -> None:
try:
if args.destination:
# Run in dialer mode with default port DEFAULT_PORT + 1 if not specified
port = args.port if args.port is not None else DEFAULT_PORT + 1
trio.run(run_dialer, port, args.destination)
# Run in dialer mode with random available port if not specified
trio.run(run_dialer, args.port, args.destination)
else:
# Run in listener mode with default port DEFAULT_PORT if not specified
port = args.port if args.port is not None else DEFAULT_PORT
trio.run(run_listener, port)
# Run in listener mode with random available port if not specified
trio.run(run_listener, args.port)
except KeyboardInterrupt:
print("\nInterrupted by user")
logger.info("Interrupted by user")

View File

@ -55,7 +55,6 @@ async def send_ping(stream: INetStream) -> None:
async def run(port: int, destination: str) -> None:
localhost_ip = "127.0.0.1"
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
host = new_host(listen_addrs=[listen_addr])
@ -65,8 +64,8 @@ async def run(port: int, destination: str) -> None:
print(
"Run this from the same folder in another console:\n\n"
f"ping-demo -p {int(port) + 1} "
f"-d /ip4/{localhost_ip}/tcp/{port}/p2p/{host.get_id().pretty()}\n"
f"ping-demo "
f"-d {host.get_addrs()[0]}\n"
)
print("Waiting for incoming connection...")
@ -96,10 +95,8 @@ def main() -> None:
)
parser = argparse.ArgumentParser(description=description)
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
parser.add_argument(
"-p", "--port", default=8000, type=int, help="source port number"
)
parser.add_argument(
"-d",
"--destination",
@ -108,9 +105,6 @@ def main() -> None:
)
args = parser.parse_args()
if not args.port:
raise RuntimeError("failed to determine local port")
try:
trio.run(run, *(args.port, args.destination))
except KeyboardInterrupt:

View File

@ -1,9 +1,6 @@
import argparse
import logging
import socket
from typing import (
Optional,
)
import base58
import multiaddr
@ -109,7 +106,7 @@ async def monitor_peer_topics(pubsub, nursery, termination_event):
await trio.sleep(2)
async def run(topic: str, destination: Optional[str], port: Optional[int]) -> None:
async def run(topic: str, destination: str | None, port: int | None) -> None:
# Initialize network settings
localhost_ip = "127.0.0.1"

View File

@ -152,12 +152,12 @@ def get_default_muxer_options() -> TMuxerOptions:
def new_swarm(
key_pair: Optional[KeyPair] = None,
muxer_opt: Optional[TMuxerOptions] = None,
sec_opt: Optional[TSecurityOptions] = None,
peerstore_opt: Optional[IPeerStore] = None,
muxer_preference: Optional[Literal["YAMUX", "MPLEX"]] = None,
listen_addrs: Optional[Sequence[multiaddr.Multiaddr]] = None,
key_pair: KeyPair | None = None,
muxer_opt: TMuxerOptions | None = None,
sec_opt: TSecurityOptions | None = None,
peerstore_opt: IPeerStore | None = None,
muxer_preference: Literal["YAMUX", "MPLEX"] | None = None,
listen_addrs: Sequence[multiaddr.Multiaddr] | None = None,
) -> INetworkService:
"""
Create a swarm instance based on the parameters.
@ -238,13 +238,13 @@ def new_swarm(
def new_host(
key_pair: Optional[KeyPair] = None,
muxer_opt: Optional[TMuxerOptions] = None,
sec_opt: Optional[TSecurityOptions] = None,
peerstore_opt: Optional[IPeerStore] = None,
disc_opt: Optional[IPeerRouting] = None,
muxer_preference: Optional[Literal["YAMUX", "MPLEX"]] = None,
listen_addrs: Sequence[multiaddr.Multiaddr] = None,
key_pair: KeyPair | None = None,
muxer_opt: TMuxerOptions | None = None,
sec_opt: TSecurityOptions | None = None,
peerstore_opt: IPeerStore | None = None,
disc_opt: IPeerRouting | None = None,
muxer_preference: Literal["YAMUX", "MPLEX"] | None = None,
listen_addrs: Sequence[multiaddr.Multiaddr] | None = None,
) -> IHost:
"""
Create a new libp2p host based on the given parameters.

View File

@ -8,6 +8,10 @@ from collections.abc import (
KeysView,
Sequence,
)
from contextlib import AbstractAsyncContextManager
from types import (
TracebackType,
)
from typing import (
TYPE_CHECKING,
Any,
@ -156,7 +160,11 @@ class IMuxedConn(ABC):
event_started: trio.Event
@abstractmethod
def __init__(self, conn: ISecureConn, peer_id: ID) -> None:
def __init__(
self,
conn: ISecureConn,
peer_id: ID,
) -> None:
"""
Initialize a new multiplexed connection.
@ -215,7 +223,7 @@ class IMuxedConn(ABC):
"""
class IMuxedStream(ReadWriteCloser):
class IMuxedStream(ReadWriteCloser, AsyncContextManager["IMuxedStream"]):
"""
Interface for a multiplexed stream.
@ -249,6 +257,20 @@ class IMuxedStream(ReadWriteCloser):
otherwise False.
"""
@abstractmethod
async def __aenter__(self) -> "IMuxedStream":
"""Enter the async context manager."""
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Exit the async context manager and close the stream."""
await self.close()
# -------------------------- net_stream interface.py --------------------------
@ -269,7 +291,7 @@ class INetStream(ReadWriteCloser):
muxed_conn: IMuxedConn
@abstractmethod
def get_protocol(self) -> TProtocol:
def get_protocol(self) -> TProtocol | None:
"""
Retrieve the protocol identifier for the stream.
@ -898,7 +920,7 @@ class INetwork(ABC):
"""
@abstractmethod
async def listen(self, *multiaddrs: Sequence[Multiaddr]) -> bool:
async def listen(self, *multiaddrs: Multiaddr) -> bool:
"""
Start listening on one or more multiaddresses.
@ -1156,7 +1178,9 @@ class IHost(ABC):
"""
@abstractmethod
def run(self, listen_addrs: Sequence[Multiaddr]) -> AsyncContextManager[None]:
def run(
self, listen_addrs: Sequence[Multiaddr]
) -> AbstractAsyncContextManager[None]:
"""
Run the host and start listening on the specified multiaddresses.
@ -1416,6 +1440,60 @@ class IPeerData(ABC):
"""
@abstractmethod
def update_last_identified(self) -> None:
"""
Updates timestamp to current time.
"""
@abstractmethod
def get_last_identified(self) -> int:
"""
Fetch the last identified timestamp
Returns
-------
last_identified_timestamp
The lastIdentified time of peer.
"""
@abstractmethod
def get_ttl(self) -> int:
"""
Get ttl value for the peer for validity check
Returns
-------
int
The ttl of the peer.
"""
@abstractmethod
def set_ttl(self, ttl: int) -> None:
"""
Set ttl value for the peer for validity check
Parameters
----------
ttl : int
The ttl for the peer.
"""
@abstractmethod
def is_expired(self) -> bool:
"""
Check if the peer is expired based on last_identified and ttl
Returns
-------
bool
True, if last_identified + ttl > current_time
"""
# ------------------ multiselect_communicator interface.py ------------------
@ -1546,7 +1624,7 @@ class IMultiselectMuxer(ABC):
and its corresponding handler for communication.
"""
handlers: dict[TProtocol, StreamHandlerFn]
handlers: dict[TProtocol | None, StreamHandlerFn | None]
@abstractmethod
def add_handler(self, protocol: TProtocol, handler: StreamHandlerFn) -> None:
@ -1562,7 +1640,7 @@ class IMultiselectMuxer(ABC):
"""
def get_protocols(self) -> tuple[TProtocol, ...]:
def get_protocols(self) -> tuple[TProtocol | None, ...]:
"""
Retrieve the protocols for which handlers have been registered.
@ -1577,7 +1655,7 @@ class IMultiselectMuxer(ABC):
@abstractmethod
async def negotiate(
self, communicator: IMultiselectCommunicator
) -> tuple[TProtocol, StreamHandlerFn]:
) -> tuple[TProtocol | None, StreamHandlerFn | None]:
"""
Negotiate a protocol selection with a multiselect client.
@ -1654,7 +1732,7 @@ class IPeerRouting(ABC):
"""
@abstractmethod
async def find_peer(self, peer_id: ID) -> PeerInfo:
async def find_peer(self, peer_id: ID) -> PeerInfo | None:
"""
Search for a peer with the specified peer ID.
@ -1822,6 +1900,11 @@ class IPubsubRouter(ABC):
"""
mesh: dict[str, set[ID]]
fanout: dict[str, set[ID]]
peer_protocol: dict[ID, TProtocol]
degree: int
@abstractmethod
def get_protocols(self) -> list[TProtocol]:
"""
@ -1847,7 +1930,7 @@ class IPubsubRouter(ABC):
"""
@abstractmethod
def add_peer(self, peer_id: ID, protocol_id: TProtocol) -> None:
def add_peer(self, peer_id: ID, protocol_id: TProtocol | None) -> None:
"""
Notify the router that a new peer has connected.

View File

@ -116,15 +116,15 @@ def initialize_pair(
EncryptionParameters(
cipher_type,
hash_type,
first_half[0:iv_size],
first_half[iv_size + cipher_key_size :],
first_half[iv_size : iv_size + cipher_key_size],
bytes(first_half[0:iv_size]),
bytes(first_half[iv_size + cipher_key_size :]),
bytes(first_half[iv_size : iv_size + cipher_key_size]),
),
EncryptionParameters(
cipher_type,
hash_type,
second_half[0:iv_size],
second_half[iv_size + cipher_key_size :],
second_half[iv_size : iv_size + cipher_key_size],
bytes(second_half[0:iv_size]),
bytes(second_half[iv_size + cipher_key_size :]),
bytes(second_half[iv_size : iv_size + cipher_key_size]),
),
)

View File

@ -9,29 +9,40 @@ from libp2p.crypto.keys import (
if sys.platform != "win32":
from fastecdsa import (
curve as curve_types,
keys,
point,
)
from fastecdsa import curve as curve_types
from fastecdsa.encoding.sec1 import (
SEC1Encoder,
)
else:
from coincurve import PrivateKey as CPrivateKey
from coincurve import PublicKey as CPublicKey
from coincurve import (
PrivateKey as CPrivateKey,
PublicKey as CPublicKey,
)
def infer_local_type(curve: str) -> object:
"""
Convert a str representation of some elliptic curve to a
representation understood by the backend of this module.
"""
if curve != "P-256":
raise NotImplementedError("Only P-256 curve is supported")
if sys.platform != "win32":
if sys.platform != "win32":
def infer_local_type(curve: str) -> curve_types.Curve:
"""
Convert a str representation of some elliptic curve to a
representation understood by the backend of this module.
"""
if curve != "P-256":
raise NotImplementedError("Only P-256 curve is supported")
return curve_types.P256
return "P-256" # coincurve only supports P-256
else:
def infer_local_type(curve: str) -> str:
"""
Convert a str representation of some elliptic curve to a
representation understood by the backend of this module.
"""
if curve != "P-256":
raise NotImplementedError("Only P-256 curve is supported")
return "P-256" # coincurve only supports P-256
if sys.platform != "win32":
@ -68,7 +79,10 @@ if sys.platform != "win32":
return cls(private_key_impl, curve_type)
def to_bytes(self) -> bytes:
return keys.export_key(self.impl, self.curve)
key_str = keys.export_key(self.impl, self.curve)
if key_str is None:
raise Exception("Key not found")
return key_str.encode()
def get_type(self) -> KeyType:
return KeyType.ECC_P256

View File

@ -4,8 +4,10 @@ from Crypto.Hash import (
from nacl.exceptions import (
BadSignatureError,
)
from nacl.public import PrivateKey as PrivateKeyImpl
from nacl.public import PublicKey as PublicKeyImpl
from nacl.public import (
PrivateKey as PrivateKeyImpl,
PublicKey as PublicKeyImpl,
)
from nacl.signing import (
SigningKey,
VerifyKey,
@ -48,7 +50,7 @@ class Ed25519PrivateKey(PrivateKey):
self.impl = impl
@classmethod
def new(cls, seed: bytes = None) -> "Ed25519PrivateKey":
def new(cls, seed: bytes | None = None) -> "Ed25519PrivateKey":
if not seed:
seed = utils.random()
@ -75,7 +77,7 @@ class Ed25519PrivateKey(PrivateKey):
return Ed25519PublicKey(self.impl.public_key)
def create_new_key_pair(seed: bytes = None) -> KeyPair:
def create_new_key_pair(seed: bytes | None = None) -> KeyPair:
private_key = Ed25519PrivateKey.new(seed)
public_key = private_key.get_public_key()
return KeyPair(private_key, public_key)

View File

@ -1,6 +1,6 @@
from collections.abc import Callable
import sys
from typing import (
Callable,
cast,
)

View File

@ -81,12 +81,10 @@ class PrivateKey(Key):
"""A ``PrivateKey`` represents a cryptographic private key."""
@abstractmethod
def sign(self, data: bytes) -> bytes:
...
def sign(self, data: bytes) -> bytes: ...
@abstractmethod
def get_public_key(self) -> PublicKey:
...
def get_public_key(self) -> PublicKey: ...
def _serialize_to_protobuf(self) -> crypto_pb2.PrivateKey:
"""Return the protobuf representation of this ``Key``."""

View File

@ -37,7 +37,7 @@ class Secp256k1PrivateKey(PrivateKey):
self.impl = impl
@classmethod
def new(cls, secret: bytes = None) -> "Secp256k1PrivateKey":
def new(cls, secret: bytes | None = None) -> "Secp256k1PrivateKey":
private_key_impl = coincurve.PrivateKey(secret)
return cls(private_key_impl)
@ -65,7 +65,7 @@ class Secp256k1PrivateKey(PrivateKey):
return Secp256k1PublicKey(public_key_impl)
def create_new_key_pair(secret: bytes = None) -> KeyPair:
def create_new_key_pair(secret: bytes | None = None) -> KeyPair:
"""
Returns a new Secp256k1 keypair derived from the provided ``secret``, a
sequence of bytes corresponding to some integer between 0 and the group

View File

@ -1,13 +1,9 @@
from collections.abc import (
Awaitable,
Callable,
Mapping,
)
from typing import (
TYPE_CHECKING,
Callable,
NewType,
Union,
)
from typing import TYPE_CHECKING, NewType, Union, cast
if TYPE_CHECKING:
from libp2p.abc import (
@ -16,15 +12,9 @@ if TYPE_CHECKING:
ISecureTransport,
)
else:
class INetStream:
pass
class IMuxedConn:
pass
class ISecureTransport:
pass
IMuxedConn = cast(type, object)
INetStream = cast(type, object)
ISecureTransport = cast(type, object)
from libp2p.io.abc import (
@ -38,10 +28,10 @@ from libp2p.pubsub.pb import (
)
TProtocol = NewType("TProtocol", str)
StreamHandlerFn = Callable[["INetStream"], Awaitable[None]]
StreamHandlerFn = Callable[[INetStream], Awaitable[None]]
THandler = Callable[[ReadWriteCloser], Awaitable[None]]
TSecurityOptions = Mapping[TProtocol, "ISecureTransport"]
TMuxerClass = type["IMuxedConn"]
TSecurityOptions = Mapping[TProtocol, ISecureTransport]
TMuxerClass = type[IMuxedConn]
TMuxerOptions = Mapping[TProtocol, TMuxerClass]
SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool]
AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]]

View File

@ -1,7 +1,4 @@
import logging
from typing import (
Union,
)
from libp2p.custom_types import (
TProtocol,
@ -94,7 +91,7 @@ class AutoNATService:
finally:
await stream.close()
async def _handle_request(self, request: Union[bytes, Message]) -> Message:
async def _handle_request(self, request: bytes | Message) -> Message:
"""
Process an AutoNAT protocol request.

View File

@ -84,26 +84,23 @@ class AutoNAT:
request: Any,
target: str,
options: tuple[Any, ...] = (),
channel_credentials: Optional[Any] = None,
call_credentials: Optional[Any] = None,
channel_credentials: Any | None = None,
call_credentials: Any | None = None,
insecure: bool = False,
compression: Optional[Any] = None,
wait_for_ready: Optional[bool] = None,
timeout: Optional[float] = None,
metadata: Optional[list[tuple[str, str]]] = None,
compression: Any | None = None,
wait_for_ready: bool | None = None,
timeout: float | None = None,
metadata: list[tuple[str, str]] | None = None,
) -> Any:
return grpc.experimental.unary_unary(
request,
target,
channel = grpc.secure_channel(target, channel_credentials) if channel_credentials else grpc.insecure_channel(target)
return channel.unary_unary(
"/autonat.pb.AutoNAT/Dial",
autonat__pb2.Message.SerializeToString,
autonat__pb2.Message.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
request_serializer=autonat__pb2.Message.SerializeToString,
response_deserializer=autonat__pb2.Message.FromString,
_registered_method=True,
)(
request,
timeout=timeout,
metadata=metadata,
wait_for_ready=wait_for_ready,
)

View File

@ -3,6 +3,7 @@ from collections.abc import (
Sequence,
)
from contextlib import (
AbstractAsyncContextManager,
asynccontextmanager,
)
import logging
@ -88,14 +89,14 @@ class BasicHost(IHost):
def __init__(
self,
network: INetworkService,
default_protocols: "OrderedDict[TProtocol, StreamHandlerFn]" = None,
default_protocols: Optional["OrderedDict[TProtocol, StreamHandlerFn]"] = None,
) -> None:
self._network = network
self._network.set_stream_handler(self._swarm_stream_handler)
self.peerstore = self._network.peerstore
# Protocol muxing
default_protocols = default_protocols or get_default_protocols(self)
self.multiselect = Multiselect(default_protocols)
self.multiselect = Multiselect(dict(default_protocols.items()))
self.multiselect_client = MultiselectClient()
def get_id(self) -> ID:
@ -147,19 +148,23 @@ class BasicHost(IHost):
"""
return list(self._network.connections.keys())
@asynccontextmanager
async def run(
def run(
self, listen_addrs: Sequence[multiaddr.Multiaddr]
) -> AsyncIterator[None]:
) -> AbstractAsyncContextManager[None]:
"""
Run the host instance and listen to ``listen_addrs``.
:param listen_addrs: a sequence of multiaddrs that we want to listen to
"""
network = self.get_network()
async with background_trio_service(network):
await network.listen(*listen_addrs)
yield
@asynccontextmanager
async def _run() -> AsyncIterator[None]:
network = self.get_network()
async with background_trio_service(network):
await network.listen(*listen_addrs)
yield
return _run()
def set_stream_handler(
self, protocol_id: TProtocol, stream_handler: StreamHandlerFn
@ -229,7 +234,7 @@ class BasicHost(IHost):
:param peer_info: peer_info of the peer we want to connect to
:type peer_info: peer.peerinfo.PeerInfo
"""
self.peerstore.add_addrs(peer_info.peer_id, peer_info.addrs, 10)
self.peerstore.add_addrs(peer_info.peer_id, peer_info.addrs, 120)
# there is already a connection to this peer
if peer_info.peer_id in self._network.connections:
@ -258,6 +263,15 @@ class BasicHost(IHost):
await net_stream.reset()
return
net_stream.set_protocol(protocol)
if handler is None:
logger.debug(
"no handler for protocol %s, closing stream from peer %s",
protocol,
net_stream.muxed_conn.peer_id,
)
await net_stream.reset()
return
await handler(net_stream)
def get_live_peers(self) -> list[ID]:
@ -277,7 +291,7 @@ class BasicHost(IHost):
"""
return peer_id in self._network.connections
def get_peer_connection_info(self, peer_id: ID) -> Optional[INetConn]:
def get_peer_connection_info(self, peer_id: ID) -> INetConn | None:
"""
Get connection information for a specific peer if connected.

View File

@ -9,13 +9,13 @@ from libp2p.abc import (
IHost,
)
from libp2p.host.ping import (
ID as PingID,
handle_ping,
)
from libp2p.host.ping import ID as PingID
from libp2p.identity.identify.identify import (
ID as IdentifyID,
identify_handler_for,
)
from libp2p.identity.identify.identify import ID as IdentifyID
if TYPE_CHECKING:
from libp2p.custom_types import (

View File

@ -40,8 +40,8 @@ class RoutedHost(BasicHost):
found_peer_info = await self._router.find_peer(peer_info.peer_id)
if not found_peer_info:
raise ConnectionFailure("Unable to find Peer address")
self.peerstore.add_addrs(peer_info.peer_id, found_peer_info.addrs, 10)
self.peerstore.add_addrs(peer_info.peer_id, peer_info.addrs, 10)
self.peerstore.add_addrs(peer_info.peer_id, found_peer_info.addrs, 120)
self.peerstore.add_addrs(peer_info.peer_id, peer_info.addrs, 120)
# there is already a connection to this peer
if peer_info.peer_id in self._network.connections:

View File

@ -1,7 +1,4 @@
import logging
from typing import (
Optional,
)
from multiaddr import (
Multiaddr,
@ -40,8 +37,8 @@ def _multiaddr_to_bytes(maddr: Multiaddr) -> bytes:
def _remote_address_to_multiaddr(
remote_address: Optional[tuple[str, int]]
) -> Optional[Multiaddr]:
remote_address: tuple[str, int] | None,
) -> Multiaddr | None:
"""Convert a (host, port) tuple to a Multiaddr."""
if remote_address is None:
return None
@ -58,7 +55,7 @@ def _remote_address_to_multiaddr(
def _mk_identify_protobuf(
host: IHost, observed_multiaddr: Optional[Multiaddr]
host: IHost, observed_multiaddr: Multiaddr | None
) -> Identify:
public_key = host.get_public_key()
laddrs = host.get_addrs()
@ -81,15 +78,14 @@ def identify_handler_for(host: IHost) -> StreamHandlerFn:
peer_id = (
stream.muxed_conn.peer_id
) # remote peer_id is in class Mplex (mplex.py )
observed_multiaddr: Multiaddr | None = None
# Get the remote address
try:
remote_address = stream.get_remote_address()
# Convert to multiaddr
if remote_address:
observed_multiaddr = _remote_address_to_multiaddr(remote_address)
else:
observed_multiaddr = None
logger.debug(
"Connection from remote peer %s, address: %s, multiaddr: %s",
peer_id,

View File

@ -1,7 +1,4 @@
import logging
from typing import (
Optional,
)
from multiaddr import (
Multiaddr,
@ -135,7 +132,7 @@ async def _update_peerstore_from_identify(
async def push_identify_to_peer(
host: IHost, peer_id: ID, observed_multiaddr: Optional[Multiaddr] = None
host: IHost, peer_id: ID, observed_multiaddr: Multiaddr | None = None
) -> bool:
"""
Push an identify message to a specific peer.
@ -172,8 +169,8 @@ async def push_identify_to_peer(
async def push_identify_to_peers(
host: IHost,
peer_ids: Optional[set[ID]] = None,
observed_multiaddr: Optional[Multiaddr] = None,
peer_ids: set[ID] | None = None,
observed_multiaddr: Multiaddr | None = None,
) -> None:
"""
Push an identify message to multiple peers in parallel.

View File

@ -2,27 +2,22 @@ from abc import (
ABC,
abstractmethod,
)
from typing import (
Optional,
)
from typing import Any
class Closer(ABC):
@abstractmethod
async def close(self) -> None:
...
async def close(self) -> None: ...
class Reader(ABC):
@abstractmethod
async def read(self, n: int = None) -> bytes:
...
async def read(self, n: int | None = None) -> bytes: ...
class Writer(ABC):
@abstractmethod
async def write(self, data: bytes) -> None:
...
async def write(self, data: bytes) -> None: ...
class WriteCloser(Writer, Closer):
@ -39,7 +34,7 @@ class ReadWriter(Reader, Writer):
class ReadWriteCloser(Reader, Writer, Closer):
@abstractmethod
def get_remote_address(self) -> Optional[tuple[str, int]]:
def get_remote_address(self) -> tuple[str, int] | None:
"""
Return the remote address of the connected peer.
@ -50,14 +45,12 @@ class ReadWriteCloser(Reader, Writer, Closer):
class MsgReader(ABC):
@abstractmethod
async def read_msg(self) -> bytes:
...
async def read_msg(self) -> bytes: ...
class MsgWriter(ABC):
@abstractmethod
async def write_msg(self, msg: bytes) -> None:
...
async def write_msg(self, msg: bytes) -> None: ...
class MsgReadWriteCloser(MsgReader, MsgWriter, Closer):
@ -66,19 +59,26 @@ class MsgReadWriteCloser(MsgReader, MsgWriter, Closer):
class Encrypter(ABC):
@abstractmethod
def encrypt(self, data: bytes) -> bytes:
...
def encrypt(self, data: bytes) -> bytes: ...
@abstractmethod
def decrypt(self, data: bytes) -> bytes:
...
def decrypt(self, data: bytes) -> bytes: ...
class EncryptedMsgReadWriter(MsgReadWriteCloser, Encrypter):
"""Read/write message with encryption/decryption."""
def get_remote_address(self) -> Optional[tuple[str, int]]:
conn: Any | None
def __init__(self, conn: Any | None = None):
self.conn = conn
def get_remote_address(self) -> tuple[str, int] | None:
"""Get remote address if supported by the underlying connection."""
if hasattr(self, "conn") and hasattr(self.conn, "get_remote_address"):
if (
self.conn is not None
and hasattr(self, "conn")
and hasattr(self.conn, "get_remote_address")
):
return self.conn.get_remote_address()
return None

View File

@ -5,6 +5,7 @@ from that repo: "a simple package to r/w length-delimited slices."
NOTE: currently missing the capability to indicate lengths by "varint" method.
"""
from abc import (
abstractmethod,
)
@ -60,12 +61,10 @@ class BaseMsgReadWriter(MsgReadWriteCloser):
return await read_exactly(self.read_write_closer, length)
@abstractmethod
async def next_msg_len(self) -> int:
...
async def next_msg_len(self) -> int: ...
@abstractmethod
def encode_msg(self, msg: bytes) -> bytes:
...
def encode_msg(self, msg: bytes) -> bytes: ...
async def close(self) -> None:
await self.read_write_closer.close()

View File

@ -1,7 +1,4 @@
import logging
from typing import (
Optional,
)
import trio
@ -34,7 +31,7 @@ class TrioTCPStream(ReadWriteCloser):
except (trio.ClosedResourceError, trio.BrokenResourceError) as error:
raise IOException from error
async def read(self, n: int = None) -> bytes:
async def read(self, n: int | None = None) -> bytes:
async with self.read_lock:
if n is not None and n == 0:
return b""
@ -46,7 +43,7 @@ class TrioTCPStream(ReadWriteCloser):
async def close(self) -> None:
await self.stream.aclose()
def get_remote_address(self) -> Optional[tuple[str, int]]:
def get_remote_address(self) -> tuple[str, int] | None:
"""Return the remote address as (host, port) tuple."""
try:
return self.stream.socket.getpeername()

View File

@ -14,12 +14,14 @@ async def read_exactly(
"""
NOTE: relying on exceptions to break out on erroneous conditions, like EOF
"""
data = await reader.read(n)
buffer = bytearray()
buffer.extend(await reader.read(n))
for _ in range(retry_count):
if len(data) < n:
remaining = n - len(data)
data += await reader.read(remaining)
if len(buffer) < n:
remaining = n - len(buffer)
buffer.extend(await reader.read(remaining))
else:
return data
raise IncompleteReadError({"requested_count": n, "received_count": len(data)})
return bytes(buffer)
raise IncompleteReadError({"requested_count": n, "received_count": len(buffer)})

View File

@ -1,7 +1,3 @@
from typing import (
Optional,
)
from libp2p.abc import (
IRawConnection,
)
@ -32,7 +28,7 @@ class RawConnection(IRawConnection):
except IOException as error:
raise RawConnError from error
async def read(self, n: int = None) -> bytes:
async def read(self, n: int | None = None) -> bytes:
"""
Read up to ``n`` bytes from the underlying stream. This call is
delegated directly to the underlying ``self.reader``.
@ -47,6 +43,6 @@ class RawConnection(IRawConnection):
async def close(self) -> None:
await self.stream.close()
def get_remote_address(self) -> Optional[tuple[str, int]]:
def get_remote_address(self) -> tuple[str, int] | None:
"""Delegate to the underlying stream's get_remote_address method."""
return self.stream.get_remote_address()

View File

@ -22,7 +22,7 @@ if TYPE_CHECKING:
"""
Reference: https://github.com/libp2p/go-libp2p-swarm/blob/04c86bbdafd390651cb2ee14e334f7caeedad722/swarm_conn.go # noqa: E501
Reference: https://github.com/libp2p/go-libp2p-swarm/blob/04c86bbdafd390651cb2ee14e334f7caeedad722/swarm_conn.go
"""
@ -32,7 +32,11 @@ class SwarmConn(INetConn):
streams: set[NetStream]
event_closed: trio.Event
def __init__(self, muxed_conn: IMuxedConn, swarm: "Swarm") -> None:
def __init__(
self,
muxed_conn: IMuxedConn,
swarm: "Swarm",
) -> None:
self.muxed_conn = muxed_conn
self.swarm = swarm
self.streams = set()
@ -40,7 +44,7 @@ class SwarmConn(INetConn):
self.event_started = trio.Event()
if hasattr(muxed_conn, "on_close"):
logging.debug(f"Setting on_close for peer {muxed_conn.peer_id}")
muxed_conn.on_close = self._on_muxed_conn_closed
setattr(muxed_conn, "on_close", self._on_muxed_conn_closed)
else:
logging.error(
f"muxed_conn for peer {muxed_conn.peer_id} has no on_close attribute"

View File

@ -1,7 +1,9 @@
from typing import (
Optional,
from enum import (
Enum,
)
import trio
from libp2p.abc import (
IMuxedStream,
INetStream,
@ -23,19 +25,103 @@ from .exceptions import (
)
# TODO: Handle exceptions from `muxed_stream`
# TODO: Add stream state
# - Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501
class NetStream(INetStream):
muxed_stream: IMuxedStream
protocol_id: Optional[TProtocol]
class StreamState(Enum):
"""NetStream States"""
OPEN = "open"
CLOSE_READ = "close_read"
CLOSE_WRITE = "close_write"
CLOSE_BOTH = "close_both"
RESET = "reset"
class NetStream(INetStream):
"""
Summary
_______
A Network stream implementation.
NetStream wraps a muxed stream and provides proper state tracking, resource cleanup,
and event notification capabilities.
State Machine
_____________
.. code:: markdown
[CREATED] → OPEN → CLOSE_READ → CLOSE_BOTH → [CLEANUP]
↓ ↗ ↗
CLOSE_WRITE → ← ↗
↓ ↗
RESET → → → → → → → →
State Transitions
_________________
- OPEN → CLOSE_READ: EOF encountered during read()
- OPEN → CLOSE_WRITE: Explicit close() call
- OPEN → RESET: reset() call or critical stream error
- CLOSE_READ → CLOSE_BOTH: Explicit close() call
- CLOSE_WRITE → CLOSE_BOTH: EOF encountered during read()
- Any state → RESET: reset() call
Terminal States (trigger cleanup)
_________________________________
- CLOSE_BOTH: Stream fully closed, triggers resource cleanup
- RESET: Stream reset/terminated, triggers resource cleanup
Operation Validity by State
___________________________
OPEN: read() ✓ write() ✓ close() ✓ reset() ✓
CLOSE_READ: read() ✗ write() ✓ close() ✓ reset() ✓
CLOSE_WRITE: read() ✓ write() ✗ close() ✓ reset() ✓
CLOSE_BOTH: read() ✗ write() ✗ close() ✓ reset() ✓
RESET: read() ✗ write() ✗ close() ✓ reset() ✓
Cleanup Process (triggered by CLOSE_BOTH or RESET)
__________________________________________________
1. Remove stream from SwarmConn
2. Notify all listeners with ClosedStream event
3. Decrement reference counter
4. Background cleanup via nursery (if provided)
Thread Safety
_____________
All state operations are protected by trio.Lock() for safe concurrent access.
State checks and modifications are atomic operations.
Example: See :file:`examples/doc-examples/example_net_stream.py`
:param muxed_stream (IMuxedStream): The underlying muxed stream
:param nursery (Optional[trio.Nursery]): Nursery for background cleanup tasks
:raises StreamClosed: When attempting invalid operations on closed streams
:raises StreamEOF: When EOF is encountered during read operations
:raises StreamReset: When the underlying stream has been reset
"""
muxed_stream: IMuxedStream
protocol_id: TProtocol | None
__stream_state: StreamState
def __init__(
self, muxed_stream: IMuxedStream, nursery: trio.Nursery | None = None
) -> None:
super().__init__()
def __init__(self, muxed_stream: IMuxedStream) -> None:
self.muxed_stream = muxed_stream
self.muxed_conn = muxed_stream.muxed_conn
self.protocol_id = None
def get_protocol(self) -> TProtocol:
# For background tasks
self._nursery = nursery
# State management
self.__stream_state = StreamState.OPEN
self._state_lock = trio.Lock()
# For notification handling
self._notify_lock = trio.Lock()
def get_protocol(self) -> TProtocol | None:
"""
:return: protocol id that stream runs on
"""
@ -47,42 +133,176 @@ class NetStream(INetStream):
"""
self.protocol_id = protocol_id
async def read(self, n: int = None) -> bytes:
@property
async def state(self) -> StreamState:
"""Get current stream state."""
async with self._state_lock:
return self.__stream_state
async def read(self, n: int | None = None) -> bytes:
"""
Read from stream.
:param n: number of bytes to read
:return: bytes of input
:raises StreamClosed: If `NetStream` is closed for reading
:raises StreamReset: If `NetStream` is reset
:raises StreamEOF: If trying to read after reaching end of file
:return: Bytes read from the stream
"""
async with self._state_lock:
if self.__stream_state in [
StreamState.CLOSE_READ,
StreamState.CLOSE_BOTH,
]:
raise StreamClosed("Stream is closed for reading")
if self.__stream_state == StreamState.RESET:
raise StreamReset("Stream is reset, cannot be used to read")
try:
return await self.muxed_stream.read(n)
data = await self.muxed_stream.read(n)
return data
except MuxedStreamEOF as error:
async with self._state_lock:
if self.__stream_state == StreamState.CLOSE_WRITE:
self.__stream_state = StreamState.CLOSE_BOTH
await self._remove()
elif self.__stream_state == StreamState.OPEN:
self.__stream_state = StreamState.CLOSE_READ
raise StreamEOF() from error
except MuxedStreamReset as error:
async with self._state_lock:
if self.__stream_state in [
StreamState.OPEN,
StreamState.CLOSE_READ,
StreamState.CLOSE_WRITE,
]:
self.__stream_state = StreamState.RESET
await self._remove()
raise StreamReset() from error
async def write(self, data: bytes) -> None:
"""
Write to stream.
:return: number of bytes written
:param data: bytes to write
:raises StreamClosed: If `NetStream` is closed for writing or reset
:raises StreamClosed: If `StreamError` occurred while writing
"""
async with self._state_lock:
if self.__stream_state in [
StreamState.CLOSE_WRITE,
StreamState.CLOSE_BOTH,
StreamState.RESET,
]:
raise StreamClosed("Stream is closed for writing")
try:
await self.muxed_stream.write(data)
except (MuxedStreamClosed, MuxedStreamError) as error:
async with self._state_lock:
if self.__stream_state == StreamState.OPEN:
self.__stream_state = StreamState.CLOSE_WRITE
elif self.__stream_state == StreamState.CLOSE_READ:
self.__stream_state = StreamState.CLOSE_BOTH
await self._remove()
raise StreamClosed() from error
async def close(self) -> None:
"""Close stream."""
"""Close stream for writing."""
async with self._state_lock:
if self.__stream_state in [
StreamState.CLOSE_BOTH,
StreamState.RESET,
StreamState.CLOSE_WRITE,
]:
return
await self.muxed_stream.close()
async with self._state_lock:
if self.__stream_state == StreamState.CLOSE_READ:
self.__stream_state = StreamState.CLOSE_BOTH
await self._remove()
elif self.__stream_state == StreamState.OPEN:
self.__stream_state = StreamState.CLOSE_WRITE
async def reset(self) -> None:
"""Reset stream, closing both ends."""
async with self._state_lock:
if self.__stream_state == StreamState.RESET:
return
await self.muxed_stream.reset()
def get_remote_address(self) -> Optional[tuple[str, int]]:
async with self._state_lock:
if self.__stream_state in [
StreamState.OPEN,
StreamState.CLOSE_READ,
StreamState.CLOSE_WRITE,
]:
self.__stream_state = StreamState.RESET
await self._remove()
async def _remove(self) -> None:
"""
Remove stream from connection and notify listeners.
This is called when the stream is fully closed or reset.
"""
if hasattr(self.muxed_conn, "remove_stream"):
remove_stream = getattr(self.muxed_conn, "remove_stream")
await remove_stream(self)
# Notify in background using Trio nursery if available
if self._nursery:
self._nursery.start_soon(self._notify_closed)
else:
await self._notify_closed()
async def _notify_closed(self) -> None:
"""
Notify all listeners that the stream has been closed.
This runs in a separate task to avoid blocking the main flow.
"""
async with self._notify_lock:
if hasattr(self.muxed_conn, "swarm"):
swarm = getattr(self.muxed_conn, "swarm")
if hasattr(swarm, "notify_all"):
await swarm.notify_all(
lambda notifiee: notifiee.closed_stream(swarm, self)
)
if hasattr(swarm, "refs") and hasattr(swarm.refs, "done"):
swarm.refs.done()
def get_remote_address(self) -> tuple[str, int] | None:
"""Delegate to the underlying muxed stream."""
return self.muxed_stream.get_remote_address()
# TODO: `remove`: Called by close and write when the stream is in specific states.
# It notifies `ClosedStream` after `SwarmConn.remove_stream` is called.
# Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501
async def is_closed(self) -> bool:
"""Check if stream is closed."""
current_state = await self.state
return current_state in [StreamState.CLOSE_BOTH, StreamState.RESET]
async def is_readable(self) -> bool:
"""Check if stream is readable."""
current_state = await self.state
return current_state not in [
StreamState.CLOSE_READ,
StreamState.CLOSE_BOTH,
StreamState.RESET,
]
async def is_writable(self) -> bool:
"""Check if stream is writable."""
current_state = await self.state
return current_state not in [
StreamState.CLOSE_WRITE,
StreamState.CLOSE_BOTH,
StreamState.RESET,
]
def __str__(self) -> str:
"""String representation of the stream."""
return f"<NetStream[{self.__stream_state.value}] protocol={self.protocol_id}>"

View File

@ -1,7 +1,4 @@
import logging
from typing import (
Optional,
)
from multiaddr import (
Multiaddr,
@ -75,7 +72,7 @@ class Swarm(Service, INetworkService):
connections: dict[ID, INetConn]
listeners: dict[str, IListener]
common_stream_handler: StreamHandlerFn
listener_nursery: Optional[trio.Nursery]
listener_nursery: trio.Nursery | None
event_listener_nursery_created: trio.Event
notifees: list[INotifee]
@ -340,7 +337,9 @@ class Swarm(Service, INetworkService):
if hasattr(self, "transport") and self.transport is not None:
# Check if transport has close method before calling it
if hasattr(self.transport, "close"):
await self.transport.close()
await self.transport.close() # type: ignore
# Ignoring the type above since `transport` may not have a close method
# and we have already checked it with hasattr
logger.debug("swarm successfully closed")
@ -360,7 +359,11 @@ class Swarm(Service, INetworkService):
and start to monitor the connection for its new streams and
disconnection.
"""
swarm_conn = SwarmConn(muxed_conn, self)
swarm_conn = SwarmConn(
muxed_conn,
self,
)
self.manager.run_task(muxed_conn.start)
await muxed_conn.event_started.wait()
self.manager.run_task(swarm_conn.start)

View File

@ -1,7 +1,4 @@
import hashlib
from typing import (
Union,
)
import base58
import multihash
@ -24,7 +21,7 @@ if ENABLE_INLINING:
_digest: bytes
def __init__(self) -> None:
self._digest = bytearray()
self._digest = b""
def update(self, input: bytes) -> None:
self._digest += input
@ -39,8 +36,8 @@ if ENABLE_INLINING:
class ID:
_bytes: bytes
_xor_id: int = None
_b58_str: str = None
_xor_id: int | None = None
_b58_str: str | None = None
def __init__(self, peer_id_bytes: bytes) -> None:
self._bytes = peer_id_bytes
@ -93,7 +90,7 @@ class ID:
return cls(mh_digest.encode())
def sha256_digest(data: Union[str, bytes]) -> bytes:
def sha256_digest(data: str | bytes) -> bytes:
if isinstance(data, str):
data = data.encode("utf8")
return hashlib.sha256(data).digest()

View File

@ -1,6 +1,7 @@
from collections.abc import (
Sequence,
)
import time
from typing import (
Any,
)
@ -19,11 +20,13 @@ from libp2p.crypto.keys import (
class PeerData(IPeerData):
pubkey: PublicKey
privkey: PrivateKey
pubkey: PublicKey | None
privkey: PrivateKey | None
metadata: dict[Any, Any]
protocols: list[str]
addrs: list[Multiaddr]
last_identified: int
ttl: int # Keep ttl=0 by default for always valid
def __init__(self) -> None:
self.pubkey = None
@ -31,6 +34,8 @@ class PeerData(IPeerData):
self.metadata = {}
self.protocols = []
self.addrs = []
self.last_identified = int(time.time())
self.ttl = 0
def get_protocols(self) -> list[str]:
"""
@ -115,6 +120,36 @@ class PeerData(IPeerData):
raise PeerDataError("private key not found")
return self.privkey
def update_last_identified(self) -> None:
self.last_identified = int(time.time())
def get_last_identified(self) -> int:
"""
:return: last identified timestamp
"""
return self.last_identified
def get_ttl(self) -> int:
"""
:return: ttl for current peer
"""
return self.ttl
def set_ttl(self, ttl: int) -> None:
"""
:param ttl: ttl to set
"""
self.ttl = ttl
def is_expired(self) -> bool:
"""
:return: true, if last_identified+ttl > current_time
"""
# for ttl = 0; peer_data is always valid
if self.ttl > 0 and self.last_identified + self.ttl < int(time.time()):
return True
return False
class PeerDataError(KeyError):
"""Raised when a key is not found in peer metadata."""

View File

@ -32,21 +32,31 @@ def info_from_p2p_addr(addr: multiaddr.Multiaddr) -> PeerInfo:
if not addr:
raise InvalidAddrError("`addr` should not be `None`")
parts = addr.split()
parts: list[multiaddr.Multiaddr] = addr.split()
if not parts:
raise InvalidAddrError(
f"`parts`={parts} should at least have a protocol `P_P2P`"
)
p2p_part = parts[-1]
last_protocol_code = p2p_part.protocols()[0].code
if last_protocol_code != multiaddr.protocols.P_P2P:
p2p_protocols = p2p_part.protocols()
if not p2p_protocols:
raise InvalidAddrError("The last part of the address has no protocols")
last_protocol = p2p_protocols[0]
if last_protocol is None:
raise InvalidAddrError("The last protocol is None")
last_protocol_code = last_protocol.code
if last_protocol_code != multiaddr.multiaddr.protocols.P_P2P:
raise InvalidAddrError(
f"The last protocol should be `P_P2P` instead of `{last_protocol_code}`"
)
# make sure the /p2p value parses as a peer.ID
peer_id_str: str = p2p_part.value_for_protocol(multiaddr.protocols.P_P2P)
peer_id_str = p2p_part.value_for_protocol(multiaddr.multiaddr.protocols.P_P2P)
if peer_id_str is None:
raise InvalidAddrError("Missing value for /p2p protocol in multiaddr")
peer_id: ID = ID.from_base58(peer_id_str)
# we might have received just an / p2p part, which means there's no addr.

View File

@ -4,7 +4,6 @@ from collections import (
from collections.abc import (
Sequence,
)
import sys
from typing import (
Any,
)
@ -33,7 +32,7 @@ from .peerinfo import (
PeerInfo,
)
PERMANENT_ADDR_TTL = sys.maxsize
PERMANENT_ADDR_TTL = 0
class PeerStore(IPeerStore):
@ -49,6 +48,8 @@ class PeerStore(IPeerStore):
"""
if peer_id in self.peer_data_map:
peer_data = self.peer_data_map[peer_id]
if peer_data.is_expired():
peer_data.clear_addrs()
return PeerInfo(peer_id, peer_data.get_addrs())
raise PeerStoreError("peer ID not found")
@ -84,6 +85,18 @@ class PeerStore(IPeerStore):
"""
return list(self.peer_data_map.keys())
def valid_peer_ids(self) -> list[ID]:
"""
:return: all of the valid peer IDs stored in peer store
"""
valid_peer_ids: list[ID] = []
for peer_id, peer_data in self.peer_data_map.items():
if not peer_data.is_expired():
valid_peer_ids.append(peer_id)
else:
peer_data.clear_addrs()
return valid_peer_ids
def get(self, peer_id: ID, key: str) -> Any:
"""
:param peer_id: peer ID to get peer data for
@ -108,7 +121,7 @@ class PeerStore(IPeerStore):
peer_data = self.peer_data_map[peer_id]
peer_data.put_metadata(key, val)
def add_addr(self, peer_id: ID, addr: Multiaddr, ttl: int) -> None:
def add_addr(self, peer_id: ID, addr: Multiaddr, ttl: int = 0) -> None:
"""
:param peer_id: peer ID to add address for
:param addr:
@ -116,24 +129,30 @@ class PeerStore(IPeerStore):
"""
self.add_addrs(peer_id, [addr], ttl)
def add_addrs(self, peer_id: ID, addrs: Sequence[Multiaddr], ttl: int) -> None:
def add_addrs(self, peer_id: ID, addrs: Sequence[Multiaddr], ttl: int = 0) -> None:
"""
:param peer_id: peer ID to add address for
:param addrs:
:param ttl: time-to-live for the this record
"""
# Ignore ttl for now
peer_data = self.peer_data_map[peer_id]
peer_data.add_addrs(list(addrs))
peer_data.set_ttl(ttl)
peer_data.update_last_identified()
def addrs(self, peer_id: ID) -> list[Multiaddr]:
"""
:param peer_id: peer ID to get addrs for
:return: list of addrs
:return: list of addrs of a valid peer.
:raise PeerStoreError: if peer ID not found
"""
if peer_id in self.peer_data_map:
return self.peer_data_map[peer_id].get_addrs()
peer_data = self.peer_data_map[peer_id]
if not peer_data.is_expired():
return peer_data.get_addrs()
else:
peer_data.clear_addrs()
raise PeerStoreError("peer ID is expired")
raise PeerStoreError("peer ID not found")
def clear_addrs(self, peer_id: ID) -> None:
@ -153,7 +172,11 @@ class PeerStore(IPeerStore):
for peer_id in self.peer_data_map:
if len(self.peer_data_map[peer_id].get_addrs()) >= 1:
output.append(peer_id)
peer_data = self.peer_data_map[peer_id]
if not peer_data.is_expired():
output.append(peer_id)
else:
peer_data.clear_addrs()
return output
def add_pubkey(self, peer_id: ID, pubkey: PublicKey) -> None:

View File

@ -23,16 +23,20 @@ class Multiselect(IMultiselectMuxer):
communication.
"""
handlers: dict[TProtocol, StreamHandlerFn]
handlers: dict[TProtocol | None, StreamHandlerFn | None]
def __init__(
self, default_handlers: dict[TProtocol, StreamHandlerFn] = None
self,
default_handlers: None
| (dict[TProtocol | None, StreamHandlerFn | None]) = None,
) -> None:
if not default_handlers:
default_handlers = {}
self.handlers = default_handlers
def add_handler(self, protocol: TProtocol, handler: StreamHandlerFn) -> None:
def add_handler(
self, protocol: TProtocol | None, handler: StreamHandlerFn | None
) -> None:
"""
Store the handler with the given protocol.
@ -41,9 +45,10 @@ class Multiselect(IMultiselectMuxer):
"""
self.handlers[protocol] = handler
# FIXME: Make TProtocol Optional[TProtocol] to keep types consistent
async def negotiate(
self, communicator: IMultiselectCommunicator
) -> tuple[TProtocol, StreamHandlerFn]:
) -> tuple[TProtocol, StreamHandlerFn | None]:
"""
Negotiate performs protocol selection.
@ -60,7 +65,7 @@ class Multiselect(IMultiselectMuxer):
raise MultiselectError() from error
if command == "ls":
supported_protocols = list(self.handlers.keys())
supported_protocols = [p for p in self.handlers.keys() if p is not None]
response = "\n".join(supported_protocols) + "\n"
try:
@ -82,6 +87,8 @@ class Multiselect(IMultiselectMuxer):
except MultiselectCommunicatorError as error:
raise MultiselectError() from error
raise MultiselectError("Negotiation failed: no matching protocol")
async def handshake(self, communicator: IMultiselectCommunicator) -> None:
"""
Perform handshake to agree on multiselect protocol.

View File

@ -22,6 +22,9 @@ from libp2p.utils import (
encode_varint_prefixed,
)
from .exceptions import (
PubsubRouterError,
)
from .pb import (
rpc_pb2,
)
@ -37,7 +40,7 @@ logger = logging.getLogger("libp2p.pubsub.floodsub")
class FloodSub(IPubsubRouter):
protocols: list[TProtocol]
pubsub: Pubsub
pubsub: Pubsub | None
def __init__(self, protocols: Sequence[TProtocol]) -> None:
self.protocols = list(protocols)
@ -58,7 +61,7 @@ class FloodSub(IPubsubRouter):
"""
self.pubsub = pubsub
def add_peer(self, peer_id: ID, protocol_id: TProtocol) -> None:
def add_peer(self, peer_id: ID, protocol_id: TProtocol | None) -> None:
"""
Notifies the router that a new peer has been connected.
@ -108,17 +111,22 @@ class FloodSub(IPubsubRouter):
logger.debug("publishing message %s", pubsub_msg)
if self.pubsub is None:
raise PubsubRouterError("pubsub not attached to this instance")
else:
pubsub = self.pubsub
for peer_id in peers_gen:
if peer_id not in self.pubsub.peers:
if peer_id not in pubsub.peers:
continue
stream = self.pubsub.peers[peer_id]
stream = pubsub.peers[peer_id]
# FIXME: We should add a `WriteMsg` similar to write delimited messages.
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107
try:
await stream.write(encode_varint_prefixed(rpc_msg.SerializeToString()))
except StreamClosed:
logger.debug("Fail to publish message to %s: stream closed", peer_id)
self.pubsub._handle_dead_peer(peer_id)
pubsub._handle_dead_peer(peer_id)
async def join(self, topic: str) -> None:
"""
@ -150,12 +158,16 @@ class FloodSub(IPubsubRouter):
:param origin: peer id of the peer the message originate from.
:return: a generator of the peer ids who we send data to.
"""
if self.pubsub is None:
raise PubsubRouterError("pubsub not attached to this instance")
else:
pubsub = self.pubsub
for topic in topic_ids:
if topic not in self.pubsub.peer_topics:
if topic not in pubsub.peer_topics:
continue
for peer_id in self.pubsub.peer_topics[topic]:
for peer_id in pubsub.peer_topics[topic]:
if peer_id in (msg_forwarder, origin):
continue
if peer_id not in self.pubsub.peers:
if peer_id not in pubsub.peers:
continue
yield peer_id

View File

@ -10,6 +10,7 @@ from collections.abc import (
)
import logging
import random
import time
from typing import (
Any,
DefaultDict,
@ -66,7 +67,7 @@ logger = logging.getLogger("libp2p.pubsub.gossipsub")
class GossipSub(IPubsubRouter, Service):
protocols: list[TProtocol]
pubsub: Pubsub
pubsub: Pubsub | None
degree: int
degree_high: int
@ -80,8 +81,7 @@ class GossipSub(IPubsubRouter, Service):
# The protocol peer supports
peer_protocol: dict[ID, TProtocol]
# TODO: Add `time_since_last_publish`
# Create topic --> time since last publish map.
time_since_last_publish: dict[str, int]
mcache: MessageCache
@ -98,7 +98,7 @@ class GossipSub(IPubsubRouter, Service):
degree: int,
degree_low: int,
degree_high: int,
direct_peers: Sequence[PeerInfo] = None,
direct_peers: Sequence[PeerInfo] | None = None,
time_to_live: int = 60,
gossip_window: int = 3,
gossip_history: int = 5,
@ -138,10 +138,9 @@ class GossipSub(IPubsubRouter, Service):
self.direct_peers[direct_peer.peer_id] = direct_peer
self.direct_connect_interval = direct_connect_interval
self.direct_connect_initial_delay = direct_connect_initial_delay
self.time_since_last_publish = {}
async def run(self) -> None:
if self.pubsub is None:
raise NoPubsubAttached
self.manager.run_daemon_task(self.heartbeat)
if len(self.direct_peers) > 0:
self.manager.run_daemon_task(self.direct_connect_heartbeat)
@ -172,7 +171,7 @@ class GossipSub(IPubsubRouter, Service):
logger.debug("attached to pusub")
def add_peer(self, peer_id: ID, protocol_id: TProtocol) -> None:
def add_peer(self, peer_id: ID, protocol_id: TProtocol | None) -> None:
"""
Notifies the router that a new peer has been connected.
@ -181,6 +180,9 @@ class GossipSub(IPubsubRouter, Service):
"""
logger.debug("adding peer %s with protocol %s", peer_id, protocol_id)
if protocol_id is None:
raise ValueError("Protocol cannot be None")
if protocol_id not in (PROTOCOL_ID, floodsub.PROTOCOL_ID):
# We should never enter here. Becuase the `protocol_id` is registered by
# your pubsub instance in multistream-select, but it is not the protocol
@ -242,6 +244,8 @@ class GossipSub(IPubsubRouter, Service):
logger.debug("publishing message %s", pubsub_msg)
for peer_id in peers_gen:
if self.pubsub is None:
raise NoPubsubAttached
if peer_id not in self.pubsub.peers:
continue
stream = self.pubsub.peers[peer_id]
@ -253,6 +257,8 @@ class GossipSub(IPubsubRouter, Service):
except StreamClosed:
logger.debug("Fail to publish message to %s: stream closed", peer_id)
self.pubsub._handle_dead_peer(peer_id)
for topic in pubsub_msg.topicIDs:
self.time_since_last_publish[topic] = int(time.time())
def _get_peers_to_send(
self, topic_ids: Iterable[str], msg_forwarder: ID, origin: ID
@ -266,6 +272,8 @@ class GossipSub(IPubsubRouter, Service):
"""
send_to: set[ID] = set()
for topic in topic_ids:
if self.pubsub is None:
raise NoPubsubAttached
if topic not in self.pubsub.peer_topics:
continue
@ -315,6 +323,9 @@ class GossipSub(IPubsubRouter, Service):
:param topic: topic to join
"""
if self.pubsub is None:
raise NoPubsubAttached
logger.debug("joining topic %s", topic)
if topic in self.mesh:
@ -342,6 +353,7 @@ class GossipSub(IPubsubRouter, Service):
await self.emit_graft(topic, peer)
self.fanout.pop(topic, None)
self.time_since_last_publish.pop(topic, None)
async def leave(self, topic: str) -> None:
# Note: the comments here are the near-exact algorithm description from the spec
@ -464,6 +476,8 @@ class GossipSub(IPubsubRouter, Service):
await trio.sleep(self.direct_connect_initial_delay)
while True:
for direct_peer in self.direct_peers:
if self.pubsub is None:
raise NoPubsubAttached
if direct_peer not in self.pubsub.peers:
try:
await self.pubsub.host.connect(self.direct_peers[direct_peer])
@ -481,6 +495,8 @@ class GossipSub(IPubsubRouter, Service):
peers_to_graft: DefaultDict[ID, list[str]] = defaultdict(list)
peers_to_prune: DefaultDict[ID, list[str]] = defaultdict(list)
for topic in self.mesh:
if self.pubsub is None:
raise NoPubsubAttached
# Skip if no peers have subscribed to the topic
if topic not in self.pubsub.peer_topics:
continue
@ -514,20 +530,26 @@ class GossipSub(IPubsubRouter, Service):
def fanout_heartbeat(self) -> None:
# Note: the comments here are the exact pseudocode from the spec
for topic in self.fanout:
# Delete topic entry if it's not in `pubsub.peer_topics`
# or (TODO) if it's time-since-last-published > ttl
if topic not in self.pubsub.peer_topics:
for topic in list(self.fanout):
if (
self.pubsub is not None
and topic not in self.pubsub.peer_topics
and self.time_since_last_publish.get(topic, 0) + self.time_to_live
< int(time.time())
):
# Remove topic from fanout
del self.fanout[topic]
else:
# Check if fanout peers are still in the topic and remove the ones that are not # noqa: E501
# ref: https://github.com/libp2p/go-libp2p-pubsub/blob/01b9825fbee1848751d90a8469e3f5f43bac8466/gossipsub.go#L498-L504 # noqa: E501
in_topic_fanout_peers = [
peer
for peer in self.fanout[topic]
if peer in self.pubsub.peer_topics[topic]
]
in_topic_fanout_peers: list[ID] = []
if self.pubsub is not None:
in_topic_fanout_peers = [
peer
for peer in self.fanout[topic]
if peer in self.pubsub.peer_topics[topic]
]
self.fanout[topic] = set(in_topic_fanout_peers)
num_fanout_peers_in_topic = len(self.fanout[topic])
@ -547,6 +569,8 @@ class GossipSub(IPubsubRouter, Service):
for topic in self.mesh:
msg_ids = self.mcache.window(topic)
if msg_ids:
if self.pubsub is None:
raise NoPubsubAttached
# Get all pubsub peers in a topic and only add them if they are
# gossipsub peers too
if topic in self.pubsub.peer_topics:
@ -566,6 +590,8 @@ class GossipSub(IPubsubRouter, Service):
for topic in self.fanout:
msg_ids = self.mcache.window(topic)
if msg_ids:
if self.pubsub is None:
raise NoPubsubAttached
# Get all pubsub peers in topic and only add if they are
# gossipsub peers also
if topic in self.pubsub.peer_topics:
@ -614,6 +640,8 @@ class GossipSub(IPubsubRouter, Service):
def _get_in_topic_gossipsub_peers_from_minus(
self, topic: str, num_to_select: int, minus: Iterable[ID]
) -> list[ID]:
if self.pubsub is None:
raise NoPubsubAttached
gossipsub_peers_in_topic = {
peer_id
for peer_id in self.pubsub.peer_topics[topic]
@ -627,6 +655,8 @@ class GossipSub(IPubsubRouter, Service):
self, ihave_msg: rpc_pb2.ControlIHave, sender_peer_id: ID
) -> None:
"""Checks the seen set and requests unknown messages with an IWANT message."""
if self.pubsub is None:
raise NoPubsubAttached
# Get list of all seen (seqnos, from) from the (seqno, from) tuples in
# seen_messages cache
seen_seqnos_and_peers = [
@ -659,7 +689,7 @@ class GossipSub(IPubsubRouter, Service):
msgs_to_forward: list[rpc_pb2.Message] = []
for msg_id_iwant in msg_ids:
# Check if the wanted message ID is present in mcache
msg: rpc_pb2.Message = self.mcache.get(msg_id_iwant)
msg: rpc_pb2.Message | None = self.mcache.get(msg_id_iwant)
# Cache hit
if msg:
@ -677,6 +707,8 @@ class GossipSub(IPubsubRouter, Service):
# 2) Serialize that packet
rpc_msg: bytes = packet.SerializeToString()
if self.pubsub is None:
raise NoPubsubAttached
# 3) Get the stream to this peer
if sender_peer_id not in self.pubsub.peers:
@ -731,9 +763,9 @@ class GossipSub(IPubsubRouter, Service):
def pack_control_msgs(
self,
ihave_msgs: list[rpc_pb2.ControlIHave],
graft_msgs: list[rpc_pb2.ControlGraft],
prune_msgs: list[rpc_pb2.ControlPrune],
ihave_msgs: list[rpc_pb2.ControlIHave] | None,
graft_msgs: list[rpc_pb2.ControlGraft] | None,
prune_msgs: list[rpc_pb2.ControlPrune] | None,
) -> rpc_pb2.ControlMessage:
control_msg: rpc_pb2.ControlMessage = rpc_pb2.ControlMessage()
if ihave_msgs:
@ -765,7 +797,7 @@ class GossipSub(IPubsubRouter, Service):
await self.emit_control_message(control_msg, to_peer)
async def emit_graft(self, topic: str, to_peer: ID) -> None:
async def emit_graft(self, topic: str, id: ID) -> None:
"""Emit graft message, sent to to_peer, for topic."""
graft_msg: rpc_pb2.ControlGraft = rpc_pb2.ControlGraft()
graft_msg.topicID = topic
@ -773,9 +805,9 @@ class GossipSub(IPubsubRouter, Service):
control_msg: rpc_pb2.ControlMessage = rpc_pb2.ControlMessage()
control_msg.graft.extend([graft_msg])
await self.emit_control_message(control_msg, to_peer)
await self.emit_control_message(control_msg, id)
async def emit_prune(self, topic: str, to_peer: ID) -> None:
async def emit_prune(self, topic: str, id: ID) -> None:
"""Emit graft message, sent to to_peer, for topic."""
prune_msg: rpc_pb2.ControlPrune = rpc_pb2.ControlPrune()
prune_msg.topicID = topic
@ -783,11 +815,13 @@ class GossipSub(IPubsubRouter, Service):
control_msg: rpc_pb2.ControlMessage = rpc_pb2.ControlMessage()
control_msg.prune.extend([prune_msg])
await self.emit_control_message(control_msg, to_peer)
await self.emit_control_message(control_msg, id)
async def emit_control_message(
self, control_msg: rpc_pb2.ControlMessage, to_peer: ID
) -> None:
if self.pubsub is None:
raise NoPubsubAttached
# Add control message to packet
packet: rpc_pb2.RPC = rpc_pb2.RPC()
packet.control.CopyFrom(control_msg)

View File

@ -1,9 +1,6 @@
from collections.abc import (
Sequence,
)
from typing import (
Optional,
)
from .pb import (
rpc_pb2,
@ -66,7 +63,7 @@ class MessageCache:
self.history[0].append(CacheEntry(mid, msg.topicIDs))
def get(self, mid: tuple[bytes, bytes]) -> Optional[rpc_pb2.Message]:
def get(self, mid: tuple[bytes, bytes]) -> rpc_pb2.Message | None:
"""
Get a message from the mcache.

View File

@ -4,6 +4,7 @@ from __future__ import (
import base64
from collections.abc import (
Callable,
KeysView,
)
import functools
@ -11,7 +12,6 @@ import hashlib
import logging
import time
from typing import (
Callable,
NamedTuple,
cast,
)
@ -53,6 +53,9 @@ from libp2p.network.stream.exceptions import (
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerdata import (
PeerDataError,
)
from libp2p.tools.async_service import (
Service,
)
@ -120,7 +123,10 @@ class Pubsub(Service, IPubsub):
# Indicate if we should enforce signature verification
strict_signing: bool
sign_key: PrivateKey
sign_key: PrivateKey | None
# Set of blacklisted peer IDs
blacklisted_peers: set[ID]
event_handle_peer_queue_started: trio.Event
event_handle_dead_peer_queue_started: trio.Event
@ -129,7 +135,7 @@ class Pubsub(Service, IPubsub):
self,
host: IHost,
router: IPubsubRouter,
cache_size: int = None,
cache_size: int | None = None,
seen_ttl: int = 120,
sweep_interval: int = 60,
strict_signing: bool = True,
@ -201,6 +207,9 @@ class Pubsub(Service, IPubsub):
self.counter = int(time.time())
# Set of blacklisted peer IDs
self.blacklisted_peers = set()
self.event_handle_peer_queue_started = trio.Event()
self.event_handle_dead_peer_queue_started = trio.Event()
@ -320,6 +329,82 @@ class Pubsub(Service, IPubsub):
if topic in self.topic_validators
)
def add_to_blacklist(self, peer_id: ID) -> None:
"""
Add a peer to the blacklist.
When a peer is blacklisted:
- Any existing connection to that peer is immediately closed and removed
- The peer is removed from all topic subscription mappings
- Future connection attempts from this peer will be rejected
- Messages forwarded by or originating from this peer will be dropped
- The peer will not be able to participate in pubsub communication
:param peer_id: the peer ID to blacklist
"""
self.blacklisted_peers.add(peer_id)
logger.debug("Added peer %s to blacklist", peer_id)
self.manager.run_task(self._teardown_if_connected, peer_id)
async def _teardown_if_connected(self, peer_id: ID) -> None:
"""Close their stream and remove them if connected"""
stream = self.peers.get(peer_id)
if stream is not None:
try:
await stream.reset()
except Exception:
pass
del self.peers[peer_id]
# Also remove from any subscription maps:
for _topic, peerset in self.peer_topics.items():
if peer_id in peerset:
peerset.discard(peer_id)
def remove_from_blacklist(self, peer_id: ID) -> None:
"""
Remove a peer from the blacklist.
Once removed from the blacklist:
- The peer can establish new connections to this node
- Messages from this peer will be processed normally
- The peer can participate in topic subscriptions and message forwarding
:param peer_id: the peer ID to remove from blacklist
"""
self.blacklisted_peers.discard(peer_id)
logger.debug("Removed peer %s from blacklist", peer_id)
def is_peer_blacklisted(self, peer_id: ID) -> bool:
"""
Check if a peer is blacklisted.
:param peer_id: the peer ID to check
:return: True if peer is blacklisted, False otherwise
"""
return peer_id in self.blacklisted_peers
def clear_blacklist(self) -> None:
"""
Clear all peers from the blacklist.
This removes all blacklist restrictions, allowing previously blacklisted
peers to:
- Establish new connections
- Send and forward messages
- Participate in topic subscriptions
"""
self.blacklisted_peers.clear()
logger.debug("Cleared all peers from blacklist")
def get_blacklisted_peers(self) -> set[ID]:
"""
Get a copy of the current blacklisted peers.
Returns a snapshot of all currently blacklisted peer IDs. These peers
are completely isolated from pubsub communication - their connections
are rejected and their messages are dropped.
:return: a set containing all blacklisted peer IDs
"""
return self.blacklisted_peers.copy()
async def stream_handler(self, stream: INetStream) -> None:
"""
Stream handler for pubsub. Gets invoked whenever a new stream is
@ -346,6 +431,10 @@ class Pubsub(Service, IPubsub):
await self.event_handle_dead_peer_queue_started.wait()
async def _handle_new_peer(self, peer_id: ID) -> None:
if self.is_peer_blacklisted(peer_id):
logger.debug("Rejecting blacklisted peer %s", peer_id)
return
try:
stream: INetStream = await self.host.new_stream(peer_id, self.protocols)
except SwarmException as error:
@ -359,7 +448,6 @@ class Pubsub(Service, IPubsub):
except StreamClosed:
logger.debug("Fail to add new peer %s: stream closed", peer_id)
return
# TODO: Check if the peer in black list.
try:
self.router.add_peer(peer_id, stream.get_protocol())
except Exception as error:
@ -549,6 +637,9 @@ class Pubsub(Service, IPubsub):
if self.strict_signing:
priv_key = self.sign_key
if priv_key is None:
raise PeerDataError("private key not found")
signature = priv_key.sign(
PUBSUB_SIGNING_PREFIX.encode() + msg.SerializeToString()
)
@ -609,9 +700,20 @@ class Pubsub(Service, IPubsub):
"""
logger.debug("attempting to publish message %s", msg)
# TODO: Check if the `source` is in the blacklist. If yes, reject.
# Check if the message forwarder (source) is in the blacklist. If yes, reject.
if self.is_peer_blacklisted(msg_forwarder):
logger.debug(
"Rejecting message from blacklisted source peer %s", msg_forwarder
)
return
# TODO: Check if the `from` is in the blacklist. If yes, reject.
# Check if the message originator (from) is in the blacklist. If yes, reject.
msg_from_peer = ID(msg.from_id)
if self.is_peer_blacklisted(msg_from_peer):
logger.debug(
"Rejecting message from blacklisted originator peer %s", msg_from_peer
)
return
# If the message is processed before, return(i.e., don't further process the message) # noqa: E501
if self._is_msg_seen(msg):

View File

@ -1,7 +1,3 @@
from typing import (
Optional,
)
from libp2p.abc import (
ISecureConn,
)
@ -49,5 +45,5 @@ class BaseSession(ISecureConn):
def get_remote_peer(self) -> ID:
return self.remote_peer
def get_remote_public_key(self) -> Optional[PublicKey]:
def get_remote_public_key(self) -> PublicKey:
return self.remote_permanent_pubkey

View File

@ -1,7 +1,7 @@
import secrets
from typing import (
from collections.abc import (
Callable,
)
import secrets
from libp2p.abc import (
ISecureTransport,

View File

@ -93,13 +93,13 @@ class InsecureSession(BaseSession):
async def write(self, data: bytes) -> None:
await self.conn.write(data)
async def read(self, n: int = None) -> bytes:
async def read(self, n: int | None = None) -> bytes:
return await self.conn.read(n)
async def close(self) -> None:
await self.conn.close()
def get_remote_address(self) -> Optional[tuple[str, int]]:
def get_remote_address(self) -> tuple[str, int] | None:
"""
Delegate to the underlying connection's get_remote_address method.
"""
@ -131,6 +131,15 @@ async def run_handshake(
remote_msg.ParseFromString(remote_msg_bytes)
received_peer_id = ID(remote_msg.id)
# Verify that `remote_peer_id` isn't `None`
# That is the only condition that `remote_peer_id` would not need to be checked
# against the `recieved_peer_id` gotten from the outbound/recieved `msg`.
# The check against `received_peer_id` happens in the next if-block
if is_initiator and remote_peer_id is None:
raise HandshakeFailure(
"remote peer ID cannot be None if `is_initiator` is set to `True`"
)
# Verify if the receive `ID` matches the one we originally initialize the session.
# We only need to check it when we are the initiator, because only in that condition
# we possibly knows the `ID` of the remote.

View File

@ -1,5 +1,4 @@
from typing import (
Optional,
cast,
)
@ -10,7 +9,6 @@ from libp2p.abc import (
)
from libp2p.io.abc import (
EncryptedMsgReadWriter,
MsgReadWriteCloser,
ReadWriteCloser,
)
from libp2p.io.msgio import (
@ -40,7 +38,7 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter):
implemented by the subclasses.
"""
read_writer: MsgReadWriteCloser
read_writer: NoisePacketReadWriter
noise_state: NoiseState
# FIXME: This prefix is added in msg#3 in Go. Check whether it's a desired behavior.
@ -50,12 +48,12 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter):
self.read_writer = NoisePacketReadWriter(cast(ReadWriteCloser, conn))
self.noise_state = noise_state
async def write_msg(self, data: bytes, prefix_encoded: bool = False) -> None:
data_encrypted = self.encrypt(data)
async def write_msg(self, msg: bytes, prefix_encoded: bool = False) -> None:
data_encrypted = self.encrypt(msg)
if prefix_encoded:
await self.read_writer.write_msg(self.prefix + data_encrypted)
else:
await self.read_writer.write_msg(data_encrypted)
# Manually add the prefix if needed
data_encrypted = self.prefix + data_encrypted
await self.read_writer.write_msg(data_encrypted)
async def read_msg(self, prefix_encoded: bool = False) -> bytes:
noise_msg_encrypted = await self.read_writer.read_msg()
@ -67,10 +65,11 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter):
async def close(self) -> None:
await self.read_writer.close()
def get_remote_address(self) -> Optional[tuple[str, int]]:
def get_remote_address(self) -> tuple[str, int] | None:
# Delegate to the underlying connection if possible
if hasattr(self.read_writer, "read_write_closer") and hasattr(
self.read_writer.read_write_closer, "get_remote_address"
self.read_writer.read_write_closer,
"get_remote_address",
):
return self.read_writer.read_write_closer.get_remote_address()
return None
@ -78,7 +77,7 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter):
class NoiseHandshakeReadWriter(BaseNoiseMsgReadWriter):
def encrypt(self, data: bytes) -> bytes:
return self.noise_state.write_message(data)
return bytes(self.noise_state.write_message(data))
def decrypt(self, data: bytes) -> bytes:
return bytes(self.noise_state.read_message(data))

View File

@ -19,7 +19,7 @@ SIGNED_DATA_PREFIX = "noise-libp2p-static-key:"
class NoiseHandshakePayload:
id_pubkey: PublicKey
id_sig: bytes
early_data: bytes = None
early_data: bytes | None = None
def serialize(self) -> bytes:
msg = noise_pb.NoiseHandshakePayload(

View File

@ -7,8 +7,10 @@ from cryptography.hazmat.primitives import (
serialization,
)
from noise.backends.default.keypairs import KeyPair as NoiseKeyPair
from noise.connection import Keypair as NoiseKeypairEnum
from noise.connection import NoiseConnection as NoiseState
from noise.connection import (
Keypair as NoiseKeypairEnum,
NoiseConnection as NoiseState,
)
from libp2p.abc import (
IRawConnection,
@ -47,14 +49,12 @@ from .messages import (
class IPattern(ABC):
@abstractmethod
async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn:
...
async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn: ...
@abstractmethod
async def handshake_outbound(
self, conn: IRawConnection, remote_peer: ID
) -> ISecureConn:
...
) -> ISecureConn: ...
class BasePattern(IPattern):
@ -62,13 +62,15 @@ class BasePattern(IPattern):
noise_static_key: PrivateKey
local_peer: ID
libp2p_privkey: PrivateKey
early_data: bytes
early_data: bytes | None
def create_noise_state(self) -> NoiseState:
noise_state = NoiseState.from_name(self.protocol_name)
noise_state.set_keypair_from_private_bytes(
NoiseKeypairEnum.STATIC, self.noise_static_key.to_bytes()
)
if noise_state.noise_protocol is None:
raise NoiseStateError("noise_protocol is not initialized")
return noise_state
def make_handshake_payload(self) -> NoiseHandshakePayload:
@ -84,7 +86,7 @@ class PatternXX(BasePattern):
local_peer: ID,
libp2p_privkey: PrivateKey,
noise_static_key: PrivateKey,
early_data: bytes = None,
early_data: bytes | None = None,
) -> None:
self.protocol_name = b"Noise_XX_25519_ChaChaPoly_SHA256"
self.local_peer = local_peer
@ -96,7 +98,12 @@ class PatternXX(BasePattern):
noise_state = self.create_noise_state()
noise_state.set_as_responder()
noise_state.start_handshake()
if noise_state.noise_protocol is None:
raise NoiseStateError("noise_protocol is not initialized")
handshake_state = noise_state.noise_protocol.handshake_state
if handshake_state is None:
raise NoiseStateError("Handshake state is not initialized")
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
# Consume msg#1.
@ -145,7 +152,11 @@ class PatternXX(BasePattern):
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
noise_state.set_as_initiator()
noise_state.start_handshake()
if noise_state.noise_protocol is None:
raise NoiseStateError("noise_protocol is not initialized")
handshake_state = noise_state.noise_protocol.handshake_state
if handshake_state is None:
raise NoiseStateError("Handshake state is not initialized")
# Send msg#1, which is *not* encrypted.
msg_1 = b""
@ -195,6 +206,8 @@ class PatternXX(BasePattern):
@staticmethod
def _get_pubkey_from_noise_keypair(key_pair: NoiseKeyPair) -> PublicKey:
# Use `Ed25519PublicKey` since 25519 is used in our pattern.
if key_pair.public is None:
raise NoiseStateError("public key is not initialized")
raw_bytes = key_pair.public.public_bytes(
serialization.Encoding.Raw, serialization.PublicFormat.Raw
)

View File

@ -26,7 +26,7 @@ class Transport(ISecureTransport):
libp2p_privkey: PrivateKey
noise_privkey: PrivateKey
local_peer: ID
early_data: bytes
early_data: bytes | None
with_noise_pipes: bool
# NOTE: Implementations that support Noise Pipes must decide whether to use
@ -37,8 +37,8 @@ class Transport(ISecureTransport):
def __init__(
self,
libp2p_keypair: KeyPair,
noise_privkey: PrivateKey = None,
early_data: bytes = None,
noise_privkey: PrivateKey,
early_data: bytes | None = None,
with_noise_pipes: bool = False,
) -> None:
self.libp2p_privkey = libp2p_keypair.private_key

View File

@ -2,9 +2,6 @@ from dataclasses import (
dataclass,
)
import itertools
from typing import (
Optional,
)
import multihash
@ -14,14 +11,10 @@ from libp2p.abc import (
)
from libp2p.crypto.authenticated_encryption import (
EncryptionParameters as AuthenticatedEncryptionParameters,
)
from libp2p.crypto.authenticated_encryption import (
InvalidMACException,
)
from libp2p.crypto.authenticated_encryption import (
MacAndCipher as Encrypter,
initialize_pair as initialize_pair_for_encryption,
)
from libp2p.crypto.authenticated_encryption import MacAndCipher as Encrypter
from libp2p.crypto.ecc import (
ECCPublicKey,
)
@ -91,6 +84,8 @@ class SecioPacketReadWriter(FixedSizeLenMsgReadWriter):
class SecioMsgReadWriter(EncryptedMsgReadWriter):
read_writer: SecioPacketReadWriter
local_encrypter: Encrypter
remote_encrypter: Encrypter
def __init__(
self,
@ -213,7 +208,8 @@ async def _response_to_msg(read_writer: SecioPacketReadWriter, msg: bytes) -> by
def _mk_multihash_sha256(data: bytes) -> bytes:
return multihash.digest(data, "sha2-256")
mh = multihash.digest(data, "sha2-256")
return mh.encode()
def _mk_score(public_key: PublicKey, nonce: bytes) -> bytes:
@ -270,7 +266,7 @@ def _select_encryption_parameters(
async def _establish_session_parameters(
local_peer: PeerID,
local_private_key: PrivateKey,
remote_peer: Optional[PeerID],
remote_peer: PeerID | None,
conn: SecioPacketReadWriter,
nonce: bytes,
) -> tuple[SessionParameters, bytes]:
@ -399,7 +395,7 @@ async def create_secure_session(
local_peer: PeerID,
local_private_key: PrivateKey,
conn: IRawConnection,
remote_peer: PeerID = None,
remote_peer: PeerID | None = None,
) -> ISecureConn:
"""
Attempt the initial `secio` handshake with the remote peer.

View File

@ -1,7 +1,4 @@
import io
from typing import (
Optional,
)
from libp2p.crypto.keys import (
PrivateKey,
@ -44,7 +41,7 @@ class SecureSession(BaseSession):
self._reset_internal_buffer()
def get_remote_address(self) -> Optional[tuple[str, int]]:
def get_remote_address(self) -> tuple[str, int] | None:
"""Delegate to the underlying connection's get_remote_address method."""
return self.conn.get_remote_address()
@ -53,7 +50,7 @@ class SecureSession(BaseSession):
self.low_watermark = 0
self.high_watermark = 0
def _drain(self, n: int) -> bytes:
def _drain(self, n: int | None) -> bytes:
if self.low_watermark == self.high_watermark:
return b""
@ -75,7 +72,7 @@ class SecureSession(BaseSession):
self.low_watermark = 0
self.high_watermark = len(msg)
async def read(self, n: int = None) -> bytes:
async def read(self, n: int | None = None) -> bytes:
if n == 0:
return b""
@ -85,6 +82,9 @@ class SecureSession(BaseSession):
msg = await self.conn.read_msg()
if n is None:
return msg
if n < len(msg):
self._fill(msg)
return self._drain(n)

View File

@ -1,7 +1,4 @@
import logging
from typing import (
Optional,
)
import trio
@ -168,7 +165,7 @@ class Mplex(IMuxedConn):
raise MplexUnavailable
async def send_message(
self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID
self, flag: HeaderTags, data: bytes | None, stream_id: StreamID
) -> int:
"""
Send a message over the connection.
@ -366,6 +363,6 @@ class Mplex(IMuxedConn):
self.event_closed.set()
await self.new_stream_send_channel.aclose()
def get_remote_address(self) -> Optional[tuple[str, int]]:
def get_remote_address(self) -> tuple[str, int] | None:
"""Delegate to the underlying Mplex connection's secured_conn."""
return self.secured_conn.get_remote_address()

View File

@ -1,6 +1,8 @@
from types import (
TracebackType,
)
from typing import (
TYPE_CHECKING,
Optional,
)
import trio
@ -37,9 +39,12 @@ class MplexStream(IMuxedStream):
name: str
stream_id: StreamID
muxed_conn: "Mplex"
read_deadline: int
write_deadline: int
# NOTE: All methods used here are part of `Mplex` which is a derived
# class of IMuxedConn. Ignoring this type assignment should not pose
# any risk.
muxed_conn: "Mplex" # type: ignore[assignment]
read_deadline: int | None
write_deadline: int | None
# TODO: Add lock for read/write to avoid interleaving receiving messages?
close_lock: trio.Lock
@ -89,7 +94,7 @@ class MplexStream(IMuxedStream):
self._buf = self._buf[len(payload) :]
return bytes(payload)
def _read_return_when_blocked(self) -> bytes:
def _read_return_when_blocked(self) -> bytearray:
buf = bytearray()
while True:
try:
@ -99,7 +104,7 @@ class MplexStream(IMuxedStream):
break
return buf
async def read(self, n: int = None) -> bytes:
async def read(self, n: int | None = None) -> bytes:
"""
Read up to n bytes. Read possibly returns fewer than `n` bytes, if
there are not enough bytes in the Mplex buffer. If `n is None`, read
@ -254,6 +259,19 @@ class MplexStream(IMuxedStream):
self.write_deadline = ttl
return True
def get_remote_address(self) -> Optional[tuple[str, int]]:
def get_remote_address(self) -> tuple[str, int] | None:
"""Delegate to the parent Mplex connection."""
return self.muxed_conn.get_remote_address()
async def __aenter__(self) -> "MplexStream":
"""Enter the async context manager."""
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Exit the async context manager and close the stream."""
await self.close()

View File

@ -95,7 +95,7 @@ class MuxerMultistream:
if protocol == PROTOCOL_ID:
async with trio.open_nursery():
def on_close() -> None:
async def on_close() -> None:
pass
return Yamux(

View File

@ -3,15 +3,19 @@ Yamux stream multiplexer implementation for py-libp2p.
This is the preferred multiplexing protocol due to its performance and feature set.
Mplex is also available for legacy compatibility but may be deprecated in the future.
"""
from collections.abc import (
Awaitable,
Callable,
)
import inspect
import logging
import struct
from types import (
TracebackType,
)
from typing import (
Callable,
Optional,
Any,
)
import trio
@ -74,6 +78,19 @@ class YamuxStream(IMuxedStream):
self.recv_window = DEFAULT_WINDOW_SIZE
self.window_lock = trio.Lock()
async def __aenter__(self) -> "YamuxStream":
"""Enter the async context manager."""
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Exit the async context manager and close the stream."""
await self.close()
async def write(self, data: bytes) -> None:
if self.send_closed:
raise MuxedStreamError("Stream is closed for sending")
@ -110,7 +127,7 @@ class YamuxStream(IMuxedStream):
if self.send_window < DEFAULT_WINDOW_SIZE // 2:
await self.send_window_update()
async def send_window_update(self, increment: Optional[int] = None) -> None:
async def send_window_update(self, increment: int | None = None) -> None:
"""Send a window update to peer."""
if increment is None:
increment = DEFAULT_WINDOW_SIZE - self.recv_window
@ -125,7 +142,7 @@ class YamuxStream(IMuxedStream):
)
await self.conn.secured_conn.write(header)
async def read(self, n: int = -1) -> bytes:
async def read(self, n: int | None = -1) -> bytes:
# Handle None value for n by converting it to -1
if n is None:
n = -1
@ -145,8 +162,7 @@ class YamuxStream(IMuxedStream):
if buffer and len(buffer) > 0:
# Wait for closure even if data is available
logging.debug(
f"Stream {self.stream_id}:"
f"Waiting for FIN before returning data"
f"Stream {self.stream_id}:Waiting for FIN before returning data"
)
await self.conn.stream_events[self.stream_id].wait()
self.conn.stream_events[self.stream_id] = trio.Event()
@ -224,7 +240,7 @@ class YamuxStream(IMuxedStream):
"""
raise NotImplementedError("Yamux does not support setting read deadlines")
def get_remote_address(self) -> Optional[tuple[str, int]]:
def get_remote_address(self) -> tuple[str, int] | None:
"""
Returns the remote address of the underlying connection.
"""
@ -252,8 +268,8 @@ class Yamux(IMuxedConn):
self,
secured_conn: ISecureConn,
peer_id: ID,
is_initiator: Optional[bool] = None,
on_close: Optional[Callable[[], Awaitable[None]]] = None,
is_initiator: bool | None = None,
on_close: Callable[[], Awaitable[Any]] | None = None,
) -> None:
self.secured_conn = secured_conn
self.peer_id = peer_id
@ -267,7 +283,7 @@ class Yamux(IMuxedConn):
self.is_initiator_value = (
is_initiator if is_initiator is not None else secured_conn.is_initiator
)
self.next_stream_id = 1 if self.is_initiator_value else 2
self.next_stream_id: int = 1 if self.is_initiator_value else 2
self.streams: dict[int, YamuxStream] = {}
self.streams_lock = trio.Lock()
self.new_stream_send_channel: MemorySendChannel[YamuxStream]
@ -281,7 +297,7 @@ class Yamux(IMuxedConn):
self.event_started = trio.Event()
self.stream_buffers: dict[int, bytearray] = {}
self.stream_events: dict[int, trio.Event] = {}
self._nursery: Optional[Nursery] = None
self._nursery: Nursery | None = None
async def start(self) -> None:
logging.debug(f"Starting Yamux for {self.peer_id}")
@ -449,8 +465,14 @@ class Yamux(IMuxedConn):
# Wait for data if stream is still open
logging.debug(f"Waiting for data on stream {self.peer_id}:{stream_id}")
await self.stream_events[stream_id].wait()
self.stream_events[stream_id] = trio.Event()
try:
await self.stream_events[stream_id].wait()
self.stream_events[stream_id] = trio.Event()
except KeyError:
raise MuxedStreamEOF("Stream was removed")
# This line should never be reached, but satisfies the type checker
raise MuxedStreamEOF("Unexpected end of read_stream")
async def handle_incoming(self) -> None:
while not self.event_shutting_down.is_set():
@ -458,8 +480,7 @@ class Yamux(IMuxedConn):
header = await self.secured_conn.read(HEADER_SIZE)
if not header or len(header) < HEADER_SIZE:
logging.debug(
f"Connection closed or"
f"incomplete header for peer {self.peer_id}"
f"Connection closed orincomplete header for peer {self.peer_id}"
)
self.event_shutting_down.set()
await self._cleanup_on_error()
@ -528,8 +549,7 @@ class Yamux(IMuxedConn):
)
elif error_code == GO_AWAY_PROTOCOL_ERROR:
logging.error(
f"Received GO_AWAY for peer"
f"{self.peer_id}: Protocol error"
f"Received GO_AWAY for peer{self.peer_id}: Protocol error"
)
elif error_code == GO_AWAY_INTERNAL_ERROR:
logging.error(

View File

@ -1,12 +1,10 @@
# Copied from https://github.com/ethereum/async-service
import os
from typing import (
Any,
)
from typing import Any
def get_task_name(value: Any, explicit_name: str = None) -> str:
def get_task_name(value: Any, explicit_name: str | None = None) -> str:
# inline import to ensure `_utils` is always importable from the rest of
# the module.
from .abc import ( # noqa: F401

View File

@ -28,33 +28,27 @@ class TaskAPI(Hashable):
parent: Optional["TaskWithChildrenAPI"]
@abstractmethod
async def run(self) -> None:
...
async def run(self) -> None: ...
@abstractmethod
async def cancel(self) -> None:
...
async def cancel(self) -> None: ...
@property
@abstractmethod
def is_done(self) -> bool:
...
def is_done(self) -> bool: ...
@abstractmethod
async def wait_done(self) -> None:
...
async def wait_done(self) -> None: ...
class TaskWithChildrenAPI(TaskAPI):
children: set[TaskAPI]
@abstractmethod
def add_child(self, child: TaskAPI) -> None:
...
def add_child(self, child: TaskAPI) -> None: ...
@abstractmethod
def discard_child(self, child: TaskAPI) -> None:
...
def discard_child(self, child: TaskAPI) -> None: ...
class ServiceAPI(ABC):
@ -212,7 +206,11 @@ class InternalManagerAPI(ManagerAPI):
@trio_typing.takes_callable_and_args
@abstractmethod
def run_task(
self, async_fn: AsyncFn, *args: Any, daemon: bool = False, name: str = None
self,
async_fn: AsyncFn,
*args: Any,
daemon: bool = False,
name: str | None = None,
) -> None:
"""
Run a task in the background. If the function throws an exception it
@ -225,7 +223,9 @@ class InternalManagerAPI(ManagerAPI):
@trio_typing.takes_callable_and_args
@abstractmethod
def run_daemon_task(self, async_fn: AsyncFn, *args: Any, name: str = None) -> None:
def run_daemon_task(
self, async_fn: AsyncFn, *args: Any, name: str | None = None
) -> None:
"""
Run a daemon task in the background.
@ -235,7 +235,7 @@ class InternalManagerAPI(ManagerAPI):
@abstractmethod
def run_child_service(
self, service: ServiceAPI, daemon: bool = False, name: str = None
self, service: ServiceAPI, daemon: bool = False, name: str | None = None
) -> "ManagerAPI":
"""
Run a service in the background. If the function throws an exception it
@ -248,7 +248,7 @@ class InternalManagerAPI(ManagerAPI):
@abstractmethod
def run_daemon_child_service(
self, service: ServiceAPI, name: str = None
self, service: ServiceAPI, name: str | None = None
) -> "ManagerAPI":
"""
Run a daemon service in the background.

View File

@ -9,6 +9,7 @@ from collections import (
)
from collections.abc import (
Awaitable,
Callable,
Iterable,
Sequence,
)
@ -16,8 +17,6 @@ import logging
import sys
from typing import (
Any,
Callable,
Optional,
TypeVar,
cast,
)
@ -98,7 +97,7 @@ def as_service(service_fn: LogicFnType) -> type[ServiceAPI]:
class BaseTask(TaskAPI):
def __init__(
self, name: str, daemon: bool, parent: Optional[TaskWithChildrenAPI]
self, name: str, daemon: bool, parent: TaskWithChildrenAPI | None
) -> None:
# meta
self.name = name
@ -125,7 +124,7 @@ class BaseTask(TaskAPI):
class BaseTaskWithChildren(BaseTask, TaskWithChildrenAPI):
def __init__(
self, name: str, daemon: bool, parent: Optional[TaskWithChildrenAPI]
self, name: str, daemon: bool, parent: TaskWithChildrenAPI | None
) -> None:
super().__init__(name, daemon, parent)
self.children = set()
@ -142,26 +141,20 @@ T = TypeVar("T", bound="BaseFunctionTask")
class BaseFunctionTask(BaseTaskWithChildren):
@classmethod
def iterate_tasks(cls: type[T], *tasks: TaskAPI) -> Iterable[T]:
def iterate_tasks(cls, *tasks: TaskAPI) -> Iterable["BaseFunctionTask"]:
"""Iterate over all tasks of this class type and their children recursively."""
for task in tasks:
if isinstance(task, cls):
if isinstance(task, BaseFunctionTask):
yield task
else:
continue
yield from cls.iterate_tasks(
*(
child_task
for child_task in task.children
if isinstance(child_task, cls)
)
)
if isinstance(task, TaskWithChildrenAPI):
yield from cls.iterate_tasks(*task.children)
def __init__(
self,
name: str,
daemon: bool,
parent: Optional[TaskWithChildrenAPI],
parent: TaskWithChildrenAPI | None,
async_fn: AsyncFn,
async_fn_args: Sequence[Any],
) -> None:
@ -259,12 +252,15 @@ class BaseManager(InternalManagerAPI):
# Wait API
#
def run_daemon_task(
self, async_fn: Callable[..., Awaitable[Any]], *args: Any, name: str = None
self,
async_fn: Callable[..., Awaitable[Any]],
*args: Any,
name: str | None = None,
) -> None:
self.run_task(async_fn, *args, daemon=True, name=name)
def run_daemon_child_service(
self, service: ServiceAPI, name: str = None
self, service: ServiceAPI, name: str | None = None
) -> ManagerAPI:
return self.run_child_service(service, daemon=True, name=name)
@ -286,8 +282,7 @@ class BaseManager(InternalManagerAPI):
# Task Management
#
@abstractmethod
def _schedule_task(self, task: TaskAPI) -> None:
...
def _schedule_task(self, task: TaskAPI) -> None: ...
def _common_run_task(self, task: TaskAPI) -> None:
if not self.is_running:
@ -307,7 +302,7 @@ class BaseManager(InternalManagerAPI):
self._schedule_task(task)
def _add_child_task(
self, parent: Optional[TaskWithChildrenAPI], task: TaskAPI
self, parent: TaskWithChildrenAPI | None, task: TaskAPI
) -> None:
if parent is None:
all_children = self._root_tasks

View File

@ -6,7 +6,9 @@ from __future__ import (
from collections.abc import (
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Iterable,
Sequence,
)
from contextlib import (
@ -16,7 +18,6 @@ import functools
import sys
from typing import (
Any,
Callable,
Optional,
TypeVar,
cast,
@ -59,6 +60,16 @@ from .typing import (
class FunctionTask(BaseFunctionTask):
_trio_task: trio.lowlevel.Task | None = None
@classmethod
def iterate_tasks(cls, *tasks: TaskAPI) -> Iterable[FunctionTask]:
"""Iterate over all FunctionTask instances and their children recursively."""
for task in tasks:
if isinstance(task, FunctionTask):
yield task
if isinstance(task, TaskWithChildrenAPI):
yield from cls.iterate_tasks(*task.children)
def __init__(
self,
name: str,
@ -75,7 +86,7 @@ class FunctionTask(BaseFunctionTask):
# Each task gets its own `CancelScope` which is how we can manually
# control cancellation order of the task DAG
self._cancel_scope = trio.CancelScope()
self._cancel_scope = trio.CancelScope() # type: ignore[call-arg]
#
# Trio specific API
@ -309,7 +320,7 @@ class TrioManager(BaseManager):
async_fn: Callable[..., Awaitable[Any]],
*args: Any,
daemon: bool = False,
name: str = None,
name: str | None = None,
) -> None:
task = FunctionTask(
name=get_task_name(async_fn, name),
@ -322,7 +333,7 @@ class TrioManager(BaseManager):
self._common_run_task(task)
def run_child_service(
self, service: ServiceAPI, daemon: bool = False, name: str = None
self, service: ServiceAPI, daemon: bool = False, name: str | None = None
) -> ManagerAPI:
task = ChildServiceTask(
name=get_task_name(service, name),
@ -416,7 +427,12 @@ def external_api(func: TFunc) -> TFunc:
async with trio.open_nursery() as nursery:
# mypy's type hints for start_soon break with this invocation.
nursery.start_soon(
_wait_api_fn, self, func, args, kwargs, send_channel # type: ignore
_wait_api_fn, # type: ignore
self,
func,
args,
kwargs,
send_channel,
)
nursery.start_soon(_wait_finished, self, func, send_channel)
result, err = await receive_channel.receive()

View File

@ -2,13 +2,13 @@
from collections.abc import (
Awaitable,
Callable,
)
from types import (
TracebackType,
)
from typing import (
Any,
Callable,
)
EXC_INFO = tuple[type[BaseException], BaseException, TracebackType]

View File

@ -32,7 +32,7 @@ class GossipsubParams(NamedTuple):
degree: int = 10
degree_low: int = 9
degree_high: int = 11
direct_peers: Sequence[PeerInfo] = None
direct_peers: Sequence[PeerInfo] = []
time_to_live: int = 30
gossip_window: int = 3
gossip_history: int = 5

View File

@ -1,10 +1,8 @@
from collections.abc import (
Awaitable,
)
import logging
from typing import (
Callable,
)
import logging
import trio
@ -63,12 +61,12 @@ async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm) -> None:
logging.debug(
"Swarm connection verification failed on attempt"
+ f" {attempt+1}, retrying..."
+ f" {attempt + 1}, retrying..."
)
except Exception as e:
last_error = e
logging.debug(f"Swarm connection attempt {attempt+1} failed: {e}")
logging.debug(f"Swarm connection attempt {attempt + 1} failed: {e}")
await trio.sleep(retry_delay)
# If we got here, all retries failed
@ -115,12 +113,12 @@ async def connect(node1: IHost, node2: IHost) -> None:
return
logging.debug(
f"Connection verification failed on attempt {attempt+1}, retrying..."
f"Connection verification failed on attempt {attempt + 1}, retrying..."
)
except Exception as e:
last_error = e
logging.debug(f"Connection attempt {attempt+1} failed: {e}")
logging.debug(f"Connection attempt {attempt + 1} failed: {e}")
await trio.sleep(retry_delay)
# If we got here, all retries failed

View File

@ -1,11 +1,9 @@
from collections.abc import (
Awaitable,
Callable,
Sequence,
)
import logging
from typing import (
Callable,
)
from multiaddr import (
Multiaddr,
@ -44,7 +42,7 @@ class TCPListener(IListener):
self.handler = handler_function
# TODO: Get rid of `nursery`?
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> None:
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool:
"""
Put listener in listening mode and wait for incoming connections.
@ -56,7 +54,7 @@ class TCPListener(IListener):
handler: Callable[[trio.SocketStream], Awaitable[None]],
port: int,
host: str,
task_status: TaskStatus[Sequence[trio.SocketListener]] = None,
task_status: TaskStatus[Sequence[trio.SocketListener]],
) -> None:
"""Just a proxy function to add logging here."""
logger.debug("serve_tcp %s %s", host, port)
@ -67,18 +65,53 @@ class TCPListener(IListener):
remote_port: int = 0
try:
tcp_stream = TrioTCPStream(stream)
remote_host, remote_port = tcp_stream.get_remote_address()
remote_tuple = tcp_stream.get_remote_address()
if remote_tuple is not None:
remote_host, remote_port = remote_tuple
await self.handler(tcp_stream)
except Exception:
logger.debug(f"Connection from {remote_host}:{remote_port} failed.")
listeners = await nursery.start(
tcp_port_str = maddr.value_for_protocol("tcp")
if tcp_port_str is None:
logger.error(f"Cannot listen: TCP port is missing in multiaddress {maddr}")
return False
try:
tcp_port = int(tcp_port_str)
except ValueError:
logger.error(
f"Cannot listen: Invalid TCP port '{tcp_port_str}' "
f"in multiaddress {maddr}"
)
return False
ip4_host_str = maddr.value_for_protocol("ip4")
# For trio.serve_tcp, ip4_host_str (as host argument) can be None,
# which typically means listen on all available interfaces.
started_listeners = await nursery.start(
serve_tcp,
handler,
int(maddr.value_for_protocol("tcp")),
maddr.value_for_protocol("ip4"),
tcp_port,
ip4_host_str,
)
self.listeners.extend(listeners)
if started_listeners is None:
# This implies that task_status.started() was not called within serve_tcp,
# likely because trio.serve_tcp itself failed to start (e.g., port in use).
logger.error(
f"Failed to start TCP listener for {maddr}: "
f"`nursery.start` returned None. "
"This might be due to issues like the port already "
"being in use or invalid host."
)
return False
self.listeners.extend(started_listeners)
return True
def get_addrs(self) -> tuple[Multiaddr, ...]:
"""
@ -105,15 +138,42 @@ class TCP(ITransport):
:return: `RawConnection` if successful
:raise OpenConnectionError: raised when failed to open connection
"""
self.host = maddr.value_for_protocol("ip4")
self.port = int(maddr.value_for_protocol("tcp"))
host_str = maddr.value_for_protocol("ip4")
port_str = maddr.value_for_protocol("tcp")
if host_str is None:
raise OpenConnectionError(
f"Failed to dial {maddr}: IP address not found in multiaddr."
)
if port_str is None:
raise OpenConnectionError(
f"Failed to dial {maddr}: TCP port not found in multiaddr."
)
try:
stream = await trio.open_tcp_stream(self.host, self.port)
except OSError as error:
raise OpenConnectionError from error
read_write_closer = TrioTCPStream(stream)
port_int = int(port_str)
except ValueError:
raise OpenConnectionError(
f"Failed to dial {maddr}: Invalid TCP port '{port_str}'."
)
try:
# trio.open_tcp_stream requires host to be str or bytes, not None.
stream = await trio.open_tcp_stream(host_str, port_int)
except OSError as error:
# OSError is common for network issues like "Connection refused"
# or "Host unreachable".
raise OpenConnectionError(
f"Failed to open TCP stream to {maddr}: {error}"
) from error
except Exception as error:
# Catch other potential errors from trio.open_tcp_stream and wrap them.
raise OpenConnectionError(
f"An unexpected error occurred when dialing {maddr}: {error}"
) from error
read_write_closer = TrioTCPStream(stream)
return RawConnection(read_write_closer, True)
def create_listener(self, handler_function: THandler) -> TCPListener:

View File

@ -13,15 +13,13 @@ import sys
import threading
from typing import (
Any,
Optional,
Union,
)
# Create a log queue
log_queue: "queue.Queue[Any]" = queue.Queue()
# Store the current listener to stop it on exit
_current_listener: Optional[logging.handlers.QueueListener] = None
_current_listener: logging.handlers.QueueListener | None = None
# Event to track when the listener is ready
_listener_ready = threading.Event()
@ -135,7 +133,7 @@ def setup_logging() -> None:
formatter = logging.Formatter(DEFAULT_LOG_FORMAT)
# Configure handlers
handlers: list[Union[logging.StreamHandler[Any], logging.FileHandler]] = []
handlers: list[logging.StreamHandler[Any] | logging.FileHandler] = []
# Console handler
console_handler = logging.StreamHandler(sys.stderr)

View File

@ -0,0 +1 @@
The `NetStream.state` property is now async and requires `await`. Update any direct state access to use `await stream.state`.

View File

@ -0,0 +1 @@
Added proper state management and resource cleanup to `NetStream`, fixing memory leaks and improved error handling.

View File

@ -0,0 +1 @@
Modernizes several aspects of the project, notably using ``pyproject.toml`` for project info instead of ``setup.py``, using ``ruff`` to replace several separate linting tools, and ``pyrefly`` in addition to ``mypy`` for typing. Also includes changes across the codebase to conform to new linting and typing rules.

View File

@ -0,0 +1 @@
Removes support for python 3.9 and updates some code conventions, notably using ``|`` operator in typing instead of ``Optional`` or ``Union``

View File

@ -0,0 +1 @@
implement AsyncContextManager for IMuxedStream to support async with

View File

@ -0,0 +1 @@
feat: add method to compute time since last message published by a peer and remove fanout peers based on ttl.

View File

@ -0,0 +1 @@
implement blacklist management for `pubsub.Pubsub` with methods to get, add, remove, check, and clear blacklisted peer IDs.

View File

@ -0,0 +1 @@
fix: remove expired peers from peerstore based on TTL

View File

@ -0,0 +1 @@
Updated examples to automatically use random port, when `-p` flag is not given

View File

@ -1,20 +1,105 @@
[tool.autoflake]
exclude = "__init__.py"
remove_all_unused_imports = true
[tool.isort]
combine_as_imports = false
extra_standard_library = "pytest"
force_grid_wrap = 1
force_sort_within_sections = true
force_to_top = "pytest"
honor_noqa = true
known_first_party = "libp2p"
known_third_party = "anyio,factory,lru,p2pclient,pytest,noise"
multi_line_output = 3
profile = "black"
skip_glob= "*_pb2*.py, *.pyi"
use_parentheses = true
[build-system]
requires = ["setuptools>=42", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "libp2p"
version = "0.2.7"
description = "libp2p: The Python implementation of the libp2p networking stack"
readme = "README.md"
requires-python = ">=3.10, <4.0"
license = { text = "MIT AND Apache-2.0" }
keywords = ["libp2p", "p2p"]
authors = [
{ name = "The Ethereum Foundation", email = "snakecharmers@ethereum.org" },
]
dependencies = [
"base58>=1.0.3",
"coincurve>=10.0.0",
"exceptiongroup>=1.2.0; python_version < '3.11'",
"grpcio>=1.41.0",
"lru-dict>=1.1.6",
"multiaddr>=0.0.9",
"mypy-protobuf>=3.0.0",
"noiseprotocol>=0.3.0",
"protobuf>=3.20.1,<4.0.0",
"pycryptodome>=3.9.2",
"pymultihash>=0.8.2",
"pynacl>=1.3.0",
"rpcudp>=3.0.0",
"trio-typing>=0.0.4",
"trio>=0.26.0",
"fastecdsa==2.3.2; sys_platform != 'win32'",
]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Natural Language :: English",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
]
[project.urls]
Homepage = "https://github.com/libp2p/py-libp2p"
[project.scripts]
chat-demo = "examples.chat.chat:main"
echo-demo = "examples.echo.echo:main"
ping-demo = "examples.ping.ping:main"
identify-demo = "examples.identify.identify:main"
identify-push-demo = "examples.identify_push.identify_push_demo:run_main"
identify-push-listener-dialer-demo = "examples.identify_push.identify_push_listener_dialer:main"
pubsub-demo = "examples.pubsub.pubsub:main"
[project.optional-dependencies]
dev = [
"build>=0.9.0",
"bump_my_version>=0.19.0",
"ipython",
"mypy>=1.15.0",
"pre-commit>=3.4.0",
"tox>=4.0.0",
"twine",
"wheel",
"setuptools>=42",
"sphinx>=6.0.0",
"sphinx_rtd_theme>=1.0.0",
"towncrier>=24,<25",
"p2pclient==0.2.0",
"pytest>=7.0.0",
"pytest-xdist>=2.4.0",
"pytest-trio>=0.5.2",
"factory-boy>=2.12.0,<3.0.0",
"ruff>=0.11.10",
"pyrefly (>=0.17.1,<0.18.0)",
]
docs = [
"sphinx>=6.0.0",
"sphinx_rtd_theme>=1.0.0",
"towncrier>=24,<25",
"tomli; python_version < '3.11'",
]
test = [
"p2pclient==0.2.0",
"pytest>=7.0.0",
"pytest-xdist>=2.4.0",
"pytest-trio>=0.5.2",
"factory-boy>=2.12.0,<3.0.0",
]
[tool.setuptools]
include-package-data = true
[tool.setuptools.packages.find]
exclude = ["scripts*", "tests*"]
[tool.setuptools.package-data]
libp2p = ["py.typed"]
[tool.mypy]
check_untyped_defs = true
@ -27,37 +112,12 @@ disallow_untyped_defs = true
ignore_missing_imports = true
incremental = false
strict_equality = true
strict_optional = false
strict_optional = true
warn_redundant_casts = true
warn_return_any = false
warn_unused_configs = true
warn_unused_ignores = true
warn_unused_ignores = false
[tool.pydocstyle]
# All error codes found here:
# http://www.pydocstyle.org/en/3.0.0/error_codes.html
#
# Ignored:
# D1 - Missing docstring error codes
#
# Selected:
# D2 - Whitespace error codes
# D3 - Quote error codes
# D4 - Content related error codes
select = "D2,D3,D4"
# Extra ignores:
# D200 - One-line docstring should fit on one line with quotes
# D203 - 1 blank line required before class docstring
# D204 - 1 blank line required after class docstring
# D205 - 1 blank line required between summary line and description
# D212 - Multi-line docstring summary should start at the first line
# D302 - Use u""" for Unicode docstrings
# D400 - First line should end with a period
# D401 - First line should be in imperative mood
# D412 - No blank lines allowed between a section header and its content
# D415 - First line should end with a period, question mark, or exclamation point
add-ignore = "D200,D203,D204,D205,D212,D302,D400,D401,D412,D415"
# Explanation:
# D400 - Enabling this error code seems to make it a requirement that the first
@ -138,8 +198,8 @@ parse = """
)?
"""
serialize = [
"{major}.{minor}.{patch}-{stage}.{devnum}",
"{major}.{minor}.{patch}",
"{major}.{minor}.{patch}-{stage}.{devnum}",
"{major}.{minor}.{patch}",
]
search = "{current_version}"
replace = "{new_version}"
@ -156,11 +216,7 @@ message = "Bump version: {current_version} → {new_version}"
[tool.bumpversion.parts.stage]
optional_value = "stable"
first_value = "stable"
values = [
"alpha",
"beta",
"stable",
]
values = ["alpha", "beta", "stable"]
[tool.bumpversion.part.devnum]
@ -168,3 +224,63 @@ values = [
filename = "setup.py"
search = "version=\"{current_version}\""
replace = "version=\"{new_version}\""
[[tool.bumpversion.files]]
filename = "pyproject.toml" # Keep pyproject.toml version in sync
search = 'version = "{current_version}"'
replace = 'version = "{new_version}"'
[tool.ruff]
line-length = 88
exclude = ["__init__.py", "*_pb2*.py", "*.pyi"]
[tool.ruff.lint]
select = [
"F", # Pyflakes
"E", # pycodestyle errors
"W", # pycodestyle warnings
"I", # isort
"D", # pydocstyle
]
# Ignores from pydocstyle and any other desired ones
ignore = [
"D100",
"D101",
"D102",
"D103",
"D105",
"D106",
"D107",
"D200",
"D203",
"D204",
"D205",
"D212",
"D400",
"D401",
"D412",
"D415",
]
[tool.ruff.lint.isort]
force-wrap-aliases = true
combine-as-imports = true
extra-standard-library = []
force-sort-within-sections = true
known-first-party = ["libp2p", "tests"]
known-third-party = ["anyio", "factory", "lru", "p2pclient", "pytest", "noise"]
force-to-top = ["pytest"]
[tool.ruff.format]
# Using Ruff's Black-compatible formatter.
# Options like quote-style = "double" or indent-style = "space" can be set here if needed.
[tool.pyrefly]
project_includes = ["libp2p", "examples", "tests"]
project_excludes = [
"**/.project-template/**",
"**/docs/conf.py",
"**/*pb2.py",
"**/*.pyi",
".venv/**",
]

117
setup.py
View File

@ -1,117 +0,0 @@
#!/usr/bin/env python
import sys
from setuptools import (
find_packages,
setup,
)
description = "libp2p: The Python implementation of the libp2p networking stack"
# Platform-specific dependencies
if sys.platform == "win32":
crypto_requires = [] # We'll use coincurve instead of fastecdsa on Windows
else:
crypto_requires = ["fastecdsa==1.7.5"]
extras_require = {
"dev": [
"build>=0.9.0",
"bump_my_version>=0.19.0",
"ipython",
"mypy==1.10.0",
"pre-commit>=3.4.0",
"tox>=4.0.0",
"twine",
"wheel",
],
"docs": [
"sphinx>=6.0.0",
"sphinx_rtd_theme>=1.0.0",
"towncrier>=24,<25",
],
"test": [
"p2pclient==0.2.0",
"pytest>=7.0.0",
"pytest-xdist>=2.4.0",
"pytest-trio>=0.5.2",
"factory-boy>=2.12.0,<3.0.0",
],
}
extras_require["dev"] = (
extras_require["dev"] + extras_require["docs"] + extras_require["test"]
)
try:
with open("./README.md", encoding="utf-8") as readme:
long_description = readme.read()
except FileNotFoundError:
long_description = description
install_requires = [
"base58>=1.0.3",
"coincurve>=10.0.0",
"exceptiongroup>=1.2.0; python_version < '3.11'",
"grpcio>=1.41.0",
"lru-dict>=1.1.6",
"multiaddr>=0.0.9",
"mypy-protobuf>=3.0.0",
"noiseprotocol>=0.3.0",
"protobuf>=6.30.1",
"pycryptodome>=3.9.2",
"pymultihash>=0.8.2",
"pynacl>=1.3.0",
"rpcudp>=3.0.0",
"trio-typing>=0.0.4",
"trio>=0.26.0",
]
# Add platform-specific dependencies
install_requires.extend(crypto_requires)
setup(
name="libp2p",
# *IMPORTANT*: Don't manually change the version here. See Contributing docs for the release process.
version="0.2.7",
description=description,
long_description=long_description,
long_description_content_type="text/markdown",
author="The Ethereum Foundation",
author_email="snakecharmers@ethereum.org",
url="https://github.com/libp2p/py-libp2p",
include_package_data=True,
install_requires=install_requires,
python_requires=">=3.9, <4",
extras_require=extras_require,
py_modules=["libp2p"],
license="MIT AND Apache-2.0",
license_files=("LICENSE-MIT", "LICENSE-APACHE"),
zip_safe=False,
keywords="libp2p p2p",
packages=find_packages(exclude=["scripts", "scripts.*", "tests", "tests.*"]),
package_data={"libp2p": ["py.typed"]},
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Natural Language :: English",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
],
platforms=["unix", "linux", "osx", "win32"],
entry_points={
"console_scripts": [
"chat-demo=examples.chat.chat:main",
"echo-demo=examples.echo.echo:main",
"ping-demo=examples.ping.ping:main",
"identify-demo=examples.identify.identify:main",
"identify-push-demo=examples.identify_push.identify_push_demo:run_main",
"identify-push-listener-dialer-demo=examples.identify_push.identify_push_listener_dialer:main",
"pubsub-demo=examples.pubsub.pubsub:main",
],
},
)

View File

@ -209,6 +209,18 @@ async def ping_demo(host_a, host_b):
async def pubsub_demo(host_a, host_b):
gossipsub_a = GossipSub(
[GOSSIPSUB_PROTOCOL_ID],
3,
2,
4,
)
gossipsub_b = GossipSub(
[GOSSIPSUB_PROTOCOL_ID],
3,
2,
4,
)
gossipsub_a = GossipSub([GOSSIPSUB_PROTOCOL_ID], 3, 2, 4, None, 1, 1)
gossipsub_b = GossipSub([GOSSIPSUB_PROTOCOL_ID], 3, 2, 4, None, 1, 1)
pubsub_a = Pubsub(host_a, gossipsub_a)

View File

@ -76,18 +76,18 @@ async def test_update_status():
# Less than 2 successful dials should result in PRIVATE status
service.dial_results = {
ID("peer1"): True,
ID("peer2"): False,
ID("peer3"): False,
ID(b"peer1"): True,
ID(b"peer2"): False,
ID(b"peer3"): False,
}
service.update_status()
assert service.status == AutoNATStatus.PRIVATE
# 2 or more successful dials should result in PUBLIC status
service.dial_results = {
ID("peer1"): True,
ID("peer2"): True,
ID("peer3"): False,
ID(b"peer1"): True,
ID(b"peer2"): True,
ID(b"peer3"): False,
}
service.update_status()
assert service.status == AutoNATStatus.PUBLIC

View File

@ -22,9 +22,10 @@ async def test_host_routing_success():
@pytest.mark.trio
async def test_host_routing_fail():
async with RoutedHostFactory.create_batch_and_listen(
2
) as routed_hosts, HostFactory.create_batch_and_listen(1) as basic_hosts:
async with (
RoutedHostFactory.create_batch_and_listen(2) as routed_hosts,
HostFactory.create_batch_and_listen(1) as basic_hosts,
):
# routing fails because host_c does not use routing
with pytest.raises(ConnectionFailure):
await routed_hosts[0].connect(PeerInfo(basic_hosts[0].get_id(), []))

View File

@ -218,7 +218,6 @@ async def test_push_identify_to_peers_with_explicit_params(security_protocol):
This test ensures all parameters of push_identify_to_peers are properly tested.
"""
# Create four hosts to thoroughly test selective pushing
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,

View File

@ -8,23 +8,20 @@ into network after network has already started listening
TODO: Add tests for closed_stream, listen_close when those
features are implemented in swarm
"""
import enum
import pytest
from multiaddr import Multiaddr
import trio
from libp2p.abc import (
INetConn,
INetStream,
INetwork,
INotifee,
)
from libp2p.tools.async_service import (
background_trio_service,
)
from libp2p.tools.constants import (
LISTEN_MADDR,
)
from libp2p.tools.utils import (
connect_swarm,
)
from libp2p.tools.utils import connect_swarm
from tests.utils.factories import (
SwarmFactory,
)
@ -40,169 +37,94 @@ class Event(enum.Enum):
class MyNotifee(INotifee):
def __init__(self, events):
def __init__(self, events: list[Event]):
self.events = events
async def opened_stream(self, network, stream):
async def opened_stream(self, network: INetwork, stream: INetStream) -> None:
self.events.append(Event.OpenedStream)
async def closed_stream(self, network, stream):
async def closed_stream(self, network: INetwork, stream: INetStream) -> None:
# TODO: It is not implemented yet.
pass
async def connected(self, network, conn):
async def connected(self, network: INetwork, conn: INetConn) -> None:
self.events.append(Event.Connected)
async def disconnected(self, network, conn):
async def disconnected(self, network: INetwork, conn: INetConn) -> None:
self.events.append(Event.Disconnected)
async def listen(self, network, _multiaddr):
async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None:
self.events.append(Event.Listen)
async def listen_close(self, network, _multiaddr):
async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None:
# TODO: It is not implemented yet.
pass
@pytest.mark.trio
async def test_notify(security_protocol):
swarms = [SwarmFactory(security_protocol=security_protocol) for _ in range(2)]
events_0_0 = []
events_1_0 = []
events_0_without_listen = []
# Helper to wait for specific event
async def wait_for_event(events_list, expected_event, timeout=1.0):
start_time = trio.current_time()
while trio.current_time() - start_time < timeout:
if expected_event in events_list:
return True
await trio.sleep(0.01)
async def wait_for_event(events_list, event, timeout=1.0):
with trio.move_on_after(timeout):
while event not in events_list:
await trio.sleep(0.01)
return True
return False
# Run swarms.
async with background_trio_service(swarms[0]), background_trio_service(swarms[1]):
# Register events before listening
swarms[0].register_notifee(MyNotifee(events_0_0))
swarms[1].register_notifee(MyNotifee(events_1_0))
# Event lists for notifees
events_0_0 = []
events_0_1 = []
events_1_0 = []
events_1_1 = []
# Listen
async with trio.open_nursery() as nursery:
nursery.start_soon(swarms[0].listen, LISTEN_MADDR)
nursery.start_soon(swarms[1].listen, LISTEN_MADDR)
# Create two swarms, but do not listen yet
async with SwarmFactory.create_batch_and_listen(2) as swarms:
# Register notifees before listening
notifee_0_0 = MyNotifee(events_0_0)
notifee_0_1 = MyNotifee(events_0_1)
notifee_1_0 = MyNotifee(events_1_0)
notifee_1_1 = MyNotifee(events_1_1)
# Wait for Listen events
assert await wait_for_event(events_0_0, Event.Listen)
assert await wait_for_event(events_1_0, Event.Listen)
swarms[0].register_notifee(notifee_0_0)
swarms[0].register_notifee(notifee_0_1)
swarms[1].register_notifee(notifee_1_0)
swarms[1].register_notifee(notifee_1_1)
swarms[0].register_notifee(MyNotifee(events_0_without_listen))
# Connected
# Connect swarms
await connect_swarm(swarms[0], swarms[1])
assert await wait_for_event(events_0_0, Event.Connected)
assert await wait_for_event(events_1_0, Event.Connected)
assert await wait_for_event(events_0_without_listen, Event.Connected)
# OpenedStream: first
await swarms[0].new_stream(swarms[1].get_peer_id())
# OpenedStream: second
await swarms[0].new_stream(swarms[1].get_peer_id())
# OpenedStream: third, but different direction.
await swarms[1].new_stream(swarms[0].get_peer_id())
# Create a stream
stream = await swarms[0].new_stream(swarms[1].get_peer_id())
await stream.close()
# Clear any duplicate events that might have occurred
events_0_0.copy()
events_1_0.copy()
events_0_without_listen.copy()
# TODO: Check `ClosedStream` and `ListenClose` events after they are ready.
# Disconnected
# Close peer
await swarms[0].close_peer(swarms[1].get_peer_id())
assert await wait_for_event(events_0_0, Event.Disconnected)
assert await wait_for_event(events_1_0, Event.Disconnected)
assert await wait_for_event(events_0_without_listen, Event.Disconnected)
# Connected again, but different direction.
await connect_swarm(swarms[1], swarms[0])
# Wait for events
assert await wait_for_event(events_0_0, Event.Connected, 1.0)
assert await wait_for_event(events_0_0, Event.OpenedStream, 1.0)
# assert await wait_for_event(
# events_0_0, Event.ClosedStream, 1.0
# ) # Not implemented
assert await wait_for_event(events_0_0, Event.Disconnected, 1.0)
# Get the index of the first disconnected event
disconnect_idx_0_0 = events_0_0.index(Event.Disconnected)
disconnect_idx_1_0 = events_1_0.index(Event.Disconnected)
disconnect_idx_without_listen = events_0_without_listen.index(
Event.Disconnected
)
assert await wait_for_event(events_0_1, Event.Connected, 1.0)
assert await wait_for_event(events_0_1, Event.OpenedStream, 1.0)
# assert await wait_for_event(
# events_0_1, Event.ClosedStream, 1.0
# ) # Not implemented
assert await wait_for_event(events_0_1, Event.Disconnected, 1.0)
# Check for connected event after disconnect
assert await wait_for_event(
events_0_0[disconnect_idx_0_0 + 1 :], Event.Connected
)
assert await wait_for_event(
events_1_0[disconnect_idx_1_0 + 1 :], Event.Connected
)
assert await wait_for_event(
events_0_without_listen[disconnect_idx_without_listen + 1 :],
Event.Connected,
)
assert await wait_for_event(events_1_0, Event.Connected, 1.0)
assert await wait_for_event(events_1_0, Event.OpenedStream, 1.0)
# assert await wait_for_event(
# events_1_0, Event.ClosedStream, 1.0
# ) # Not implemented
assert await wait_for_event(events_1_0, Event.Disconnected, 1.0)
# Disconnected again, but different direction.
await swarms[1].close_peer(swarms[0].get_peer_id())
# Find index of the second connected event
second_connect_idx_0_0 = events_0_0.index(
Event.Connected, disconnect_idx_0_0 + 1
)
second_connect_idx_1_0 = events_1_0.index(
Event.Connected, disconnect_idx_1_0 + 1
)
second_connect_idx_without_listen = events_0_without_listen.index(
Event.Connected, disconnect_idx_without_listen + 1
)
# Check for second disconnected event
assert await wait_for_event(
events_0_0[second_connect_idx_0_0 + 1 :], Event.Disconnected
)
assert await wait_for_event(
events_1_0[second_connect_idx_1_0 + 1 :], Event.Disconnected
)
assert await wait_for_event(
events_0_without_listen[second_connect_idx_without_listen + 1 :],
Event.Disconnected,
)
# Verify the core sequence of events
expected_events_without_listen = [
Event.Connected,
Event.Disconnected,
Event.Connected,
Event.Disconnected,
]
# Filter events to check only pattern we care about
# (skipping OpenedStream which may vary)
filtered_events_0_0 = [
e
for e in events_0_0
if e in [Event.Listen, Event.Connected, Event.Disconnected]
]
filtered_events_1_0 = [
e
for e in events_1_0
if e in [Event.Listen, Event.Connected, Event.Disconnected]
]
filtered_events_without_listen = [
e
for e in events_0_without_listen
if e in [Event.Connected, Event.Disconnected]
]
# Check that the pattern matches
assert filtered_events_0_0[0] == Event.Listen, "First event should be Listen"
assert filtered_events_1_0[0] == Event.Listen, "First event should be Listen"
# Check pattern: Connected -> Disconnected -> Connected -> Disconnected
assert filtered_events_0_0[1:5] == expected_events_without_listen
assert filtered_events_1_0[1:5] == expected_events_without_listen
assert filtered_events_without_listen[:4] == expected_events_without_listen
assert await wait_for_event(events_1_1, Event.Connected, 1.0)
assert await wait_for_event(events_1_1, Event.OpenedStream, 1.0)
# assert await wait_for_event(
# events_1_1, Event.ClosedStream, 1.0
# ) # Not implemented
assert await wait_for_event(events_1_1, Event.Disconnected, 1.0)

View File

@ -13,6 +13,9 @@ from libp2p import (
from libp2p.network.exceptions import (
SwarmException,
)
from libp2p.network.swarm import (
Swarm,
)
from libp2p.tools.utils import (
connect_swarm,
)
@ -166,12 +169,14 @@ async def test_swarm_multiaddr(security_protocol):
def test_new_swarm_defaults_to_tcp():
swarm = new_swarm()
assert isinstance(swarm, Swarm)
assert isinstance(swarm.transport, TCP)
def test_new_swarm_tcp_multiaddr_supported():
addr = Multiaddr("/ip4/127.0.0.1/tcp/9999")
swarm = new_swarm(listen_addrs=[addr])
assert isinstance(swarm, Swarm)
assert isinstance(swarm.transport, TCP)

View File

@ -1,5 +1,9 @@
import pytest
from multiaddr import (
Multiaddr,
)
from libp2p.peer.id import ID
from libp2p.peer.peerstore import (
PeerStore,
PeerStoreError,
@ -11,51 +15,72 @@ from libp2p.peer.peerstore import (
def test_addrs_empty():
with pytest.raises(PeerStoreError):
store = PeerStore()
val = store.addrs("peer")
val = store.addrs(ID(b"peer"))
assert not val
def test_add_addr_single():
store = PeerStore()
store.add_addr("peer1", "/foo", 10)
store.add_addr("peer1", "/bar", 10)
store.add_addr("peer2", "/baz", 10)
store.add_addr(ID(b"peer1"), Multiaddr("/ip4/127.0.0.1/tcp/4001"), 10)
store.add_addr(ID(b"peer1"), Multiaddr("/ip4/127.0.0.1/tcp/4002"), 10)
store.add_addr(ID(b"peer2"), Multiaddr("/ip4/127.0.0.1/tcp/4003"), 10)
assert store.addrs("peer1") == ["/foo", "/bar"]
assert store.addrs("peer2") == ["/baz"]
assert store.addrs(ID(b"peer1")) == [
Multiaddr("/ip4/127.0.0.1/tcp/4001"),
Multiaddr("/ip4/127.0.0.1/tcp/4002"),
]
assert store.addrs(ID(b"peer2")) == [Multiaddr("/ip4/127.0.0.1/tcp/4003")]
def test_add_addrs_multiple():
store = PeerStore()
store.add_addrs("peer1", ["/foo1", "/bar1"], 10)
store.add_addrs("peer2", ["/foo2"], 10)
store.add_addrs(
ID(b"peer1"),
[Multiaddr("/ip4/127.0.0.1/tcp/40011"), Multiaddr("/ip4/127.0.0.1/tcp/40021")],
10,
)
store.add_addrs(ID(b"peer2"), [Multiaddr("/ip4/127.0.0.1/tcp/40012")], 10)
assert store.addrs("peer1") == ["/foo1", "/bar1"]
assert store.addrs("peer2") == ["/foo2"]
assert store.addrs(ID(b"peer1")) == [
Multiaddr("/ip4/127.0.0.1/tcp/40011"),
Multiaddr("/ip4/127.0.0.1/tcp/40021"),
]
assert store.addrs(ID(b"peer2")) == [Multiaddr("/ip4/127.0.0.1/tcp/40012")]
def test_clear_addrs():
store = PeerStore()
store.add_addrs("peer1", ["/foo1", "/bar1"], 10)
store.add_addrs("peer2", ["/foo2"], 10)
store.clear_addrs("peer1")
store.add_addrs(
ID(b"peer1"),
[Multiaddr("/ip4/127.0.0.1/tcp/40011"), Multiaddr("/ip4/127.0.0.1/tcp/40021")],
10,
)
store.add_addrs(ID(b"peer2"), [Multiaddr("/ip4/127.0.0.1/tcp/40012")], 10)
store.clear_addrs(ID(b"peer1"))
assert store.addrs("peer1") == []
assert store.addrs("peer2") == ["/foo2"]
assert store.addrs(ID(b"peer1")) == []
assert store.addrs(ID(b"peer2")) == [Multiaddr("/ip4/127.0.0.1/tcp/40012")]
store.add_addrs("peer1", ["/foo1", "/bar1"], 10)
store.add_addrs(
ID(b"peer1"),
[Multiaddr("/ip4/127.0.0.1/tcp/40011"), Multiaddr("/ip4/127.0.0.1/tcp/40021")],
10,
)
assert store.addrs("peer1") == ["/foo1", "/bar1"]
assert store.addrs(ID(b"peer1")) == [
Multiaddr("/ip4/127.0.0.1/tcp/40011"),
Multiaddr("/ip4/127.0.0.1/tcp/40021"),
]
def test_peers_with_addrs():
store = PeerStore()
store.add_addrs("peer1", [], 10)
store.add_addrs("peer2", ["/foo"], 10)
store.add_addrs("peer3", ["/bar"], 10)
store.add_addrs(ID(b"peer1"), [], 10)
store.add_addrs(ID(b"peer2"), [Multiaddr("/ip4/127.0.0.1/tcp/4001")], 10)
store.add_addrs(ID(b"peer3"), [Multiaddr("/ip4/127.0.0.1/tcp/4002")], 10)
assert set(store.peers_with_addrs()) == {"peer2", "peer3"}
assert set(store.peers_with_addrs()) == {ID(b"peer2"), ID(b"peer3")}
store.clear_addrs("peer2")
store.clear_addrs(ID(b"peer2"))
assert set(store.peers_with_addrs()) == {"peer3"}
assert set(store.peers_with_addrs()) == {ID(b"peer3")}

View File

@ -23,9 +23,7 @@ kBZ7WvkmPV3aPL6jnwp2pXepntdVnaTiSxJ1dkXShZ/VSSDNZMYKY306EtHrIu3NZHtXhdyHKcggDXr
qkBrdgErAkAlpGPojUwemOggr4FD8sLX1ot2hDJyyV7OK2FXfajWEYJyMRL1Gm9Uk1+Un53RAkJneqp
JGAzKpyttXBTIDO51AkEA98KTiROMnnU8Y6Mgcvr68/SMIsvCYMt9/mtwSBGgl80VaTQ5Hpaktl6Xbh
VUt5Wv0tRxlXZiViCGCD1EtrrwTw==
""".replace(
"\n", ""
)
""".replace("\n", "")
EXPECTED_PEER_ID = "QmRK3JgmVEGiewxWbhpXLJyjWuGuLeSTMTndA1coMHEy5o"

View File

@ -1,4 +1,7 @@
from collections.abc import Sequence
import pytest
from multiaddr import Multiaddr
from libp2p.crypto.secp256k1 import (
create_new_key_pair,
@ -8,7 +11,7 @@ from libp2p.peer.peerdata import (
PeerDataError,
)
MOCK_ADDR = "/peer"
MOCK_ADDR = Multiaddr("/ip4/127.0.0.1/tcp/4001")
MOCK_KEYPAIR = create_new_key_pair()
MOCK_PUBKEY = MOCK_KEYPAIR.public_key
MOCK_PRIVKEY = MOCK_KEYPAIR.private_key
@ -23,7 +26,7 @@ def test_get_protocols_empty():
# Test case when adding protocols
def test_add_protocols():
peer_data = PeerData()
protocols = ["protocol1", "protocol2"]
protocols: Sequence[str] = ["protocol1", "protocol2"]
peer_data.add_protocols(protocols)
assert peer_data.get_protocols() == protocols
@ -31,7 +34,7 @@ def test_add_protocols():
# Test case when setting protocols
def test_set_protocols():
peer_data = PeerData()
protocols = ["protocolA", "protocolB"]
protocols: Sequence[str] = ["protocol1", "protocol2"]
peer_data.set_protocols(protocols)
assert peer_data.get_protocols() == protocols
@ -39,7 +42,7 @@ def test_set_protocols():
# Test case when adding addresses
def test_add_addrs():
peer_data = PeerData()
addresses = [MOCK_ADDR]
addresses: Sequence[Multiaddr] = [MOCK_ADDR]
peer_data.add_addrs(addresses)
assert peer_data.get_addrs() == addresses
@ -47,7 +50,7 @@ def test_add_addrs():
# Test case when adding same address more than once
def test_add_dup_addrs():
peer_data = PeerData()
addresses = [MOCK_ADDR, MOCK_ADDR]
addresses: Sequence[Multiaddr] = [MOCK_ADDR, MOCK_ADDR]
peer_data.add_addrs(addresses)
peer_data.add_addrs(addresses)
assert peer_data.get_addrs() == [MOCK_ADDR]
@ -56,7 +59,7 @@ def test_add_dup_addrs():
# Test case for clearing addresses
def test_clear_addrs():
peer_data = PeerData()
addresses = [MOCK_ADDR]
addresses: Sequence[Multiaddr] = [MOCK_ADDR]
peer_data.add_addrs(addresses)
peer_data.clear_addrs()
assert peer_data.get_addrs() == []

View File

@ -6,16 +6,12 @@ import multihash
from libp2p.crypto.rsa import (
create_new_key_pair,
)
import libp2p.peer.id as PeerID
from libp2p.peer.id import (
ID,
)
ALPHABETS = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
# ensure we are not in "debug" mode for the following tests
PeerID.FRIENDLY_IDS = False
def test_eq_impl_for_bytes():
random_id_string = ""
@ -70,8 +66,8 @@ def test_eq_true():
def test_eq_false():
peer_id = ID("efgh")
other = ID("abcd")
peer_id = ID(b"efgh")
other = ID(b"abcd")
assert peer_id != other
@ -91,7 +87,7 @@ def test_id_from_base58():
for _ in range(10):
random_id_string += random.choice(ALPHABETS)
expected = ID(base58.b58decode(random_id_string))
actual = ID.from_base58(random_id_string.encode())
actual = ID.from_base58(random_id_string)
assert actual == expected

View File

@ -17,10 +17,14 @@ VALID_MULTI_ADDR_STR = "/ip4/127.0.0.1/tcp/8000/p2p/3YgLAeMKSAPcGqZkAt8mREqhQXmJ
def test_init_():
random_addrs = [random.randint(0, 255) for r in range(4)]
random_addrs = [
multiaddr.Multiaddr(f"/ip4/127.0.0.1/tcp/{1000 + i}") for i in range(4)
]
random_id_string = ""
for _ in range(10):
random_id_string += random.SystemRandom().choice(ALPHABETS)
peer_id = ID(random_id_string.encode())
peer_info = PeerInfo(peer_id, random_addrs)

View File

@ -1,5 +1,6 @@
import pytest
from libp2p.peer.id import ID
from libp2p.peer.peerstore import (
PeerStore,
PeerStoreError,
@ -11,36 +12,36 @@ from libp2p.peer.peerstore import (
def test_get_empty():
with pytest.raises(PeerStoreError):
store = PeerStore()
val = store.get("peer", "key")
val = store.get(ID(b"peer"), "key")
assert not val
def test_put_get_simple():
store = PeerStore()
store.put("peer", "key", "val")
assert store.get("peer", "key") == "val"
store.put(ID(b"peer"), "key", "val")
assert store.get(ID(b"peer"), "key") == "val"
def test_put_get_update():
store = PeerStore()
store.put("peer", "key1", "val1")
store.put("peer", "key2", "val2")
store.put("peer", "key2", "new val2")
store.put(ID(b"peer"), "key1", "val1")
store.put(ID(b"peer"), "key2", "val2")
store.put(ID(b"peer"), "key2", "new val2")
assert store.get("peer", "key1") == "val1"
assert store.get("peer", "key2") == "new val2"
assert store.get(ID(b"peer"), "key1") == "val1"
assert store.get(ID(b"peer"), "key2") == "new val2"
def test_put_get_two_peers():
store = PeerStore()
store.put("peer1", "key1", "val1")
store.put("peer2", "key1", "val1 prime")
store.put(ID(b"peer1"), "key1", "val1")
store.put(ID(b"peer2"), "key1", "val1 prime")
assert store.get("peer1", "key1") == "val1"
assert store.get("peer2", "key1") == "val1 prime"
assert store.get(ID(b"peer1"), "key1") == "val1"
assert store.get(ID(b"peer2"), "key1") == "val1 prime"
# Try update
store.put("peer2", "key1", "new val1")
store.put(ID(b"peer2"), "key1", "new val1")
assert store.get("peer1", "key1") == "val1"
assert store.get("peer2", "key1") == "new val1"
assert store.get(ID(b"peer1"), "key1") == "val1"
assert store.get(ID(b"peer2"), "key1") == "new val1"

View File

@ -1,5 +1,9 @@
import pytest
import time
import pytest
from multiaddr import Multiaddr
from libp2p.peer.id import ID
from libp2p.peer.peerstore import (
PeerStore,
PeerStoreError,
@ -11,52 +15,77 @@ from libp2p.peer.peerstore import (
def test_peer_info_empty():
store = PeerStore()
with pytest.raises(PeerStoreError):
store.peer_info("peer")
store.peer_info(ID(b"peer"))
def test_peer_info_basic():
store = PeerStore()
store.add_addr("peer", "/foo", 10)
info = store.peer_info("peer")
store.add_addr(ID(b"peer"), Multiaddr("/ip4/127.0.0.1/tcp/4001"), 1)
assert info.peer_id == "peer"
assert info.addrs == ["/foo"]
# update ttl to new value
store.add_addr(ID(b"peer"), Multiaddr("/ip4/127.0.0.1/tcp/4002"), 2)
time.sleep(1)
info = store.peer_info(ID(b"peer"))
assert info.peer_id == ID(b"peer")
assert info.addrs == [
Multiaddr("/ip4/127.0.0.1/tcp/4001"),
Multiaddr("/ip4/127.0.0.1/tcp/4002"),
]
# Check that addresses are cleared after ttl
time.sleep(2)
info = store.peer_info(ID(b"peer"))
assert info.peer_id == ID(b"peer")
assert info.addrs == []
assert store.peer_ids() == [ID(b"peer")]
assert store.valid_peer_ids() == []
# Check if all the data remains valid if ttl is set to default(0)
def test_peer_permanent_ttl():
store = PeerStore()
store.add_addr(ID(b"peer"), Multiaddr("/ip4/127.0.0.1/tcp/4001"))
time.sleep(1)
info = store.peer_info(ID(b"peer"))
assert info.peer_id == ID(b"peer")
assert info.addrs == [Multiaddr("/ip4/127.0.0.1/tcp/4001")]
def test_add_get_protocols_basic():
store = PeerStore()
store.add_protocols("peer1", ["p1", "p2"])
store.add_protocols("peer2", ["p3"])
store.add_protocols(ID(b"peer1"), ["p1", "p2"])
store.add_protocols(ID(b"peer2"), ["p3"])
assert set(store.get_protocols("peer1")) == {"p1", "p2"}
assert set(store.get_protocols("peer2")) == {"p3"}
assert set(store.get_protocols(ID(b"peer1"))) == {"p1", "p2"}
assert set(store.get_protocols(ID(b"peer2"))) == {"p3"}
def test_add_get_protocols_extend():
store = PeerStore()
store.add_protocols("peer1", ["p1", "p2"])
store.add_protocols("peer1", ["p3"])
store.add_protocols(ID(b"peer1"), ["p1", "p2"])
store.add_protocols(ID(b"peer1"), ["p3"])
assert set(store.get_protocols("peer1")) == {"p1", "p2", "p3"}
assert set(store.get_protocols(ID(b"peer1"))) == {"p1", "p2", "p3"}
def test_set_protocols():
store = PeerStore()
store.add_protocols("peer1", ["p1", "p2"])
store.add_protocols("peer2", ["p3"])
store.add_protocols(ID(b"peer1"), ["p1", "p2"])
store.add_protocols(ID(b"peer2"), ["p3"])
store.set_protocols("peer1", ["p4"])
store.set_protocols("peer2", [])
store.set_protocols(ID(b"peer1"), ["p4"])
store.set_protocols(ID(b"peer2"), [])
assert set(store.get_protocols("peer1")) == {"p4"}
assert set(store.get_protocols("peer2")) == set()
assert set(store.get_protocols(ID(b"peer1"))) == {"p4"}
assert set(store.get_protocols(ID(b"peer2"))) == set()
# Test with methods from other Peer interfaces.
def test_peers():
store = PeerStore()
store.add_protocols("peer1", [])
store.put("peer2", "key", "val")
store.add_addr("peer3", "/foo", 10)
store.add_protocols(ID(b"peer1"), [])
store.put(ID(b"peer2"), "key", "val")
store.add_addr(ID(b"peer3"), Multiaddr("/ip4/127.0.0.1/tcp/4001"), 10)
assert set(store.peer_ids()) == {"peer1", "peer2", "peer3"}
assert set(store.peer_ids()) == {ID(b"peer1"), ID(b"peer2"), ID(b"peer3")}

Some files were not shown because too many files have changed in this diff Show More