diff --git a/.github/workflows/tox.yml b/.github/workflows/tox.yml index 8fe058f6..ef963f80 100644 --- a/.github/workflows/tox.yml +++ b/.github/workflows/tox.yml @@ -16,10 +16,10 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python: ['3.9', '3.10', '3.11', '3.12', '3.13'] + python: ["3.10", "3.11", "3.12", "3.13"] toxenv: [core, interop, lint, wheel, demos] include: - - python: '3.10' + - python: "3.10" toxenv: docs fail-fast: false steps: @@ -46,7 +46,7 @@ jobs: runs-on: windows-latest strategy: matrix: - python-version: ['3.11', '3.12', '3.13'] + python-version: ["3.11", "3.12", "3.13"] toxenv: [core, wheel] fail-fast: false steps: diff --git a/.gitignore b/.gitignore index 192718c6..e46cc8aa 100644 --- a/.gitignore +++ b/.gitignore @@ -146,6 +146,9 @@ instance/ # PyBuilder target/ +# PyRight Config +pyrightconfig.json + # Jupyter Notebook .ipynb_checkpoints @@ -171,3 +174,7 @@ env.bak/ # mkdocs documentation /site + +#lockfiles +uv.lock +poetry.lock diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1712b7f1..962f4046 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,59 +1,49 @@ exclude: '.project-template|docs/conf.py|.*pb2\..*' repos: -- repo: https://github.com/pre-commit/pre-commit-hooks + - repo: https://github.com/pre-commit/pre-commit-hooks rev: v5.0.0 hooks: - - id: check-yaml - - id: check-toml - - id: end-of-file-fixer - - id: trailing-whitespace -- repo: https://github.com/asottile/pyupgrade - rev: v3.15.0 + - id: check-yaml + - id: check-toml + - id: end-of-file-fixer + - id: trailing-whitespace + - repo: https://github.com/asottile/pyupgrade + rev: v3.20.0 hooks: - - id: pyupgrade - args: [--py39-plus] -- repo: https://github.com/psf/black - rev: 23.9.1 + - id: pyupgrade + args: [--py310-plus] + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.11.10 hooks: - - id: black -- repo: https://github.com/PyCQA/flake8 - rev: 6.1.0 - hooks: - - id: flake8 - additional_dependencies: - - flake8-bugbear==23.9.16 - exclude: setup.py -- repo: https://github.com/PyCQA/autoflake - rev: v2.2.1 - hooks: - - id: autoflake -- repo: https://github.com/pycqa/isort - rev: 5.12.0 - hooks: - - id: isort -- repo: https://github.com/pycqa/pydocstyle - rev: 6.3.0 - hooks: - - id: pydocstyle - additional_dependencies: - - tomli # required until >= python311 -- repo: https://github.com/executablebooks/mdformat + - id: ruff + args: [--fix, --exit-non-zero-on-fix] + - id: ruff-format + - repo: https://github.com/executablebooks/mdformat rev: 0.7.22 hooks: - - id: mdformat + - id: mdformat additional_dependencies: - - mdformat-gfm -- repo: local + - mdformat-gfm + - repo: local hooks: - - id: mypy-local + - id: mypy-local name: run mypy with all dev dependencies present - entry: python -m mypy -p libp2p + entry: mypy -p libp2p language: system always_run: true pass_filenames: false -- repo: local + - repo: local hooks: - - id: check-rst-files + - id: pyrefly-local + name: run pyrefly typecheck locally + entry: pyrefly check + language: system + always_run: true + pass_filenames: false + + - repo: local + hooks: + - id: check-rst-files name: Check for .rst files in the top-level directory entry: python -c "import glob, sys; rst_files = glob.glob('*.rst'); sys.exit(1) if rst_files else sys.exit(0)" language: system diff --git a/.project-template/fill_template_vars.py b/.project-template/fill_template_vars.py deleted file mode 100644 index 52ceb02b..00000000 --- a/.project-template/fill_template_vars.py +++ /dev/null @@ -1,71 +0,0 @@ -#!/usr/bin/env python3 - -import os -import sys -import re -from pathlib import Path - - -def _find_files(project_root): - path_exclude_pattern = r"\.git($|\/)|venv|_build" - file_exclude_pattern = r"fill_template_vars\.py|\.swp$" - filepaths = [] - for dir_path, _dir_names, file_names in os.walk(project_root): - if not re.search(path_exclude_pattern, dir_path): - for file in file_names: - if not re.search(file_exclude_pattern, file): - filepaths.append(str(Path(dir_path, file))) - - return filepaths - - -def _replace(pattern, replacement, project_root): - print(f"Replacing values: {pattern}") - for file in _find_files(project_root): - try: - with open(file) as f: - content = f.read() - content = re.sub(pattern, replacement, content) - with open(file, "w") as f: - f.write(content) - except UnicodeDecodeError: - pass - - -def main(): - project_root = Path(os.path.realpath(sys.argv[0])).parent.parent - - module_name = input("What is your python module name? ") - - pypi_input = input(f"What is your pypi package name? (default: {module_name}) ") - pypi_name = pypi_input or module_name - - repo_input = input(f"What is your github project name? (default: {pypi_name}) ") - repo_name = repo_input or pypi_name - - rtd_input = input( - f"What is your readthedocs.org project name? (default: {pypi_name}) " - ) - rtd_name = rtd_input or pypi_name - - project_input = input( - f"What is your project name (ex: at the top of the README)? (default: {repo_name}) " - ) - project_name = project_input or repo_name - - short_description = input("What is a one-liner describing the project? ") - - _replace("", module_name, project_root) - _replace("", pypi_name, project_root) - _replace("", repo_name, project_root) - _replace("", rtd_name, project_root) - _replace("", project_name, project_root) - _replace("", short_description, project_root) - - os.makedirs(project_root / module_name, exist_ok=True) - Path(project_root / module_name / "__init__.py").touch() - Path(project_root / module_name / "py.typed").touch() - - -if __name__ == "__main__": - main() diff --git a/.project-template/refill_template_vars.py b/.project-template/refill_template_vars.py deleted file mode 100644 index 03ab7c0c..00000000 --- a/.project-template/refill_template_vars.py +++ /dev/null @@ -1,39 +0,0 @@ -#!/usr/bin/env python3 - -import os -import sys -from pathlib import Path -import subprocess - - -def main(): - template_dir = Path(os.path.dirname(sys.argv[0])) - template_vars_file = template_dir / "template_vars.txt" - fill_template_vars_script = template_dir / "fill_template_vars.py" - - with open(template_vars_file, "r") as input_file: - content_lines = input_file.readlines() - - process = subprocess.Popen( - [sys.executable, str(fill_template_vars_script)], - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - ) - - for line in content_lines: - process.stdin.write(line) - process.stdin.flush() - - stdout, stderr = process.communicate() - - if process.returncode != 0: - print(f"Error occurred: {stderr}") - sys.exit(1) - - print(stdout) - - -if __name__ == "__main__": - main() diff --git a/.project-template/template_vars.txt b/.project-template/template_vars.txt deleted file mode 100644 index ce0a492e..00000000 --- a/.project-template/template_vars.txt +++ /dev/null @@ -1,6 +0,0 @@ -libp2p -libp2p -py-libp2p -py-libp2p -py-libp2p -The Python implementation of the libp2p networking stack diff --git a/Makefile b/Makefile index 3977db58..3f5ce5ea 100644 --- a/Makefile +++ b/Makefile @@ -7,6 +7,7 @@ help: @echo "clean-pyc - remove Python file artifacts" @echo "clean - run clean-build and clean-pyc" @echo "dist - build package and cat contents of the dist directory" + @echo "fix - fix formatting & linting issues with ruff" @echo "lint - fix linting issues with pre-commit" @echo "test - run tests quickly with the default Python" @echo "docs - generate docs and open in browser (linux-docs for version on linux)" @@ -37,8 +38,14 @@ lint: && pre-commit run --all-files --show-diff-on-failure \ ) +fix: + python -m ruff check --fix + +typecheck: + pre-commit run mypy-local --all-files && pre-commit run pyrefly-local --all-files + test: - python -m pytest tests + python -m pytest tests -n auto # protobufs management diff --git a/docs/conf.py b/docs/conf.py index 6d18b63f..446252f1 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -15,14 +15,24 @@ # documentation root, use os.path.abspath to make it absolute, like shown here. # sys.path.insert(0, os.path.abspath('.')) +import doctest import os +import sys +from unittest.mock import MagicMock -DIR = os.path.dirname(__file__) -with open(os.path.join(DIR, "../setup.py"), "r") as f: - for line in f: - if "version=" in line: - setup_version = line.split('"')[1] - break +try: + import tomllib +except ModuleNotFoundError: + # For Python < 3.11 + import tomli as tomllib # type: ignore (In case of >3.11 Pyrefly doesnt find tomli , which is right but a false flag) + +# Path to pyproject.toml (assuming conf.py is in a 'docs' subdirectory) +pyproject_path = os.path.join(os.path.dirname(__file__), "..", "pyproject.toml") + +with open(pyproject_path, "rb") as f: + pyproject_data = tomllib.load(f) + +setup_version = pyproject_data["project"]["version"] # -- General configuration ------------------------------------------------ @@ -302,7 +312,6 @@ intersphinx_mapping = { # -- Doctest configuration ---------------------------------------- -import doctest doctest_default_flags = ( 0 @@ -317,10 +326,9 @@ doctest_default_flags = ( # Mock out dependencies that are unbuildable on readthedocs, as recommended here: # https://docs.readthedocs.io/en/rel/faq.html#i-get-import-errors-on-libraries-that-depend-on-c-modules -import sys -from unittest.mock import MagicMock -# Add new modules to mock here (it should be the same list as those excluded in setup.py) +# Add new modules to mock here (it should be the same list +# as those excluded in pyproject.toml) MOCK_MODULES = [ "fastecdsa", "fastecdsa.encoding", @@ -338,4 +346,4 @@ todo_include_todos = True # Allow duplicate object descriptions nitpicky = False -nitpick_ignore = [("py:class", "type")] \ No newline at end of file +nitpick_ignore = [("py:class", "type")] diff --git a/examples/doc-examples/example_encryption_insecure.py b/examples/doc-examples/example_encryption_insecure.py index acd947e7..dae23a68 100644 --- a/examples/doc-examples/example_encryption_insecure.py +++ b/examples/doc-examples/example_encryption_insecure.py @@ -24,9 +24,6 @@ async def main(): insecure_transport = InsecureTransport( # local_key_pair: The key pair used for libp2p identity local_key_pair=key_pair, - # secure_bytes_provider: Optional function to generate secure random bytes - # (defaults to secrets.token_bytes) - secure_bytes_provider=None, # Use default implementation ) # Create a security options dictionary mapping protocol ID to transport diff --git a/examples/doc-examples/example_encryption_noise.py b/examples/doc-examples/example_encryption_noise.py index 4918dc6f..a2a4318c 100644 --- a/examples/doc-examples/example_encryption_noise.py +++ b/examples/doc-examples/example_encryption_noise.py @@ -9,8 +9,10 @@ from libp2p import ( from libp2p.crypto.secp256k1 import ( create_new_key_pair, ) -from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID -from libp2p.security.noise.transport import Transport as NoiseTransport +from libp2p.security.noise.transport import ( + PROTOCOL_ID as NOISE_PROTOCOL_ID, + Transport as NoiseTransport, +) async def main(): diff --git a/examples/doc-examples/example_encryption_secio.py b/examples/doc-examples/example_encryption_secio.py index 6204031b..603ad6ea 100644 --- a/examples/doc-examples/example_encryption_secio.py +++ b/examples/doc-examples/example_encryption_secio.py @@ -9,8 +9,10 @@ from libp2p import ( from libp2p.crypto.secp256k1 import ( create_new_key_pair, ) -from libp2p.security.secio.transport import ID as SECIO_PROTOCOL_ID -from libp2p.security.secio.transport import Transport as SecioTransport +from libp2p.security.secio.transport import ( + ID as SECIO_PROTOCOL_ID, + Transport as SecioTransport, +) async def main(): @@ -22,9 +24,6 @@ async def main(): secio_transport = SecioTransport( # local_key_pair: The key pair used for libp2p identity and authentication local_key_pair=key_pair, - # secure_bytes_provider: Optional function to generate secure random bytes - # (defaults to secrets.token_bytes) - secure_bytes_provider=None, # Use default implementation ) # Create a security options dictionary mapping protocol ID to transport diff --git a/examples/doc-examples/example_multiplexer.py b/examples/doc-examples/example_multiplexer.py index 7cbf29f0..0d6f2662 100644 --- a/examples/doc-examples/example_multiplexer.py +++ b/examples/doc-examples/example_multiplexer.py @@ -9,10 +9,9 @@ from libp2p import ( from libp2p.crypto.secp256k1 import ( create_new_key_pair, ) -from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID -from libp2p.security.noise.transport import Transport as NoiseTransport -from libp2p.stream_muxer.mplex.mplex import ( - MPLEX_PROTOCOL_ID, +from libp2p.security.noise.transport import ( + PROTOCOL_ID as NOISE_PROTOCOL_ID, + Transport as NoiseTransport, ) @@ -37,14 +36,8 @@ async def main(): # Create a security options dictionary mapping protocol ID to transport security_options = {NOISE_PROTOCOL_ID: noise_transport} - # Create a muxer options dictionary mapping protocol ID to muxer class - # We don't need to instantiate the muxer here, the host will do that for us - muxer_options = {MPLEX_PROTOCOL_ID: None} - # Create a host with the key pair, Noise security, and mplex multiplexer - host = new_host( - key_pair=key_pair, sec_opt=security_options, muxer_opt=muxer_options - ) + host = new_host(key_pair=key_pair, sec_opt=security_options) # Configure the listening address port = 8000 diff --git a/examples/doc-examples/example_peer_discovery.py b/examples/doc-examples/example_peer_discovery.py index dd789ad0..7ceec375 100644 --- a/examples/doc-examples/example_peer_discovery.py +++ b/examples/doc-examples/example_peer_discovery.py @@ -12,10 +12,9 @@ from libp2p.crypto.secp256k1 import ( from libp2p.peer.peerinfo import ( info_from_p2p_addr, ) -from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID -from libp2p.security.noise.transport import Transport as NoiseTransport -from libp2p.stream_muxer.mplex.mplex import ( - MPLEX_PROTOCOL_ID, +from libp2p.security.noise.transport import ( + PROTOCOL_ID as NOISE_PROTOCOL_ID, + Transport as NoiseTransport, ) @@ -40,14 +39,8 @@ async def main(): # Create a security options dictionary mapping protocol ID to transport security_options = {NOISE_PROTOCOL_ID: noise_transport} - # Create a muxer options dictionary mapping protocol ID to muxer class - # We don't need to instantiate the muxer here, the host will do that for us - muxer_options = {MPLEX_PROTOCOL_ID: None} - # Create a host with the key pair, Noise security, and mplex multiplexer - host = new_host( - key_pair=key_pair, sec_opt=security_options, muxer_opt=muxer_options - ) + host = new_host(key_pair=key_pair, sec_opt=security_options) # Configure the listening address port = 8000 diff --git a/examples/doc-examples/example_running.py b/examples/doc-examples/example_running.py index c9d3d053..a0169931 100644 --- a/examples/doc-examples/example_running.py +++ b/examples/doc-examples/example_running.py @@ -9,10 +9,9 @@ from libp2p import ( from libp2p.crypto.secp256k1 import ( create_new_key_pair, ) -from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID -from libp2p.security.noise.transport import Transport as NoiseTransport -from libp2p.stream_muxer.mplex.mplex import ( - MPLEX_PROTOCOL_ID, +from libp2p.security.noise.transport import ( + PROTOCOL_ID as NOISE_PROTOCOL_ID, + Transport as NoiseTransport, ) @@ -37,14 +36,8 @@ async def main(): # Create a security options dictionary mapping protocol ID to transport security_options = {NOISE_PROTOCOL_ID: noise_transport} - # Create a muxer options dictionary mapping protocol ID to muxer class - # We don't need to instantiate the muxer here, the host will do that for us - muxer_options = {MPLEX_PROTOCOL_ID: None} - # Create a host with the key pair, Noise security, and mplex multiplexer - host = new_host( - key_pair=key_pair, sec_opt=security_options, muxer_opt=muxer_options - ) + host = new_host(key_pair=key_pair, sec_opt=security_options) # Configure the listening address port = 8000 diff --git a/examples/echo/echo.py b/examples/echo/echo.py index 0f6c28ab..382d4f27 100644 --- a/examples/echo/echo.py +++ b/examples/echo/echo.py @@ -29,7 +29,7 @@ async def _echo_stream_handler(stream: INetStream) -> None: await stream.close() -async def run(port: int, destination: str, seed: int = None) -> None: +async def run(port: int, destination: str, seed: int | None = None) -> None: localhost_ip = "127.0.0.1" listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") diff --git a/examples/identify_push/identify_push_listener_dialer.py b/examples/identify_push/identify_push_listener_dialer.py index fc106af6..7df222d7 100644 --- a/examples/identify_push/identify_push_listener_dialer.py +++ b/examples/identify_push/identify_push_listener_dialer.py @@ -38,17 +38,17 @@ from libp2p.crypto.secp256k1 import ( create_new_key_pair, ) from libp2p.identity.identify import ( + ID as ID_IDENTIFY, identify_handler_for, ) -from libp2p.identity.identify import ID as ID_IDENTIFY from libp2p.identity.identify.pb.identify_pb2 import ( Identify, ) from libp2p.identity.identify_push import ( + ID_PUSH as ID_IDENTIFY_PUSH, identify_push_handler_for, push_identify_to_peer, ) -from libp2p.identity.identify_push import ID_PUSH as ID_IDENTIFY_PUSH from libp2p.peer.peerinfo import ( info_from_p2p_addr, ) diff --git a/examples/pubsub/pubsub.py b/examples/pubsub/pubsub.py index 9f853744..9dca415f 100644 --- a/examples/pubsub/pubsub.py +++ b/examples/pubsub/pubsub.py @@ -1,9 +1,6 @@ import argparse import logging import socket -from typing import ( - Optional, -) import base58 import multiaddr @@ -109,7 +106,7 @@ async def monitor_peer_topics(pubsub, nursery, termination_event): await trio.sleep(2) -async def run(topic: str, destination: Optional[str], port: Optional[int]) -> None: +async def run(topic: str, destination: str | None, port: int | None) -> None: # Initialize network settings localhost_ip = "127.0.0.1" diff --git a/libp2p/__init__.py b/libp2p/__init__.py index c05d05e5..64f47243 100644 --- a/libp2p/__init__.py +++ b/libp2p/__init__.py @@ -152,12 +152,12 @@ def get_default_muxer_options() -> TMuxerOptions: def new_swarm( - key_pair: Optional[KeyPair] = None, - muxer_opt: Optional[TMuxerOptions] = None, - sec_opt: Optional[TSecurityOptions] = None, - peerstore_opt: Optional[IPeerStore] = None, - muxer_preference: Optional[Literal["YAMUX", "MPLEX"]] = None, - listen_addrs: Optional[Sequence[multiaddr.Multiaddr]] = None, + key_pair: KeyPair | None = None, + muxer_opt: TMuxerOptions | None = None, + sec_opt: TSecurityOptions | None = None, + peerstore_opt: IPeerStore | None = None, + muxer_preference: Literal["YAMUX", "MPLEX"] | None = None, + listen_addrs: Sequence[multiaddr.Multiaddr] | None = None, ) -> INetworkService: """ Create a swarm instance based on the parameters. @@ -236,13 +236,13 @@ def new_swarm( def new_host( - key_pair: Optional[KeyPair] = None, - muxer_opt: Optional[TMuxerOptions] = None, - sec_opt: Optional[TSecurityOptions] = None, - peerstore_opt: Optional[IPeerStore] = None, - disc_opt: Optional[IPeerRouting] = None, - muxer_preference: Optional[Literal["YAMUX", "MPLEX"]] = None, - listen_addrs: Sequence[multiaddr.Multiaddr] = None, + key_pair: KeyPair | None = None, + muxer_opt: TMuxerOptions | None = None, + sec_opt: TSecurityOptions | None = None, + peerstore_opt: IPeerStore | None = None, + disc_opt: IPeerRouting | None = None, + muxer_preference: Literal["YAMUX", "MPLEX"] | None = None, + listen_addrs: Sequence[multiaddr.Multiaddr] | None = None, ) -> IHost: """ Create a new libp2p host based on the given parameters. diff --git a/libp2p/abc.py b/libp2p/abc.py index f9686bac..06570eaa 100644 --- a/libp2p/abc.py +++ b/libp2p/abc.py @@ -8,6 +8,7 @@ from collections.abc import ( KeysView, Sequence, ) +from contextlib import AbstractAsyncContextManager from types import ( TracebackType, ) @@ -15,7 +16,6 @@ from typing import ( TYPE_CHECKING, Any, AsyncContextManager, - Optional, ) from multiaddr import ( @@ -160,7 +160,11 @@ class IMuxedConn(ABC): event_started: trio.Event @abstractmethod - def __init__(self, conn: ISecureConn, peer_id: ID) -> None: + def __init__( + self, + conn: ISecureConn, + peer_id: ID, + ) -> None: """ Initialize a new multiplexed connection. @@ -260,9 +264,9 @@ class IMuxedStream(ReadWriteCloser, AsyncContextManager["IMuxedStream"]): async def __aexit__( self, - exc_type: Optional[type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> None: """Exit the async context manager and close the stream.""" await self.close() @@ -287,7 +291,7 @@ class INetStream(ReadWriteCloser): muxed_conn: IMuxedConn @abstractmethod - def get_protocol(self) -> TProtocol: + def get_protocol(self) -> TProtocol | None: """ Retrieve the protocol identifier for the stream. @@ -916,7 +920,7 @@ class INetwork(ABC): """ @abstractmethod - async def listen(self, *multiaddrs: Sequence[Multiaddr]) -> bool: + async def listen(self, *multiaddrs: Multiaddr) -> bool: """ Start listening on one or more multiaddresses. @@ -1174,7 +1178,9 @@ class IHost(ABC): """ @abstractmethod - def run(self, listen_addrs: Sequence[Multiaddr]) -> AsyncContextManager[None]: + def run( + self, listen_addrs: Sequence[Multiaddr] + ) -> AbstractAsyncContextManager[None]: """ Run the host and start listening on the specified multiaddresses. @@ -1564,7 +1570,7 @@ class IMultiselectMuxer(ABC): and its corresponding handler for communication. """ - handlers: dict[TProtocol, StreamHandlerFn] + handlers: dict[TProtocol | None, StreamHandlerFn | None] @abstractmethod def add_handler(self, protocol: TProtocol, handler: StreamHandlerFn) -> None: @@ -1580,7 +1586,7 @@ class IMultiselectMuxer(ABC): """ - def get_protocols(self) -> tuple[TProtocol, ...]: + def get_protocols(self) -> tuple[TProtocol | None, ...]: """ Retrieve the protocols for which handlers have been registered. @@ -1595,7 +1601,7 @@ class IMultiselectMuxer(ABC): @abstractmethod async def negotiate( self, communicator: IMultiselectCommunicator - ) -> tuple[TProtocol, StreamHandlerFn]: + ) -> tuple[TProtocol | None, StreamHandlerFn | None]: """ Negotiate a protocol selection with a multiselect client. @@ -1672,7 +1678,7 @@ class IPeerRouting(ABC): """ @abstractmethod - async def find_peer(self, peer_id: ID) -> PeerInfo: + async def find_peer(self, peer_id: ID) -> PeerInfo | None: """ Search for a peer with the specified peer ID. @@ -1840,6 +1846,11 @@ class IPubsubRouter(ABC): """ + mesh: dict[str, set[ID]] + fanout: dict[str, set[ID]] + peer_protocol: dict[ID, TProtocol] + degree: int + @abstractmethod def get_protocols(self) -> list[TProtocol]: """ @@ -1865,7 +1876,7 @@ class IPubsubRouter(ABC): """ @abstractmethod - def add_peer(self, peer_id: ID, protocol_id: TProtocol) -> None: + def add_peer(self, peer_id: ID, protocol_id: TProtocol | None) -> None: """ Notify the router that a new peer has connected. diff --git a/libp2p/crypto/authenticated_encryption.py b/libp2p/crypto/authenticated_encryption.py index 7683fe90..70f15d45 100644 --- a/libp2p/crypto/authenticated_encryption.py +++ b/libp2p/crypto/authenticated_encryption.py @@ -116,15 +116,15 @@ def initialize_pair( EncryptionParameters( cipher_type, hash_type, - first_half[0:iv_size], - first_half[iv_size + cipher_key_size :], - first_half[iv_size : iv_size + cipher_key_size], + bytes(first_half[0:iv_size]), + bytes(first_half[iv_size + cipher_key_size :]), + bytes(first_half[iv_size : iv_size + cipher_key_size]), ), EncryptionParameters( cipher_type, hash_type, - second_half[0:iv_size], - second_half[iv_size + cipher_key_size :], - second_half[iv_size : iv_size + cipher_key_size], + bytes(second_half[0:iv_size]), + bytes(second_half[iv_size + cipher_key_size :]), + bytes(second_half[iv_size : iv_size + cipher_key_size]), ), ) diff --git a/libp2p/crypto/ecc.py b/libp2p/crypto/ecc.py index ec31bc3e..d78741d2 100644 --- a/libp2p/crypto/ecc.py +++ b/libp2p/crypto/ecc.py @@ -9,29 +9,40 @@ from libp2p.crypto.keys import ( if sys.platform != "win32": from fastecdsa import ( + curve as curve_types, keys, point, ) - from fastecdsa import curve as curve_types from fastecdsa.encoding.sec1 import ( SEC1Encoder, ) else: - from coincurve import PrivateKey as CPrivateKey - from coincurve import PublicKey as CPublicKey + from coincurve import ( + PrivateKey as CPrivateKey, + PublicKey as CPublicKey, + ) -def infer_local_type(curve: str) -> object: - """ - Convert a str representation of some elliptic curve to a - representation understood by the backend of this module. - """ - if curve != "P-256": - raise NotImplementedError("Only P-256 curve is supported") +if sys.platform != "win32": - if sys.platform != "win32": + def infer_local_type(curve: str) -> curve_types.Curve: + """ + Convert a str representation of some elliptic curve to a + representation understood by the backend of this module. + """ + if curve != "P-256": + raise NotImplementedError("Only P-256 curve is supported") return curve_types.P256 - return "P-256" # coincurve only supports P-256 +else: + + def infer_local_type(curve: str) -> str: + """ + Convert a str representation of some elliptic curve to a + representation understood by the backend of this module. + """ + if curve != "P-256": + raise NotImplementedError("Only P-256 curve is supported") + return "P-256" # coincurve only supports P-256 if sys.platform != "win32": @@ -68,7 +79,10 @@ if sys.platform != "win32": return cls(private_key_impl, curve_type) def to_bytes(self) -> bytes: - return keys.export_key(self.impl, self.curve) + key_str = keys.export_key(self.impl, self.curve) + if key_str is None: + raise Exception("Key not found") + return key_str.encode() def get_type(self) -> KeyType: return KeyType.ECC_P256 diff --git a/libp2p/crypto/ed25519.py b/libp2p/crypto/ed25519.py index 01a7a98f..66960676 100644 --- a/libp2p/crypto/ed25519.py +++ b/libp2p/crypto/ed25519.py @@ -4,8 +4,10 @@ from Crypto.Hash import ( from nacl.exceptions import ( BadSignatureError, ) -from nacl.public import PrivateKey as PrivateKeyImpl -from nacl.public import PublicKey as PublicKeyImpl +from nacl.public import ( + PrivateKey as PrivateKeyImpl, + PublicKey as PublicKeyImpl, +) from nacl.signing import ( SigningKey, VerifyKey, @@ -48,7 +50,7 @@ class Ed25519PrivateKey(PrivateKey): self.impl = impl @classmethod - def new(cls, seed: bytes = None) -> "Ed25519PrivateKey": + def new(cls, seed: bytes | None = None) -> "Ed25519PrivateKey": if not seed: seed = utils.random() @@ -75,7 +77,7 @@ class Ed25519PrivateKey(PrivateKey): return Ed25519PublicKey(self.impl.public_key) -def create_new_key_pair(seed: bytes = None) -> KeyPair: +def create_new_key_pair(seed: bytes | None = None) -> KeyPair: private_key = Ed25519PrivateKey.new(seed) public_key = private_key.get_public_key() return KeyPair(private_key, public_key) diff --git a/libp2p/crypto/key_exchange.py b/libp2p/crypto/key_exchange.py index 5a713fd3..f8bc13eb 100644 --- a/libp2p/crypto/key_exchange.py +++ b/libp2p/crypto/key_exchange.py @@ -1,6 +1,6 @@ +from collections.abc import Callable import sys from typing import ( - Callable, cast, ) diff --git a/libp2p/crypto/keys.py b/libp2p/crypto/keys.py index 4a4f78a6..21cf71b2 100644 --- a/libp2p/crypto/keys.py +++ b/libp2p/crypto/keys.py @@ -81,12 +81,10 @@ class PrivateKey(Key): """A ``PrivateKey`` represents a cryptographic private key.""" @abstractmethod - def sign(self, data: bytes) -> bytes: - ... + def sign(self, data: bytes) -> bytes: ... @abstractmethod - def get_public_key(self) -> PublicKey: - ... + def get_public_key(self) -> PublicKey: ... def _serialize_to_protobuf(self) -> crypto_pb2.PrivateKey: """Return the protobuf representation of this ``Key``.""" diff --git a/libp2p/crypto/secp256k1.py b/libp2p/crypto/secp256k1.py index 6ed97190..44c32162 100644 --- a/libp2p/crypto/secp256k1.py +++ b/libp2p/crypto/secp256k1.py @@ -37,7 +37,7 @@ class Secp256k1PrivateKey(PrivateKey): self.impl = impl @classmethod - def new(cls, secret: bytes = None) -> "Secp256k1PrivateKey": + def new(cls, secret: bytes | None = None) -> "Secp256k1PrivateKey": private_key_impl = coincurve.PrivateKey(secret) return cls(private_key_impl) @@ -65,7 +65,7 @@ class Secp256k1PrivateKey(PrivateKey): return Secp256k1PublicKey(public_key_impl) -def create_new_key_pair(secret: bytes = None) -> KeyPair: +def create_new_key_pair(secret: bytes | None = None) -> KeyPair: """ Returns a new Secp256k1 keypair derived from the provided ``secret``, a sequence of bytes corresponding to some integer between 0 and the group diff --git a/libp2p/custom_types.py b/libp2p/custom_types.py index 1789844c..0b844133 100644 --- a/libp2p/custom_types.py +++ b/libp2p/custom_types.py @@ -1,13 +1,9 @@ from collections.abc import ( Awaitable, + Callable, Mapping, ) -from typing import ( - TYPE_CHECKING, - Callable, - NewType, - Union, -) +from typing import TYPE_CHECKING, NewType, Union, cast if TYPE_CHECKING: from libp2p.abc import ( @@ -16,15 +12,9 @@ if TYPE_CHECKING: ISecureTransport, ) else: - - class INetStream: - pass - - class IMuxedConn: - pass - - class ISecureTransport: - pass + IMuxedConn = cast(type, object) + INetStream = cast(type, object) + ISecureTransport = cast(type, object) from libp2p.io.abc import ( @@ -38,10 +28,10 @@ from libp2p.pubsub.pb import ( ) TProtocol = NewType("TProtocol", str) -StreamHandlerFn = Callable[["INetStream"], Awaitable[None]] +StreamHandlerFn = Callable[[INetStream], Awaitable[None]] THandler = Callable[[ReadWriteCloser], Awaitable[None]] -TSecurityOptions = Mapping[TProtocol, "ISecureTransport"] -TMuxerClass = type["IMuxedConn"] +TSecurityOptions = Mapping[TProtocol, ISecureTransport] +TMuxerClass = type[IMuxedConn] TMuxerOptions = Mapping[TProtocol, TMuxerClass] SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool] AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]] diff --git a/libp2p/host/autonat/autonat.py b/libp2p/host/autonat/autonat.py index 29723a3e..ae4663f1 100644 --- a/libp2p/host/autonat/autonat.py +++ b/libp2p/host/autonat/autonat.py @@ -1,7 +1,4 @@ import logging -from typing import ( - Union, -) from libp2p.custom_types import ( TProtocol, @@ -94,7 +91,7 @@ class AutoNATService: finally: await stream.close() - async def _handle_request(self, request: Union[bytes, Message]) -> Message: + async def _handle_request(self, request: bytes | Message) -> Message: """ Process an AutoNAT protocol request. diff --git a/libp2p/host/autonat/pb/autonat_pb2_grpc.py b/libp2p/host/autonat/pb/autonat_pb2_grpc.py index de6f77d2..179738ad 100644 --- a/libp2p/host/autonat/pb/autonat_pb2_grpc.py +++ b/libp2p/host/autonat/pb/autonat_pb2_grpc.py @@ -84,26 +84,23 @@ class AutoNAT: request: Any, target: str, options: tuple[Any, ...] = (), - channel_credentials: Optional[Any] = None, - call_credentials: Optional[Any] = None, + channel_credentials: Any | None = None, + call_credentials: Any | None = None, insecure: bool = False, - compression: Optional[Any] = None, - wait_for_ready: Optional[bool] = None, - timeout: Optional[float] = None, - metadata: Optional[list[tuple[str, str]]] = None, + compression: Any | None = None, + wait_for_ready: bool | None = None, + timeout: float | None = None, + metadata: list[tuple[str, str]] | None = None, ) -> Any: - return grpc.experimental.unary_unary( - request, - target, + channel = grpc.secure_channel(target, channel_credentials) if channel_credentials else grpc.insecure_channel(target) + return channel.unary_unary( "/autonat.pb.AutoNAT/Dial", - autonat__pb2.Message.SerializeToString, - autonat__pb2.Message.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, + request_serializer=autonat__pb2.Message.SerializeToString, + response_deserializer=autonat__pb2.Message.FromString, + _registered_method=True, + )( + request, + timeout=timeout, + metadata=metadata, + wait_for_ready=wait_for_ready, ) diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index 60b31fe0..6d844bee 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -3,6 +3,7 @@ from collections.abc import ( Sequence, ) from contextlib import ( + AbstractAsyncContextManager, asynccontextmanager, ) import logging @@ -88,14 +89,14 @@ class BasicHost(IHost): def __init__( self, network: INetworkService, - default_protocols: "OrderedDict[TProtocol, StreamHandlerFn]" = None, + default_protocols: Optional["OrderedDict[TProtocol, StreamHandlerFn]"] = None, ) -> None: self._network = network self._network.set_stream_handler(self._swarm_stream_handler) self.peerstore = self._network.peerstore # Protocol muxing default_protocols = default_protocols or get_default_protocols(self) - self.multiselect = Multiselect(default_protocols) + self.multiselect = Multiselect(dict(default_protocols.items())) self.multiselect_client = MultiselectClient() def get_id(self) -> ID: @@ -147,19 +148,23 @@ class BasicHost(IHost): """ return list(self._network.connections.keys()) - @asynccontextmanager - async def run( + def run( self, listen_addrs: Sequence[multiaddr.Multiaddr] - ) -> AsyncIterator[None]: + ) -> AbstractAsyncContextManager[None]: """ Run the host instance and listen to ``listen_addrs``. :param listen_addrs: a sequence of multiaddrs that we want to listen to """ - network = self.get_network() - async with background_trio_service(network): - await network.listen(*listen_addrs) - yield + + @asynccontextmanager + async def _run() -> AsyncIterator[None]: + network = self.get_network() + async with background_trio_service(network): + await network.listen(*listen_addrs) + yield + + return _run() def set_stream_handler( self, protocol_id: TProtocol, stream_handler: StreamHandlerFn @@ -258,6 +263,15 @@ class BasicHost(IHost): await net_stream.reset() return net_stream.set_protocol(protocol) + if handler is None: + logger.debug( + "no handler for protocol %s, closing stream from peer %s", + protocol, + net_stream.muxed_conn.peer_id, + ) + await net_stream.reset() + return + await handler(net_stream) def get_live_peers(self) -> list[ID]: @@ -277,7 +291,7 @@ class BasicHost(IHost): """ return peer_id in self._network.connections - def get_peer_connection_info(self, peer_id: ID) -> Optional[INetConn]: + def get_peer_connection_info(self, peer_id: ID) -> INetConn | None: """ Get connection information for a specific peer if connected. diff --git a/libp2p/host/defaults.py b/libp2p/host/defaults.py index eb454dc5..b8c50886 100644 --- a/libp2p/host/defaults.py +++ b/libp2p/host/defaults.py @@ -9,13 +9,13 @@ from libp2p.abc import ( IHost, ) from libp2p.host.ping import ( + ID as PingID, handle_ping, ) -from libp2p.host.ping import ID as PingID from libp2p.identity.identify.identify import ( + ID as IdentifyID, identify_handler_for, ) -from libp2p.identity.identify.identify import ID as IdentifyID if TYPE_CHECKING: from libp2p.custom_types import ( diff --git a/libp2p/identity/identify/identify.py b/libp2p/identity/identify/identify.py index e157a85c..5d066e37 100644 --- a/libp2p/identity/identify/identify.py +++ b/libp2p/identity/identify/identify.py @@ -1,7 +1,4 @@ import logging -from typing import ( - Optional, -) from multiaddr import ( Multiaddr, @@ -40,8 +37,8 @@ def _multiaddr_to_bytes(maddr: Multiaddr) -> bytes: def _remote_address_to_multiaddr( - remote_address: Optional[tuple[str, int]] -) -> Optional[Multiaddr]: + remote_address: tuple[str, int] | None, +) -> Multiaddr | None: """Convert a (host, port) tuple to a Multiaddr.""" if remote_address is None: return None @@ -58,7 +55,7 @@ def _remote_address_to_multiaddr( def _mk_identify_protobuf( - host: IHost, observed_multiaddr: Optional[Multiaddr] + host: IHost, observed_multiaddr: Multiaddr | None ) -> Identify: public_key = host.get_public_key() laddrs = host.get_addrs() @@ -81,15 +78,14 @@ def identify_handler_for(host: IHost) -> StreamHandlerFn: peer_id = ( stream.muxed_conn.peer_id ) # remote peer_id is in class Mplex (mplex.py ) - + observed_multiaddr: Multiaddr | None = None # Get the remote address try: remote_address = stream.get_remote_address() # Convert to multiaddr if remote_address: observed_multiaddr = _remote_address_to_multiaddr(remote_address) - else: - observed_multiaddr = None + logger.debug( "Connection from remote peer %s, address: %s, multiaddr: %s", peer_id, diff --git a/libp2p/identity/identify_push/identify_push.py b/libp2p/identity/identify_push/identify_push.py index 883b63de..c649c368 100644 --- a/libp2p/identity/identify_push/identify_push.py +++ b/libp2p/identity/identify_push/identify_push.py @@ -1,7 +1,4 @@ import logging -from typing import ( - Optional, -) from multiaddr import ( Multiaddr, @@ -135,7 +132,7 @@ async def _update_peerstore_from_identify( async def push_identify_to_peer( - host: IHost, peer_id: ID, observed_multiaddr: Optional[Multiaddr] = None + host: IHost, peer_id: ID, observed_multiaddr: Multiaddr | None = None ) -> bool: """ Push an identify message to a specific peer. @@ -172,8 +169,8 @@ async def push_identify_to_peer( async def push_identify_to_peers( host: IHost, - peer_ids: Optional[set[ID]] = None, - observed_multiaddr: Optional[Multiaddr] = None, + peer_ids: set[ID] | None = None, + observed_multiaddr: Multiaddr | None = None, ) -> None: """ Push an identify message to multiple peers in parallel. diff --git a/libp2p/io/abc.py b/libp2p/io/abc.py index 75125fd8..0ea355cf 100644 --- a/libp2p/io/abc.py +++ b/libp2p/io/abc.py @@ -2,27 +2,22 @@ from abc import ( ABC, abstractmethod, ) -from typing import ( - Optional, -) +from typing import Any class Closer(ABC): @abstractmethod - async def close(self) -> None: - ... + async def close(self) -> None: ... class Reader(ABC): @abstractmethod - async def read(self, n: int = None) -> bytes: - ... + async def read(self, n: int | None = None) -> bytes: ... class Writer(ABC): @abstractmethod - async def write(self, data: bytes) -> None: - ... + async def write(self, data: bytes) -> None: ... class WriteCloser(Writer, Closer): @@ -39,7 +34,7 @@ class ReadWriter(Reader, Writer): class ReadWriteCloser(Reader, Writer, Closer): @abstractmethod - def get_remote_address(self) -> Optional[tuple[str, int]]: + def get_remote_address(self) -> tuple[str, int] | None: """ Return the remote address of the connected peer. @@ -50,14 +45,12 @@ class ReadWriteCloser(Reader, Writer, Closer): class MsgReader(ABC): @abstractmethod - async def read_msg(self) -> bytes: - ... + async def read_msg(self) -> bytes: ... class MsgWriter(ABC): @abstractmethod - async def write_msg(self, msg: bytes) -> None: - ... + async def write_msg(self, msg: bytes) -> None: ... class MsgReadWriteCloser(MsgReader, MsgWriter, Closer): @@ -66,19 +59,26 @@ class MsgReadWriteCloser(MsgReader, MsgWriter, Closer): class Encrypter(ABC): @abstractmethod - def encrypt(self, data: bytes) -> bytes: - ... + def encrypt(self, data: bytes) -> bytes: ... @abstractmethod - def decrypt(self, data: bytes) -> bytes: - ... + def decrypt(self, data: bytes) -> bytes: ... class EncryptedMsgReadWriter(MsgReadWriteCloser, Encrypter): """Read/write message with encryption/decryption.""" - def get_remote_address(self) -> Optional[tuple[str, int]]: + conn: Any | None + + def __init__(self, conn: Any | None = None): + self.conn = conn + + def get_remote_address(self) -> tuple[str, int] | None: """Get remote address if supported by the underlying connection.""" - if hasattr(self, "conn") and hasattr(self.conn, "get_remote_address"): + if ( + self.conn is not None + and hasattr(self, "conn") + and hasattr(self.conn, "get_remote_address") + ): return self.conn.get_remote_address() return None diff --git a/libp2p/io/msgio.py b/libp2p/io/msgio.py index fa049cbd..1cf7114b 100644 --- a/libp2p/io/msgio.py +++ b/libp2p/io/msgio.py @@ -5,6 +5,7 @@ from that repo: "a simple package to r/w length-delimited slices." NOTE: currently missing the capability to indicate lengths by "varint" method. """ + from abc import ( abstractmethod, ) @@ -60,12 +61,10 @@ class BaseMsgReadWriter(MsgReadWriteCloser): return await read_exactly(self.read_write_closer, length) @abstractmethod - async def next_msg_len(self) -> int: - ... + async def next_msg_len(self) -> int: ... @abstractmethod - def encode_msg(self, msg: bytes) -> bytes: - ... + def encode_msg(self, msg: bytes) -> bytes: ... async def close(self) -> None: await self.read_write_closer.close() diff --git a/libp2p/io/trio.py b/libp2p/io/trio.py index f0301b90..29a808cd 100644 --- a/libp2p/io/trio.py +++ b/libp2p/io/trio.py @@ -1,7 +1,4 @@ import logging -from typing import ( - Optional, -) import trio @@ -34,7 +31,7 @@ class TrioTCPStream(ReadWriteCloser): except (trio.ClosedResourceError, trio.BrokenResourceError) as error: raise IOException from error - async def read(self, n: int = None) -> bytes: + async def read(self, n: int | None = None) -> bytes: async with self.read_lock: if n is not None and n == 0: return b"" @@ -46,7 +43,7 @@ class TrioTCPStream(ReadWriteCloser): async def close(self) -> None: await self.stream.aclose() - def get_remote_address(self) -> Optional[tuple[str, int]]: + def get_remote_address(self) -> tuple[str, int] | None: """Return the remote address as (host, port) tuple.""" try: return self.stream.socket.getpeername() diff --git a/libp2p/io/utils.py b/libp2p/io/utils.py index 8f873ea0..43ae1a3f 100644 --- a/libp2p/io/utils.py +++ b/libp2p/io/utils.py @@ -14,12 +14,14 @@ async def read_exactly( """ NOTE: relying on exceptions to break out on erroneous conditions, like EOF """ - data = await reader.read(n) + buffer = bytearray() + buffer.extend(await reader.read(n)) for _ in range(retry_count): - if len(data) < n: - remaining = n - len(data) - data += await reader.read(remaining) + if len(buffer) < n: + remaining = n - len(buffer) + buffer.extend(await reader.read(remaining)) + else: - return data - raise IncompleteReadError({"requested_count": n, "received_count": len(data)}) + return bytes(buffer) + raise IncompleteReadError({"requested_count": n, "received_count": len(buffer)}) diff --git a/libp2p/network/connection/raw_connection.py b/libp2p/network/connection/raw_connection.py index 2c6dd5d7..dd857327 100644 --- a/libp2p/network/connection/raw_connection.py +++ b/libp2p/network/connection/raw_connection.py @@ -1,7 +1,3 @@ -from typing import ( - Optional, -) - from libp2p.abc import ( IRawConnection, ) @@ -32,7 +28,7 @@ class RawConnection(IRawConnection): except IOException as error: raise RawConnError from error - async def read(self, n: int = None) -> bytes: + async def read(self, n: int | None = None) -> bytes: """ Read up to ``n`` bytes from the underlying stream. This call is delegated directly to the underlying ``self.reader``. @@ -47,6 +43,6 @@ class RawConnection(IRawConnection): async def close(self) -> None: await self.stream.close() - def get_remote_address(self) -> Optional[tuple[str, int]]: + def get_remote_address(self) -> tuple[str, int] | None: """Delegate to the underlying stream's get_remote_address method.""" return self.stream.get_remote_address() diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index f0fc2a36..79c8849f 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: """ -Reference: https://github.com/libp2p/go-libp2p-swarm/blob/04c86bbdafd390651cb2ee14e334f7caeedad722/swarm_conn.go # noqa: E501 +Reference: https://github.com/libp2p/go-libp2p-swarm/blob/04c86bbdafd390651cb2ee14e334f7caeedad722/swarm_conn.go """ @@ -32,7 +32,11 @@ class SwarmConn(INetConn): streams: set[NetStream] event_closed: trio.Event - def __init__(self, muxed_conn: IMuxedConn, swarm: "Swarm") -> None: + def __init__( + self, + muxed_conn: IMuxedConn, + swarm: "Swarm", + ) -> None: self.muxed_conn = muxed_conn self.swarm = swarm self.streams = set() @@ -40,7 +44,7 @@ class SwarmConn(INetConn): self.event_started = trio.Event() if hasattr(muxed_conn, "on_close"): logging.debug(f"Setting on_close for peer {muxed_conn.peer_id}") - muxed_conn.on_close = self._on_muxed_conn_closed + setattr(muxed_conn, "on_close", self._on_muxed_conn_closed) else: logging.error( f"muxed_conn for peer {muxed_conn.peer_id} has no on_close attribute" diff --git a/libp2p/network/stream/net_stream.py b/libp2p/network/stream/net_stream.py index 62e6f711..300f0fa4 100644 --- a/libp2p/network/stream/net_stream.py +++ b/libp2p/network/stream/net_stream.py @@ -1,7 +1,3 @@ -from typing import ( - Optional, -) - from libp2p.abc import ( IMuxedStream, INetStream, @@ -28,14 +24,14 @@ from .exceptions import ( # - Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501 class NetStream(INetStream): muxed_stream: IMuxedStream - protocol_id: Optional[TProtocol] + protocol_id: TProtocol | None def __init__(self, muxed_stream: IMuxedStream) -> None: self.muxed_stream = muxed_stream self.muxed_conn = muxed_stream.muxed_conn self.protocol_id = None - def get_protocol(self) -> TProtocol: + def get_protocol(self) -> TProtocol | None: """ :return: protocol id that stream runs on """ @@ -47,7 +43,7 @@ class NetStream(INetStream): """ self.protocol_id = protocol_id - async def read(self, n: int = None) -> bytes: + async def read(self, n: int | None = None) -> bytes: """ Read from stream. @@ -79,7 +75,7 @@ class NetStream(INetStream): async def reset(self) -> None: await self.muxed_stream.reset() - def get_remote_address(self) -> Optional[tuple[str, int]]: + def get_remote_address(self) -> tuple[str, int] | None: """Delegate to the underlying muxed stream.""" return self.muxed_stream.get_remote_address() diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 267151f6..d19b8177 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -1,7 +1,4 @@ import logging -from typing import ( - Optional, -) from multiaddr import ( Multiaddr, @@ -75,7 +72,7 @@ class Swarm(Service, INetworkService): connections: dict[ID, INetConn] listeners: dict[str, IListener] common_stream_handler: StreamHandlerFn - listener_nursery: Optional[trio.Nursery] + listener_nursery: trio.Nursery | None event_listener_nursery_created: trio.Event notifees: list[INotifee] @@ -340,7 +337,9 @@ class Swarm(Service, INetworkService): if hasattr(self, "transport") and self.transport is not None: # Check if transport has close method before calling it if hasattr(self.transport, "close"): - await self.transport.close() + await self.transport.close() # type: ignore + # Ignoring the type above since `transport` may not have a close method + # and we have already checked it with hasattr logger.debug("swarm successfully closed") @@ -360,7 +359,11 @@ class Swarm(Service, INetworkService): and start to monitor the connection for its new streams and disconnection. """ - swarm_conn = SwarmConn(muxed_conn, self) + swarm_conn = SwarmConn( + muxed_conn, + self, + ) + self.manager.run_task(muxed_conn.start) await muxed_conn.event_started.wait() self.manager.run_task(swarm_conn.start) diff --git a/libp2p/peer/id.py b/libp2p/peer/id.py index 06c7674f..0be51ea2 100644 --- a/libp2p/peer/id.py +++ b/libp2p/peer/id.py @@ -1,7 +1,4 @@ import hashlib -from typing import ( - Union, -) import base58 import multihash @@ -24,7 +21,7 @@ if ENABLE_INLINING: _digest: bytes def __init__(self) -> None: - self._digest = bytearray() + self._digest = b"" def update(self, input: bytes) -> None: self._digest += input @@ -39,8 +36,8 @@ if ENABLE_INLINING: class ID: _bytes: bytes - _xor_id: int = None - _b58_str: str = None + _xor_id: int | None = None + _b58_str: str | None = None def __init__(self, peer_id_bytes: bytes) -> None: self._bytes = peer_id_bytes @@ -93,7 +90,7 @@ class ID: return cls(mh_digest.encode()) -def sha256_digest(data: Union[str, bytes]) -> bytes: +def sha256_digest(data: str | bytes) -> bytes: if isinstance(data, str): data = data.encode("utf8") return hashlib.sha256(data).digest() diff --git a/libp2p/peer/peerdata.py b/libp2p/peer/peerdata.py index f0e52463..fa9f4f54 100644 --- a/libp2p/peer/peerdata.py +++ b/libp2p/peer/peerdata.py @@ -1,9 +1,7 @@ from collections.abc import ( Sequence, ) -from typing import ( - Any, -) +from typing import Any from multiaddr import ( Multiaddr, @@ -19,8 +17,8 @@ from libp2p.crypto.keys import ( class PeerData(IPeerData): - pubkey: PublicKey - privkey: PrivateKey + pubkey: PublicKey | None + privkey: PrivateKey | None metadata: dict[Any, Any] protocols: list[str] addrs: list[Multiaddr] diff --git a/libp2p/peer/peerinfo.py b/libp2p/peer/peerinfo.py index 024b1801..f3b3bd7b 100644 --- a/libp2p/peer/peerinfo.py +++ b/libp2p/peer/peerinfo.py @@ -32,21 +32,31 @@ def info_from_p2p_addr(addr: multiaddr.Multiaddr) -> PeerInfo: if not addr: raise InvalidAddrError("`addr` should not be `None`") - parts = addr.split() + parts: list[multiaddr.Multiaddr] = addr.split() if not parts: raise InvalidAddrError( f"`parts`={parts} should at least have a protocol `P_P2P`" ) p2p_part = parts[-1] - last_protocol_code = p2p_part.protocols()[0].code - if last_protocol_code != multiaddr.protocols.P_P2P: + p2p_protocols = p2p_part.protocols() + if not p2p_protocols: + raise InvalidAddrError("The last part of the address has no protocols") + last_protocol = p2p_protocols[0] + if last_protocol is None: + raise InvalidAddrError("The last protocol is None") + + last_protocol_code = last_protocol.code + if last_protocol_code != multiaddr.multiaddr.protocols.P_P2P: raise InvalidAddrError( f"The last protocol should be `P_P2P` instead of `{last_protocol_code}`" ) # make sure the /p2p value parses as a peer.ID - peer_id_str: str = p2p_part.value_for_protocol(multiaddr.protocols.P_P2P) + peer_id_str = p2p_part.value_for_protocol(multiaddr.multiaddr.protocols.P_P2P) + if peer_id_str is None: + raise InvalidAddrError("Missing value for /p2p protocol in multiaddr") + peer_id: ID = ID.from_base58(peer_id_str) # we might have received just an / p2p part, which means there's no addr. diff --git a/libp2p/protocol_muxer/multiselect.py b/libp2p/protocol_muxer/multiselect.py index b7ee2004..8f6e0e74 100644 --- a/libp2p/protocol_muxer/multiselect.py +++ b/libp2p/protocol_muxer/multiselect.py @@ -23,16 +23,20 @@ class Multiselect(IMultiselectMuxer): communication. """ - handlers: dict[TProtocol, StreamHandlerFn] + handlers: dict[TProtocol | None, StreamHandlerFn | None] def __init__( - self, default_handlers: dict[TProtocol, StreamHandlerFn] = None + self, + default_handlers: None + | (dict[TProtocol | None, StreamHandlerFn | None]) = None, ) -> None: if not default_handlers: default_handlers = {} self.handlers = default_handlers - def add_handler(self, protocol: TProtocol, handler: StreamHandlerFn) -> None: + def add_handler( + self, protocol: TProtocol | None, handler: StreamHandlerFn | None + ) -> None: """ Store the handler with the given protocol. @@ -41,9 +45,10 @@ class Multiselect(IMultiselectMuxer): """ self.handlers[protocol] = handler + # FIXME: Make TProtocol Optional[TProtocol] to keep types consistent async def negotiate( self, communicator: IMultiselectCommunicator - ) -> tuple[TProtocol, StreamHandlerFn]: + ) -> tuple[TProtocol, StreamHandlerFn | None]: """ Negotiate performs protocol selection. @@ -60,7 +65,7 @@ class Multiselect(IMultiselectMuxer): raise MultiselectError() from error if command == "ls": - supported_protocols = list(self.handlers.keys()) + supported_protocols = [p for p in self.handlers.keys() if p is not None] response = "\n".join(supported_protocols) + "\n" try: @@ -82,6 +87,8 @@ class Multiselect(IMultiselectMuxer): except MultiselectCommunicatorError as error: raise MultiselectError() from error + raise MultiselectError("Negotiation failed: no matching protocol") + async def handshake(self, communicator: IMultiselectCommunicator) -> None: """ Perform handshake to agree on multiselect protocol. diff --git a/libp2p/pubsub/floodsub.py b/libp2p/pubsub/floodsub.py index 884dc89a..93d01f1a 100644 --- a/libp2p/pubsub/floodsub.py +++ b/libp2p/pubsub/floodsub.py @@ -22,6 +22,9 @@ from libp2p.utils import ( encode_varint_prefixed, ) +from .exceptions import ( + PubsubRouterError, +) from .pb import ( rpc_pb2, ) @@ -37,7 +40,7 @@ logger = logging.getLogger("libp2p.pubsub.floodsub") class FloodSub(IPubsubRouter): protocols: list[TProtocol] - pubsub: Pubsub + pubsub: Pubsub | None def __init__(self, protocols: Sequence[TProtocol]) -> None: self.protocols = list(protocols) @@ -58,7 +61,7 @@ class FloodSub(IPubsubRouter): """ self.pubsub = pubsub - def add_peer(self, peer_id: ID, protocol_id: TProtocol) -> None: + def add_peer(self, peer_id: ID, protocol_id: TProtocol | None) -> None: """ Notifies the router that a new peer has been connected. @@ -108,17 +111,22 @@ class FloodSub(IPubsubRouter): logger.debug("publishing message %s", pubsub_msg) + if self.pubsub is None: + raise PubsubRouterError("pubsub not attached to this instance") + else: + pubsub = self.pubsub + for peer_id in peers_gen: - if peer_id not in self.pubsub.peers: + if peer_id not in pubsub.peers: continue - stream = self.pubsub.peers[peer_id] + stream = pubsub.peers[peer_id] # FIXME: We should add a `WriteMsg` similar to write delimited messages. # Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107 try: await stream.write(encode_varint_prefixed(rpc_msg.SerializeToString())) except StreamClosed: logger.debug("Fail to publish message to %s: stream closed", peer_id) - self.pubsub._handle_dead_peer(peer_id) + pubsub._handle_dead_peer(peer_id) async def join(self, topic: str) -> None: """ @@ -150,12 +158,16 @@ class FloodSub(IPubsubRouter): :param origin: peer id of the peer the message originate from. :return: a generator of the peer ids who we send data to. """ + if self.pubsub is None: + raise PubsubRouterError("pubsub not attached to this instance") + else: + pubsub = self.pubsub for topic in topic_ids: - if topic not in self.pubsub.peer_topics: + if topic not in pubsub.peer_topics: continue - for peer_id in self.pubsub.peer_topics[topic]: + for peer_id in pubsub.peer_topics[topic]: if peer_id in (msg_forwarder, origin): continue - if peer_id not in self.pubsub.peers: + if peer_id not in pubsub.peers: continue yield peer_id diff --git a/libp2p/pubsub/gossipsub.py b/libp2p/pubsub/gossipsub.py index cbed462d..813719dd 100644 --- a/libp2p/pubsub/gossipsub.py +++ b/libp2p/pubsub/gossipsub.py @@ -67,7 +67,7 @@ logger = logging.getLogger("libp2p.pubsub.gossipsub") class GossipSub(IPubsubRouter, Service): protocols: list[TProtocol] - pubsub: Pubsub + pubsub: Pubsub | None degree: int degree_high: int @@ -98,7 +98,7 @@ class GossipSub(IPubsubRouter, Service): degree: int, degree_low: int, degree_high: int, - direct_peers: Sequence[PeerInfo] = None, + direct_peers: Sequence[PeerInfo] | None = None, time_to_live: int = 60, gossip_window: int = 3, gossip_history: int = 5, @@ -141,8 +141,6 @@ class GossipSub(IPubsubRouter, Service): self.time_since_last_publish = {} async def run(self) -> None: - if self.pubsub is None: - raise NoPubsubAttached self.manager.run_daemon_task(self.heartbeat) if len(self.direct_peers) > 0: self.manager.run_daemon_task(self.direct_connect_heartbeat) @@ -173,7 +171,7 @@ class GossipSub(IPubsubRouter, Service): logger.debug("attached to pusub") - def add_peer(self, peer_id: ID, protocol_id: TProtocol) -> None: + def add_peer(self, peer_id: ID, protocol_id: TProtocol | None) -> None: """ Notifies the router that a new peer has been connected. @@ -182,6 +180,9 @@ class GossipSub(IPubsubRouter, Service): """ logger.debug("adding peer %s with protocol %s", peer_id, protocol_id) + if protocol_id is None: + raise ValueError("Protocol cannot be None") + if protocol_id not in (PROTOCOL_ID, floodsub.PROTOCOL_ID): # We should never enter here. Becuase the `protocol_id` is registered by # your pubsub instance in multistream-select, but it is not the protocol @@ -243,6 +244,8 @@ class GossipSub(IPubsubRouter, Service): logger.debug("publishing message %s", pubsub_msg) for peer_id in peers_gen: + if self.pubsub is None: + raise NoPubsubAttached if peer_id not in self.pubsub.peers: continue stream = self.pubsub.peers[peer_id] @@ -269,6 +272,8 @@ class GossipSub(IPubsubRouter, Service): """ send_to: set[ID] = set() for topic in topic_ids: + if self.pubsub is None: + raise NoPubsubAttached if topic not in self.pubsub.peer_topics: continue @@ -318,6 +323,9 @@ class GossipSub(IPubsubRouter, Service): :param topic: topic to join """ + if self.pubsub is None: + raise NoPubsubAttached + logger.debug("joining topic %s", topic) if topic in self.mesh: @@ -468,6 +476,8 @@ class GossipSub(IPubsubRouter, Service): await trio.sleep(self.direct_connect_initial_delay) while True: for direct_peer in self.direct_peers: + if self.pubsub is None: + raise NoPubsubAttached if direct_peer not in self.pubsub.peers: try: await self.pubsub.host.connect(self.direct_peers[direct_peer]) @@ -485,6 +495,8 @@ class GossipSub(IPubsubRouter, Service): peers_to_graft: DefaultDict[ID, list[str]] = defaultdict(list) peers_to_prune: DefaultDict[ID, list[str]] = defaultdict(list) for topic in self.mesh: + if self.pubsub is None: + raise NoPubsubAttached # Skip if no peers have subscribed to the topic if topic not in self.pubsub.peer_topics: continue @@ -520,7 +532,8 @@ class GossipSub(IPubsubRouter, Service): # Note: the comments here are the exact pseudocode from the spec for topic in list(self.fanout): if ( - topic not in self.pubsub.peer_topics + self.pubsub is not None + and topic not in self.pubsub.peer_topics and self.time_since_last_publish.get(topic, 0) + self.time_to_live < int(time.time()) ): @@ -529,11 +542,14 @@ class GossipSub(IPubsubRouter, Service): else: # Check if fanout peers are still in the topic and remove the ones that are not # noqa: E501 # ref: https://github.com/libp2p/go-libp2p-pubsub/blob/01b9825fbee1848751d90a8469e3f5f43bac8466/gossipsub.go#L498-L504 # noqa: E501 - in_topic_fanout_peers = [ - peer - for peer in self.fanout[topic] - if peer in self.pubsub.peer_topics[topic] - ] + + in_topic_fanout_peers: list[ID] = [] + if self.pubsub is not None: + in_topic_fanout_peers = [ + peer + for peer in self.fanout[topic] + if peer in self.pubsub.peer_topics[topic] + ] self.fanout[topic] = set(in_topic_fanout_peers) num_fanout_peers_in_topic = len(self.fanout[topic]) @@ -553,6 +569,8 @@ class GossipSub(IPubsubRouter, Service): for topic in self.mesh: msg_ids = self.mcache.window(topic) if msg_ids: + if self.pubsub is None: + raise NoPubsubAttached # Get all pubsub peers in a topic and only add them if they are # gossipsub peers too if topic in self.pubsub.peer_topics: @@ -572,6 +590,8 @@ class GossipSub(IPubsubRouter, Service): for topic in self.fanout: msg_ids = self.mcache.window(topic) if msg_ids: + if self.pubsub is None: + raise NoPubsubAttached # Get all pubsub peers in topic and only add if they are # gossipsub peers also if topic in self.pubsub.peer_topics: @@ -620,6 +640,8 @@ class GossipSub(IPubsubRouter, Service): def _get_in_topic_gossipsub_peers_from_minus( self, topic: str, num_to_select: int, minus: Iterable[ID] ) -> list[ID]: + if self.pubsub is None: + raise NoPubsubAttached gossipsub_peers_in_topic = { peer_id for peer_id in self.pubsub.peer_topics[topic] @@ -633,6 +655,8 @@ class GossipSub(IPubsubRouter, Service): self, ihave_msg: rpc_pb2.ControlIHave, sender_peer_id: ID ) -> None: """Checks the seen set and requests unknown messages with an IWANT message.""" + if self.pubsub is None: + raise NoPubsubAttached # Get list of all seen (seqnos, from) from the (seqno, from) tuples in # seen_messages cache seen_seqnos_and_peers = [ @@ -665,7 +689,7 @@ class GossipSub(IPubsubRouter, Service): msgs_to_forward: list[rpc_pb2.Message] = [] for msg_id_iwant in msg_ids: # Check if the wanted message ID is present in mcache - msg: rpc_pb2.Message = self.mcache.get(msg_id_iwant) + msg: rpc_pb2.Message | None = self.mcache.get(msg_id_iwant) # Cache hit if msg: @@ -683,6 +707,8 @@ class GossipSub(IPubsubRouter, Service): # 2) Serialize that packet rpc_msg: bytes = packet.SerializeToString() + if self.pubsub is None: + raise NoPubsubAttached # 3) Get the stream to this peer if sender_peer_id not in self.pubsub.peers: @@ -737,9 +763,9 @@ class GossipSub(IPubsubRouter, Service): def pack_control_msgs( self, - ihave_msgs: list[rpc_pb2.ControlIHave], - graft_msgs: list[rpc_pb2.ControlGraft], - prune_msgs: list[rpc_pb2.ControlPrune], + ihave_msgs: list[rpc_pb2.ControlIHave] | None, + graft_msgs: list[rpc_pb2.ControlGraft] | None, + prune_msgs: list[rpc_pb2.ControlPrune] | None, ) -> rpc_pb2.ControlMessage: control_msg: rpc_pb2.ControlMessage = rpc_pb2.ControlMessage() if ihave_msgs: @@ -771,7 +797,7 @@ class GossipSub(IPubsubRouter, Service): await self.emit_control_message(control_msg, to_peer) - async def emit_graft(self, topic: str, to_peer: ID) -> None: + async def emit_graft(self, topic: str, id: ID) -> None: """Emit graft message, sent to to_peer, for topic.""" graft_msg: rpc_pb2.ControlGraft = rpc_pb2.ControlGraft() graft_msg.topicID = topic @@ -779,9 +805,9 @@ class GossipSub(IPubsubRouter, Service): control_msg: rpc_pb2.ControlMessage = rpc_pb2.ControlMessage() control_msg.graft.extend([graft_msg]) - await self.emit_control_message(control_msg, to_peer) + await self.emit_control_message(control_msg, id) - async def emit_prune(self, topic: str, to_peer: ID) -> None: + async def emit_prune(self, topic: str, id: ID) -> None: """Emit graft message, sent to to_peer, for topic.""" prune_msg: rpc_pb2.ControlPrune = rpc_pb2.ControlPrune() prune_msg.topicID = topic @@ -789,11 +815,13 @@ class GossipSub(IPubsubRouter, Service): control_msg: rpc_pb2.ControlMessage = rpc_pb2.ControlMessage() control_msg.prune.extend([prune_msg]) - await self.emit_control_message(control_msg, to_peer) + await self.emit_control_message(control_msg, id) async def emit_control_message( self, control_msg: rpc_pb2.ControlMessage, to_peer: ID ) -> None: + if self.pubsub is None: + raise NoPubsubAttached # Add control message to packet packet: rpc_pb2.RPC = rpc_pb2.RPC() packet.control.CopyFrom(control_msg) diff --git a/libp2p/pubsub/mcache.py b/libp2p/pubsub/mcache.py index fe1ecb29..e3776fdd 100644 --- a/libp2p/pubsub/mcache.py +++ b/libp2p/pubsub/mcache.py @@ -1,9 +1,6 @@ from collections.abc import ( Sequence, ) -from typing import ( - Optional, -) from .pb import ( rpc_pb2, @@ -66,7 +63,7 @@ class MessageCache: self.history[0].append(CacheEntry(mid, msg.topicIDs)) - def get(self, mid: tuple[bytes, bytes]) -> Optional[rpc_pb2.Message]: + def get(self, mid: tuple[bytes, bytes]) -> rpc_pb2.Message | None: """ Get a message from the mcache. diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 1f37607e..5f66f30a 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -4,6 +4,7 @@ from __future__ import ( import base64 from collections.abc import ( + Callable, KeysView, ) import functools @@ -11,7 +12,6 @@ import hashlib import logging import time from typing import ( - Callable, NamedTuple, cast, ) @@ -53,6 +53,9 @@ from libp2p.network.stream.exceptions import ( from libp2p.peer.id import ( ID, ) +from libp2p.peer.peerdata import ( + PeerDataError, +) from libp2p.tools.async_service import ( Service, ) @@ -120,7 +123,7 @@ class Pubsub(Service, IPubsub): # Indicate if we should enforce signature verification strict_signing: bool - sign_key: PrivateKey + sign_key: PrivateKey | None # Set of blacklisted peer IDs blacklisted_peers: set[ID] @@ -132,7 +135,7 @@ class Pubsub(Service, IPubsub): self, host: IHost, router: IPubsubRouter, - cache_size: int = None, + cache_size: int | None = None, seen_ttl: int = 120, sweep_interval: int = 60, strict_signing: bool = True, @@ -634,6 +637,9 @@ class Pubsub(Service, IPubsub): if self.strict_signing: priv_key = self.sign_key + if priv_key is None: + raise PeerDataError("private key not found") + signature = priv_key.sign( PUBSUB_SIGNING_PREFIX.encode() + msg.SerializeToString() ) diff --git a/libp2p/security/base_session.py b/libp2p/security/base_session.py index fa99a62a..596179a9 100644 --- a/libp2p/security/base_session.py +++ b/libp2p/security/base_session.py @@ -1,7 +1,3 @@ -from typing import ( - Optional, -) - from libp2p.abc import ( ISecureConn, ) @@ -49,5 +45,5 @@ class BaseSession(ISecureConn): def get_remote_peer(self) -> ID: return self.remote_peer - def get_remote_public_key(self) -> Optional[PublicKey]: + def get_remote_public_key(self) -> PublicKey: return self.remote_permanent_pubkey diff --git a/libp2p/security/base_transport.py b/libp2p/security/base_transport.py index 108ded01..b8fbd99f 100644 --- a/libp2p/security/base_transport.py +++ b/libp2p/security/base_transport.py @@ -1,7 +1,7 @@ -import secrets -from typing import ( +from collections.abc import ( Callable, ) +import secrets from libp2p.abc import ( ISecureTransport, diff --git a/libp2p/security/insecure/transport.py b/libp2p/security/insecure/transport.py index 4666cc78..a230e970 100644 --- a/libp2p/security/insecure/transport.py +++ b/libp2p/security/insecure/transport.py @@ -1,7 +1,3 @@ -from typing import ( - Optional, -) - from libp2p.abc import ( IRawConnection, ISecureConn, @@ -87,13 +83,13 @@ class InsecureSession(BaseSession): async def write(self, data: bytes) -> None: await self.conn.write(data) - async def read(self, n: int = None) -> bytes: + async def read(self, n: int | None = None) -> bytes: return await self.conn.read(n) async def close(self) -> None: await self.conn.close() - def get_remote_address(self) -> Optional[tuple[str, int]]: + def get_remote_address(self) -> tuple[str, int] | None: """ Delegate to the underlying connection's get_remote_address method. """ @@ -105,7 +101,7 @@ async def run_handshake( local_private_key: PrivateKey, conn: IRawConnection, is_initiator: bool, - remote_peer_id: ID, + remote_peer_id: ID | None, ) -> ISecureConn: """Raise `HandshakeFailure` when handshake failed.""" msg = make_exchange_message(local_private_key.get_public_key()) @@ -124,6 +120,15 @@ async def run_handshake( remote_msg.ParseFromString(remote_msg_bytes) received_peer_id = ID(remote_msg.id) + # Verify that `remote_peer_id` isn't `None` + # That is the only condition that `remote_peer_id` would not need to be checked + # against the `recieved_peer_id` gotten from the outbound/recieved `msg`. + # The check against `received_peer_id` happens in the next if-block + if is_initiator and remote_peer_id is None: + raise HandshakeFailure( + "remote peer ID cannot be None if `is_initiator` is set to `True`" + ) + # Verify if the receive `ID` matches the one we originally initialize the session. # We only need to check it when we are the initiator, because only in that condition # we possibly knows the `ID` of the remote. diff --git a/libp2p/security/noise/io.py b/libp2p/security/noise/io.py index f9a0260b..877aa5ab 100644 --- a/libp2p/security/noise/io.py +++ b/libp2p/security/noise/io.py @@ -1,5 +1,4 @@ from typing import ( - Optional, cast, ) @@ -10,7 +9,6 @@ from libp2p.abc import ( ) from libp2p.io.abc import ( EncryptedMsgReadWriter, - MsgReadWriteCloser, ReadWriteCloser, ) from libp2p.io.msgio import ( @@ -40,7 +38,7 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter): implemented by the subclasses. """ - read_writer: MsgReadWriteCloser + read_writer: NoisePacketReadWriter noise_state: NoiseState # FIXME: This prefix is added in msg#3 in Go. Check whether it's a desired behavior. @@ -50,12 +48,12 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter): self.read_writer = NoisePacketReadWriter(cast(ReadWriteCloser, conn)) self.noise_state = noise_state - async def write_msg(self, data: bytes, prefix_encoded: bool = False) -> None: - data_encrypted = self.encrypt(data) + async def write_msg(self, msg: bytes, prefix_encoded: bool = False) -> None: + data_encrypted = self.encrypt(msg) if prefix_encoded: - await self.read_writer.write_msg(self.prefix + data_encrypted) - else: - await self.read_writer.write_msg(data_encrypted) + # Manually add the prefix if needed + data_encrypted = self.prefix + data_encrypted + await self.read_writer.write_msg(data_encrypted) async def read_msg(self, prefix_encoded: bool = False) -> bytes: noise_msg_encrypted = await self.read_writer.read_msg() @@ -67,10 +65,11 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter): async def close(self) -> None: await self.read_writer.close() - def get_remote_address(self) -> Optional[tuple[str, int]]: + def get_remote_address(self) -> tuple[str, int] | None: # Delegate to the underlying connection if possible if hasattr(self.read_writer, "read_write_closer") and hasattr( - self.read_writer.read_write_closer, "get_remote_address" + self.read_writer.read_write_closer, + "get_remote_address", ): return self.read_writer.read_write_closer.get_remote_address() return None @@ -78,7 +77,7 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter): class NoiseHandshakeReadWriter(BaseNoiseMsgReadWriter): def encrypt(self, data: bytes) -> bytes: - return self.noise_state.write_message(data) + return bytes(self.noise_state.write_message(data)) def decrypt(self, data: bytes) -> bytes: return bytes(self.noise_state.read_message(data)) diff --git a/libp2p/security/noise/messages.py b/libp2p/security/noise/messages.py index cea5f166..309b24b0 100644 --- a/libp2p/security/noise/messages.py +++ b/libp2p/security/noise/messages.py @@ -19,7 +19,7 @@ SIGNED_DATA_PREFIX = "noise-libp2p-static-key:" class NoiseHandshakePayload: id_pubkey: PublicKey id_sig: bytes - early_data: bytes = None + early_data: bytes | None = None def serialize(self) -> bytes: msg = noise_pb.NoiseHandshakePayload( diff --git a/libp2p/security/noise/patterns.py b/libp2p/security/noise/patterns.py index 27b8d63b..00f51d06 100644 --- a/libp2p/security/noise/patterns.py +++ b/libp2p/security/noise/patterns.py @@ -7,8 +7,10 @@ from cryptography.hazmat.primitives import ( serialization, ) from noise.backends.default.keypairs import KeyPair as NoiseKeyPair -from noise.connection import Keypair as NoiseKeypairEnum -from noise.connection import NoiseConnection as NoiseState +from noise.connection import ( + Keypair as NoiseKeypairEnum, + NoiseConnection as NoiseState, +) from libp2p.abc import ( IRawConnection, @@ -47,14 +49,12 @@ from .messages import ( class IPattern(ABC): @abstractmethod - async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn: - ... + async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn: ... @abstractmethod async def handshake_outbound( self, conn: IRawConnection, remote_peer: ID - ) -> ISecureConn: - ... + ) -> ISecureConn: ... class BasePattern(IPattern): @@ -62,13 +62,15 @@ class BasePattern(IPattern): noise_static_key: PrivateKey local_peer: ID libp2p_privkey: PrivateKey - early_data: bytes + early_data: bytes | None def create_noise_state(self) -> NoiseState: noise_state = NoiseState.from_name(self.protocol_name) noise_state.set_keypair_from_private_bytes( NoiseKeypairEnum.STATIC, self.noise_static_key.to_bytes() ) + if noise_state.noise_protocol is None: + raise NoiseStateError("noise_protocol is not initialized") return noise_state def make_handshake_payload(self) -> NoiseHandshakePayload: @@ -84,7 +86,7 @@ class PatternXX(BasePattern): local_peer: ID, libp2p_privkey: PrivateKey, noise_static_key: PrivateKey, - early_data: bytes = None, + early_data: bytes | None = None, ) -> None: self.protocol_name = b"Noise_XX_25519_ChaChaPoly_SHA256" self.local_peer = local_peer @@ -96,7 +98,12 @@ class PatternXX(BasePattern): noise_state = self.create_noise_state() noise_state.set_as_responder() noise_state.start_handshake() + if noise_state.noise_protocol is None: + raise NoiseStateError("noise_protocol is not initialized") handshake_state = noise_state.noise_protocol.handshake_state + if handshake_state is None: + raise NoiseStateError("Handshake state is not initialized") + read_writer = NoiseHandshakeReadWriter(conn, noise_state) # Consume msg#1. @@ -145,7 +152,11 @@ class PatternXX(BasePattern): read_writer = NoiseHandshakeReadWriter(conn, noise_state) noise_state.set_as_initiator() noise_state.start_handshake() + if noise_state.noise_protocol is None: + raise NoiseStateError("noise_protocol is not initialized") handshake_state = noise_state.noise_protocol.handshake_state + if handshake_state is None: + raise NoiseStateError("Handshake state is not initialized") # Send msg#1, which is *not* encrypted. msg_1 = b"" @@ -195,6 +206,8 @@ class PatternXX(BasePattern): @staticmethod def _get_pubkey_from_noise_keypair(key_pair: NoiseKeyPair) -> PublicKey: # Use `Ed25519PublicKey` since 25519 is used in our pattern. + if key_pair.public is None: + raise NoiseStateError("public key is not initialized") raw_bytes = key_pair.public.public_bytes( serialization.Encoding.Raw, serialization.PublicFormat.Raw ) diff --git a/libp2p/security/noise/transport.py b/libp2p/security/noise/transport.py index e90dcc64..8fdd6b6e 100644 --- a/libp2p/security/noise/transport.py +++ b/libp2p/security/noise/transport.py @@ -26,7 +26,7 @@ class Transport(ISecureTransport): libp2p_privkey: PrivateKey noise_privkey: PrivateKey local_peer: ID - early_data: bytes + early_data: bytes | None with_noise_pipes: bool # NOTE: Implementations that support Noise Pipes must decide whether to use @@ -37,8 +37,8 @@ class Transport(ISecureTransport): def __init__( self, libp2p_keypair: KeyPair, - noise_privkey: PrivateKey = None, - early_data: bytes = None, + noise_privkey: PrivateKey, + early_data: bytes | None = None, with_noise_pipes: bool = False, ) -> None: self.libp2p_privkey = libp2p_keypair.private_key diff --git a/libp2p/security/secio/transport.py b/libp2p/security/secio/transport.py index 343c9a1a..fad2b945 100644 --- a/libp2p/security/secio/transport.py +++ b/libp2p/security/secio/transport.py @@ -2,9 +2,6 @@ from dataclasses import ( dataclass, ) import itertools -from typing import ( - Optional, -) import multihash @@ -14,14 +11,10 @@ from libp2p.abc import ( ) from libp2p.crypto.authenticated_encryption import ( EncryptionParameters as AuthenticatedEncryptionParameters, -) -from libp2p.crypto.authenticated_encryption import ( InvalidMACException, -) -from libp2p.crypto.authenticated_encryption import ( + MacAndCipher as Encrypter, initialize_pair as initialize_pair_for_encryption, ) -from libp2p.crypto.authenticated_encryption import MacAndCipher as Encrypter from libp2p.crypto.ecc import ( ECCPublicKey, ) @@ -91,6 +84,8 @@ class SecioPacketReadWriter(FixedSizeLenMsgReadWriter): class SecioMsgReadWriter(EncryptedMsgReadWriter): read_writer: SecioPacketReadWriter + local_encrypter: Encrypter + remote_encrypter: Encrypter def __init__( self, @@ -213,7 +208,8 @@ async def _response_to_msg(read_writer: SecioPacketReadWriter, msg: bytes) -> by def _mk_multihash_sha256(data: bytes) -> bytes: - return multihash.digest(data, "sha2-256") + mh = multihash.digest(data, "sha2-256") + return mh.encode() def _mk_score(public_key: PublicKey, nonce: bytes) -> bytes: @@ -270,7 +266,7 @@ def _select_encryption_parameters( async def _establish_session_parameters( local_peer: PeerID, local_private_key: PrivateKey, - remote_peer: Optional[PeerID], + remote_peer: PeerID | None, conn: SecioPacketReadWriter, nonce: bytes, ) -> tuple[SessionParameters, bytes]: @@ -399,7 +395,7 @@ async def create_secure_session( local_peer: PeerID, local_private_key: PrivateKey, conn: IRawConnection, - remote_peer: PeerID = None, + remote_peer: PeerID | None = None, ) -> ISecureConn: """ Attempt the initial `secio` handshake with the remote peer. diff --git a/libp2p/security/secure_session.py b/libp2p/security/secure_session.py index 7551bfee..ea31972a 100644 --- a/libp2p/security/secure_session.py +++ b/libp2p/security/secure_session.py @@ -1,7 +1,4 @@ import io -from typing import ( - Optional, -) from libp2p.crypto.keys import ( PrivateKey, @@ -44,7 +41,7 @@ class SecureSession(BaseSession): self._reset_internal_buffer() - def get_remote_address(self) -> Optional[tuple[str, int]]: + def get_remote_address(self) -> tuple[str, int] | None: """Delegate to the underlying connection's get_remote_address method.""" return self.conn.get_remote_address() @@ -53,7 +50,7 @@ class SecureSession(BaseSession): self.low_watermark = 0 self.high_watermark = 0 - def _drain(self, n: int) -> bytes: + def _drain(self, n: int | None) -> bytes: if self.low_watermark == self.high_watermark: return b"" @@ -75,7 +72,7 @@ class SecureSession(BaseSession): self.low_watermark = 0 self.high_watermark = len(msg) - async def read(self, n: int = None) -> bytes: + async def read(self, n: int | None = None) -> bytes: if n == 0: return b"" @@ -85,6 +82,9 @@ class SecureSession(BaseSession): msg = await self.conn.read_msg() + if n is None: + return msg + if n < len(msg): self._fill(msg) return self._drain(n) diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index e21e0768..a3548646 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -1,7 +1,4 @@ import logging -from typing import ( - Optional, -) import trio @@ -168,7 +165,7 @@ class Mplex(IMuxedConn): raise MplexUnavailable async def send_message( - self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID + self, flag: HeaderTags, data: bytes | None, stream_id: StreamID ) -> int: """ Send a message over the connection. @@ -366,6 +363,6 @@ class Mplex(IMuxedConn): self.event_closed.set() await self.new_stream_send_channel.aclose() - def get_remote_address(self) -> Optional[tuple[str, int]]: + def get_remote_address(self) -> tuple[str, int] | None: """Delegate to the underlying Mplex connection's secured_conn.""" return self.secured_conn.get_remote_address() diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index a5bce0c1..3b640df1 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -3,7 +3,6 @@ from types import ( ) from typing import ( TYPE_CHECKING, - Optional, ) import trio @@ -40,9 +39,12 @@ class MplexStream(IMuxedStream): name: str stream_id: StreamID - muxed_conn: "Mplex" - read_deadline: int - write_deadline: int + # NOTE: All methods used here are part of `Mplex` which is a derived + # class of IMuxedConn. Ignoring this type assignment should not pose + # any risk. + muxed_conn: "Mplex" # type: ignore[assignment] + read_deadline: int | None + write_deadline: int | None # TODO: Add lock for read/write to avoid interleaving receiving messages? close_lock: trio.Lock @@ -92,7 +94,7 @@ class MplexStream(IMuxedStream): self._buf = self._buf[len(payload) :] return bytes(payload) - def _read_return_when_blocked(self) -> bytes: + def _read_return_when_blocked(self) -> bytearray: buf = bytearray() while True: try: @@ -102,7 +104,7 @@ class MplexStream(IMuxedStream): break return buf - async def read(self, n: int = None) -> bytes: + async def read(self, n: int | None = None) -> bytes: """ Read up to n bytes. Read possibly returns fewer than `n` bytes, if there are not enough bytes in the Mplex buffer. If `n is None`, read @@ -257,7 +259,7 @@ class MplexStream(IMuxedStream): self.write_deadline = ttl return True - def get_remote_address(self) -> Optional[tuple[str, int]]: + def get_remote_address(self) -> tuple[str, int] | None: """Delegate to the parent Mplex connection.""" return self.muxed_conn.get_remote_address() @@ -267,9 +269,9 @@ class MplexStream(IMuxedStream): async def __aexit__( self, - exc_type: Optional[type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> None: """Exit the async context manager and close the stream.""" await self.close() diff --git a/libp2p/stream_muxer/muxer_multistream.py b/libp2p/stream_muxer/muxer_multistream.py index 3151b0fe..b4aa5d57 100644 --- a/libp2p/stream_muxer/muxer_multistream.py +++ b/libp2p/stream_muxer/muxer_multistream.py @@ -95,7 +95,7 @@ class MuxerMultistream: if protocol == PROTOCOL_ID: async with trio.open_nursery(): - def on_close() -> None: + async def on_close() -> None: pass return Yamux( diff --git a/libp2p/stream_muxer/yamux/yamux.py b/libp2p/stream_muxer/yamux/yamux.py index ceceb541..92123465 100644 --- a/libp2p/stream_muxer/yamux/yamux.py +++ b/libp2p/stream_muxer/yamux/yamux.py @@ -3,8 +3,10 @@ Yamux stream multiplexer implementation for py-libp2p. This is the preferred multiplexing protocol due to its performance and feature set. Mplex is also available for legacy compatibility but may be deprecated in the future. """ + from collections.abc import ( Awaitable, + Callable, ) import inspect import logging @@ -13,8 +15,7 @@ from types import ( TracebackType, ) from typing import ( - Callable, - Optional, + Any, ) import trio @@ -83,9 +84,9 @@ class YamuxStream(IMuxedStream): async def __aexit__( self, - exc_type: Optional[type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> None: """Exit the async context manager and close the stream.""" await self.close() @@ -126,7 +127,7 @@ class YamuxStream(IMuxedStream): if self.send_window < DEFAULT_WINDOW_SIZE // 2: await self.send_window_update() - async def send_window_update(self, increment: Optional[int] = None) -> None: + async def send_window_update(self, increment: int | None = None) -> None: """Send a window update to peer.""" if increment is None: increment = DEFAULT_WINDOW_SIZE - self.recv_window @@ -141,7 +142,7 @@ class YamuxStream(IMuxedStream): ) await self.conn.secured_conn.write(header) - async def read(self, n: int = -1) -> bytes: + async def read(self, n: int | None = -1) -> bytes: # Handle None value for n by converting it to -1 if n is None: n = -1 @@ -161,8 +162,7 @@ class YamuxStream(IMuxedStream): if buffer and len(buffer) > 0: # Wait for closure even if data is available logging.debug( - f"Stream {self.stream_id}:" - f"Waiting for FIN before returning data" + f"Stream {self.stream_id}:Waiting for FIN before returning data" ) await self.conn.stream_events[self.stream_id].wait() self.conn.stream_events[self.stream_id] = trio.Event() @@ -240,7 +240,7 @@ class YamuxStream(IMuxedStream): """ raise NotImplementedError("Yamux does not support setting read deadlines") - def get_remote_address(self) -> Optional[tuple[str, int]]: + def get_remote_address(self) -> tuple[str, int] | None: """ Returns the remote address of the underlying connection. """ @@ -268,8 +268,8 @@ class Yamux(IMuxedConn): self, secured_conn: ISecureConn, peer_id: ID, - is_initiator: Optional[bool] = None, - on_close: Optional[Callable[[], Awaitable[None]]] = None, + is_initiator: bool | None = None, + on_close: Callable[[], Awaitable[Any]] | None = None, ) -> None: self.secured_conn = secured_conn self.peer_id = peer_id @@ -283,7 +283,7 @@ class Yamux(IMuxedConn): self.is_initiator_value = ( is_initiator if is_initiator is not None else secured_conn.is_initiator ) - self.next_stream_id = 1 if self.is_initiator_value else 2 + self.next_stream_id: int = 1 if self.is_initiator_value else 2 self.streams: dict[int, YamuxStream] = {} self.streams_lock = trio.Lock() self.new_stream_send_channel: MemorySendChannel[YamuxStream] @@ -297,7 +297,7 @@ class Yamux(IMuxedConn): self.event_started = trio.Event() self.stream_buffers: dict[int, bytearray] = {} self.stream_events: dict[int, trio.Event] = {} - self._nursery: Optional[Nursery] = None + self._nursery: Nursery | None = None async def start(self) -> None: logging.debug(f"Starting Yamux for {self.peer_id}") @@ -465,8 +465,14 @@ class Yamux(IMuxedConn): # Wait for data if stream is still open logging.debug(f"Waiting for data on stream {self.peer_id}:{stream_id}") - await self.stream_events[stream_id].wait() - self.stream_events[stream_id] = trio.Event() + try: + await self.stream_events[stream_id].wait() + self.stream_events[stream_id] = trio.Event() + except KeyError: + raise MuxedStreamEOF("Stream was removed") + + # This line should never be reached, but satisfies the type checker + raise MuxedStreamEOF("Unexpected end of read_stream") async def handle_incoming(self) -> None: while not self.event_shutting_down.is_set(): @@ -474,8 +480,7 @@ class Yamux(IMuxedConn): header = await self.secured_conn.read(HEADER_SIZE) if not header or len(header) < HEADER_SIZE: logging.debug( - f"Connection closed or" - f"incomplete header for peer {self.peer_id}" + f"Connection closed orincomplete header for peer {self.peer_id}" ) self.event_shutting_down.set() await self._cleanup_on_error() @@ -544,8 +549,7 @@ class Yamux(IMuxedConn): ) elif error_code == GO_AWAY_PROTOCOL_ERROR: logging.error( - f"Received GO_AWAY for peer" - f"{self.peer_id}: Protocol error" + f"Received GO_AWAY for peer{self.peer_id}: Protocol error" ) elif error_code == GO_AWAY_INTERNAL_ERROR: logging.error( diff --git a/libp2p/tools/async_service/_utils.py b/libp2p/tools/async_service/_utils.py index 6754e827..3be8c20b 100644 --- a/libp2p/tools/async_service/_utils.py +++ b/libp2p/tools/async_service/_utils.py @@ -1,12 +1,10 @@ # Copied from https://github.com/ethereum/async-service import os -from typing import ( - Any, -) +from typing import Any -def get_task_name(value: Any, explicit_name: str = None) -> str: +def get_task_name(value: Any, explicit_name: str | None = None) -> str: # inline import to ensure `_utils` is always importable from the rest of # the module. from .abc import ( # noqa: F401 diff --git a/libp2p/tools/async_service/abc.py b/libp2p/tools/async_service/abc.py index 95cce84e..51f23b0f 100644 --- a/libp2p/tools/async_service/abc.py +++ b/libp2p/tools/async_service/abc.py @@ -28,33 +28,27 @@ class TaskAPI(Hashable): parent: Optional["TaskWithChildrenAPI"] @abstractmethod - async def run(self) -> None: - ... + async def run(self) -> None: ... @abstractmethod - async def cancel(self) -> None: - ... + async def cancel(self) -> None: ... @property @abstractmethod - def is_done(self) -> bool: - ... + def is_done(self) -> bool: ... @abstractmethod - async def wait_done(self) -> None: - ... + async def wait_done(self) -> None: ... class TaskWithChildrenAPI(TaskAPI): children: set[TaskAPI] @abstractmethod - def add_child(self, child: TaskAPI) -> None: - ... + def add_child(self, child: TaskAPI) -> None: ... @abstractmethod - def discard_child(self, child: TaskAPI) -> None: - ... + def discard_child(self, child: TaskAPI) -> None: ... class ServiceAPI(ABC): @@ -212,7 +206,11 @@ class InternalManagerAPI(ManagerAPI): @trio_typing.takes_callable_and_args @abstractmethod def run_task( - self, async_fn: AsyncFn, *args: Any, daemon: bool = False, name: str = None + self, + async_fn: AsyncFn, + *args: Any, + daemon: bool = False, + name: str | None = None, ) -> None: """ Run a task in the background. If the function throws an exception it @@ -225,7 +223,9 @@ class InternalManagerAPI(ManagerAPI): @trio_typing.takes_callable_and_args @abstractmethod - def run_daemon_task(self, async_fn: AsyncFn, *args: Any, name: str = None) -> None: + def run_daemon_task( + self, async_fn: AsyncFn, *args: Any, name: str | None = None + ) -> None: """ Run a daemon task in the background. @@ -235,7 +235,7 @@ class InternalManagerAPI(ManagerAPI): @abstractmethod def run_child_service( - self, service: ServiceAPI, daemon: bool = False, name: str = None + self, service: ServiceAPI, daemon: bool = False, name: str | None = None ) -> "ManagerAPI": """ Run a service in the background. If the function throws an exception it @@ -248,7 +248,7 @@ class InternalManagerAPI(ManagerAPI): @abstractmethod def run_daemon_child_service( - self, service: ServiceAPI, name: str = None + self, service: ServiceAPI, name: str | None = None ) -> "ManagerAPI": """ Run a daemon service in the background. diff --git a/libp2p/tools/async_service/base.py b/libp2p/tools/async_service/base.py index 60ec654d..a23f0e75 100644 --- a/libp2p/tools/async_service/base.py +++ b/libp2p/tools/async_service/base.py @@ -9,6 +9,7 @@ from collections import ( ) from collections.abc import ( Awaitable, + Callable, Iterable, Sequence, ) @@ -16,8 +17,6 @@ import logging import sys from typing import ( Any, - Callable, - Optional, TypeVar, cast, ) @@ -98,7 +97,7 @@ def as_service(service_fn: LogicFnType) -> type[ServiceAPI]: class BaseTask(TaskAPI): def __init__( - self, name: str, daemon: bool, parent: Optional[TaskWithChildrenAPI] + self, name: str, daemon: bool, parent: TaskWithChildrenAPI | None ) -> None: # meta self.name = name @@ -125,7 +124,7 @@ class BaseTask(TaskAPI): class BaseTaskWithChildren(BaseTask, TaskWithChildrenAPI): def __init__( - self, name: str, daemon: bool, parent: Optional[TaskWithChildrenAPI] + self, name: str, daemon: bool, parent: TaskWithChildrenAPI | None ) -> None: super().__init__(name, daemon, parent) self.children = set() @@ -142,26 +141,20 @@ T = TypeVar("T", bound="BaseFunctionTask") class BaseFunctionTask(BaseTaskWithChildren): @classmethod - def iterate_tasks(cls: type[T], *tasks: TaskAPI) -> Iterable[T]: + def iterate_tasks(cls, *tasks: TaskAPI) -> Iterable["BaseFunctionTask"]: + """Iterate over all tasks of this class type and their children recursively.""" for task in tasks: - if isinstance(task, cls): + if isinstance(task, BaseFunctionTask): yield task - else: - continue - yield from cls.iterate_tasks( - *( - child_task - for child_task in task.children - if isinstance(child_task, cls) - ) - ) + if isinstance(task, TaskWithChildrenAPI): + yield from cls.iterate_tasks(*task.children) def __init__( self, name: str, daemon: bool, - parent: Optional[TaskWithChildrenAPI], + parent: TaskWithChildrenAPI | None, async_fn: AsyncFn, async_fn_args: Sequence[Any], ) -> None: @@ -259,12 +252,15 @@ class BaseManager(InternalManagerAPI): # Wait API # def run_daemon_task( - self, async_fn: Callable[..., Awaitable[Any]], *args: Any, name: str = None + self, + async_fn: Callable[..., Awaitable[Any]], + *args: Any, + name: str | None = None, ) -> None: self.run_task(async_fn, *args, daemon=True, name=name) def run_daemon_child_service( - self, service: ServiceAPI, name: str = None + self, service: ServiceAPI, name: str | None = None ) -> ManagerAPI: return self.run_child_service(service, daemon=True, name=name) @@ -286,8 +282,7 @@ class BaseManager(InternalManagerAPI): # Task Management # @abstractmethod - def _schedule_task(self, task: TaskAPI) -> None: - ... + def _schedule_task(self, task: TaskAPI) -> None: ... def _common_run_task(self, task: TaskAPI) -> None: if not self.is_running: @@ -307,7 +302,7 @@ class BaseManager(InternalManagerAPI): self._schedule_task(task) def _add_child_task( - self, parent: Optional[TaskWithChildrenAPI], task: TaskAPI + self, parent: TaskWithChildrenAPI | None, task: TaskAPI ) -> None: if parent is None: all_children = self._root_tasks diff --git a/libp2p/tools/async_service/trio_service.py b/libp2p/tools/async_service/trio_service.py index f65a5706..3fdddb81 100644 --- a/libp2p/tools/async_service/trio_service.py +++ b/libp2p/tools/async_service/trio_service.py @@ -6,7 +6,9 @@ from __future__ import ( from collections.abc import ( AsyncIterator, Awaitable, + Callable, Coroutine, + Iterable, Sequence, ) from contextlib import ( @@ -16,7 +18,6 @@ import functools import sys from typing import ( Any, - Callable, Optional, TypeVar, cast, @@ -59,6 +60,16 @@ from .typing import ( class FunctionTask(BaseFunctionTask): _trio_task: trio.lowlevel.Task | None = None + @classmethod + def iterate_tasks(cls, *tasks: TaskAPI) -> Iterable[FunctionTask]: + """Iterate over all FunctionTask instances and their children recursively.""" + for task in tasks: + if isinstance(task, FunctionTask): + yield task + + if isinstance(task, TaskWithChildrenAPI): + yield from cls.iterate_tasks(*task.children) + def __init__( self, name: str, @@ -75,7 +86,7 @@ class FunctionTask(BaseFunctionTask): # Each task gets its own `CancelScope` which is how we can manually # control cancellation order of the task DAG - self._cancel_scope = trio.CancelScope() + self._cancel_scope = trio.CancelScope() # type: ignore[call-arg] # # Trio specific API @@ -309,7 +320,7 @@ class TrioManager(BaseManager): async_fn: Callable[..., Awaitable[Any]], *args: Any, daemon: bool = False, - name: str = None, + name: str | None = None, ) -> None: task = FunctionTask( name=get_task_name(async_fn, name), @@ -322,7 +333,7 @@ class TrioManager(BaseManager): self._common_run_task(task) def run_child_service( - self, service: ServiceAPI, daemon: bool = False, name: str = None + self, service: ServiceAPI, daemon: bool = False, name: str | None = None ) -> ManagerAPI: task = ChildServiceTask( name=get_task_name(service, name), @@ -416,7 +427,12 @@ def external_api(func: TFunc) -> TFunc: async with trio.open_nursery() as nursery: # mypy's type hints for start_soon break with this invocation. nursery.start_soon( - _wait_api_fn, self, func, args, kwargs, send_channel # type: ignore + _wait_api_fn, # type: ignore + self, + func, + args, + kwargs, + send_channel, ) nursery.start_soon(_wait_finished, self, func, send_channel) result, err = await receive_channel.receive() diff --git a/libp2p/tools/async_service/typing.py b/libp2p/tools/async_service/typing.py index 616b71d9..e725d483 100644 --- a/libp2p/tools/async_service/typing.py +++ b/libp2p/tools/async_service/typing.py @@ -2,13 +2,13 @@ from collections.abc import ( Awaitable, + Callable, ) from types import ( TracebackType, ) from typing import ( Any, - Callable, ) EXC_INFO = tuple[type[BaseException], BaseException, TracebackType] diff --git a/libp2p/tools/constants.py b/libp2p/tools/constants.py index b9d5c849..a9ba4b76 100644 --- a/libp2p/tools/constants.py +++ b/libp2p/tools/constants.py @@ -32,7 +32,7 @@ class GossipsubParams(NamedTuple): degree: int = 10 degree_low: int = 9 degree_high: int = 11 - direct_peers: Sequence[PeerInfo] = None + direct_peers: Sequence[PeerInfo] = [] time_to_live: int = 30 gossip_window: int = 3 gossip_history: int = 5 diff --git a/libp2p/tools/utils.py b/libp2p/tools/utils.py index 320a46ba..48f4efcf 100644 --- a/libp2p/tools/utils.py +++ b/libp2p/tools/utils.py @@ -1,10 +1,8 @@ from collections.abc import ( Awaitable, -) -import logging -from typing import ( Callable, ) +import logging import trio @@ -63,12 +61,12 @@ async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm) -> None: logging.debug( "Swarm connection verification failed on attempt" - + f" {attempt+1}, retrying..." + + f" {attempt + 1}, retrying..." ) except Exception as e: last_error = e - logging.debug(f"Swarm connection attempt {attempt+1} failed: {e}") + logging.debug(f"Swarm connection attempt {attempt + 1} failed: {e}") await trio.sleep(retry_delay) # If we got here, all retries failed @@ -115,12 +113,12 @@ async def connect(node1: IHost, node2: IHost) -> None: return logging.debug( - f"Connection verification failed on attempt {attempt+1}, retrying..." + f"Connection verification failed on attempt {attempt + 1}, retrying..." ) except Exception as e: last_error = e - logging.debug(f"Connection attempt {attempt+1} failed: {e}") + logging.debug(f"Connection attempt {attempt + 1} failed: {e}") await trio.sleep(retry_delay) # If we got here, all retries failed diff --git a/libp2p/transport/tcp/tcp.py b/libp2p/transport/tcp/tcp.py index 4ed06c98..1598ea42 100644 --- a/libp2p/transport/tcp/tcp.py +++ b/libp2p/transport/tcp/tcp.py @@ -1,11 +1,9 @@ from collections.abc import ( Awaitable, + Callable, Sequence, ) import logging -from typing import ( - Callable, -) from multiaddr import ( Multiaddr, @@ -44,7 +42,7 @@ class TCPListener(IListener): self.handler = handler_function # TODO: Get rid of `nursery`? - async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> None: + async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool: """ Put listener in listening mode and wait for incoming connections. @@ -56,7 +54,7 @@ class TCPListener(IListener): handler: Callable[[trio.SocketStream], Awaitable[None]], port: int, host: str, - task_status: TaskStatus[Sequence[trio.SocketListener]] = None, + task_status: TaskStatus[Sequence[trio.SocketListener]], ) -> None: """Just a proxy function to add logging here.""" logger.debug("serve_tcp %s %s", host, port) @@ -67,18 +65,53 @@ class TCPListener(IListener): remote_port: int = 0 try: tcp_stream = TrioTCPStream(stream) - remote_host, remote_port = tcp_stream.get_remote_address() + remote_tuple = tcp_stream.get_remote_address() + + if remote_tuple is not None: + remote_host, remote_port = remote_tuple + await self.handler(tcp_stream) except Exception: logger.debug(f"Connection from {remote_host}:{remote_port} failed.") - listeners = await nursery.start( + tcp_port_str = maddr.value_for_protocol("tcp") + if tcp_port_str is None: + logger.error(f"Cannot listen: TCP port is missing in multiaddress {maddr}") + return False + + try: + tcp_port = int(tcp_port_str) + except ValueError: + logger.error( + f"Cannot listen: Invalid TCP port '{tcp_port_str}' " + f"in multiaddress {maddr}" + ) + return False + + ip4_host_str = maddr.value_for_protocol("ip4") + # For trio.serve_tcp, ip4_host_str (as host argument) can be None, + # which typically means listen on all available interfaces. + + started_listeners = await nursery.start( serve_tcp, handler, - int(maddr.value_for_protocol("tcp")), - maddr.value_for_protocol("ip4"), + tcp_port, + ip4_host_str, ) - self.listeners.extend(listeners) + + if started_listeners is None: + # This implies that task_status.started() was not called within serve_tcp, + # likely because trio.serve_tcp itself failed to start (e.g., port in use). + logger.error( + f"Failed to start TCP listener for {maddr}: " + f"`nursery.start` returned None. " + "This might be due to issues like the port already " + "being in use or invalid host." + ) + return False + + self.listeners.extend(started_listeners) + return True def get_addrs(self) -> tuple[Multiaddr, ...]: """ @@ -105,15 +138,42 @@ class TCP(ITransport): :return: `RawConnection` if successful :raise OpenConnectionError: raised when failed to open connection """ - self.host = maddr.value_for_protocol("ip4") - self.port = int(maddr.value_for_protocol("tcp")) + host_str = maddr.value_for_protocol("ip4") + port_str = maddr.value_for_protocol("tcp") + + if host_str is None: + raise OpenConnectionError( + f"Failed to dial {maddr}: IP address not found in multiaddr." + ) + + if port_str is None: + raise OpenConnectionError( + f"Failed to dial {maddr}: TCP port not found in multiaddr." + ) try: - stream = await trio.open_tcp_stream(self.host, self.port) - except OSError as error: - raise OpenConnectionError from error - read_write_closer = TrioTCPStream(stream) + port_int = int(port_str) + except ValueError: + raise OpenConnectionError( + f"Failed to dial {maddr}: Invalid TCP port '{port_str}'." + ) + try: + # trio.open_tcp_stream requires host to be str or bytes, not None. + stream = await trio.open_tcp_stream(host_str, port_int) + except OSError as error: + # OSError is common for network issues like "Connection refused" + # or "Host unreachable". + raise OpenConnectionError( + f"Failed to open TCP stream to {maddr}: {error}" + ) from error + except Exception as error: + # Catch other potential errors from trio.open_tcp_stream and wrap them. + raise OpenConnectionError( + f"An unexpected error occurred when dialing {maddr}: {error}" + ) from error + + read_write_closer = TrioTCPStream(stream) return RawConnection(read_write_closer, True) def create_listener(self, handler_function: THandler) -> TCPListener: diff --git a/libp2p/utils/logging.py b/libp2p/utils/logging.py index 637d028d..3458a41e 100644 --- a/libp2p/utils/logging.py +++ b/libp2p/utils/logging.py @@ -13,15 +13,13 @@ import sys import threading from typing import ( Any, - Optional, - Union, ) # Create a log queue log_queue: "queue.Queue[Any]" = queue.Queue() # Store the current listener to stop it on exit -_current_listener: Optional[logging.handlers.QueueListener] = None +_current_listener: logging.handlers.QueueListener | None = None # Event to track when the listener is ready _listener_ready = threading.Event() @@ -135,7 +133,7 @@ def setup_logging() -> None: formatter = logging.Formatter(DEFAULT_LOG_FORMAT) # Configure handlers - handlers: list[Union[logging.StreamHandler[Any], logging.FileHandler]] = [] + handlers: list[logging.StreamHandler[Any] | logging.FileHandler] = [] # Console handler console_handler = logging.StreamHandler(sys.stderr) diff --git a/newsfragments/618.internal.rst b/newsfragments/618.internal.rst new file mode 100644 index 00000000..3db303dc --- /dev/null +++ b/newsfragments/618.internal.rst @@ -0,0 +1 @@ +Modernizes several aspects of the project, notably using ``pyproject.toml`` for project info instead of ``setup.py``, using ``ruff`` to replace several separate linting tools, and ``pyrefly`` in addition to ``mypy`` for typing. Also includes changes across the codebase to conform to new linting and typing rules. diff --git a/newsfragments/618.removal.rst b/newsfragments/618.removal.rst new file mode 100644 index 00000000..64fc5134 --- /dev/null +++ b/newsfragments/618.removal.rst @@ -0,0 +1 @@ +Removes support for python 3.9 and updates some code conventions, notably using ``|`` operator in typing instead of ``Optional`` or ``Union`` diff --git a/pyproject.toml b/pyproject.toml index 8b2e3caa..9d2f47da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,20 +1,105 @@ -[tool.autoflake] -exclude = "__init__.py" -remove_all_unused_imports = true -[tool.isort] -combine_as_imports = false -extra_standard_library = "pytest" -force_grid_wrap = 1 -force_sort_within_sections = true -force_to_top = "pytest" -honor_noqa = true -known_first_party = "libp2p" -known_third_party = "anyio,factory,lru,p2pclient,pytest,noise" -multi_line_output = 3 -profile = "black" -skip_glob= "*_pb2*.py, *.pyi" -use_parentheses = true +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "libp2p" +version = "0.2.7" +description = "libp2p: The Python implementation of the libp2p networking stack" +readme = "README.md" +requires-python = ">=3.10, <4.0" +license = { text = "MIT AND Apache-2.0" } +keywords = ["libp2p", "p2p"] +authors = [ + { name = "The Ethereum Foundation", email = "snakecharmers@ethereum.org" }, +] +dependencies = [ + "base58>=1.0.3", + "coincurve>=10.0.0", + "exceptiongroup>=1.2.0; python_version < '3.11'", + "grpcio>=1.41.0", + "lru-dict>=1.1.6", + "multiaddr>=0.0.9", + "mypy-protobuf>=3.0.0", + "noiseprotocol>=0.3.0", + "protobuf>=3.20.1,<4.0.0", + "pycryptodome>=3.9.2", + "pymultihash>=0.8.2", + "pynacl>=1.3.0", + "rpcudp>=3.0.0", + "trio-typing>=0.0.4", + "trio>=0.26.0", + "fastecdsa==1.7.5; sys_platform != 'win32'", +] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Natural Language :: English", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", +] + +[project.urls] +Homepage = "https://github.com/libp2p/py-libp2p" + +[project.scripts] +chat-demo = "examples.chat.chat:main" +echo-demo = "examples.echo.echo:main" +ping-demo = "examples.ping.ping:main" +identify-demo = "examples.identify.identify:main" +identify-push-demo = "examples.identify_push.identify_push_demo:run_main" +identify-push-listener-dialer-demo = "examples.identify_push.identify_push_listener_dialer:main" +pubsub-demo = "examples.pubsub.pubsub:main" + +[project.optional-dependencies] +dev = [ + "build>=0.9.0", + "bump_my_version>=0.19.0", + "ipython", + "mypy>=1.15.0", + "pre-commit>=3.4.0", + "tox>=4.0.0", + "twine", + "wheel", + "setuptools>=42", + "sphinx>=6.0.0", + "sphinx_rtd_theme>=1.0.0", + "towncrier>=24,<25", + "p2pclient==0.2.0", + "pytest>=7.0.0", + "pytest-xdist>=2.4.0", + "pytest-trio>=0.5.2", + "factory-boy>=2.12.0,<3.0.0", + "ruff>=0.11.10", + "pyrefly (>=0.17.1,<0.18.0)", +] +docs = [ + "sphinx>=6.0.0", + "sphinx_rtd_theme>=1.0.0", + "towncrier>=24,<25", + "tomli; python_version < '3.11'", +] +test = [ + "p2pclient==0.2.0", + "pytest>=7.0.0", + "pytest-xdist>=2.4.0", + "pytest-trio>=0.5.2", + "factory-boy>=2.12.0,<3.0.0", +] + +[tool.setuptools] +include-package-data = true + +[tool.setuptools.packages.find] +exclude = ["scripts*", "tests*"] + +[tool.setuptools.package-data] +libp2p = ["py.typed"] + [tool.mypy] check_untyped_defs = true @@ -27,37 +112,12 @@ disallow_untyped_defs = true ignore_missing_imports = true incremental = false strict_equality = true -strict_optional = false +strict_optional = true warn_redundant_casts = true warn_return_any = false warn_unused_configs = true -warn_unused_ignores = true +warn_unused_ignores = false -[tool.pydocstyle] -# All error codes found here: -# http://www.pydocstyle.org/en/3.0.0/error_codes.html -# -# Ignored: -# D1 - Missing docstring error codes -# -# Selected: -# D2 - Whitespace error codes -# D3 - Quote error codes -# D4 - Content related error codes -select = "D2,D3,D4" - -# Extra ignores: -# D200 - One-line docstring should fit on one line with quotes -# D203 - 1 blank line required before class docstring -# D204 - 1 blank line required after class docstring -# D205 - 1 blank line required between summary line and description -# D212 - Multi-line docstring summary should start at the first line -# D302 - Use u""" for Unicode docstrings -# D400 - First line should end with a period -# D401 - First line should be in imperative mood -# D412 - No blank lines allowed between a section header and its content -# D415 - First line should end with a period, question mark, or exclamation point -add-ignore = "D200,D203,D204,D205,D212,D302,D400,D401,D412,D415" # Explanation: # D400 - Enabling this error code seems to make it a requirement that the first @@ -138,8 +198,8 @@ parse = """ )? """ serialize = [ - "{major}.{minor}.{patch}-{stage}.{devnum}", - "{major}.{minor}.{patch}", + "{major}.{minor}.{patch}-{stage}.{devnum}", + "{major}.{minor}.{patch}", ] search = "{current_version}" replace = "{new_version}" @@ -156,11 +216,7 @@ message = "Bump version: {current_version} → {new_version}" [tool.bumpversion.parts.stage] optional_value = "stable" first_value = "stable" -values = [ - "alpha", - "beta", - "stable", -] +values = ["alpha", "beta", "stable"] [tool.bumpversion.part.devnum] @@ -168,3 +224,63 @@ values = [ filename = "setup.py" search = "version=\"{current_version}\"" replace = "version=\"{new_version}\"" + +[[tool.bumpversion.files]] +filename = "pyproject.toml" # Keep pyproject.toml version in sync +search = 'version = "{current_version}"' +replace = 'version = "{new_version}"' + +[tool.ruff] +line-length = 88 +exclude = ["__init__.py", "*_pb2*.py", "*.pyi"] + +[tool.ruff.lint] +select = [ + "F", # Pyflakes + "E", # pycodestyle errors + "W", # pycodestyle warnings + "I", # isort + "D", # pydocstyle +] +# Ignores from pydocstyle and any other desired ones +ignore = [ + "D100", + "D101", + "D102", + "D103", + "D105", + "D106", + "D107", + "D200", + "D203", + "D204", + "D205", + "D212", + "D400", + "D401", + "D412", + "D415", +] + +[tool.ruff.lint.isort] +force-wrap-aliases = true +combine-as-imports = true +extra-standard-library = [] +force-sort-within-sections = true +known-first-party = ["libp2p", "tests"] +known-third-party = ["anyio", "factory", "lru", "p2pclient", "pytest", "noise"] +force-to-top = ["pytest"] + +[tool.ruff.format] +# Using Ruff's Black-compatible formatter. +# Options like quote-style = "double" or indent-style = "space" can be set here if needed. + +[tool.pyrefly] +project_includes = ["libp2p", "examples", "tests"] +project_excludes = [ + "**/.project-template/**", + "**/docs/conf.py", + "**/*pb2.py", + "**/*.pyi", + ".venv/**", +] diff --git a/setup.py b/setup.py deleted file mode 100644 index a23d811a..00000000 --- a/setup.py +++ /dev/null @@ -1,117 +0,0 @@ -#!/usr/bin/env python -import sys - -from setuptools import ( - find_packages, - setup, -) - -description = "libp2p: The Python implementation of the libp2p networking stack" - -# Platform-specific dependencies -if sys.platform == "win32": - crypto_requires = [] # We'll use coincurve instead of fastecdsa on Windows -else: - crypto_requires = ["fastecdsa==1.7.5"] - -extras_require = { - "dev": [ - "build>=0.9.0", - "bump_my_version>=0.19.0", - "ipython", - "mypy==1.10.0", - "pre-commit>=3.4.0", - "tox>=4.0.0", - "twine", - "wheel", - ], - "docs": [ - "sphinx>=6.0.0", - "sphinx_rtd_theme>=1.0.0", - "towncrier>=24,<25", - ], - "test": [ - "p2pclient==0.2.0", - "pytest>=7.0.0", - "pytest-xdist>=2.4.0", - "pytest-trio>=0.5.2", - "factory-boy>=2.12.0,<3.0.0", - ], -} - -extras_require["dev"] = ( - extras_require["dev"] + extras_require["docs"] + extras_require["test"] -) - -try: - with open("./README.md", encoding="utf-8") as readme: - long_description = readme.read() -except FileNotFoundError: - long_description = description - -install_requires = [ - "base58>=1.0.3", - "coincurve>=10.0.0", - "exceptiongroup>=1.2.0; python_version < '3.11'", - "grpcio>=1.41.0", - "lru-dict>=1.1.6", - "multiaddr>=0.0.9", - "mypy-protobuf>=3.0.0", - "noiseprotocol>=0.3.0", - "protobuf>=6.30.1", - "pycryptodome>=3.9.2", - "pymultihash>=0.8.2", - "pynacl>=1.3.0", - "rpcudp>=3.0.0", - "trio-typing>=0.0.4", - "trio>=0.26.0", -] - -# Add platform-specific dependencies -install_requires.extend(crypto_requires) - -setup( - name="libp2p", - # *IMPORTANT*: Don't manually change the version here. See Contributing docs for the release process. - version="0.2.7", - description=description, - long_description=long_description, - long_description_content_type="text/markdown", - author="The Ethereum Foundation", - author_email="snakecharmers@ethereum.org", - url="https://github.com/libp2p/py-libp2p", - include_package_data=True, - install_requires=install_requires, - python_requires=">=3.9, <4", - extras_require=extras_require, - py_modules=["libp2p"], - license="MIT AND Apache-2.0", - license_files=("LICENSE-MIT", "LICENSE-APACHE"), - zip_safe=False, - keywords="libp2p p2p", - packages=find_packages(exclude=["scripts", "scripts.*", "tests", "tests.*"]), - package_data={"libp2p": ["py.typed"]}, - classifiers=[ - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "Natural Language :: English", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3.13", - ], - platforms=["unix", "linux", "osx", "win32"], - entry_points={ - "console_scripts": [ - "chat-demo=examples.chat.chat:main", - "echo-demo=examples.echo.echo:main", - "ping-demo=examples.ping.ping:main", - "identify-demo=examples.identify.identify:main", - "identify-push-demo=examples.identify_push.identify_push_demo:run_main", - "identify-push-listener-dialer-demo=examples.identify_push.identify_push_listener_dialer:main", - "pubsub-demo=examples.pubsub.pubsub:main", - ], - }, -) diff --git a/tests/crypto/test_x25519.py b/tests/core/crypto/test_x25519.py similarity index 100% rename from tests/crypto/test_x25519.py rename to tests/core/crypto/test_x25519.py diff --git a/tests/core/examples/test_examples.py b/tests/core/examples/test_examples.py index 61ec59b1..d60327b6 100644 --- a/tests/core/examples/test_examples.py +++ b/tests/core/examples/test_examples.py @@ -209,6 +209,18 @@ async def ping_demo(host_a, host_b): async def pubsub_demo(host_a, host_b): + gossipsub_a = GossipSub( + [GOSSIPSUB_PROTOCOL_ID], + 3, + 2, + 4, + ) + gossipsub_b = GossipSub( + [GOSSIPSUB_PROTOCOL_ID], + 3, + 2, + 4, + ) gossipsub_a = GossipSub([GOSSIPSUB_PROTOCOL_ID], 3, 2, 4, None, 1, 1) gossipsub_b = GossipSub([GOSSIPSUB_PROTOCOL_ID], 3, 2, 4, None, 1, 1) pubsub_a = Pubsub(host_a, gossipsub_a) diff --git a/tests/core/host/test_autonat.py b/tests/core/host/test_autonat.py index fe394745..4c6dbaca 100644 --- a/tests/core/host/test_autonat.py +++ b/tests/core/host/test_autonat.py @@ -76,18 +76,18 @@ async def test_update_status(): # Less than 2 successful dials should result in PRIVATE status service.dial_results = { - ID("peer1"): True, - ID("peer2"): False, - ID("peer3"): False, + ID(b"peer1"): True, + ID(b"peer2"): False, + ID(b"peer3"): False, } service.update_status() assert service.status == AutoNATStatus.PRIVATE # 2 or more successful dials should result in PUBLIC status service.dial_results = { - ID("peer1"): True, - ID("peer2"): True, - ID("peer3"): False, + ID(b"peer1"): True, + ID(b"peer2"): True, + ID(b"peer3"): False, } service.update_status() assert service.status == AutoNATStatus.PUBLIC diff --git a/tests/core/host/test_routed_host.py b/tests/core/host/test_routed_host.py index 1c0d21db..ecd19ebf 100644 --- a/tests/core/host/test_routed_host.py +++ b/tests/core/host/test_routed_host.py @@ -22,9 +22,10 @@ async def test_host_routing_success(): @pytest.mark.trio async def test_host_routing_fail(): - async with RoutedHostFactory.create_batch_and_listen( - 2 - ) as routed_hosts, HostFactory.create_batch_and_listen(1) as basic_hosts: + async with ( + RoutedHostFactory.create_batch_and_listen(2) as routed_hosts, + HostFactory.create_batch_and_listen(1) as basic_hosts, + ): # routing fails because host_c does not use routing with pytest.raises(ConnectionFailure): await routed_hosts[0].connect(PeerInfo(basic_hosts[0].get_id(), [])) diff --git a/tests/core/identity/identify_push/test_identify_push.py b/tests/core/identity/identify_push/test_identify_push.py index cfceb17a..1b875e6f 100644 --- a/tests/core/identity/identify_push/test_identify_push.py +++ b/tests/core/identity/identify_push/test_identify_push.py @@ -218,7 +218,6 @@ async def test_push_identify_to_peers_with_explicit_params(security_protocol): This test ensures all parameters of push_identify_to_peers are properly tested. """ - # Create four hosts to thoroughly test selective pushing async with host_pair_factory(security_protocol=security_protocol) as ( host_a, diff --git a/tests/core/network/test_notify.py b/tests/core/network/test_notify.py index 0f2d8b44..98caaf86 100644 --- a/tests/core/network/test_notify.py +++ b/tests/core/network/test_notify.py @@ -8,23 +8,20 @@ into network after network has already started listening TODO: Add tests for closed_stream, listen_close when those features are implemented in swarm """ + import enum import pytest +from multiaddr import Multiaddr import trio from libp2p.abc import ( + INetConn, + INetStream, + INetwork, INotifee, ) -from libp2p.tools.async_service import ( - background_trio_service, -) -from libp2p.tools.constants import ( - LISTEN_MADDR, -) -from libp2p.tools.utils import ( - connect_swarm, -) +from libp2p.tools.utils import connect_swarm from tests.utils.factories import ( SwarmFactory, ) @@ -40,169 +37,94 @@ class Event(enum.Enum): class MyNotifee(INotifee): - def __init__(self, events): + def __init__(self, events: list[Event]): self.events = events - async def opened_stream(self, network, stream): + async def opened_stream(self, network: INetwork, stream: INetStream) -> None: self.events.append(Event.OpenedStream) - async def closed_stream(self, network, stream): + async def closed_stream(self, network: INetwork, stream: INetStream) -> None: # TODO: It is not implemented yet. pass - async def connected(self, network, conn): + async def connected(self, network: INetwork, conn: INetConn) -> None: self.events.append(Event.Connected) - async def disconnected(self, network, conn): + async def disconnected(self, network: INetwork, conn: INetConn) -> None: self.events.append(Event.Disconnected) - async def listen(self, network, _multiaddr): + async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None: self.events.append(Event.Listen) - async def listen_close(self, network, _multiaddr): + async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None: # TODO: It is not implemented yet. pass @pytest.mark.trio async def test_notify(security_protocol): - swarms = [SwarmFactory(security_protocol=security_protocol) for _ in range(2)] - - events_0_0 = [] - events_1_0 = [] - events_0_without_listen = [] - # Helper to wait for specific event - async def wait_for_event(events_list, expected_event, timeout=1.0): - start_time = trio.current_time() - while trio.current_time() - start_time < timeout: - if expected_event in events_list: - return True - await trio.sleep(0.01) + async def wait_for_event(events_list, event, timeout=1.0): + with trio.move_on_after(timeout): + while event not in events_list: + await trio.sleep(0.01) + return True return False - # Run swarms. - async with background_trio_service(swarms[0]), background_trio_service(swarms[1]): - # Register events before listening - swarms[0].register_notifee(MyNotifee(events_0_0)) - swarms[1].register_notifee(MyNotifee(events_1_0)) + # Event lists for notifees + events_0_0 = [] + events_0_1 = [] + events_1_0 = [] + events_1_1 = [] - # Listen - async with trio.open_nursery() as nursery: - nursery.start_soon(swarms[0].listen, LISTEN_MADDR) - nursery.start_soon(swarms[1].listen, LISTEN_MADDR) + # Create two swarms, but do not listen yet + async with SwarmFactory.create_batch_and_listen(2) as swarms: + # Register notifees before listening + notifee_0_0 = MyNotifee(events_0_0) + notifee_0_1 = MyNotifee(events_0_1) + notifee_1_0 = MyNotifee(events_1_0) + notifee_1_1 = MyNotifee(events_1_1) - # Wait for Listen events - assert await wait_for_event(events_0_0, Event.Listen) - assert await wait_for_event(events_1_0, Event.Listen) + swarms[0].register_notifee(notifee_0_0) + swarms[0].register_notifee(notifee_0_1) + swarms[1].register_notifee(notifee_1_0) + swarms[1].register_notifee(notifee_1_1) - swarms[0].register_notifee(MyNotifee(events_0_without_listen)) - - # Connected + # Connect swarms await connect_swarm(swarms[0], swarms[1]) - assert await wait_for_event(events_0_0, Event.Connected) - assert await wait_for_event(events_1_0, Event.Connected) - assert await wait_for_event(events_0_without_listen, Event.Connected) - # OpenedStream: first - await swarms[0].new_stream(swarms[1].get_peer_id()) - # OpenedStream: second - await swarms[0].new_stream(swarms[1].get_peer_id()) - # OpenedStream: third, but different direction. - await swarms[1].new_stream(swarms[0].get_peer_id()) + # Create a stream + stream = await swarms[0].new_stream(swarms[1].get_peer_id()) + await stream.close() - # Clear any duplicate events that might have occurred - events_0_0.copy() - events_1_0.copy() - events_0_without_listen.copy() - - # TODO: Check `ClosedStream` and `ListenClose` events after they are ready. - - # Disconnected + # Close peer await swarms[0].close_peer(swarms[1].get_peer_id()) - assert await wait_for_event(events_0_0, Event.Disconnected) - assert await wait_for_event(events_1_0, Event.Disconnected) - assert await wait_for_event(events_0_without_listen, Event.Disconnected) - # Connected again, but different direction. - await connect_swarm(swarms[1], swarms[0]) + # Wait for events + assert await wait_for_event(events_0_0, Event.Connected, 1.0) + assert await wait_for_event(events_0_0, Event.OpenedStream, 1.0) + # assert await wait_for_event( + # events_0_0, Event.ClosedStream, 1.0 + # ) # Not implemented + assert await wait_for_event(events_0_0, Event.Disconnected, 1.0) - # Get the index of the first disconnected event - disconnect_idx_0_0 = events_0_0.index(Event.Disconnected) - disconnect_idx_1_0 = events_1_0.index(Event.Disconnected) - disconnect_idx_without_listen = events_0_without_listen.index( - Event.Disconnected - ) + assert await wait_for_event(events_0_1, Event.Connected, 1.0) + assert await wait_for_event(events_0_1, Event.OpenedStream, 1.0) + # assert await wait_for_event( + # events_0_1, Event.ClosedStream, 1.0 + # ) # Not implemented + assert await wait_for_event(events_0_1, Event.Disconnected, 1.0) - # Check for connected event after disconnect - assert await wait_for_event( - events_0_0[disconnect_idx_0_0 + 1 :], Event.Connected - ) - assert await wait_for_event( - events_1_0[disconnect_idx_1_0 + 1 :], Event.Connected - ) - assert await wait_for_event( - events_0_without_listen[disconnect_idx_without_listen + 1 :], - Event.Connected, - ) + assert await wait_for_event(events_1_0, Event.Connected, 1.0) + assert await wait_for_event(events_1_0, Event.OpenedStream, 1.0) + # assert await wait_for_event( + # events_1_0, Event.ClosedStream, 1.0 + # ) # Not implemented + assert await wait_for_event(events_1_0, Event.Disconnected, 1.0) - # Disconnected again, but different direction. - await swarms[1].close_peer(swarms[0].get_peer_id()) - - # Find index of the second connected event - second_connect_idx_0_0 = events_0_0.index( - Event.Connected, disconnect_idx_0_0 + 1 - ) - second_connect_idx_1_0 = events_1_0.index( - Event.Connected, disconnect_idx_1_0 + 1 - ) - second_connect_idx_without_listen = events_0_without_listen.index( - Event.Connected, disconnect_idx_without_listen + 1 - ) - - # Check for second disconnected event - assert await wait_for_event( - events_0_0[second_connect_idx_0_0 + 1 :], Event.Disconnected - ) - assert await wait_for_event( - events_1_0[second_connect_idx_1_0 + 1 :], Event.Disconnected - ) - assert await wait_for_event( - events_0_without_listen[second_connect_idx_without_listen + 1 :], - Event.Disconnected, - ) - - # Verify the core sequence of events - expected_events_without_listen = [ - Event.Connected, - Event.Disconnected, - Event.Connected, - Event.Disconnected, - ] - - # Filter events to check only pattern we care about - # (skipping OpenedStream which may vary) - filtered_events_0_0 = [ - e - for e in events_0_0 - if e in [Event.Listen, Event.Connected, Event.Disconnected] - ] - filtered_events_1_0 = [ - e - for e in events_1_0 - if e in [Event.Listen, Event.Connected, Event.Disconnected] - ] - filtered_events_without_listen = [ - e - for e in events_0_without_listen - if e in [Event.Connected, Event.Disconnected] - ] - - # Check that the pattern matches - assert filtered_events_0_0[0] == Event.Listen, "First event should be Listen" - assert filtered_events_1_0[0] == Event.Listen, "First event should be Listen" - - # Check pattern: Connected -> Disconnected -> Connected -> Disconnected - assert filtered_events_0_0[1:5] == expected_events_without_listen - assert filtered_events_1_0[1:5] == expected_events_without_listen - assert filtered_events_without_listen[:4] == expected_events_without_listen + assert await wait_for_event(events_1_1, Event.Connected, 1.0) + assert await wait_for_event(events_1_1, Event.OpenedStream, 1.0) + # assert await wait_for_event( + # events_1_1, Event.ClosedStream, 1.0 + # ) # Not implemented + assert await wait_for_event(events_1_1, Event.Disconnected, 1.0) diff --git a/tests/core/network/test_swarm.py b/tests/core/network/test_swarm.py index e3204b79..6389bcb3 100644 --- a/tests/core/network/test_swarm.py +++ b/tests/core/network/test_swarm.py @@ -13,6 +13,9 @@ from libp2p import ( from libp2p.network.exceptions import ( SwarmException, ) +from libp2p.network.swarm import ( + Swarm, +) from libp2p.tools.utils import ( connect_swarm, ) @@ -166,12 +169,14 @@ async def test_swarm_multiaddr(security_protocol): def test_new_swarm_defaults_to_tcp(): swarm = new_swarm() + assert isinstance(swarm, Swarm) assert isinstance(swarm.transport, TCP) def test_new_swarm_tcp_multiaddr_supported(): addr = Multiaddr("/ip4/127.0.0.1/tcp/9999") swarm = new_swarm(listen_addrs=[addr]) + assert isinstance(swarm, Swarm) assert isinstance(swarm.transport, TCP) diff --git a/tests/core/peer/test_addrbook.py b/tests/core/peer/test_addrbook.py index 55240659..1b642cb2 100644 --- a/tests/core/peer/test_addrbook.py +++ b/tests/core/peer/test_addrbook.py @@ -1,5 +1,9 @@ import pytest +from multiaddr import ( + Multiaddr, +) +from libp2p.peer.id import ID from libp2p.peer.peerstore import ( PeerStore, PeerStoreError, @@ -11,51 +15,72 @@ from libp2p.peer.peerstore import ( def test_addrs_empty(): with pytest.raises(PeerStoreError): store = PeerStore() - val = store.addrs("peer") + val = store.addrs(ID(b"peer")) assert not val def test_add_addr_single(): store = PeerStore() - store.add_addr("peer1", "/foo", 10) - store.add_addr("peer1", "/bar", 10) - store.add_addr("peer2", "/baz", 10) + store.add_addr(ID(b"peer1"), Multiaddr("/ip4/127.0.0.1/tcp/4001"), 10) + store.add_addr(ID(b"peer1"), Multiaddr("/ip4/127.0.0.1/tcp/4002"), 10) + store.add_addr(ID(b"peer2"), Multiaddr("/ip4/127.0.0.1/tcp/4003"), 10) - assert store.addrs("peer1") == ["/foo", "/bar"] - assert store.addrs("peer2") == ["/baz"] + assert store.addrs(ID(b"peer1")) == [ + Multiaddr("/ip4/127.0.0.1/tcp/4001"), + Multiaddr("/ip4/127.0.0.1/tcp/4002"), + ] + assert store.addrs(ID(b"peer2")) == [Multiaddr("/ip4/127.0.0.1/tcp/4003")] def test_add_addrs_multiple(): store = PeerStore() - store.add_addrs("peer1", ["/foo1", "/bar1"], 10) - store.add_addrs("peer2", ["/foo2"], 10) + store.add_addrs( + ID(b"peer1"), + [Multiaddr("/ip4/127.0.0.1/tcp/40011"), Multiaddr("/ip4/127.0.0.1/tcp/40021")], + 10, + ) + store.add_addrs(ID(b"peer2"), [Multiaddr("/ip4/127.0.0.1/tcp/40012")], 10) - assert store.addrs("peer1") == ["/foo1", "/bar1"] - assert store.addrs("peer2") == ["/foo2"] + assert store.addrs(ID(b"peer1")) == [ + Multiaddr("/ip4/127.0.0.1/tcp/40011"), + Multiaddr("/ip4/127.0.0.1/tcp/40021"), + ] + assert store.addrs(ID(b"peer2")) == [Multiaddr("/ip4/127.0.0.1/tcp/40012")] def test_clear_addrs(): store = PeerStore() - store.add_addrs("peer1", ["/foo1", "/bar1"], 10) - store.add_addrs("peer2", ["/foo2"], 10) - store.clear_addrs("peer1") + store.add_addrs( + ID(b"peer1"), + [Multiaddr("/ip4/127.0.0.1/tcp/40011"), Multiaddr("/ip4/127.0.0.1/tcp/40021")], + 10, + ) + store.add_addrs(ID(b"peer2"), [Multiaddr("/ip4/127.0.0.1/tcp/40012")], 10) + store.clear_addrs(ID(b"peer1")) - assert store.addrs("peer1") == [] - assert store.addrs("peer2") == ["/foo2"] + assert store.addrs(ID(b"peer1")) == [] + assert store.addrs(ID(b"peer2")) == [Multiaddr("/ip4/127.0.0.1/tcp/40012")] - store.add_addrs("peer1", ["/foo1", "/bar1"], 10) + store.add_addrs( + ID(b"peer1"), + [Multiaddr("/ip4/127.0.0.1/tcp/40011"), Multiaddr("/ip4/127.0.0.1/tcp/40021")], + 10, + ) - assert store.addrs("peer1") == ["/foo1", "/bar1"] + assert store.addrs(ID(b"peer1")) == [ + Multiaddr("/ip4/127.0.0.1/tcp/40011"), + Multiaddr("/ip4/127.0.0.1/tcp/40021"), + ] def test_peers_with_addrs(): store = PeerStore() - store.add_addrs("peer1", [], 10) - store.add_addrs("peer2", ["/foo"], 10) - store.add_addrs("peer3", ["/bar"], 10) + store.add_addrs(ID(b"peer1"), [], 10) + store.add_addrs(ID(b"peer2"), [Multiaddr("/ip4/127.0.0.1/tcp/4001")], 10) + store.add_addrs(ID(b"peer3"), [Multiaddr("/ip4/127.0.0.1/tcp/4002")], 10) - assert set(store.peers_with_addrs()) == {"peer2", "peer3"} + assert set(store.peers_with_addrs()) == {ID(b"peer2"), ID(b"peer3")} - store.clear_addrs("peer2") + store.clear_addrs(ID(b"peer2")) - assert set(store.peers_with_addrs()) == {"peer3"} + assert set(store.peers_with_addrs()) == {ID(b"peer3")} diff --git a/tests/core/peer/test_interop.py b/tests/core/peer/test_interop.py index cda571f9..05667cdd 100644 --- a/tests/core/peer/test_interop.py +++ b/tests/core/peer/test_interop.py @@ -23,9 +23,7 @@ kBZ7WvkmPV3aPL6jnwp2pXepntdVnaTiSxJ1dkXShZ/VSSDNZMYKY306EtHrIu3NZHtXhdyHKcggDXr qkBrdgErAkAlpGPojUwemOggr4FD8sLX1ot2hDJyyV7OK2FXfajWEYJyMRL1Gm9Uk1+Un53RAkJneqp JGAzKpyttXBTIDO51AkEA98KTiROMnnU8Y6Mgcvr68/SMIsvCYMt9/mtwSBGgl80VaTQ5Hpaktl6Xbh VUt5Wv0tRxlXZiViCGCD1EtrrwTw== -""".replace( - "\n", "" -) +""".replace("\n", "") EXPECTED_PEER_ID = "QmRK3JgmVEGiewxWbhpXLJyjWuGuLeSTMTndA1coMHEy5o" diff --git a/tests/core/peer/test_peerdata.py b/tests/core/peer/test_peerdata.py index aad8c5d5..65e98959 100644 --- a/tests/core/peer/test_peerdata.py +++ b/tests/core/peer/test_peerdata.py @@ -1,4 +1,7 @@ +from collections.abc import Sequence + import pytest +from multiaddr import Multiaddr from libp2p.crypto.secp256k1 import ( create_new_key_pair, @@ -8,7 +11,7 @@ from libp2p.peer.peerdata import ( PeerDataError, ) -MOCK_ADDR = "/peer" +MOCK_ADDR = Multiaddr("/ip4/127.0.0.1/tcp/4001") MOCK_KEYPAIR = create_new_key_pair() MOCK_PUBKEY = MOCK_KEYPAIR.public_key MOCK_PRIVKEY = MOCK_KEYPAIR.private_key @@ -23,7 +26,7 @@ def test_get_protocols_empty(): # Test case when adding protocols def test_add_protocols(): peer_data = PeerData() - protocols = ["protocol1", "protocol2"] + protocols: Sequence[str] = ["protocol1", "protocol2"] peer_data.add_protocols(protocols) assert peer_data.get_protocols() == protocols @@ -31,7 +34,7 @@ def test_add_protocols(): # Test case when setting protocols def test_set_protocols(): peer_data = PeerData() - protocols = ["protocolA", "protocolB"] + protocols: Sequence[str] = ["protocol1", "protocol2"] peer_data.set_protocols(protocols) assert peer_data.get_protocols() == protocols @@ -39,7 +42,7 @@ def test_set_protocols(): # Test case when adding addresses def test_add_addrs(): peer_data = PeerData() - addresses = [MOCK_ADDR] + addresses: Sequence[Multiaddr] = [MOCK_ADDR] peer_data.add_addrs(addresses) assert peer_data.get_addrs() == addresses @@ -47,7 +50,7 @@ def test_add_addrs(): # Test case when adding same address more than once def test_add_dup_addrs(): peer_data = PeerData() - addresses = [MOCK_ADDR, MOCK_ADDR] + addresses: Sequence[Multiaddr] = [MOCK_ADDR, MOCK_ADDR] peer_data.add_addrs(addresses) peer_data.add_addrs(addresses) assert peer_data.get_addrs() == [MOCK_ADDR] @@ -56,7 +59,7 @@ def test_add_dup_addrs(): # Test case for clearing addresses def test_clear_addrs(): peer_data = PeerData() - addresses = [MOCK_ADDR] + addresses: Sequence[Multiaddr] = [MOCK_ADDR] peer_data.add_addrs(addresses) peer_data.clear_addrs() assert peer_data.get_addrs() == [] diff --git a/tests/core/peer/test_peerid.py b/tests/core/peer/test_peerid.py index b2201c09..705aa550 100644 --- a/tests/core/peer/test_peerid.py +++ b/tests/core/peer/test_peerid.py @@ -6,16 +6,12 @@ import multihash from libp2p.crypto.rsa import ( create_new_key_pair, ) -import libp2p.peer.id as PeerID from libp2p.peer.id import ( ID, ) ALPHABETS = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz" -# ensure we are not in "debug" mode for the following tests -PeerID.FRIENDLY_IDS = False - def test_eq_impl_for_bytes(): random_id_string = "" @@ -70,8 +66,8 @@ def test_eq_true(): def test_eq_false(): - peer_id = ID("efgh") - other = ID("abcd") + peer_id = ID(b"efgh") + other = ID(b"abcd") assert peer_id != other @@ -91,7 +87,7 @@ def test_id_from_base58(): for _ in range(10): random_id_string += random.choice(ALPHABETS) expected = ID(base58.b58decode(random_id_string)) - actual = ID.from_base58(random_id_string.encode()) + actual = ID.from_base58(random_id_string) assert actual == expected diff --git a/tests/core/peer/test_peerinfo.py b/tests/core/peer/test_peerinfo.py index 497060c0..5e67d022 100644 --- a/tests/core/peer/test_peerinfo.py +++ b/tests/core/peer/test_peerinfo.py @@ -17,10 +17,14 @@ VALID_MULTI_ADDR_STR = "/ip4/127.0.0.1/tcp/8000/p2p/3YgLAeMKSAPcGqZkAt8mREqhQXmJ def test_init_(): - random_addrs = [random.randint(0, 255) for r in range(4)] + random_addrs = [ + multiaddr.Multiaddr(f"/ip4/127.0.0.1/tcp/{1000 + i}") for i in range(4) + ] + random_id_string = "" for _ in range(10): random_id_string += random.SystemRandom().choice(ALPHABETS) + peer_id = ID(random_id_string.encode()) peer_info = PeerInfo(peer_id, random_addrs) diff --git a/tests/core/peer/test_peermetadata.py b/tests/core/peer/test_peermetadata.py index 0ee56f2d..e68e5108 100644 --- a/tests/core/peer/test_peermetadata.py +++ b/tests/core/peer/test_peermetadata.py @@ -1,5 +1,6 @@ import pytest +from libp2p.peer.id import ID from libp2p.peer.peerstore import ( PeerStore, PeerStoreError, @@ -11,36 +12,36 @@ from libp2p.peer.peerstore import ( def test_get_empty(): with pytest.raises(PeerStoreError): store = PeerStore() - val = store.get("peer", "key") + val = store.get(ID(b"peer"), "key") assert not val def test_put_get_simple(): store = PeerStore() - store.put("peer", "key", "val") - assert store.get("peer", "key") == "val" + store.put(ID(b"peer"), "key", "val") + assert store.get(ID(b"peer"), "key") == "val" def test_put_get_update(): store = PeerStore() - store.put("peer", "key1", "val1") - store.put("peer", "key2", "val2") - store.put("peer", "key2", "new val2") + store.put(ID(b"peer"), "key1", "val1") + store.put(ID(b"peer"), "key2", "val2") + store.put(ID(b"peer"), "key2", "new val2") - assert store.get("peer", "key1") == "val1" - assert store.get("peer", "key2") == "new val2" + assert store.get(ID(b"peer"), "key1") == "val1" + assert store.get(ID(b"peer"), "key2") == "new val2" def test_put_get_two_peers(): store = PeerStore() - store.put("peer1", "key1", "val1") - store.put("peer2", "key1", "val1 prime") + store.put(ID(b"peer1"), "key1", "val1") + store.put(ID(b"peer2"), "key1", "val1 prime") - assert store.get("peer1", "key1") == "val1" - assert store.get("peer2", "key1") == "val1 prime" + assert store.get(ID(b"peer1"), "key1") == "val1" + assert store.get(ID(b"peer2"), "key1") == "val1 prime" # Try update - store.put("peer2", "key1", "new val1") + store.put(ID(b"peer2"), "key1", "new val1") - assert store.get("peer1", "key1") == "val1" - assert store.get("peer2", "key1") == "new val1" + assert store.get(ID(b"peer1"), "key1") == "val1" + assert store.get(ID(b"peer2"), "key1") == "new val1" diff --git a/tests/core/peer/test_peerstore.py b/tests/core/peer/test_peerstore.py index 42137b3c..fcfc83a2 100644 --- a/tests/core/peer/test_peerstore.py +++ b/tests/core/peer/test_peerstore.py @@ -1,5 +1,7 @@ import pytest +from multiaddr import Multiaddr +from libp2p.peer.id import ID from libp2p.peer.peerstore import ( PeerStore, PeerStoreError, @@ -11,52 +13,52 @@ from libp2p.peer.peerstore import ( def test_peer_info_empty(): store = PeerStore() with pytest.raises(PeerStoreError): - store.peer_info("peer") + store.peer_info(ID(b"peer")) def test_peer_info_basic(): store = PeerStore() - store.add_addr("peer", "/foo", 10) - info = store.peer_info("peer") + store.add_addr(ID(b"peer"), Multiaddr("/ip4/127.0.0.1/tcp/4001"), 10) + info = store.peer_info(ID(b"peer")) - assert info.peer_id == "peer" - assert info.addrs == ["/foo"] + assert info.peer_id == ID(b"peer") + assert info.addrs == [Multiaddr("/ip4/127.0.0.1/tcp/4001")] def test_add_get_protocols_basic(): store = PeerStore() - store.add_protocols("peer1", ["p1", "p2"]) - store.add_protocols("peer2", ["p3"]) + store.add_protocols(ID(b"peer1"), ["p1", "p2"]) + store.add_protocols(ID(b"peer2"), ["p3"]) - assert set(store.get_protocols("peer1")) == {"p1", "p2"} - assert set(store.get_protocols("peer2")) == {"p3"} + assert set(store.get_protocols(ID(b"peer1"))) == {"p1", "p2"} + assert set(store.get_protocols(ID(b"peer2"))) == {"p3"} def test_add_get_protocols_extend(): store = PeerStore() - store.add_protocols("peer1", ["p1", "p2"]) - store.add_protocols("peer1", ["p3"]) + store.add_protocols(ID(b"peer1"), ["p1", "p2"]) + store.add_protocols(ID(b"peer1"), ["p3"]) - assert set(store.get_protocols("peer1")) == {"p1", "p2", "p3"} + assert set(store.get_protocols(ID(b"peer1"))) == {"p1", "p2", "p3"} def test_set_protocols(): store = PeerStore() - store.add_protocols("peer1", ["p1", "p2"]) - store.add_protocols("peer2", ["p3"]) + store.add_protocols(ID(b"peer1"), ["p1", "p2"]) + store.add_protocols(ID(b"peer2"), ["p3"]) - store.set_protocols("peer1", ["p4"]) - store.set_protocols("peer2", []) + store.set_protocols(ID(b"peer1"), ["p4"]) + store.set_protocols(ID(b"peer2"), []) - assert set(store.get_protocols("peer1")) == {"p4"} - assert set(store.get_protocols("peer2")) == set() + assert set(store.get_protocols(ID(b"peer1"))) == {"p4"} + assert set(store.get_protocols(ID(b"peer2"))) == set() # Test with methods from other Peer interfaces. def test_peers(): store = PeerStore() - store.add_protocols("peer1", []) - store.put("peer2", "key", "val") - store.add_addr("peer3", "/foo", 10) + store.add_protocols(ID(b"peer1"), []) + store.put(ID(b"peer2"), "key", "val") + store.add_addr(ID(b"peer3"), Multiaddr("/ip4/127.0.0.1/tcp/4001"), 10) - assert set(store.peer_ids()) == {"peer1", "peer2", "peer3"} + assert set(store.peer_ids()) == {ID(b"peer1"), ID(b"peer2"), ID(b"peer3")} diff --git a/tests/core/protocol_muxer/test_protocol_muxer.py b/tests/core/protocol_muxer/test_protocol_muxer.py index 98f48533..b089390b 100644 --- a/tests/core/protocol_muxer/test_protocol_muxer.py +++ b/tests/core/protocol_muxer/test_protocol_muxer.py @@ -1,10 +1,7 @@ import pytest -from trio.testing import ( - RaisesGroup, -) -from libp2p.host.exceptions import ( - StreamFailure, +from libp2p.custom_types import ( + TProtocol, ) from libp2p.tools.utils import ( create_echo_stream_handler, @@ -13,10 +10,10 @@ from tests.utils.factories import ( HostFactory, ) -PROTOCOL_ECHO = "/echo/1.0.0" -PROTOCOL_POTATO = "/potato/1.0.0" -PROTOCOL_FOO = "/foo/1.0.0" -PROTOCOL_ROCK = "/rock/1.0.0" +PROTOCOL_ECHO = TProtocol("/echo/1.0.0") +PROTOCOL_POTATO = TProtocol("/potato/1.0.0") +PROTOCOL_FOO = TProtocol("/foo/1.0.0") +PROTOCOL_ROCK = TProtocol("/rock/1.0.0") ACK_PREFIX = "ack:" @@ -61,19 +58,12 @@ async def test_single_protocol_succeeds(security_protocol): @pytest.mark.trio async def test_single_protocol_fails(security_protocol): - # using trio.testing.RaisesGroup b/c pytest.raises does not handle ExceptionGroups - # yet: https://github.com/pytest-dev/pytest/issues/11538 - # but switch to that once they do - - # the StreamFailure is within 2 nested ExceptionGroups, so we use strict=False - # to unwrap down to the core Exception - with RaisesGroup(StreamFailure, allow_unwrapped=True, flatten_subgroups=True): + # Expect that protocol negotiation fails when no common protocols exist + with pytest.raises(Exception): await perform_simple_test( "", [PROTOCOL_ECHO], [PROTOCOL_POTATO], security_protocol ) - # Cleanup not reached on error - @pytest.mark.trio async def test_multiple_protocol_first_is_valid_succeeds(security_protocol): @@ -103,16 +93,16 @@ async def test_multiple_protocol_second_is_valid_succeeds(security_protocol): @pytest.mark.trio async def test_multiple_protocol_fails(security_protocol): - protocols_for_client = [PROTOCOL_ROCK, PROTOCOL_FOO, "/bar/1.0.0"] - protocols_for_listener = ["/aspyn/1.0.0", "/rob/1.0.0", "/zx/1.0.0", "/alex/1.0.0"] + protocols_for_client = [PROTOCOL_ROCK, PROTOCOL_FOO, TProtocol("/bar/1.0.0")] + protocols_for_listener = [ + TProtocol("/aspyn/1.0.0"), + TProtocol("/rob/1.0.0"), + TProtocol("/zx/1.0.0"), + TProtocol("/alex/1.0.0"), + ] - # using trio.testing.RaisesGroup b/c pytest.raises does not handle ExceptionGroups - # yet: https://github.com/pytest-dev/pytest/issues/11538 - # but switch to that once they do - - # the StreamFailure is within 2 nested ExceptionGroups, so we use strict=False - # to unwrap down to the core Exception - with RaisesGroup(StreamFailure, allow_unwrapped=True, flatten_subgroups=True): + # Expect that protocol negotiation fails when no common protocols exist + with pytest.raises(Exception): await perform_simple_test( "", protocols_for_client, protocols_for_listener, security_protocol ) @@ -142,8 +132,8 @@ async def test_multistream_command(security_protocol): for protocol in supported_protocols: assert protocol in response - assert "/does/not/exist" not in response - assert "/foo/bar/1.2.3" not in response + assert TProtocol("/does/not/exist") not in response + assert TProtocol("/foo/bar/1.2.3") not in response # Dialer asks for unspoorted command with pytest.raises(ValueError, match="Command not supported"): diff --git a/tests/core/pubsub/test_dummyaccount_demo.py b/tests/core/pubsub/test_dummyaccount_demo.py index 417c69e4..c70ba57e 100644 --- a/tests/core/pubsub/test_dummyaccount_demo.py +++ b/tests/core/pubsub/test_dummyaccount_demo.py @@ -20,7 +20,6 @@ async def perform_test(num_nodes, adjacency_map, action_func, assertion_func): such as send crypto and set crypto :param assertion_func: assertions for testing the results of the actions are correct """ - async with DummyAccountNode.create(num_nodes) as dummy_nodes: # Create connections between nodes according to `adjacency_map` async with trio.open_nursery() as nursery: diff --git a/tests/core/pubsub/test_floodsub.py b/tests/core/pubsub/test_floodsub.py index 053dcb7f..135cbbec 100644 --- a/tests/core/pubsub/test_floodsub.py +++ b/tests/core/pubsub/test_floodsub.py @@ -46,7 +46,7 @@ async def test_simple_two_nodes(): async def test_timed_cache_two_nodes(): # Two nodes using LastSeenCache with a TTL of 120 seconds def get_msg_id(msg): - return (msg.data, msg.from_id) + return msg.data + msg.from_id async with PubsubFactory.create_batch_with_floodsub( 2, seen_ttl=120, msg_id_constructor=get_msg_id diff --git a/tests/core/pubsub/test_gossipsub.py b/tests/core/pubsub/test_gossipsub.py index 20c315ef..dffcbeac 100644 --- a/tests/core/pubsub/test_gossipsub.py +++ b/tests/core/pubsub/test_gossipsub.py @@ -5,6 +5,7 @@ import trio from libp2p.pubsub.gossipsub import ( PROTOCOL_ID, + GossipSub, ) from libp2p.tools.utils import ( connect, @@ -24,7 +25,10 @@ async def test_join(): async with PubsubFactory.create_batch_with_gossipsub( 4, degree=4, degree_low=3, degree_high=5, heartbeat_interval=1, time_to_live=1 ) as pubsubs_gsub: - gossipsubs = [pubsub.router for pubsub in pubsubs_gsub] + gossipsubs = [] + for pubsub in pubsubs_gsub: + if isinstance(pubsub.router, GossipSub): + gossipsubs.append(pubsub.router) hosts = [pubsub.host for pubsub in pubsubs_gsub] hosts_indices = list(range(len(pubsubs_gsub))) @@ -86,7 +90,9 @@ async def test_join(): @pytest.mark.trio async def test_leave(): async with PubsubFactory.create_batch_with_gossipsub(1) as pubsubs_gsub: - gossipsub = pubsubs_gsub[0].router + router = pubsubs_gsub[0].router + assert isinstance(router, GossipSub) + gossipsub = router topic = "test_leave" assert topic not in gossipsub.mesh @@ -104,7 +110,11 @@ async def test_leave(): @pytest.mark.trio async def test_handle_graft(monkeypatch): async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub: - gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub) + gossipsub_routers = [] + for pubsub in pubsubs_gsub: + if isinstance(pubsub.router, GossipSub): + gossipsub_routers.append(pubsub.router) + gossipsubs = tuple(gossipsub_routers) index_alice = 0 id_alice = pubsubs_gsub[index_alice].my_id @@ -156,7 +166,11 @@ async def test_handle_prune(): async with PubsubFactory.create_batch_with_gossipsub( 2, heartbeat_interval=3 ) as pubsubs_gsub: - gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub) + gossipsub_routers = [] + for pubsub in pubsubs_gsub: + if isinstance(pubsub.router, GossipSub): + gossipsub_routers.append(pubsub.router) + gossipsubs = tuple(gossipsub_routers) index_alice = 0 id_alice = pubsubs_gsub[index_alice].my_id @@ -382,7 +396,9 @@ async def test_mesh_heartbeat(initial_mesh_peer_count, monkeypatch): fake_peer_ids = [IDFactory() for _ in range(total_peer_count)] peer_protocol = {peer_id: PROTOCOL_ID for peer_id in fake_peer_ids} - monkeypatch.setattr(pubsubs_gsub[0].router, "peer_protocol", peer_protocol) + router = pubsubs_gsub[0].router + assert isinstance(router, GossipSub) + monkeypatch.setattr(router, "peer_protocol", peer_protocol) peer_topics = {topic: set(fake_peer_ids)} # Monkeypatch the peer subscriptions @@ -394,27 +410,21 @@ async def test_mesh_heartbeat(initial_mesh_peer_count, monkeypatch): mesh_peers = [fake_peer_ids[i] for i in mesh_peer_indices] router_mesh = {topic: set(mesh_peers)} # Monkeypatch our mesh peers - monkeypatch.setattr(pubsubs_gsub[0].router, "mesh", router_mesh) + monkeypatch.setattr(router, "mesh", router_mesh) - peers_to_graft, peers_to_prune = pubsubs_gsub[0].router.mesh_heartbeat() - if initial_mesh_peer_count > pubsubs_gsub[0].router.degree: + peers_to_graft, peers_to_prune = router.mesh_heartbeat() + if initial_mesh_peer_count > router.degree: # If number of initial mesh peers is more than `GossipSubDegree`, # we should PRUNE mesh peers assert len(peers_to_graft) == 0 - assert ( - len(peers_to_prune) - == initial_mesh_peer_count - pubsubs_gsub[0].router.degree - ) + assert len(peers_to_prune) == initial_mesh_peer_count - router.degree for peer in peers_to_prune: assert peer in mesh_peers - elif initial_mesh_peer_count < pubsubs_gsub[0].router.degree: + elif initial_mesh_peer_count < router.degree: # If number of initial mesh peers is less than `GossipSubDegree`, # we should GRAFT more peers assert len(peers_to_prune) == 0 - assert ( - len(peers_to_graft) - == pubsubs_gsub[0].router.degree - initial_mesh_peer_count - ) + assert len(peers_to_graft) == router.degree - initial_mesh_peer_count for peer in peers_to_graft: assert peer not in mesh_peers else: @@ -436,7 +446,10 @@ async def test_gossip_heartbeat(initial_peer_count, monkeypatch): fake_peer_ids = [IDFactory() for _ in range(total_peer_count)] peer_protocol = {peer_id: PROTOCOL_ID for peer_id in fake_peer_ids} - monkeypatch.setattr(pubsubs_gsub[0].router, "peer_protocol", peer_protocol) + router_obj = pubsubs_gsub[0].router + assert isinstance(router_obj, GossipSub) + router = router_obj + monkeypatch.setattr(router, "peer_protocol", peer_protocol) topic_mesh_peer_count = 14 # Split into mesh peers and fanout peers @@ -453,14 +466,14 @@ async def test_gossip_heartbeat(initial_peer_count, monkeypatch): mesh_peers = [fake_peer_ids[i] for i in mesh_peer_indices] router_mesh = {topic_mesh: set(mesh_peers)} # Monkeypatch our mesh peers - monkeypatch.setattr(pubsubs_gsub[0].router, "mesh", router_mesh) + monkeypatch.setattr(router, "mesh", router_mesh) fanout_peer_indices = random.sample( range(topic_mesh_peer_count, total_peer_count), initial_peer_count ) fanout_peers = [fake_peer_ids[i] for i in fanout_peer_indices] router_fanout = {topic_fanout: set(fanout_peers)} # Monkeypatch our fanout peers - monkeypatch.setattr(pubsubs_gsub[0].router, "fanout", router_fanout) + monkeypatch.setattr(router, "fanout", router_fanout) def window(topic): if topic == topic_mesh: @@ -471,20 +484,18 @@ async def test_gossip_heartbeat(initial_peer_count, monkeypatch): return [] # Monkeypatch the memory cache messages - monkeypatch.setattr(pubsubs_gsub[0].router.mcache, "window", window) + monkeypatch.setattr(router.mcache, "window", window) - peers_to_gossip = pubsubs_gsub[0].router.gossip_heartbeat() + peers_to_gossip = router.gossip_heartbeat() # If our mesh peer count is less than `GossipSubDegree`, we should gossip to up # to `GossipSubDegree` peers (exclude mesh peers). - if topic_mesh_peer_count - initial_peer_count < pubsubs_gsub[0].router.degree: + if topic_mesh_peer_count - initial_peer_count < router.degree: # The same goes for fanout so it's two times the number of peers to gossip. assert len(peers_to_gossip) == 2 * ( topic_mesh_peer_count - initial_peer_count ) - elif ( - topic_mesh_peer_count - initial_peer_count >= pubsubs_gsub[0].router.degree - ): - assert len(peers_to_gossip) == 2 * (pubsubs_gsub[0].router.degree) + elif topic_mesh_peer_count - initial_peer_count >= router.degree: + assert len(peers_to_gossip) == 2 * (router.degree) for peer in peers_to_gossip: if peer in peer_topics[topic_mesh]: diff --git a/tests/core/pubsub/test_gossipsub_direct_peers.py b/tests/core/pubsub/test_gossipsub_direct_peers.py index d8464a4b..adb20a80 100644 --- a/tests/core/pubsub/test_gossipsub_direct_peers.py +++ b/tests/core/pubsub/test_gossipsub_direct_peers.py @@ -4,6 +4,9 @@ import trio from libp2p.peer.peerinfo import ( info_from_p2p_addr, ) +from libp2p.pubsub.gossipsub import ( + GossipSub, +) from libp2p.tools.utils import ( connect, ) @@ -82,31 +85,33 @@ async def test_reject_graft(): await pubsubs_gsub_1[0].router.join(topic) # Pre-Graft assertions - assert ( - topic in pubsubs_gsub_0[0].router.mesh - ), "topic not in mesh for gossipsub 0" - assert ( - topic in pubsubs_gsub_1[0].router.mesh - ), "topic not in mesh for gossipsub 1" - assert ( - host_1.get_id() not in pubsubs_gsub_0[0].router.mesh[topic] - ), "gossipsub 1 in mesh topic for gossipsub 0" - assert ( - host_0.get_id() not in pubsubs_gsub_1[0].router.mesh[topic] - ), "gossipsub 0 in mesh topic for gossipsub 1" + assert topic in pubsubs_gsub_0[0].router.mesh, ( + "topic not in mesh for gossipsub 0" + ) + assert topic in pubsubs_gsub_1[0].router.mesh, ( + "topic not in mesh for gossipsub 1" + ) + assert host_1.get_id() not in pubsubs_gsub_0[0].router.mesh[topic], ( + "gossipsub 1 in mesh topic for gossipsub 0" + ) + assert host_0.get_id() not in pubsubs_gsub_1[0].router.mesh[topic], ( + "gossipsub 0 in mesh topic for gossipsub 1" + ) # Gossipsub 1 emits a graft request to Gossipsub 0 - await pubsubs_gsub_0[0].router.emit_graft(topic, host_1.get_id()) + router_obj = pubsubs_gsub_0[0].router + assert isinstance(router_obj, GossipSub) + await router_obj.emit_graft(topic, host_1.get_id()) await trio.sleep(1) # Post-Graft assertions - assert ( - host_1.get_id() not in pubsubs_gsub_0[0].router.mesh[topic] - ), "gossipsub 1 in mesh topic for gossipsub 0" - assert ( - host_0.get_id() not in pubsubs_gsub_1[0].router.mesh[topic] - ), "gossipsub 0 in mesh topic for gossipsub 1" + assert host_1.get_id() not in pubsubs_gsub_0[0].router.mesh[topic], ( + "gossipsub 1 in mesh topic for gossipsub 0" + ) + assert host_0.get_id() not in pubsubs_gsub_1[0].router.mesh[topic], ( + "gossipsub 0 in mesh topic for gossipsub 1" + ) except Exception as e: print(f"Test failed with error: {e}") @@ -139,12 +144,12 @@ async def test_heartbeat_reconnect(): await trio.sleep(1) # Verify initial connection - assert ( - host_1.get_id() in pubsubs_gsub_0[0].peers - ), "Initial connection not established for gossipsub 0" - assert ( - host_0.get_id() in pubsubs_gsub_1[0].peers - ), "Initial connection not established for gossipsub 0" + assert host_1.get_id() in pubsubs_gsub_0[0].peers, ( + "Initial connection not established for gossipsub 0" + ) + assert host_0.get_id() in pubsubs_gsub_1[0].peers, ( + "Initial connection not established for gossipsub 0" + ) # Simulate disconnection await host_0.disconnect(host_1.get_id()) @@ -153,17 +158,17 @@ async def test_heartbeat_reconnect(): await trio.sleep(1) # Verify that peers are removed after disconnection - assert ( - host_0.get_id() not in pubsubs_gsub_1[0].peers - ), "Peer 0 still in gossipsub 1 after disconnection" + assert host_0.get_id() not in pubsubs_gsub_1[0].peers, ( + "Peer 0 still in gossipsub 1 after disconnection" + ) # Wait for heartbeat to reestablish connection await trio.sleep(2) # Verify connection reestablishment - assert ( - host_0.get_id() in pubsubs_gsub_1[0].peers - ), "Reconnection not established for gossipsub 0" + assert host_0.get_id() in pubsubs_gsub_1[0].peers, ( + "Reconnection not established for gossipsub 0" + ) except Exception as e: print(f"Test failed with error: {e}") diff --git a/tests/core/pubsub/test_mcache.py b/tests/core/pubsub/test_mcache.py index 7a494259..9d73840d 100644 --- a/tests/core/pubsub/test_mcache.py +++ b/tests/core/pubsub/test_mcache.py @@ -1,15 +1,26 @@ +from collections.abc import ( + Sequence, +) + +from libp2p.peer.id import ( + ID, +) from libp2p.pubsub.mcache import ( MessageCache, ) +from libp2p.pubsub.pb import ( + rpc_pb2, +) -class Msg: - __slots__ = ["topicIDs", "seqno", "from_id"] - - def __init__(self, topicIDs, seqno, from_id): - self.topicIDs = topicIDs - self.seqno = seqno - self.from_id = from_id +def make_msg( + topic_ids: Sequence[str], + seqno: bytes, + from_id: ID, +) -> rpc_pb2.Message: + return rpc_pb2.Message( + from_id=from_id.to_bytes(), seqno=seqno, topicIDs=list(topic_ids) + ) def test_mcache(): @@ -19,7 +30,7 @@ def test_mcache(): msgs = [] for i in range(60): - msgs.append(Msg(["test"], i, "test")) + msgs.append(make_msg(["test"], i.to_bytes(1, "big"), ID(b"test"))) for i in range(10): mcache.put(msgs[i]) diff --git a/tests/core/pubsub/test_pubsub.py b/tests/core/pubsub/test_pubsub.py index ff145887..81389ed1 100644 --- a/tests/core/pubsub/test_pubsub.py +++ b/tests/core/pubsub/test_pubsub.py @@ -1,6 +1,7 @@ from contextlib import ( contextmanager, ) +import inspect from typing import ( NamedTuple, ) @@ -14,6 +15,9 @@ from libp2p.exceptions import ( from libp2p.network.stream.exceptions import ( StreamEOF, ) +from libp2p.peer.id import ( + ID, +) from libp2p.pubsub.pb import ( rpc_pb2, ) @@ -121,16 +125,18 @@ async def test_set_and_remove_topic_validator(): async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: is_sync_validator_called = False - def sync_validator(peer_id, msg): + def sync_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool: nonlocal is_sync_validator_called is_sync_validator_called = True + return True is_async_validator_called = False - async def async_validator(peer_id, msg): + async def async_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool: nonlocal is_async_validator_called is_async_validator_called = True await trio.lowlevel.checkpoint() + return True topic = "TEST_VALIDATOR" @@ -144,7 +150,13 @@ async def test_set_and_remove_topic_validator(): assert not topic_validator.is_async # Validate with sync validator - topic_validator.validator(peer_id=IDFactory(), msg="msg") + test_msg = make_pubsub_msg( + origin_id=IDFactory(), + topic_ids=[topic], + data=b"test", + seqno=b"\x00" * 8, + ) + topic_validator.validator(IDFactory(), test_msg) assert is_sync_validator_called assert not is_async_validator_called @@ -158,7 +170,20 @@ async def test_set_and_remove_topic_validator(): assert topic_validator.is_async # Validate with async validator - await topic_validator.validator(peer_id=IDFactory(), msg="msg") + test_msg = make_pubsub_msg( + origin_id=IDFactory(), + topic_ids=[topic], + data=b"test", + seqno=b"\x00" * 8, + ) + validator = topic_validator.validator + if topic_validator.is_async: + import inspect + + if inspect.iscoroutinefunction(validator): + await validator(IDFactory(), test_msg) + else: + validator(IDFactory(), test_msg) assert is_async_validator_called assert not is_sync_validator_called @@ -170,20 +195,18 @@ async def test_set_and_remove_topic_validator(): @pytest.mark.trio async def test_get_msg_validators(): + calls = [0, 0] # [sync, async] + + def sync_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool: + calls[0] += 1 + return True + + async def async_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool: + calls[1] += 1 + await trio.lowlevel.checkpoint() + return True + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: - times_sync_validator_called = 0 - - def sync_validator(peer_id, msg): - nonlocal times_sync_validator_called - times_sync_validator_called += 1 - - times_async_validator_called = 0 - - async def async_validator(peer_id, msg): - nonlocal times_async_validator_called - times_async_validator_called += 1 - await trio.lowlevel.checkpoint() - topic_1 = "TEST_VALIDATOR_1" topic_2 = "TEST_VALIDATOR_2" topic_3 = "TEST_VALIDATOR_3" @@ -204,13 +227,15 @@ async def test_get_msg_validators(): topic_validators = pubsubs_fsub[0].get_msg_validators(msg) for topic_validator in topic_validators: + validator = topic_validator.validator if topic_validator.is_async: - await topic_validator.validator(peer_id=IDFactory(), msg="msg") + if inspect.iscoroutinefunction(validator): + await validator(IDFactory(), msg) else: - topic_validator.validator(peer_id=IDFactory(), msg="msg") + validator(IDFactory(), msg) - assert times_sync_validator_called == 2 - assert times_async_validator_called == 1 + assert calls[0] == 2 + assert calls[1] == 1 @pytest.mark.parametrize( @@ -221,17 +246,17 @@ async def test_get_msg_validators(): async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed): async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: - def passed_sync_validator(peer_id, msg): + def passed_sync_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool: return True - def failed_sync_validator(peer_id, msg): + def failed_sync_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool: return False - async def passed_async_validator(peer_id, msg): + async def passed_async_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool: await trio.lowlevel.checkpoint() return True - async def failed_async_validator(peer_id, msg): + async def failed_async_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool: await trio.lowlevel.checkpoint() return False @@ -297,11 +322,12 @@ async def test_continuously_read_stream(monkeypatch, nursery, security_protocol) m.setattr(pubsubs_fsub[0].router, "handle_rpc", mock_handle_rpc) yield Events(event_push_msg, event_handle_subscription, event_handle_rpc) - async with PubsubFactory.create_batch_with_floodsub( - 1, security_protocol=security_protocol - ) as pubsubs_fsub, net_stream_pair_factory( - security_protocol=security_protocol - ) as stream_pair: + async with ( + PubsubFactory.create_batch_with_floodsub( + 1, security_protocol=security_protocol + ) as pubsubs_fsub, + net_stream_pair_factory(security_protocol=security_protocol) as stream_pair, + ): await pubsubs_fsub[0].subscribe(TESTING_TOPIC) # Kick off the task `continuously_read_stream` nursery.start_soon(pubsubs_fsub[0].continuously_read_stream, stream_pair[0]) @@ -429,11 +455,12 @@ async def test_handle_talk(): @pytest.mark.trio async def test_message_all_peers(monkeypatch, security_protocol): - async with PubsubFactory.create_batch_with_floodsub( - 1, security_protocol=security_protocol - ) as pubsubs_fsub, net_stream_pair_factory( - security_protocol=security_protocol - ) as stream_pair: + async with ( + PubsubFactory.create_batch_with_floodsub( + 1, security_protocol=security_protocol + ) as pubsubs_fsub, + net_stream_pair_factory(security_protocol=security_protocol) as stream_pair, + ): peer_id = IDFactory() mock_peers = {peer_id: stream_pair[0]} with monkeypatch.context() as m: @@ -530,15 +557,15 @@ async def test_publish_push_msg_is_called(monkeypatch): await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA) await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA) - assert ( - len(msgs) == 2 - ), "`push_msg` should be called every time `publish` is called" + assert len(msgs) == 2, ( + "`push_msg` should be called every time `publish` is called" + ) assert (msg_forwarders[0] == msg_forwarders[1]) and ( msg_forwarders[1] == pubsubs_fsub[0].my_id ) - assert ( - msgs[0].seqno != msgs[1].seqno - ), "`seqno` should be different every time" + assert msgs[0].seqno != msgs[1].seqno, ( + "`seqno` should be different every time" + ) @pytest.mark.trio @@ -611,7 +638,7 @@ async def test_push_msg(monkeypatch): # Test: add a topic validator and `push_msg` the message that # does not pass the validation. # `router_publish` is not called then. - def failed_sync_validator(peer_id, msg): + def failed_sync_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool: return False pubsubs_fsub[0].set_topic_validator( @@ -659,6 +686,9 @@ async def test_strict_signing_failed_validation(monkeypatch): seqno=b"\x00" * 8, ) priv_key = pubsubs_fsub[0].sign_key + assert priv_key is not None, ( + "Private key should not be None when strict_signing=True" + ) signature = priv_key.sign( PUBSUB_SIGNING_PREFIX.encode() + msg.SerializeToString() ) @@ -803,15 +833,15 @@ async def test_blacklist_blocks_new_peer_connections(monkeypatch): await pubsub._handle_new_peer(blacklisted_peer) # Verify that both new_stream and router.add_peer was not called - assert ( - not new_stream_called - ), "new_stream should be not be called to get hello packet" - assert ( - not router_add_peer_called - ), "Router.add_peer should not be called for blacklisted peer" - assert ( - blacklisted_peer not in pubsub.peers - ), "Blacklisted peer should not be in peers dict" + assert not new_stream_called, ( + "new_stream should be not be called to get hello packet" + ) + assert not router_add_peer_called, ( + "Router.add_peer should not be called for blacklisted peer" + ) + assert blacklisted_peer not in pubsub.peers, ( + "Blacklisted peer should not be in peers dict" + ) @pytest.mark.trio @@ -838,7 +868,7 @@ async def test_blacklist_blocks_messages_from_blacklisted_originator(): # Track if router.publish is called router_publish_called = False - async def mock_router_publish(*args, **kwargs): + async def mock_router_publish(msg_forwarder: ID, pubsub_msg: rpc_pb2.Message): nonlocal router_publish_called router_publish_called = True await trio.lowlevel.checkpoint() @@ -851,12 +881,12 @@ async def test_blacklist_blocks_messages_from_blacklisted_originator(): await pubsub.push_msg(blacklisted_originator, msg) # Verify message was rejected - assert ( - not router_publish_called - ), "Router.publish should not be called for blacklisted originator" - assert not pubsub._is_msg_seen( - msg - ), "Message from blacklisted originator should not be marked as seen" + assert not router_publish_called, ( + "Router.publish should not be called for blacklisted originator" + ) + assert not pubsub._is_msg_seen(msg), ( + "Message from blacklisted originator should not be marked as seen" + ) finally: pubsub.router.publish = original_router_publish @@ -894,8 +924,8 @@ async def test_blacklist_allows_non_blacklisted_peers(): # Track router.publish calls router_publish_calls = [] - async def mock_router_publish(*args, **kwargs): - router_publish_calls.append(args) + async def mock_router_publish(msg_forwarder: ID, pubsub_msg: rpc_pb2.Message): + router_publish_calls.append((msg_forwarder, pubsub_msg)) await trio.lowlevel.checkpoint() original_router_publish = pubsub.router.publish @@ -909,15 +939,15 @@ async def test_blacklist_allows_non_blacklisted_peers(): await pubsub.push_msg(allowed_peer, msg_from_blacklisted) # Verify only allowed message was processed - assert ( - len(router_publish_calls) == 1 - ), "Only one message should be processed" - assert pubsub._is_msg_seen( - msg_from_allowed - ), "Allowed message should be marked as seen" - assert not pubsub._is_msg_seen( - msg_from_blacklisted - ), "Blacklisted message should not be marked as seen" + assert len(router_publish_calls) == 1, ( + "Only one message should be processed" + ) + assert pubsub._is_msg_seen(msg_from_allowed), ( + "Allowed message should be marked as seen" + ) + assert not pubsub._is_msg_seen(msg_from_blacklisted), ( + "Blacklisted message should not be marked as seen" + ) # Verify subscription received the allowed message received_msg = await sub.get() @@ -960,7 +990,7 @@ async def test_blacklist_integration_with_existing_functionality(): # due to seen cache (not blacklist) router_publish_called = False - async def mock_router_publish(*args, **kwargs): + async def mock_router_publish(msg_forwarder: ID, pubsub_msg: rpc_pb2.Message): nonlocal router_publish_called router_publish_called = True await trio.lowlevel.checkpoint() @@ -970,9 +1000,9 @@ async def test_blacklist_integration_with_existing_functionality(): try: await pubsub.push_msg(other_peer, msg) - assert ( - not router_publish_called - ), "Duplicate message should be rejected by seen cache" + assert not router_publish_called, ( + "Duplicate message should be rejected by seen cache" + ) finally: pubsub.router.publish = original_router_publish @@ -1001,7 +1031,7 @@ async def test_blacklist_blocks_messages_from_blacklisted_source(): # Track if router.publish is called (it shouldn't be for blacklisted forwarder) router_publish_called = False - async def mock_router_publish(*args, **kwargs): + async def mock_router_publish(msg_forwarder: ID, pubsub_msg: rpc_pb2.Message): nonlocal router_publish_called router_publish_called = True await trio.lowlevel.checkpoint() @@ -1014,12 +1044,12 @@ async def test_blacklist_blocks_messages_from_blacklisted_source(): await pubsub.push_msg(blacklisted_forwarder, msg) # Verify message was rejected - assert ( - not router_publish_called - ), "Router.publish should not be called for blacklisted forwarder" - assert not pubsub._is_msg_seen( - msg - ), "Message from blacklisted forwarder should not be marked as seen" + assert not router_publish_called, ( + "Router.publish should not be called for blacklisted forwarder" + ) + assert not pubsub._is_msg_seen(msg), ( + "Message from blacklisted forwarder should not be marked as seen" + ) finally: pubsub.router.publish = original_router_publish diff --git a/tests/core/security/test_secio.py b/tests/core/security/test_secio.py index ac1a03a3..55035bbf 100644 --- a/tests/core/security/test_secio.py +++ b/tests/core/security/test_secio.py @@ -1,6 +1,7 @@ import pytest import trio +from libp2p.abc import ISecureConn from libp2p.crypto.secp256k1 import ( create_new_key_pair, ) @@ -32,7 +33,8 @@ async def test_create_secure_session(nursery): async with raw_conn_factory(nursery) as conns: local_conn, remote_conn = conns - local_secure_conn, remote_secure_conn = None, None + local_secure_conn: ISecureConn | None = None + remote_secure_conn: ISecureConn | None = None async def local_create_secure_session(): nonlocal local_secure_conn @@ -54,6 +56,9 @@ async def test_create_secure_session(nursery): nursery_1.start_soon(local_create_secure_session) nursery_1.start_soon(remote_create_secure_session) + if local_secure_conn is None or remote_secure_conn is None: + raise Exception("Failed to secure connection") + msg = b"abc" await local_secure_conn.write(msg) received_msg = await remote_secure_conn.read(MAX_READ_LEN) diff --git a/tests/stream_muxer/test_async_context_manager.py b/tests/core/stream_muxer/test_async_context_manager.py similarity index 63% rename from tests/stream_muxer/test_async_context_manager.py rename to tests/core/stream_muxer/test_async_context_manager.py index a79e6a7c..08a8487a 100644 --- a/tests/stream_muxer/test_async_context_manager.py +++ b/tests/core/stream_muxer/test_async_context_manager.py @@ -1,6 +1,9 @@ import pytest import trio +from libp2p.abc import ISecureConn +from libp2p.crypto.keys import PrivateKey, PublicKey +from libp2p.peer.id import ID from libp2p.stream_muxer.exceptions import ( MuxedStreamClosed, MuxedStreamError, @@ -8,18 +11,49 @@ from libp2p.stream_muxer.exceptions import ( from libp2p.stream_muxer.mplex.datastructures import ( StreamID, ) +from libp2p.stream_muxer.mplex.mplex import Mplex from libp2p.stream_muxer.mplex.mplex_stream import ( MplexStream, ) from libp2p.stream_muxer.yamux.yamux import ( + Yamux, YamuxStream, ) +DUMMY_PEER_ID = ID(b"dummy_peer_id") -class DummySecuredConn: - async def write(self, data): + +class DummySecuredConn(ISecureConn): + def __init__(self, is_initiator: bool = False): + self.is_initiator = is_initiator + + async def write(self, data: bytes) -> None: pass + async def read(self, n: int | None = -1) -> bytes: + return b"" + + async def close(self) -> None: + pass + + def get_remote_address(self): + return None + + def get_local_address(self): + return None + + def get_local_peer(self) -> ID: + return ID(b"local") + + def get_local_private_key(self) -> PrivateKey: + return PrivateKey() # Dummy key + + def get_remote_peer(self) -> ID: + return ID(b"remote") + + def get_remote_public_key(self) -> PublicKey: + return PublicKey() # Dummy key + class MockMuxedConn: def __init__(self): @@ -37,9 +71,37 @@ class MockMuxedConn: return None +class MockMplexMuxedConn: + def __init__(self): + self.streams_lock = trio.Lock() + self.event_shutting_down = trio.Event() + self.event_closed = trio.Event() + self.event_started = trio.Event() + + async def send_message(self, flag, data, stream_id): + pass + + def get_remote_address(self): + return None + + +class MockYamuxMuxedConn: + def __init__(self): + self.secured_conn = DummySecuredConn() + self.event_shutting_down = trio.Event() + self.event_closed = trio.Event() + self.event_started = trio.Event() + + async def send_message(self, flag, data, stream_id): + pass + + def get_remote_address(self): + return None + + @pytest.mark.trio async def test_mplex_stream_async_context_manager(): - muxed_conn = MockMuxedConn() + muxed_conn = Mplex(DummySecuredConn(), DUMMY_PEER_ID) stream_id = StreamID(1, True) # Use real StreamID stream = MplexStream( name="test_stream", @@ -57,7 +119,7 @@ async def test_mplex_stream_async_context_manager(): @pytest.mark.trio async def test_yamux_stream_async_context_manager(): - muxed_conn = MockMuxedConn() + muxed_conn = Yamux(DummySecuredConn(), DUMMY_PEER_ID) stream = YamuxStream(stream_id=1, conn=muxed_conn, is_initiator=True) async with stream as s: assert s is stream @@ -69,7 +131,7 @@ async def test_yamux_stream_async_context_manager(): @pytest.mark.trio async def test_mplex_stream_async_context_manager_with_error(): - muxed_conn = MockMuxedConn() + muxed_conn = Mplex(DummySecuredConn(), DUMMY_PEER_ID) stream_id = StreamID(1, True) stream = MplexStream( name="test_stream", @@ -89,7 +151,7 @@ async def test_mplex_stream_async_context_manager_with_error(): @pytest.mark.trio async def test_yamux_stream_async_context_manager_with_error(): - muxed_conn = MockMuxedConn() + muxed_conn = Yamux(DummySecuredConn(), DUMMY_PEER_ID) stream = YamuxStream(stream_id=1, conn=muxed_conn, is_initiator=True) with pytest.raises(ValueError): async with stream as s: @@ -103,7 +165,7 @@ async def test_yamux_stream_async_context_manager_with_error(): @pytest.mark.trio async def test_mplex_stream_async_context_manager_write_after_close(): - muxed_conn = MockMuxedConn() + muxed_conn = Mplex(DummySecuredConn(), DUMMY_PEER_ID) stream_id = StreamID(1, True) stream = MplexStream( name="test_stream", @@ -119,7 +181,7 @@ async def test_mplex_stream_async_context_manager_write_after_close(): @pytest.mark.trio async def test_yamux_stream_async_context_manager_write_after_close(): - muxed_conn = MockMuxedConn() + muxed_conn = Yamux(DummySecuredConn(), DUMMY_PEER_ID) stream = YamuxStream(stream_id=1, conn=muxed_conn, is_initiator=True) async with stream as s: assert s is stream diff --git a/tests/core/stream_muxer/test_multiplexer_selection.py b/tests/core/stream_muxer/test_multiplexer_selection.py index 656713b9..b2f3e305 100644 --- a/tests/core/stream_muxer/test_multiplexer_selection.py +++ b/tests/core/stream_muxer/test_multiplexer_selection.py @@ -1,6 +1,7 @@ import logging import pytest +from multiaddr.multiaddr import Multiaddr import trio from libp2p import ( @@ -11,6 +12,8 @@ from libp2p import ( new_host, set_default_muxer, ) +from libp2p.custom_types import TProtocol +from libp2p.peer.peerinfo import PeerInfo # Enable logging for debugging logging.basicConfig(level=logging.DEBUG) @@ -24,13 +27,14 @@ async def host_pair(muxer_preference=None, muxer_opt=None): host_b = new_host(muxer_preference=muxer_preference, muxer_opt=muxer_opt) # Start both hosts - await host_a.get_network().listen("/ip4/127.0.0.1/tcp/0") - await host_b.get_network().listen("/ip4/127.0.0.1/tcp/0") + await host_a.get_network().listen(Multiaddr("/ip4/127.0.0.1/tcp/0")) + await host_b.get_network().listen(Multiaddr("/ip4/127.0.0.1/tcp/0")) # Connect hosts with a timeout listen_addrs_a = host_a.get_addrs() with trio.move_on_after(5): # 5 second timeout - await host_b.connect(host_a.get_id(), listen_addrs_a) + peer_info_a = PeerInfo(host_a.get_id(), listen_addrs_a) + await host_b.connect(peer_info_a) yield host_a, host_b @@ -57,14 +61,14 @@ async def test_multiplexer_preference_parameter(muxer_preference): try: # Start both hosts - await host_a.get_network().listen("/ip4/127.0.0.1/tcp/0") - await host_b.get_network().listen("/ip4/127.0.0.1/tcp/0") + await host_a.get_network().listen(Multiaddr("/ip4/127.0.0.1/tcp/0")) + await host_b.get_network().listen(Multiaddr("/ip4/127.0.0.1/tcp/0")) # Connect hosts with timeout listen_addrs_a = host_a.get_addrs() with trio.move_on_after(5): # 5 second timeout - await host_b.connect(host_a.get_id(), listen_addrs_a) - + peer_info_a = PeerInfo(host_a.get_id(), listen_addrs_a) + await host_b.connect(peer_info_a) # Check if connection was established connections = host_b.get_network().connections assert len(connections) > 0, "Connection not established" @@ -74,7 +78,7 @@ async def test_multiplexer_preference_parameter(muxer_preference): muxed_conn = conn.muxed_conn # Define a simple echo protocol - ECHO_PROTOCOL = "/echo/1.0.0" + ECHO_PROTOCOL = TProtocol("/echo/1.0.0") # Setup echo handler on host_a async def echo_handler(stream): @@ -89,7 +93,7 @@ async def test_multiplexer_preference_parameter(muxer_preference): # Open a stream with timeout with trio.move_on_after(5): - stream = await muxed_conn.open_stream(ECHO_PROTOCOL) + stream = await muxed_conn.open_stream() # Check stream type if muxer_preference == MUXER_YAMUX: @@ -132,13 +136,14 @@ async def test_explicit_muxer_options(muxer_option_func, expected_stream_class): try: # Start both hosts - await host_a.get_network().listen("/ip4/127.0.0.1/tcp/0") - await host_b.get_network().listen("/ip4/127.0.0.1/tcp/0") + await host_a.get_network().listen(Multiaddr("/ip4/127.0.0.1/tcp/0")) + await host_b.get_network().listen(Multiaddr("/ip4/127.0.0.1/tcp/0")) # Connect hosts with timeout listen_addrs_a = host_a.get_addrs() with trio.move_on_after(5): # 5 second timeout - await host_b.connect(host_a.get_id(), listen_addrs_a) + peer_info_a = PeerInfo(host_a.get_id(), listen_addrs_a) + await host_b.connect(peer_info_a) # Check if connection was established connections = host_b.get_network().connections @@ -149,7 +154,7 @@ async def test_explicit_muxer_options(muxer_option_func, expected_stream_class): muxed_conn = conn.muxed_conn # Define a simple echo protocol - ECHO_PROTOCOL = "/echo/1.0.0" + ECHO_PROTOCOL = TProtocol("/echo/1.0.0") # Setup echo handler on host_a async def echo_handler(stream): @@ -164,7 +169,7 @@ async def test_explicit_muxer_options(muxer_option_func, expected_stream_class): # Open a stream with timeout with trio.move_on_after(5): - stream = await muxed_conn.open_stream(ECHO_PROTOCOL) + stream = await muxed_conn.open_stream() # Check stream type assert expected_stream_class in stream.__class__.__name__ @@ -200,13 +205,14 @@ async def test_global_default_muxer(global_default): try: # Start both hosts - await host_a.get_network().listen("/ip4/127.0.0.1/tcp/0") - await host_b.get_network().listen("/ip4/127.0.0.1/tcp/0") + await host_a.get_network().listen(Multiaddr("/ip4/127.0.0.1/tcp/0")) + await host_b.get_network().listen(Multiaddr("/ip4/127.0.0.1/tcp/0")) # Connect hosts with timeout listen_addrs_a = host_a.get_addrs() with trio.move_on_after(5): # 5 second timeout - await host_b.connect(host_a.get_id(), listen_addrs_a) + peer_info_a = PeerInfo(host_a.get_id(), listen_addrs_a) + await host_b.connect(peer_info_a) # Check if connection was established connections = host_b.get_network().connections @@ -217,7 +223,7 @@ async def test_global_default_muxer(global_default): muxed_conn = conn.muxed_conn # Define a simple echo protocol - ECHO_PROTOCOL = "/echo/1.0.0" + ECHO_PROTOCOL = TProtocol("/echo/1.0.0") # Setup echo handler on host_a async def echo_handler(stream): @@ -232,7 +238,7 @@ async def test_global_default_muxer(global_default): # Open a stream with timeout with trio.move_on_after(5): - stream = await muxed_conn.open_stream(ECHO_PROTOCOL) + stream = await muxed_conn.open_stream() # Check stream type based on global default if global_default == MUXER_YAMUX: diff --git a/tests/core/stream_muxer/test_yamux.py b/tests/core/stream_muxer/test_yamux.py index fa25af9f..81d05676 100644 --- a/tests/core/stream_muxer/test_yamux.py +++ b/tests/core/stream_muxer/test_yamux.py @@ -7,6 +7,9 @@ from trio.testing import ( memory_stream_pair, ) +from libp2p.abc import ( + IRawConnection, +) from libp2p.crypto.ed25519 import ( create_new_key_pair, ) @@ -29,18 +32,19 @@ from libp2p.stream_muxer.yamux.yamux import ( ) -class TrioStreamAdapter: - def __init__(self, send_stream, receive_stream): +class TrioStreamAdapter(IRawConnection): + def __init__(self, send_stream, receive_stream, is_initiator: bool = False): self.send_stream = send_stream self.receive_stream = receive_stream + self.is_initiator = is_initiator - async def write(self, data): + async def write(self, data: bytes) -> None: logging.debug(f"Writing {len(data)} bytes") with trio.move_on_after(2): await self.send_stream.send_all(data) - async def read(self, n=-1): - if n == -1: + async def read(self, n: int | None = None) -> bytes: + if n is None or n == -1: raise ValueError("Reading unbounded not supported") logging.debug(f"Attempting to read {n} bytes") with trio.move_on_after(2): @@ -48,9 +52,13 @@ class TrioStreamAdapter: logging.debug(f"Read {len(data)} bytes") return data - async def close(self): + async def close(self) -> None: logging.debug("Closing stream") + def get_remote_address(self) -> tuple[str, int] | None: + # Return None since this is a test adapter without real network info + return None + @pytest.fixture def key_pair(): @@ -68,8 +76,8 @@ async def secure_conn_pair(key_pair, peer_id): client_send, server_receive = memory_stream_pair() server_send, client_receive = memory_stream_pair() - client_rw = TrioStreamAdapter(client_send, client_receive) - server_rw = TrioStreamAdapter(server_send, server_receive) + client_rw = TrioStreamAdapter(client_send, client_receive, is_initiator=True) + server_rw = TrioStreamAdapter(server_send, server_receive, is_initiator=False) insecure_transport = InsecureTransport(key_pair) @@ -196,9 +204,9 @@ async def test_yamux_stream_close(yamux_pair): await trio.sleep(0.1) # Now both directions are closed, so stream should be fully closed - assert ( - client_stream.closed - ), "Client stream should be fully closed after bidirectional close" + assert client_stream.closed, ( + "Client stream should be fully closed after bidirectional close" + ) # Writing should still fail with pytest.raises(MuxedStreamError): @@ -215,8 +223,12 @@ async def test_yamux_stream_reset(yamux_pair): server_stream = await server_yamux.accept_stream() await client_stream.reset() # After reset, reading should raise MuxedStreamReset or MuxedStreamEOF - with pytest.raises((MuxedStreamEOF, MuxedStreamError)): + try: await server_stream.read() + except (MuxedStreamEOF, MuxedStreamError): + pass + else: + pytest.fail("Expected MuxedStreamEOF or MuxedStreamError") # Verify subsequent operations fail with StreamReset or EOF with pytest.raises(MuxedStreamError): await server_stream.read() @@ -269,9 +281,9 @@ async def test_yamux_flow_control(yamux_pair): await client_stream.write(large_data) # Check that window was reduced - assert ( - client_stream.send_window < initial_window - ), "Window should be reduced after sending" + assert client_stream.send_window < initial_window, ( + "Window should be reduced after sending" + ) # Read the data on the server side received = b"" @@ -307,9 +319,9 @@ async def test_yamux_flow_control(yamux_pair): f" {client_stream.send_window}," f"initial half: {initial_window // 2}" ) - assert ( - client_stream.send_window > initial_window // 2 - ), "Window should be increased after update" + assert client_stream.send_window > initial_window // 2, ( + "Window should be increased after update" + ) await client_stream.close() await server_stream.close() @@ -349,17 +361,17 @@ async def test_yamux_half_close(yamux_pair): test_data = b"server response after client close" # The server shouldn't be marked as send_closed yet - assert ( - not server_stream.send_closed - ), "Server stream shouldn't be marked as send_closed" + assert not server_stream.send_closed, ( + "Server stream shouldn't be marked as send_closed" + ) await server_stream.write(test_data) # Client can still read received = await client_stream.read(len(test_data)) - assert ( - received == test_data - ), "Client should still be able to read after sending FIN" + assert received == test_data, ( + "Client should still be able to read after sending FIN" + ) # Now server closes its sending side await server_stream.close() @@ -406,9 +418,9 @@ async def test_yamux_go_away_with_error(yamux_pair): await trio.sleep(0.2) # Verify server recognized shutdown - assert ( - server_yamux.event_shutting_down.is_set() - ), "Server should be shutting down after GO_AWAY" + assert server_yamux.event_shutting_down.is_set(), ( + "Server should be shutting down after GO_AWAY" + ) logging.debug("test_yamux_go_away_with_error complete") diff --git a/tests/core/tools/async_service/test_trio_based_service.py b/tests/core/tools/async_service/test_trio_based_service.py index 599a702f..1a3db153 100644 --- a/tests/core/tools/async_service/test_trio_based_service.py +++ b/tests/core/tools/async_service/test_trio_based_service.py @@ -11,13 +11,8 @@ else: import pytest import trio -from trio.testing import ( - Matcher, - RaisesGroup, -) from libp2p.tools.async_service import ( - DaemonTaskExit, LifecycleError, Service, TrioManager, @@ -134,11 +129,7 @@ async def test_trio_service_lifecycle_run_and_exception(): manager = TrioManager(service) async def do_service_run(): - with RaisesGroup( - Matcher(RuntimeError, match="Service throwing error"), - allow_unwrapped=True, - flatten_subgroups=True, - ): + with pytest.raises(ExceptionGroup): await manager.run() await do_service_lifecycle_check( @@ -165,11 +156,7 @@ async def test_trio_service_lifecycle_run_and_task_exception(): manager = TrioManager(service) async def do_service_run(): - with RaisesGroup( - Matcher(RuntimeError, match="Service throwing error"), - allow_unwrapped=True, - flatten_subgroups=True, - ): + with pytest.raises(ExceptionGroup): await manager.run() await do_service_lifecycle_check( @@ -230,11 +217,7 @@ async def test_trio_service_lifecycle_run_and_daemon_task_exit(): manager = TrioManager(service) async def do_service_run(): - with RaisesGroup( - Matcher(DaemonTaskExit, match="Daemon task"), - allow_unwrapped=True, - flatten_subgroups=True, - ): + with pytest.raises(ExceptionGroup): await manager.run() await do_service_lifecycle_check( @@ -395,11 +378,7 @@ async def test_trio_service_manager_run_task_reraises_exceptions(): with trio.fail_after(1): await trio.sleep_forever() - with RaisesGroup( - Matcher(Exception, match="task exception in run_task"), - allow_unwrapped=True, - flatten_subgroups=True, - ): + with pytest.raises(ExceptionGroup): async with background_trio_service(RunTaskService()): task_event.set() with trio.fail_after(1): @@ -419,13 +398,7 @@ async def test_trio_service_manager_run_daemon_task_cancels_if_exits(): with trio.fail_after(1): await trio.sleep_forever() - with RaisesGroup( - Matcher( - DaemonTaskExit, match=r"Daemon task daemon_task_fn\[daemon=True\] exited" - ), - allow_unwrapped=True, - flatten_subgroups=True, - ): + with pytest.raises(ExceptionGroup): async with background_trio_service(RunTaskService()): task_event.set() with trio.fail_after(1): @@ -443,11 +416,7 @@ async def test_trio_service_manager_propogates_and_records_exceptions(): assert manager.did_error is False - with RaisesGroup( - Matcher(RuntimeError, match="this is the error"), - allow_unwrapped=True, - flatten_subgroups=True, - ): + with pytest.raises(ExceptionGroup): await manager.run() assert manager.did_error is True @@ -641,7 +610,7 @@ async def test_trio_service_with_try_finally_cleanup_with_shielded_await(): ready_cancel.set() await self.manager.wait_finished() finally: - with trio.CancelScope(shield=True): + with trio.CancelScope(shield=True): # type: ignore[call-arg] await trio.lowlevel.checkpoint() self.cleanup_up = True @@ -660,7 +629,7 @@ async def test_error_in_service_run(): self.manager.run_daemon_task(self.manager.wait_finished) raise ValueError("Exception inside run()") - with RaisesGroup(ValueError, allow_unwrapped=True, flatten_subgroups=True): + with pytest.raises(ExceptionGroup): await TrioManager.run_service(ServiceTest()) @@ -679,5 +648,5 @@ async def test_daemon_task_finishes_leaving_children(): async def run(self): self.manager.run_daemon_task(self.buggy_daemon) - with RaisesGroup(DaemonTaskExit, allow_unwrapped=True, flatten_subgroups=True): + with pytest.raises(ExceptionGroup): await TrioManager.run_service(ServiceTest()) diff --git a/tests/core/tools/async_service/test_trio_external_api.py b/tests/core/tools/async_service/test_trio_external_api.py index 3b389024..4f67d593 100644 --- a/tests/core/tools/async_service/test_trio_external_api.py +++ b/tests/core/tools/async_service/test_trio_external_api.py @@ -1,9 +1,15 @@ # Copied from https://github.com/ethereum/async-service +import sys + import pytest import trio -from trio.testing import ( - RaisesGroup, -) + +if sys.version_info >= (3, 11): + from builtins import ( + ExceptionGroup, + ) +else: + from exceptiongroup import ExceptionGroup from libp2p.tools.async_service import ( LifecycleError, @@ -50,7 +56,7 @@ async def test_trio_service_external_api_raises_when_cancelled(): service = ExternalAPIService() async with background_trio_service(service) as manager: - with RaisesGroup(LifecycleError, allow_unwrapped=True, flatten_subgroups=True): + with pytest.raises(ExceptionGroup): async with trio.open_nursery() as nursery: # an event to ensure that we are indeed within the body of the is_within_fn = trio.Event() diff --git a/tests/core/tools/async_service/test_trio_manager_stats.py b/tests/core/tools/async_service/test_trio_manager_stats.py index 659b2f8d..1f5f2a06 100644 --- a/tests/core/tools/async_service/test_trio_manager_stats.py +++ b/tests/core/tools/async_service/test_trio_manager_stats.py @@ -3,8 +3,8 @@ import trio from libp2p.tools.async_service import ( Service, - background_trio_service, ) +from libp2p.tools.async_service.trio_service import TrioManager @pytest.mark.trio @@ -33,24 +33,31 @@ async def test_trio_manager_stats(): self.manager.run_task(trio.lowlevel.checkpoint) service = StatsTest() - async with background_trio_service(service) as manager: - service.run_external_root() - assert len(manager._root_tasks) == 2 - with trio.fail_after(1): - await ready.wait() + async with trio.open_nursery() as nursery: + manager = TrioManager(service) + nursery.start_soon(manager.run) + await manager.wait_started() - # we need to yield to the event loop a few times to allow the various - # tasks to schedule themselves and get running. - for _ in range(50): - await trio.lowlevel.checkpoint() + try: + service.run_external_root() + assert len(manager._root_tasks) == 2 + with trio.fail_after(1): + await ready.wait() - assert manager.stats.tasks.total_count == 10 - assert manager.stats.tasks.finished_count == 3 - assert manager.stats.tasks.pending_count == 7 + # we need to yield to the event loop a few times to allow the various + # tasks to schedule themselves and get running. + for _ in range(50): + await trio.lowlevel.checkpoint() - # This is a simple test to ensure that finished tasks are removed from - # tracking to prevent unbounded memory growth. - assert len(manager._root_tasks) == 1 + assert manager.stats.tasks.total_count == 10 + assert manager.stats.tasks.finished_count == 3 + assert manager.stats.tasks.pending_count == 7 + + # This is a simple test to ensure that finished tasks are removed from + # tracking to prevent unbounded memory growth. + assert len(manager._root_tasks) == 1 + finally: + await manager.stop() # now check after exiting assert manager.stats.tasks.total_count == 10 @@ -67,18 +74,26 @@ async def test_trio_manager_stats_does_not_count_main_run_method(): self.manager.run_task(trio.sleep_forever) ready.set() - async with background_trio_service(StatsTest()) as manager: - with trio.fail_after(1): - await ready.wait() + service = StatsTest() + async with trio.open_nursery() as nursery: + manager = TrioManager(service) + nursery.start_soon(manager.run) + await manager.wait_started() - # we need to yield to the event loop a few times to allow the various - # tasks to schedule themselves and get running. - for _ in range(10): - await trio.lowlevel.checkpoint() + try: + with trio.fail_after(1): + await ready.wait() - assert manager.stats.tasks.total_count == 1 - assert manager.stats.tasks.finished_count == 0 - assert manager.stats.tasks.pending_count == 1 + # we need to yield to the event loop a few times to allow the various + # tasks to schedule themselves and get running. + for _ in range(10): + await trio.lowlevel.checkpoint() + + assert manager.stats.tasks.total_count == 1 + assert manager.stats.tasks.finished_count == 0 + assert manager.stats.tasks.pending_count == 1 + finally: + await manager.stop() # now check after exiting assert manager.stats.tasks.total_count == 1 diff --git a/tests/core/transport/test_tcp.py b/tests/core/transport/test_tcp.py index 0a77a78d..80c97a21 100644 --- a/tests/core/transport/test_tcp.py +++ b/tests/core/transport/test_tcp.py @@ -36,7 +36,7 @@ async def test_tcp_listener(nursery): @pytest.mark.trio async def test_tcp_dial(nursery): transport = TCP() - raw_conn_other_side = None + raw_conn_other_side: RawConnection | None = None event = trio.Event() async def handler(tcp_stream): @@ -59,5 +59,6 @@ async def test_tcp_dial(nursery): await event.wait() data = b"123" + assert raw_conn_other_side is not None await raw_conn_other_side.write(data) assert (await raw_conn.read(len(data))) == data diff --git a/tests/exceptions/test_exceptions.py b/tests/exceptions/test_exceptions.py index 09849c6d..f60cabe3 100644 --- a/tests/exceptions/test_exceptions.py +++ b/tests/exceptions/test_exceptions.py @@ -4,10 +4,14 @@ from libp2p.exceptions import ( def test_multierror_str_and_storage(): - errors = [ValueError("bad value"), KeyError("missing key"), "custom error"] + errors = [ + ValueError("bad value"), + KeyError("missing key"), + RuntimeError("custom error"), + ] multi_error = MultiError(errors) # Check for storage assert multi_error.errors == errors # Check for representation - expected = "Error 1: bad value\n" "Error 2: 'missing key'\n" "Error 3: custom error" + expected = "Error 1: bad value\nError 2: 'missing key'\nError 3: custom error" assert str(multi_error) == expected diff --git a/tests/utils/factories.py b/tests/utils/factories.py index 1fe32344..4df82033 100644 --- a/tests/utils/factories.py +++ b/tests/utils/factories.py @@ -1,5 +1,6 @@ from collections.abc import ( AsyncIterator, + Callable, Sequence, ) from contextlib import ( @@ -8,7 +9,6 @@ from contextlib import ( ) from typing import ( Any, - Callable, cast, ) @@ -88,8 +88,10 @@ from libp2p.security.noise.messages import ( NoiseHandshakePayload, make_handshake_payload_sig, ) -from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID -from libp2p.security.noise.transport import Transport as NoiseTransport +from libp2p.security.noise.transport import ( + PROTOCOL_ID as NOISE_PROTOCOL_ID, + Transport as NoiseTransport, +) import libp2p.security.secio.transport as secio from libp2p.stream_muxer.mplex.mplex import ( MPLEX_PROTOCOL_ID, @@ -134,7 +136,7 @@ class IDFactory(factory.Factory): model = ID peer_id_bytes = factory.LazyFunction( - lambda: generate_peer_id_from(default_key_pair_factory()) + lambda: generate_peer_id_from(default_key_pair_factory()).to_bytes() ) @@ -177,7 +179,7 @@ def noise_transport_factory(key_pair: KeyPair) -> ISecureTransport: def security_options_factory_factory( - protocol_id: TProtocol = None, + protocol_id: TProtocol | None = None, ) -> Callable[[KeyPair], TSecurityOptions]: if protocol_id is None: protocol_id = DEFAULT_SECURITY_PROTOCOL_ID @@ -217,8 +219,8 @@ def default_muxer_transport_factory() -> TMuxerOptions: async def raw_conn_factory( nursery: trio.Nursery, ) -> AsyncIterator[tuple[IRawConnection, IRawConnection]]: - conn_0 = None - conn_1 = None + conn_0: IRawConnection | None = None + conn_1: IRawConnection | None = None event = trio.Event() async def tcp_stream_handler(stream: ReadWriteCloser) -> None: @@ -233,6 +235,7 @@ async def raw_conn_factory( listening_maddr = listener.get_addrs()[0] conn_0 = await tcp_transport.dial(listening_maddr) await event.wait() + assert conn_0 is not None and conn_1 is not None yield conn_0, conn_1 @@ -247,8 +250,8 @@ async def noise_conn_factory( NoiseTransport, noise_transport_factory(create_secp256k1_key_pair()) ) - local_secure_conn: ISecureConn = None - remote_secure_conn: ISecureConn = None + local_secure_conn: ISecureConn | None = None + remote_secure_conn: ISecureConn | None = None async def upgrade_local_conn() -> None: nonlocal local_secure_conn @@ -299,9 +302,9 @@ class SwarmFactory(factory.Factory): @asynccontextmanager async def create_and_listen( cls, - key_pair: KeyPair = None, - security_protocol: TProtocol = None, - muxer_opt: TMuxerOptions = None, + key_pair: KeyPair | None = None, + security_protocol: TProtocol | None = None, + muxer_opt: TMuxerOptions | None = None, ) -> AsyncIterator[Swarm]: # `factory.Factory.__init__` does *not* prepare a *default value* if we pass # an argument explicitly with `None`. If an argument is `None`, we don't pass it @@ -323,8 +326,8 @@ class SwarmFactory(factory.Factory): async def create_batch_and_listen( cls, number: int, - security_protocol: TProtocol = None, - muxer_opt: TMuxerOptions = None, + security_protocol: TProtocol | None = None, + muxer_opt: TMuxerOptions | None = None, ) -> AsyncIterator[tuple[Swarm, ...]]: async with AsyncExitStack() as stack: ctx_mgrs = [ @@ -344,11 +347,11 @@ class HostFactory(factory.Factory): class Params: key_pair = factory.LazyFunction(default_key_pair_factory) - security_protocol: TProtocol = None + security_protocol: TProtocol | None = None muxer_opt = factory.LazyFunction(default_muxer_transport_factory) network = factory.LazyAttribute( - lambda o: SwarmFactory( + lambda o: SwarmFactory.build( security_protocol=o.security_protocol, muxer_opt=o.muxer_opt ) ) @@ -358,8 +361,8 @@ class HostFactory(factory.Factory): async def create_batch_and_listen( cls, number: int, - security_protocol: TProtocol = None, - muxer_opt: TMuxerOptions = None, + security_protocol: TProtocol | None = None, + muxer_opt: TMuxerOptions | None = None, ) -> AsyncIterator[tuple[BasicHost, ...]]: async with SwarmFactory.create_batch_and_listen( number, security_protocol=security_protocol, muxer_opt=muxer_opt @@ -377,7 +380,7 @@ class DummyRouter(IPeerRouting): def _add_peer(self, peer_id: ID, addrs: list[Multiaddr]) -> None: self._routing_table[peer_id] = PeerInfo(peer_id, addrs) - async def find_peer(self, peer_id: ID) -> PeerInfo: + async def find_peer(self, peer_id: ID) -> PeerInfo | None: await trio.lowlevel.checkpoint() return self._routing_table.get(peer_id, None) @@ -388,11 +391,11 @@ class RoutedHostFactory(factory.Factory): class Params: key_pair = factory.LazyFunction(default_key_pair_factory) - security_protocol: TProtocol = None + security_protocol: TProtocol | None = None muxer_opt = factory.LazyFunction(default_muxer_transport_factory) network = factory.LazyAttribute( - lambda o: HostFactory( + lambda o: HostFactory.build( security_protocol=o.security_protocol, muxer_opt=o.muxer_opt ).get_network() ) @@ -403,8 +406,8 @@ class RoutedHostFactory(factory.Factory): async def create_batch_and_listen( cls, number: int, - security_protocol: TProtocol = None, - muxer_opt: TMuxerOptions = None, + security_protocol: TProtocol | None = None, + muxer_opt: TMuxerOptions | None = None, ) -> AsyncIterator[tuple[RoutedHost, ...]]: routing_table = DummyRouter() async with HostFactory.create_batch_and_listen( @@ -447,8 +450,8 @@ class PubsubFactory(factory.Factory): model = Pubsub host = factory.SubFactory(HostFactory) - router = None - cache_size = None + router: IPubsubRouter | None = None + cache_size: int | None = None strict_signing = False @classmethod @@ -457,13 +460,15 @@ class PubsubFactory(factory.Factory): cls, host: IHost, router: IPubsubRouter, - cache_size: int, + cache_size: int | None, seen_ttl: int, sweep_interval: int, strict_signing: bool, - msg_id_constructor: Callable[[rpc_pb2.Message], bytes] = None, + msg_id_constructor: Callable[[rpc_pb2.Message], bytes] | None = None, ) -> AsyncIterator[Pubsub]: - pubsub = cls( + if msg_id_constructor is None: + msg_id_constructor = get_peer_and_seqno_msg_id + pubsub = Pubsub( host=host, router=router, cache_size=cache_size, @@ -482,13 +487,13 @@ class PubsubFactory(factory.Factory): cls, number: int, routers: Sequence[IPubsubRouter], - cache_size: int = None, + cache_size: int | None = None, seen_ttl: int = 120, sweep_interval: int = 60, strict_signing: bool = False, - security_protocol: TProtocol = None, - muxer_opt: TMuxerOptions = None, - msg_id_constructor: Callable[[rpc_pb2.Message], bytes] = None, + security_protocol: TProtocol | None = None, + muxer_opt: TMuxerOptions | None = None, + msg_id_constructor: Callable[[rpc_pb2.Message], bytes] | None = None, ) -> AsyncIterator[tuple[Pubsub, ...]]: async with HostFactory.create_batch_and_listen( number, security_protocol=security_protocol, muxer_opt=muxer_opt @@ -516,16 +521,15 @@ class PubsubFactory(factory.Factory): async def create_batch_with_floodsub( cls, number: int, - cache_size: int = None, + cache_size: int | None = None, seen_ttl: int = 120, sweep_interval: int = 60, strict_signing: bool = False, - protocols: Sequence[TProtocol] = None, - security_protocol: TProtocol = None, - muxer_opt: TMuxerOptions = None, - msg_id_constructor: Callable[ - [rpc_pb2.Message], bytes - ] = get_peer_and_seqno_msg_id, + protocols: Sequence[TProtocol] | None = None, + security_protocol: TProtocol | None = None, + muxer_opt: TMuxerOptions | None = None, + msg_id_constructor: None + | (Callable[[rpc_pb2.Message], bytes]) = get_peer_and_seqno_msg_id, ) -> AsyncIterator[tuple[Pubsub, ...]]: if protocols is not None: floodsubs = FloodsubFactory.create_batch(number, protocols=list(protocols)) @@ -550,9 +554,9 @@ class PubsubFactory(factory.Factory): cls, number: int, *, - cache_size: int = None, + cache_size: int | None = None, strict_signing: bool = False, - protocols: Sequence[TProtocol] = None, + protocols: Sequence[TProtocol] | None = None, degree: int = GOSSIPSUB_PARAMS.degree, degree_low: int = GOSSIPSUB_PARAMS.degree_low, degree_high: int = GOSSIPSUB_PARAMS.degree_high, @@ -564,11 +568,10 @@ class PubsubFactory(factory.Factory): heartbeat_initial_delay: float = GOSSIPSUB_PARAMS.heartbeat_initial_delay, direct_connect_initial_delay: float = GOSSIPSUB_PARAMS.direct_connect_initial_delay, # noqa: E501 direct_connect_interval: int = GOSSIPSUB_PARAMS.direct_connect_interval, - security_protocol: TProtocol = None, - muxer_opt: TMuxerOptions = None, - msg_id_constructor: Callable[ - [rpc_pb2.Message], bytes - ] = get_peer_and_seqno_msg_id, + security_protocol: TProtocol | None = None, + muxer_opt: TMuxerOptions | None = None, + msg_id_constructor: None + | (Callable[[rpc_pb2.Message], bytes]) = get_peer_and_seqno_msg_id, ) -> AsyncIterator[tuple[Pubsub, ...]]: if protocols is not None: gossipsubs = GossipsubFactory.create_batch( @@ -605,6 +608,8 @@ class PubsubFactory(factory.Factory): number, gossipsubs, cache_size, + 120, # seen_ttl + 60, # sweep_interval strict_signing, security_protocol=security_protocol, muxer_opt=muxer_opt, @@ -618,7 +623,8 @@ class PubsubFactory(factory.Factory): @asynccontextmanager async def swarm_pair_factory( - security_protocol: TProtocol = None, muxer_opt: TMuxerOptions = None + security_protocol: TProtocol | None = None, + muxer_opt: TMuxerOptions | None = None, ) -> AsyncIterator[tuple[Swarm, Swarm]]: async with SwarmFactory.create_batch_and_listen( 2, security_protocol=security_protocol, muxer_opt=muxer_opt @@ -629,7 +635,8 @@ async def swarm_pair_factory( @asynccontextmanager async def host_pair_factory( - security_protocol: TProtocol = None, muxer_opt: TMuxerOptions = None + security_protocol: TProtocol | None = None, + muxer_opt: TMuxerOptions | None = None, ) -> AsyncIterator[tuple[BasicHost, BasicHost]]: async with HostFactory.create_batch_and_listen( 2, security_protocol=security_protocol, muxer_opt=muxer_opt @@ -640,7 +647,8 @@ async def host_pair_factory( @asynccontextmanager async def swarm_conn_pair_factory( - security_protocol: TProtocol = None, muxer_opt: TMuxerOptions = None + security_protocol: TProtocol | None = None, + muxer_opt: TMuxerOptions | None = None, ) -> AsyncIterator[tuple[SwarmConn, SwarmConn]]: async with swarm_pair_factory( security_protocol=security_protocol, muxer_opt=muxer_opt @@ -652,7 +660,7 @@ async def swarm_conn_pair_factory( @asynccontextmanager async def mplex_conn_pair_factory( - security_protocol: TProtocol = None, + security_protocol: TProtocol | None = None, ) -> AsyncIterator[tuple[Mplex, Mplex]]: async with swarm_conn_pair_factory( security_protocol=security_protocol, @@ -666,7 +674,7 @@ async def mplex_conn_pair_factory( @asynccontextmanager async def mplex_stream_pair_factory( - security_protocol: TProtocol = None, + security_protocol: TProtocol | None = None, ) -> AsyncIterator[tuple[MplexStream, MplexStream]]: async with mplex_conn_pair_factory( security_protocol=security_protocol @@ -684,7 +692,7 @@ async def mplex_stream_pair_factory( @asynccontextmanager async def yamux_conn_pair_factory( - security_protocol: TProtocol = None, + security_protocol: TProtocol | None = None, ) -> AsyncIterator[tuple[Yamux, Yamux]]: async with swarm_conn_pair_factory( security_protocol=security_protocol, muxer_opt=default_muxer_transport_factory() @@ -697,7 +705,7 @@ async def yamux_conn_pair_factory( @asynccontextmanager async def yamux_stream_pair_factory( - security_protocol: TProtocol = None, + security_protocol: TProtocol | None = None, ) -> AsyncIterator[tuple[YamuxStream, YamuxStream]]: async with yamux_conn_pair_factory( security_protocol=security_protocol @@ -715,11 +723,12 @@ async def yamux_stream_pair_factory( @asynccontextmanager async def net_stream_pair_factory( - security_protocol: TProtocol = None, muxer_opt: TMuxerOptions = None + security_protocol: TProtocol | None = None, + muxer_opt: TMuxerOptions | None = None, ) -> AsyncIterator[tuple[INetStream, INetStream]]: protocol_id = TProtocol("/example/id/1") - stream_1: INetStream + stream_1: INetStream | None = None # Just a proxy, we only care about the stream. # Add a barrier to avoid stream being removed. @@ -736,5 +745,6 @@ async def net_stream_pair_factory( hosts[1].set_stream_handler(protocol_id, handler) stream_0 = await hosts[0].new_stream(hosts[1].get_id(), [protocol_id]) + assert stream_1 is not None yield stream_0, stream_1 event_handler_finished.set() diff --git a/tests/utils/interop/daemon.py b/tests/utils/interop/daemon.py index e55aba9f..639bd4cc 100644 --- a/tests/utils/interop/daemon.py +++ b/tests/utils/interop/daemon.py @@ -131,13 +131,13 @@ async def make_p2pd( async with p2pc.listen(): peer_id, maddrs = await p2pc.identify() - listen_maddr: Multiaddr = None + listen_maddr: Multiaddr | None = None for maddr in maddrs: try: - ip = maddr.value_for_protocol(multiaddr.protocols.P_IP4) + ip = maddr.value_for_protocol(multiaddr.multiaddr.protocols.P_IP4) # NOTE: Check if this `maddr` uses `tcp`. - maddr.value_for_protocol(multiaddr.protocols.P_TCP) - except multiaddr.exceptions.ProtocolLookupError: + maddr.value_for_protocol(multiaddr.multiaddr.protocols.P_TCP) + except multiaddr.multiaddr.exceptions.ProtocolLookupError: continue if ip == LOCALHOST_IP: listen_maddr = maddr diff --git a/tests/utils/interop/process.py b/tests/utils/interop/process.py index 0d9e2650..c655d334 100644 --- a/tests/utils/interop/process.py +++ b/tests/utils/interop/process.py @@ -14,27 +14,34 @@ TIMEOUT_DURATION = 30 class AbstractInterativeProcess(ABC): @abstractmethod - async def start(self) -> None: - ... + async def start(self) -> None: ... @abstractmethod - async def close(self) -> None: - ... + async def close(self) -> None: ... class BaseInteractiveProcess(AbstractInterativeProcess): - proc: trio.Process = None + proc: trio.Process | None = None cmd: str args: list[str] bytes_read: bytearray - patterns: Iterable[bytes] = None + patterns: Iterable[bytes] | None = None event_ready: trio.Event async def wait_until_ready(self) -> None: + if self.proc is None: + raise Exception("process is not defined") + if self.patterns is None: + raise Exception("patterns is not defined") patterns_occurred = {pat: False for pat in self.patterns} buffers = {pat: bytearray() for pat in self.patterns} async def read_from_daemon_and_check() -> None: + if self.proc is None: + raise Exception("process is not defined") + if self.proc.stdout is None: + raise Exception("process stdout is None, cannot read output") + async for data in self.proc.stdout: self.bytes_read.extend(data) for pat, occurred in patterns_occurred.items(): diff --git a/tests/utils/interop/utils.py b/tests/utils/interop/utils.py index fe0997a0..30b89197 100644 --- a/tests/utils/interop/utils.py +++ b/tests/utils/interop/utils.py @@ -5,11 +5,10 @@ from typing import ( from multiaddr import ( Multiaddr, ) +from p2pclient.libp2p_stubs.peer.id import ID as StubID import trio -from libp2p.host.host_interface import ( - IHost, -) +from libp2p.abc import IHost from libp2p.peer.id import ( ID, ) @@ -58,7 +57,10 @@ async def connect(a: TDaemonOrHost, b: TDaemonOrHost) -> None: b_peer_info = _get_peer_info(b) if isinstance(a, Daemon): - await a.control.connect(b_peer_info.peer_id, b_peer_info.addrs) + # Convert internal libp2p ID to p2pclient stub ID .connect() + await a.control.connect( + StubID(b_peer_info.peer_id.to_bytes()), b_peer_info.addrs + ) else: # isinstance(b, IHost) await a.connect(b_peer_info) # Allow additional sleep for both side to establish the connection. diff --git a/tests/utils/pubsub/dummy_account_node.py b/tests/utils/pubsub/dummy_account_node.py index a1149bd5..cefc79f9 100644 --- a/tests/utils/pubsub/dummy_account_node.py +++ b/tests/utils/pubsub/dummy_account_node.py @@ -8,6 +8,7 @@ from contextlib import ( from libp2p.abc import ( IHost, + ISubscriptionAPI, ) from libp2p.pubsub.pubsub import ( Pubsub, @@ -40,9 +41,11 @@ class DummyAccountNode(Service): """ pubsub: Pubsub + subscription: ISubscriptionAPI | None def __init__(self, pubsub: Pubsub) -> None: self.pubsub = pubsub + self.subscription = None self.balances: dict[str, int] = {} @property @@ -74,6 +77,10 @@ class DummyAccountNode(Service): async def handle_incoming_msgs(self) -> None: """Handle all incoming messages on the CRYPTO_TOPIC from peers.""" while True: + if self.subscription is None: + raise RuntimeError( + "Subscription must be set before handling incoming messages" + ) incoming = await self.subscription.get() msg_comps = incoming.data.decode("utf-8").split(",") diff --git a/tests/utils/test_logging.py b/tests/utils/test_logging.py index 51dafb7f..603af5e1 100644 --- a/tests/utils/test_logging.py +++ b/tests/utils/test_logging.py @@ -1,4 +1,5 @@ import logging +import logging.handlers import os from pathlib import ( Path, diff --git a/tox.ini b/tox.ini index 347f1dd4..5ebb00ce 100644 --- a/tox.ini +++ b/tox.ini @@ -1,9 +1,9 @@ [tox] envlist= - py{39,310,311,312,313}-core - py{39,310,311,312,313}-lint - py{39,310,311,312,313}-wheel - py{39,310,311,312,313}-interop + py{310,311,312,313}-core + py{310,311,312,313}-lint + py{310,311,312,313}-wheel + py{310,311,312,313}-interop windows-wheel docs @@ -26,7 +26,6 @@ commands= basepython= docs: python windows-wheel: python - py39: python3.9 py310: python3.10 py311: python3.11 py312: python3.12 @@ -36,7 +35,7 @@ extras= docs allowlist_externals=make,pre-commit -[testenv:py{39,310,311,312,313}-lint] +[testenv:py{310,311,312,313}-lint] deps=pre-commit extras= dev @@ -44,7 +43,7 @@ commands= pre-commit install pre-commit run --all-files --show-diff-on-failure -[testenv:py{39,310,311,312,313}-wheel] +[testenv:py{310,311,312,313}-wheel] deps= wheel build[virtualenv]