From b840eaa7e176b5962e8013caaadae4ac916efadb Mon Sep 17 00:00:00 2001 From: yashksaini-coder Date: Sat, 9 Aug 2025 01:22:03 +0530 Subject: [PATCH 01/71] Implement advanced network discovery example and address validation utilities - Added `network_discover.py` to demonstrate Thin Waist address handling. - Introduced `address_validation.py` with functions for discovering available network interfaces, expanding wildcard addresses, and determining optimal binding addresses. - Included fallback mechanisms for environments lacking Thin Waist support. --- examples/advanced/network_discover.py | 60 +++++++++++++ libp2p/utils/address_validation.py | 125 ++++++++++++++++++++++++++ 2 files changed, 185 insertions(+) create mode 100644 examples/advanced/network_discover.py create mode 100644 libp2p/utils/address_validation.py diff --git a/examples/advanced/network_discover.py b/examples/advanced/network_discover.py new file mode 100644 index 00000000..a1a22052 --- /dev/null +++ b/examples/advanced/network_discover.py @@ -0,0 +1,60 @@ +""" +Advanced demonstration of Thin Waist address handling. + +Run: + python -m examples.advanced.network_discovery +""" + +from __future__ import annotations + +from multiaddr import Multiaddr + +try: + from libp2p.utils.address_validation import ( + get_available_interfaces, + expand_wildcard_address, + get_optimal_binding_address, + ) +except ImportError: + # Fallbacks if utilities are missing + def get_available_interfaces(port: int, protocol: str = "tcp"): + return [Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")] + + def expand_wildcard_address(addr: Multiaddr, port: int | None = None): + return [addr if port is None else Multiaddr(str(addr).rsplit("/", 1)[0] + f"/{port}")] + + def get_optimal_binding_address(port: int, protocol: str = "tcp"): + return Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}") + + +def main() -> None: + port = 8080 + interfaces = get_available_interfaces(port) + print(f"Discovered interfaces for port {port}:") + for a in interfaces: + print(f" - {a}") + + wildcard_v4 = Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") + expanded_v4 = expand_wildcard_address(wildcard_v4) + print("\nExpanded IPv4 wildcard:") + for a in expanded_v4: + print(f" - {a}") + + wildcard_v6 = Multiaddr(f"/ip6/::/tcp/{port}") + expanded_v6 = expand_wildcard_address(wildcard_v6) + print("\nExpanded IPv6 wildcard:") + for a in expanded_v6: + print(f" - {a}") + + print("\nOptimal binding address heuristic result:") + print(f" -> {get_optimal_binding_address(port)}") + + override_port = 9000 + overridden = expand_wildcard_address(wildcard_v4, port=override_port) + print(f"\nPort override expansion to {override_port}:") + for a in overridden: + print(f" - {a}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/libp2p/utils/address_validation.py b/libp2p/utils/address_validation.py new file mode 100644 index 00000000..be7f8082 --- /dev/null +++ b/libp2p/utils/address_validation.py @@ -0,0 +1,125 @@ +from __future__ import annotations +from typing import List, Optional +from multiaddr import Multiaddr + +try: + from multiaddr.utils import get_thin_waist_addresses, get_network_addrs # type: ignore + _HAS_THIN_WAIST = True +except ImportError: # pragma: no cover - only executed in older environments + _HAS_THIN_WAIST = False + get_thin_waist_addresses = None # type: ignore + get_network_addrs = None # type: ignore + + +def _safe_get_network_addrs(ip_version: int) -> List[str]: + """ + Internal safe wrapper. Returns a list of IP addresses for the requested IP version. + Falls back to minimal defaults when Thin Waist helpers are missing. + + :param ip_version: 4 or 6 + """ + if _HAS_THIN_WAIST and get_network_addrs: + try: + return get_network_addrs(ip_version) or [] + except Exception: # pragma: no cover - defensive + return [] + # Fallback behavior (very conservative) + if ip_version == 4: + return ["127.0.0.1"] + if ip_version == 6: + return ["::1"] + return [] + + +def _safe_expand(addr: Multiaddr, port: Optional[int] = None) -> List[Multiaddr]: + """ + Internal safe expansion wrapper. Returns a list of Multiaddr objects. + If Thin Waist isn't available, returns [addr] (identity). + """ + if _HAS_THIN_WAIST and get_thin_waist_addresses: + try: + if port is not None: + return get_thin_waist_addresses(addr, port=port) or [] + return get_thin_waist_addresses(addr) or [] + except Exception: # pragma: no cover - defensive + return [addr] + return [addr] + + +def get_available_interfaces(port: int, protocol: str = "tcp") -> List[Multiaddr]: + """ + Discover available network interfaces (IPv4 + IPv6 if supported) for binding. + + :param port: Port number to bind to. + :param protocol: Transport protocol (e.g., "tcp" or "udp"). + :return: List of Multiaddr objects representing candidate interface addresses. + """ + addrs: List[Multiaddr] = [] + + # IPv4 enumeration + for ip in _safe_get_network_addrs(4): + addrs.append(Multiaddr(f"/ip4/{ip}/{protocol}/{port}")) + + # IPv6 enumeration (optional: only include if we have at least one global or loopback) + for ip in _safe_get_network_addrs(6): + # Avoid returning unusable wildcard expansions if the environment does not support IPv6 + addrs.append(Multiaddr(f"/ip6/{ip}/{protocol}/{port}")) + + # Fallback if nothing discovered + if not addrs: + addrs.append(Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")) + + return addrs + + +def expand_wildcard_address(addr: Multiaddr, port: Optional[int] = None) -> List[Multiaddr]: + """ + Expand a wildcard (e.g. /ip4/0.0.0.0/tcp/0) into all concrete interfaces. + + :param addr: Multiaddr to expand. + :param port: Optional override for port selection. + :return: List of concrete Multiaddr instances. + """ + expanded = _safe_expand(addr, port=port) + if not expanded: # Safety fallback + return [addr] + return expanded + + +def get_optimal_binding_address(port: int, protocol: str = "tcp") -> Multiaddr: + """ + Choose an optimal address for an example to bind to: + - Prefer non-loopback IPv4 + - Then non-loopback IPv6 + - Fallback to loopback + - Fallback to wildcard + + :param port: Port number. + :param protocol: Transport protocol. + :return: A single Multiaddr chosen heuristically. + """ + candidates = get_available_interfaces(port, protocol) + + def is_non_loopback(ma: Multiaddr) -> bool: + s = str(ma) + return not ("/ip4/127." in s or "/ip6/::1" in s) + + for c in candidates: + if "/ip4/" in str(c) and is_non_loopback(c): + return c + for c in candidates: + if "/ip6/" in str(c) and is_non_loopback(c): + return c + for c in candidates: + if "/ip4/127." in str(c) or "/ip6/::1" in str(c): + return c + + # As a final fallback, produce a wildcard + return Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}") + + +__all__ = [ + "get_available_interfaces", + "get_optimal_binding_address", + "expand_wildcard_address", +] \ No newline at end of file From fa174230baa836cd83da09a7a505b5221f7cad36 Mon Sep 17 00:00:00 2001 From: yashksaini-coder Date: Sat, 9 Aug 2025 01:22:17 +0530 Subject: [PATCH 02/71] Refactor echo example to use optimal binding address - Replaced hardcoded listen address with `get_optimal_binding_address` for improved flexibility. - Imported address validation utilities in `echo.py` and updated `__init__.py` to include new functions. --- examples/echo/echo.py | 10 ++++++++-- libp2p/utils/__init__.py | 9 +++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/examples/echo/echo.py b/examples/echo/echo.py index 126a7da2..15c40c25 100644 --- a/examples/echo/echo.py +++ b/examples/echo/echo.py @@ -19,6 +19,11 @@ from libp2p.peer.peerinfo import ( info_from_p2p_addr, ) +from libp2p.utils.address_validation import ( + get_optimal_binding_address, + get_available_interfaces, +) + PROTOCOL_ID = TProtocol("/echo/1.0.0") MAX_READ_LEN = 2**32 - 1 @@ -31,8 +36,9 @@ async def _echo_stream_handler(stream: INetStream) -> None: async def run(port: int, destination: str, seed: int | None = None) -> None: - listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") - + # CHANGED: previously hardcoded 0.0.0.0 + listen_addr = get_optimal_binding_address(port) + if seed: import random diff --git a/libp2p/utils/__init__.py b/libp2p/utils/__init__.py index 0f78bfcb..0f68e701 100644 --- a/libp2p/utils/__init__.py +++ b/libp2p/utils/__init__.py @@ -15,6 +15,12 @@ from libp2p.utils.version import ( get_agent_version, ) +from libp2p.utils.address_validation import ( + get_available_interfaces, + get_optimal_binding_address, + expand_wildcard_address, +) + __all__ = [ "decode_uvarint_from_stream", "encode_delim", @@ -26,4 +32,7 @@ __all__ = [ "decode_varint_from_bytes", "decode_varint_with_size", "read_length_prefixed_protobuf", + "get_available_interfaces", + "get_optimal_binding_address", + "expand_wildcard_address", ] From 59a898c8cee163af6b5dd85787fc17c03c43f20d Mon Sep 17 00:00:00 2001 From: yashksaini-coder Date: Sat, 9 Aug 2025 01:24:14 +0530 Subject: [PATCH 03/71] Add tests for echo example and address validation utilities - Introduced `test_echo_thin_waist.py` to validate the echo example's output for Thin Waist lines. - Added `test_address_validation.py` to cover functions for available interfaces, optimal binding addresses, and wildcard address expansion. - Included parameterized tests and environment checks for IPv6 support. --- tests/examples/test_echo_thin_waist.py | 51 +++++++++++++++++++++++ tests/utils/test_address_validation.py | 56 ++++++++++++++++++++++++++ 2 files changed, 107 insertions(+) create mode 100644 tests/examples/test_echo_thin_waist.py create mode 100644 tests/utils/test_address_validation.py diff --git a/tests/examples/test_echo_thin_waist.py b/tests/examples/test_echo_thin_waist.py new file mode 100644 index 00000000..9da85928 --- /dev/null +++ b/tests/examples/test_echo_thin_waist.py @@ -0,0 +1,51 @@ +import asyncio +import contextlib +import subprocess +import sys +import time +from pathlib import Path + +import pytest + +# This test is intentionally lightweight and can be marked as 'integration'. +# It ensures the echo example runs and prints the new Thin Waist lines. + +EXAMPLES_DIR = Path(__file__).parent.parent.parent / "examples" / "echo" + + +@pytest.mark.timeout(20) +def test_echo_example_starts_and_prints_thin_waist(monkeypatch, tmp_path): + # We run: python examples/echo/echo.py -p 0 + cmd = [sys.executable, str(EXAMPLES_DIR / "echo.py"), "-p", "0"] + proc = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + assert proc.stdout is not None + + found_selected = False + found_interfaces = False + start = time.time() + + try: + while time.time() - start < 10: + line = proc.stdout.readline() + if not line: + time.sleep(0.1) + continue + if "Selected binding address:" in line: + found_selected = True + if "Available candidate interfaces:" in line: + found_interfaces = True + if "Waiting for incoming connections..." in line: + break + finally: + with contextlib.suppress(ProcessLookupError): + proc.terminate() + with contextlib.suppress(ProcessLookupError): + proc.kill() + + assert found_selected, "Did not capture Thin Waist binding log line" + assert found_interfaces, "Did not capture Thin Waist interfaces log line" \ No newline at end of file diff --git a/tests/utils/test_address_validation.py b/tests/utils/test_address_validation.py new file mode 100644 index 00000000..80ae27e8 --- /dev/null +++ b/tests/utils/test_address_validation.py @@ -0,0 +1,56 @@ +import os + +import pytest +from multiaddr import Multiaddr + +from libp2p.utils.address_validation import ( + get_available_interfaces, + get_optimal_binding_address, + expand_wildcard_address, +) + + +@pytest.mark.parametrize("proto", ["tcp"]) +def test_get_available_interfaces(proto: str) -> None: + interfaces = get_available_interfaces(0, protocol=proto) + assert len(interfaces) > 0 + for addr in interfaces: + assert isinstance(addr, Multiaddr) + assert f"/{proto}/" in str(addr) + + +def test_get_optimal_binding_address() -> None: + addr = get_optimal_binding_address(0) + assert isinstance(addr, Multiaddr) + # At least IPv4 or IPv6 prefix present + s = str(addr) + assert ("/ip4/" in s) or ("/ip6/" in s) + + +def test_expand_wildcard_address_ipv4() -> None: + wildcard = Multiaddr("/ip4/0.0.0.0/tcp/0") + expanded = expand_wildcard_address(wildcard) + assert len(expanded) > 0 + for e in expanded: + assert isinstance(e, Multiaddr) + assert "/tcp/" in str(e) + + +def test_expand_wildcard_address_port_override() -> None: + wildcard = Multiaddr("/ip4/0.0.0.0/tcp/7000") + overridden = expand_wildcard_address(wildcard, port=9001) + assert len(overridden) > 0 + for e in overridden: + assert str(e).endswith("/tcp/9001") + + +@pytest.mark.skipif( + os.environ.get("NO_IPV6") == "1", + reason="Environment disallows IPv6", +) +def test_expand_wildcard_address_ipv6() -> None: + wildcard = Multiaddr("/ip6/::/tcp/0") + expanded = expand_wildcard_address(wildcard) + assert len(expanded) > 0 + for e in expanded: + assert "/ip6/" in str(e) \ No newline at end of file From b838a0e3b672eb875047acdf3449e65702f5c0ee Mon Sep 17 00:00:00 2001 From: unniznd Date: Tue, 12 Aug 2025 21:50:10 +0530 Subject: [PATCH 04/71] added none type to return value of negotiate and changed caller handles to handle none. Added newsfragment. --- libp2p/host/basic_host.py | 3 +++ libp2p/protocol_muxer/multiselect.py | 2 +- libp2p/security/security_multistream.py | 7 ++++++- libp2p/stream_muxer/muxer_multistream.py | 7 ++++++- newsfragments/837.fix.rst | 1 + 5 files changed, 17 insertions(+), 3 deletions(-) create mode 100644 newsfragments/837.fix.rst diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index 70e41953..008fe7e5 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -288,6 +288,9 @@ class BasicHost(IHost): protocol, handler = await self.multiselect.negotiate( MultiselectCommunicator(net_stream), self.negotiate_timeout ) + if protocol is None: + await net_stream.reset() + raise StreamFailure("No protocol selected") except MultiselectError as error: peer_id = net_stream.muxed_conn.peer_id logger.debug( diff --git a/libp2p/protocol_muxer/multiselect.py b/libp2p/protocol_muxer/multiselect.py index 8d311391..e58c0981 100644 --- a/libp2p/protocol_muxer/multiselect.py +++ b/libp2p/protocol_muxer/multiselect.py @@ -53,7 +53,7 @@ class Multiselect(IMultiselectMuxer): self, communicator: IMultiselectCommunicator, negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT, - ) -> tuple[TProtocol, StreamHandlerFn | None]: + ) -> tuple[TProtocol | None, StreamHandlerFn | None]: """ Negotiate performs protocol selection. diff --git a/libp2p/security/security_multistream.py b/libp2p/security/security_multistream.py index 193cc092..d15dbbd9 100644 --- a/libp2p/security/security_multistream.py +++ b/libp2p/security/security_multistream.py @@ -26,6 +26,9 @@ from libp2p.protocol_muxer.multiselect_client import ( from libp2p.protocol_muxer.multiselect_communicator import ( MultiselectCommunicator, ) +from libp2p.transport.exceptions import ( + SecurityUpgradeFailure, +) """ Represents a secured connection object, which includes a connection and details about @@ -104,7 +107,7 @@ class SecurityMultistream(ABC): :param is_initiator: true if we are the initiator, false otherwise :return: selected secure transport """ - protocol: TProtocol + protocol: TProtocol | None communicator = MultiselectCommunicator(conn) if is_initiator: # Select protocol if initiator @@ -114,5 +117,7 @@ class SecurityMultistream(ABC): else: # Select protocol if non-initiator protocol, _ = await self.multiselect.negotiate(communicator) + if protocol is None: + raise SecurityUpgradeFailure("No protocol selected") # Return transport from protocol return self.transports[protocol] diff --git a/libp2p/stream_muxer/muxer_multistream.py b/libp2p/stream_muxer/muxer_multistream.py index 76699c67..d96820a4 100644 --- a/libp2p/stream_muxer/muxer_multistream.py +++ b/libp2p/stream_muxer/muxer_multistream.py @@ -30,6 +30,9 @@ from libp2p.stream_muxer.yamux.yamux import ( PROTOCOL_ID, Yamux, ) +from libp2p.transport.exceptions import ( + MuxerUpgradeFailure, +) class MuxerMultistream: @@ -73,7 +76,7 @@ class MuxerMultistream: :param conn: conn to choose a transport over :return: selected muxer transport """ - protocol: TProtocol + protocol: TProtocol | None communicator = MultiselectCommunicator(conn) if conn.is_initiator: protocol = await self.multiselect_client.select_one_of( @@ -81,6 +84,8 @@ class MuxerMultistream: ) else: protocol, _ = await self.multiselect.negotiate(communicator) + if protocol is None: + raise MuxerUpgradeFailure("No protocol selected") return self.transports[protocol] async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn: diff --git a/newsfragments/837.fix.rst b/newsfragments/837.fix.rst new file mode 100644 index 00000000..47919c23 --- /dev/null +++ b/newsfragments/837.fix.rst @@ -0,0 +1 @@ +Added multiselect type consistency in negotiate method. Updates all the usages of the method. From 1ecff5437ce8bbd6c2edf66f12f02466d5d3ad7c Mon Sep 17 00:00:00 2001 From: unniznd Date: Thu, 14 Aug 2025 07:29:06 +0530 Subject: [PATCH 05/71] fixed newsfragment filename issue. --- newsfragments/{837.fix.rst => 837.bugfix.rst} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename newsfragments/{837.fix.rst => 837.bugfix.rst} (100%) diff --git a/newsfragments/837.fix.rst b/newsfragments/837.bugfix.rst similarity index 100% rename from newsfragments/837.fix.rst rename to newsfragments/837.bugfix.rst From b363d1d6d0d04223551affb52243caa6732b0967 Mon Sep 17 00:00:00 2001 From: yashksaini-coder Date: Mon, 18 Aug 2025 12:38:04 +0530 Subject: [PATCH 06/71] fix: update listening address handling to use all available interfaces --- examples/echo/echo.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/examples/echo/echo.py b/examples/echo/echo.py index 15c40c25..ba52fe76 100644 --- a/examples/echo/echo.py +++ b/examples/echo/echo.py @@ -36,22 +36,20 @@ async def _echo_stream_handler(stream: INetStream) -> None: async def run(port: int, destination: str, seed: int | None = None) -> None: - # CHANGED: previously hardcoded 0.0.0.0 - listen_addr = get_optimal_binding_address(port) - + # Use all available interfaces for listening (JS parity) + listen_addrs = get_available_interfaces(port) + if seed: import random - random.seed(seed) secret_number = random.getrandbits(32 * 8) secret = secret_number.to_bytes(length=32, byteorder="big") else: import secrets - secret = secrets.token_bytes(32) host = new_host(key_pair=create_new_key_pair(secret)) - async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery: + async with host.run(listen_addrs=listen_addrs), trio.open_nursery() as nursery: # Start the peer-store cleanup task nursery.start_soon(host.get_peerstore().start_cleanup_task, 60) @@ -60,8 +58,14 @@ async def run(port: int, destination: str, seed: int | None = None) -> None: if not destination: # its the server host.set_stream_handler(PROTOCOL_ID, _echo_stream_handler) + # Print all listen addresses with peer ID (JS parity) + print("Listener ready, listening on:") + peer_id = host.get_id().to_string() + for addr in listen_addrs: + print(f"{addr}/p2p/{peer_id}") + print( - "Run this from the same folder in another console:\n\n" + "\nRun this from the same folder in another console:\n\n" f"echo-demo " f"-d {host.get_addrs()[0]}\n" ) From a2fcf33bc173fb12d7dd2fb8524490583329c840 Mon Sep 17 00:00:00 2001 From: yashksaini-coder Date: Mon, 18 Aug 2025 12:38:10 +0530 Subject: [PATCH 07/71] refactor: migrate echo example test to use Trio for process handling --- tests/examples/test_echo_thin_waist.py | 78 +++++++++++++++----------- 1 file changed, 46 insertions(+), 32 deletions(-) diff --git a/tests/examples/test_echo_thin_waist.py b/tests/examples/test_echo_thin_waist.py index 9da85928..c861f547 100644 --- a/tests/examples/test_echo_thin_waist.py +++ b/tests/examples/test_echo_thin_waist.py @@ -1,51 +1,65 @@ -import asyncio import contextlib -import subprocess import sys -import time from pathlib import Path import pytest +import trio # This test is intentionally lightweight and can be marked as 'integration'. -# It ensures the echo example runs and prints the new Thin Waist lines. +# It ensures the echo example runs and prints the new Thin Waist lines using Trio primitives. EXAMPLES_DIR = Path(__file__).parent.parent.parent / "examples" / "echo" -@pytest.mark.timeout(20) -def test_echo_example_starts_and_prints_thin_waist(monkeypatch, tmp_path): - # We run: python examples/echo/echo.py -p 0 +@pytest.mark.trio +async def test_echo_example_starts_and_prints_thin_waist() -> None: cmd = [sys.executable, str(EXAMPLES_DIR / "echo.py"), "-p", "0"] - proc = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, - ) - assert proc.stdout is not None found_selected = False found_interfaces = False - start = time.time() - try: - while time.time() - start < 10: - line = proc.stdout.readline() - if not line: - time.sleep(0.1) - continue - if "Selected binding address:" in line: - found_selected = True - if "Available candidate interfaces:" in line: - found_interfaces = True - if "Waiting for incoming connections..." in line: - break - finally: - with contextlib.suppress(ProcessLookupError): - proc.terminate() - with contextlib.suppress(ProcessLookupError): - proc.kill() + # Use a cancellation scope as timeout (similar to previous 10s loop) + with trio.move_on_after(10) as cancel_scope: + # Start process streaming stdout + proc = await trio.open_process( + cmd, + stdout=trio.SUBPROCESS_PIPE, + stderr=trio.STDOUT, + ) + + assert proc.stdout is not None # for type checkers + buffer = b"" + + try: + while not (found_selected and found_interfaces): + # Read some bytes (non-blocking with timeout scope) + data = await proc.stdout.receive_some(1024) + if not data: + # Process might still be starting; yield control + await trio.sleep(0.05) + continue + buffer += data + # Process complete lines + *lines, buffer = buffer.split(b"\n") if b"\n" in buffer else ([], buffer) + for raw in lines: + line = raw.decode(errors="ignore") + if "Selected binding address:" in line: + found_selected = True + if "Available candidate interfaces:" in line: + found_interfaces = True + if "Waiting for incoming connections..." in line: + # We have reached steady state; can stop reading further + if found_selected and found_interfaces: + break + finally: + # Terminate the long-running echo example + with contextlib.suppress(Exception): + proc.terminate() + with contextlib.suppress(Exception): + await trio.move_on_after(2)(proc.wait) # best-effort wait + if cancel_scope.cancelled_caught: + # Timeout occurred + pass assert found_selected, "Did not capture Thin Waist binding log line" assert found_interfaces, "Did not capture Thin Waist interfaces log line" \ No newline at end of file From 9378490dcb4b333d558ccebaed26fc67b6a3d799 Mon Sep 17 00:00:00 2001 From: yashksaini-coder Date: Mon, 18 Aug 2025 12:40:38 +0530 Subject: [PATCH 08/71] fix: ensure loopback addresses are included in available interfaces --- libp2p/utils/address_validation.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/libp2p/utils/address_validation.py b/libp2p/utils/address_validation.py index be7f8082..493ed120 100644 --- a/libp2p/utils/address_validation.py +++ b/libp2p/utils/address_validation.py @@ -57,13 +57,23 @@ def get_available_interfaces(port: int, protocol: str = "tcp") -> List[Multiaddr addrs: List[Multiaddr] = [] # IPv4 enumeration + seen_v4: set[str] = set() for ip in _safe_get_network_addrs(4): + seen_v4.add(ip) addrs.append(Multiaddr(f"/ip4/{ip}/{protocol}/{port}")) + # Ensure loopback IPv4 explicitly present (JS echo parity) even if not returned + if "127.0.0.1" not in seen_v4: + addrs.append(Multiaddr(f"/ip4/127.0.0.1/{protocol}/{port}")) + # IPv6 enumeration (optional: only include if we have at least one global or loopback) + seen_v6: set[str] = set() for ip in _safe_get_network_addrs(6): - # Avoid returning unusable wildcard expansions if the environment does not support IPv6 + seen_v6.add(ip) addrs.append(Multiaddr(f"/ip6/{ip}/{protocol}/{port}")) + # Optionally ensure IPv6 loopback when any IPv6 present but loopback missing + if seen_v6 and "::1" not in seen_v6: + addrs.append(Multiaddr(f"/ip6/::1/{protocol}/{port}")) # Fallback if nothing discovered if not addrs: From 05b372b1eb50a0f266851b6cdf4f4c6e5ecdb202 Mon Sep 17 00:00:00 2001 From: acul71 Date: Tue, 19 Aug 2025 01:11:48 +0200 Subject: [PATCH 09/71] Fix linting and type checking issues for Thin Waist feature --- examples/advanced/network_discover.py | 9 ++- examples/echo/echo.py | 4 +- libp2p/utils/address_validation.py | 28 ++++--- tests/examples/test_echo_thin_waist.py | 101 +++++++++++++++++++------ tests/utils/test_address_validation.py | 4 +- 5 files changed, 106 insertions(+), 40 deletions(-) diff --git a/examples/advanced/network_discover.py b/examples/advanced/network_discover.py index a1a22052..87b44ddf 100644 --- a/examples/advanced/network_discover.py +++ b/examples/advanced/network_discover.py @@ -11,8 +11,8 @@ from multiaddr import Multiaddr try: from libp2p.utils.address_validation import ( - get_available_interfaces, expand_wildcard_address, + get_available_interfaces, get_optimal_binding_address, ) except ImportError: @@ -21,7 +21,10 @@ except ImportError: return [Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")] def expand_wildcard_address(addr: Multiaddr, port: int | None = None): - return [addr if port is None else Multiaddr(str(addr).rsplit("/", 1)[0] + f"/{port}")] + if port is None: + return [addr] + addr_str = str(addr).rsplit("/", 1)[0] + return [Multiaddr(addr_str + f"/{port}")] def get_optimal_binding_address(port: int, protocol: str = "tcp"): return Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}") @@ -57,4 +60,4 @@ def main() -> None: if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/echo/echo.py b/examples/echo/echo.py index 15c40c25..caf80b37 100644 --- a/examples/echo/echo.py +++ b/examples/echo/echo.py @@ -18,10 +18,8 @@ from libp2p.network.stream.net_stream import ( from libp2p.peer.peerinfo import ( info_from_p2p_addr, ) - from libp2p.utils.address_validation import ( get_optimal_binding_address, - get_available_interfaces, ) PROTOCOL_ID = TProtocol("/echo/1.0.0") @@ -38,7 +36,7 @@ async def _echo_stream_handler(stream: INetStream) -> None: async def run(port: int, destination: str, seed: int | None = None) -> None: # CHANGED: previously hardcoded 0.0.0.0 listen_addr = get_optimal_binding_address(port) - + if seed: import random diff --git a/libp2p/utils/address_validation.py b/libp2p/utils/address_validation.py index be7f8082..c0709920 100644 --- a/libp2p/utils/address_validation.py +++ b/libp2p/utils/address_validation.py @@ -1,9 +1,13 @@ from __future__ import annotations -from typing import List, Optional + from multiaddr import Multiaddr try: - from multiaddr.utils import get_thin_waist_addresses, get_network_addrs # type: ignore + from multiaddr.utils import ( # type: ignore + get_network_addrs, + get_thin_waist_addresses, + ) + _HAS_THIN_WAIST = True except ImportError: # pragma: no cover - only executed in older environments _HAS_THIN_WAIST = False @@ -11,7 +15,7 @@ except ImportError: # pragma: no cover - only executed in older environments get_network_addrs = None # type: ignore -def _safe_get_network_addrs(ip_version: int) -> List[str]: +def _safe_get_network_addrs(ip_version: int) -> list[str]: """ Internal safe wrapper. Returns a list of IP addresses for the requested IP version. Falls back to minimal defaults when Thin Waist helpers are missing. @@ -31,7 +35,7 @@ def _safe_get_network_addrs(ip_version: int) -> List[str]: return [] -def _safe_expand(addr: Multiaddr, port: Optional[int] = None) -> List[Multiaddr]: +def _safe_expand(addr: Multiaddr, port: int | None = None) -> list[Multiaddr]: """ Internal safe expansion wrapper. Returns a list of Multiaddr objects. If Thin Waist isn't available, returns [addr] (identity). @@ -46,7 +50,7 @@ def _safe_expand(addr: Multiaddr, port: Optional[int] = None) -> List[Multiaddr] return [addr] -def get_available_interfaces(port: int, protocol: str = "tcp") -> List[Multiaddr]: +def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr]: """ Discover available network interfaces (IPv4 + IPv6 if supported) for binding. @@ -54,15 +58,17 @@ def get_available_interfaces(port: int, protocol: str = "tcp") -> List[Multiaddr :param protocol: Transport protocol (e.g., "tcp" or "udp"). :return: List of Multiaddr objects representing candidate interface addresses. """ - addrs: List[Multiaddr] = [] + addrs: list[Multiaddr] = [] # IPv4 enumeration for ip in _safe_get_network_addrs(4): addrs.append(Multiaddr(f"/ip4/{ip}/{protocol}/{port}")) - # IPv6 enumeration (optional: only include if we have at least one global or loopback) + # IPv6 enumeration (optional: only include if we have at least one global or + # loopback) for ip in _safe_get_network_addrs(6): - # Avoid returning unusable wildcard expansions if the environment does not support IPv6 + # Avoid returning unusable wildcard expansions if the environment does not + # support IPv6 addrs.append(Multiaddr(f"/ip6/{ip}/{protocol}/{port}")) # Fallback if nothing discovered @@ -72,7 +78,9 @@ def get_available_interfaces(port: int, protocol: str = "tcp") -> List[Multiaddr return addrs -def expand_wildcard_address(addr: Multiaddr, port: Optional[int] = None) -> List[Multiaddr]: +def expand_wildcard_address( + addr: Multiaddr, port: int | None = None +) -> list[Multiaddr]: """ Expand a wildcard (e.g. /ip4/0.0.0.0/tcp/0) into all concrete interfaces. @@ -122,4 +130,4 @@ __all__ = [ "get_available_interfaces", "get_optimal_binding_address", "expand_wildcard_address", -] \ No newline at end of file +] diff --git a/tests/examples/test_echo_thin_waist.py b/tests/examples/test_echo_thin_waist.py index 9da85928..47e5e495 100644 --- a/tests/examples/test_echo_thin_waist.py +++ b/tests/examples/test_echo_thin_waist.py @@ -1,45 +1,60 @@ -import asyncio import contextlib +import os +from pathlib import Path import subprocess import sys import time -from pathlib import Path -import pytest +from multiaddr import Multiaddr +from multiaddr.protocols import P_IP4, P_IP6, P_P2P, P_TCP + +# pytestmark = pytest.mark.timeout(20) # Temporarily disabled for debugging # This test is intentionally lightweight and can be marked as 'integration'. # It ensures the echo example runs and prints the new Thin Waist lines. -EXAMPLES_DIR = Path(__file__).parent.parent.parent / "examples" / "echo" +current_file = Path(__file__) +project_root = current_file.parent.parent.parent +EXAMPLES_DIR: Path = project_root / "examples" / "echo" -@pytest.mark.timeout(20) def test_echo_example_starts_and_prints_thin_waist(monkeypatch, tmp_path): - # We run: python examples/echo/echo.py -p 0 - cmd = [sys.executable, str(EXAMPLES_DIR / "echo.py"), "-p", "0"] - proc = subprocess.Popen( + """Run echo server and validate printed multiaddr and peer id.""" + # Run echo example as server + cmd = [sys.executable, "-u", str(EXAMPLES_DIR / "echo.py"), "-p", "0"] + env = {**os.environ, "PYTHONUNBUFFERED": "1"} + proc: subprocess.Popen[str] = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, + env=env, ) - assert proc.stdout is not None - found_selected = False - found_interfaces = False + if proc.stdout is None: + proc.terminate() + raise RuntimeError("Process stdout is None") + out_stream = proc.stdout + + peer_id: str | None = None + printed_multiaddr: str | None = None + saw_waiting = False + start = time.time() - + timeout_s = 8.0 try: - while time.time() - start < 10: - line = proc.stdout.readline() + while time.time() - start < timeout_s: + line = out_stream.readline() if not line: - time.sleep(0.1) + time.sleep(0.05) continue - if "Selected binding address:" in line: - found_selected = True - if "Available candidate interfaces:" in line: - found_interfaces = True - if "Waiting for incoming connections..." in line: + s = line.strip() + if s.startswith("I am "): + peer_id = s.partition("I am ")[2] + if s.startswith("echo-demo -d "): + printed_multiaddr = s.partition("echo-demo -d ")[2] + if "Waiting for incoming connections..." in s: + saw_waiting = True break finally: with contextlib.suppress(ProcessLookupError): @@ -47,5 +62,47 @@ def test_echo_example_starts_and_prints_thin_waist(monkeypatch, tmp_path): with contextlib.suppress(ProcessLookupError): proc.kill() - assert found_selected, "Did not capture Thin Waist binding log line" - assert found_interfaces, "Did not capture Thin Waist interfaces log line" \ No newline at end of file + assert peer_id, "Did not capture peer ID line" + assert printed_multiaddr, "Did not capture multiaddr line" + assert saw_waiting, "Did not capture waiting-for-connections line" + + # Validate multiaddr structure using py-multiaddr protocol methods + ma = Multiaddr(printed_multiaddr) # should parse without error + + # Check that the multiaddr contains the p2p protocol + try: + peer_id_from_multiaddr = ma.value_for_protocol("p2p") + assert peer_id_from_multiaddr is not None, ( + "Multiaddr missing p2p protocol value" + ) + assert peer_id_from_multiaddr == peer_id, ( + f"Peer ID mismatch: {peer_id_from_multiaddr} != {peer_id}" + ) + except Exception as e: + raise AssertionError(f"Failed to extract p2p protocol value: {e}") + + # Validate the multiaddr structure by checking protocols + protocols = ma.protocols() + + # Should have at least IP, TCP, and P2P protocols + assert any(p.code == P_IP4 or p.code == P_IP6 for p in protocols), ( + "Missing IP protocol" + ) + assert any(p.code == P_TCP for p in protocols), "Missing TCP protocol" + assert any(p.code == P_P2P for p in protocols), "Missing P2P protocol" + + # Extract the p2p part and validate it matches the captured peer ID + p2p_part = Multiaddr(f"/p2p/{peer_id}") + try: + # Decapsulate the p2p part to get the transport address + transport_addr = ma.decapsulate(p2p_part) + # Verify the decapsulated address doesn't contain p2p + transport_protocols = transport_addr.protocols() + assert not any(p.code == P_P2P for p in transport_protocols), ( + "Decapsulation failed - still contains p2p" + ) + # Verify the original multiaddr can be reconstructed + reconstructed = transport_addr.encapsulate(p2p_part) + assert str(reconstructed) == str(ma), "Reconstruction failed" + except Exception as e: + raise AssertionError(f"Multiaddr decapsulation failed: {e}") diff --git a/tests/utils/test_address_validation.py b/tests/utils/test_address_validation.py index 80ae27e8..5b108d09 100644 --- a/tests/utils/test_address_validation.py +++ b/tests/utils/test_address_validation.py @@ -4,9 +4,9 @@ import pytest from multiaddr import Multiaddr from libp2p.utils.address_validation import ( + expand_wildcard_address, get_available_interfaces, get_optimal_binding_address, - expand_wildcard_address, ) @@ -53,4 +53,4 @@ def test_expand_wildcard_address_ipv6() -> None: expanded = expand_wildcard_address(wildcard) assert len(expanded) > 0 for e in expanded: - assert "/ip6/" in str(e) \ No newline at end of file + assert "/ip6/" in str(e) From a1b16248d3ca941ad16f624867387faabfa60a06 Mon Sep 17 00:00:00 2001 From: yashksaini-coder Date: Tue, 19 Aug 2025 20:47:18 +0530 Subject: [PATCH 10/71] fix: correct listening address variable in echo example and streamline address printing --- examples/echo/echo.py | 5 ++--- libp2p/utils/address_validation.py | 10 ++++++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/echo/echo.py b/examples/echo/echo.py index 67e82e07..0cf8c449 100644 --- a/examples/echo/echo.py +++ b/examples/echo/echo.py @@ -47,7 +47,7 @@ async def run(port: int, destination: str, seed: int | None = None) -> None: secret = secrets.token_bytes(32) host = new_host(key_pair=create_new_key_pair(secret)) - async with host.run(listen_addrs=listen_addrs), trio.open_nursery() as nursery: + async with host.run(listen_addr=listen_addr), trio.open_nursery() as nursery: # Start the peer-store cleanup task nursery.start_soon(host.get_peerstore().start_cleanup_task, 60) @@ -59,8 +59,7 @@ async def run(port: int, destination: str, seed: int | None = None) -> None: # Print all listen addresses with peer ID (JS parity) print("Listener ready, listening on:") peer_id = host.get_id().to_string() - for addr in listen_addrs: - print(f"{addr}/p2p/{peer_id}") + print(f"{listen_addr}/p2p/{peer_id}") print( "\nRun this from the same folder in another console:\n\n" diff --git a/libp2p/utils/address_validation.py b/libp2p/utils/address_validation.py index e323dbd5..67299270 100644 --- a/libp2p/utils/address_validation.py +++ b/libp2p/utils/address_validation.py @@ -62,16 +62,18 @@ def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr # IPv4 enumeration seen_v4: set[str] = set() + for ip in _safe_get_network_addrs(4): seen_v4.add(ip) addrs.append(Multiaddr(f"/ip4/{ip}/{protocol}/{port}")) + seen_v6: set[str] = set() + for ip in _safe_get_network_addrs(6): + seen_v6.add(ip) + addrs.append(Multiaddr(f"/ip6/{ip}/{protocol}/{port}")) + # IPv6 enumeration (optional: only include if we have at least one global or # loopback) - for ip in _safe_get_network_addrs(6): - # Avoid returning unusable wildcard expansions if the environment does not - # support IPv6 - addrs.append(Multiaddr(f"/ip6/{ip}/{protocol}/{port}")) # Optionally ensure IPv6 loopback when any IPv6 present but loopback missing if seen_v6 and "::1" not in seen_v6: addrs.append(Multiaddr(f"/ip6/::1/{protocol}/{port}")) From 69d52748913406f83fab0c23acfcdf22b8057371 Mon Sep 17 00:00:00 2001 From: yashksaini-coder Date: Tue, 19 Aug 2025 22:32:26 +0530 Subject: [PATCH 11/71] fix: update listening address parameter in echo example to accept a list --- examples/echo/echo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/echo/echo.py b/examples/echo/echo.py index 0cf8c449..8075f125 100644 --- a/examples/echo/echo.py +++ b/examples/echo/echo.py @@ -47,7 +47,7 @@ async def run(port: int, destination: str, seed: int | None = None) -> None: secret = secrets.token_bytes(32) host = new_host(key_pair=create_new_key_pair(secret)) - async with host.run(listen_addr=listen_addr), trio.open_nursery() as nursery: + async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery: # Start the peer-store cleanup task nursery.start_soon(host.get_peerstore().start_cleanup_task, 60) From c2c91b8c58ca8eae4032272316aa41b300db7731 Mon Sep 17 00:00:00 2001 From: yashksaini-coder Date: Wed, 20 Aug 2025 18:05:20 +0530 Subject: [PATCH 12/71] refactor: Improve comment formatting in test_echo_thin_waist.py for clarity --- tests/examples/test_echo_thin_waist.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/examples/test_echo_thin_waist.py b/tests/examples/test_echo_thin_waist.py index e9401225..2bcb52b1 100644 --- a/tests/examples/test_echo_thin_waist.py +++ b/tests/examples/test_echo_thin_waist.py @@ -11,7 +11,8 @@ from multiaddr.protocols import P_IP4, P_IP6, P_P2P, P_TCP # pytestmark = pytest.mark.timeout(20) # Temporarily disabled for debugging # This test is intentionally lightweight and can be marked as 'integration'. -# It ensures the echo example runs and prints the new Thin Waist lines using Trio primitives. +# It ensures the echo example runs and prints the new Thin Waist lines using +# Trio primitives. current_file = Path(__file__) project_root = current_file.parent.parent.parent From 5b9bec8e28820d66d859b6bc3f40fb3b70b80dbb Mon Sep 17 00:00:00 2001 From: yashksaini-coder Date: Wed, 20 Aug 2025 18:29:35 +0530 Subject: [PATCH 13/71] fix: Enhance error handling in echo stream handler to manage stream closure and exceptions --- examples/echo/echo.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/examples/echo/echo.py b/examples/echo/echo.py index 8075f125..73d30df9 100644 --- a/examples/echo/echo.py +++ b/examples/echo/echo.py @@ -15,6 +15,9 @@ from libp2p.custom_types import ( from libp2p.network.stream.net_stream import ( INetStream, ) +from libp2p.network.stream.exceptions import ( + StreamEOF, +) from libp2p.peer.peerinfo import ( info_from_p2p_addr, ) @@ -27,10 +30,19 @@ MAX_READ_LEN = 2**32 - 1 async def _echo_stream_handler(stream: INetStream) -> None: - # Wait until EOF - msg = await stream.read(MAX_READ_LEN) - await stream.write(msg) - await stream.close() + try: + peer_id = stream.muxed_conn.peer_id + print(f"Received connection from {peer_id}") + # Wait until EOF + msg = await stream.read(MAX_READ_LEN) + print(f"Echoing message: {msg.decode('utf-8')}") + await stream.write(msg) + except StreamEOF: + print("Stream closed by remote peer.") + except Exception as e: + print(f"Error in echo handler: {e}") + finally: + await stream.close() async def run(port: int, destination: str, seed: int | None = None) -> None: @@ -63,8 +75,7 @@ async def run(port: int, destination: str, seed: int | None = None) -> None: print( "\nRun this from the same folder in another console:\n\n" - f"echo-demo " - f"-d {host.get_addrs()[0]}\n" + f"echo-demo -d {host.get_addrs()[0]}\n" ) print("Waiting for incoming connections...") await trio.sleep_forever() From ed2716c1bf6ab339569be8277ce3bcdc93e58de0 Mon Sep 17 00:00:00 2001 From: yashksaini-coder Date: Fri, 22 Aug 2025 11:48:37 +0530 Subject: [PATCH 14/71] feat: Enhance echo example to dynamically find free ports and improve address handling - Added a function to find a free port on localhost. - Updated the run function to use the new port finding logic when a non-positive port is provided. - Modified address printing to handle multiple listen addresses correctly. - Improved the get_available_interfaces function to ensure the IPv4 loopback address is included. --- examples/echo/echo.py | 31 ++++++++++++++++++++---------- libp2p/utils/address_validation.py | 4 ++++ 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/examples/echo/echo.py b/examples/echo/echo.py index 73d30df9..fe59e6df 100644 --- a/examples/echo/echo.py +++ b/examples/echo/echo.py @@ -1,4 +1,7 @@ import argparse +import random +import secrets +import socket import multiaddr import trio @@ -12,23 +15,30 @@ from libp2p.crypto.secp256k1 import ( from libp2p.custom_types import ( TProtocol, ) -from libp2p.network.stream.net_stream import ( - INetStream, -) from libp2p.network.stream.exceptions import ( StreamEOF, ) +from libp2p.network.stream.net_stream import ( + INetStream, +) from libp2p.peer.peerinfo import ( info_from_p2p_addr, ) from libp2p.utils.address_validation import ( - get_optimal_binding_address, + get_available_interfaces, ) PROTOCOL_ID = TProtocol("/echo/1.0.0") MAX_READ_LEN = 2**32 - 1 +def find_free_port(): + """Find a free port on localhost.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) # Bind to a free port provided by the OS + return s.getsockname()[1] + + async def _echo_stream_handler(stream: INetStream) -> None: try: peer_id = stream.muxed_conn.peer_id @@ -47,19 +57,19 @@ async def _echo_stream_handler(stream: INetStream) -> None: async def run(port: int, destination: str, seed: int | None = None) -> None: # CHANGED: previously hardcoded 0.0.0.0 - listen_addr = get_optimal_binding_address(port) + if port <= 0: + port = find_free_port() + listen_addr = get_available_interfaces(port) if seed: - import random random.seed(seed) secret_number = random.getrandbits(32 * 8) secret = secret_number.to_bytes(length=32, byteorder="big") else: - import secrets secret = secrets.token_bytes(32) host = new_host(key_pair=create_new_key_pair(secret)) - async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery: + async with host.run(listen_addrs=listen_addr), trio.open_nursery() as nursery: # Start the peer-store cleanup task nursery.start_soon(host.get_peerstore().start_cleanup_task, 60) @@ -69,9 +79,10 @@ async def run(port: int, destination: str, seed: int | None = None) -> None: host.set_stream_handler(PROTOCOL_ID, _echo_stream_handler) # Print all listen addresses with peer ID (JS parity) - print("Listener ready, listening on:") + print("Listener ready, listening on:\n") peer_id = host.get_id().to_string() - print(f"{listen_addr}/p2p/{peer_id}") + for addr in listen_addr: + print(f"{addr}/p2p/{peer_id}") print( "\nRun this from the same folder in another console:\n\n" diff --git a/libp2p/utils/address_validation.py b/libp2p/utils/address_validation.py index 67299270..565bef28 100644 --- a/libp2p/utils/address_validation.py +++ b/libp2p/utils/address_validation.py @@ -67,6 +67,10 @@ def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr seen_v4.add(ip) addrs.append(Multiaddr(f"/ip4/{ip}/{protocol}/{port}")) + # Ensure IPv4 loopback is always included when IPv4 interfaces are discovered + if seen_v4 and "127.0.0.1" not in seen_v4: + addrs.append(Multiaddr(f"/ip4/127.0.0.1/{protocol}/{port}")) + seen_v6: set[str] = set() for ip in _safe_get_network_addrs(6): seen_v6.add(ip) From b6cbd78943a51af5a1f67665a3efea07979e307b Mon Sep 17 00:00:00 2001 From: acul71 Date: Sun, 24 Aug 2025 01:49:42 +0200 Subject: [PATCH 15/71] Fix multi-address listening bug in swarm.listen() - Fix early return in swarm.listen() that prevented listening on all addresses - Add comprehensive tests for multi-address listening functionality - Ensure all available interfaces are properly bound and connectable --- libp2p/network/swarm.py | 11 +-- tests/core/network/test_swarm.py | 116 +++++++++++++++++++++++++++++++ 2 files changed, 123 insertions(+), 4 deletions(-) diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 0aa60514..67d46279 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -249,9 +249,11 @@ class Swarm(Service, INetworkService): # We need to wait until `self.listener_nursery` is created. await self.event_listener_nursery_created.wait() + success_count = 0 for maddr in multiaddrs: if str(maddr) in self.listeners: - return True + success_count += 1 + continue async def conn_handler( read_write_closer: ReadWriteCloser, maddr: Multiaddr = maddr @@ -302,13 +304,14 @@ class Swarm(Service, INetworkService): # Call notifiers since event occurred await self.notify_listen(maddr) - return True + success_count += 1 + logger.debug("successfully started listening on: %s", maddr) except OSError: # Failed. Continue looping. logger.debug("fail to listen on: %s", maddr) - # No maddr succeeded - return False + # Return true if at least one address succeeded + return success_count > 0 async def close(self) -> None: """ diff --git a/tests/core/network/test_swarm.py b/tests/core/network/test_swarm.py index 6389bcb3..605913ec 100644 --- a/tests/core/network/test_swarm.py +++ b/tests/core/network/test_swarm.py @@ -16,6 +16,9 @@ from libp2p.network.exceptions import ( from libp2p.network.swarm import ( Swarm, ) +from libp2p.tools.async_service import ( + background_trio_service, +) from libp2p.tools.utils import ( connect_swarm, ) @@ -184,3 +187,116 @@ def test_new_swarm_quic_multiaddr_raises(): addr = Multiaddr("/ip4/127.0.0.1/udp/9999/quic") with pytest.raises(ValueError, match="QUIC not yet supported"): new_swarm(listen_addrs=[addr]) + + +@pytest.mark.trio +async def test_swarm_listen_multiple_addresses(security_protocol): + """Test that swarm can listen on multiple addresses simultaneously.""" + from libp2p.utils.address_validation import get_available_interfaces + + # Get multiple addresses to listen on + listen_addrs = get_available_interfaces(0) # Let OS choose ports + + # Create a swarm and listen on multiple addresses + swarm = SwarmFactory.build(security_protocol=security_protocol) + async with background_trio_service(swarm): + # Listen on all addresses + success = await swarm.listen(*listen_addrs) + assert success, "Should successfully listen on at least one address" + + # Check that we have listeners for the addresses + actual_listeners = list(swarm.listeners.keys()) + assert len(actual_listeners) > 0, "Should have at least one listener" + + # Verify that all successful listeners are in the listeners dict + successful_count = 0 + for addr in listen_addrs: + addr_str = str(addr) + if addr_str in actual_listeners: + successful_count += 1 + # This address successfully started listening + listener = swarm.listeners[addr_str] + listener_addrs = listener.get_addrs() + assert len(listener_addrs) > 0, ( + f"Listener for {addr} should have addresses" + ) + + # Check that the listener address matches the expected address + # (port might be different if we used port 0) + expected_ip = addr.value_for_protocol("ip4") + expected_protocol = addr.value_for_protocol("tcp") + if expected_ip and expected_protocol: + found_matching = False + for listener_addr in listener_addrs: + if ( + listener_addr.value_for_protocol("ip4") == expected_ip + and listener_addr.value_for_protocol("tcp") is not None + ): + found_matching = True + break + assert found_matching, ( + f"Listener for {addr} should have matching IP" + ) + + assert successful_count == len(listen_addrs), ( + f"All {len(listen_addrs)} addresses should be listening, " + f"but only {successful_count} succeeded" + ) + + +@pytest.mark.trio +async def test_swarm_listen_multiple_addresses_connectivity(security_protocol): + """Test that real libp2p connections can be established to all listening addresses.""" # noqa: E501 + from libp2p.peer.peerinfo import info_from_p2p_addr + from libp2p.utils.address_validation import get_available_interfaces + + # Get multiple addresses to listen on + listen_addrs = get_available_interfaces(0) # Let OS choose ports + + # Create a swarm and listen on multiple addresses + swarm1 = SwarmFactory.build(security_protocol=security_protocol) + async with background_trio_service(swarm1): + # Listen on all addresses + success = await swarm1.listen(*listen_addrs) + assert success, "Should successfully listen on at least one address" + + # Verify all available interfaces are listening + assert len(swarm1.listeners) == len(listen_addrs), ( + f"All {len(listen_addrs)} interfaces should be listening, " + f"but only {len(swarm1.listeners)} are" + ) + + # Create a second swarm to test connections + swarm2 = SwarmFactory.build(security_protocol=security_protocol) + async with background_trio_service(swarm2): + # Test connectivity to each listening address using real libp2p connections + for addr_str, listener in swarm1.listeners.items(): + listener_addrs = listener.get_addrs() + for listener_addr in listener_addrs: + # Create a full multiaddr with peer ID for libp2p connection + peer_id = swarm1.get_peer_id() + full_addr = listener_addr.encapsulate(f"/p2p/{peer_id}") + + # Test real libp2p connection + try: + peer_info = info_from_p2p_addr(full_addr) + + # Add the peer info to swarm2's peerstore so it knows where to connect # noqa: E501 + swarm2.peerstore.add_addrs( + peer_info.peer_id, [listener_addr], 10000 + ) + + await swarm2.dial_peer(peer_info.peer_id) + + # Verify connection was established + assert peer_info.peer_id in swarm2.connections, ( + f"Connection to {full_addr} should be established" + ) + assert swarm2.get_peer_id() in swarm1.connections, ( + f"Connection from {full_addr} should be established" + ) + + except Exception as e: + pytest.fail( + f"Failed to establish libp2p connection to {full_addr}: {e}" + ) From 3bd6d1f579d454b68853d7bb03c2615d064b3f8b Mon Sep 17 00:00:00 2001 From: acul71 Date: Sun, 24 Aug 2025 02:29:23 +0200 Subject: [PATCH 16/71] doc: add newsfragment --- newsfragments/863.bugfix.rst | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 newsfragments/863.bugfix.rst diff --git a/newsfragments/863.bugfix.rst b/newsfragments/863.bugfix.rst new file mode 100644 index 00000000..64de57b4 --- /dev/null +++ b/newsfragments/863.bugfix.rst @@ -0,0 +1,5 @@ +Fix multi-address listening bug in swarm.listen() + +- Fix early return in swarm.listen() that prevented listening on all addresses +- Add comprehensive tests for multi-address listening functionality +- Ensure all available interfaces are properly bound and connectable From 88a1f0a390b5716ba7287788a64391e78613fcc7 Mon Sep 17 00:00:00 2001 From: yashksaini-coder Date: Sun, 24 Aug 2025 21:17:29 +0530 Subject: [PATCH 17/71] cherry pick https://github.com/acul71/py-libp2p-fork/blob/7a1198c8c6e9a69c1ab5044adf04a859828c0a95/libp2p/utils/address_validation.py --- libp2p/utils/address_validation.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/libp2p/utils/address_validation.py b/libp2p/utils/address_validation.py index 565bef28..189f00cc 100644 --- a/libp2p/utils/address_validation.py +++ b/libp2p/utils/address_validation.py @@ -71,16 +71,22 @@ def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr if seen_v4 and "127.0.0.1" not in seen_v4: addrs.append(Multiaddr(f"/ip4/127.0.0.1/{protocol}/{port}")) - seen_v6: set[str] = set() - for ip in _safe_get_network_addrs(6): - seen_v6.add(ip) - addrs.append(Multiaddr(f"/ip6/{ip}/{protocol}/{port}")) - - # IPv6 enumeration (optional: only include if we have at least one global or - # loopback) - # Optionally ensure IPv6 loopback when any IPv6 present but loopback missing - if seen_v6 and "::1" not in seen_v6: - addrs.append(Multiaddr(f"/ip6/::1/{protocol}/{port}")) + # TODO: IPv6 support temporarily disabled due to libp2p handshake issues + # IPv6 connections fail during protocol negotiation (SecurityUpgradeFailure) + # Re-enable IPv6 support once the following issues are resolved: + # - libp2p security handshake over IPv6 + # - multiselect protocol over IPv6 + # - connection establishment over IPv6 + # + # seen_v6: set[str] = set() + # for ip in _safe_get_network_addrs(6): + # seen_v6.add(ip) + # addrs.append(Multiaddr(f"/ip6/{ip}/{protocol}/{port}")) + # + # # Always include IPv6 loopback for testing purposes when IPv6 is available + # # This ensures IPv6 functionality can be tested even without global IPv6 addresses + # if "::1" not in seen_v6: + # addrs.append(Multiaddr(f"/ip6/::1/{protocol}/{port}")) # Fallback if nothing discovered if not addrs: @@ -141,4 +147,4 @@ __all__ = [ "get_available_interfaces", "get_optimal_binding_address", "expand_wildcard_address", -] +] \ No newline at end of file From cf48d2e9a4ed61349ee88684d28206fb23e08ea0 Mon Sep 17 00:00:00 2001 From: yashksaini-coder Date: Sun, 24 Aug 2025 22:03:31 +0530 Subject: [PATCH 18/71] chore(app): Add 811.internal.rst --- newsfragments/811.internal.rst | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 newsfragments/811.internal.rst diff --git a/newsfragments/811.internal.rst b/newsfragments/811.internal.rst new file mode 100644 index 00000000..8d86c55d --- /dev/null +++ b/newsfragments/811.internal.rst @@ -0,0 +1,7 @@ +Add Thin Waist address validation utilities and integrate into echo example + +- Add ``libp2p/utils/address_validation.py`` with dynamic interface discovery +- Implement ``get_available_interfaces()``, ``get_optimal_binding_address()``, and ``expand_wildcard_address()`` +- Update echo example to use dynamic address discovery instead of hardcoded wildcard +- Add safe fallbacks for environments lacking Thin Waist support +- Temporarily disable IPv6 support due to libp2p handshake issues (TODO: re-enable when resolved) \ No newline at end of file From 75ffb791acd57051ce1aa4db06f7b1da6823ae2a Mon Sep 17 00:00:00 2001 From: yashksaini-coder Date: Sun, 24 Aug 2025 22:06:07 +0530 Subject: [PATCH 19/71] fix: Ensure newline at end of file in address_validation.py and update news fragment formatting --- libp2p/utils/address_validation.py | 2 +- newsfragments/811.internal.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/libp2p/utils/address_validation.py b/libp2p/utils/address_validation.py index 189f00cc..c0da78a6 100644 --- a/libp2p/utils/address_validation.py +++ b/libp2p/utils/address_validation.py @@ -147,4 +147,4 @@ __all__ = [ "get_available_interfaces", "get_optimal_binding_address", "expand_wildcard_address", -] \ No newline at end of file +] diff --git a/newsfragments/811.internal.rst b/newsfragments/811.internal.rst index 8d86c55d..59804430 100644 --- a/newsfragments/811.internal.rst +++ b/newsfragments/811.internal.rst @@ -4,4 +4,4 @@ Add Thin Waist address validation utilities and integrate into echo example - Implement ``get_available_interfaces()``, ``get_optimal_binding_address()``, and ``expand_wildcard_address()`` - Update echo example to use dynamic address discovery instead of hardcoded wildcard - Add safe fallbacks for environments lacking Thin Waist support -- Temporarily disable IPv6 support due to libp2p handshake issues (TODO: re-enable when resolved) \ No newline at end of file +- Temporarily disable IPv6 support due to libp2p handshake issues (TODO: re-enable when resolved) From ed91ee0c311c74f4fb9bb4edb4acf07887a49521 Mon Sep 17 00:00:00 2001 From: yashksaini-coder Date: Sun, 24 Aug 2025 23:28:02 +0530 Subject: [PATCH 20/71] refactor(app): 804 refactored find_free_port() in address_validation.py --- examples/echo/echo.py | 9 +-------- examples/pubsub/pubsub.py | 11 +++-------- libp2p/utils/address_validation.py | 10 ++++++++++ 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/examples/echo/echo.py b/examples/echo/echo.py index fe59e6df..e713a759 100644 --- a/examples/echo/echo.py +++ b/examples/echo/echo.py @@ -1,7 +1,6 @@ import argparse import random import secrets -import socket import multiaddr import trio @@ -25,6 +24,7 @@ from libp2p.peer.peerinfo import ( info_from_p2p_addr, ) from libp2p.utils.address_validation import ( + find_free_port, get_available_interfaces, ) @@ -32,13 +32,6 @@ PROTOCOL_ID = TProtocol("/echo/1.0.0") MAX_READ_LEN = 2**32 - 1 -def find_free_port(): - """Find a free port on localhost.""" - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) # Bind to a free port provided by the OS - return s.getsockname()[1] - - async def _echo_stream_handler(stream: INetStream) -> None: try: peer_id = stream.muxed_conn.peer_id diff --git a/examples/pubsub/pubsub.py b/examples/pubsub/pubsub.py index 1ab6d650..41545658 100644 --- a/examples/pubsub/pubsub.py +++ b/examples/pubsub/pubsub.py @@ -1,6 +1,5 @@ import argparse import logging -import socket import base58 import multiaddr @@ -31,6 +30,9 @@ from libp2p.stream_muxer.mplex.mplex import ( from libp2p.tools.async_service.trio_service import ( background_trio_service, ) +from libp2p.utils.address_validation import ( + find_free_port, +) # Configure logging logging.basicConfig( @@ -77,13 +79,6 @@ async def publish_loop(pubsub, topic, termination_event): await trio.sleep(1) # Avoid tight loop on error -def find_free_port(): - """Find a free port on localhost.""" - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) # Bind to a free port provided by the OS - return s.getsockname()[1] - - async def monitor_peer_topics(pubsub, nursery, termination_event): """ Monitor for new topics that peers are subscribed to and diff --git a/libp2p/utils/address_validation.py b/libp2p/utils/address_validation.py index c0da78a6..77b797a1 100644 --- a/libp2p/utils/address_validation.py +++ b/libp2p/utils/address_validation.py @@ -1,5 +1,7 @@ from __future__ import annotations +import socket + from multiaddr import Multiaddr try: @@ -35,6 +37,13 @@ def _safe_get_network_addrs(ip_version: int) -> list[str]: return [] +def find_free_port() -> int: + """Find a free port on localhost.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) # Bind to a free port provided by the OS + return s.getsockname()[1] + + def _safe_expand(addr: Multiaddr, port: int | None = None) -> list[Multiaddr]: """ Internal safe expansion wrapper. Returns a list of Multiaddr objects. @@ -147,4 +156,5 @@ __all__ = [ "get_available_interfaces", "get_optimal_binding_address", "expand_wildcard_address", + "find_free_port", ] From 63a8458d451fe8d73ab8fca2502b42eaffb5a77b Mon Sep 17 00:00:00 2001 From: yashksaini-coder Date: Sun, 24 Aug 2025 23:40:05 +0530 Subject: [PATCH 21/71] add import to __init__ --- libp2p/utils/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/libp2p/utils/__init__.py b/libp2p/utils/__init__.py index 0f68e701..b881eb92 100644 --- a/libp2p/utils/__init__.py +++ b/libp2p/utils/__init__.py @@ -19,6 +19,7 @@ from libp2p.utils.address_validation import ( get_available_interfaces, get_optimal_binding_address, expand_wildcard_address, + find_free_port, ) __all__ = [ @@ -35,4 +36,5 @@ __all__ = [ "get_available_interfaces", "get_optimal_binding_address", "expand_wildcard_address", + "find_free_port", ] From 6a0a7c21e85a589550ea3b83a0b77d93735769b8 Mon Sep 17 00:00:00 2001 From: yashksaini-coder Date: Mon, 25 Aug 2025 01:31:30 +0530 Subject: [PATCH 22/71] chore(app): Add newsfragment for 811.feature.rst --- newsfragments/811.feature.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 newsfragments/811.feature.rst diff --git a/newsfragments/811.feature.rst b/newsfragments/811.feature.rst new file mode 100644 index 00000000..47a0aa68 --- /dev/null +++ b/newsfragments/811.feature.rst @@ -0,0 +1 @@ + Added Thin Waist address validation utilities (with support for interface enumeration, optimal binding, and wildcard expansion). From 6c6adf7459dbeb12f5a3ff9804bf52775da532a4 Mon Sep 17 00:00:00 2001 From: yashksaini-coder Date: Mon, 25 Aug 2025 12:43:18 +0530 Subject: [PATCH 23/71] chore(app): 804 Suggested changes - Remove the comment --- examples/echo/echo.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/echo/echo.py b/examples/echo/echo.py index e713a759..19e98377 100644 --- a/examples/echo/echo.py +++ b/examples/echo/echo.py @@ -49,7 +49,6 @@ async def _echo_stream_handler(stream: INetStream) -> None: async def run(port: int, destination: str, seed: int | None = None) -> None: - # CHANGED: previously hardcoded 0.0.0.0 if port <= 0: port = find_free_port() listen_addr = get_available_interfaces(port) From 621469734949df6e7b9abecfb4edc585f97766d2 Mon Sep 17 00:00:00 2001 From: unniznd Date: Mon, 25 Aug 2025 23:01:35 +0530 Subject: [PATCH 24/71] removed redundant imports --- libp2p/security/security_multistream.py | 3 --- libp2p/stream_muxer/muxer_multistream.py | 3 --- 2 files changed, 6 deletions(-) diff --git a/libp2p/security/security_multistream.py b/libp2p/security/security_multistream.py index 9b341ed7..a9c4b19c 100644 --- a/libp2p/security/security_multistream.py +++ b/libp2p/security/security_multistream.py @@ -29,9 +29,6 @@ from libp2p.protocol_muxer.multiselect_client import ( from libp2p.protocol_muxer.multiselect_communicator import ( MultiselectCommunicator, ) -from libp2p.transport.exceptions import ( - SecurityUpgradeFailure, -) """ Represents a secured connection object, which includes a connection and details about diff --git a/libp2p/stream_muxer/muxer_multistream.py b/libp2p/stream_muxer/muxer_multistream.py index 4a07b261..322db912 100644 --- a/libp2p/stream_muxer/muxer_multistream.py +++ b/libp2p/stream_muxer/muxer_multistream.py @@ -33,9 +33,6 @@ from libp2p.stream_muxer.yamux.yamux import ( PROTOCOL_ID, Yamux, ) -from libp2p.transport.exceptions import ( - MuxerUpgradeFailure, -) class MuxerMultistream: From 53db128f6984d6d4f38dd8a9195b66a475f9b9f8 Mon Sep 17 00:00:00 2001 From: lla-dane Date: Tue, 12 Aug 2025 13:57:16 +0530 Subject: [PATCH 25/71] fix typos --- libp2p/identity/identify/identify.py | 13 +- libp2p/kad_dht/kad_dht.py | 223 +++++++++++++++++++ libp2p/kad_dht/pb/kademlia.proto | 3 + libp2p/kad_dht/pb/kademlia_pb2.py | 31 +-- libp2p/kad_dht/pb/kademlia_pb2.pyi | 197 ++++++---------- libp2p/kad_dht/peer_routing.py | 91 ++++++++ libp2p/kad_dht/provider_store.py | 89 +++++++- libp2p/kad_dht/value_store.py | 60 ++++- libp2p/peer/peerstore.py | 14 +- tests/core/kad_dht/test_unit_peer_routing.py | 5 +- 10 files changed, 569 insertions(+), 157 deletions(-) diff --git a/libp2p/identity/identify/identify.py b/libp2p/identity/identify/identify.py index b2811ff9..4e8931ba 100644 --- a/libp2p/identity/identify/identify.py +++ b/libp2p/identity/identify/identify.py @@ -15,8 +15,7 @@ from libp2p.custom_types import ( from libp2p.network.stream.exceptions import ( StreamClosed, ) -from libp2p.peer.envelope import seal_record -from libp2p.peer.peer_record import PeerRecord +from libp2p.peer.peerstore import create_signed_peer_record from libp2p.utils import ( decode_varint_with_size, get_agent_version, @@ -66,9 +65,11 @@ def _mk_identify_protobuf( protocols = tuple(str(p) for p in host.get_mux().get_protocols() if p is not None) # Create a signed peer-record for the remote peer - record = PeerRecord(host.get_id(), host.get_addrs()) - envelope = seal_record(record, host.get_private_key()) - protobuf = envelope.marshal_envelope() + envelope = create_signed_peer_record( + host.get_id(), + host.get_addrs(), + host.get_private_key(), + ) observed_addr = observed_multiaddr.to_bytes() if observed_multiaddr else b"" return Identify( @@ -78,7 +79,7 @@ def _mk_identify_protobuf( listen_addrs=map(_multiaddr_to_bytes, laddrs), observed_addr=observed_addr, protocols=protocols, - signedPeerRecord=protobuf, + signedPeerRecord=envelope.marshal_envelope(), ) diff --git a/libp2p/kad_dht/kad_dht.py b/libp2p/kad_dht/kad_dht.py index 097b6c48..2fb42662 100644 --- a/libp2p/kad_dht/kad_dht.py +++ b/libp2p/kad_dht/kad_dht.py @@ -25,12 +25,14 @@ from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager from libp2p.network.stream.net_stream import ( INetStream, ) +from libp2p.peer.envelope import Envelope, consume_envelope from libp2p.peer.id import ( ID, ) from libp2p.peer.peerinfo import ( PeerInfo, ) +from libp2p.peer.peerstore import create_signed_peer_record from libp2p.tools.async_service import ( Service, ) @@ -234,6 +236,9 @@ class KadDHT(Service): await self.add_peer(peer_id) logger.debug(f"Added peer {peer_id} to routing table") + closer_peer_envelope: Envelope | None = None + provider_peer_envelope: Envelope | None = None + try: # Read varint-prefixed length for the message length_prefix = b"" @@ -266,6 +271,7 @@ class KadDHT(Service): # Handle FIND_NODE message if message.type == Message.MessageType.FIND_NODE: # Get target key directly from protobuf + print("FIND NODE RECEIVED") target_key = message.key # Find closest peers to the target key @@ -274,6 +280,26 @@ class KadDHT(Service): ) logger.debug(f"Found {len(closest_peers)} peers close to target") + # Consume the source signed_peer_record if sent + if message.HasField("senderRecord"): + try: + # Convert the signed-peer-record(Envelope) from + # protobuf bytes + envelope, _ = consume_envelope( + message.senderRecord, "libp2p-peer-record" + ) + # Use the defualt TTL of 2 hours (7200 seconds) + if not self.host.get_peerstore().consume_peer_record( + envelope, 7200 + ): + logger.error( + "Updating the Certified-Addr-Book was unsuccessful" + ) + except Exception as e: + logger.error( + "Error updating the certified addr book for peer: %s", e + ) + # Build response message with protobuf response = Message() response.type = Message.MessageType.FIND_NODE @@ -298,6 +324,25 @@ class KadDHT(Service): except Exception: pass + # Add the signed-peer-record for each peer in the peer-proto + # if cached in the peerstore + closer_peer_envelope = ( + self.host.get_peerstore().get_peer_record(peer) + ) + + if closer_peer_envelope is not None: + peer_proto.signedRecord = ( + closer_peer_envelope.marshal_envelope() + ) + + # Create sender_signed_peer_record + envelope = create_signed_peer_record( + self.host.get_id(), + self.host.get_addrs(), + self.host.get_private_key(), + ) + response.senderRecord = envelope.marshal_envelope() + # Serialize and send response response_bytes = response.SerializeToString() await stream.write(varint.encode(len(response_bytes))) @@ -312,6 +357,26 @@ class KadDHT(Service): key = message.key logger.debug(f"Received ADD_PROVIDER for key {key.hex()}") + # Consume the source signed-peer-record if sent + if message.HasField("senderRecord"): + try: + # Convert the signed-peer-record(Envelope) from + # protobuf bytes + envelope, _ = consume_envelope( + message.senderRecord, "libp2p-peer-record" + ) + # Use the default TTL of 2 hours (72000 seconds) + if not self.host.get_peerstore().consume_peer_record( + envelope, 7200 + ): + logger.error( + "Updating the Certified-Addr-Book was unsuccessful" + ) + except Exception as e: + logger.error( + "Error updating the certified addr book for peer: %s", e + ) + # Extract provider information for provider_proto in message.providerPeers: try: @@ -341,11 +406,42 @@ class KadDHT(Service): except Exception as e: logger.warning(f"Failed to process provider info: {e}") + # Process the signed-records of provider if sent + if provider_proto.HasField("signedRecord"): + try: + # Convert the signed-peer-record(Envelope) from + # protobuf bytes + envelope, _ = consume_envelope( + provider_proto.signedRecord, + "libp2p-peer-record", + ) + # Use the default TTL of 2 hours (7200 seconds) + if not self.host.get_peerstore().consume_peer_record( # noqa + envelope, 7200 + ): + logger.error( + "Failed to update the Certified-Addr-Book" + ) + except Exception as e: + logger.error( + "Error updating the certified-addr-book for peer %s: %s", # noqa + provider_id, + e, + ) + # Send acknowledgement response = Message() response.type = Message.MessageType.ADD_PROVIDER response.key = key + # Add sender's signed-peer-record + envelope = create_signed_peer_record( + self.host.get_id(), + self.host.get_addrs(), + self.host.get_private_key(), + ) + response.senderRecord = envelope.marshal_envelope() + response_bytes = response.SerializeToString() await stream.write(varint.encode(len(response_bytes))) await stream.write(response_bytes) @@ -357,6 +453,26 @@ class KadDHT(Service): key = message.key logger.debug(f"Received GET_PROVIDERS request for key {key.hex()}") + # Consume the source signed_peer_record if sent + if message.HasField("senderRecord"): + try: + # Convert the signed-peer-record(Envelope) from + # protobuf bytes + envelope, _ = consume_envelope( + message.senderRecord, "libp2p-peer-record" + ) + # Use the defualt TTL of 2 hours (7200 seconds) + if not self.host.get_peerstore().consume_peer_record( + envelope, 7200 + ): + logger.error( + "Updating the Certified-Addr-Book was unsuccessful" + ) + except Exception as e: + logger.error( + "Error updating the certified addr book for peer: %s", e + ) + # Find providers for the key providers = self.provider_store.get_providers(key) logger.debug( @@ -368,12 +484,32 @@ class KadDHT(Service): response.type = Message.MessageType.GET_PROVIDERS response.key = key + # Create sender_signed_peer_record for the response + envelope = create_signed_peer_record( + self.host.get_id(), + self.host.get_addrs(), + self.host.get_private_key(), + ) + response.senderRecord = envelope.marshal_envelope() + # Add provider information to response for provider_info in providers: provider_proto = response.providerPeers.add() provider_proto.id = provider_info.peer_id.to_bytes() provider_proto.connection = Message.ConnectionType.CAN_CONNECT + # Add provider signed-records if cached + provider_peer_envelope = ( + self.host.get_peerstore().get_peer_record( + provider_info.peer_id + ) + ) + + if provider_peer_envelope is not None: + provider_proto.signedRecord = ( + provider_peer_envelope.marshal_envelope() + ) + # Add addresses if available for addr in provider_info.addrs: provider_proto.addrs.append(addr.to_bytes()) @@ -397,6 +533,16 @@ class KadDHT(Service): peer_proto.id = peer.to_bytes() peer_proto.connection = Message.ConnectionType.CAN_CONNECT + # Add the signed-records of closest_peers if cached + closer_peer_envelope = ( + self.host.get_peerstore().get_peer_record(peer) + ) + + if closer_peer_envelope is not None: + peer_proto.signedRecord = ( + closer_peer_envelope.marshal_envelope() + ) + # Add addresses if available try: addrs = self.host.get_peerstore().addrs(peer) @@ -417,6 +563,26 @@ class KadDHT(Service): key = message.key logger.debug(f"Received GET_VALUE request for key {key.hex()}") + # Consume the sender_signed_peer_record + if message.HasField("senderRecord"): + try: + # Convert the signed-peer-record(Envelope) from + # protobuf bytes + envelope, _ = consume_envelope( + message.senderRecord, "libp2p-peer-record" + ) + # Use the default TTL of 2 hours (7200 seconds) + if not self.host.get_peerstore().consume_peer_record( + envelope, 7200 + ): + logger.error( + "Updating teh Certified-Addr-Book was unsuccessful" + ) + except Exception as e: + logger.error( + "Error updating the certified addr book for peer: %s", e + ) + value = self.value_store.get(key) if value: logger.debug(f"Found value for key {key.hex()}") @@ -431,6 +597,14 @@ class KadDHT(Service): response.record.value = value response.record.timeReceived = str(time.time()) + # Create sender_signed_peer_record + envelope = create_signed_peer_record( + self.host.get_id(), + self.host.get_addrs(), + self.host.get_private_key(), + ) + response.senderRecord = envelope.marshal_envelope() + # Serialize and send response response_bytes = response.SerializeToString() await stream.write(varint.encode(len(response_bytes))) @@ -444,6 +618,14 @@ class KadDHT(Service): response.type = Message.MessageType.GET_VALUE response.key = key + # Create sender_signed_peer_record for the response + envelope = create_signed_peer_record( + self.host.get_id(), + self.host.get_addrs(), + self.host.get_private_key(), + ) + response.senderRecord = envelope.marshal_envelope() + # Add closest peers to key closest_peers = self.routing_table.find_local_closest_peers( key, 20 @@ -462,6 +644,16 @@ class KadDHT(Service): peer_proto.id = peer.to_bytes() peer_proto.connection = Message.ConnectionType.CAN_CONNECT + # Add signed-records of closer-peers if cached + closer_peer_envelope = ( + self.host.get_peerstore().get_peer_record(peer) + ) + + if closer_peer_envelope is not None: + peer_proto.signedRecord = ( + closer_peer_envelope.marshal_envelope() + ) + # Add addresses if available try: addrs = self.host.get_peerstore().addrs(peer) @@ -484,6 +676,27 @@ class KadDHT(Service): key = message.record.key value = message.record.value success = False + + # Consume the source signed_peer_record if sent + if message.HasField("senderRecord"): + try: + # Convert the signed-peer-record(Envelope) from + # protobuf bytes + envelope, _ = consume_envelope( + message.senderRecord, "libp2p-peer-record" + ) + # Use the default TTL of 2 hours (7200 seconds) + if not self.host.get_peerstore().consume_peer_record( + envelope, 7200 + ): + logger.error( + "Updating the certified-addr-book was unsuccessful" + ) + except Exception as e: + logger.error( + "Error updating the certified addr book for peer: %s", e + ) + try: if not (key and value): raise ValueError( @@ -504,6 +717,16 @@ class KadDHT(Service): response.type = Message.MessageType.PUT_VALUE if success: response.key = key + + # Create sender_signed_peer_record for the response + envelope = create_signed_peer_record( + self.host.get_id(), + self.host.get_addrs(), + self.host.get_private_key(), + ) + response.senderRecord = envelope.marshal_envelope() + + # Serialize and send response response_bytes = response.SerializeToString() await stream.write(varint.encode(len(response_bytes))) await stream.write(response_bytes) diff --git a/libp2p/kad_dht/pb/kademlia.proto b/libp2p/kad_dht/pb/kademlia.proto index fd198d28..7c3e5bad 100644 --- a/libp2p/kad_dht/pb/kademlia.proto +++ b/libp2p/kad_dht/pb/kademlia.proto @@ -27,6 +27,7 @@ message Message { bytes id = 1; repeated bytes addrs = 2; ConnectionType connection = 3; + optional bytes signedRecord = 4; // Envelope(PeerRecord) encoded } MessageType type = 1; @@ -35,4 +36,6 @@ message Message { Record record = 3; repeated Peer closerPeers = 8; repeated Peer providerPeers = 9; + + optional bytes senderRecord = 11; // Envelope(PeerRecord) encoded } diff --git a/libp2p/kad_dht/pb/kademlia_pb2.py b/libp2p/kad_dht/pb/kademlia_pb2.py index 781333bf..ac23169c 100644 --- a/libp2p/kad_dht/pb/kademlia_pb2.py +++ b/libp2p/kad_dht/pb/kademlia_pb2.py @@ -1,11 +1,12 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: libp2p/kad_dht/pb/kademlia.proto +# Protobuf Python Version: 4.25.3 """Generated protocol buffer code.""" -from google.protobuf.internal import builder as _builder from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -13,21 +14,21 @@ _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/kad_dht/pb/kademlia.proto\":\n\x06Record\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x14\n\x0ctimeReceived\x18\x05 \x01(\t\"\xca\x03\n\x07Message\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.Message.MessageType\x12\x17\n\x0f\x63lusterLevelRaw\x18\n \x01(\x05\x12\x0b\n\x03key\x18\x02 \x01(\x0c\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record\x12\"\n\x0b\x63loserPeers\x18\x08 \x03(\x0b\x32\r.Message.Peer\x12$\n\rproviderPeers\x18\t \x03(\x0b\x32\r.Message.Peer\x1aN\n\x04Peer\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12+\n\nconnection\x18\x03 \x01(\x0e\x32\x17.Message.ConnectionType\"i\n\x0bMessageType\x12\r\n\tPUT_VALUE\x10\x00\x12\r\n\tGET_VALUE\x10\x01\x12\x10\n\x0c\x41\x44\x44_PROVIDER\x10\x02\x12\x11\n\rGET_PROVIDERS\x10\x03\x12\r\n\tFIND_NODE\x10\x04\x12\x08\n\x04PING\x10\x05\"W\n\x0e\x43onnectionType\x12\x11\n\rNOT_CONNECTED\x10\x00\x12\r\n\tCONNECTED\x10\x01\x12\x0f\n\x0b\x43\x41N_CONNECT\x10\x02\x12\x12\n\x0e\x43\x41NNOT_CONNECT\x10\x03\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/kad_dht/pb/kademlia.proto\":\n\x06Record\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x14\n\x0ctimeReceived\x18\x05 \x01(\t\"\xa2\x04\n\x07Message\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.Message.MessageType\x12\x17\n\x0f\x63lusterLevelRaw\x18\n \x01(\x05\x12\x0b\n\x03key\x18\x02 \x01(\x0c\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record\x12\"\n\x0b\x63loserPeers\x18\x08 \x03(\x0b\x32\r.Message.Peer\x12$\n\rproviderPeers\x18\t \x03(\x0b\x32\r.Message.Peer\x12\x19\n\x0csenderRecord\x18\x0b \x01(\x0cH\x00\x88\x01\x01\x1az\n\x04Peer\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12+\n\nconnection\x18\x03 \x01(\x0e\x32\x17.Message.ConnectionType\x12\x19\n\x0csignedRecord\x18\x04 \x01(\x0cH\x00\x88\x01\x01\x42\x0f\n\r_signedRecord\"i\n\x0bMessageType\x12\r\n\tPUT_VALUE\x10\x00\x12\r\n\tGET_VALUE\x10\x01\x12\x10\n\x0c\x41\x44\x44_PROVIDER\x10\x02\x12\x11\n\rGET_PROVIDERS\x10\x03\x12\r\n\tFIND_NODE\x10\x04\x12\x08\n\x04PING\x10\x05\"W\n\x0e\x43onnectionType\x12\x11\n\rNOT_CONNECTED\x10\x00\x12\r\n\tCONNECTED\x10\x01\x12\x0f\n\x0b\x43\x41N_CONNECT\x10\x02\x12\x12\n\x0e\x43\x41NNOT_CONNECT\x10\x03\x42\x0f\n\r_senderRecordb\x06proto3') -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.kad_dht.pb.kademlia_pb2', globals()) +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.kad_dht.pb.kademlia_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _RECORD._serialized_start=36 - _RECORD._serialized_end=94 - _MESSAGE._serialized_start=97 - _MESSAGE._serialized_end=555 - _MESSAGE_PEER._serialized_start=281 - _MESSAGE_PEER._serialized_end=359 - _MESSAGE_MESSAGETYPE._serialized_start=361 - _MESSAGE_MESSAGETYPE._serialized_end=466 - _MESSAGE_CONNECTIONTYPE._serialized_start=468 - _MESSAGE_CONNECTIONTYPE._serialized_end=555 + _globals['_RECORD']._serialized_start=36 + _globals['_RECORD']._serialized_end=94 + _globals['_MESSAGE']._serialized_start=97 + _globals['_MESSAGE']._serialized_end=643 + _globals['_MESSAGE_PEER']._serialized_start=308 + _globals['_MESSAGE_PEER']._serialized_end=430 + _globals['_MESSAGE_MESSAGETYPE']._serialized_start=432 + _globals['_MESSAGE_MESSAGETYPE']._serialized_end=537 + _globals['_MESSAGE_CONNECTIONTYPE']._serialized_start=539 + _globals['_MESSAGE_CONNECTIONTYPE']._serialized_end=626 # @@protoc_insertion_point(module_scope) diff --git a/libp2p/kad_dht/pb/kademlia_pb2.pyi b/libp2p/kad_dht/pb/kademlia_pb2.pyi index c8f16db2..6d80d77d 100644 --- a/libp2p/kad_dht/pb/kademlia_pb2.pyi +++ b/libp2p/kad_dht/pb/kademlia_pb2.pyi @@ -1,133 +1,70 @@ -""" -@generated by mypy-protobuf. Do not edit manually! -isort:skip_file -""" +from google.protobuf.internal import containers as _containers +from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union -import builtins -import collections.abc -import google.protobuf.descriptor -import google.protobuf.internal.containers -import google.protobuf.internal.enum_type_wrapper -import google.protobuf.message -import sys -import typing +DESCRIPTOR: _descriptor.FileDescriptor -if sys.version_info >= (3, 10): - import typing as typing_extensions -else: - import typing_extensions +class Record(_message.Message): + __slots__ = ("key", "value", "timeReceived") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + TIMERECEIVED_FIELD_NUMBER: _ClassVar[int] + key: bytes + value: bytes + timeReceived: str + def __init__(self, key: _Optional[bytes] = ..., value: _Optional[bytes] = ..., timeReceived: _Optional[str] = ...) -> None: ... -DESCRIPTOR: google.protobuf.descriptor.FileDescriptor - -@typing.final -class Record(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - KEY_FIELD_NUMBER: builtins.int - VALUE_FIELD_NUMBER: builtins.int - TIMERECEIVED_FIELD_NUMBER: builtins.int - key: builtins.bytes - value: builtins.bytes - timeReceived: builtins.str - def __init__( - self, - *, - key: builtins.bytes = ..., - value: builtins.bytes = ..., - timeReceived: builtins.str = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["key", b"key", "timeReceived", b"timeReceived", "value", b"value"]) -> None: ... - -global___Record = Record - -@typing.final -class Message(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - class _MessageType: - ValueType = typing.NewType("ValueType", builtins.int) - V: typing_extensions.TypeAlias = ValueType - - class _MessageTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._MessageType.ValueType], builtins.type): - DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor - PUT_VALUE: Message._MessageType.ValueType # 0 - GET_VALUE: Message._MessageType.ValueType # 1 - ADD_PROVIDER: Message._MessageType.ValueType # 2 - GET_PROVIDERS: Message._MessageType.ValueType # 3 - FIND_NODE: Message._MessageType.ValueType # 4 - PING: Message._MessageType.ValueType # 5 - - class MessageType(_MessageType, metaclass=_MessageTypeEnumTypeWrapper): ... - PUT_VALUE: Message.MessageType.ValueType # 0 - GET_VALUE: Message.MessageType.ValueType # 1 - ADD_PROVIDER: Message.MessageType.ValueType # 2 - GET_PROVIDERS: Message.MessageType.ValueType # 3 - FIND_NODE: Message.MessageType.ValueType # 4 - PING: Message.MessageType.ValueType # 5 - - class _ConnectionType: - ValueType = typing.NewType("ValueType", builtins.int) - V: typing_extensions.TypeAlias = ValueType - - class _ConnectionTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._ConnectionType.ValueType], builtins.type): - DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor - NOT_CONNECTED: Message._ConnectionType.ValueType # 0 - CONNECTED: Message._ConnectionType.ValueType # 1 - CAN_CONNECT: Message._ConnectionType.ValueType # 2 - CANNOT_CONNECT: Message._ConnectionType.ValueType # 3 - - class ConnectionType(_ConnectionType, metaclass=_ConnectionTypeEnumTypeWrapper): ... - NOT_CONNECTED: Message.ConnectionType.ValueType # 0 - CONNECTED: Message.ConnectionType.ValueType # 1 - CAN_CONNECT: Message.ConnectionType.ValueType # 2 - CANNOT_CONNECT: Message.ConnectionType.ValueType # 3 - - @typing.final - class Peer(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - ID_FIELD_NUMBER: builtins.int - ADDRS_FIELD_NUMBER: builtins.int - CONNECTION_FIELD_NUMBER: builtins.int - id: builtins.bytes - connection: global___Message.ConnectionType.ValueType - @property - def addrs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ... - def __init__( - self, - *, - id: builtins.bytes = ..., - addrs: collections.abc.Iterable[builtins.bytes] | None = ..., - connection: global___Message.ConnectionType.ValueType = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["addrs", b"addrs", "connection", b"connection", "id", b"id"]) -> None: ... - - TYPE_FIELD_NUMBER: builtins.int - CLUSTERLEVELRAW_FIELD_NUMBER: builtins.int - KEY_FIELD_NUMBER: builtins.int - RECORD_FIELD_NUMBER: builtins.int - CLOSERPEERS_FIELD_NUMBER: builtins.int - PROVIDERPEERS_FIELD_NUMBER: builtins.int - type: global___Message.MessageType.ValueType - clusterLevelRaw: builtins.int - key: builtins.bytes - @property - def record(self) -> global___Record: ... - @property - def closerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ... - @property - def providerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ... - def __init__( - self, - *, - type: global___Message.MessageType.ValueType = ..., - clusterLevelRaw: builtins.int = ..., - key: builtins.bytes = ..., - record: global___Record | None = ..., - closerPeers: collections.abc.Iterable[global___Message.Peer] | None = ..., - providerPeers: collections.abc.Iterable[global___Message.Peer] | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["record", b"record"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["closerPeers", b"closerPeers", "clusterLevelRaw", b"clusterLevelRaw", "key", b"key", "providerPeers", b"providerPeers", "record", b"record", "type", b"type"]) -> None: ... - -global___Message = Message +class Message(_message.Message): + __slots__ = ("type", "clusterLevelRaw", "key", "record", "closerPeers", "providerPeers", "senderRecord") + class MessageType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + PUT_VALUE: _ClassVar[Message.MessageType] + GET_VALUE: _ClassVar[Message.MessageType] + ADD_PROVIDER: _ClassVar[Message.MessageType] + GET_PROVIDERS: _ClassVar[Message.MessageType] + FIND_NODE: _ClassVar[Message.MessageType] + PING: _ClassVar[Message.MessageType] + PUT_VALUE: Message.MessageType + GET_VALUE: Message.MessageType + ADD_PROVIDER: Message.MessageType + GET_PROVIDERS: Message.MessageType + FIND_NODE: Message.MessageType + PING: Message.MessageType + class ConnectionType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + NOT_CONNECTED: _ClassVar[Message.ConnectionType] + CONNECTED: _ClassVar[Message.ConnectionType] + CAN_CONNECT: _ClassVar[Message.ConnectionType] + CANNOT_CONNECT: _ClassVar[Message.ConnectionType] + NOT_CONNECTED: Message.ConnectionType + CONNECTED: Message.ConnectionType + CAN_CONNECT: Message.ConnectionType + CANNOT_CONNECT: Message.ConnectionType + class Peer(_message.Message): + __slots__ = ("id", "addrs", "connection", "signedRecord") + ID_FIELD_NUMBER: _ClassVar[int] + ADDRS_FIELD_NUMBER: _ClassVar[int] + CONNECTION_FIELD_NUMBER: _ClassVar[int] + SIGNEDRECORD_FIELD_NUMBER: _ClassVar[int] + id: bytes + addrs: _containers.RepeatedScalarFieldContainer[bytes] + connection: Message.ConnectionType + signedRecord: bytes + def __init__(self, id: _Optional[bytes] = ..., addrs: _Optional[_Iterable[bytes]] = ..., connection: _Optional[_Union[Message.ConnectionType, str]] = ..., signedRecord: _Optional[bytes] = ...) -> None: ... + TYPE_FIELD_NUMBER: _ClassVar[int] + CLUSTERLEVELRAW_FIELD_NUMBER: _ClassVar[int] + KEY_FIELD_NUMBER: _ClassVar[int] + RECORD_FIELD_NUMBER: _ClassVar[int] + CLOSERPEERS_FIELD_NUMBER: _ClassVar[int] + PROVIDERPEERS_FIELD_NUMBER: _ClassVar[int] + SENDERRECORD_FIELD_NUMBER: _ClassVar[int] + type: Message.MessageType + clusterLevelRaw: int + key: bytes + record: Record + closerPeers: _containers.RepeatedCompositeFieldContainer[Message.Peer] + providerPeers: _containers.RepeatedCompositeFieldContainer[Message.Peer] + senderRecord: bytes + def __init__(self, type: _Optional[_Union[Message.MessageType, str]] = ..., clusterLevelRaw: _Optional[int] = ..., key: _Optional[bytes] = ..., record: _Optional[_Union[Record, _Mapping]] = ..., closerPeers: _Optional[_Iterable[_Union[Message.Peer, _Mapping]]] = ..., providerPeers: _Optional[_Iterable[_Union[Message.Peer, _Mapping]]] = ..., senderRecord: _Optional[bytes] = ...) -> None: ... # type: ignore diff --git a/libp2p/kad_dht/peer_routing.py b/libp2p/kad_dht/peer_routing.py index c4a066f7..dc3190a5 100644 --- a/libp2p/kad_dht/peer_routing.py +++ b/libp2p/kad_dht/peer_routing.py @@ -15,12 +15,14 @@ from libp2p.abc import ( INetStream, IPeerRouting, ) +from libp2p.peer.envelope import Envelope, consume_envelope from libp2p.peer.id import ( ID, ) from libp2p.peer.peerinfo import ( PeerInfo, ) +from libp2p.peer.peerstore import create_signed_peer_record from .common import ( ALPHA, @@ -255,6 +257,14 @@ class PeerRouting(IPeerRouting): find_node_msg.type = Message.MessageType.FIND_NODE find_node_msg.key = target_key # Set target key directly as bytes + print("MESSAGE GOING TO BE CREATED") + + # Create sender_signed_peer_record + envelope = create_signed_peer_record( + self.host.get_id(), self.host.get_addrs(), self.host.get_private_key() + ) + find_node_msg.senderRecord = envelope.marshal_envelope() + # Serialize and send the protobuf message with varint length prefix proto_bytes = find_node_msg.SerializeToString() logger.debug( @@ -299,6 +309,26 @@ class PeerRouting(IPeerRouting): # Process closest peers from response if response_msg.type == Message.MessageType.FIND_NODE: + # Consume the sender_signed_peer_record + if response_msg.HasField("senderRecord"): + try: + # Convert the signed-peer-record(Envelope) from + # protobuf bytes + envelope, _ = consume_envelope( + response_msg.senderRecord, "libp2p-peer-record" + ) + # Use the default TTL of 2 hours (7200 seconds) + if not self.host.get_peerstore().consume_peer_record( + envelope, 7200 + ): + logger.error( + "Updating teh Certified-Addr-Book was unsuccessful" + ) + except Exception as e: + logger.error( + "Error updating the certified addr book for peer: %s", e + ) + for peer_data in response_msg.closerPeers: new_peer_id = ID(peer_data.id) if new_peer_id not in results: @@ -311,7 +341,29 @@ class PeerRouting(IPeerRouting): addrs = [Multiaddr(addr) for addr in peer_data.addrs] self.host.get_peerstore().add_addrs(new_peer_id, addrs, 3600) + # Consume the received closer_peers signed-records + if peer_data.HasField("signedRecord"): + try: + # Convert the signed-peer-record(Envelope) from + # protobuf bytes + envelope, _ = consume_envelope( + peer_data.signedRecord, + "libp2p-peer-record", + ) + # Use the default TTL of 2 hours (7200 seconds) + if not self.host.get_peerstore().consume_peer_record( + envelope, 7200 + ): + logger.error("Failed to update certified-addr-book") + except Exception as e: + logger.error( + "Error updating the certified-addr-book for peer %s: %s", # noqa + new_peer_id, + e, + ) + except Exception as e: + print("EXCEPTION CAME") logger.debug(f"Error querying peer {peer} for closest: {e}") finally: @@ -345,10 +397,31 @@ class PeerRouting(IPeerRouting): # Parse protobuf message kad_message = Message() + closer_peer_envelope: Envelope | None = None try: kad_message.ParseFromString(message_bytes) if kad_message.type == Message.MessageType.FIND_NODE: + # Consume the sender's signed-peer-record if sent + if kad_message.HasField("senderRecord"): + try: + # Convert the signed-peer-record(Envelope) from + # protobuf bytes + envelope, _ = consume_envelope( + kad_message.senderRecord, "libp2p-peer-record" + ) + # Use the default TTL of 2 hours (7200 seconds) + if not self.host.get_peerstore().consume_peer_record( + envelope, 7200 + ): + logger.error( + "Updating the Certified-Addr-Book was unsuccessful" + ) + except Exception as e: + logger.error( + "Error updating the certified addr book for peer: %s", e + ) + # Get target key directly from protobuf message target_key = kad_message.key @@ -361,12 +434,30 @@ class PeerRouting(IPeerRouting): response = Message() response.type = Message.MessageType.FIND_NODE + # Create sender_signed_peer_record for the response + envelope = create_signed_peer_record( + self.host.get_id(), + self.host.get_addrs(), + self.host.get_private_key(), + ) + response.senderRecord = envelope.marshal_envelope() + # Add peer information to response for peer_id in closest_peers: peer_proto = response.closerPeers.add() peer_proto.id = peer_id.to_bytes() peer_proto.connection = Message.ConnectionType.CAN_CONNECT + # Add the signed-records of closest_peers if cached + closer_peer_envelope = ( + self.host.get_peerstore().get_peer_record(peer_id) + ) + + if isinstance(closer_peer_envelope, Envelope): + peer_proto.signedRecord = ( + closer_peer_envelope.marshal_envelope() + ) + # Add addresses if available try: addrs = self.host.get_peerstore().addrs(peer_id) diff --git a/libp2p/kad_dht/provider_store.py b/libp2p/kad_dht/provider_store.py index 5c34f0c7..c5800914 100644 --- a/libp2p/kad_dht/provider_store.py +++ b/libp2p/kad_dht/provider_store.py @@ -22,12 +22,14 @@ from libp2p.abc import ( from libp2p.custom_types import ( TProtocol, ) +from libp2p.peer.envelope import consume_envelope from libp2p.peer.id import ( ID, ) from libp2p.peer.peerinfo import ( PeerInfo, ) +from libp2p.peer.peerstore import create_signed_peer_record from .common import ( ALPHA, @@ -240,11 +242,22 @@ class ProviderStore: message.type = Message.MessageType.ADD_PROVIDER message.key = key + # Create sender's signed-peer-record + envelope = create_signed_peer_record( + self.host.get_id(), + self.host.get_addrs(), + self.host.get_private_key(), + ) + message.senderRecord = envelope.marshal_envelope() + # Add our provider info provider = message.providerPeers.add() provider.id = self.local_peer_id.to_bytes() provider.addrs.extend(addrs) + # Add the provider's signed-peer-record + provider.signedRecord = envelope.marshal_envelope() + # Serialize and send the message proto_bytes = message.SerializeToString() await stream.write(varint.encode(len(proto_bytes))) @@ -276,9 +289,27 @@ class ProviderStore: response = Message() response.ParseFromString(response_bytes) - # Check response type - response.type == Message.MessageType.ADD_PROVIDER - if response.type: + if response.type == Message.MessageType.ADD_PROVIDER: + # Consume the sender's signed-peer-record if sent + if response.HasField("senderRecord"): + try: + # Convert the signed-peer-record(Envelope) from + # protobuf bytes + envelope, _ = consume_envelope( + response.senderRecord, "libp2p-peer-record" + ) + # Use the defualt TTL of 2 hours (7200 seconds) + if not self.host.get_peerstore().consume_peer_record( + envelope, 7200 + ): + logger.error( + "Updating the Certified-Addr-Book was unsuccessful" + ) + except Exception as e: + logger.error( + "Error updating the certified addr book for peer: %s", e + ) + result = True except Exception as e: @@ -380,6 +411,14 @@ class ProviderStore: message.type = Message.MessageType.GET_PROVIDERS message.key = key + # Create sender's signed-peer-record + envelope = create_signed_peer_record( + self.host.get_id(), + self.host.get_addrs(), + self.host.get_private_key(), + ) + message.senderRecord = envelope.marshal_envelope() + # Serialize and send the message proto_bytes = message.SerializeToString() await stream.write(varint.encode(len(proto_bytes))) @@ -414,6 +453,26 @@ class ProviderStore: if response.type != Message.MessageType.GET_PROVIDERS: return [] + # Consume the sender's signed-peer-record if sent + if response.HasField("senderRecord"): + try: + # Convert the signed-peer-record(Envelope) from + # protobuf bytes + envelope, _ = consume_envelope( + response.senderRecord, "libp2p-peer-record" + ) + # Use the defualt TTL of 2 hours (7200 seconds) + if not self.host.get_peerstore().consume_peer_record( + envelope, 7200 + ): + logger.error( + "Updating the Certified-Addr-Book was unsuccessful" + ) + except Exception as e: + logger.error( + "Error updating the certified addr book for peer: %s", e + ) + # Extract provider information providers = [] for provider_proto in response.providerPeers: @@ -431,6 +490,30 @@ class ProviderStore: # Create PeerInfo and add to result providers.append(PeerInfo(provider_id, addrs)) + + # Consume the provider's signed-peer-record if sent + if provider_proto.HasField("signedRecord"): + try: + # Convert the signed-peer-record(Envelope) from + # protobuf bytes + envelope, _ = consume_envelope( + provider_proto.signedRecord, + "libp2p-peer-record", + ) + # Use the default TTL of 2 hours (7200 seconds) + if not self.host.get_peerstore().consume_peer_record( # noqa + envelope, 7200 + ): + logger.error( + "Failed to update the Certified-Addr-Book" + ) + except Exception as e: + logger.error( + "Error updating the certified-addr-book for peer %s: %s", # noqa + provider_id, + e, + ) + except Exception as e: logger.warning(f"Failed to parse provider info: {e}") diff --git a/libp2p/kad_dht/value_store.py b/libp2p/kad_dht/value_store.py index b79425fd..28cc6d8c 100644 --- a/libp2p/kad_dht/value_store.py +++ b/libp2p/kad_dht/value_store.py @@ -15,9 +15,11 @@ from libp2p.abc import ( from libp2p.custom_types import ( TProtocol, ) +from libp2p.peer.envelope import consume_envelope from libp2p.peer.id import ( ID, ) +from libp2p.peer.peerstore import create_signed_peer_record from .common import ( DEFAULT_TTL, @@ -110,6 +112,14 @@ class ValueStore: message = Message() message.type = Message.MessageType.PUT_VALUE + # Create sender's signed-peer-record + envelope = create_signed_peer_record( + self.host.get_id(), + self.host.get_addrs(), + self.host.get_private_key(), + ) + message.senderRecord = envelope.marshal_envelope() + # Set message fields message.key = key message.record.key = key @@ -155,7 +165,27 @@ class ValueStore: # Check if response is valid if response.type == Message.MessageType.PUT_VALUE: - if response.key: + # Consume the sender's signed-peer-record if sent + if response.HasField("senderRecord"): + try: + # Convert the signed-peer-record(Envelope) from + # protobuf bytes + envelope, _ = consume_envelope( + response.senderRecord, "libp2p-peer-record" + ) + # Use the default TTL of 2 hours (7200 seconds) + if not self.host.get_peerstore().consume_peer_record( + envelope, 7200 + ): + logger.error( + "Updating the certified-addr-book was unsuccessful" + ) + except Exception as e: + logger.error( + "Error updating the certified addr book for peer: %s", e + ) + + if response.key == key: result = True return result @@ -231,6 +261,14 @@ class ValueStore: message.type = Message.MessageType.GET_VALUE message.key = key + # Create sender's signed-peer-record + envelope = create_signed_peer_record( + self.host.get_id(), + self.host.get_addrs(), + self.host.get_private_key(), + ) + message.senderRecord = envelope.marshal_envelope() + # Serialize and send the protobuf message proto_bytes = message.SerializeToString() await stream.write(varint.encode(len(proto_bytes))) @@ -275,6 +313,26 @@ class ValueStore: and response.HasField("record") and response.record.value ): + # Consume the sender's signed-peer-record + if response.HasField("senderRecord"): + try: + # Convert the signed-peer-record(Envelope) from + # protobuf bytes + envelope, _ = consume_envelope( + response.senderRecord, "libp2p-peer-record" + ) + # Use the default TTL of 2 hours (7200 seconds) + if not self.host.get_peerstore().consume_peer_record( + envelope, 7200 + ): + logger.error( + "Updating the certified-addr-book was unsuccessful" + ) + except Exception as e: + logger.error( + "Error updating the certified addr book for peer: %s", e + ) + logger.debug( f"Received value for key {key.hex()} from peer {peer_id}" ) diff --git a/libp2p/peer/peerstore.py b/libp2p/peer/peerstore.py index 043aaf0d..4669e9ec 100644 --- a/libp2p/peer/peerstore.py +++ b/libp2p/peer/peerstore.py @@ -23,7 +23,8 @@ from libp2p.crypto.keys import ( PrivateKey, PublicKey, ) -from libp2p.peer.envelope import Envelope +from libp2p.peer.envelope import Envelope, seal_record +from libp2p.peer.peer_record import PeerRecord from .id import ( ID, @@ -39,6 +40,17 @@ from .peerinfo import ( PERMANENT_ADDR_TTL = 0 +def create_signed_peer_record( + peer_id: ID, addrs: list[Multiaddr], pvt_key: PrivateKey +) -> Envelope: + """Creates a signed_peer_record wrapped in an Envelope""" + record = PeerRecord(peer_id, addrs) + envelope = seal_record(record, pvt_key) + + print(envelope) + return envelope + + class PeerRecordState: envelope: Envelope seq: int diff --git a/tests/core/kad_dht/test_unit_peer_routing.py b/tests/core/kad_dht/test_unit_peer_routing.py index ffe20655..6e15ce7e 100644 --- a/tests/core/kad_dht/test_unit_peer_routing.py +++ b/tests/core/kad_dht/test_unit_peer_routing.py @@ -57,7 +57,10 @@ class TestPeerRouting: def mock_host(self): """Create a mock host for testing.""" host = Mock() - host.get_id.return_value = create_valid_peer_id("local") + key_pair = create_new_key_pair() + host.get_id.return_value = ID.from_pubkey(key_pair.public_key) + host.get_public_key.return_value = key_pair.public_key + host.get_private_key.return_value = key_pair.private_key host.get_addrs.return_value = [Multiaddr("/ip4/127.0.0.1/tcp/8000")] host.get_peerstore.return_value = Mock() host.new_stream = AsyncMock() From d1792588f9fcfb074bdbb873447d726358087d97 Mon Sep 17 00:00:00 2001 From: lla-dane Date: Sun, 10 Aug 2025 14:59:55 +0530 Subject: [PATCH 26/71] added tests for signed-peee-record transfer in kad-dht --- libp2p/kad_dht/kad_dht.py | 4 +- libp2p/kad_dht/peer_routing.py | 2 - libp2p/peer/peerstore.py | 2 - tests/core/kad_dht/test_kad_dht.py | 135 ++++++++++++++++++++++++++++- 4 files changed, 137 insertions(+), 6 deletions(-) diff --git a/libp2p/kad_dht/kad_dht.py b/libp2p/kad_dht/kad_dht.py index 2fb42662..f510390d 100644 --- a/libp2p/kad_dht/kad_dht.py +++ b/libp2p/kad_dht/kad_dht.py @@ -271,7 +271,6 @@ class KadDHT(Service): # Handle FIND_NODE message if message.type == Message.MessageType.FIND_NODE: # Get target key directly from protobuf - print("FIND NODE RECEIVED") target_key = message.key # Find closest peers to the target key @@ -353,6 +352,7 @@ class KadDHT(Service): # Handle ADD_PROVIDER message elif message.type == Message.MessageType.ADD_PROVIDER: + print("ADD_PROVIDER REQ RECEIVED") # Process ADD_PROVIDER key = message.key logger.debug(f"Received ADD_PROVIDER for key {key.hex()}") @@ -449,6 +449,7 @@ class KadDHT(Service): # Handle GET_PROVIDERS message elif message.type == Message.MessageType.GET_PROVIDERS: + print("GET_PROVIDERS REQ RECIEVED") # Process GET_PROVIDERS key = message.key logger.debug(f"Received GET_PROVIDERS request for key {key.hex()}") @@ -559,6 +560,7 @@ class KadDHT(Service): # Handle GET_VALUE message elif message.type == Message.MessageType.GET_VALUE: + print("GET VALUE REQ RECEIVED") # Process GET_VALUE key = message.key logger.debug(f"Received GET_VALUE request for key {key.hex()}") diff --git a/libp2p/kad_dht/peer_routing.py b/libp2p/kad_dht/peer_routing.py index dc3190a5..a2f3d193 100644 --- a/libp2p/kad_dht/peer_routing.py +++ b/libp2p/kad_dht/peer_routing.py @@ -257,8 +257,6 @@ class PeerRouting(IPeerRouting): find_node_msg.type = Message.MessageType.FIND_NODE find_node_msg.key = target_key # Set target key directly as bytes - print("MESSAGE GOING TO BE CREATED") - # Create sender_signed_peer_record envelope = create_signed_peer_record( self.host.get_id(), self.host.get_addrs(), self.host.get_private_key() diff --git a/libp2p/peer/peerstore.py b/libp2p/peer/peerstore.py index 4669e9ec..0faccb45 100644 --- a/libp2p/peer/peerstore.py +++ b/libp2p/peer/peerstore.py @@ -46,8 +46,6 @@ def create_signed_peer_record( """Creates a signed_peer_record wrapped in an Envelope""" record = PeerRecord(peer_id, addrs) envelope = seal_record(record, pvt_key) - - print(envelope) return envelope diff --git a/tests/core/kad_dht/test_kad_dht.py b/tests/core/kad_dht/test_kad_dht.py index a6f73074..eaf9a956 100644 --- a/tests/core/kad_dht/test_kad_dht.py +++ b/tests/core/kad_dht/test_kad_dht.py @@ -21,6 +21,7 @@ from libp2p.kad_dht.kad_dht import ( from libp2p.kad_dht.utils import ( create_key_from_binary, ) +from libp2p.peer.envelope import Envelope from libp2p.peer.peerinfo import ( PeerInfo, ) @@ -80,6 +81,16 @@ async def test_find_node(dht_pair: tuple[KadDHT, KadDHT]): with trio.fail_after(TEST_TIMEOUT): found_info = await dht_a.find_peer(dht_b.host.get_id()) + # Verifies if the senderRecord in the FIND_NODE request is correctly processed + assert isinstance( + dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id()), Envelope + ) + + # Verifies if the senderRecord in the FIND_NODE response is correctly proccessed + assert isinstance( + dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id()), Envelope + ) + # Verify that the found peer has the correct peer ID assert found_info is not None, "Failed to find the target peer" assert found_info.peer_id == dht_b.host.get_id(), "Found incorrect peer ID" @@ -104,14 +115,44 @@ async def test_put_and_get_value(dht_pair: tuple[KadDHT, KadDHT]): await dht_a.routing_table.add_peer(peer_b_info) print("Routing table of a has ", dht_a.routing_table.get_peer_ids()) + # An extra FIND_NODE req is sent between the 2 nodes while dht creation, + # so both the nodes will have records of each other before PUT_VALUE req is sent + envelope_a = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id()) + envelope_b = dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id()) + + assert isinstance(envelope_a, Envelope) + assert isinstance(envelope_b, Envelope) + + record_a = envelope_a.record() + record_b = envelope_b.record() + # Store the value using the first node (this will also store locally) with trio.fail_after(TEST_TIMEOUT): await dht_a.put_value(key, value) + # These are the records that were sent betweeen the peers during the PUT_VALUE req + envelope_a_put_value = dht_a.host.get_peerstore().get_peer_record( + dht_b.host.get_id() + ) + envelope_b_put_value = dht_b.host.get_peerstore().get_peer_record( + dht_a.host.get_id() + ) + + assert isinstance(envelope_a_put_value, Envelope) + assert isinstance(envelope_b_put_value, Envelope) + + record_a_put_value = envelope_a_put_value.record() + record_b_put_value = envelope_b_put_value.record() + + # This proves that both the records are different, and a new signed record + # was passed between the peers during PUT_VALUE exceution, which proves the + # signed-record transfer works correctly in PUT_VALUE executions. + assert record_a.seq < record_a_put_value.seq + assert record_b.seq < record_b_put_value.seq + # # Log debugging information logger.debug("Put value with key %s...", key.hex()[:10]) logger.debug("Node A value store: %s", dht_a.value_store.store) - print("hello test") # # Allow more time for the value to propagate await trio.sleep(0.5) @@ -126,6 +167,26 @@ async def test_put_and_get_value(dht_pair: tuple[KadDHT, KadDHT]): print("the value stored in node b is", dht_b.get_value_store_size()) logger.debug("Retrieved value: %s", retrieved_value) + # These are the records that were sent betweeen the peers during the PUT_VALUE req + envelope_a_get_value = dht_a.host.get_peerstore().get_peer_record( + dht_b.host.get_id() + ) + envelope_b_get_value = dht_b.host.get_peerstore().get_peer_record( + dht_a.host.get_id() + ) + + assert isinstance(envelope_a_get_value, Envelope) + assert isinstance(envelope_b_get_value, Envelope) + + record_a_get_value = envelope_a_get_value.record() + record_b_get_value = envelope_b_get_value.record() + + # This proves that there was no record exchange between the nodes during GET_VALUE + # execution, as dht_b already had the key/value pair stored locally after the + # PUT_VALUE execution. + assert record_a_get_value.seq == record_a_put_value.seq + assert record_b_get_value.seq == record_b_put_value.seq + # Verify that the retrieved value matches the original assert retrieved_value == value, "Retrieved value does not match the stored value" @@ -142,11 +203,43 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]): # Store content on the first node dht_a.value_store.put(content_id, content) + # An extra FIND_NODE req is sent between the 2 nodes while dht creation, + # so both the nodes will have records of each other before PUT_VALUE req is sent + envelope_a = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id()) + envelope_b = dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id()) + + assert isinstance(envelope_a, Envelope) + assert isinstance(envelope_b, Envelope) + + record_a = envelope_a.record() + record_b = envelope_b.record() + # Advertise the first node as a provider with trio.fail_after(TEST_TIMEOUT): success = await dht_a.provide(content_id) assert success, "Failed to advertise as provider" + # These are the records that were sent betweeen the peers during + # the ADD_PROVIDER req + envelope_a_add_prov = dht_a.host.get_peerstore().get_peer_record( + dht_b.host.get_id() + ) + envelope_b_add_prov = dht_b.host.get_peerstore().get_peer_record( + dht_a.host.get_id() + ) + + assert isinstance(envelope_a_add_prov, Envelope) + assert isinstance(envelope_b_add_prov, Envelope) + + record_a_add_prov = envelope_a_add_prov.record() + record_b_add_prov = envelope_b_add_prov.record() + + # This proves that both the records are different, and a new signed record + # was passed between the peers during ADD_PROVIDER exceution, which proves the + # signed-record transfer works correctly in ADD_PROVIDER executions. + assert record_a.seq < record_a_add_prov.seq + assert record_b.seq < record_b_add_prov.seq + # Allow time for the provider record to propagate await trio.sleep(0.1) @@ -154,6 +247,26 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]): with trio.fail_after(TEST_TIMEOUT): providers = await dht_b.find_providers(content_id) + # These are the records in each peer after the find_provider execution + envelope_a_find_prov = dht_a.host.get_peerstore().get_peer_record( + dht_b.host.get_id() + ) + envelope_b_find_prov = dht_b.host.get_peerstore().get_peer_record( + dht_a.host.get_id() + ) + + assert isinstance(envelope_a_find_prov, Envelope) + assert isinstance(envelope_b_find_prov, Envelope) + + record_a_find_prov = envelope_a_find_prov.record() + record_b_find_prov = envelope_b_find_prov.record() + + # This proves that both the records are same, as the dht_b already + # has the provider record for the content_id, after the ADD_PROVIDER + # advertisement by dht_a + assert record_a_find_prov.seq == record_a_add_prov.seq + assert record_b_find_prov.seq == record_b_add_prov.seq + # Verify that we found the first node as a provider assert providers, "No providers found" assert any(p.peer_id == dht_a.local_peer_id for p in providers), ( @@ -166,3 +279,23 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]): assert retrieved_value == content, ( "Retrieved content does not match the original" ) + + # These are the record state of each peer aftet the GET_VALUE execution + envelope_a_get_value = dht_a.host.get_peerstore().get_peer_record( + dht_b.host.get_id() + ) + envelope_b_get_value = dht_b.host.get_peerstore().get_peer_record( + dht_a.host.get_id() + ) + + assert isinstance(envelope_a_get_value, Envelope) + assert isinstance(envelope_b_get_value, Envelope) + + record_a_get_value = envelope_a_get_value.record() + record_b_get_value = envelope_b_get_value.record() + + # This proves that both the records are different, meaning that there was + # a new signed-record tranfer during the GET_VALUE execution by dht_b, which means + # the signed-record transfer works correctly in GET_VALUE executions. + assert record_a_find_prov.seq < record_a_get_value.seq + assert record_b_find_prov.seq < record_b_get_value.seq From 5ab68026d639be4617ebe4411897b20b68965762 Mon Sep 17 00:00:00 2001 From: lla-dane Date: Sun, 10 Aug 2025 15:02:39 +0530 Subject: [PATCH 27/71] removed redundant logs --- libp2p/kad_dht/kad_dht.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/libp2p/kad_dht/kad_dht.py b/libp2p/kad_dht/kad_dht.py index f510390d..2a3a2b1a 100644 --- a/libp2p/kad_dht/kad_dht.py +++ b/libp2p/kad_dht/kad_dht.py @@ -352,7 +352,6 @@ class KadDHT(Service): # Handle ADD_PROVIDER message elif message.type == Message.MessageType.ADD_PROVIDER: - print("ADD_PROVIDER REQ RECEIVED") # Process ADD_PROVIDER key = message.key logger.debug(f"Received ADD_PROVIDER for key {key.hex()}") @@ -449,7 +448,6 @@ class KadDHT(Service): # Handle GET_PROVIDERS message elif message.type == Message.MessageType.GET_PROVIDERS: - print("GET_PROVIDERS REQ RECIEVED") # Process GET_PROVIDERS key = message.key logger.debug(f"Received GET_PROVIDERS request for key {key.hex()}") @@ -560,7 +558,6 @@ class KadDHT(Service): # Handle GET_VALUE message elif message.type == Message.MessageType.GET_VALUE: - print("GET VALUE REQ RECEIVED") # Process GET_VALUE key = message.key logger.debug(f"Received GET_VALUE request for key {key.hex()}") From a21d9e878bc97f0f51ea756438c129df5f057e38 Mon Sep 17 00:00:00 2001 From: lla-dane Date: Mon, 11 Aug 2025 09:48:45 +0530 Subject: [PATCH 28/71] recompile protobuf schema and remove typos --- libp2p/kad_dht/kad_dht.py | 2 +- libp2p/kad_dht/peer_routing.py | 1 - tests/core/kad_dht/test_kad_dht.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/libp2p/kad_dht/kad_dht.py b/libp2p/kad_dht/kad_dht.py index 2a3a2b1a..78cf50e2 100644 --- a/libp2p/kad_dht/kad_dht.py +++ b/libp2p/kad_dht/kad_dht.py @@ -364,7 +364,7 @@ class KadDHT(Service): envelope, _ = consume_envelope( message.senderRecord, "libp2p-peer-record" ) - # Use the default TTL of 2 hours (72000 seconds) + # Use the default TTL of 2 hours (7200 seconds) if not self.host.get_peerstore().consume_peer_record( envelope, 7200 ): diff --git a/libp2p/kad_dht/peer_routing.py b/libp2p/kad_dht/peer_routing.py index a2f3d193..58406f05 100644 --- a/libp2p/kad_dht/peer_routing.py +++ b/libp2p/kad_dht/peer_routing.py @@ -361,7 +361,6 @@ class PeerRouting(IPeerRouting): ) except Exception as e: - print("EXCEPTION CAME") logger.debug(f"Error querying peer {peer} for closest: {e}") finally: diff --git a/tests/core/kad_dht/test_kad_dht.py b/tests/core/kad_dht/test_kad_dht.py index eaf9a956..70d9a5e9 100644 --- a/tests/core/kad_dht/test_kad_dht.py +++ b/tests/core/kad_dht/test_kad_dht.py @@ -86,7 +86,7 @@ async def test_find_node(dht_pair: tuple[KadDHT, KadDHT]): dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id()), Envelope ) - # Verifies if the senderRecord in the FIND_NODE response is correctly proccessed + # Verifies if the senderRecord in the FIND_NODE response is correctly processed assert isinstance( dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id()), Envelope ) From 702ad4876e8c925284b6e4faf4e63342a03f8b4a Mon Sep 17 00:00:00 2001 From: lla-dane Date: Tue, 12 Aug 2025 13:53:40 +0530 Subject: [PATCH 29/71] remove too much repeatitive code --- libp2p/kad_dht/kad_dht.py | 126 +++---------------------------- libp2p/kad_dht/peer_routing.py | 61 ++------------- libp2p/kad_dht/provider_store.py | 63 +--------------- libp2p/kad_dht/utils.py | 44 +++++++++++ libp2p/kad_dht/value_store.py | 40 +--------- 5 files changed, 68 insertions(+), 266 deletions(-) diff --git a/libp2p/kad_dht/kad_dht.py b/libp2p/kad_dht/kad_dht.py index 78cf50e2..db0e635e 100644 --- a/libp2p/kad_dht/kad_dht.py +++ b/libp2p/kad_dht/kad_dht.py @@ -22,10 +22,11 @@ from libp2p.abc import ( IHost, ) from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager +from libp2p.kad_dht.utils import maybe_consume_signed_record from libp2p.network.stream.net_stream import ( INetStream, ) -from libp2p.peer.envelope import Envelope, consume_envelope +from libp2p.peer.envelope import Envelope from libp2p.peer.id import ( ID, ) @@ -280,24 +281,7 @@ class KadDHT(Service): logger.debug(f"Found {len(closest_peers)} peers close to target") # Consume the source signed_peer_record if sent - if message.HasField("senderRecord"): - try: - # Convert the signed-peer-record(Envelope) from - # protobuf bytes - envelope, _ = consume_envelope( - message.senderRecord, "libp2p-peer-record" - ) - # Use the defualt TTL of 2 hours (7200 seconds) - if not self.host.get_peerstore().consume_peer_record( - envelope, 7200 - ): - logger.error( - "Updating the Certified-Addr-Book was unsuccessful" - ) - except Exception as e: - logger.error( - "Error updating the certified addr book for peer: %s", e - ) + success = maybe_consume_signed_record(message, self.host) # Build response message with protobuf response = Message() @@ -357,24 +341,7 @@ class KadDHT(Service): logger.debug(f"Received ADD_PROVIDER for key {key.hex()}") # Consume the source signed-peer-record if sent - if message.HasField("senderRecord"): - try: - # Convert the signed-peer-record(Envelope) from - # protobuf bytes - envelope, _ = consume_envelope( - message.senderRecord, "libp2p-peer-record" - ) - # Use the default TTL of 2 hours (7200 seconds) - if not self.host.get_peerstore().consume_peer_record( - envelope, 7200 - ): - logger.error( - "Updating the Certified-Addr-Book was unsuccessful" - ) - except Exception as e: - logger.error( - "Error updating the certified addr book for peer: %s", e - ) + success = maybe_consume_signed_record(message, self.host) # Extract provider information for provider_proto in message.providerPeers: @@ -402,31 +369,13 @@ class KadDHT(Service): logger.debug( f"Added provider {provider_id} for key {key.hex()}" ) - except Exception as e: - logger.warning(f"Failed to process provider info: {e}") # Process the signed-records of provider if sent - if provider_proto.HasField("signedRecord"): - try: - # Convert the signed-peer-record(Envelope) from - # protobuf bytes - envelope, _ = consume_envelope( - provider_proto.signedRecord, - "libp2p-peer-record", - ) - # Use the default TTL of 2 hours (7200 seconds) - if not self.host.get_peerstore().consume_peer_record( # noqa - envelope, 7200 - ): - logger.error( - "Failed to update the Certified-Addr-Book" - ) - except Exception as e: - logger.error( - "Error updating the certified-addr-book for peer %s: %s", # noqa - provider_id, - e, - ) + success = maybe_consume_signed_record( + provider_proto, self.host + ) + except Exception as e: + logger.warning(f"Failed to process provider info: {e}") # Send acknowledgement response = Message() @@ -453,24 +402,7 @@ class KadDHT(Service): logger.debug(f"Received GET_PROVIDERS request for key {key.hex()}") # Consume the source signed_peer_record if sent - if message.HasField("senderRecord"): - try: - # Convert the signed-peer-record(Envelope) from - # protobuf bytes - envelope, _ = consume_envelope( - message.senderRecord, "libp2p-peer-record" - ) - # Use the defualt TTL of 2 hours (7200 seconds) - if not self.host.get_peerstore().consume_peer_record( - envelope, 7200 - ): - logger.error( - "Updating the Certified-Addr-Book was unsuccessful" - ) - except Exception as e: - logger.error( - "Error updating the certified addr book for peer: %s", e - ) + success = maybe_consume_signed_record(message, self.host) # Find providers for the key providers = self.provider_store.get_providers(key) @@ -563,24 +495,7 @@ class KadDHT(Service): logger.debug(f"Received GET_VALUE request for key {key.hex()}") # Consume the sender_signed_peer_record - if message.HasField("senderRecord"): - try: - # Convert the signed-peer-record(Envelope) from - # protobuf bytes - envelope, _ = consume_envelope( - message.senderRecord, "libp2p-peer-record" - ) - # Use the default TTL of 2 hours (7200 seconds) - if not self.host.get_peerstore().consume_peer_record( - envelope, 7200 - ): - logger.error( - "Updating teh Certified-Addr-Book was unsuccessful" - ) - except Exception as e: - logger.error( - "Error updating the certified addr book for peer: %s", e - ) + success = maybe_consume_signed_record(message, self.host) value = self.value_store.get(key) if value: @@ -677,24 +592,7 @@ class KadDHT(Service): success = False # Consume the source signed_peer_record if sent - if message.HasField("senderRecord"): - try: - # Convert the signed-peer-record(Envelope) from - # protobuf bytes - envelope, _ = consume_envelope( - message.senderRecord, "libp2p-peer-record" - ) - # Use the default TTL of 2 hours (7200 seconds) - if not self.host.get_peerstore().consume_peer_record( - envelope, 7200 - ): - logger.error( - "Updating the certified-addr-book was unsuccessful" - ) - except Exception as e: - logger.error( - "Error updating the certified addr book for peer: %s", e - ) + success = maybe_consume_signed_record(message, self.host) try: if not (key and value): diff --git a/libp2p/kad_dht/peer_routing.py b/libp2p/kad_dht/peer_routing.py index 58406f05..e36f7caf 100644 --- a/libp2p/kad_dht/peer_routing.py +++ b/libp2p/kad_dht/peer_routing.py @@ -15,7 +15,7 @@ from libp2p.abc import ( INetStream, IPeerRouting, ) -from libp2p.peer.envelope import Envelope, consume_envelope +from libp2p.peer.envelope import Envelope from libp2p.peer.id import ( ID, ) @@ -35,6 +35,7 @@ from .routing_table import ( RoutingTable, ) from .utils import ( + maybe_consume_signed_record, sort_peer_ids_by_distance, ) @@ -308,24 +309,7 @@ class PeerRouting(IPeerRouting): # Process closest peers from response if response_msg.type == Message.MessageType.FIND_NODE: # Consume the sender_signed_peer_record - if response_msg.HasField("senderRecord"): - try: - # Convert the signed-peer-record(Envelope) from - # protobuf bytes - envelope, _ = consume_envelope( - response_msg.senderRecord, "libp2p-peer-record" - ) - # Use the default TTL of 2 hours (7200 seconds) - if not self.host.get_peerstore().consume_peer_record( - envelope, 7200 - ): - logger.error( - "Updating teh Certified-Addr-Book was unsuccessful" - ) - except Exception as e: - logger.error( - "Error updating the certified addr book for peer: %s", e - ) + _ = maybe_consume_signed_record(response_msg, self.host) for peer_data in response_msg.closerPeers: new_peer_id = ID(peer_data.id) @@ -340,25 +324,7 @@ class PeerRouting(IPeerRouting): self.host.get_peerstore().add_addrs(new_peer_id, addrs, 3600) # Consume the received closer_peers signed-records - if peer_data.HasField("signedRecord"): - try: - # Convert the signed-peer-record(Envelope) from - # protobuf bytes - envelope, _ = consume_envelope( - peer_data.signedRecord, - "libp2p-peer-record", - ) - # Use the default TTL of 2 hours (7200 seconds) - if not self.host.get_peerstore().consume_peer_record( - envelope, 7200 - ): - logger.error("Failed to update certified-addr-book") - except Exception as e: - logger.error( - "Error updating the certified-addr-book for peer %s: %s", # noqa - new_peer_id, - e, - ) + _ = maybe_consume_signed_record(peer_data, self.host) except Exception as e: logger.debug(f"Error querying peer {peer} for closest: {e}") @@ -400,24 +366,7 @@ class PeerRouting(IPeerRouting): if kad_message.type == Message.MessageType.FIND_NODE: # Consume the sender's signed-peer-record if sent - if kad_message.HasField("senderRecord"): - try: - # Convert the signed-peer-record(Envelope) from - # protobuf bytes - envelope, _ = consume_envelope( - kad_message.senderRecord, "libp2p-peer-record" - ) - # Use the default TTL of 2 hours (7200 seconds) - if not self.host.get_peerstore().consume_peer_record( - envelope, 7200 - ): - logger.error( - "Updating the Certified-Addr-Book was unsuccessful" - ) - except Exception as e: - logger.error( - "Error updating the certified addr book for peer: %s", e - ) + _ = maybe_consume_signed_record(kad_message, self.host) # Get target key directly from protobuf message target_key = kad_message.key diff --git a/libp2p/kad_dht/provider_store.py b/libp2p/kad_dht/provider_store.py index c5800914..21bd1c80 100644 --- a/libp2p/kad_dht/provider_store.py +++ b/libp2p/kad_dht/provider_store.py @@ -22,7 +22,7 @@ from libp2p.abc import ( from libp2p.custom_types import ( TProtocol, ) -from libp2p.peer.envelope import consume_envelope +from libp2p.kad_dht.utils import maybe_consume_signed_record from libp2p.peer.id import ( ID, ) @@ -291,25 +291,7 @@ class ProviderStore: if response.type == Message.MessageType.ADD_PROVIDER: # Consume the sender's signed-peer-record if sent - if response.HasField("senderRecord"): - try: - # Convert the signed-peer-record(Envelope) from - # protobuf bytes - envelope, _ = consume_envelope( - response.senderRecord, "libp2p-peer-record" - ) - # Use the defualt TTL of 2 hours (7200 seconds) - if not self.host.get_peerstore().consume_peer_record( - envelope, 7200 - ): - logger.error( - "Updating the Certified-Addr-Book was unsuccessful" - ) - except Exception as e: - logger.error( - "Error updating the certified addr book for peer: %s", e - ) - + _ = maybe_consume_signed_record(response, self.host) result = True except Exception as e: @@ -454,24 +436,7 @@ class ProviderStore: return [] # Consume the sender's signed-peer-record if sent - if response.HasField("senderRecord"): - try: - # Convert the signed-peer-record(Envelope) from - # protobuf bytes - envelope, _ = consume_envelope( - response.senderRecord, "libp2p-peer-record" - ) - # Use the defualt TTL of 2 hours (7200 seconds) - if not self.host.get_peerstore().consume_peer_record( - envelope, 7200 - ): - logger.error( - "Updating the Certified-Addr-Book was unsuccessful" - ) - except Exception as e: - logger.error( - "Error updating the certified addr book for peer: %s", e - ) + _ = maybe_consume_signed_record(response, self.host) # Extract provider information providers = [] @@ -492,27 +457,7 @@ class ProviderStore: providers.append(PeerInfo(provider_id, addrs)) # Consume the provider's signed-peer-record if sent - if provider_proto.HasField("signedRecord"): - try: - # Convert the signed-peer-record(Envelope) from - # protobuf bytes - envelope, _ = consume_envelope( - provider_proto.signedRecord, - "libp2p-peer-record", - ) - # Use the default TTL of 2 hours (7200 seconds) - if not self.host.get_peerstore().consume_peer_record( # noqa - envelope, 7200 - ): - logger.error( - "Failed to update the Certified-Addr-Book" - ) - except Exception as e: - logger.error( - "Error updating the certified-addr-book for peer %s: %s", # noqa - provider_id, - e, - ) + _ = maybe_consume_signed_record(provider_proto, self.host) except Exception as e: logger.warning(f"Failed to parse provider info: {e}") diff --git a/libp2p/kad_dht/utils.py b/libp2p/kad_dht/utils.py index 61158320..64976cb3 100644 --- a/libp2p/kad_dht/utils.py +++ b/libp2p/kad_dht/utils.py @@ -2,13 +2,57 @@ Utility functions for Kademlia DHT implementation. """ +import logging + import base58 import multihash +from libp2p.abc import IHost +from libp2p.peer.envelope import consume_envelope from libp2p.peer.id import ( ID, ) +from .pb.kademlia_pb2 import ( + Message, +) + +logger = logging.getLogger("kademlia-example.utils") + + +def maybe_consume_signed_record(msg: Message | Message.Peer, host: IHost) -> bool: + if isinstance(msg, Message): + if msg.HasField("senderRecord"): + try: + # Convert the signed-peer-record(Envelope) from + # protobuf bytes + envelope, _ = consume_envelope(msg.senderRecord, "libp2p-peer-record") + # Use the default TTL of 2 hours (7200 seconds) + if not host.get_peerstore().consume_peer_record(envelope, 7200): + logger.error("Updating the certified-addr-book was unsuccessful") + except Exception as e: + logger.error("Error updating teh certified addr book for peer: %s", e) + return False + else: + if msg.HasField("signedRecord"): + try: + # Convert the signed-peer-record(Envelope) from + # protobuf bytes + envelope, _ = consume_envelope( + msg.signedRecord, + "libp2p-peer-record", + ) + # Use the default TTL of 2 hours (7200 seconds) + if not host.get_peerstore().consume_peer_record(envelope, 7200): + logger.error("Failed to update the Certified-Addr-Book") + except Exception as e: + logger.error( + "Error updating the certified-addr-book: %s", + e, + ) + + return True + def create_key_from_binary(binary_data: bytes) -> bytes: """ diff --git a/libp2p/kad_dht/value_store.py b/libp2p/kad_dht/value_store.py index 28cc6d8c..adc37b72 100644 --- a/libp2p/kad_dht/value_store.py +++ b/libp2p/kad_dht/value_store.py @@ -15,7 +15,7 @@ from libp2p.abc import ( from libp2p.custom_types import ( TProtocol, ) -from libp2p.peer.envelope import consume_envelope +from libp2p.kad_dht.utils import maybe_consume_signed_record from libp2p.peer.id import ( ID, ) @@ -166,24 +166,7 @@ class ValueStore: # Check if response is valid if response.type == Message.MessageType.PUT_VALUE: # Consume the sender's signed-peer-record if sent - if response.HasField("senderRecord"): - try: - # Convert the signed-peer-record(Envelope) from - # protobuf bytes - envelope, _ = consume_envelope( - response.senderRecord, "libp2p-peer-record" - ) - # Use the default TTL of 2 hours (7200 seconds) - if not self.host.get_peerstore().consume_peer_record( - envelope, 7200 - ): - logger.error( - "Updating the certified-addr-book was unsuccessful" - ) - except Exception as e: - logger.error( - "Error updating the certified addr book for peer: %s", e - ) + _ = maybe_consume_signed_record(response, self.host) if response.key == key: result = True @@ -314,24 +297,7 @@ class ValueStore: and response.record.value ): # Consume the sender's signed-peer-record - if response.HasField("senderRecord"): - try: - # Convert the signed-peer-record(Envelope) from - # protobuf bytes - envelope, _ = consume_envelope( - response.senderRecord, "libp2p-peer-record" - ) - # Use the default TTL of 2 hours (7200 seconds) - if not self.host.get_peerstore().consume_peer_record( - envelope, 7200 - ): - logger.error( - "Updating the certified-addr-book was unsuccessful" - ) - except Exception as e: - logger.error( - "Error updating the certified addr book for peer: %s", e - ) + _ = maybe_consume_signed_record(response, self.host) logger.debug( f"Received value for key {key.hex()} from peer {peer_id}" From cea1985c5c7b8aed6ea2b202b775adc949ad682b Mon Sep 17 00:00:00 2001 From: lla-dane Date: Thu, 14 Aug 2025 10:39:48 +0530 Subject: [PATCH 30/71] add reissuing mechanism of records if addrs dont change --- libp2p/abc.py | 7 +++ libp2p/host/basic_host.py | 9 +++ libp2p/kad_dht/kad_dht.py | 51 ++++------------ libp2p/kad_dht/peer_routing.py | 16 ++--- libp2p/kad_dht/provider_store.py | 21 ++----- libp2p/kad_dht/utils.py | 29 +++++++++ libp2p/kad_dht/value_store.py | 19 ++---- libp2p/peer/envelope.py | 5 ++ libp2p/peer/peerstore.py | 9 +++ tests/core/kad_dht/test_kad_dht.py | 95 ++++++++++++++++++++++++++---- 10 files changed, 170 insertions(+), 91 deletions(-) diff --git a/libp2p/abc.py b/libp2p/abc.py index 90ad6a45..614af8bf 100644 --- a/libp2p/abc.py +++ b/libp2p/abc.py @@ -970,6 +970,13 @@ class IPeerStore( # --------CERTIFIED-ADDR-BOOK---------- + @abstractmethod + def get_local_record(self) -> Optional["Envelope"]: + """Get the local-peer-record wrapped in Envelope""" + + def set_local_record(self, envelope: "Envelope") -> None: + """Set the local-peer-record wrapped in Envelope""" + @abstractmethod def consume_peer_record(self, envelope: "Envelope", ttl: int) -> bool: """ diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index b40b0128..a0311bd8 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -43,6 +43,7 @@ from libp2p.peer.id import ( from libp2p.peer.peerinfo import ( PeerInfo, ) +from libp2p.peer.peerstore import create_signed_peer_record from libp2p.protocol_muxer.exceptions import ( MultiselectClientError, MultiselectError, @@ -110,6 +111,14 @@ class BasicHost(IHost): if bootstrap: self.bootstrap = BootstrapDiscovery(network, bootstrap) + # Cache a signed-record if the local-node in the PeerStore + envelope = create_signed_peer_record( + self.get_id(), + self.get_addrs(), + self.get_private_key(), + ) + self.get_peerstore().set_local_record(envelope) + def get_id(self) -> ID: """ :return: peer_id of host diff --git a/libp2p/kad_dht/kad_dht.py b/libp2p/kad_dht/kad_dht.py index db0e635e..f93aa75e 100644 --- a/libp2p/kad_dht/kad_dht.py +++ b/libp2p/kad_dht/kad_dht.py @@ -22,7 +22,7 @@ from libp2p.abc import ( IHost, ) from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager -from libp2p.kad_dht.utils import maybe_consume_signed_record +from libp2p.kad_dht.utils import env_to_send_in_RPC, maybe_consume_signed_record from libp2p.network.stream.net_stream import ( INetStream, ) @@ -33,7 +33,6 @@ from libp2p.peer.id import ( from libp2p.peer.peerinfo import ( PeerInfo, ) -from libp2p.peer.peerstore import create_signed_peer_record from libp2p.tools.async_service import ( Service, ) @@ -319,12 +318,8 @@ class KadDHT(Service): ) # Create sender_signed_peer_record - envelope = create_signed_peer_record( - self.host.get_id(), - self.host.get_addrs(), - self.host.get_private_key(), - ) - response.senderRecord = envelope.marshal_envelope() + envelope_bytes, bool = env_to_send_in_RPC(self.host) + response.senderRecord = envelope_bytes # Serialize and send response response_bytes = response.SerializeToString() @@ -383,12 +378,8 @@ class KadDHT(Service): response.key = key # Add sender's signed-peer-record - envelope = create_signed_peer_record( - self.host.get_id(), - self.host.get_addrs(), - self.host.get_private_key(), - ) - response.senderRecord = envelope.marshal_envelope() + envelope_bytes, bool = env_to_send_in_RPC(self.host) + response.senderRecord = envelope_bytes response_bytes = response.SerializeToString() await stream.write(varint.encode(len(response_bytes))) @@ -416,12 +407,8 @@ class KadDHT(Service): response.key = key # Create sender_signed_peer_record for the response - envelope = create_signed_peer_record( - self.host.get_id(), - self.host.get_addrs(), - self.host.get_private_key(), - ) - response.senderRecord = envelope.marshal_envelope() + envelope_bytes, bool = env_to_send_in_RPC(self.host) + response.senderRecord = envelope_bytes # Add provider information to response for provider_info in providers: @@ -512,12 +499,8 @@ class KadDHT(Service): response.record.timeReceived = str(time.time()) # Create sender_signed_peer_record - envelope = create_signed_peer_record( - self.host.get_id(), - self.host.get_addrs(), - self.host.get_private_key(), - ) - response.senderRecord = envelope.marshal_envelope() + envelope_bytes, bool = env_to_send_in_RPC(self.host) + response.senderRecord = envelope_bytes # Serialize and send response response_bytes = response.SerializeToString() @@ -533,12 +516,8 @@ class KadDHT(Service): response.key = key # Create sender_signed_peer_record for the response - envelope = create_signed_peer_record( - self.host.get_id(), - self.host.get_addrs(), - self.host.get_private_key(), - ) - response.senderRecord = envelope.marshal_envelope() + envelope_bytes, bool = env_to_send_in_RPC(self.host) + response.senderRecord = envelope_bytes # Add closest peers to key closest_peers = self.routing_table.find_local_closest_peers( @@ -616,12 +595,8 @@ class KadDHT(Service): response.key = key # Create sender_signed_peer_record for the response - envelope = create_signed_peer_record( - self.host.get_id(), - self.host.get_addrs(), - self.host.get_private_key(), - ) - response.senderRecord = envelope.marshal_envelope() + envelope_bytes, bool = env_to_send_in_RPC(self.host) + response.senderRecord = envelope_bytes # Serialize and send response response_bytes = response.SerializeToString() diff --git a/libp2p/kad_dht/peer_routing.py b/libp2p/kad_dht/peer_routing.py index e36f7caf..4362ffea 100644 --- a/libp2p/kad_dht/peer_routing.py +++ b/libp2p/kad_dht/peer_routing.py @@ -22,7 +22,6 @@ from libp2p.peer.id import ( from libp2p.peer.peerinfo import ( PeerInfo, ) -from libp2p.peer.peerstore import create_signed_peer_record from .common import ( ALPHA, @@ -35,6 +34,7 @@ from .routing_table import ( RoutingTable, ) from .utils import ( + env_to_send_in_RPC, maybe_consume_signed_record, sort_peer_ids_by_distance, ) @@ -259,10 +259,8 @@ class PeerRouting(IPeerRouting): find_node_msg.key = target_key # Set target key directly as bytes # Create sender_signed_peer_record - envelope = create_signed_peer_record( - self.host.get_id(), self.host.get_addrs(), self.host.get_private_key() - ) - find_node_msg.senderRecord = envelope.marshal_envelope() + envelope_bytes, bool = env_to_send_in_RPC(self.host) + find_node_msg.senderRecord = envelope_bytes # Serialize and send the protobuf message with varint length prefix proto_bytes = find_node_msg.SerializeToString() @@ -381,12 +379,8 @@ class PeerRouting(IPeerRouting): response.type = Message.MessageType.FIND_NODE # Create sender_signed_peer_record for the response - envelope = create_signed_peer_record( - self.host.get_id(), - self.host.get_addrs(), - self.host.get_private_key(), - ) - response.senderRecord = envelope.marshal_envelope() + envelope_bytes, bool = env_to_send_in_RPC(self.host) + response.senderRecord = envelope_bytes # Add peer information to response for peer_id in closest_peers: diff --git a/libp2p/kad_dht/provider_store.py b/libp2p/kad_dht/provider_store.py index 21bd1c80..4c6a8e06 100644 --- a/libp2p/kad_dht/provider_store.py +++ b/libp2p/kad_dht/provider_store.py @@ -22,14 +22,13 @@ from libp2p.abc import ( from libp2p.custom_types import ( TProtocol, ) -from libp2p.kad_dht.utils import maybe_consume_signed_record +from libp2p.kad_dht.utils import env_to_send_in_RPC, maybe_consume_signed_record from libp2p.peer.id import ( ID, ) from libp2p.peer.peerinfo import ( PeerInfo, ) -from libp2p.peer.peerstore import create_signed_peer_record from .common import ( ALPHA, @@ -243,12 +242,8 @@ class ProviderStore: message.key = key # Create sender's signed-peer-record - envelope = create_signed_peer_record( - self.host.get_id(), - self.host.get_addrs(), - self.host.get_private_key(), - ) - message.senderRecord = envelope.marshal_envelope() + envelope_bytes, bool = env_to_send_in_RPC(self.host) + message.senderRecord = envelope_bytes # Add our provider info provider = message.providerPeers.add() @@ -256,7 +251,7 @@ class ProviderStore: provider.addrs.extend(addrs) # Add the provider's signed-peer-record - provider.signedRecord = envelope.marshal_envelope() + provider.signedRecord = envelope_bytes # Serialize and send the message proto_bytes = message.SerializeToString() @@ -394,12 +389,8 @@ class ProviderStore: message.key = key # Create sender's signed-peer-record - envelope = create_signed_peer_record( - self.host.get_id(), - self.host.get_addrs(), - self.host.get_private_key(), - ) - message.senderRecord = envelope.marshal_envelope() + envelope_bytes, bool = env_to_send_in_RPC(self.host) + message.senderRecord = envelope_bytes # Serialize and send the message proto_bytes = message.SerializeToString() diff --git a/libp2p/kad_dht/utils.py b/libp2p/kad_dht/utils.py index 64976cb3..3cf79efd 100644 --- a/libp2p/kad_dht/utils.py +++ b/libp2p/kad_dht/utils.py @@ -12,6 +12,7 @@ from libp2p.peer.envelope import consume_envelope from libp2p.peer.id import ( ID, ) +from libp2p.peer.peerstore import create_signed_peer_record from .pb.kademlia_pb2 import ( Message, @@ -54,6 +55,34 @@ def maybe_consume_signed_record(msg: Message | Message.Peer, host: IHost) -> boo return True +def env_to_send_in_RPC(host: IHost) -> tuple[bytes, bool]: + listen_addrs_set = {addr for addr in host.get_addrs()} + local_env = host.get_peerstore().get_local_record() + + if local_env is None: + # No cached SPR yet -> create one + return issue_and_cache_local_record(host), True + else: + record_addrs_set = local_env._env_addrs_set() + if record_addrs_set == listen_addrs_set: + # Perfect match -> reuse cached envelope + return local_env.marshal_envelope(), False + else: + # Addresses changed -> issue a new SPR and cache it + return issue_and_cache_local_record(host), True + + +def issue_and_cache_local_record(host: IHost) -> bytes: + env = create_signed_peer_record( + host.get_id(), + host.get_addrs(), + host.get_private_key(), + ) + # Cache it for nexxt time use + host.get_peerstore().set_local_record(env) + return env.marshal_envelope() + + def create_key_from_binary(binary_data: bytes) -> bytes: """ Creates a key for the DHT by hashing binary data with SHA-256. diff --git a/libp2p/kad_dht/value_store.py b/libp2p/kad_dht/value_store.py index adc37b72..bb143dcd 100644 --- a/libp2p/kad_dht/value_store.py +++ b/libp2p/kad_dht/value_store.py @@ -15,11 +15,10 @@ from libp2p.abc import ( from libp2p.custom_types import ( TProtocol, ) -from libp2p.kad_dht.utils import maybe_consume_signed_record +from libp2p.kad_dht.utils import env_to_send_in_RPC, maybe_consume_signed_record from libp2p.peer.id import ( ID, ) -from libp2p.peer.peerstore import create_signed_peer_record from .common import ( DEFAULT_TTL, @@ -113,12 +112,8 @@ class ValueStore: message.type = Message.MessageType.PUT_VALUE # Create sender's signed-peer-record - envelope = create_signed_peer_record( - self.host.get_id(), - self.host.get_addrs(), - self.host.get_private_key(), - ) - message.senderRecord = envelope.marshal_envelope() + envelope_bytes, bool = env_to_send_in_RPC(self.host) + message.senderRecord = envelope_bytes # Set message fields message.key = key @@ -245,12 +240,8 @@ class ValueStore: message.key = key # Create sender's signed-peer-record - envelope = create_signed_peer_record( - self.host.get_id(), - self.host.get_addrs(), - self.host.get_private_key(), - ) - message.senderRecord = envelope.marshal_envelope() + envelope_bytes, bool = env_to_send_in_RPC(self.host) + message.senderRecord = envelope_bytes # Serialize and send the protobuf message proto_bytes = message.SerializeToString() diff --git a/libp2p/peer/envelope.py b/libp2p/peer/envelope.py index e93a8280..f8bf9f43 100644 --- a/libp2p/peer/envelope.py +++ b/libp2p/peer/envelope.py @@ -1,5 +1,7 @@ from typing import Any, cast +import multiaddr + from libp2p.crypto.ed25519 import Ed25519PublicKey from libp2p.crypto.keys import PrivateKey, PublicKey from libp2p.crypto.rsa import RSAPublicKey @@ -131,6 +133,9 @@ class Envelope: ) return False + def _env_addrs_set(self) -> set[multiaddr.Multiaddr]: + return {b for b in self.record().addrs} + def pub_key_to_protobuf(pub_key: PublicKey) -> cryto_pb.PublicKey: """ diff --git a/libp2p/peer/peerstore.py b/libp2p/peer/peerstore.py index 0faccb45..ad6f08db 100644 --- a/libp2p/peer/peerstore.py +++ b/libp2p/peer/peerstore.py @@ -65,8 +65,17 @@ class PeerStore(IPeerStore): self.peer_data_map = defaultdict(PeerData) self.addr_update_channels: dict[ID, MemorySendChannel[Multiaddr]] = {} self.peer_record_map: dict[ID, PeerRecordState] = {} + self.local_peer_record: Envelope | None = None self.max_records = max_records + def get_local_record(self) -> Envelope | None: + """Get the local-signed-record wrapped in Envelope""" + return self.local_peer_record + + def set_local_record(self, envelope: Envelope) -> None: + """Set the local-signed-record wrapped in Envelope""" + self.local_peer_record = envelope + def peer_info(self, peer_id: ID) -> PeerInfo: """ :param peer_id: peer ID to get info for diff --git a/tests/core/kad_dht/test_kad_dht.py b/tests/core/kad_dht/test_kad_dht.py index 70d9a5e9..a2e9ec4c 100644 --- a/tests/core/kad_dht/test_kad_dht.py +++ b/tests/core/kad_dht/test_kad_dht.py @@ -9,9 +9,12 @@ This module tests core functionality of the Kademlia DHT including: import hashlib import logging +import os +from unittest.mock import patch import uuid import pytest +import multiaddr import trio from libp2p.kad_dht.kad_dht import ( @@ -77,6 +80,18 @@ async def test_find_node(dht_pair: tuple[KadDHT, KadDHT]): """Test that nodes can find each other in the DHT.""" dht_a, dht_b = dht_pair + # An extra FIND_NODE req is sent between the 2 nodes while dht creation, + # so both the nodes will have records of each other before the next FIND_NODE + # req is sent + envelope_a = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id()) + envelope_b = dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id()) + + assert isinstance(envelope_a, Envelope) + assert isinstance(envelope_b, Envelope) + + record_a = envelope_a.record() + record_b = envelope_b.record() + # Node A should be able to find Node B with trio.fail_after(TEST_TIMEOUT): found_info = await dht_a.find_peer(dht_b.host.get_id()) @@ -91,6 +106,26 @@ async def test_find_node(dht_pair: tuple[KadDHT, KadDHT]): dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id()), Envelope ) + # These are the records that were sent betweeen the peers during the FIND_NODE req + envelope_a_find_peer = dht_a.host.get_peerstore().get_peer_record( + dht_b.host.get_id() + ) + envelope_b_find_peer = dht_b.host.get_peerstore().get_peer_record( + dht_a.host.get_id() + ) + + assert isinstance(envelope_a_find_peer, Envelope) + assert isinstance(envelope_b_find_peer, Envelope) + + record_a_find_peer = envelope_a_find_peer.record() + record_b_find_peer = envelope_b_find_peer.record() + + # This proves that both the records are same, and a latest cached signed record + # was passed between the peers during FIND_NODE exceution, which proves the + # signed-record transfer/re-issuing works correctly in FIND_NODE executions. + assert record_a.seq == record_a_find_peer.seq + assert record_b.seq == record_b_find_peer.seq + # Verify that the found peer has the correct peer ID assert found_info is not None, "Failed to find the target peer" assert found_info.peer_id == dht_b.host.get_id(), "Found incorrect peer ID" @@ -144,11 +179,11 @@ async def test_put_and_get_value(dht_pair: tuple[KadDHT, KadDHT]): record_a_put_value = envelope_a_put_value.record() record_b_put_value = envelope_b_put_value.record() - # This proves that both the records are different, and a new signed record + # This proves that both the records are same, and a latest cached signed record # was passed between the peers during PUT_VALUE exceution, which proves the - # signed-record transfer works correctly in PUT_VALUE executions. - assert record_a.seq < record_a_put_value.seq - assert record_b.seq < record_b_put_value.seq + # signed-record transfer/re-issuing works correctly in PUT_VALUE executions. + assert record_a.seq == record_a_put_value.seq + assert record_b.seq == record_b_put_value.seq # # Log debugging information logger.debug("Put value with key %s...", key.hex()[:10]) @@ -234,11 +269,12 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]): record_a_add_prov = envelope_a_add_prov.record() record_b_add_prov = envelope_b_add_prov.record() - # This proves that both the records are different, and a new signed record + # This proves that both the records are same, the latest cached signed record # was passed between the peers during ADD_PROVIDER exceution, which proves the - # signed-record transfer works correctly in ADD_PROVIDER executions. - assert record_a.seq < record_a_add_prov.seq - assert record_b.seq < record_b_add_prov.seq + # signed-record transfer/re-issuing of the latest record works correctly in + # ADD_PROVIDER executions. + assert record_a.seq == record_a_add_prov.seq + assert record_b.seq == record_b_add_prov.seq # Allow time for the provider record to propagate await trio.sleep(0.1) @@ -294,8 +330,41 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]): record_a_get_value = envelope_a_get_value.record() record_b_get_value = envelope_b_get_value.record() - # This proves that both the records are different, meaning that there was - # a new signed-record tranfer during the GET_VALUE execution by dht_b, which means - # the signed-record transfer works correctly in GET_VALUE executions. - assert record_a_find_prov.seq < record_a_get_value.seq - assert record_b_find_prov.seq < record_b_get_value.seq + # This proves that both the records are same, meaning that the latest cached + # signed-record tranfer happened during the GET_VALUE execution by dht_b, + # which means the signed-record transfer/re-issuing works correctly + # in GET_VALUE executions. + assert record_a_find_prov.seq == record_a_get_value.seq + assert record_b_find_prov.seq == record_b_get_value.seq + + +@pytest.mark.trio +async def test_reissue_when_listen_addrs_change(dht_pair: tuple[KadDHT, KadDHT]): + dht_a, dht_b = dht_pair + + # Warm-up: A stores B's current record + with trio.fail_after(10): + await dht_a.find_peer(dht_b.host.get_id()) + + env0 = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id()) + assert isinstance(env0, Envelope) + seq0 = env0.record().seq + + # Simulate B's listen addrs changing (different port) + new_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/123") + + # Patch just for the duration we force B to respond: + with patch.object(dht_b.host, "get_addrs", return_value=[new_addr]): + # Force B to send a response (which should include a fresh SPR) + with trio.fail_after(10): + await dht_a.peer_routing._query_peer_for_closest( + dht_b.host.get_id(), os.urandom(32) + ) + + # A should now hold B's new record with a bumped seq + env1 = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id()) + assert isinstance(env1, Envelope) + seq1 = env1.record().seq + + # This proves that upon the change in listen_addrs, we issue new records + assert seq1 > seq0, f"Expected seq to bump after addr change, got {seq0} -> {seq1}" From efc899e8725f9bdc4b7c27d0aad66c4ff2b214d5 Mon Sep 17 00:00:00 2001 From: lla-dane Date: Thu, 14 Aug 2025 11:34:40 +0530 Subject: [PATCH 31/71] fix abc.py file --- libp2p/abc.py | 1 + 1 file changed, 1 insertion(+) diff --git a/libp2p/abc.py b/libp2p/abc.py index 614af8bf..a9748339 100644 --- a/libp2p/abc.py +++ b/libp2p/abc.py @@ -974,6 +974,7 @@ class IPeerStore( def get_local_record(self) -> Optional["Envelope"]: """Get the local-peer-record wrapped in Envelope""" + @abstractmethod def set_local_record(self, envelope: "Envelope") -> None: """Set the local-peer-record wrapped in Envelope""" From 57d1c9d80784e229ade251c58aa30fbdedc5fbf9 Mon Sep 17 00:00:00 2001 From: lla-dane Date: Fri, 15 Aug 2025 16:11:27 +0530 Subject: [PATCH 32/71] reject dht-msgs upon receiving invalid records --- libp2p/kad_dht/kad_dht.py | 45 ++++++++++++++++++++++++++------ libp2p/kad_dht/peer_routing.py | 22 ++++++++++++---- libp2p/kad_dht/provider_store.py | 26 +++++++++++++----- libp2p/kad_dht/utils.py | 7 +++-- libp2p/kad_dht/value_store.py | 13 ++++++--- 5 files changed, 89 insertions(+), 24 deletions(-) diff --git a/libp2p/kad_dht/kad_dht.py b/libp2p/kad_dht/kad_dht.py index f93aa75e..adfd7400 100644 --- a/libp2p/kad_dht/kad_dht.py +++ b/libp2p/kad_dht/kad_dht.py @@ -280,7 +280,12 @@ class KadDHT(Service): logger.debug(f"Found {len(closest_peers)} peers close to target") # Consume the source signed_peer_record if sent - success = maybe_consume_signed_record(message, self.host) + if not maybe_consume_signed_record(message, self.host): + logger.error( + "Received an invalid-signed-record, dropping the stream" + ) + await stream.close() + return # Build response message with protobuf response = Message() @@ -336,7 +341,12 @@ class KadDHT(Service): logger.debug(f"Received ADD_PROVIDER for key {key.hex()}") # Consume the source signed-peer-record if sent - success = maybe_consume_signed_record(message, self.host) + if not maybe_consume_signed_record(message, self.host): + logger.error( + "Received an invalid-signed-record, dropping the stream" + ) + await stream.close() + return # Extract provider information for provider_proto in message.providerPeers: @@ -366,9 +376,13 @@ class KadDHT(Service): ) # Process the signed-records of provider if sent - success = maybe_consume_signed_record( - provider_proto, self.host - ) + if not maybe_consume_signed_record(message, self.host): + logger.error( + "Received an invalid-signed-record," + "dropping the stream" + ) + await stream.close() + return except Exception as e: logger.warning(f"Failed to process provider info: {e}") @@ -393,7 +407,12 @@ class KadDHT(Service): logger.debug(f"Received GET_PROVIDERS request for key {key.hex()}") # Consume the source signed_peer_record if sent - success = maybe_consume_signed_record(message, self.host) + if not maybe_consume_signed_record(message, self.host): + logger.error( + "Received an invalid-signed-record, dropping the stream" + ) + await stream.close() + return # Find providers for the key providers = self.provider_store.get_providers(key) @@ -482,7 +501,12 @@ class KadDHT(Service): logger.debug(f"Received GET_VALUE request for key {key.hex()}") # Consume the sender_signed_peer_record - success = maybe_consume_signed_record(message, self.host) + if not maybe_consume_signed_record(message, self.host): + logger.error( + "Received an invalid-signed-record, dropping the stream" + ) + await stream.close() + return value = self.value_store.get(key) if value: @@ -571,7 +595,12 @@ class KadDHT(Service): success = False # Consume the source signed_peer_record if sent - success = maybe_consume_signed_record(message, self.host) + if not maybe_consume_signed_record(message, self.host): + logger.error( + "Received an invalid-signed-record, dropping the stream" + ) + await stream.close() + return try: if not (key and value): diff --git a/libp2p/kad_dht/peer_routing.py b/libp2p/kad_dht/peer_routing.py index 4362ffea..cd1611ed 100644 --- a/libp2p/kad_dht/peer_routing.py +++ b/libp2p/kad_dht/peer_routing.py @@ -307,9 +307,20 @@ class PeerRouting(IPeerRouting): # Process closest peers from response if response_msg.type == Message.MessageType.FIND_NODE: # Consume the sender_signed_peer_record - _ = maybe_consume_signed_record(response_msg, self.host) + if not maybe_consume_signed_record(response_msg, self.host): + logger.error( + "Received an invalid-signed-record,ignoring the response" + ) + return [] for peer_data in response_msg.closerPeers: + # Consume the received closer_peers signed-records + if not maybe_consume_signed_record(peer_data, self.host): + logger.error( + "Received an invalid-signed-record,ignoring the response" + ) + return [] + new_peer_id = ID(peer_data.id) if new_peer_id not in results: results.append(new_peer_id) @@ -321,9 +332,6 @@ class PeerRouting(IPeerRouting): addrs = [Multiaddr(addr) for addr in peer_data.addrs] self.host.get_peerstore().add_addrs(new_peer_id, addrs, 3600) - # Consume the received closer_peers signed-records - _ = maybe_consume_signed_record(peer_data, self.host) - except Exception as e: logger.debug(f"Error querying peer {peer} for closest: {e}") @@ -364,7 +372,11 @@ class PeerRouting(IPeerRouting): if kad_message.type == Message.MessageType.FIND_NODE: # Consume the sender's signed-peer-record if sent - _ = maybe_consume_signed_record(kad_message, self.host) + if not maybe_consume_signed_record(kad_message, self.host): + logger.error( + "Receivedf an invalid-signed-record, dropping the stream" + ) + return # Get target key directly from protobuf message target_key = kad_message.key diff --git a/libp2p/kad_dht/provider_store.py b/libp2p/kad_dht/provider_store.py index 4c6a8e06..ee7adfe8 100644 --- a/libp2p/kad_dht/provider_store.py +++ b/libp2p/kad_dht/provider_store.py @@ -286,8 +286,13 @@ class ProviderStore: if response.type == Message.MessageType.ADD_PROVIDER: # Consume the sender's signed-peer-record if sent - _ = maybe_consume_signed_record(response, self.host) - result = True + if not maybe_consume_signed_record(response, self.host): + logger.error( + "Received an invalid-signed-record, ignoring the response" + ) + result = False + else: + result = True except Exception as e: logger.warning(f"Error sending ADD_PROVIDER to {peer_id}: {e}") @@ -427,12 +432,24 @@ class ProviderStore: return [] # Consume the sender's signed-peer-record if sent - _ = maybe_consume_signed_record(response, self.host) + if not maybe_consume_signed_record(response, self.host): + logger.error( + "Recieved an invalid-signed-record, ignoring the response" + ) + return [] # Extract provider information providers = [] for provider_proto in response.providerPeers: try: + # Consume the provider's signed-peer-record if sent + if not maybe_consume_signed_record(provider_proto, self.host): + logger.error( + "Recieved an invalid-signed-record, " + "ignoring the response" + ) + return [] + # Create peer ID from bytes provider_id = ID(provider_proto.id) @@ -447,9 +464,6 @@ class ProviderStore: # Create PeerInfo and add to result providers.append(PeerInfo(provider_id, addrs)) - # Consume the provider's signed-peer-record if sent - _ = maybe_consume_signed_record(provider_proto, self.host) - except Exception as e: logger.warning(f"Failed to parse provider info: {e}") diff --git a/libp2p/kad_dht/utils.py b/libp2p/kad_dht/utils.py index 3cf79efd..6d65d1af 100644 --- a/libp2p/kad_dht/utils.py +++ b/libp2p/kad_dht/utils.py @@ -27,7 +27,10 @@ def maybe_consume_signed_record(msg: Message | Message.Peer, host: IHost) -> boo try: # Convert the signed-peer-record(Envelope) from # protobuf bytes - envelope, _ = consume_envelope(msg.senderRecord, "libp2p-peer-record") + envelope, _ = consume_envelope( + msg.senderRecord, + "libp2p-peer-record", + ) # Use the default TTL of 2 hours (7200 seconds) if not host.get_peerstore().consume_peer_record(envelope, 7200): logger.error("Updating the certified-addr-book was unsuccessful") @@ -51,7 +54,7 @@ def maybe_consume_signed_record(msg: Message | Message.Peer, host: IHost) -> boo "Error updating the certified-addr-book: %s", e, ) - + return False return True diff --git a/libp2p/kad_dht/value_store.py b/libp2p/kad_dht/value_store.py index bb143dcd..aa545797 100644 --- a/libp2p/kad_dht/value_store.py +++ b/libp2p/kad_dht/value_store.py @@ -161,8 +161,11 @@ class ValueStore: # Check if response is valid if response.type == Message.MessageType.PUT_VALUE: # Consume the sender's signed-peer-record if sent - _ = maybe_consume_signed_record(response, self.host) - + if not maybe_consume_signed_record(response, self.host): + logger.error( + "Received an invalid-signed-record, ignoring the response" + ) + return False if response.key == key: result = True return result @@ -288,7 +291,11 @@ class ValueStore: and response.record.value ): # Consume the sender's signed-peer-record - _ = maybe_consume_signed_record(response, self.host) + if not maybe_consume_signed_record(response, self.host): + logger.error( + "Received an invalid-signed-record, ignoring the response" + ) + return None logger.debug( f"Received value for key {key.hex()} from peer {peer_id}" From ba39e91a2ee6f63b6a122d11334a32459732b260 Mon Sep 17 00:00:00 2001 From: lla-dane Date: Sun, 17 Aug 2025 12:10:08 +0530 Subject: [PATCH 33/71] added test for req rejection upon invalid record transfer --- tests/core/kad_dht/test_kad_dht.py | 40 ++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/core/kad_dht/test_kad_dht.py b/tests/core/kad_dht/test_kad_dht.py index a2e9ec4c..05a31468 100644 --- a/tests/core/kad_dht/test_kad_dht.py +++ b/tests/core/kad_dht/test_kad_dht.py @@ -17,6 +17,7 @@ import pytest import multiaddr import trio +from libp2p.crypto.rsa import create_new_key_pair from libp2p.kad_dht.kad_dht import ( DHTMode, KadDHT, @@ -368,3 +369,42 @@ async def test_reissue_when_listen_addrs_change(dht_pair: tuple[KadDHT, KadDHT]) # This proves that upon the change in listen_addrs, we issue new records assert seq1 > seq0, f"Expected seq to bump after addr change, got {seq0} -> {seq1}" + + +@pytest.mark.trio +async def test_dht_req_fail_with_invalid_record_transfer( + dht_pair: tuple[KadDHT, KadDHT], +): + """ + Testing showing failure of storing and retrieving values in the DHT, + if invalid signed-records are sent. + """ + dht_a, dht_b = dht_pair + peer_b_info = PeerInfo(dht_b.host.get_id(), dht_b.host.get_addrs()) + + # Generate a random key and value + key = create_key_from_binary(b"test-key") + value = b"test-value" + + # First add the value directly to node A's store to verify storage works + dht_a.value_store.put(key, value) + local_value = dht_a.value_store.get(key) + assert local_value == value, "Local value storage failed" + await dht_a.routing_table.add_peer(peer_b_info) + + # Corrupt dht_a's local peer_record + envelope = dht_a.host.get_peerstore().get_local_record() + key_pair = create_new_key_pair() + + if envelope is not None: + envelope.public_key = key_pair.public_key + dht_a.host.get_peerstore().set_local_record(envelope) + + with trio.fail_after(TEST_TIMEOUT): + await dht_a.put_value(key, value) + + value = dht_b.value_store.get(key) + + # This proves that DHT_B rejected DHT_A PUT_RECORD req upon receiving + # the corrupted invalid record + assert value is None From 3aacb3a391015e710cef24b3973c195b69c4ff25 Mon Sep 17 00:00:00 2001 From: lla-dane Date: Sun, 17 Aug 2025 12:21:39 +0530 Subject: [PATCH 34/71] remove the timeout bound from the kad-dht test --- tests/core/kad_dht/test_kad_dht.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/core/kad_dht/test_kad_dht.py b/tests/core/kad_dht/test_kad_dht.py index 05a31468..37730308 100644 --- a/tests/core/kad_dht/test_kad_dht.py +++ b/tests/core/kad_dht/test_kad_dht.py @@ -400,8 +400,7 @@ async def test_dht_req_fail_with_invalid_record_transfer( envelope.public_key = key_pair.public_key dht_a.host.get_peerstore().set_local_record(envelope) - with trio.fail_after(TEST_TIMEOUT): - await dht_a.put_value(key, value) + await dht_a.put_value(key, value) value = dht_b.value_store.get(key) From 3917d7b5967bae22655bd1054a7e0a5175b42e35 Mon Sep 17 00:00:00 2001 From: lla-dane Date: Wed, 20 Aug 2025 18:07:32 +0530 Subject: [PATCH 35/71] verify peer_id in signed-record matches authenticated sender --- libp2p/kad_dht/kad_dht.py | 14 ++++++++------ libp2p/kad_dht/peer_routing.py | 8 +++++--- libp2p/kad_dht/provider_store.py | 7 ++++--- libp2p/kad_dht/utils.py | 13 ++++++++++--- libp2p/kad_dht/value_store.py | 4 ++-- tests/core/kad_dht/test_kad_dht.py | 24 ++++++++++++++++++++---- 6 files changed, 49 insertions(+), 21 deletions(-) diff --git a/libp2p/kad_dht/kad_dht.py b/libp2p/kad_dht/kad_dht.py index adfd7400..44787690 100644 --- a/libp2p/kad_dht/kad_dht.py +++ b/libp2p/kad_dht/kad_dht.py @@ -280,7 +280,7 @@ class KadDHT(Service): logger.debug(f"Found {len(closest_peers)} peers close to target") # Consume the source signed_peer_record if sent - if not maybe_consume_signed_record(message, self.host): + if not maybe_consume_signed_record(message, self.host, peer_id): logger.error( "Received an invalid-signed-record, dropping the stream" ) @@ -341,7 +341,7 @@ class KadDHT(Service): logger.debug(f"Received ADD_PROVIDER for key {key.hex()}") # Consume the source signed-peer-record if sent - if not maybe_consume_signed_record(message, self.host): + if not maybe_consume_signed_record(message, self.host, peer_id): logger.error( "Received an invalid-signed-record, dropping the stream" ) @@ -376,7 +376,9 @@ class KadDHT(Service): ) # Process the signed-records of provider if sent - if not maybe_consume_signed_record(message, self.host): + if not maybe_consume_signed_record( + message, self.host, peer_id + ): logger.error( "Received an invalid-signed-record," "dropping the stream" @@ -407,7 +409,7 @@ class KadDHT(Service): logger.debug(f"Received GET_PROVIDERS request for key {key.hex()}") # Consume the source signed_peer_record if sent - if not maybe_consume_signed_record(message, self.host): + if not maybe_consume_signed_record(message, self.host, peer_id): logger.error( "Received an invalid-signed-record, dropping the stream" ) @@ -501,7 +503,7 @@ class KadDHT(Service): logger.debug(f"Received GET_VALUE request for key {key.hex()}") # Consume the sender_signed_peer_record - if not maybe_consume_signed_record(message, self.host): + if not maybe_consume_signed_record(message, self.host, peer_id): logger.error( "Received an invalid-signed-record, dropping the stream" ) @@ -595,7 +597,7 @@ class KadDHT(Service): success = False # Consume the source signed_peer_record if sent - if not maybe_consume_signed_record(message, self.host): + if not maybe_consume_signed_record(message, self.host, peer_id): logger.error( "Received an invalid-signed-record, dropping the stream" ) diff --git a/libp2p/kad_dht/peer_routing.py b/libp2p/kad_dht/peer_routing.py index cd1611ed..34b95902 100644 --- a/libp2p/kad_dht/peer_routing.py +++ b/libp2p/kad_dht/peer_routing.py @@ -307,14 +307,15 @@ class PeerRouting(IPeerRouting): # Process closest peers from response if response_msg.type == Message.MessageType.FIND_NODE: # Consume the sender_signed_peer_record - if not maybe_consume_signed_record(response_msg, self.host): + if not maybe_consume_signed_record(response_msg, self.host, peer): logger.error( "Received an invalid-signed-record,ignoring the response" ) return [] for peer_data in response_msg.closerPeers: - # Consume the received closer_peers signed-records + # Consume the received closer_peers signed-records, peer-id is + # sent with the peer-data if not maybe_consume_signed_record(peer_data, self.host): logger.error( "Received an invalid-signed-record,ignoring the response" @@ -353,6 +354,7 @@ class PeerRouting(IPeerRouting): """ try: # Read message length + peer_id = stream.muxed_conn.peer_id length_bytes = await stream.read(4) if not length_bytes: return @@ -372,7 +374,7 @@ class PeerRouting(IPeerRouting): if kad_message.type == Message.MessageType.FIND_NODE: # Consume the sender's signed-peer-record if sent - if not maybe_consume_signed_record(kad_message, self.host): + if not maybe_consume_signed_record(kad_message, self.host, peer_id): logger.error( "Receivedf an invalid-signed-record, dropping the stream" ) diff --git a/libp2p/kad_dht/provider_store.py b/libp2p/kad_dht/provider_store.py index ee7adfe8..1aae23f7 100644 --- a/libp2p/kad_dht/provider_store.py +++ b/libp2p/kad_dht/provider_store.py @@ -286,7 +286,7 @@ class ProviderStore: if response.type == Message.MessageType.ADD_PROVIDER: # Consume the sender's signed-peer-record if sent - if not maybe_consume_signed_record(response, self.host): + if not maybe_consume_signed_record(response, self.host, peer_id): logger.error( "Received an invalid-signed-record, ignoring the response" ) @@ -432,7 +432,7 @@ class ProviderStore: return [] # Consume the sender's signed-peer-record if sent - if not maybe_consume_signed_record(response, self.host): + if not maybe_consume_signed_record(response, self.host, peer_id): logger.error( "Recieved an invalid-signed-record, ignoring the response" ) @@ -442,7 +442,8 @@ class ProviderStore: providers = [] for provider_proto in response.providerPeers: try: - # Consume the provider's signed-peer-record if sent + # Consume the provider's signed-peer-record if sent, peer-id + # already sent with the provider-proto if not maybe_consume_signed_record(provider_proto, self.host): logger.error( "Recieved an invalid-signed-record, " diff --git a/libp2p/kad_dht/utils.py b/libp2p/kad_dht/utils.py index 6d65d1af..6c406587 100644 --- a/libp2p/kad_dht/utils.py +++ b/libp2p/kad_dht/utils.py @@ -21,16 +21,20 @@ from .pb.kademlia_pb2 import ( logger = logging.getLogger("kademlia-example.utils") -def maybe_consume_signed_record(msg: Message | Message.Peer, host: IHost) -> bool: +def maybe_consume_signed_record( + msg: Message | Message.Peer, host: IHost, peer_id: ID | None = None +) -> bool: if isinstance(msg, Message): if msg.HasField("senderRecord"): try: # Convert the signed-peer-record(Envelope) from # protobuf bytes - envelope, _ = consume_envelope( + envelope, record = consume_envelope( msg.senderRecord, "libp2p-peer-record", ) + if not (isinstance(peer_id, ID) and record.peer_id == peer_id): + return False # Use the default TTL of 2 hours (7200 seconds) if not host.get_peerstore().consume_peer_record(envelope, 7200): logger.error("Updating the certified-addr-book was unsuccessful") @@ -39,13 +43,16 @@ def maybe_consume_signed_record(msg: Message | Message.Peer, host: IHost) -> boo return False else: if msg.HasField("signedRecord"): + # TODO: Check in with the Message.Peer id with the record's id try: # Convert the signed-peer-record(Envelope) from # protobuf bytes - envelope, _ = consume_envelope( + envelope, record = consume_envelope( msg.signedRecord, "libp2p-peer-record", ) + if not record.peer_id.to_bytes() == msg.id: + return False # Use the default TTL of 2 hours (7200 seconds) if not host.get_peerstore().consume_peer_record(envelope, 7200): logger.error("Failed to update the Certified-Addr-Book") diff --git a/libp2p/kad_dht/value_store.py b/libp2p/kad_dht/value_store.py index aa545797..c0241528 100644 --- a/libp2p/kad_dht/value_store.py +++ b/libp2p/kad_dht/value_store.py @@ -161,7 +161,7 @@ class ValueStore: # Check if response is valid if response.type == Message.MessageType.PUT_VALUE: # Consume the sender's signed-peer-record if sent - if not maybe_consume_signed_record(response, self.host): + if not maybe_consume_signed_record(response, self.host, peer_id): logger.error( "Received an invalid-signed-record, ignoring the response" ) @@ -291,7 +291,7 @@ class ValueStore: and response.record.value ): # Consume the sender's signed-peer-record - if not maybe_consume_signed_record(response, self.host): + if not maybe_consume_signed_record(response, self.host, peer_id): logger.error( "Received an invalid-signed-record, ignoring the response" ) diff --git a/tests/core/kad_dht/test_kad_dht.py b/tests/core/kad_dht/test_kad_dht.py index 37730308..0d9a29f7 100644 --- a/tests/core/kad_dht/test_kad_dht.py +++ b/tests/core/kad_dht/test_kad_dht.py @@ -25,7 +25,9 @@ from libp2p.kad_dht.kad_dht import ( from libp2p.kad_dht.utils import ( create_key_from_binary, ) -from libp2p.peer.envelope import Envelope +from libp2p.peer.envelope import Envelope, seal_record +from libp2p.peer.id import ID +from libp2p.peer.peer_record import PeerRecord from libp2p.peer.peerinfo import ( PeerInfo, ) @@ -394,6 +396,8 @@ async def test_dht_req_fail_with_invalid_record_transfer( # Corrupt dht_a's local peer_record envelope = dht_a.host.get_peerstore().get_local_record() + if envelope is not None: + true_record = envelope.record() key_pair = create_new_key_pair() if envelope is not None: @@ -401,9 +405,21 @@ async def test_dht_req_fail_with_invalid_record_transfer( dht_a.host.get_peerstore().set_local_record(envelope) await dht_a.put_value(key, value) - - value = dht_b.value_store.get(key) + retrieved_value = dht_b.value_store.get(key) # This proves that DHT_B rejected DHT_A PUT_RECORD req upon receiving # the corrupted invalid record - assert value is None + assert retrieved_value is None + + # Create a corrupt envelope with correct signature but false peer_id + false_record = PeerRecord(ID.from_pubkey(key_pair.public_key), true_record.addrs) + false_envelope = seal_record(false_record, dht_a.host.get_private_key()) + + dht_a.host.get_peerstore().set_local_record(false_envelope) + + await dht_a.put_value(key, value) + retrieved_value = dht_b.value_store.get(key) + + # This proves that DHT_B rejected DHT_A PUT_RECORD req upon receving + # the record with a different peer_id regardless of a valid signature + assert retrieved_value is None From 15f4a399ec3ba5b955b0fdca2e80887bf02b6b1c Mon Sep 17 00:00:00 2001 From: lla-dane Date: Sat, 23 Aug 2025 15:36:57 +0530 Subject: [PATCH 36/71] Added and docstrings and removed typos --- libp2p/kad_dht/provider_store.py | 4 ++-- libp2p/kad_dht/utils.py | 34 ++++++++++++++++++++++++++++---- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/libp2p/kad_dht/provider_store.py b/libp2p/kad_dht/provider_store.py index 1aae23f7..3f912ace 100644 --- a/libp2p/kad_dht/provider_store.py +++ b/libp2p/kad_dht/provider_store.py @@ -434,7 +434,7 @@ class ProviderStore: # Consume the sender's signed-peer-record if sent if not maybe_consume_signed_record(response, self.host, peer_id): logger.error( - "Recieved an invalid-signed-record, ignoring the response" + "Received an invalid-signed-record, ignoring the response" ) return [] @@ -446,7 +446,7 @@ class ProviderStore: # already sent with the provider-proto if not maybe_consume_signed_record(provider_proto, self.host): logger.error( - "Recieved an invalid-signed-record, " + "Received an invalid-signed-record, " "ignoring the response" ) return [] diff --git a/libp2p/kad_dht/utils.py b/libp2p/kad_dht/utils.py index 6c406587..839efb10 100644 --- a/libp2p/kad_dht/utils.py +++ b/libp2p/kad_dht/utils.py @@ -24,6 +24,31 @@ logger = logging.getLogger("kademlia-example.utils") def maybe_consume_signed_record( msg: Message | Message.Peer, host: IHost, peer_id: ID | None = None ) -> bool: + """ + Attempt to parse and store a signed-peer-record (Envelope) received during + DHT communication. If the record is invalid, the peer-id does not match, or + updating the peerstore fails, the function logs an error and returns False. + + Parameters + ---------- + msg : Message | Message.Peer + The protobuf message received during DHT communication. Can either be a + top-level `Message` containing `senderRecord` or a `Message.Peer` + containing `signedRecord`. + host : IHost + The local host instance, providing access to the peerstore for storing + verified peer records. + peer_id : ID | None, optional + The expected peer ID for record validation. If provided, the peer ID + inside the record must match this value. + + Returns + ------- + bool + True if a valid signed peer record was successfully consumed and stored, + False otherwise. + + """ if isinstance(msg, Message): if msg.HasField("senderRecord"): try: @@ -37,13 +62,13 @@ def maybe_consume_signed_record( return False # Use the default TTL of 2 hours (7200 seconds) if not host.get_peerstore().consume_peer_record(envelope, 7200): - logger.error("Updating the certified-addr-book was unsuccessful") + logger.error("Failed to update the Certified-Addr-Book") + return False except Exception as e: - logger.error("Error updating teh certified addr book for peer: %s", e) + logger.error("Failed to update the Certified-Addr-Book: %s", e) return False else: if msg.HasField("signedRecord"): - # TODO: Check in with the Message.Peer id with the record's id try: # Convert the signed-peer-record(Envelope) from # protobuf bytes @@ -56,9 +81,10 @@ def maybe_consume_signed_record( # Use the default TTL of 2 hours (7200 seconds) if not host.get_peerstore().consume_peer_record(envelope, 7200): logger.error("Failed to update the Certified-Addr-Book") + return False except Exception as e: logger.error( - "Error updating the certified-addr-book: %s", + "Failed to update the Certified-Addr-Book: %s", e, ) return False From 091ac082b9e61d77214b10570086885f83ecdde4 Mon Sep 17 00:00:00 2001 From: lla-dane Date: Sat, 23 Aug 2025 15:52:43 +0530 Subject: [PATCH 37/71] Commented out the bool variable from env_to_send_in_RPC() at places --- libp2p/kad_dht/kad_dht.py | 12 ++++++------ libp2p/kad_dht/peer_routing.py | 4 ++-- libp2p/kad_dht/provider_store.py | 4 ++-- libp2p/kad_dht/value_store.py | 4 ++-- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/libp2p/kad_dht/kad_dht.py b/libp2p/kad_dht/kad_dht.py index 44787690..701d8415 100644 --- a/libp2p/kad_dht/kad_dht.py +++ b/libp2p/kad_dht/kad_dht.py @@ -323,7 +323,7 @@ class KadDHT(Service): ) # Create sender_signed_peer_record - envelope_bytes, bool = env_to_send_in_RPC(self.host) + envelope_bytes, _ = env_to_send_in_RPC(self.host) response.senderRecord = envelope_bytes # Serialize and send response @@ -394,7 +394,7 @@ class KadDHT(Service): response.key = key # Add sender's signed-peer-record - envelope_bytes, bool = env_to_send_in_RPC(self.host) + envelope_bytes, _ = env_to_send_in_RPC(self.host) response.senderRecord = envelope_bytes response_bytes = response.SerializeToString() @@ -428,7 +428,7 @@ class KadDHT(Service): response.key = key # Create sender_signed_peer_record for the response - envelope_bytes, bool = env_to_send_in_RPC(self.host) + envelope_bytes, _ = env_to_send_in_RPC(self.host) response.senderRecord = envelope_bytes # Add provider information to response @@ -525,7 +525,7 @@ class KadDHT(Service): response.record.timeReceived = str(time.time()) # Create sender_signed_peer_record - envelope_bytes, bool = env_to_send_in_RPC(self.host) + envelope_bytes, _ = env_to_send_in_RPC(self.host) response.senderRecord = envelope_bytes # Serialize and send response @@ -542,7 +542,7 @@ class KadDHT(Service): response.key = key # Create sender_signed_peer_record for the response - envelope_bytes, bool = env_to_send_in_RPC(self.host) + envelope_bytes, _ = env_to_send_in_RPC(self.host) response.senderRecord = envelope_bytes # Add closest peers to key @@ -626,7 +626,7 @@ class KadDHT(Service): response.key = key # Create sender_signed_peer_record for the response - envelope_bytes, bool = env_to_send_in_RPC(self.host) + envelope_bytes, _ = env_to_send_in_RPC(self.host) response.senderRecord = envelope_bytes # Serialize and send response diff --git a/libp2p/kad_dht/peer_routing.py b/libp2p/kad_dht/peer_routing.py index 34b95902..cf96dd7b 100644 --- a/libp2p/kad_dht/peer_routing.py +++ b/libp2p/kad_dht/peer_routing.py @@ -259,7 +259,7 @@ class PeerRouting(IPeerRouting): find_node_msg.key = target_key # Set target key directly as bytes # Create sender_signed_peer_record - envelope_bytes, bool = env_to_send_in_RPC(self.host) + envelope_bytes, _ = env_to_send_in_RPC(self.host) find_node_msg.senderRecord = envelope_bytes # Serialize and send the protobuf message with varint length prefix @@ -393,7 +393,7 @@ class PeerRouting(IPeerRouting): response.type = Message.MessageType.FIND_NODE # Create sender_signed_peer_record for the response - envelope_bytes, bool = env_to_send_in_RPC(self.host) + envelope_bytes, _ = env_to_send_in_RPC(self.host) response.senderRecord = envelope_bytes # Add peer information to response diff --git a/libp2p/kad_dht/provider_store.py b/libp2p/kad_dht/provider_store.py index 3f912ace..fd780840 100644 --- a/libp2p/kad_dht/provider_store.py +++ b/libp2p/kad_dht/provider_store.py @@ -242,7 +242,7 @@ class ProviderStore: message.key = key # Create sender's signed-peer-record - envelope_bytes, bool = env_to_send_in_RPC(self.host) + envelope_bytes, _ = env_to_send_in_RPC(self.host) message.senderRecord = envelope_bytes # Add our provider info @@ -394,7 +394,7 @@ class ProviderStore: message.key = key # Create sender's signed-peer-record - envelope_bytes, bool = env_to_send_in_RPC(self.host) + envelope_bytes, _ = env_to_send_in_RPC(self.host) message.senderRecord = envelope_bytes # Serialize and send the message diff --git a/libp2p/kad_dht/value_store.py b/libp2p/kad_dht/value_store.py index c0241528..7ada100f 100644 --- a/libp2p/kad_dht/value_store.py +++ b/libp2p/kad_dht/value_store.py @@ -112,7 +112,7 @@ class ValueStore: message.type = Message.MessageType.PUT_VALUE # Create sender's signed-peer-record - envelope_bytes, bool = env_to_send_in_RPC(self.host) + envelope_bytes, _ = env_to_send_in_RPC(self.host) message.senderRecord = envelope_bytes # Set message fields @@ -243,7 +243,7 @@ class ValueStore: message.key = key # Create sender's signed-peer-record - envelope_bytes, bool = env_to_send_in_RPC(self.host) + envelope_bytes, _ = env_to_send_in_RPC(self.host) message.senderRecord = envelope_bytes # Serialize and send the protobuf message From 8958c0fac39421a58fc3fe99f1d40a7db2aa1c7d Mon Sep 17 00:00:00 2001 From: lla-dane Date: Sat, 23 Aug 2025 16:05:08 +0530 Subject: [PATCH 38/71] Moved env_to_send_in_RPC function to libp2p/init.py --- libp2p/__init__.py | 70 ++++++++++++++++++++++++++++++++ libp2p/kad_dht/kad_dht.py | 3 +- libp2p/kad_dht/peer_routing.py | 2 +- libp2p/kad_dht/provider_store.py | 3 +- libp2p/kad_dht/utils.py | 29 ------------- libp2p/kad_dht/value_store.py | 3 +- 6 files changed, 77 insertions(+), 33 deletions(-) diff --git a/libp2p/__init__.py b/libp2p/__init__.py index d2ce122a..5942cd2e 100644 --- a/libp2p/__init__.py +++ b/libp2p/__init__.py @@ -49,6 +49,7 @@ from libp2p.peer.id import ( ) from libp2p.peer.peerstore import ( PeerStore, + create_signed_peer_record, ) from libp2p.security.insecure.transport import ( PLAINTEXT_PROTOCOL_ID, @@ -155,6 +156,75 @@ def get_default_muxer_options() -> TMuxerOptions: else: # YAMUX is default return create_yamux_muxer_option() +def env_to_send_in_RPC(host: IHost) -> tuple[bytes, bool]: + """ + Returns the signed peer record (Envelope) to be sent in an RPC, + by checking whether the host already has a cached signed peer record. + If one exists and its addresses match the host's current listen addresses, + the cached envelope is reused. Otherwise, a new signed peer record is created, + cached, and returned. + + Parameters + ---------- + host : IHost + The local host instance, providing access to peer ID, listen addresses, + private key, and the peerstore. + + Returns + ------- + tuple[bytes, bool] + A tuple containing: + - The serialized envelope (bytes) for the signed peer record. + - A boolean flag indicating whether a new record was created (True) + or an existing cached one was reused (False). + + """ + + listen_addrs_set = {addr for addr in host.get_addrs()} + local_env = host.get_peerstore().get_local_record() + + if local_env is None: + # No cached SPR yet -> create one + return issue_and_cache_local_record(host), True + else: + record_addrs_set = local_env._env_addrs_set() + if record_addrs_set == listen_addrs_set: + # Perfect match -> reuse cached envelope + return local_env.marshal_envelope(), False + else: + # Addresses changed -> issue a new SPR and cache it + return issue_and_cache_local_record(host), True + + +def issue_and_cache_local_record(host: IHost) -> bytes: + """ + Create and cache a new signed peer record (Envelope) for the host. + + This function generates a new signed peer record from the host’s peer ID, + listen addresses, and private key. The resulting envelope is stored in + the peerstore as the local record for future reuse. + + Parameters + ---------- + host : IHost + The local host instance, providing access to peer ID, listen addresses, + private key, and the peerstore. + + Returns + ------- + bytes + The serialized envelope (bytes) representing the newly created signed + peer record. + """ + env = create_signed_peer_record( + host.get_id(), + host.get_addrs(), + host.get_private_key(), + ) + # Cache it for nexxt time use + host.get_peerstore().set_local_record(env) + return env.marshal_envelope() + def new_swarm( key_pair: KeyPair | None = None, diff --git a/libp2p/kad_dht/kad_dht.py b/libp2p/kad_dht/kad_dht.py index 701d8415..39de7cc0 100644 --- a/libp2p/kad_dht/kad_dht.py +++ b/libp2p/kad_dht/kad_dht.py @@ -18,11 +18,12 @@ from multiaddr import ( import trio import varint +from libp2p import env_to_send_in_RPC from libp2p.abc import ( IHost, ) from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager -from libp2p.kad_dht.utils import env_to_send_in_RPC, maybe_consume_signed_record +from libp2p.kad_dht.utils import maybe_consume_signed_record from libp2p.network.stream.net_stream import ( INetStream, ) diff --git a/libp2p/kad_dht/peer_routing.py b/libp2p/kad_dht/peer_routing.py index cf96dd7b..9dc18c83 100644 --- a/libp2p/kad_dht/peer_routing.py +++ b/libp2p/kad_dht/peer_routing.py @@ -10,6 +10,7 @@ import logging import trio import varint +from libp2p import env_to_send_in_RPC from libp2p.abc import ( IHost, INetStream, @@ -34,7 +35,6 @@ from .routing_table import ( RoutingTable, ) from .utils import ( - env_to_send_in_RPC, maybe_consume_signed_record, sort_peer_ids_by_distance, ) diff --git a/libp2p/kad_dht/provider_store.py b/libp2p/kad_dht/provider_store.py index fd780840..45be2dba 100644 --- a/libp2p/kad_dht/provider_store.py +++ b/libp2p/kad_dht/provider_store.py @@ -16,13 +16,14 @@ from multiaddr import ( import trio import varint +from libp2p import env_to_send_in_RPC from libp2p.abc import ( IHost, ) from libp2p.custom_types import ( TProtocol, ) -from libp2p.kad_dht.utils import env_to_send_in_RPC, maybe_consume_signed_record +from libp2p.kad_dht.utils import maybe_consume_signed_record from libp2p.peer.id import ( ID, ) diff --git a/libp2p/kad_dht/utils.py b/libp2p/kad_dht/utils.py index 839efb10..fe768723 100644 --- a/libp2p/kad_dht/utils.py +++ b/libp2p/kad_dht/utils.py @@ -12,7 +12,6 @@ from libp2p.peer.envelope import consume_envelope from libp2p.peer.id import ( ID, ) -from libp2p.peer.peerstore import create_signed_peer_record from .pb.kademlia_pb2 import ( Message, @@ -91,34 +90,6 @@ def maybe_consume_signed_record( return True -def env_to_send_in_RPC(host: IHost) -> tuple[bytes, bool]: - listen_addrs_set = {addr for addr in host.get_addrs()} - local_env = host.get_peerstore().get_local_record() - - if local_env is None: - # No cached SPR yet -> create one - return issue_and_cache_local_record(host), True - else: - record_addrs_set = local_env._env_addrs_set() - if record_addrs_set == listen_addrs_set: - # Perfect match -> reuse cached envelope - return local_env.marshal_envelope(), False - else: - # Addresses changed -> issue a new SPR and cache it - return issue_and_cache_local_record(host), True - - -def issue_and_cache_local_record(host: IHost) -> bytes: - env = create_signed_peer_record( - host.get_id(), - host.get_addrs(), - host.get_private_key(), - ) - # Cache it for nexxt time use - host.get_peerstore().set_local_record(env) - return env.marshal_envelope() - - def create_key_from_binary(binary_data: bytes) -> bytes: """ Creates a key for the DHT by hashing binary data with SHA-256. diff --git a/libp2p/kad_dht/value_store.py b/libp2p/kad_dht/value_store.py index 7ada100f..39223f02 100644 --- a/libp2p/kad_dht/value_store.py +++ b/libp2p/kad_dht/value_store.py @@ -9,13 +9,14 @@ import time import varint +from libp2p import env_to_send_in_RPC from libp2p.abc import ( IHost, ) from libp2p.custom_types import ( TProtocol, ) -from libp2p.kad_dht.utils import env_to_send_in_RPC, maybe_consume_signed_record +from libp2p.kad_dht.utils import maybe_consume_signed_record from libp2p.peer.id import ( ID, ) From 5bf9c7b5379357da21b0d900855d7f81c1197dbf Mon Sep 17 00:00:00 2001 From: lla-dane Date: Sat, 23 Aug 2025 16:07:10 +0530 Subject: [PATCH 39/71] Fix spinx error --- libp2p/__init__.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/libp2p/__init__.py b/libp2p/__init__.py index 5942cd2e..e95bacc0 100644 --- a/libp2p/__init__.py +++ b/libp2p/__init__.py @@ -158,11 +158,12 @@ def get_default_muxer_options() -> TMuxerOptions: def env_to_send_in_RPC(host: IHost) -> tuple[bytes, bool]: """ - Returns the signed peer record (Envelope) to be sent in an RPC, - by checking whether the host already has a cached signed peer record. - If one exists and its addresses match the host's current listen addresses, - the cached envelope is reused. Otherwise, a new signed peer record is created, - cached, and returned. + Return the signed peer record (Envelope) to be sent in an RPC. + + This function checks whether the host already has a cached signed peer record + (SPR). If one exists and its addresses match the host's current listen + addresses, the cached envelope is reused. Otherwise, a new signed peer record + is created, cached, and returned. Parameters ---------- @@ -173,13 +174,11 @@ def env_to_send_in_RPC(host: IHost) -> tuple[bytes, bool]: Returns ------- tuple[bytes, bool] - A tuple containing: - - The serialized envelope (bytes) for the signed peer record. - - A boolean flag indicating whether a new record was created (True) - or an existing cached one was reused (False). - + A 2-tuple where the first element is the serialized envelope (bytes) + for the signed peer record, and the second element is a boolean flag + indicating whether a new record was created (True) or an existing cached + one was reused (False). """ - listen_addrs_set = {addr for addr in host.get_addrs()} local_env = host.get_peerstore().get_local_record() From 91bee9df8915817f742f61bac3ff1da04f929167 Mon Sep 17 00:00:00 2001 From: lla-dane Date: Sat, 23 Aug 2025 16:20:24 +0530 Subject: [PATCH 40/71] Moved env_to_send_in_RPC function to libp2p/peer/peerstore.py --- libp2p/__init__.py | 69 ------------------------------ libp2p/kad_dht/kad_dht.py | 2 +- libp2p/kad_dht/peer_routing.py | 2 +- libp2p/kad_dht/provider_store.py | 2 +- libp2p/kad_dht/value_store.py | 2 +- libp2p/peer/peerstore.py | 72 ++++++++++++++++++++++++++++++++ 6 files changed, 76 insertions(+), 73 deletions(-) diff --git a/libp2p/__init__.py b/libp2p/__init__.py index e95bacc0..350ae46b 100644 --- a/libp2p/__init__.py +++ b/libp2p/__init__.py @@ -156,75 +156,6 @@ def get_default_muxer_options() -> TMuxerOptions: else: # YAMUX is default return create_yamux_muxer_option() -def env_to_send_in_RPC(host: IHost) -> tuple[bytes, bool]: - """ - Return the signed peer record (Envelope) to be sent in an RPC. - - This function checks whether the host already has a cached signed peer record - (SPR). If one exists and its addresses match the host's current listen - addresses, the cached envelope is reused. Otherwise, a new signed peer record - is created, cached, and returned. - - Parameters - ---------- - host : IHost - The local host instance, providing access to peer ID, listen addresses, - private key, and the peerstore. - - Returns - ------- - tuple[bytes, bool] - A 2-tuple where the first element is the serialized envelope (bytes) - for the signed peer record, and the second element is a boolean flag - indicating whether a new record was created (True) or an existing cached - one was reused (False). - """ - listen_addrs_set = {addr for addr in host.get_addrs()} - local_env = host.get_peerstore().get_local_record() - - if local_env is None: - # No cached SPR yet -> create one - return issue_and_cache_local_record(host), True - else: - record_addrs_set = local_env._env_addrs_set() - if record_addrs_set == listen_addrs_set: - # Perfect match -> reuse cached envelope - return local_env.marshal_envelope(), False - else: - # Addresses changed -> issue a new SPR and cache it - return issue_and_cache_local_record(host), True - - -def issue_and_cache_local_record(host: IHost) -> bytes: - """ - Create and cache a new signed peer record (Envelope) for the host. - - This function generates a new signed peer record from the host’s peer ID, - listen addresses, and private key. The resulting envelope is stored in - the peerstore as the local record for future reuse. - - Parameters - ---------- - host : IHost - The local host instance, providing access to peer ID, listen addresses, - private key, and the peerstore. - - Returns - ------- - bytes - The serialized envelope (bytes) representing the newly created signed - peer record. - """ - env = create_signed_peer_record( - host.get_id(), - host.get_addrs(), - host.get_private_key(), - ) - # Cache it for nexxt time use - host.get_peerstore().set_local_record(env) - return env.marshal_envelope() - - def new_swarm( key_pair: KeyPair | None = None, muxer_opt: TMuxerOptions | None = None, diff --git a/libp2p/kad_dht/kad_dht.py b/libp2p/kad_dht/kad_dht.py index 39de7cc0..b5064e23 100644 --- a/libp2p/kad_dht/kad_dht.py +++ b/libp2p/kad_dht/kad_dht.py @@ -18,7 +18,6 @@ from multiaddr import ( import trio import varint -from libp2p import env_to_send_in_RPC from libp2p.abc import ( IHost, ) @@ -34,6 +33,7 @@ from libp2p.peer.id import ( from libp2p.peer.peerinfo import ( PeerInfo, ) +from libp2p.peer.peerstore import env_to_send_in_RPC from libp2p.tools.async_service import ( Service, ) diff --git a/libp2p/kad_dht/peer_routing.py b/libp2p/kad_dht/peer_routing.py index 9dc18c83..195209a2 100644 --- a/libp2p/kad_dht/peer_routing.py +++ b/libp2p/kad_dht/peer_routing.py @@ -10,7 +10,6 @@ import logging import trio import varint -from libp2p import env_to_send_in_RPC from libp2p.abc import ( IHost, INetStream, @@ -23,6 +22,7 @@ from libp2p.peer.id import ( from libp2p.peer.peerinfo import ( PeerInfo, ) +from libp2p.peer.peerstore import env_to_send_in_RPC from .common import ( ALPHA, diff --git a/libp2p/kad_dht/provider_store.py b/libp2p/kad_dht/provider_store.py index 45be2dba..77bb464f 100644 --- a/libp2p/kad_dht/provider_store.py +++ b/libp2p/kad_dht/provider_store.py @@ -16,7 +16,6 @@ from multiaddr import ( import trio import varint -from libp2p import env_to_send_in_RPC from libp2p.abc import ( IHost, ) @@ -30,6 +29,7 @@ from libp2p.peer.id import ( from libp2p.peer.peerinfo import ( PeerInfo, ) +from libp2p.peer.peerstore import env_to_send_in_RPC from .common import ( ALPHA, diff --git a/libp2p/kad_dht/value_store.py b/libp2p/kad_dht/value_store.py index 39223f02..2002965f 100644 --- a/libp2p/kad_dht/value_store.py +++ b/libp2p/kad_dht/value_store.py @@ -9,7 +9,6 @@ import time import varint -from libp2p import env_to_send_in_RPC from libp2p.abc import ( IHost, ) @@ -20,6 +19,7 @@ from libp2p.kad_dht.utils import maybe_consume_signed_record from libp2p.peer.id import ( ID, ) +from libp2p.peer.peerstore import env_to_send_in_RPC from .common import ( DEFAULT_TTL, diff --git a/libp2p/peer/peerstore.py b/libp2p/peer/peerstore.py index ad6f08db..993a8523 100644 --- a/libp2p/peer/peerstore.py +++ b/libp2p/peer/peerstore.py @@ -16,6 +16,7 @@ import trio from trio import MemoryReceiveChannel, MemorySendChannel from libp2p.abc import ( + IHost, IPeerStore, ) from libp2p.crypto.keys import ( @@ -49,6 +50,77 @@ def create_signed_peer_record( return envelope +def env_to_send_in_RPC(host: IHost) -> tuple[bytes, bool]: + """ + Return the signed peer record (Envelope) to be sent in an RPC. + + This function checks whether the host already has a cached signed peer record + (SPR). If one exists and its addresses match the host's current listen + addresses, the cached envelope is reused. Otherwise, a new signed peer record + is created, cached, and returned. + + Parameters + ---------- + host : IHost + The local host instance, providing access to peer ID, listen addresses, + private key, and the peerstore. + + Returns + ------- + tuple[bytes, bool] + A 2-tuple where the first element is the serialized envelope (bytes) + for the signed peer record, and the second element is a boolean flag + indicating whether a new record was created (True) or an existing cached + one was reused (False). + + """ + listen_addrs_set = {addr for addr in host.get_addrs()} + local_env = host.get_peerstore().get_local_record() + + if local_env is None: + # No cached SPR yet -> create one + return issue_and_cache_local_record(host), True + else: + record_addrs_set = local_env._env_addrs_set() + if record_addrs_set == listen_addrs_set: + # Perfect match -> reuse cached envelope + return local_env.marshal_envelope(), False + else: + # Addresses changed -> issue a new SPR and cache it + return issue_and_cache_local_record(host), True + + +def issue_and_cache_local_record(host: IHost) -> bytes: + """ + Create and cache a new signed peer record (Envelope) for the host. + + This function generates a new signed peer record from the host’s peer ID, + listen addresses, and private key. The resulting envelope is stored in + the peerstore as the local record for future reuse. + + Parameters + ---------- + host : IHost + The local host instance, providing access to peer ID, listen addresses, + private key, and the peerstore. + + Returns + ------- + bytes + The serialized envelope (bytes) representing the newly created signed + peer record. + + """ + env = create_signed_peer_record( + host.get_id(), + host.get_addrs(), + host.get_private_key(), + ) + # Cache it for nexxt time use + host.get_peerstore().set_local_record(env) + return env.marshal_envelope() + + class PeerRecordState: envelope: Envelope seq: int From 7b2d637382d34be4e1ba6621f080965fba5eb5aa Mon Sep 17 00:00:00 2001 From: lla-dane Date: Sat, 23 Aug 2025 16:30:34 +0530 Subject: [PATCH 41/71] Now using env_to_send_in_RPC for issuing records in Identify rpc messages --- libp2p/identity/identify/identify.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/libp2p/identity/identify/identify.py b/libp2p/identity/identify/identify.py index 4e8931ba..146fbd2d 100644 --- a/libp2p/identity/identify/identify.py +++ b/libp2p/identity/identify/identify.py @@ -15,7 +15,7 @@ from libp2p.custom_types import ( from libp2p.network.stream.exceptions import ( StreamClosed, ) -from libp2p.peer.peerstore import create_signed_peer_record +from libp2p.peer.peerstore import env_to_send_in_RPC from libp2p.utils import ( decode_varint_with_size, get_agent_version, @@ -65,11 +65,7 @@ def _mk_identify_protobuf( protocols = tuple(str(p) for p in host.get_mux().get_protocols() if p is not None) # Create a signed peer-record for the remote peer - envelope = create_signed_peer_record( - host.get_id(), - host.get_addrs(), - host.get_private_key(), - ) + envelope_bytes, _ = env_to_send_in_RPC(host) observed_addr = observed_multiaddr.to_bytes() if observed_multiaddr else b"" return Identify( @@ -79,7 +75,7 @@ def _mk_identify_protobuf( listen_addrs=map(_multiaddr_to_bytes, laddrs), observed_addr=observed_addr, protocols=protocols, - signedPeerRecord=envelope.marshal_envelope(), + signedPeerRecord=envelope_bytes, ) From fe3f7adc1b66f0e97a45f257ff353259354e1147 Mon Sep 17 00:00:00 2001 From: lla-dane Date: Tue, 26 Aug 2025 12:49:33 +0530 Subject: [PATCH 42/71] fix typos --- libp2p/kad_dht/peer_routing.py | 2 +- libp2p/peer/peerstore.py | 2 +- tests/core/kad_dht/test_kad_dht.py | 14 +++++++------- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/libp2p/kad_dht/peer_routing.py b/libp2p/kad_dht/peer_routing.py index 195209a2..f5313cb6 100644 --- a/libp2p/kad_dht/peer_routing.py +++ b/libp2p/kad_dht/peer_routing.py @@ -376,7 +376,7 @@ class PeerRouting(IPeerRouting): # Consume the sender's signed-peer-record if sent if not maybe_consume_signed_record(kad_message, self.host, peer_id): logger.error( - "Receivedf an invalid-signed-record, dropping the stream" + "Received an invalid-signed-record, dropping the stream" ) return diff --git a/libp2p/peer/peerstore.py b/libp2p/peer/peerstore.py index 993a8523..ddf1af1f 100644 --- a/libp2p/peer/peerstore.py +++ b/libp2p/peer/peerstore.py @@ -116,7 +116,7 @@ def issue_and_cache_local_record(host: IHost) -> bytes: host.get_addrs(), host.get_private_key(), ) - # Cache it for nexxt time use + # Cache it for next time use host.get_peerstore().set_local_record(env) return env.marshal_envelope() diff --git a/tests/core/kad_dht/test_kad_dht.py b/tests/core/kad_dht/test_kad_dht.py index 0d9a29f7..285268d9 100644 --- a/tests/core/kad_dht/test_kad_dht.py +++ b/tests/core/kad_dht/test_kad_dht.py @@ -109,7 +109,7 @@ async def test_find_node(dht_pair: tuple[KadDHT, KadDHT]): dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id()), Envelope ) - # These are the records that were sent betweeen the peers during the FIND_NODE req + # These are the records that were sent between the peers during the FIND_NODE req envelope_a_find_peer = dht_a.host.get_peerstore().get_peer_record( dht_b.host.get_id() ) @@ -124,7 +124,7 @@ async def test_find_node(dht_pair: tuple[KadDHT, KadDHT]): record_b_find_peer = envelope_b_find_peer.record() # This proves that both the records are same, and a latest cached signed record - # was passed between the peers during FIND_NODE exceution, which proves the + # was passed between the peers during FIND_NODE execution, which proves the # signed-record transfer/re-issuing works correctly in FIND_NODE executions. assert record_a.seq == record_a_find_peer.seq assert record_b.seq == record_b_find_peer.seq @@ -168,7 +168,7 @@ async def test_put_and_get_value(dht_pair: tuple[KadDHT, KadDHT]): with trio.fail_after(TEST_TIMEOUT): await dht_a.put_value(key, value) - # These are the records that were sent betweeen the peers during the PUT_VALUE req + # These are the records that were sent between the peers during the PUT_VALUE req envelope_a_put_value = dht_a.host.get_peerstore().get_peer_record( dht_b.host.get_id() ) @@ -183,7 +183,7 @@ async def test_put_and_get_value(dht_pair: tuple[KadDHT, KadDHT]): record_b_put_value = envelope_b_put_value.record() # This proves that both the records are same, and a latest cached signed record - # was passed between the peers during PUT_VALUE exceution, which proves the + # was passed between the peers during PUT_VALUE execution, which proves the # signed-record transfer/re-issuing works correctly in PUT_VALUE executions. assert record_a.seq == record_a_put_value.seq assert record_b.seq == record_b_put_value.seq @@ -205,7 +205,7 @@ async def test_put_and_get_value(dht_pair: tuple[KadDHT, KadDHT]): print("the value stored in node b is", dht_b.get_value_store_size()) logger.debug("Retrieved value: %s", retrieved_value) - # These are the records that were sent betweeen the peers during the PUT_VALUE req + # These are the records that were sent between the peers during the PUT_VALUE req envelope_a_get_value = dht_a.host.get_peerstore().get_peer_record( dht_b.host.get_id() ) @@ -257,7 +257,7 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]): success = await dht_a.provide(content_id) assert success, "Failed to advertise as provider" - # These are the records that were sent betweeen the peers during + # These are the records that were sent between the peers during # the ADD_PROVIDER req envelope_a_add_prov = dht_a.host.get_peerstore().get_peer_record( dht_b.host.get_id() @@ -273,7 +273,7 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]): record_b_add_prov = envelope_b_add_prov.record() # This proves that both the records are same, the latest cached signed record - # was passed between the peers during ADD_PROVIDER exceution, which proves the + # was passed between the peers during ADD_PROVIDER execution, which proves the # signed-record transfer/re-issuing of the latest record works correctly in # ADD_PROVIDER executions. assert record_a.seq == record_a_add_prov.seq From 2006b2c92cf24263336e50d4e2ccb6b056716f75 Mon Sep 17 00:00:00 2001 From: lla-dane Date: Tue, 26 Aug 2025 12:59:18 +0530 Subject: [PATCH 43/71] added newsfragment --- newsfragments/815.feature.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 newsfragments/815.feature.rst diff --git a/newsfragments/815.feature.rst b/newsfragments/815.feature.rst new file mode 100644 index 00000000..8fcf6fea --- /dev/null +++ b/newsfragments/815.feature.rst @@ -0,0 +1 @@ +KAD-DHT now include signed-peer-records in its protobuf message schema, for more secure peer-discovery. From 943bcc4d36455026e08152b06f967eafe4df2e6f Mon Sep 17 00:00:00 2001 From: lla-dane Date: Wed, 27 Aug 2025 10:17:40 +0530 Subject: [PATCH 44/71] fix the logic error in add_provider handling --- libp2p/kad_dht/kad_dht.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libp2p/kad_dht/kad_dht.py b/libp2p/kad_dht/kad_dht.py index b5064e23..0d05aaf8 100644 --- a/libp2p/kad_dht/kad_dht.py +++ b/libp2p/kad_dht/kad_dht.py @@ -378,7 +378,7 @@ class KadDHT(Service): # Process the signed-records of provider if sent if not maybe_consume_signed_record( - message, self.host, peer_id + provider_proto, self.host ): logger.error( "Received an invalid-signed-record," From c2c4228591eadac7bb0d1c3bdd5aa0a697fe7d7f Mon Sep 17 00:00:00 2001 From: lla-dane Date: Wed, 27 Aug 2025 13:02:32 +0530 Subject: [PATCH 45/71] added test for ADD_PROVIDER record processing --- tests/core/kad_dht/test_kad_dht.py | 36 ++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tests/core/kad_dht/test_kad_dht.py b/tests/core/kad_dht/test_kad_dht.py index 285268d9..5bf4f3e8 100644 --- a/tests/core/kad_dht/test_kad_dht.py +++ b/tests/core/kad_dht/test_kad_dht.py @@ -31,6 +31,7 @@ from libp2p.peer.peer_record import PeerRecord from libp2p.peer.peerinfo import ( PeerInfo, ) +from libp2p.peer.peerstore import create_signed_peer_record from libp2p.tools.async_service import ( background_trio_service, ) @@ -340,6 +341,41 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]): assert record_a_find_prov.seq == record_a_get_value.seq assert record_b_find_prov.seq == record_b_get_value.seq + # Create a new provider record in dht_a + provider_key_pair = create_new_key_pair() + provider_peer_id = ID.from_pubkey(provider_key_pair.public_key) + provider_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/123") + provider_peer_info = PeerInfo(peer_id=provider_peer_id, addrs=[provider_addr]) + + # Generate a random content ID + content_2 = f"random-content-{uuid.uuid4()}".encode() + content_id_2 = hashlib.sha256(content_2).digest() + + provider_signed_envelope = create_signed_peer_record( + provider_peer_id, [provider_addr], provider_key_pair.private_key + ) + assert ( + dht_a.host.get_peerstore().consume_peer_record(provider_signed_envelope, 7200) + is True + ) + + # Store this provider record in dht_a + dht_a.provider_store.add_provider(content_id_2, provider_peer_info) + + # Fetch the provider-record via peer-discovery at dht_b's end + peerinfo = await dht_b.provider_store.find_providers(content_id_2) + + assert len(peerinfo) == 1 + assert peerinfo[0].peer_id == provider_peer_id + provider_envelope = dht_b.host.get_peerstore().get_peer_record(provider_peer_id) + + # This proves that the signed-envelope of provider is consumed on dht_b's end + assert provider_envelope is not None + assert ( + provider_signed_envelope.marshal_envelope() + == provider_envelope.marshal_envelope() + ) + @pytest.mark.trio async def test_reissue_when_listen_addrs_change(dht_pair: tuple[KadDHT, KadDHT]): From c08007feda758dfc16efb29940f153e751d8922c Mon Sep 17 00:00:00 2001 From: unniznd Date: Wed, 27 Aug 2025 21:54:05 +0530 Subject: [PATCH 46/71] improve error message in basic host --- libp2p/host/basic_host.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index 1ef5dda2..ee1bb04d 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -290,7 +290,9 @@ class BasicHost(IHost): ) if protocol is None: await net_stream.reset() - raise StreamFailure("No protocol selected") + raise StreamFailure( + "Failed to negotiate protocol: no protocol selected" + ) except MultiselectError as error: peer_id = net_stream.muxed_conn.peer_id logger.debug( From 9f80dbae12920622416cd774b6db1198965cb718 Mon Sep 17 00:00:00 2001 From: unniznd Date: Wed, 27 Aug 2025 22:05:19 +0530 Subject: [PATCH 47/71] added the testcase for StreamFailure --- tests/core/host/test_basic_host.py | 37 ++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/tests/core/host/test_basic_host.py b/tests/core/host/test_basic_host.py index ed21ad80..635f2863 100644 --- a/tests/core/host/test_basic_host.py +++ b/tests/core/host/test_basic_host.py @@ -1,3 +1,10 @@ +from unittest.mock import ( + AsyncMock, + MagicMock, +) + +import pytest + from libp2p import ( new_swarm, ) @@ -10,6 +17,9 @@ from libp2p.host.basic_host import ( from libp2p.host.defaults import ( get_default_protocols, ) +from libp2p.host.exceptions import ( + StreamFailure, +) def test_default_protocols(): @@ -22,3 +32,30 @@ def test_default_protocols(): # NOTE: comparing keys for equality as handlers may be closures that do not compare # in the way this test is concerned with assert handlers.keys() == get_default_protocols(host).keys() + + +@pytest.mark.trio +async def test_swarm_stream_handler_no_protocol_selected(monkeypatch): + key_pair = create_new_key_pair() + swarm = new_swarm(key_pair) + host = BasicHost(swarm) + + # Create a mock net_stream + net_stream = MagicMock() + net_stream.reset = AsyncMock() + net_stream.muxed_conn.peer_id = "peer-test" + + # Monkeypatch negotiate to simulate "no protocol selected" + async def fake_negotiate(comm, timeout): + return None, None + + monkeypatch.setattr(host.multiselect, "negotiate", fake_negotiate) + + # Now run the handler and expect StreamFailure + with pytest.raises( + StreamFailure, match="Failed to negotiate protocol: no protocol selected" + ): + await host._swarm_stream_handler(net_stream) + + # Ensure reset was called since negotiation failed + net_stream.reset.assert_awaited() From c577fd2f7133d7fa7e9e80920db15f9eb23e15be Mon Sep 17 00:00:00 2001 From: bomanaps Date: Thu, 28 Aug 2025 20:59:36 +0100 Subject: [PATCH 48/71] feat(swarm): enhance swarm with retry backoff --- examples/enhanced_swarm_example.py | 220 +++++++++++ libp2p/__init__.py | 38 +- libp2p/network/swarm.py | 349 +++++++++++++++++- newsfragments/874.feature.rst | 1 + tests/core/network/test_enhanced_swarm.py | 428 ++++++++++++++++++++++ 5 files changed, 1015 insertions(+), 21 deletions(-) create mode 100644 examples/enhanced_swarm_example.py create mode 100644 newsfragments/874.feature.rst create mode 100644 tests/core/network/test_enhanced_swarm.py diff --git a/examples/enhanced_swarm_example.py b/examples/enhanced_swarm_example.py new file mode 100644 index 00000000..37770411 --- /dev/null +++ b/examples/enhanced_swarm_example.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python3 +""" +Example demonstrating the enhanced Swarm with retry logic, exponential backoff, +and multi-connection support. + +This example shows how to: +1. Configure retry behavior with exponential backoff +2. Enable multi-connection support with connection pooling +3. Use different load balancing strategies +4. Maintain backward compatibility +""" + +import asyncio +import logging + +from libp2p import new_swarm +from libp2p.network.swarm import ConnectionConfig, RetryConfig + +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +async def example_basic_enhanced_swarm() -> None: + """Example of basic enhanced Swarm usage.""" + logger.info("Creating enhanced Swarm with default configuration...") + + # Create enhanced swarm with default retry and connection config + swarm = new_swarm() + # Use default configuration values directly + default_retry = RetryConfig() + default_connection = ConnectionConfig() + + logger.info(f"Swarm created with peer ID: {swarm.get_peer_id()}") + logger.info( + f"Retry config: max_retries={default_retry.max_retries}" + ) + logger.info( + f"Connection config: max_connections_per_peer=" + f"{default_connection.max_connections_per_peer}" + ) + logger.info( + f"Connection pool enabled: {default_connection.enable_connection_pool}" + ) + + await swarm.close() + logger.info("Basic enhanced Swarm example completed") + + +async def example_custom_retry_config() -> None: + """Example of custom retry configuration.""" + logger.info("Creating enhanced Swarm with custom retry configuration...") + + # Custom retry configuration for aggressive retry behavior + retry_config = RetryConfig( + max_retries=5, # More retries + initial_delay=0.05, # Faster initial retry + max_delay=10.0, # Lower max delay + backoff_multiplier=1.5, # Less aggressive backoff + jitter_factor=0.2 # More jitter + ) + + # Create swarm with custom retry config + swarm = new_swarm(retry_config=retry_config) + + logger.info("Custom retry config applied:") + logger.info( + f" Max retries: {retry_config.max_retries}" + ) + logger.info( + f" Initial delay: {retry_config.initial_delay}s" + ) + logger.info( + f" Max delay: {retry_config.max_delay}s" + ) + logger.info( + f" Backoff multiplier: {retry_config.backoff_multiplier}" + ) + logger.info( + f" Jitter factor: {retry_config.jitter_factor}" + ) + + await swarm.close() + logger.info("Custom retry config example completed") + + +async def example_custom_connection_config() -> None: + """Example of custom connection configuration.""" + logger.info("Creating enhanced Swarm with custom connection configuration...") + + # Custom connection configuration for high-performance scenarios + connection_config = ConnectionConfig( + max_connections_per_peer=5, # More connections per peer + connection_timeout=60.0, # Longer timeout + enable_connection_pool=True, # Enable connection pooling + load_balancing_strategy="least_loaded" # Use least loaded strategy + ) + + # Create swarm with custom connection config + swarm = new_swarm(connection_config=connection_config) + + logger.info("Custom connection config applied:") + logger.info( + f" Max connections per peer: " + f"{connection_config.max_connections_per_peer}" + ) + logger.info( + f" Connection timeout: {connection_config.connection_timeout}s" + ) + logger.info( + f" Connection pool enabled: " + f"{connection_config.enable_connection_pool}" + ) + logger.info( + f" Load balancing strategy: " + f"{connection_config.load_balancing_strategy}" + ) + + await swarm.close() + logger.info("Custom connection config example completed") + + +async def example_backward_compatibility() -> None: + """Example showing backward compatibility.""" + logger.info("Creating enhanced Swarm with backward compatibility...") + + # Disable connection pool to maintain original behavior + connection_config = ConnectionConfig(enable_connection_pool=False) + + # Create swarm with connection pool disabled + swarm = new_swarm(connection_config=connection_config) + + logger.info("Backward compatibility mode:") + logger.info( + f" Connection pool enabled: {connection_config.enable_connection_pool}" + ) + logger.info( + f" Connections dict type: {type(swarm.connections)}" + ) + logger.info( + " Retry logic still available: 3 max retries" + ) + + await swarm.close() + logger.info("Backward compatibility example completed") + + +async def example_production_ready_config() -> None: + """Example of production-ready configuration.""" + logger.info("Creating enhanced Swarm with production-ready configuration...") + + # Production-ready retry configuration + retry_config = RetryConfig( + max_retries=3, # Reasonable retry limit + initial_delay=0.1, # Quick initial retry + max_delay=30.0, # Cap exponential backoff + backoff_multiplier=2.0, # Standard exponential backoff + jitter_factor=0.1 # Small jitter to prevent thundering herd + ) + + # Production-ready connection configuration + connection_config = ConnectionConfig( + max_connections_per_peer=3, # Balance between performance and resource usage + connection_timeout=30.0, # Reasonable timeout + enable_connection_pool=True, # Enable for better performance + load_balancing_strategy="round_robin" # Simple, predictable strategy + ) + + # Create swarm with production config + swarm = new_swarm( + retry_config=retry_config, + connection_config=connection_config + ) + + logger.info("Production-ready configuration applied:") + logger.info( + f" Retry: {retry_config.max_retries} retries, " + f"{retry_config.max_delay}s max delay" + ) + logger.info( + f" Connections: {connection_config.max_connections_per_peer} per peer" + ) + logger.info( + f" Load balancing: {connection_config.load_balancing_strategy}" + ) + + await swarm.close() + logger.info("Production-ready configuration example completed") + + +async def main() -> None: + """Run all examples.""" + logger.info("Enhanced Swarm Examples") + logger.info("=" * 50) + + try: + await example_basic_enhanced_swarm() + logger.info("-" * 30) + + await example_custom_retry_config() + logger.info("-" * 30) + + await example_custom_connection_config() + logger.info("-" * 30) + + await example_backward_compatibility() + logger.info("-" * 30) + + await example_production_ready_config() + logger.info("-" * 30) + + logger.info("All examples completed successfully!") + + except Exception as e: + logger.error(f"Example failed: {e}") + raise + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/libp2p/__init__.py b/libp2p/__init__.py index d2ce122a..ff3a70fc 100644 --- a/libp2p/__init__.py +++ b/libp2p/__init__.py @@ -1,3 +1,5 @@ +"""Libp2p Python implementation.""" + from collections.abc import ( Mapping, Sequence, @@ -6,15 +8,12 @@ from importlib.metadata import version as __version from typing import ( Literal, Optional, - Type, - cast, ) import multiaddr from libp2p.abc import ( IHost, - IMuxedConn, INetworkService, IPeerRouting, IPeerStore, @@ -32,9 +31,6 @@ from libp2p.custom_types import ( TProtocol, TSecurityOptions, ) -from libp2p.discovery.mdns.mdns import ( - MDNSDiscovery, -) from libp2p.host.basic_host import ( BasicHost, ) @@ -42,6 +38,8 @@ from libp2p.host.routed_host import ( RoutedHost, ) from libp2p.network.swarm import ( + ConnectionConfig, + RetryConfig, Swarm, ) from libp2p.peer.id import ( @@ -54,17 +52,19 @@ from libp2p.security.insecure.transport import ( PLAINTEXT_PROTOCOL_ID, InsecureTransport, ) -from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID -from libp2p.security.noise.transport import Transport as NoiseTransport +from libp2p.security.noise.transport import ( + PROTOCOL_ID as NOISE_PROTOCOL_ID, + Transport as NoiseTransport, +) import libp2p.security.secio.transport as secio from libp2p.stream_muxer.mplex.mplex import ( MPLEX_PROTOCOL_ID, Mplex, ) from libp2p.stream_muxer.yamux.yamux import ( + PROTOCOL_ID as YAMUX_PROTOCOL_ID, Yamux, ) -from libp2p.stream_muxer.yamux.yamux import PROTOCOL_ID as YAMUX_PROTOCOL_ID from libp2p.transport.tcp.tcp import ( TCP, ) @@ -87,7 +87,6 @@ MUXER_MPLEX = "MPLEX" DEFAULT_NEGOTIATE_TIMEOUT = 5 - def set_default_muxer(muxer_name: Literal["YAMUX", "MPLEX"]) -> None: """ Set the default multiplexer protocol to use. @@ -163,6 +162,8 @@ def new_swarm( peerstore_opt: IPeerStore | None = None, muxer_preference: Literal["YAMUX", "MPLEX"] | None = None, listen_addrs: Sequence[multiaddr.Multiaddr] | None = None, + retry_config: Optional["RetryConfig"] = None, + connection_config: Optional["ConnectionConfig"] = None, ) -> INetworkService: """ Create a swarm instance based on the parameters. @@ -239,7 +240,14 @@ def new_swarm( # Store our key pair in peerstore peerstore.add_key_pair(id_opt, key_pair) - return Swarm(id_opt, peerstore, upgrader, transport) + return Swarm( + id_opt, + peerstore, + upgrader, + transport, + retry_config=retry_config, + connection_config=connection_config + ) def new_host( @@ -279,6 +287,12 @@ def new_host( if disc_opt is not None: return RoutedHost(swarm, disc_opt, enable_mDNS, bootstrap) - return BasicHost(network=swarm,enable_mDNS=enable_mDNS , bootstrap=bootstrap, negotitate_timeout=negotiate_timeout) + return BasicHost( + network=swarm, + enable_mDNS=enable_mDNS, + bootstrap=bootstrap, + negotitate_timeout=negotiate_timeout + ) + __version__ = __version("libp2p") diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 67d46279..77fe2b6d 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -2,7 +2,9 @@ from collections.abc import ( Awaitable, Callable, ) +from dataclasses import dataclass import logging +import random from multiaddr import ( Multiaddr, @@ -59,6 +61,188 @@ from .exceptions import ( logger = logging.getLogger("libp2p.network.swarm") +@dataclass +class RetryConfig: + """Configuration for retry logic with exponential backoff.""" + + max_retries: int = 3 + initial_delay: float = 0.1 + max_delay: float = 30.0 + backoff_multiplier: float = 2.0 + jitter_factor: float = 0.1 + + +@dataclass +class ConnectionConfig: + """Configuration for connection pool and multi-connection support.""" + + max_connections_per_peer: int = 3 + connection_timeout: float = 30.0 + enable_connection_pool: bool = True + load_balancing_strategy: str = "round_robin" # or "least_loaded" + + +@dataclass +class ConnectionInfo: + """Information about a connection in the pool.""" + + connection: INetConn + address: str + established_at: float + last_used: float + stream_count: int + is_healthy: bool + + +class ConnectionPool: + """Manages multiple connections per peer with load balancing.""" + + def __init__(self, max_connections_per_peer: int = 3): + self.max_connections_per_peer = max_connections_per_peer + self.peer_connections: dict[ID, list[ConnectionInfo]] = {} + self._round_robin_index: dict[ID, int] = {} + + def add_connection(self, peer_id: ID, connection: INetConn, address: str) -> None: + """Add a connection to the pool with deduplication.""" + if peer_id not in self.peer_connections: + self.peer_connections[peer_id] = [] + + # Check for duplicate connections to the same address + for conn_info in self.peer_connections[peer_id]: + if conn_info.address == address: + logger.debug( + f"Connection to {address} already exists for peer {peer_id}" + ) + return + + # Add new connection + try: + current_time = trio.current_time() + except RuntimeError: + # Fallback for testing contexts where trio is not running + import time + + current_time = time.time() + + conn_info = ConnectionInfo( + connection=connection, + address=address, + established_at=current_time, + last_used=current_time, + stream_count=0, + is_healthy=True, + ) + + self.peer_connections[peer_id].append(conn_info) + + # Trim if we exceed max connections + if len(self.peer_connections[peer_id]) > self.max_connections_per_peer: + self._trim_connections(peer_id) + + def get_connection( + self, peer_id: ID, strategy: str = "round_robin" + ) -> INetConn | None: + """Get a connection using the specified load balancing strategy.""" + if peer_id not in self.peer_connections or not self.peer_connections[peer_id]: + return None + + connections = self.peer_connections[peer_id] + + if strategy == "round_robin": + if peer_id not in self._round_robin_index: + self._round_robin_index[peer_id] = 0 + + index = self._round_robin_index[peer_id] % len(connections) + self._round_robin_index[peer_id] += 1 + + conn_info = connections[index] + try: + conn_info.last_used = trio.current_time() + except RuntimeError: + import time + + conn_info.last_used = time.time() + return conn_info.connection + + elif strategy == "least_loaded": + # Find connection with least streams + # Note: stream_count is a custom attribute we add to connections + conn_info = min( + connections, key=lambda c: getattr(c.connection, "stream_count", 0) + ) + try: + conn_info.last_used = trio.current_time() + except RuntimeError: + import time + + conn_info.last_used = time.time() + return conn_info.connection + + else: + # Default to first connection + conn_info = connections[0] + try: + conn_info.last_used = trio.current_time() + except RuntimeError: + import time + + conn_info.last_used = time.time() + return conn_info.connection + + def has_connection(self, peer_id: ID) -> bool: + """Check if we have any connections to the peer.""" + return ( + peer_id in self.peer_connections and len(self.peer_connections[peer_id]) > 0 + ) + + def remove_connection(self, peer_id: ID, connection: INetConn) -> None: + """Remove a connection from the pool.""" + if peer_id in self.peer_connections: + self.peer_connections[peer_id] = [ + conn_info + for conn_info in self.peer_connections[peer_id] + if conn_info.connection != connection + ] + + # Clean up empty peer entries + if not self.peer_connections[peer_id]: + del self.peer_connections[peer_id] + if peer_id in self._round_robin_index: + del self._round_robin_index[peer_id] + + def _trim_connections(self, peer_id: ID) -> None: + """Remove oldest connections when limit is exceeded.""" + connections = self.peer_connections[peer_id] + if len(connections) <= self.max_connections_per_peer: + return + + # Sort by last used time and remove oldest + connections.sort(key=lambda c: c.last_used) + connections_to_remove = connections[: -self.max_connections_per_peer] + + for conn_info in connections_to_remove: + logger.debug( + f"Trimming old connection to {conn_info.address} for peer {peer_id}" + ) + try: + # Close the connection asynchronously + trio.lowlevel.spawn_system_task( + self._close_connection_async, conn_info.connection + ) + except Exception as e: + logger.warning(f"Error closing trimmed connection: {e}") + + # Keep only the most recently used connections + self.peer_connections[peer_id] = connections[-self.max_connections_per_peer :] + + async def _close_connection_async(self, connection: INetConn) -> None: + """Close a connection asynchronously.""" + try: + await connection.close() + except Exception as e: + logger.warning(f"Error closing connection: {e}") + + def create_default_stream_handler(network: INetworkService) -> StreamHandlerFn: async def stream_handler(stream: INetStream) -> None: await network.get_manager().wait_finished() @@ -71,9 +255,8 @@ class Swarm(Service, INetworkService): peerstore: IPeerStore upgrader: TransportUpgrader transport: ITransport - # TODO: Connection and `peer_id` are 1-1 mapping in our implementation, - # whereas in Go one `peer_id` may point to multiple connections. - connections: dict[ID, INetConn] + # Enhanced: Support for multiple connections per peer + connections: dict[ID, INetConn] # Backward compatibility listeners: dict[str, IListener] common_stream_handler: StreamHandlerFn listener_nursery: trio.Nursery | None @@ -81,17 +264,38 @@ class Swarm(Service, INetworkService): notifees: list[INotifee] + # Enhanced: New configuration and connection pool + retry_config: RetryConfig + connection_config: ConnectionConfig + connection_pool: ConnectionPool | None + def __init__( self, peer_id: ID, peerstore: IPeerStore, upgrader: TransportUpgrader, transport: ITransport, + retry_config: RetryConfig | None = None, + connection_config: ConnectionConfig | None = None, ): self.self_id = peer_id self.peerstore = peerstore self.upgrader = upgrader self.transport = transport + + # Enhanced: Initialize retry and connection configuration + self.retry_config = retry_config or RetryConfig() + self.connection_config = connection_config or ConnectionConfig() + + # Enhanced: Initialize connection pool if enabled + if self.connection_config.enable_connection_pool: + self.connection_pool = ConnectionPool( + self.connection_config.max_connections_per_peer + ) + else: + self.connection_pool = None + + # Backward compatibility: Keep existing connections dict self.connections = dict() self.listeners = dict() @@ -124,12 +328,20 @@ class Swarm(Service, INetworkService): async def dial_peer(self, peer_id: ID) -> INetConn: """ - Try to create a connection to peer_id. + Try to create a connection to peer_id with enhanced retry logic. :param peer_id: peer if we want to dial :raises SwarmException: raised when an error occurs :return: muxed connection """ + # Enhanced: Check connection pool first if enabled + if self.connection_pool and self.connection_pool.has_connection(peer_id): + connection = self.connection_pool.get_connection(peer_id) + if connection: + logger.debug(f"Reusing existing connection to peer {peer_id}") + return connection + + # Enhanced: Check existing single connection for backward compatibility if peer_id in self.connections: # If muxed connection already exists for peer_id, # set muxed connection equal to existing muxed connection @@ -148,10 +360,21 @@ class Swarm(Service, INetworkService): exceptions: list[SwarmException] = [] - # Try all known addresses + # Enhanced: Try all known addresses with retry logic for multiaddr in addrs: try: - return await self.dial_addr(multiaddr, peer_id) + connection = await self._dial_with_retry(multiaddr, peer_id) + + # Enhanced: Add to connection pool if enabled + if self.connection_pool: + self.connection_pool.add_connection( + peer_id, connection, str(multiaddr) + ) + + # Backward compatibility: Keep existing connections dict + self.connections[peer_id] = connection + + return connection except SwarmException as e: exceptions.append(e) logger.debug( @@ -167,9 +390,64 @@ class Swarm(Service, INetworkService): "connection (with exceptions)" ) from MultiError(exceptions) - async def dial_addr(self, addr: Multiaddr, peer_id: ID) -> INetConn: + async def _dial_with_retry(self, addr: Multiaddr, peer_id: ID) -> INetConn: """ - Try to create a connection to peer_id with addr. + Enhanced: Dial with retry logic and exponential backoff. + + :param addr: the address to dial + :param peer_id: the peer we want to connect to + :raises SwarmException: raised when all retry attempts fail + :return: network connection + """ + last_exception = None + + for attempt in range(self.retry_config.max_retries + 1): + try: + return await self._dial_addr_single_attempt(addr, peer_id) + except Exception as e: + last_exception = e + if attempt < self.retry_config.max_retries: + delay = self._calculate_backoff_delay(attempt) + logger.debug( + f"Connection attempt {attempt + 1} failed, " + f"retrying in {delay:.2f}s: {e}" + ) + await trio.sleep(delay) + else: + logger.debug(f"All {self.retry_config.max_retries} attempts failed") + + # Convert the last exception to SwarmException for consistency + if last_exception is not None: + if isinstance(last_exception, SwarmException): + raise last_exception + else: + raise SwarmException( + f"Failed to connect after {self.retry_config.max_retries} attempts" + ) from last_exception + + # This should never be reached, but mypy requires it + raise SwarmException("Unexpected error in retry logic") + + def _calculate_backoff_delay(self, attempt: int) -> float: + """ + Enhanced: Calculate backoff delay with jitter to prevent thundering herd. + + :param attempt: the current attempt number (0-based) + :return: delay in seconds + """ + delay = min( + self.retry_config.initial_delay + * (self.retry_config.backoff_multiplier**attempt), + self.retry_config.max_delay, + ) + + # Add jitter to prevent synchronized retries + jitter = delay * self.retry_config.jitter_factor + return delay + random.uniform(-jitter, jitter) + + async def _dial_addr_single_attempt(self, addr: Multiaddr, peer_id: ID) -> INetConn: + """ + Enhanced: Single attempt to dial an address (extracted from original dial_addr). :param addr: the address we want to connect with :param peer_id: the peer we want to connect to @@ -216,14 +494,49 @@ class Swarm(Service, INetworkService): return swarm_conn + async def dial_addr(self, addr: Multiaddr, peer_id: ID) -> INetConn: + """ + Enhanced: Try to create a connection to peer_id with addr using retry logic. + + :param addr: the address we want to connect with + :param peer_id: the peer we want to connect to + :raises SwarmException: raised when an error occurs + :return: network connection + """ + return await self._dial_with_retry(addr, peer_id) + async def new_stream(self, peer_id: ID) -> INetStream: """ + Enhanced: Create a new stream with load balancing across multiple connections. + :param peer_id: peer_id of destination :raises SwarmException: raised when an error occurs :return: net stream instance """ logger.debug("attempting to open a stream to peer %s", peer_id) + # Enhanced: Try to get existing connection from pool first + if self.connection_pool and self.connection_pool.has_connection(peer_id): + connection = self.connection_pool.get_connection( + peer_id, self.connection_config.load_balancing_strategy + ) + if connection: + try: + net_stream = await connection.new_stream() + logger.debug( + "successfully opened a stream to peer %s " + "using existing connection", + peer_id, + ) + return net_stream + except Exception as e: + logger.debug( + f"Failed to create stream on existing connection, " + f"will dial new connection: {e}" + ) + # Fall through to dial new connection + + # Fall back to existing logic: dial peer and create stream swarm_conn = await self.dial_peer(peer_id) net_stream = await swarm_conn.new_stream() @@ -359,6 +672,11 @@ class Swarm(Service, INetworkService): if peer_id not in self.connections: return connection = self.connections[peer_id] + + # Enhanced: Remove from connection pool if enabled + if self.connection_pool: + self.connection_pool.remove_connection(peer_id, connection) + # NOTE: `connection.close` will delete `peer_id` from `self.connections` # and `notify_disconnected` for us. await connection.close() @@ -380,7 +698,15 @@ class Swarm(Service, INetworkService): await muxed_conn.event_started.wait() self.manager.run_task(swarm_conn.start) await swarm_conn.event_started.wait() - # Store muxed_conn with peer id + # Enhanced: Add to connection pool if enabled + if self.connection_pool: + # For incoming connections, we don't have a specific address + # Use a placeholder that will be updated when we get more info + self.connection_pool.add_connection( + muxed_conn.peer_id, swarm_conn, "incoming" + ) + + # Store muxed_conn with peer id (backward compatibility) self.connections[muxed_conn.peer_id] = swarm_conn # Call notifiers since event occurred await self.notify_connected(swarm_conn) @@ -392,6 +718,11 @@ class Swarm(Service, INetworkService): the connection. """ peer_id = swarm_conn.muxed_conn.peer_id + + # Enhanced: Remove from connection pool if enabled + if self.connection_pool: + self.connection_pool.remove_connection(peer_id, swarm_conn) + if peer_id not in self.connections: return del self.connections[peer_id] diff --git a/newsfragments/874.feature.rst b/newsfragments/874.feature.rst new file mode 100644 index 00000000..bef1d3bc --- /dev/null +++ b/newsfragments/874.feature.rst @@ -0,0 +1 @@ +Enhanced Swarm networking with retry logic, exponential backoff, and multi-connection support. Added configurable retry mechanisms that automatically recover from transient connection failures using exponential backoff with jitter to prevent thundering herd problems. Introduced connection pooling that allows multiple concurrent connections per peer for improved performance and fault tolerance. Added load balancing across connections and automatic connection health management. All enhancements are fully backward compatible and can be configured through new RetryConfig and ConnectionConfig classes. diff --git a/tests/core/network/test_enhanced_swarm.py b/tests/core/network/test_enhanced_swarm.py new file mode 100644 index 00000000..c076729b --- /dev/null +++ b/tests/core/network/test_enhanced_swarm.py @@ -0,0 +1,428 @@ +import time +from unittest.mock import Mock + +import pytest +from multiaddr import Multiaddr + +from libp2p.abc import INetConn, INetStream +from libp2p.network.exceptions import SwarmException +from libp2p.network.swarm import ( + ConnectionConfig, + ConnectionPool, + RetryConfig, + Swarm, +) +from libp2p.peer.id import ID + + +class MockConnection(INetConn): + """Mock connection for testing.""" + + def __init__(self, peer_id: ID, is_closed: bool = False): + self.peer_id = peer_id + self._is_closed = is_closed + self.stream_count = 0 + # Mock the muxed_conn attribute that Swarm expects + self.muxed_conn = Mock() + self.muxed_conn.peer_id = peer_id + + async def close(self): + self._is_closed = True + + @property + def is_closed(self) -> bool: + return self._is_closed + + async def new_stream(self) -> INetStream: + self.stream_count += 1 + return Mock(spec=INetStream) + + def get_streams(self) -> tuple[INetStream, ...]: + """Mock implementation of get_streams.""" + return tuple() + + def get_transport_addresses(self) -> list[Multiaddr]: + """Mock implementation of get_transport_addresses.""" + return [] + + +class MockNetStream(INetStream): + """Mock network stream for testing.""" + + def __init__(self, peer_id: ID): + self.peer_id = peer_id + + +@pytest.mark.trio +async def test_retry_config_defaults(): + """Test RetryConfig default values.""" + config = RetryConfig() + assert config.max_retries == 3 + assert config.initial_delay == 0.1 + assert config.max_delay == 30.0 + assert config.backoff_multiplier == 2.0 + assert config.jitter_factor == 0.1 + + +@pytest.mark.trio +async def test_connection_config_defaults(): + """Test ConnectionConfig default values.""" + config = ConnectionConfig() + assert config.max_connections_per_peer == 3 + assert config.connection_timeout == 30.0 + assert config.enable_connection_pool is True + assert config.load_balancing_strategy == "round_robin" + + +@pytest.mark.trio +async def test_connection_pool_basic_operations(): + """Test basic ConnectionPool operations.""" + pool = ConnectionPool(max_connections_per_peer=2) + peer_id = ID(b"QmTest") + + # Test empty pool + assert not pool.has_connection(peer_id) + assert pool.get_connection(peer_id) is None + + # Add connection + conn1 = MockConnection(peer_id) + pool.add_connection(peer_id, conn1, "addr1") + assert pool.has_connection(peer_id) + assert pool.get_connection(peer_id) == conn1 + + # Add second connection + conn2 = MockConnection(peer_id) + pool.add_connection(peer_id, conn2, "addr2") + assert len(pool.peer_connections[peer_id]) == 2 + + # Test round-robin - should cycle through connections + first_conn = pool.get_connection(peer_id, "round_robin") + second_conn = pool.get_connection(peer_id, "round_robin") + third_conn = pool.get_connection(peer_id, "round_robin") + + # Should cycle through both connections + assert first_conn in [conn1, conn2] + assert second_conn in [conn1, conn2] + assert third_conn in [conn1, conn2] + assert first_conn != second_conn or second_conn != third_conn + + # Test least loaded - set different stream counts + conn1.stream_count = 5 + conn2.stream_count = 1 + least_loaded_conn = pool.get_connection(peer_id, "least_loaded") + assert least_loaded_conn == conn2 # conn2 has fewer streams + + +@pytest.mark.trio +async def test_connection_pool_deduplication(): + """Test connection deduplication by address.""" + pool = ConnectionPool(max_connections_per_peer=3) + peer_id = ID(b"QmTest") + + conn1 = MockConnection(peer_id) + pool.add_connection(peer_id, conn1, "addr1") + + # Try to add connection with same address + conn2 = MockConnection(peer_id) + pool.add_connection(peer_id, conn2, "addr1") + + # Should only have one connection + assert len(pool.peer_connections[peer_id]) == 1 + assert pool.get_connection(peer_id) == conn1 + + +@pytest.mark.trio +async def test_connection_pool_trimming(): + """Test connection trimming when limit is exceeded.""" + pool = ConnectionPool(max_connections_per_peer=2) + peer_id = ID(b"QmTest") + + # Add 3 connections + conn1 = MockConnection(peer_id) + conn2 = MockConnection(peer_id) + conn3 = MockConnection(peer_id) + + pool.add_connection(peer_id, conn1, "addr1") + pool.add_connection(peer_id, conn2, "addr2") + pool.add_connection(peer_id, conn3, "addr3") + + # Should trim to 2 connections + assert len(pool.peer_connections[peer_id]) == 2 + + # The oldest connections should be removed + remaining_connections = [c.connection for c in pool.peer_connections[peer_id]] + assert conn3 in remaining_connections # Most recent should remain + + +@pytest.mark.trio +async def test_connection_pool_remove_connection(): + """Test removing connections from pool.""" + pool = ConnectionPool(max_connections_per_peer=3) + peer_id = ID(b"QmTest") + + conn1 = MockConnection(peer_id) + conn2 = MockConnection(peer_id) + + pool.add_connection(peer_id, conn1, "addr1") + pool.add_connection(peer_id, conn2, "addr2") + + assert len(pool.peer_connections[peer_id]) == 2 + + # Remove connection + pool.remove_connection(peer_id, conn1) + assert len(pool.peer_connections[peer_id]) == 1 + assert pool.get_connection(peer_id) == conn2 + + # Remove last connection + pool.remove_connection(peer_id, conn2) + assert not pool.has_connection(peer_id) + + +@pytest.mark.trio +async def test_enhanced_swarm_constructor(): + """Test enhanced Swarm constructor with new configuration.""" + # Create mock dependencies + peer_id = ID(b"QmTest") + peerstore = Mock() + upgrader = Mock() + transport = Mock() + + # Test with default config + swarm = Swarm(peer_id, peerstore, upgrader, transport) + assert swarm.retry_config.max_retries == 3 + assert swarm.connection_config.max_connections_per_peer == 3 + assert swarm.connection_pool is not None + + # Test with custom config + custom_retry = RetryConfig(max_retries=5, initial_delay=0.5) + custom_conn = ConnectionConfig( + max_connections_per_peer=5, + enable_connection_pool=False + ) + + swarm = Swarm(peer_id, peerstore, upgrader, transport, custom_retry, custom_conn) + assert swarm.retry_config.max_retries == 5 + assert swarm.retry_config.initial_delay == 0.5 + assert swarm.connection_config.max_connections_per_peer == 5 + assert swarm.connection_pool is None + + +@pytest.mark.trio +async def test_swarm_backoff_calculation(): + """Test exponential backoff calculation with jitter.""" + peer_id = ID(b"QmTest") + peerstore = Mock() + upgrader = Mock() + transport = Mock() + + retry_config = RetryConfig( + initial_delay=0.1, + max_delay=1.0, + backoff_multiplier=2.0, + jitter_factor=0.1 + ) + + swarm = Swarm(peer_id, peerstore, upgrader, transport, retry_config) + + # Test backoff calculation + delay1 = swarm._calculate_backoff_delay(0) + delay2 = swarm._calculate_backoff_delay(1) + delay3 = swarm._calculate_backoff_delay(2) + + # Should increase exponentially + assert delay2 > delay1 + assert delay3 > delay2 + + # Should respect max delay + assert delay1 <= 1.0 + assert delay2 <= 1.0 + assert delay3 <= 1.0 + + # Should have jitter + assert delay1 != 0.1 # Should have jitter added + + +@pytest.mark.trio +async def test_swarm_retry_logic(): + """Test retry logic in dial operations.""" + peer_id = ID(b"QmTest") + peerstore = Mock() + upgrader = Mock() + transport = Mock() + + # Configure for fast testing + retry_config = RetryConfig( + max_retries=2, + initial_delay=0.01, # Very short for testing + max_delay=0.1 + ) + + swarm = Swarm(peer_id, peerstore, upgrader, transport, retry_config) + + # Mock the single attempt method to fail twice then succeed + attempt_count = [0] + + async def mock_single_attempt(addr, peer_id): + attempt_count[0] += 1 + if attempt_count[0] < 3: + raise SwarmException(f"Attempt {attempt_count[0]} failed") + return MockConnection(peer_id) + + swarm._dial_addr_single_attempt = mock_single_attempt + + # Test retry logic + start_time = time.time() + result = await swarm._dial_with_retry(Mock(spec=Multiaddr), peer_id) + end_time = time.time() + + # Should have succeeded after 3 attempts + assert attempt_count[0] == 3 + assert result is not None + + # Should have taken some time due to retries + assert end_time - start_time > 0.02 # At least 2 delays + + +@pytest.mark.trio +async def test_swarm_multi_connection_support(): + """Test multi-connection support in Swarm.""" + peer_id = ID(b"QmTest") + peerstore = Mock() + upgrader = Mock() + transport = Mock() + + connection_config = ConnectionConfig( + max_connections_per_peer=3, + enable_connection_pool=True, + load_balancing_strategy="round_robin" + ) + + swarm = Swarm( + peer_id, + peerstore, + upgrader, + transport, + connection_config=connection_config + ) + + # Mock connection pool methods + assert swarm.connection_pool is not None + connection_pool = swarm.connection_pool + connection_pool.has_connection = Mock(return_value=True) + connection_pool.get_connection = Mock(return_value=MockConnection(peer_id)) + + # Test that new_stream uses connection pool + result = await swarm.new_stream(peer_id) + assert result is not None + # Use the mocked method directly to avoid type checking issues + get_connection_mock = connection_pool.get_connection + assert get_connection_mock.call_count == 1 + + +@pytest.mark.trio +async def test_swarm_backward_compatibility(): + """Test that enhanced Swarm maintains backward compatibility.""" + peer_id = ID(b"QmTest") + peerstore = Mock() + upgrader = Mock() + transport = Mock() + + # Create swarm with connection pool disabled + connection_config = ConnectionConfig(enable_connection_pool=False) + swarm = Swarm( + peer_id, peerstore, upgrader, transport, + connection_config=connection_config + ) + + # Should behave like original swarm + assert swarm.connection_pool is None + assert isinstance(swarm.connections, dict) + + # Test that dial_peer still works (will fail due to mocks, but structure is correct) + peerstore.addrs.return_value = [Mock(spec=Multiaddr)] + transport.dial.side_effect = Exception("Transport error") + + with pytest.raises(SwarmException): + await swarm.dial_peer(peer_id) + + +@pytest.mark.trio +async def test_swarm_connection_pool_integration(): + """Test integration between Swarm and ConnectionPool.""" + peer_id = ID(b"QmTest") + peerstore = Mock() + upgrader = Mock() + transport = Mock() + + connection_config = ConnectionConfig( + max_connections_per_peer=2, + enable_connection_pool=True + ) + + swarm = Swarm( + peer_id, peerstore, upgrader, transport, + connection_config=connection_config + ) + + # Mock successful connection creation + mock_conn = MockConnection(peer_id) + peerstore.addrs.return_value = [Mock(spec=Multiaddr)] + + async def mock_dial_with_retry(addr, peer_id): + return mock_conn + + swarm._dial_with_retry = mock_dial_with_retry + + # Test dial_peer adds to connection pool + result = await swarm.dial_peer(peer_id) + assert result == mock_conn + assert swarm.connection_pool is not None + assert swarm.connection_pool.has_connection(peer_id) + + # Test that subsequent calls reuse connection + result2 = await swarm.dial_peer(peer_id) + assert result2 == mock_conn + + +@pytest.mark.trio +async def test_swarm_connection_cleanup(): + """Test connection cleanup in enhanced Swarm.""" + peer_id = ID(b"QmTest") + peerstore = Mock() + upgrader = Mock() + transport = Mock() + + connection_config = ConnectionConfig(enable_connection_pool=True) + swarm = Swarm( + peer_id, peerstore, upgrader, transport, + connection_config=connection_config + ) + + # Add a connection + mock_conn = MockConnection(peer_id) + swarm.connections[peer_id] = mock_conn + assert swarm.connection_pool is not None + swarm.connection_pool.add_connection(peer_id, mock_conn, "test_addr") + + # Test close_peer removes from pool + await swarm.close_peer(peer_id) + assert swarm.connection_pool is not None + assert not swarm.connection_pool.has_connection(peer_id) + + # Test remove_conn removes from pool + mock_conn2 = MockConnection(peer_id) + swarm.connections[peer_id] = mock_conn2 + assert swarm.connection_pool is not None + connection_pool = swarm.connection_pool + connection_pool.add_connection(peer_id, mock_conn2, "test_addr2") + + # Note: remove_conn expects SwarmConn, but for testing we'll just + # remove from pool directly + connection_pool = swarm.connection_pool + connection_pool.remove_connection(peer_id, mock_conn2) + assert connection_pool is not None + assert not connection_pool.has_connection(peer_id) + + +if __name__ == "__main__": + pytest.main([__file__]) From 9fa3afbb0496270d39de29dc163e85591ad5f701 Mon Sep 17 00:00:00 2001 From: bomanaps Date: Thu, 28 Aug 2025 22:18:33 +0100 Subject: [PATCH 49/71] fix: format code to pass CI lint --- examples/enhanced_swarm_example.py | 96 ++++++++--------------- tests/core/network/test_enhanced_swarm.py | 30 +++---- 2 files changed, 42 insertions(+), 84 deletions(-) diff --git a/examples/enhanced_swarm_example.py b/examples/enhanced_swarm_example.py index 37770411..b5367af8 100644 --- a/examples/enhanced_swarm_example.py +++ b/examples/enhanced_swarm_example.py @@ -32,16 +32,12 @@ async def example_basic_enhanced_swarm() -> None: default_connection = ConnectionConfig() logger.info(f"Swarm created with peer ID: {swarm.get_peer_id()}") - logger.info( - f"Retry config: max_retries={default_retry.max_retries}" - ) + logger.info(f"Retry config: max_retries={default_retry.max_retries}") logger.info( f"Connection config: max_connections_per_peer=" f"{default_connection.max_connections_per_peer}" ) - logger.info( - f"Connection pool enabled: {default_connection.enable_connection_pool}" - ) + logger.info(f"Connection pool enabled: {default_connection.enable_connection_pool}") await swarm.close() logger.info("Basic enhanced Swarm example completed") @@ -53,32 +49,22 @@ async def example_custom_retry_config() -> None: # Custom retry configuration for aggressive retry behavior retry_config = RetryConfig( - max_retries=5, # More retries - initial_delay=0.05, # Faster initial retry - max_delay=10.0, # Lower max delay + max_retries=5, # More retries + initial_delay=0.05, # Faster initial retry + max_delay=10.0, # Lower max delay backoff_multiplier=1.5, # Less aggressive backoff - jitter_factor=0.2 # More jitter + jitter_factor=0.2, # More jitter ) # Create swarm with custom retry config swarm = new_swarm(retry_config=retry_config) logger.info("Custom retry config applied:") - logger.info( - f" Max retries: {retry_config.max_retries}" - ) - logger.info( - f" Initial delay: {retry_config.initial_delay}s" - ) - logger.info( - f" Max delay: {retry_config.max_delay}s" - ) - logger.info( - f" Backoff multiplier: {retry_config.backoff_multiplier}" - ) - logger.info( - f" Jitter factor: {retry_config.jitter_factor}" - ) + logger.info(f" Max retries: {retry_config.max_retries}") + logger.info(f" Initial delay: {retry_config.initial_delay}s") + logger.info(f" Max delay: {retry_config.max_delay}s") + logger.info(f" Backoff multiplier: {retry_config.backoff_multiplier}") + logger.info(f" Jitter factor: {retry_config.jitter_factor}") await swarm.close() logger.info("Custom retry config example completed") @@ -90,10 +76,10 @@ async def example_custom_connection_config() -> None: # Custom connection configuration for high-performance scenarios connection_config = ConnectionConfig( - max_connections_per_peer=5, # More connections per peer - connection_timeout=60.0, # Longer timeout - enable_connection_pool=True, # Enable connection pooling - load_balancing_strategy="least_loaded" # Use least loaded strategy + max_connections_per_peer=5, # More connections per peer + connection_timeout=60.0, # Longer timeout + enable_connection_pool=True, # Enable connection pooling + load_balancing_strategy="least_loaded", # Use least loaded strategy ) # Create swarm with custom connection config @@ -101,19 +87,14 @@ async def example_custom_connection_config() -> None: logger.info("Custom connection config applied:") logger.info( - f" Max connections per peer: " - f"{connection_config.max_connections_per_peer}" + f" Max connections per peer: {connection_config.max_connections_per_peer}" + ) + logger.info(f" Connection timeout: {connection_config.connection_timeout}s") + logger.info( + f" Connection pool enabled: {connection_config.enable_connection_pool}" ) logger.info( - f" Connection timeout: {connection_config.connection_timeout}s" - ) - logger.info( - f" Connection pool enabled: " - f"{connection_config.enable_connection_pool}" - ) - logger.info( - f" Load balancing strategy: " - f"{connection_config.load_balancing_strategy}" + f" Load balancing strategy: {connection_config.load_balancing_strategy}" ) await swarm.close() @@ -134,12 +115,8 @@ async def example_backward_compatibility() -> None: logger.info( f" Connection pool enabled: {connection_config.enable_connection_pool}" ) - logger.info( - f" Connections dict type: {type(swarm.connections)}" - ) - logger.info( - " Retry logic still available: 3 max retries" - ) + logger.info(f" Connections dict type: {type(swarm.connections)}") + logger.info(" Retry logic still available: 3 max retries") await swarm.close() logger.info("Backward compatibility example completed") @@ -151,38 +128,31 @@ async def example_production_ready_config() -> None: # Production-ready retry configuration retry_config = RetryConfig( - max_retries=3, # Reasonable retry limit - initial_delay=0.1, # Quick initial retry - max_delay=30.0, # Cap exponential backoff + max_retries=3, # Reasonable retry limit + initial_delay=0.1, # Quick initial retry + max_delay=30.0, # Cap exponential backoff backoff_multiplier=2.0, # Standard exponential backoff - jitter_factor=0.1 # Small jitter to prevent thundering herd + jitter_factor=0.1, # Small jitter to prevent thundering herd ) # Production-ready connection configuration connection_config = ConnectionConfig( max_connections_per_peer=3, # Balance between performance and resource usage - connection_timeout=30.0, # Reasonable timeout - enable_connection_pool=True, # Enable for better performance - load_balancing_strategy="round_robin" # Simple, predictable strategy + connection_timeout=30.0, # Reasonable timeout + enable_connection_pool=True, # Enable for better performance + load_balancing_strategy="round_robin", # Simple, predictable strategy ) # Create swarm with production config - swarm = new_swarm( - retry_config=retry_config, - connection_config=connection_config - ) + swarm = new_swarm(retry_config=retry_config, connection_config=connection_config) logger.info("Production-ready configuration applied:") logger.info( f" Retry: {retry_config.max_retries} retries, " f"{retry_config.max_delay}s max delay" ) - logger.info( - f" Connections: {connection_config.max_connections_per_peer} per peer" - ) - logger.info( - f" Load balancing: {connection_config.load_balancing_strategy}" - ) + logger.info(f" Connections: {connection_config.max_connections_per_peer} per peer") + logger.info(f" Load balancing: {connection_config.load_balancing_strategy}") await swarm.close() logger.info("Production-ready configuration example completed") diff --git a/tests/core/network/test_enhanced_swarm.py b/tests/core/network/test_enhanced_swarm.py index c076729b..9b100ad9 100644 --- a/tests/core/network/test_enhanced_swarm.py +++ b/tests/core/network/test_enhanced_swarm.py @@ -196,8 +196,7 @@ async def test_enhanced_swarm_constructor(): # Test with custom config custom_retry = RetryConfig(max_retries=5, initial_delay=0.5) custom_conn = ConnectionConfig( - max_connections_per_peer=5, - enable_connection_pool=False + max_connections_per_peer=5, enable_connection_pool=False ) swarm = Swarm(peer_id, peerstore, upgrader, transport, custom_retry, custom_conn) @@ -216,10 +215,7 @@ async def test_swarm_backoff_calculation(): transport = Mock() retry_config = RetryConfig( - initial_delay=0.1, - max_delay=1.0, - backoff_multiplier=2.0, - jitter_factor=0.1 + initial_delay=0.1, max_delay=1.0, backoff_multiplier=2.0, jitter_factor=0.1 ) swarm = Swarm(peer_id, peerstore, upgrader, transport, retry_config) @@ -254,7 +250,7 @@ async def test_swarm_retry_logic(): retry_config = RetryConfig( max_retries=2, initial_delay=0.01, # Very short for testing - max_delay=0.1 + max_delay=0.1, ) swarm = Swarm(peer_id, peerstore, upgrader, transport, retry_config) @@ -294,15 +290,11 @@ async def test_swarm_multi_connection_support(): connection_config = ConnectionConfig( max_connections_per_peer=3, enable_connection_pool=True, - load_balancing_strategy="round_robin" + load_balancing_strategy="round_robin", ) swarm = Swarm( - peer_id, - peerstore, - upgrader, - transport, - connection_config=connection_config + peer_id, peerstore, upgrader, transport, connection_config=connection_config ) # Mock connection pool methods @@ -330,8 +322,7 @@ async def test_swarm_backward_compatibility(): # Create swarm with connection pool disabled connection_config = ConnectionConfig(enable_connection_pool=False) swarm = Swarm( - peer_id, peerstore, upgrader, transport, - connection_config=connection_config + peer_id, peerstore, upgrader, transport, connection_config=connection_config ) # Should behave like original swarm @@ -355,13 +346,11 @@ async def test_swarm_connection_pool_integration(): transport = Mock() connection_config = ConnectionConfig( - max_connections_per_peer=2, - enable_connection_pool=True + max_connections_per_peer=2, enable_connection_pool=True ) swarm = Swarm( - peer_id, peerstore, upgrader, transport, - connection_config=connection_config + peer_id, peerstore, upgrader, transport, connection_config=connection_config ) # Mock successful connection creation @@ -394,8 +383,7 @@ async def test_swarm_connection_cleanup(): connection_config = ConnectionConfig(enable_connection_pool=True) swarm = Swarm( - peer_id, peerstore, upgrader, transport, - connection_config=connection_config + peer_id, peerstore, upgrader, transport, connection_config=connection_config ) # Add a connection From 3c52b859baca1af6ade433569fbd57d083d8f432 Mon Sep 17 00:00:00 2001 From: unniznd Date: Fri, 29 Aug 2025 11:30:17 +0530 Subject: [PATCH 50/71] improved the error message --- libp2p/security/security_multistream.py | 4 +++- libp2p/stream_muxer/muxer_multistream.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/libp2p/security/security_multistream.py b/libp2p/security/security_multistream.py index a9c4b19c..f7c81de1 100644 --- a/libp2p/security/security_multistream.py +++ b/libp2p/security/security_multistream.py @@ -118,6 +118,8 @@ class SecurityMultistream(ABC): # Select protocol if non-initiator protocol, _ = await self.multiselect.negotiate(communicator) if protocol is None: - raise MultiselectError("fail to negotiate a security protocol") + raise MultiselectError( + "fail to negotiate a security protocol: no protocl selected" + ) # Return transport from protocol return self.transports[protocol] diff --git a/libp2p/stream_muxer/muxer_multistream.py b/libp2p/stream_muxer/muxer_multistream.py index 322db912..76689c17 100644 --- a/libp2p/stream_muxer/muxer_multistream.py +++ b/libp2p/stream_muxer/muxer_multistream.py @@ -85,7 +85,9 @@ class MuxerMultistream: else: protocol, _ = await self.multiselect.negotiate(communicator) if protocol is None: - raise MultiselectError("fail to negotiate a stream muxer protocol") + raise MultiselectError( + "fail to negotiate a stream muxer protocol: no protocol selected" + ) return self.transports[protocol] async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn: From 56526b48707de39da8c74e68c31775f38a8352be Mon Sep 17 00:00:00 2001 From: lla-dane Date: Mon, 11 Aug 2025 18:27:11 +0530 Subject: [PATCH 51/71] signed-peer-record transfer integrated with pubsub rpc message trasfer --- libp2p/pubsub/floodsub.py | 10 + libp2p/pubsub/gossipsub.py | 53 +++++ libp2p/pubsub/pb/rpc.proto | 1 + libp2p/pubsub/pb/rpc_pb2.py | 67 +++--- libp2p/pubsub/pb/rpc_pb2.pyi | 435 ++++++++++------------------------- libp2p/pubsub/pubsub.py | 46 ++++ 6 files changed, 266 insertions(+), 346 deletions(-) diff --git a/libp2p/pubsub/floodsub.py b/libp2p/pubsub/floodsub.py index 3e0d454f..170f558d 100644 --- a/libp2p/pubsub/floodsub.py +++ b/libp2p/pubsub/floodsub.py @@ -15,6 +15,7 @@ from libp2p.custom_types import ( from libp2p.peer.id import ( ID, ) +from libp2p.peer.peerstore import create_signed_peer_record from .exceptions import ( PubsubRouterError, @@ -103,6 +104,15 @@ class FloodSub(IPubsubRouter): ) rpc_msg = rpc_pb2.RPC(publish=[pubsub_msg]) + # Add the senderRecord of the peer in the RPC msg + if isinstance(self.pubsub, Pubsub): + envelope = create_signed_peer_record( + self.pubsub.host.get_id(), + self.pubsub.host.get_addrs(), + self.pubsub.host.get_private_key(), + ) + rpc_msg.senderRecord = envelope.marshal_envelope() + logger.debug("publishing message %s", pubsub_msg) if self.pubsub is None: diff --git a/libp2p/pubsub/gossipsub.py b/libp2p/pubsub/gossipsub.py index c345c138..b7c70c55 100644 --- a/libp2p/pubsub/gossipsub.py +++ b/libp2p/pubsub/gossipsub.py @@ -24,6 +24,7 @@ from libp2p.abc import ( from libp2p.custom_types import ( TProtocol, ) +from libp2p.peer.envelope import consume_envelope from libp2p.peer.id import ( ID, ) @@ -34,6 +35,7 @@ from libp2p.peer.peerinfo import ( ) from libp2p.peer.peerstore import ( PERMANENT_ADDR_TTL, + create_signed_peer_record, ) from libp2p.pubsub import ( floodsub, @@ -226,6 +228,27 @@ class GossipSub(IPubsubRouter, Service): :param rpc: RPC message :param sender_peer_id: id of the peer who sent the message """ + # Process the senderRecord if sent + if isinstance(self.pubsub, Pubsub): + if rpc.HasField("senderRecord"): + try: + # Convert the signed-peer-record(Envelope) from + # protobuf bytes + envelope, _ = consume_envelope( + rpc.senderRecord, "libp2p-peer-record" + ) + # Use the default TTL of 2 hours (7200 seconds) + if self.pubsub.host.get_peerstore().consume_peer_record( + envelope, 7200 + ): + logger.error( + "Updating the Certified-Addr-Book was unsuccessful" + ) + except Exception as e: + logger.error( + "Error updating the certified addr book for peer: %s", e + ) + control_message = rpc.control # Relay each rpc control message to the appropriate handler @@ -253,6 +276,15 @@ class GossipSub(IPubsubRouter, Service): ) rpc_msg = rpc_pb2.RPC(publish=[pubsub_msg]) + # Add the senderRecord of the peer in the RPC msg + if isinstance(self.pubsub, Pubsub): + envelope = create_signed_peer_record( + self.pubsub.host.get_id(), + self.pubsub.host.get_addrs(), + self.pubsub.host.get_private_key(), + ) + rpc_msg.senderRecord = envelope.marshal_envelope() + logger.debug("publishing message %s", pubsub_msg) for peer_id in peers_gen: @@ -818,6 +850,17 @@ class GossipSub(IPubsubRouter, Service): # 1) Package these messages into a single packet packet: rpc_pb2.RPC = rpc_pb2.RPC() + # Here the an RPC message is being created and published in response + # to the iwant control msg, so we will send a freshly created senderRecord + # with the RPC msg + if isinstance(self.pubsub, Pubsub): + envelope = create_signed_peer_record( + self.pubsub.host.get_id(), + self.pubsub.host.get_addrs(), + self.pubsub.host.get_private_key(), + ) + packet.senderRecord = envelope.marshal_envelope() + packet.publish.extend(msgs_to_forward) if self.pubsub is None: @@ -973,6 +1016,16 @@ class GossipSub(IPubsubRouter, Service): raise NoPubsubAttached # Add control message to packet packet: rpc_pb2.RPC = rpc_pb2.RPC() + + # Add the sender's peer-record in the RPC msg + if isinstance(self.pubsub, Pubsub): + envelope = create_signed_peer_record( + self.pubsub.host.get_id(), + self.pubsub.host.get_addrs(), + self.pubsub.host.get_private_key(), + ) + packet.senderRecord = envelope.marshal_envelope() + packet.control.CopyFrom(control_msg) # Get stream for peer from pubsub diff --git a/libp2p/pubsub/pb/rpc.proto b/libp2p/pubsub/pb/rpc.proto index 7abce0d6..d24db281 100644 --- a/libp2p/pubsub/pb/rpc.proto +++ b/libp2p/pubsub/pb/rpc.proto @@ -14,6 +14,7 @@ message RPC { } optional ControlMessage control = 3; + optional bytes senderRecord = 4; } message Message { diff --git a/libp2p/pubsub/pb/rpc_pb2.py b/libp2p/pubsub/pb/rpc_pb2.py index 30f0281b..e4a35745 100644 --- a/libp2p/pubsub/pb/rpc_pb2.py +++ b/libp2p/pubsub/pb/rpc_pb2.py @@ -1,11 +1,12 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: libp2p/pubsub/pb/rpc.proto +# Protobuf Python Version: 4.25.3 """Generated protocol buffer code.""" -from google.protobuf.internal import builder as _builder from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -13,39 +14,39 @@ _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1alibp2p/pubsub/pb/rpc.proto\x12\tpubsub.pb\"\xb4\x01\n\x03RPC\x12-\n\rsubscriptions\x18\x01 \x03(\x0b\x32\x16.pubsub.pb.RPC.SubOpts\x12#\n\x07publish\x18\x02 \x03(\x0b\x32\x12.pubsub.pb.Message\x12*\n\x07\x63ontrol\x18\x03 \x01(\x0b\x32\x19.pubsub.pb.ControlMessage\x1a-\n\x07SubOpts\x12\x11\n\tsubscribe\x18\x01 \x01(\x08\x12\x0f\n\x07topicid\x18\x02 \x01(\t\"i\n\x07Message\x12\x0f\n\x07\x66rom_id\x18\x01 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\x12\r\n\x05seqno\x18\x03 \x01(\x0c\x12\x10\n\x08topicIDs\x18\x04 \x03(\t\x12\x11\n\tsignature\x18\x05 \x01(\x0c\x12\x0b\n\x03key\x18\x06 \x01(\x0c\"\xb0\x01\n\x0e\x43ontrolMessage\x12&\n\x05ihave\x18\x01 \x03(\x0b\x32\x17.pubsub.pb.ControlIHave\x12&\n\x05iwant\x18\x02 \x03(\x0b\x32\x17.pubsub.pb.ControlIWant\x12&\n\x05graft\x18\x03 \x03(\x0b\x32\x17.pubsub.pb.ControlGraft\x12&\n\x05prune\x18\x04 \x03(\x0b\x32\x17.pubsub.pb.ControlPrune\"3\n\x0c\x43ontrolIHave\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\x12\n\nmessageIDs\x18\x02 \x03(\t\"\"\n\x0c\x43ontrolIWant\x12\x12\n\nmessageIDs\x18\x01 \x03(\t\"\x1f\n\x0c\x43ontrolGraft\x12\x0f\n\x07topicID\x18\x01 \x01(\t\"T\n\x0c\x43ontrolPrune\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\"\n\x05peers\x18\x02 \x03(\x0b\x32\x13.pubsub.pb.PeerInfo\x12\x0f\n\x07\x62\x61\x63koff\x18\x03 \x01(\x04\"4\n\x08PeerInfo\x12\x0e\n\x06peerID\x18\x01 \x01(\x0c\x12\x18\n\x10signedPeerRecord\x18\x02 \x01(\x0c\"\x87\x03\n\x0fTopicDescriptor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x31\n\x04\x61uth\x18\x02 \x01(\x0b\x32#.pubsub.pb.TopicDescriptor.AuthOpts\x12/\n\x03\x65nc\x18\x03 \x01(\x0b\x32\".pubsub.pb.TopicDescriptor.EncOpts\x1a|\n\x08\x41uthOpts\x12:\n\x04mode\x18\x01 \x01(\x0e\x32,.pubsub.pb.TopicDescriptor.AuthOpts.AuthMode\x12\x0c\n\x04keys\x18\x02 \x03(\x0c\"&\n\x08\x41uthMode\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03KEY\x10\x01\x12\x07\n\x03WOT\x10\x02\x1a\x83\x01\n\x07\x45ncOpts\x12\x38\n\x04mode\x18\x01 \x01(\x0e\x32*.pubsub.pb.TopicDescriptor.EncOpts.EncMode\x12\x11\n\tkeyHashes\x18\x02 \x03(\x0c\"+\n\x07\x45ncMode\x12\x08\n\x04NONE\x10\x00\x12\r\n\tSHAREDKEY\x10\x01\x12\x07\n\x03WOT\x10\x02') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1alibp2p/pubsub/pb/rpc.proto\x12\tpubsub.pb\"\xca\x01\n\x03RPC\x12-\n\rsubscriptions\x18\x01 \x03(\x0b\x32\x16.pubsub.pb.RPC.SubOpts\x12#\n\x07publish\x18\x02 \x03(\x0b\x32\x12.pubsub.pb.Message\x12*\n\x07\x63ontrol\x18\x03 \x01(\x0b\x32\x19.pubsub.pb.ControlMessage\x12\x14\n\x0csenderRecord\x18\x04 \x01(\x0c\x1a-\n\x07SubOpts\x12\x11\n\tsubscribe\x18\x01 \x01(\x08\x12\x0f\n\x07topicid\x18\x02 \x01(\t\"i\n\x07Message\x12\x0f\n\x07\x66rom_id\x18\x01 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\x12\r\n\x05seqno\x18\x03 \x01(\x0c\x12\x10\n\x08topicIDs\x18\x04 \x03(\t\x12\x11\n\tsignature\x18\x05 \x01(\x0c\x12\x0b\n\x03key\x18\x06 \x01(\x0c\"\xb0\x01\n\x0e\x43ontrolMessage\x12&\n\x05ihave\x18\x01 \x03(\x0b\x32\x17.pubsub.pb.ControlIHave\x12&\n\x05iwant\x18\x02 \x03(\x0b\x32\x17.pubsub.pb.ControlIWant\x12&\n\x05graft\x18\x03 \x03(\x0b\x32\x17.pubsub.pb.ControlGraft\x12&\n\x05prune\x18\x04 \x03(\x0b\x32\x17.pubsub.pb.ControlPrune\"3\n\x0c\x43ontrolIHave\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\x12\n\nmessageIDs\x18\x02 \x03(\t\"\"\n\x0c\x43ontrolIWant\x12\x12\n\nmessageIDs\x18\x01 \x03(\t\"\x1f\n\x0c\x43ontrolGraft\x12\x0f\n\x07topicID\x18\x01 \x01(\t\"T\n\x0c\x43ontrolPrune\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\"\n\x05peers\x18\x02 \x03(\x0b\x32\x13.pubsub.pb.PeerInfo\x12\x0f\n\x07\x62\x61\x63koff\x18\x03 \x01(\x04\"4\n\x08PeerInfo\x12\x0e\n\x06peerID\x18\x01 \x01(\x0c\x12\x18\n\x10signedPeerRecord\x18\x02 \x01(\x0c\"\x87\x03\n\x0fTopicDescriptor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x31\n\x04\x61uth\x18\x02 \x01(\x0b\x32#.pubsub.pb.TopicDescriptor.AuthOpts\x12/\n\x03\x65nc\x18\x03 \x01(\x0b\x32\".pubsub.pb.TopicDescriptor.EncOpts\x1a|\n\x08\x41uthOpts\x12:\n\x04mode\x18\x01 \x01(\x0e\x32,.pubsub.pb.TopicDescriptor.AuthOpts.AuthMode\x12\x0c\n\x04keys\x18\x02 \x03(\x0c\"&\n\x08\x41uthMode\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03KEY\x10\x01\x12\x07\n\x03WOT\x10\x02\x1a\x83\x01\n\x07\x45ncOpts\x12\x38\n\x04mode\x18\x01 \x01(\x0e\x32*.pubsub.pb.TopicDescriptor.EncOpts.EncMode\x12\x11\n\tkeyHashes\x18\x02 \x03(\x0c\"+\n\x07\x45ncMode\x12\x08\n\x04NONE\x10\x00\x12\r\n\tSHAREDKEY\x10\x01\x12\x07\n\x03WOT\x10\x02') -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.pubsub.pb.rpc_pb2', globals()) +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.pubsub.pb.rpc_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _RPC._serialized_start=42 - _RPC._serialized_end=222 - _RPC_SUBOPTS._serialized_start=177 - _RPC_SUBOPTS._serialized_end=222 - _MESSAGE._serialized_start=224 - _MESSAGE._serialized_end=329 - _CONTROLMESSAGE._serialized_start=332 - _CONTROLMESSAGE._serialized_end=508 - _CONTROLIHAVE._serialized_start=510 - _CONTROLIHAVE._serialized_end=561 - _CONTROLIWANT._serialized_start=563 - _CONTROLIWANT._serialized_end=597 - _CONTROLGRAFT._serialized_start=599 - _CONTROLGRAFT._serialized_end=630 - _CONTROLPRUNE._serialized_start=632 - _CONTROLPRUNE._serialized_end=716 - _PEERINFO._serialized_start=718 - _PEERINFO._serialized_end=770 - _TOPICDESCRIPTOR._serialized_start=773 - _TOPICDESCRIPTOR._serialized_end=1164 - _TOPICDESCRIPTOR_AUTHOPTS._serialized_start=906 - _TOPICDESCRIPTOR_AUTHOPTS._serialized_end=1030 - _TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE._serialized_start=992 - _TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE._serialized_end=1030 - _TOPICDESCRIPTOR_ENCOPTS._serialized_start=1033 - _TOPICDESCRIPTOR_ENCOPTS._serialized_end=1164 - _TOPICDESCRIPTOR_ENCOPTS_ENCMODE._serialized_start=1121 - _TOPICDESCRIPTOR_ENCOPTS_ENCMODE._serialized_end=1164 + _globals['_RPC']._serialized_start=42 + _globals['_RPC']._serialized_end=244 + _globals['_RPC_SUBOPTS']._serialized_start=199 + _globals['_RPC_SUBOPTS']._serialized_end=244 + _globals['_MESSAGE']._serialized_start=246 + _globals['_MESSAGE']._serialized_end=351 + _globals['_CONTROLMESSAGE']._serialized_start=354 + _globals['_CONTROLMESSAGE']._serialized_end=530 + _globals['_CONTROLIHAVE']._serialized_start=532 + _globals['_CONTROLIHAVE']._serialized_end=583 + _globals['_CONTROLIWANT']._serialized_start=585 + _globals['_CONTROLIWANT']._serialized_end=619 + _globals['_CONTROLGRAFT']._serialized_start=621 + _globals['_CONTROLGRAFT']._serialized_end=652 + _globals['_CONTROLPRUNE']._serialized_start=654 + _globals['_CONTROLPRUNE']._serialized_end=738 + _globals['_PEERINFO']._serialized_start=740 + _globals['_PEERINFO']._serialized_end=792 + _globals['_TOPICDESCRIPTOR']._serialized_start=795 + _globals['_TOPICDESCRIPTOR']._serialized_end=1186 + _globals['_TOPICDESCRIPTOR_AUTHOPTS']._serialized_start=928 + _globals['_TOPICDESCRIPTOR_AUTHOPTS']._serialized_end=1052 + _globals['_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE']._serialized_start=1014 + _globals['_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE']._serialized_end=1052 + _globals['_TOPICDESCRIPTOR_ENCOPTS']._serialized_start=1055 + _globals['_TOPICDESCRIPTOR_ENCOPTS']._serialized_end=1186 + _globals['_TOPICDESCRIPTOR_ENCOPTS_ENCMODE']._serialized_start=1143 + _globals['_TOPICDESCRIPTOR_ENCOPTS_ENCMODE']._serialized_end=1186 # @@protoc_insertion_point(module_scope) diff --git a/libp2p/pubsub/pb/rpc_pb2.pyi b/libp2p/pubsub/pb/rpc_pb2.pyi index 88738e2e..2609fd11 100644 --- a/libp2p/pubsub/pb/rpc_pb2.pyi +++ b/libp2p/pubsub/pb/rpc_pb2.pyi @@ -1,323 +1,132 @@ -""" -@generated by mypy-protobuf. Do not edit manually! -isort:skip_file -Modified from https://github.com/libp2p/go-libp2p-pubsub/blob/master/pb/rpc.proto""" +from google.protobuf.internal import containers as _containers +from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union -import builtins -import collections.abc -import google.protobuf.descriptor -import google.protobuf.internal.containers -import google.protobuf.internal.enum_type_wrapper -import google.protobuf.message -import sys -import typing +DESCRIPTOR: _descriptor.FileDescriptor -if sys.version_info >= (3, 10): - import typing as typing_extensions -else: - import typing_extensions +class RPC(_message.Message): + __slots__ = ("subscriptions", "publish", "control", "senderRecord") + class SubOpts(_message.Message): + __slots__ = ("subscribe", "topicid") + SUBSCRIBE_FIELD_NUMBER: _ClassVar[int] + TOPICID_FIELD_NUMBER: _ClassVar[int] + subscribe: bool + topicid: str + def __init__(self, subscribe: bool = ..., topicid: _Optional[str] = ...) -> None: ... + SUBSCRIPTIONS_FIELD_NUMBER: _ClassVar[int] + PUBLISH_FIELD_NUMBER: _ClassVar[int] + CONTROL_FIELD_NUMBER: _ClassVar[int] + SENDERRECORD_FIELD_NUMBER: _ClassVar[int] + subscriptions: _containers.RepeatedCompositeFieldContainer[RPC.SubOpts] + publish: _containers.RepeatedCompositeFieldContainer[Message] + control: ControlMessage + senderRecord: bytes + def __init__(self, subscriptions: _Optional[_Iterable[_Union[RPC.SubOpts, _Mapping]]] = ..., publish: _Optional[_Iterable[_Union[Message, _Mapping]]] = ..., control: _Optional[_Union[ControlMessage, _Mapping]] = ..., senderRecord: _Optional[bytes] = ...) -> None: ... # type: ignore -DESCRIPTOR: google.protobuf.descriptor.FileDescriptor +class Message(_message.Message): + __slots__ = ("from_id", "data", "seqno", "topicIDs", "signature", "key") + FROM_ID_FIELD_NUMBER: _ClassVar[int] + DATA_FIELD_NUMBER: _ClassVar[int] + SEQNO_FIELD_NUMBER: _ClassVar[int] + TOPICIDS_FIELD_NUMBER: _ClassVar[int] + SIGNATURE_FIELD_NUMBER: _ClassVar[int] + KEY_FIELD_NUMBER: _ClassVar[int] + from_id: bytes + data: bytes + seqno: bytes + topicIDs: _containers.RepeatedScalarFieldContainer[str] + signature: bytes + key: bytes + def __init__(self, from_id: _Optional[bytes] = ..., data: _Optional[bytes] = ..., seqno: _Optional[bytes] = ..., topicIDs: _Optional[_Iterable[str]] = ..., signature: _Optional[bytes] = ..., key: _Optional[bytes] = ...) -> None: ... -@typing.final -class RPC(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor +class ControlMessage(_message.Message): + __slots__ = ("ihave", "iwant", "graft", "prune") + IHAVE_FIELD_NUMBER: _ClassVar[int] + IWANT_FIELD_NUMBER: _ClassVar[int] + GRAFT_FIELD_NUMBER: _ClassVar[int] + PRUNE_FIELD_NUMBER: _ClassVar[int] + ihave: _containers.RepeatedCompositeFieldContainer[ControlIHave] + iwant: _containers.RepeatedCompositeFieldContainer[ControlIWant] + graft: _containers.RepeatedCompositeFieldContainer[ControlGraft] + prune: _containers.RepeatedCompositeFieldContainer[ControlPrune] + def __init__(self, ihave: _Optional[_Iterable[_Union[ControlIHave, _Mapping]]] = ..., iwant: _Optional[_Iterable[_Union[ControlIWant, _Mapping]]] = ..., graft: _Optional[_Iterable[_Union[ControlGraft, _Mapping]]] = ..., prune: _Optional[_Iterable[_Union[ControlPrune, _Mapping]]] = ...) -> None: ... # type: ignore - @typing.final - class SubOpts(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor +class ControlIHave(_message.Message): + __slots__ = ("topicID", "messageIDs") + TOPICID_FIELD_NUMBER: _ClassVar[int] + MESSAGEIDS_FIELD_NUMBER: _ClassVar[int] + topicID: str + messageIDs: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, topicID: _Optional[str] = ..., messageIDs: _Optional[_Iterable[str]] = ...) -> None: ... - SUBSCRIBE_FIELD_NUMBER: builtins.int - TOPICID_FIELD_NUMBER: builtins.int - subscribe: builtins.bool - """subscribe or unsubscribe""" - topicid: builtins.str - def __init__( - self, - *, - subscribe: builtins.bool | None = ..., - topicid: builtins.str | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["subscribe", b"subscribe", "topicid", b"topicid"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["subscribe", b"subscribe", "topicid", b"topicid"]) -> None: ... +class ControlIWant(_message.Message): + __slots__ = ("messageIDs",) + MESSAGEIDS_FIELD_NUMBER: _ClassVar[int] + messageIDs: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, messageIDs: _Optional[_Iterable[str]] = ...) -> None: ... - SUBSCRIPTIONS_FIELD_NUMBER: builtins.int - PUBLISH_FIELD_NUMBER: builtins.int - CONTROL_FIELD_NUMBER: builtins.int - @property - def subscriptions(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___RPC.SubOpts]: ... - @property - def publish(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message]: ... - @property - def control(self) -> global___ControlMessage: ... - def __init__( - self, - *, - subscriptions: collections.abc.Iterable[global___RPC.SubOpts] | None = ..., - publish: collections.abc.Iterable[global___Message] | None = ..., - control: global___ControlMessage | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["control", b"control"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["control", b"control", "publish", b"publish", "subscriptions", b"subscriptions"]) -> None: ... +class ControlGraft(_message.Message): + __slots__ = ("topicID",) + TOPICID_FIELD_NUMBER: _ClassVar[int] + topicID: str + def __init__(self, topicID: _Optional[str] = ...) -> None: ... -global___RPC = RPC +class ControlPrune(_message.Message): + __slots__ = ("topicID", "peers", "backoff") + TOPICID_FIELD_NUMBER: _ClassVar[int] + PEERS_FIELD_NUMBER: _ClassVar[int] + BACKOFF_FIELD_NUMBER: _ClassVar[int] + topicID: str + peers: _containers.RepeatedCompositeFieldContainer[PeerInfo] + backoff: int + def __init__(self, topicID: _Optional[str] = ..., peers: _Optional[_Iterable[_Union[PeerInfo, _Mapping]]] = ..., backoff: _Optional[int] = ...) -> None: ... # type: ignore -@typing.final -class Message(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor +class PeerInfo(_message.Message): + __slots__ = ("peerID", "signedPeerRecord") + PEERID_FIELD_NUMBER: _ClassVar[int] + SIGNEDPEERRECORD_FIELD_NUMBER: _ClassVar[int] + peerID: bytes + signedPeerRecord: bytes + def __init__(self, peerID: _Optional[bytes] = ..., signedPeerRecord: _Optional[bytes] = ...) -> None: ... - FROM_ID_FIELD_NUMBER: builtins.int - DATA_FIELD_NUMBER: builtins.int - SEQNO_FIELD_NUMBER: builtins.int - TOPICIDS_FIELD_NUMBER: builtins.int - SIGNATURE_FIELD_NUMBER: builtins.int - KEY_FIELD_NUMBER: builtins.int - from_id: builtins.bytes - data: builtins.bytes - seqno: builtins.bytes - signature: builtins.bytes - key: builtins.bytes - @property - def topicIDs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ... - def __init__( - self, - *, - from_id: builtins.bytes | None = ..., - data: builtins.bytes | None = ..., - seqno: builtins.bytes | None = ..., - topicIDs: collections.abc.Iterable[builtins.str] | None = ..., - signature: builtins.bytes | None = ..., - key: builtins.bytes | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["data", b"data", "from_id", b"from_id", "key", b"key", "seqno", b"seqno", "signature", b"signature"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["data", b"data", "from_id", b"from_id", "key", b"key", "seqno", b"seqno", "signature", b"signature", "topicIDs", b"topicIDs"]) -> None: ... - -global___Message = Message - -@typing.final -class ControlMessage(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - IHAVE_FIELD_NUMBER: builtins.int - IWANT_FIELD_NUMBER: builtins.int - GRAFT_FIELD_NUMBER: builtins.int - PRUNE_FIELD_NUMBER: builtins.int - @property - def ihave(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ControlIHave]: ... - @property - def iwant(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ControlIWant]: ... - @property - def graft(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ControlGraft]: ... - @property - def prune(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ControlPrune]: ... - def __init__( - self, - *, - ihave: collections.abc.Iterable[global___ControlIHave] | None = ..., - iwant: collections.abc.Iterable[global___ControlIWant] | None = ..., - graft: collections.abc.Iterable[global___ControlGraft] | None = ..., - prune: collections.abc.Iterable[global___ControlPrune] | None = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["graft", b"graft", "ihave", b"ihave", "iwant", b"iwant", "prune", b"prune"]) -> None: ... - -global___ControlMessage = ControlMessage - -@typing.final -class ControlIHave(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - TOPICID_FIELD_NUMBER: builtins.int - MESSAGEIDS_FIELD_NUMBER: builtins.int - topicID: builtins.str - @property - def messageIDs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ... - def __init__( - self, - *, - topicID: builtins.str | None = ..., - messageIDs: collections.abc.Iterable[builtins.str] | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["topicID", b"topicID"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["messageIDs", b"messageIDs", "topicID", b"topicID"]) -> None: ... - -global___ControlIHave = ControlIHave - -@typing.final -class ControlIWant(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - MESSAGEIDS_FIELD_NUMBER: builtins.int - @property - def messageIDs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ... - def __init__( - self, - *, - messageIDs: collections.abc.Iterable[builtins.str] | None = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["messageIDs", b"messageIDs"]) -> None: ... - -global___ControlIWant = ControlIWant - -@typing.final -class ControlGraft(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - TOPICID_FIELD_NUMBER: builtins.int - topicID: builtins.str - def __init__( - self, - *, - topicID: builtins.str | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["topicID", b"topicID"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["topicID", b"topicID"]) -> None: ... - -global___ControlGraft = ControlGraft - -@typing.final -class ControlPrune(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - TOPICID_FIELD_NUMBER: builtins.int - PEERS_FIELD_NUMBER: builtins.int - BACKOFF_FIELD_NUMBER: builtins.int - topicID: builtins.str - backoff: builtins.int - @property - def peers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___PeerInfo]: ... - def __init__( - self, - *, - topicID: builtins.str | None = ..., - peers: collections.abc.Iterable[global___PeerInfo] | None = ..., - backoff: builtins.int | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["backoff", b"backoff", "topicID", b"topicID"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["backoff", b"backoff", "peers", b"peers", "topicID", b"topicID"]) -> None: ... - -global___ControlPrune = ControlPrune - -@typing.final -class PeerInfo(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - PEERID_FIELD_NUMBER: builtins.int - SIGNEDPEERRECORD_FIELD_NUMBER: builtins.int - peerID: builtins.bytes - signedPeerRecord: builtins.bytes - def __init__( - self, - *, - peerID: builtins.bytes | None = ..., - signedPeerRecord: builtins.bytes | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["peerID", b"peerID", "signedPeerRecord", b"signedPeerRecord"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["peerID", b"peerID", "signedPeerRecord", b"signedPeerRecord"]) -> None: ... - -global___PeerInfo = PeerInfo - -@typing.final -class TopicDescriptor(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - @typing.final - class AuthOpts(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - class _AuthMode: - ValueType = typing.NewType("ValueType", builtins.int) - V: typing_extensions.TypeAlias = ValueType - - class _AuthModeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[TopicDescriptor.AuthOpts._AuthMode.ValueType], builtins.type): - DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor - NONE: TopicDescriptor.AuthOpts._AuthMode.ValueType # 0 - """no authentication, anyone can publish""" - KEY: TopicDescriptor.AuthOpts._AuthMode.ValueType # 1 - """only messages signed by keys in the topic descriptor are accepted""" - WOT: TopicDescriptor.AuthOpts._AuthMode.ValueType # 2 - """web of trust, certificates can allow publisher set to grow""" - - class AuthMode(_AuthMode, metaclass=_AuthModeEnumTypeWrapper): ... - NONE: TopicDescriptor.AuthOpts.AuthMode.ValueType # 0 - """no authentication, anyone can publish""" - KEY: TopicDescriptor.AuthOpts.AuthMode.ValueType # 1 - """only messages signed by keys in the topic descriptor are accepted""" - WOT: TopicDescriptor.AuthOpts.AuthMode.ValueType # 2 - """web of trust, certificates can allow publisher set to grow""" - - MODE_FIELD_NUMBER: builtins.int - KEYS_FIELD_NUMBER: builtins.int - mode: global___TopicDescriptor.AuthOpts.AuthMode.ValueType - @property - def keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: - """root keys to trust""" - - def __init__( - self, - *, - mode: global___TopicDescriptor.AuthOpts.AuthMode.ValueType | None = ..., - keys: collections.abc.Iterable[builtins.bytes] | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["mode", b"mode"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["keys", b"keys", "mode", b"mode"]) -> None: ... - - @typing.final - class EncOpts(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - class _EncMode: - ValueType = typing.NewType("ValueType", builtins.int) - V: typing_extensions.TypeAlias = ValueType - - class _EncModeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[TopicDescriptor.EncOpts._EncMode.ValueType], builtins.type): - DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor - NONE: TopicDescriptor.EncOpts._EncMode.ValueType # 0 - """no encryption, anyone can read""" - SHAREDKEY: TopicDescriptor.EncOpts._EncMode.ValueType # 1 - """messages are encrypted with shared key""" - WOT: TopicDescriptor.EncOpts._EncMode.ValueType # 2 - """web of trust, certificates can allow publisher set to grow""" - - class EncMode(_EncMode, metaclass=_EncModeEnumTypeWrapper): ... - NONE: TopicDescriptor.EncOpts.EncMode.ValueType # 0 - """no encryption, anyone can read""" - SHAREDKEY: TopicDescriptor.EncOpts.EncMode.ValueType # 1 - """messages are encrypted with shared key""" - WOT: TopicDescriptor.EncOpts.EncMode.ValueType # 2 - """web of trust, certificates can allow publisher set to grow""" - - MODE_FIELD_NUMBER: builtins.int - KEYHASHES_FIELD_NUMBER: builtins.int - mode: global___TopicDescriptor.EncOpts.EncMode.ValueType - @property - def keyHashes(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: - """the hashes of the shared keys used (salted)""" - - def __init__( - self, - *, - mode: global___TopicDescriptor.EncOpts.EncMode.ValueType | None = ..., - keyHashes: collections.abc.Iterable[builtins.bytes] | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["mode", b"mode"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["keyHashes", b"keyHashes", "mode", b"mode"]) -> None: ... - - NAME_FIELD_NUMBER: builtins.int - AUTH_FIELD_NUMBER: builtins.int - ENC_FIELD_NUMBER: builtins.int - name: builtins.str - @property - def auth(self) -> global___TopicDescriptor.AuthOpts: ... - @property - def enc(self) -> global___TopicDescriptor.EncOpts: ... - def __init__( - self, - *, - name: builtins.str | None = ..., - auth: global___TopicDescriptor.AuthOpts | None = ..., - enc: global___TopicDescriptor.EncOpts | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["auth", b"auth", "enc", b"enc", "name", b"name"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["auth", b"auth", "enc", b"enc", "name", b"name"]) -> None: ... - -global___TopicDescriptor = TopicDescriptor +class TopicDescriptor(_message.Message): + __slots__ = ("name", "auth", "enc") + class AuthOpts(_message.Message): + __slots__ = ("mode", "keys") + class AuthMode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + NONE: _ClassVar[TopicDescriptor.AuthOpts.AuthMode] + KEY: _ClassVar[TopicDescriptor.AuthOpts.AuthMode] + WOT: _ClassVar[TopicDescriptor.AuthOpts.AuthMode] + NONE: TopicDescriptor.AuthOpts.AuthMode + KEY: TopicDescriptor.AuthOpts.AuthMode + WOT: TopicDescriptor.AuthOpts.AuthMode + MODE_FIELD_NUMBER: _ClassVar[int] + KEYS_FIELD_NUMBER: _ClassVar[int] + mode: TopicDescriptor.AuthOpts.AuthMode + keys: _containers.RepeatedScalarFieldContainer[bytes] + def __init__(self, mode: _Optional[_Union[TopicDescriptor.AuthOpts.AuthMode, str]] = ..., keys: _Optional[_Iterable[bytes]] = ...) -> None: ... + class EncOpts(_message.Message): + __slots__ = ("mode", "keyHashes") + class EncMode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + NONE: _ClassVar[TopicDescriptor.EncOpts.EncMode] + SHAREDKEY: _ClassVar[TopicDescriptor.EncOpts.EncMode] + WOT: _ClassVar[TopicDescriptor.EncOpts.EncMode] + NONE: TopicDescriptor.EncOpts.EncMode + SHAREDKEY: TopicDescriptor.EncOpts.EncMode + WOT: TopicDescriptor.EncOpts.EncMode + MODE_FIELD_NUMBER: _ClassVar[int] + KEYHASHES_FIELD_NUMBER: _ClassVar[int] + mode: TopicDescriptor.EncOpts.EncMode + keyHashes: _containers.RepeatedScalarFieldContainer[bytes] + def __init__(self, mode: _Optional[_Union[TopicDescriptor.EncOpts.EncMode, str]] = ..., keyHashes: _Optional[_Iterable[bytes]] = ...) -> None: ... + NAME_FIELD_NUMBER: _ClassVar[int] + AUTH_FIELD_NUMBER: _ClassVar[int] + ENC_FIELD_NUMBER: _ClassVar[int] + name: str + auth: TopicDescriptor.AuthOpts + enc: TopicDescriptor.EncOpts + def __init__(self, name: _Optional[str] = ..., auth: _Optional[_Union[TopicDescriptor.AuthOpts, _Mapping]] = ..., enc: _Optional[_Union[TopicDescriptor.EncOpts, _Mapping]] = ...) -> None: ... # type: ignore diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 5641ec5d..54430f1b 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -50,12 +50,14 @@ from libp2p.network.stream.exceptions import ( StreamEOF, StreamReset, ) +from libp2p.peer.envelope import consume_envelope from libp2p.peer.id import ( ID, ) from libp2p.peer.peerdata import ( PeerDataError, ) +from libp2p.peer.peerstore import create_signed_peer_record from libp2p.tools.async_service import ( Service, ) @@ -247,6 +249,14 @@ class Pubsub(Service, IPubsub): packet.subscriptions.extend( [rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id)] ) + # Add the sender's signedRecord in the RPC message + envelope = create_signed_peer_record( + self.host.get_id(), + self.host.get_addrs(), + self.host.get_private_key(), + ) + packet.senderRecord = envelope.marshal_envelope() + return packet async def continuously_read_stream(self, stream: INetStream) -> None: @@ -263,6 +273,27 @@ class Pubsub(Service, IPubsub): incoming: bytes = await read_varint_prefixed_bytes(stream) rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC() rpc_incoming.ParseFromString(incoming) + + # Process the sender's signed-record if sent + if rpc_incoming.HasField("senderRecord"): + try: + # Convert the signed-peer-record(Envelope) from + # protobuf bytes + envelope, _ = consume_envelope( + rpc_incoming.senderRecord, "libp2p-peer-record" + ) + # Use the default TTL of 2 hours (7200 seconds) + if self.host.get_peerstore().consume_peer_record( + envelope, 7200 + ): + logger.error( + "Updating the Certified-Addr-Book was unsuccessful" + ) + except Exception as e: + logger.error( + "Error updating the certified addr book for peer: %s", e + ) + if rpc_incoming.publish: # deal with RPC.publish for msg in rpc_incoming.publish: @@ -572,6 +603,14 @@ class Pubsub(Service, IPubsub): [rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id)] ) + # Add the senderRecord of the peer in the RPC msg + envelope = create_signed_peer_record( + self.host.get_id(), + self.host.get_addrs(), + self.host.get_private_key(), + ) + packet.senderRecord = envelope.marshal_envelope() + # Send out subscribe message to all peers await self.message_all_peers(packet.SerializeToString()) @@ -604,6 +643,13 @@ class Pubsub(Service, IPubsub): packet.subscriptions.extend( [rpc_pb2.RPC.SubOpts(subscribe=False, topicid=topic_id)] ) + # Add the senderRecord of the peer in the RPC msg + envelope = create_signed_peer_record( + self.host.get_id(), + self.host.get_addrs(), + self.host.get_private_key(), + ) + packet.senderRecord = envelope.marshal_envelope() # Send out unsubscribe message to all peers await self.message_all_peers(packet.SerializeToString()) From d4c387f9234d8231e99a23d9a48d3a269d10a5f9 Mon Sep 17 00:00:00 2001 From: lla-dane Date: Thu, 14 Aug 2025 11:26:14 +0530 Subject: [PATCH 52/71] add reissuing mechanism of records if addrs dont change as done in #815 --- libp2p/pubsub/floodsub.py | 10 +++----- libp2p/pubsub/gossipsub.py | 46 ++++++---------------------------- libp2p/pubsub/pubsub.py | 47 ++++++----------------------------- libp2p/pubsub/utils.py | 51 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 70 insertions(+), 84 deletions(-) create mode 100644 libp2p/pubsub/utils.py diff --git a/libp2p/pubsub/floodsub.py b/libp2p/pubsub/floodsub.py index 170f558d..f0e09404 100644 --- a/libp2p/pubsub/floodsub.py +++ b/libp2p/pubsub/floodsub.py @@ -15,7 +15,7 @@ from libp2p.custom_types import ( from libp2p.peer.id import ( ID, ) -from libp2p.peer.peerstore import create_signed_peer_record +from libp2p.pubsub.utils import env_to_send_in_RPC from .exceptions import ( PubsubRouterError, @@ -106,12 +106,8 @@ class FloodSub(IPubsubRouter): # Add the senderRecord of the peer in the RPC msg if isinstance(self.pubsub, Pubsub): - envelope = create_signed_peer_record( - self.pubsub.host.get_id(), - self.pubsub.host.get_addrs(), - self.pubsub.host.get_private_key(), - ) - rpc_msg.senderRecord = envelope.marshal_envelope() + envelope_bytes, bool = env_to_send_in_RPC(self.pubsub.host) + rpc_msg.senderRecord = envelope_bytes logger.debug("publishing message %s", pubsub_msg) diff --git a/libp2p/pubsub/gossipsub.py b/libp2p/pubsub/gossipsub.py index b7c70c55..fa221a0f 100644 --- a/libp2p/pubsub/gossipsub.py +++ b/libp2p/pubsub/gossipsub.py @@ -24,7 +24,6 @@ from libp2p.abc import ( from libp2p.custom_types import ( TProtocol, ) -from libp2p.peer.envelope import consume_envelope from libp2p.peer.id import ( ID, ) @@ -35,11 +34,11 @@ from libp2p.peer.peerinfo import ( ) from libp2p.peer.peerstore import ( PERMANENT_ADDR_TTL, - create_signed_peer_record, ) from libp2p.pubsub import ( floodsub, ) +from libp2p.pubsub.utils import env_to_send_in_RPC, maybe_consume_signed_record from libp2p.tools.async_service import ( Service, ) @@ -230,24 +229,7 @@ class GossipSub(IPubsubRouter, Service): """ # Process the senderRecord if sent if isinstance(self.pubsub, Pubsub): - if rpc.HasField("senderRecord"): - try: - # Convert the signed-peer-record(Envelope) from - # protobuf bytes - envelope, _ = consume_envelope( - rpc.senderRecord, "libp2p-peer-record" - ) - # Use the default TTL of 2 hours (7200 seconds) - if self.pubsub.host.get_peerstore().consume_peer_record( - envelope, 7200 - ): - logger.error( - "Updating the Certified-Addr-Book was unsuccessful" - ) - except Exception as e: - logger.error( - "Error updating the certified addr book for peer: %s", e - ) + _ = maybe_consume_signed_record(rpc, self.pubsub.host) control_message = rpc.control @@ -278,12 +260,8 @@ class GossipSub(IPubsubRouter, Service): # Add the senderRecord of the peer in the RPC msg if isinstance(self.pubsub, Pubsub): - envelope = create_signed_peer_record( - self.pubsub.host.get_id(), - self.pubsub.host.get_addrs(), - self.pubsub.host.get_private_key(), - ) - rpc_msg.senderRecord = envelope.marshal_envelope() + envelope_bytes, bool = env_to_send_in_RPC(self.pubsub.host) + rpc_msg.senderRecord = envelope_bytes logger.debug("publishing message %s", pubsub_msg) @@ -854,12 +832,8 @@ class GossipSub(IPubsubRouter, Service): # to the iwant control msg, so we will send a freshly created senderRecord # with the RPC msg if isinstance(self.pubsub, Pubsub): - envelope = create_signed_peer_record( - self.pubsub.host.get_id(), - self.pubsub.host.get_addrs(), - self.pubsub.host.get_private_key(), - ) - packet.senderRecord = envelope.marshal_envelope() + envelope_bytes, bool = env_to_send_in_RPC(self.pubsub.host) + packet.senderRecord = envelope_bytes packet.publish.extend(msgs_to_forward) @@ -1019,12 +993,8 @@ class GossipSub(IPubsubRouter, Service): # Add the sender's peer-record in the RPC msg if isinstance(self.pubsub, Pubsub): - envelope = create_signed_peer_record( - self.pubsub.host.get_id(), - self.pubsub.host.get_addrs(), - self.pubsub.host.get_private_key(), - ) - packet.senderRecord = envelope.marshal_envelope() + envelope_bytes, bool = env_to_send_in_RPC(self.pubsub.host) + packet.senderRecord = envelope_bytes packet.control.CopyFrom(control_msg) diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 54430f1b..cbaaafb5 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -50,14 +50,13 @@ from libp2p.network.stream.exceptions import ( StreamEOF, StreamReset, ) -from libp2p.peer.envelope import consume_envelope from libp2p.peer.id import ( ID, ) from libp2p.peer.peerdata import ( PeerDataError, ) -from libp2p.peer.peerstore import create_signed_peer_record +from libp2p.pubsub.utils import env_to_send_in_RPC, maybe_consume_signed_record from libp2p.tools.async_service import ( Service, ) @@ -250,12 +249,8 @@ class Pubsub(Service, IPubsub): [rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id)] ) # Add the sender's signedRecord in the RPC message - envelope = create_signed_peer_record( - self.host.get_id(), - self.host.get_addrs(), - self.host.get_private_key(), - ) - packet.senderRecord = envelope.marshal_envelope() + envelope_bytes, bool = env_to_send_in_RPC(self.host) + packet.senderRecord = envelope_bytes return packet @@ -275,24 +270,7 @@ class Pubsub(Service, IPubsub): rpc_incoming.ParseFromString(incoming) # Process the sender's signed-record if sent - if rpc_incoming.HasField("senderRecord"): - try: - # Convert the signed-peer-record(Envelope) from - # protobuf bytes - envelope, _ = consume_envelope( - rpc_incoming.senderRecord, "libp2p-peer-record" - ) - # Use the default TTL of 2 hours (7200 seconds) - if self.host.get_peerstore().consume_peer_record( - envelope, 7200 - ): - logger.error( - "Updating the Certified-Addr-Book was unsuccessful" - ) - except Exception as e: - logger.error( - "Error updating the certified addr book for peer: %s", e - ) + _ = maybe_consume_signed_record(rpc_incoming, self.host) if rpc_incoming.publish: # deal with RPC.publish @@ -604,13 +582,8 @@ class Pubsub(Service, IPubsub): ) # Add the senderRecord of the peer in the RPC msg - envelope = create_signed_peer_record( - self.host.get_id(), - self.host.get_addrs(), - self.host.get_private_key(), - ) - packet.senderRecord = envelope.marshal_envelope() - + envelope_bytes, bool = env_to_send_in_RPC(self.host) + packet.senderRecord = envelope_bytes # Send out subscribe message to all peers await self.message_all_peers(packet.SerializeToString()) @@ -644,12 +617,8 @@ class Pubsub(Service, IPubsub): [rpc_pb2.RPC.SubOpts(subscribe=False, topicid=topic_id)] ) # Add the senderRecord of the peer in the RPC msg - envelope = create_signed_peer_record( - self.host.get_id(), - self.host.get_addrs(), - self.host.get_private_key(), - ) - packet.senderRecord = envelope.marshal_envelope() + envelope_bytes, bool = env_to_send_in_RPC(self.host) + packet.senderRecord = envelope_bytes # Send out unsubscribe message to all peers await self.message_all_peers(packet.SerializeToString()) diff --git a/libp2p/pubsub/utils.py b/libp2p/pubsub/utils.py new file mode 100644 index 00000000..163a2870 --- /dev/null +++ b/libp2p/pubsub/utils.py @@ -0,0 +1,51 @@ +import logging + +from libp2p.abc import IHost +from libp2p.peer.envelope import consume_envelope +from libp2p.peer.peerstore import create_signed_peer_record +from libp2p.pubsub.pb.rpc_pb2 import RPC + +logger = logging.getLogger("pubsub-example.utils") + + +def maybe_consume_signed_record(msg: RPC, host: IHost) -> bool: + if msg.HasField("senderRecord"): + try: + # Convert the signed-peer-record(Envelope) from + # protobuf bytes + envelope, _ = consume_envelope(msg.senderRecord, "libp2p-peer-record") + # Use the default TTL of 2 hours (7200 seconds) + if not host.get_peerstore().consume_peer_record(envelope, 7200): + logger.error("Updating the certified-addr-book was unsuccessful") + except Exception as e: + logger.error("Error updating the certified addr book for peer: %s", e) + return False + return True + + +def env_to_send_in_RPC(host: IHost) -> tuple[bytes, bool]: + listen_addrs_set = {addr for addr in host.get_addrs()} + local_env = host.get_peerstore().get_local_record() + + if local_env is None: + # No cached SPR yet -> create one + return issue_and_cache_local_record(host), True + else: + record_addrs_set = local_env._env_addrs_set() + if record_addrs_set == listen_addrs_set: + # Perfect match -> reuse the cached envelope + return local_env.marshal_envelope(), False + else: + # Addresses changed -> issue a new SPR and cache it + return issue_and_cache_local_record(host), True + + +def issue_and_cache_local_record(host: IHost) -> bytes: + env = create_signed_peer_record( + host.get_id(), + host.get_addrs(), + host.get_private_key(), + ) + # Cache it for next time + host.get_peerstore().set_local_record(env) + return env.marshal_envelope() From cdfb083c0617ce81cad363889b0b0787e2643570 Mon Sep 17 00:00:00 2001 From: lla-dane Date: Thu, 14 Aug 2025 15:53:05 +0530 Subject: [PATCH 53/71] added tests to see if transfer works correctly --- tests/core/pubsub/test_pubsub.py | 61 ++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/tests/core/pubsub/test_pubsub.py b/tests/core/pubsub/test_pubsub.py index e674dbc0..179f359c 100644 --- a/tests/core/pubsub/test_pubsub.py +++ b/tests/core/pubsub/test_pubsub.py @@ -8,6 +8,7 @@ from typing import ( from unittest.mock import patch import pytest +import multiaddr import trio from libp2p.custom_types import AsyncValidatorFn @@ -17,6 +18,7 @@ from libp2p.exceptions import ( from libp2p.network.stream.exceptions import ( StreamEOF, ) +from libp2p.peer.envelope import Envelope from libp2p.peer.id import ( ID, ) @@ -87,6 +89,45 @@ async def test_re_unsubscribe(): assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids +@pytest.mark.trio +async def test_reissue_when_listen_addrs_change(): + async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub: + await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host) + await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + # Yield to let 0 notify 1 + await trio.sleep(1) + assert pubsubs_fsub[0].my_id in pubsubs_fsub[1].peer_topics[TESTING_TOPIC] + + # Check whether signed-records were transfered properly in the subscribe call + envelope_b_sub = ( + pubsubs_fsub[1] + .host.get_peerstore() + .get_peer_record(pubsubs_fsub[0].host.get_id()) + ) + assert isinstance(envelope_b_sub, Envelope) + + # Simulate pubsubs_fsub[1].host listen addrs changing (different port) + new_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/123") + + # Patch just for the duration we force A to unsubscribe + with patch.object(pubsubs_fsub[0].host, "get_addrs", return_value=[new_addr]): + # Unsubscribe from A's side so that a new_record is issued + await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) + await trio.sleep(1) + + # B should be holding A's new record with bumped seq + envelope_b_unsub = ( + pubsubs_fsub[1] + .host.get_peerstore() + .get_peer_record(pubsubs_fsub[0].host.get_id()) + ) + assert isinstance(envelope_b_unsub, Envelope) + + # This proves that a freshly signed record was issued rather than + # the latest-cached-one creating one. + assert envelope_b_sub.record().seq < envelope_b_unsub.record().seq + + @pytest.mark.trio async def test_peers_subscribe(): async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub: @@ -95,11 +136,31 @@ async def test_peers_subscribe(): # Yield to let 0 notify 1 await trio.sleep(1) assert pubsubs_fsub[0].my_id in pubsubs_fsub[1].peer_topics[TESTING_TOPIC] + + # Check whether signed-records were transfered properly in the subscribe call + envelope_b_sub = ( + pubsubs_fsub[1] + .host.get_peerstore() + .get_peer_record(pubsubs_fsub[0].host.get_id()) + ) + assert isinstance(envelope_b_sub, Envelope) + await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) # Yield to let 0 notify 1 await trio.sleep(1) assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics[TESTING_TOPIC] + envelope_b_unsub = ( + pubsubs_fsub[1] + .host.get_peerstore() + .get_peer_record(pubsubs_fsub[0].host.get_id()) + ) + assert isinstance(envelope_b_unsub, Envelope) + + # This proves that the latest-cached-record was re-issued rather than + # freshly creating one. + assert envelope_b_sub.record().seq == envelope_b_unsub.record().seq + @pytest.mark.trio async def test_get_hello_packet(): From d99b67eafa3727f8730597860e3634ea629aeb7f Mon Sep 17 00:00:00 2001 From: lla-dane Date: Sun, 17 Aug 2025 13:53:25 +0530 Subject: [PATCH 54/71] now ignoring pubsub messages upon receving invalid-signed-records --- libp2p/pubsub/gossipsub.py | 4 +++- libp2p/pubsub/pubsub.py | 6 +++++- tests/core/pubsub/test_pubsub.py | 22 ++++++++++++++++++++++ 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/libp2p/pubsub/gossipsub.py b/libp2p/pubsub/gossipsub.py index fa221a0f..aaf0b2fa 100644 --- a/libp2p/pubsub/gossipsub.py +++ b/libp2p/pubsub/gossipsub.py @@ -229,7 +229,9 @@ class GossipSub(IPubsubRouter, Service): """ # Process the senderRecord if sent if isinstance(self.pubsub, Pubsub): - _ = maybe_consume_signed_record(rpc, self.pubsub.host) + if not maybe_consume_signed_record(rpc, self.pubsub.host): + logger.error("Received an invalid-signed-record, ignoring the message") + return control_message = rpc.control diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index cbaaafb5..3200c73a 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -270,7 +270,11 @@ class Pubsub(Service, IPubsub): rpc_incoming.ParseFromString(incoming) # Process the sender's signed-record if sent - _ = maybe_consume_signed_record(rpc_incoming, self.host) + if not maybe_consume_signed_record(rpc_incoming, self.host): + logger.error( + "Received an invalid-signed-record, ignoring the incoming msg" + ) + continue if rpc_incoming.publish: # deal with RPC.publish diff --git a/tests/core/pubsub/test_pubsub.py b/tests/core/pubsub/test_pubsub.py index 179f359c..54bc67a1 100644 --- a/tests/core/pubsub/test_pubsub.py +++ b/tests/core/pubsub/test_pubsub.py @@ -11,6 +11,7 @@ import pytest import multiaddr import trio +from libp2p.crypto.rsa import create_new_key_pair from libp2p.custom_types import AsyncValidatorFn from libp2p.exceptions import ( ValidationError, @@ -162,6 +163,27 @@ async def test_peers_subscribe(): assert envelope_b_sub.record().seq == envelope_b_unsub.record().seq +@pytest.mark.trio +async def test_peer_subscribe_fail_upon_invald_record_transfer(): + async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub: + await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host) + + # Corrupt host_a's local peer record + envelope = pubsubs_fsub[0].host.get_peerstore().get_local_record() + key_pair = create_new_key_pair() + + if envelope is not None: + envelope.public_key = key_pair.public_key + pubsubs_fsub[0].host.get_peerstore().set_local_record(envelope) + + await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + # Yeild to let 0 notify 1 + await trio.sleep(1) + assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics.get( + TESTING_TOPIC, set() + ) + + @pytest.mark.trio async def test_get_hello_packet(): async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: From b26e8333bdb5a81fa779ff46938d6c69a7602360 Mon Sep 17 00:00:00 2001 From: lla-dane Date: Sat, 23 Aug 2025 18:01:57 +0530 Subject: [PATCH 55/71] updated as per the suggestions in #815 --- libp2p/pubsub/floodsub.py | 4 +-- libp2p/pubsub/gossipsub.py | 11 +++--- libp2p/pubsub/pubsub.py | 11 +++--- libp2p/pubsub/utils.py | 61 ++++++++++++++++---------------- tests/core/pubsub/test_pubsub.py | 22 +++++++++++- 5 files changed, 65 insertions(+), 44 deletions(-) diff --git a/libp2p/pubsub/floodsub.py b/libp2p/pubsub/floodsub.py index f0e09404..8167581d 100644 --- a/libp2p/pubsub/floodsub.py +++ b/libp2p/pubsub/floodsub.py @@ -15,7 +15,7 @@ from libp2p.custom_types import ( from libp2p.peer.id import ( ID, ) -from libp2p.pubsub.utils import env_to_send_in_RPC +from libp2p.peer.peerstore import env_to_send_in_RPC from .exceptions import ( PubsubRouterError, @@ -106,7 +106,7 @@ class FloodSub(IPubsubRouter): # Add the senderRecord of the peer in the RPC msg if isinstance(self.pubsub, Pubsub): - envelope_bytes, bool = env_to_send_in_RPC(self.pubsub.host) + envelope_bytes, _ = env_to_send_in_RPC(self.pubsub.host) rpc_msg.senderRecord = envelope_bytes logger.debug("publishing message %s", pubsub_msg) diff --git a/libp2p/pubsub/gossipsub.py b/libp2p/pubsub/gossipsub.py index aaf0b2fa..a4c8c463 100644 --- a/libp2p/pubsub/gossipsub.py +++ b/libp2p/pubsub/gossipsub.py @@ -34,11 +34,12 @@ from libp2p.peer.peerinfo import ( ) from libp2p.peer.peerstore import ( PERMANENT_ADDR_TTL, + env_to_send_in_RPC, ) from libp2p.pubsub import ( floodsub, ) -from libp2p.pubsub.utils import env_to_send_in_RPC, maybe_consume_signed_record +from libp2p.pubsub.utils import maybe_consume_signed_record from libp2p.tools.async_service import ( Service, ) @@ -229,7 +230,7 @@ class GossipSub(IPubsubRouter, Service): """ # Process the senderRecord if sent if isinstance(self.pubsub, Pubsub): - if not maybe_consume_signed_record(rpc, self.pubsub.host): + if not maybe_consume_signed_record(rpc, self.pubsub.host, sender_peer_id): logger.error("Received an invalid-signed-record, ignoring the message") return @@ -262,7 +263,7 @@ class GossipSub(IPubsubRouter, Service): # Add the senderRecord of the peer in the RPC msg if isinstance(self.pubsub, Pubsub): - envelope_bytes, bool = env_to_send_in_RPC(self.pubsub.host) + envelope_bytes, _ = env_to_send_in_RPC(self.pubsub.host) rpc_msg.senderRecord = envelope_bytes logger.debug("publishing message %s", pubsub_msg) @@ -834,7 +835,7 @@ class GossipSub(IPubsubRouter, Service): # to the iwant control msg, so we will send a freshly created senderRecord # with the RPC msg if isinstance(self.pubsub, Pubsub): - envelope_bytes, bool = env_to_send_in_RPC(self.pubsub.host) + envelope_bytes, _ = env_to_send_in_RPC(self.pubsub.host) packet.senderRecord = envelope_bytes packet.publish.extend(msgs_to_forward) @@ -995,7 +996,7 @@ class GossipSub(IPubsubRouter, Service): # Add the sender's peer-record in the RPC msg if isinstance(self.pubsub, Pubsub): - envelope_bytes, bool = env_to_send_in_RPC(self.pubsub.host) + envelope_bytes, _ = env_to_send_in_RPC(self.pubsub.host) packet.senderRecord = envelope_bytes packet.control.CopyFrom(control_msg) diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 3200c73a..2c605fc3 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -56,7 +56,8 @@ from libp2p.peer.id import ( from libp2p.peer.peerdata import ( PeerDataError, ) -from libp2p.pubsub.utils import env_to_send_in_RPC, maybe_consume_signed_record +from libp2p.peer.peerstore import env_to_send_in_RPC +from libp2p.pubsub.utils import maybe_consume_signed_record from libp2p.tools.async_service import ( Service, ) @@ -249,7 +250,7 @@ class Pubsub(Service, IPubsub): [rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id)] ) # Add the sender's signedRecord in the RPC message - envelope_bytes, bool = env_to_send_in_RPC(self.host) + envelope_bytes, _ = env_to_send_in_RPC(self.host) packet.senderRecord = envelope_bytes return packet @@ -270,7 +271,7 @@ class Pubsub(Service, IPubsub): rpc_incoming.ParseFromString(incoming) # Process the sender's signed-record if sent - if not maybe_consume_signed_record(rpc_incoming, self.host): + if not maybe_consume_signed_record(rpc_incoming, self.host, peer_id): logger.error( "Received an invalid-signed-record, ignoring the incoming msg" ) @@ -586,7 +587,7 @@ class Pubsub(Service, IPubsub): ) # Add the senderRecord of the peer in the RPC msg - envelope_bytes, bool = env_to_send_in_RPC(self.host) + envelope_bytes, _ = env_to_send_in_RPC(self.host) packet.senderRecord = envelope_bytes # Send out subscribe message to all peers await self.message_all_peers(packet.SerializeToString()) @@ -621,7 +622,7 @@ class Pubsub(Service, IPubsub): [rpc_pb2.RPC.SubOpts(subscribe=False, topicid=topic_id)] ) # Add the senderRecord of the peer in the RPC msg - envelope_bytes, bool = env_to_send_in_RPC(self.host) + envelope_bytes, _ = env_to_send_in_RPC(self.host) packet.senderRecord = envelope_bytes # Send out unsubscribe message to all peers diff --git a/libp2p/pubsub/utils.py b/libp2p/pubsub/utils.py index 163a2870..3a69becb 100644 --- a/libp2p/pubsub/utils.py +++ b/libp2p/pubsub/utils.py @@ -2,50 +2,49 @@ import logging from libp2p.abc import IHost from libp2p.peer.envelope import consume_envelope -from libp2p.peer.peerstore import create_signed_peer_record +from libp2p.peer.id import ID from libp2p.pubsub.pb.rpc_pb2 import RPC logger = logging.getLogger("pubsub-example.utils") -def maybe_consume_signed_record(msg: RPC, host: IHost) -> bool: +def maybe_consume_signed_record(msg: RPC, host: IHost, peer_id: ID) -> bool: + """ + Attempt to parse and store a signed-peer-record (Envelope) received during + PubSub communication. If the record is invalid, the peer-id does not match, or + updating the peerstore fails, the function logs an error and returns False. + + Parameters + ---------- + msg : RPC + The protobuf message received during PubSub communication. + host : IHost + The local host instance, providing access to the peerstore for storing + verified peer records. + peer_id : ID | None, optional + The expected peer ID for record validation. If provided, the peer ID + inside the record must match this value. + + Returns + ------- + bool + True if a valid signed peer record was successfully consumed and stored, + False otherwise. + + """ if msg.HasField("senderRecord"): try: # Convert the signed-peer-record(Envelope) from # protobuf bytes - envelope, _ = consume_envelope(msg.senderRecord, "libp2p-peer-record") + envelope, record = consume_envelope(msg.senderRecord, "libp2p-peer-record") + if not record.peer_id == peer_id: + return False + # Use the default TTL of 2 hours (7200 seconds) if not host.get_peerstore().consume_peer_record(envelope, 7200): logger.error("Updating the certified-addr-book was unsuccessful") + return False except Exception as e: logger.error("Error updating the certified addr book for peer: %s", e) return False return True - - -def env_to_send_in_RPC(host: IHost) -> tuple[bytes, bool]: - listen_addrs_set = {addr for addr in host.get_addrs()} - local_env = host.get_peerstore().get_local_record() - - if local_env is None: - # No cached SPR yet -> create one - return issue_and_cache_local_record(host), True - else: - record_addrs_set = local_env._env_addrs_set() - if record_addrs_set == listen_addrs_set: - # Perfect match -> reuse the cached envelope - return local_env.marshal_envelope(), False - else: - # Addresses changed -> issue a new SPR and cache it - return issue_and_cache_local_record(host), True - - -def issue_and_cache_local_record(host: IHost) -> bytes: - env = create_signed_peer_record( - host.get_id(), - host.get_addrs(), - host.get_private_key(), - ) - # Cache it for next time - host.get_peerstore().set_local_record(env) - return env.marshal_envelope() diff --git a/tests/core/pubsub/test_pubsub.py b/tests/core/pubsub/test_pubsub.py index 54bc67a1..9a09f34f 100644 --- a/tests/core/pubsub/test_pubsub.py +++ b/tests/core/pubsub/test_pubsub.py @@ -19,10 +19,11 @@ from libp2p.exceptions import ( from libp2p.network.stream.exceptions import ( StreamEOF, ) -from libp2p.peer.envelope import Envelope +from libp2p.peer.envelope import Envelope, seal_record from libp2p.peer.id import ( ID, ) +from libp2p.peer.peer_record import PeerRecord from libp2p.pubsub.pb import ( rpc_pb2, ) @@ -170,6 +171,8 @@ async def test_peer_subscribe_fail_upon_invald_record_transfer(): # Corrupt host_a's local peer record envelope = pubsubs_fsub[0].host.get_peerstore().get_local_record() + if envelope is not None: + true_record = envelope.record() key_pair = create_new_key_pair() if envelope is not None: @@ -183,6 +186,23 @@ async def test_peer_subscribe_fail_upon_invald_record_transfer(): TESTING_TOPIC, set() ) + # Create a corrupt envelope with correct signature but false peer-id + false_record = PeerRecord( + ID.from_pubkey(key_pair.public_key), true_record.addrs + ) + false_envelope = seal_record( + false_record, pubsubs_fsub[0].host.get_private_key() + ) + + pubsubs_fsub[0].host.get_peerstore().set_local_record(false_envelope) + + await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + # Yeild to let 0 notify 1 + await trio.sleep(1) + assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics.get( + TESTING_TOPIC, set() + ) + @pytest.mark.trio async def test_get_hello_packet(): From cb5bfeda396d60ab0f5b29030205c04ea4cb73c5 Mon Sep 17 00:00:00 2001 From: lla-dane Date: Sat, 23 Aug 2025 18:22:45 +0530 Subject: [PATCH 56/71] Use the same comment in maybe_consume_peer_record function --- libp2p/pubsub/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libp2p/pubsub/utils.py b/libp2p/pubsub/utils.py index 3a69becb..6686ba69 100644 --- a/libp2p/pubsub/utils.py +++ b/libp2p/pubsub/utils.py @@ -42,9 +42,9 @@ def maybe_consume_signed_record(msg: RPC, host: IHost, peer_id: ID) -> bool: # Use the default TTL of 2 hours (7200 seconds) if not host.get_peerstore().consume_peer_record(envelope, 7200): - logger.error("Updating the certified-addr-book was unsuccessful") + logger.error("Failed to update the Certified-Addr-Book") return False except Exception as e: - logger.error("Error updating the certified addr book for peer: %s", e) + logger.error("Failed to update the Certified-Addr-Book: %s", e) return False return True From 96e2149f4d2234a729b7ea9a00d3f73422fa36dc Mon Sep 17 00:00:00 2001 From: lla-dane Date: Tue, 26 Aug 2025 12:56:20 +0530 Subject: [PATCH 57/71] added newsfragment --- newsfragments/835.feature.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 newsfragments/835.feature.rst diff --git a/newsfragments/835.feature.rst b/newsfragments/835.feature.rst new file mode 100644 index 00000000..7e42f18e --- /dev/null +++ b/newsfragments/835.feature.rst @@ -0,0 +1 @@ +PubSub routers now include signed-peer-records in RPC messages for secure peer-info exchange. From 31040931ea7543e3d993662ddb9564bd77f40c04 Mon Sep 17 00:00:00 2001 From: acul71 Date: Sat, 30 Aug 2025 23:44:49 +0200 Subject: [PATCH 58/71] fix: remove unused upgrade_listener function (Issue 2 from #726) --- libp2p/transport/upgrader.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/libp2p/transport/upgrader.py b/libp2p/transport/upgrader.py index 8b47fff4..40ba5321 100644 --- a/libp2p/transport/upgrader.py +++ b/libp2p/transport/upgrader.py @@ -1,9 +1,7 @@ from libp2p.abc import ( - IListener, IMuxedConn, IRawConnection, ISecureConn, - ITransport, ) from libp2p.custom_types import ( TMuxerOptions, @@ -43,10 +41,6 @@ class TransportUpgrader: self.security_multistream = SecurityMultistream(secure_transports_by_protocol) self.muxer_multistream = MuxerMultistream(muxer_transports_by_protocol) - def upgrade_listener(self, transport: ITransport, listeners: IListener) -> None: - """Upgrade multiaddr listeners to libp2p-transport listeners.""" - # TODO: Figure out what to do with this function. - async def upgrade_security( self, raw_conn: IRawConnection, From d620270eafa1b1858874f77b05e51f3dcf6e3a45 Mon Sep 17 00:00:00 2001 From: acul71 Date: Sun, 31 Aug 2025 00:10:15 +0200 Subject: [PATCH 59/71] docs: add newsfragment for issue 883 - remove unused upgrade_listener function --- newsfragments/883.internal.rst | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 newsfragments/883.internal.rst diff --git a/newsfragments/883.internal.rst b/newsfragments/883.internal.rst new file mode 100644 index 00000000..a9ca3a0e --- /dev/null +++ b/newsfragments/883.internal.rst @@ -0,0 +1,5 @@ +Remove unused upgrade_listener function from transport upgrader + +- Remove unused `upgrade_listener` function from `libp2p/transport/upgrader.py` (Issue 2 from #726) +- Clean up unused imports related to the removed function +- Improve code maintainability by removing dead code From 59e1d9ae39a09d2730919ace567753412185e940 Mon Sep 17 00:00:00 2001 From: bomanaps Date: Sun, 31 Aug 2025 01:38:29 +0100 Subject: [PATCH 60/71] address architectural refactoring discussed --- docs/examples.multiple_connections.rst | 133 +++++ docs/examples.rst | 1 + .../multiple_connections_example.py} | 111 ++-- libp2p/abc.py | 62 ++- libp2p/host/basic_host.py | 4 +- libp2p/network/swarm.py | 503 +++++++++--------- tests/core/host/test_live_peers.py | 4 +- tests/core/network/test_enhanced_swarm.py | 365 +++++-------- tests/core/network/test_swarm.py | 77 ++- .../security/test_security_multistream.py | 3 + .../test_multiplexer_selection.py | 9 +- tests/utils/factories.py | 4 +- 12 files changed, 705 insertions(+), 571 deletions(-) create mode 100644 docs/examples.multiple_connections.rst rename examples/{enhanced_swarm_example.py => doc-examples/multiple_connections_example.py} (55%) diff --git a/docs/examples.multiple_connections.rst b/docs/examples.multiple_connections.rst new file mode 100644 index 00000000..da1d3b02 --- /dev/null +++ b/docs/examples.multiple_connections.rst @@ -0,0 +1,133 @@ +Multiple Connections Per Peer +============================ + +This example demonstrates how to use the multiple connections per peer feature in py-libp2p. + +Overview +-------- + +The multiple connections per peer feature allows a libp2p node to maintain multiple network connections to the same peer. This provides several benefits: + +- **Improved reliability**: If one connection fails, others remain available +- **Better performance**: Load can be distributed across multiple connections +- **Enhanced throughput**: Multiple streams can be created in parallel +- **Fault tolerance**: Redundant connections provide backup paths + +Configuration +------------- + +The feature is configured through the `ConnectionConfig` class: + +.. code-block:: python + + from libp2p.network.swarm import ConnectionConfig + + # Default configuration + config = ConnectionConfig() + print(f"Max connections per peer: {config.max_connections_per_peer}") + print(f"Load balancing strategy: {config.load_balancing_strategy}") + + # Custom configuration + custom_config = ConnectionConfig( + max_connections_per_peer=5, + connection_timeout=60.0, + load_balancing_strategy="least_loaded" + ) + +Load Balancing Strategies +------------------------ + +Two load balancing strategies are available: + +**Round Robin** (default) + Cycles through connections in order, distributing load evenly. + +**Least Loaded** + Selects the connection with the fewest active streams. + +API Usage +--------- + +The new API provides direct access to multiple connections: + +.. code-block:: python + + from libp2p import new_swarm + + # Create swarm with multiple connections support + swarm = new_swarm() + + # Dial a peer - returns list of connections + connections = await swarm.dial_peer(peer_id) + print(f"Established {len(connections)} connections") + + # Get all connections to a peer + peer_connections = swarm.get_connections(peer_id) + + # Get all connections (across all peers) + all_connections = swarm.get_connections() + + # Get the complete connections map + connections_map = swarm.get_connections_map() + + # Backward compatibility - get single connection + single_conn = swarm.get_connection(peer_id) + +Backward Compatibility +--------------------- + +Existing code continues to work through backward compatibility features: + +.. code-block:: python + + # Legacy 1:1 mapping (returns first connection for each peer) + legacy_connections = swarm.connections_legacy + + # Single connection access (returns first available connection) + conn = swarm.get_connection(peer_id) + +Example +------- + +See :doc:`examples/doc-examples/multiple_connections_example.py` for a complete working example. + +Production Configuration +----------------------- + +For production use, consider these settings: + +.. code-block:: python + + from libp2p.network.swarm import ConnectionConfig, RetryConfig + + # Production-ready configuration + retry_config = RetryConfig( + max_retries=3, + initial_delay=0.1, + max_delay=30.0, + backoff_multiplier=2.0, + jitter_factor=0.1 + ) + + connection_config = ConnectionConfig( + max_connections_per_peer=3, # Balance performance and resources + connection_timeout=30.0, # Reasonable timeout + load_balancing_strategy="round_robin" # Predictable behavior + ) + + swarm = new_swarm( + retry_config=retry_config, + connection_config=connection_config + ) + +Architecture +----------- + +The implementation follows the same architectural patterns as the Go and JavaScript reference implementations: + +- **Core data structure**: `dict[ID, list[INetConn]]` for 1:many mapping +- **API consistency**: Methods like `get_connections()` match reference implementations +- **Load balancing**: Integrated at the API level for optimal performance +- **Backward compatibility**: Maintains existing interfaces for gradual migration + +This design ensures consistency across libp2p implementations while providing the benefits of multiple connections per peer. diff --git a/docs/examples.rst b/docs/examples.rst index b8ba44d7..74864cbe 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -15,3 +15,4 @@ Examples examples.kademlia examples.mDNS examples.random_walk + examples.multiple_connections diff --git a/examples/enhanced_swarm_example.py b/examples/doc-examples/multiple_connections_example.py similarity index 55% rename from examples/enhanced_swarm_example.py rename to examples/doc-examples/multiple_connections_example.py index b5367af8..14a71ab8 100644 --- a/examples/enhanced_swarm_example.py +++ b/examples/doc-examples/multiple_connections_example.py @@ -1,18 +1,18 @@ #!/usr/bin/env python3 """ -Example demonstrating the enhanced Swarm with retry logic, exponential backoff, -and multi-connection support. +Example demonstrating multiple connections per peer support in libp2p. This example shows how to: -1. Configure retry behavior with exponential backoff -2. Enable multi-connection support with connection pooling -3. Use different load balancing strategies +1. Configure multiple connections per peer +2. Use different load balancing strategies +3. Access multiple connections through the new API 4. Maintain backward compatibility """ -import asyncio import logging +import trio + from libp2p import new_swarm from libp2p.network.swarm import ConnectionConfig, RetryConfig @@ -21,64 +21,32 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -async def example_basic_enhanced_swarm() -> None: - """Example of basic enhanced Swarm usage.""" - logger.info("Creating enhanced Swarm with default configuration...") +async def example_basic_multiple_connections() -> None: + """Example of basic multiple connections per peer usage.""" + logger.info("Creating swarm with multiple connections support...") - # Create enhanced swarm with default retry and connection config + # Create swarm with default configuration swarm = new_swarm() - # Use default configuration values directly - default_retry = RetryConfig() default_connection = ConnectionConfig() logger.info(f"Swarm created with peer ID: {swarm.get_peer_id()}") - logger.info(f"Retry config: max_retries={default_retry.max_retries}") logger.info( f"Connection config: max_connections_per_peer=" f"{default_connection.max_connections_per_peer}" ) - logger.info(f"Connection pool enabled: {default_connection.enable_connection_pool}") await swarm.close() - logger.info("Basic enhanced Swarm example completed") - - -async def example_custom_retry_config() -> None: - """Example of custom retry configuration.""" - logger.info("Creating enhanced Swarm with custom retry configuration...") - - # Custom retry configuration for aggressive retry behavior - retry_config = RetryConfig( - max_retries=5, # More retries - initial_delay=0.05, # Faster initial retry - max_delay=10.0, # Lower max delay - backoff_multiplier=1.5, # Less aggressive backoff - jitter_factor=0.2, # More jitter - ) - - # Create swarm with custom retry config - swarm = new_swarm(retry_config=retry_config) - - logger.info("Custom retry config applied:") - logger.info(f" Max retries: {retry_config.max_retries}") - logger.info(f" Initial delay: {retry_config.initial_delay}s") - logger.info(f" Max delay: {retry_config.max_delay}s") - logger.info(f" Backoff multiplier: {retry_config.backoff_multiplier}") - logger.info(f" Jitter factor: {retry_config.jitter_factor}") - - await swarm.close() - logger.info("Custom retry config example completed") + logger.info("Basic multiple connections example completed") async def example_custom_connection_config() -> None: """Example of custom connection configuration.""" - logger.info("Creating enhanced Swarm with custom connection configuration...") + logger.info("Creating swarm with custom connection configuration...") # Custom connection configuration for high-performance scenarios connection_config = ConnectionConfig( max_connections_per_peer=5, # More connections per peer connection_timeout=60.0, # Longer timeout - enable_connection_pool=True, # Enable connection pooling load_balancing_strategy="least_loaded", # Use least loaded strategy ) @@ -90,9 +58,6 @@ async def example_custom_connection_config() -> None: f" Max connections per peer: {connection_config.max_connections_per_peer}" ) logger.info(f" Connection timeout: {connection_config.connection_timeout}s") - logger.info( - f" Connection pool enabled: {connection_config.enable_connection_pool}" - ) logger.info( f" Load balancing strategy: {connection_config.load_balancing_strategy}" ) @@ -101,22 +66,39 @@ async def example_custom_connection_config() -> None: logger.info("Custom connection config example completed") -async def example_backward_compatibility() -> None: - """Example showing backward compatibility.""" - logger.info("Creating enhanced Swarm with backward compatibility...") +async def example_multiple_connections_api() -> None: + """Example of using the new multiple connections API.""" + logger.info("Demonstrating multiple connections API...") - # Disable connection pool to maintain original behavior - connection_config = ConnectionConfig(enable_connection_pool=False) + connection_config = ConnectionConfig( + max_connections_per_peer=3, + load_balancing_strategy="round_robin" + ) - # Create swarm with connection pool disabled swarm = new_swarm(connection_config=connection_config) - logger.info("Backward compatibility mode:") + logger.info("Multiple connections API features:") + logger.info(" - dial_peer() returns list[INetConn]") + logger.info(" - get_connections(peer_id) returns list[INetConn]") + logger.info(" - get_connections_map() returns dict[ID, list[INetConn]]") logger.info( - f" Connection pool enabled: {connection_config.enable_connection_pool}" + " - get_connection(peer_id) returns INetConn | None (backward compatibility)" ) - logger.info(f" Connections dict type: {type(swarm.connections)}") - logger.info(" Retry logic still available: 3 max retries") + + await swarm.close() + logger.info("Multiple connections API example completed") + + +async def example_backward_compatibility() -> None: + """Example of backward compatibility features.""" + logger.info("Demonstrating backward compatibility...") + + swarm = new_swarm() + + logger.info("Backward compatibility features:") + logger.info(" - connections_legacy property provides 1:1 mapping") + logger.info(" - get_connection() method for single connection access") + logger.info(" - Existing code continues to work") await swarm.close() logger.info("Backward compatibility example completed") @@ -124,7 +106,7 @@ async def example_backward_compatibility() -> None: async def example_production_ready_config() -> None: """Example of production-ready configuration.""" - logger.info("Creating enhanced Swarm with production-ready configuration...") + logger.info("Creating swarm with production-ready configuration...") # Production-ready retry configuration retry_config = RetryConfig( @@ -139,7 +121,6 @@ async def example_production_ready_config() -> None: connection_config = ConnectionConfig( max_connections_per_peer=3, # Balance between performance and resource usage connection_timeout=30.0, # Reasonable timeout - enable_connection_pool=True, # Enable for better performance load_balancing_strategy="round_robin", # Simple, predictable strategy ) @@ -160,19 +141,19 @@ async def example_production_ready_config() -> None: async def main() -> None: """Run all examples.""" - logger.info("Enhanced Swarm Examples") + logger.info("Multiple Connections Per Peer Examples") logger.info("=" * 50) try: - await example_basic_enhanced_swarm() - logger.info("-" * 30) - - await example_custom_retry_config() + await example_basic_multiple_connections() logger.info("-" * 30) await example_custom_connection_config() logger.info("-" * 30) + await example_multiple_connections_api() + logger.info("-" * 30) + await example_backward_compatibility() logger.info("-" * 30) @@ -187,4 +168,4 @@ async def main() -> None: if __name__ == "__main__": - asyncio.run(main()) + trio.run(main) diff --git a/libp2p/abc.py b/libp2p/abc.py index a9748339..964c7454 100644 --- a/libp2p/abc.py +++ b/libp2p/abc.py @@ -1412,15 +1412,16 @@ class INetwork(ABC): ---------- peerstore : IPeerStore The peer store for managing peer information. - connections : dict[ID, INetConn] - A mapping of peer IDs to network connections. + connections : dict[ID, list[INetConn]] + A mapping of peer IDs to lists of network connections + (multiple connections per peer). listeners : dict[str, IListener] A mapping of listener identifiers to listener instances. """ peerstore: IPeerStore - connections: dict[ID, INetConn] + connections: dict[ID, list[INetConn]] listeners: dict[str, IListener] @abstractmethod @@ -1436,9 +1437,56 @@ class INetwork(ABC): """ @abstractmethod - async def dial_peer(self, peer_id: ID) -> INetConn: + def get_connections(self, peer_id: ID | None = None) -> list[INetConn]: """ - Create a connection to the specified peer. + Get connections for peer (like JS getConnections, Go ConnsToPeer). + + Parameters + ---------- + peer_id : ID | None + The peer ID to get connections for. If None, returns all connections. + + Returns + ------- + list[INetConn] + List of connections to the specified peer, or all connections + if peer_id is None. + + """ + + @abstractmethod + def get_connections_map(self) -> dict[ID, list[INetConn]]: + """ + Get all connections map (like JS getConnectionsMap). + + Returns + ------- + dict[ID, list[INetConn]] + The complete mapping of peer IDs to their connection lists. + + """ + + @abstractmethod + def get_connection(self, peer_id: ID) -> INetConn | None: + """ + Get single connection for backward compatibility. + + Parameters + ---------- + peer_id : ID + The peer ID to get a connection for. + + Returns + ------- + INetConn | None + The first available connection, or None if no connections exist. + + """ + + @abstractmethod + async def dial_peer(self, peer_id: ID) -> list[INetConn]: + """ + Create connections to the specified peer with load balancing. Parameters ---------- @@ -1447,8 +1495,8 @@ class INetwork(ABC): Returns ------- - INetConn - The network connection instance to the specified peer. + list[INetConn] + List of established connections to the peer. Raises ------ diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index a0311bd8..a3a89dda 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -338,7 +338,7 @@ class BasicHost(IHost): :param peer_id: ID of the peer to check :return: True if peer has an active connection, False otherwise """ - return peer_id in self._network.connections + return len(self._network.get_connections(peer_id)) > 0 def get_peer_connection_info(self, peer_id: ID) -> INetConn | None: """ @@ -347,4 +347,4 @@ class BasicHost(IHost): :param peer_id: ID of the peer to get info for :return: Connection object if peer is connected, None otherwise """ - return self._network.connections.get(peer_id) + return self._network.get_connection(peer_id) diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 77fe2b6d..23a94fdb 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -74,175 +74,13 @@ class RetryConfig: @dataclass class ConnectionConfig: - """Configuration for connection pool and multi-connection support.""" + """Configuration for multi-connection support.""" max_connections_per_peer: int = 3 connection_timeout: float = 30.0 - enable_connection_pool: bool = True load_balancing_strategy: str = "round_robin" # or "least_loaded" -@dataclass -class ConnectionInfo: - """Information about a connection in the pool.""" - - connection: INetConn - address: str - established_at: float - last_used: float - stream_count: int - is_healthy: bool - - -class ConnectionPool: - """Manages multiple connections per peer with load balancing.""" - - def __init__(self, max_connections_per_peer: int = 3): - self.max_connections_per_peer = max_connections_per_peer - self.peer_connections: dict[ID, list[ConnectionInfo]] = {} - self._round_robin_index: dict[ID, int] = {} - - def add_connection(self, peer_id: ID, connection: INetConn, address: str) -> None: - """Add a connection to the pool with deduplication.""" - if peer_id not in self.peer_connections: - self.peer_connections[peer_id] = [] - - # Check for duplicate connections to the same address - for conn_info in self.peer_connections[peer_id]: - if conn_info.address == address: - logger.debug( - f"Connection to {address} already exists for peer {peer_id}" - ) - return - - # Add new connection - try: - current_time = trio.current_time() - except RuntimeError: - # Fallback for testing contexts where trio is not running - import time - - current_time = time.time() - - conn_info = ConnectionInfo( - connection=connection, - address=address, - established_at=current_time, - last_used=current_time, - stream_count=0, - is_healthy=True, - ) - - self.peer_connections[peer_id].append(conn_info) - - # Trim if we exceed max connections - if len(self.peer_connections[peer_id]) > self.max_connections_per_peer: - self._trim_connections(peer_id) - - def get_connection( - self, peer_id: ID, strategy: str = "round_robin" - ) -> INetConn | None: - """Get a connection using the specified load balancing strategy.""" - if peer_id not in self.peer_connections or not self.peer_connections[peer_id]: - return None - - connections = self.peer_connections[peer_id] - - if strategy == "round_robin": - if peer_id not in self._round_robin_index: - self._round_robin_index[peer_id] = 0 - - index = self._round_robin_index[peer_id] % len(connections) - self._round_robin_index[peer_id] += 1 - - conn_info = connections[index] - try: - conn_info.last_used = trio.current_time() - except RuntimeError: - import time - - conn_info.last_used = time.time() - return conn_info.connection - - elif strategy == "least_loaded": - # Find connection with least streams - # Note: stream_count is a custom attribute we add to connections - conn_info = min( - connections, key=lambda c: getattr(c.connection, "stream_count", 0) - ) - try: - conn_info.last_used = trio.current_time() - except RuntimeError: - import time - - conn_info.last_used = time.time() - return conn_info.connection - - else: - # Default to first connection - conn_info = connections[0] - try: - conn_info.last_used = trio.current_time() - except RuntimeError: - import time - - conn_info.last_used = time.time() - return conn_info.connection - - def has_connection(self, peer_id: ID) -> bool: - """Check if we have any connections to the peer.""" - return ( - peer_id in self.peer_connections and len(self.peer_connections[peer_id]) > 0 - ) - - def remove_connection(self, peer_id: ID, connection: INetConn) -> None: - """Remove a connection from the pool.""" - if peer_id in self.peer_connections: - self.peer_connections[peer_id] = [ - conn_info - for conn_info in self.peer_connections[peer_id] - if conn_info.connection != connection - ] - - # Clean up empty peer entries - if not self.peer_connections[peer_id]: - del self.peer_connections[peer_id] - if peer_id in self._round_robin_index: - del self._round_robin_index[peer_id] - - def _trim_connections(self, peer_id: ID) -> None: - """Remove oldest connections when limit is exceeded.""" - connections = self.peer_connections[peer_id] - if len(connections) <= self.max_connections_per_peer: - return - - # Sort by last used time and remove oldest - connections.sort(key=lambda c: c.last_used) - connections_to_remove = connections[: -self.max_connections_per_peer] - - for conn_info in connections_to_remove: - logger.debug( - f"Trimming old connection to {conn_info.address} for peer {peer_id}" - ) - try: - # Close the connection asynchronously - trio.lowlevel.spawn_system_task( - self._close_connection_async, conn_info.connection - ) - except Exception as e: - logger.warning(f"Error closing trimmed connection: {e}") - - # Keep only the most recently used connections - self.peer_connections[peer_id] = connections[-self.max_connections_per_peer :] - - async def _close_connection_async(self, connection: INetConn) -> None: - """Close a connection asynchronously.""" - try: - await connection.close() - except Exception as e: - logger.warning(f"Error closing connection: {e}") - - def create_default_stream_handler(network: INetworkService) -> StreamHandlerFn: async def stream_handler(stream: INetStream) -> None: await network.get_manager().wait_finished() @@ -256,7 +94,7 @@ class Swarm(Service, INetworkService): upgrader: TransportUpgrader transport: ITransport # Enhanced: Support for multiple connections per peer - connections: dict[ID, INetConn] # Backward compatibility + connections: dict[ID, list[INetConn]] # Multiple connections per peer listeners: dict[str, IListener] common_stream_handler: StreamHandlerFn listener_nursery: trio.Nursery | None @@ -264,10 +102,10 @@ class Swarm(Service, INetworkService): notifees: list[INotifee] - # Enhanced: New configuration and connection pool + # Enhanced: New configuration retry_config: RetryConfig connection_config: ConnectionConfig - connection_pool: ConnectionPool | None + _round_robin_index: dict[ID, int] def __init__( self, @@ -287,16 +125,8 @@ class Swarm(Service, INetworkService): self.retry_config = retry_config or RetryConfig() self.connection_config = connection_config or ConnectionConfig() - # Enhanced: Initialize connection pool if enabled - if self.connection_config.enable_connection_pool: - self.connection_pool = ConnectionPool( - self.connection_config.max_connections_per_peer - ) - else: - self.connection_pool = None - - # Backward compatibility: Keep existing connections dict - self.connections = dict() + # Enhanced: Initialize connections as 1:many mapping + self.connections = {} self.listeners = dict() # Create Notifee array @@ -307,6 +137,9 @@ class Swarm(Service, INetworkService): self.listener_nursery = None self.event_listener_nursery_created = trio.Event() + # Load balancing state + self._round_robin_index = {} + async def run(self) -> None: async with trio.open_nursery() as nursery: # Create a nursery for listener tasks. @@ -326,26 +159,74 @@ class Swarm(Service, INetworkService): def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None: self.common_stream_handler = stream_handler - async def dial_peer(self, peer_id: ID) -> INetConn: + def get_connections(self, peer_id: ID | None = None) -> list[INetConn]: """ - Try to create a connection to peer_id with enhanced retry logic. + Get connections for peer (like JS getConnections, Go ConnsToPeer). + + Parameters + ---------- + peer_id : ID | None + The peer ID to get connections for. If None, returns all connections. + + Returns + ------- + list[INetConn] + List of connections to the specified peer, or all connections + if peer_id is None. + + """ + if peer_id is not None: + return self.connections.get(peer_id, []) + + # Return all connections from all peers + all_conns = [] + for conns in self.connections.values(): + all_conns.extend(conns) + return all_conns + + def get_connections_map(self) -> dict[ID, list[INetConn]]: + """ + Get all connections map (like JS getConnectionsMap). + + Returns + ------- + dict[ID, list[INetConn]] + The complete mapping of peer IDs to their connection lists. + + """ + return self.connections.copy() + + def get_connection(self, peer_id: ID) -> INetConn | None: + """ + Get single connection for backward compatibility. + + Parameters + ---------- + peer_id : ID + The peer ID to get a connection for. + + Returns + ------- + INetConn | None + The first available connection, or None if no connections exist. + + """ + conns = self.get_connections(peer_id) + return conns[0] if conns else None + + async def dial_peer(self, peer_id: ID) -> list[INetConn]: + """ + Try to create connections to peer_id with enhanced retry logic. :param peer_id: peer if we want to dial :raises SwarmException: raised when an error occurs - :return: muxed connection + :return: list of muxed connections """ - # Enhanced: Check connection pool first if enabled - if self.connection_pool and self.connection_pool.has_connection(peer_id): - connection = self.connection_pool.get_connection(peer_id) - if connection: - logger.debug(f"Reusing existing connection to peer {peer_id}") - return connection - - # Enhanced: Check existing single connection for backward compatibility - if peer_id in self.connections: - # If muxed connection already exists for peer_id, - # set muxed connection equal to existing muxed connection - return self.connections[peer_id] + # Check if we already have connections + existing_connections = self.get_connections(peer_id) + if existing_connections: + logger.debug(f"Reusing existing connections to peer {peer_id}") + return existing_connections logger.debug("attempting to dial peer %s", peer_id) @@ -358,23 +239,19 @@ class Swarm(Service, INetworkService): if not addrs: raise SwarmException(f"No known addresses to peer {peer_id}") + connections = [] exceptions: list[SwarmException] = [] # Enhanced: Try all known addresses with retry logic for multiaddr in addrs: try: connection = await self._dial_with_retry(multiaddr, peer_id) + connections.append(connection) - # Enhanced: Add to connection pool if enabled - if self.connection_pool: - self.connection_pool.add_connection( - peer_id, connection, str(multiaddr) - ) + # Limit number of connections per peer + if len(connections) >= self.connection_config.max_connections_per_peer: + break - # Backward compatibility: Keep existing connections dict - self.connections[peer_id] = connection - - return connection except SwarmException as e: exceptions.append(e) logger.debug( @@ -384,11 +261,14 @@ class Swarm(Service, INetworkService): exc_info=e, ) - # Tried all addresses, raising exception. - raise SwarmException( - f"unable to connect to {peer_id}, no addresses established a successful " - "connection (with exceptions)" - ) from MultiError(exceptions) + if not connections: + # Tried all addresses, raising exception. + raise SwarmException( + f"unable to connect to {peer_id}, no addresses established a " + "successful connection (with exceptions)" + ) from MultiError(exceptions) + + return connections async def _dial_with_retry(self, addr: Multiaddr, peer_id: ID) -> INetConn: """ @@ -515,33 +395,76 @@ class Swarm(Service, INetworkService): """ logger.debug("attempting to open a stream to peer %s", peer_id) - # Enhanced: Try to get existing connection from pool first - if self.connection_pool and self.connection_pool.has_connection(peer_id): - connection = self.connection_pool.get_connection( - peer_id, self.connection_config.load_balancing_strategy - ) - if connection: - try: - net_stream = await connection.new_stream() - logger.debug( - "successfully opened a stream to peer %s " - "using existing connection", - peer_id, - ) - return net_stream - except Exception as e: - logger.debug( - f"Failed to create stream on existing connection, " - f"will dial new connection: {e}" - ) - # Fall through to dial new connection + # Get existing connections or dial new ones + connections = self.get_connections(peer_id) + if not connections: + connections = await self.dial_peer(peer_id) - # Fall back to existing logic: dial peer and create stream - swarm_conn = await self.dial_peer(peer_id) + # Load balancing strategy at interface level + connection = self._select_connection(connections, peer_id) - net_stream = await swarm_conn.new_stream() - logger.debug("successfully opened a stream to peer %s", peer_id) - return net_stream + try: + net_stream = await connection.new_stream() + logger.debug("successfully opened a stream to peer %s", peer_id) + return net_stream + except Exception as e: + logger.debug(f"Failed to create stream on connection: {e}") + # Try other connections if available + for other_conn in connections: + if other_conn != connection: + try: + net_stream = await other_conn.new_stream() + logger.debug( + f"Successfully opened a stream to peer {peer_id} " + "using alternative connection" + ) + return net_stream + except Exception: + continue + + # All connections failed, raise exception + raise SwarmException(f"Failed to create stream to peer {peer_id}") from e + + def _select_connection(self, connections: list[INetConn], peer_id: ID) -> INetConn: + """ + Select connection based on load balancing strategy. + + Parameters + ---------- + connections : list[INetConn] + List of available connections. + peer_id : ID + The peer ID for round-robin tracking. + strategy : str + Load balancing strategy ("round_robin", "least_loaded", etc.). + + Returns + ------- + INetConn + Selected connection. + + """ + if not connections: + raise ValueError("No connections available") + + strategy = self.connection_config.load_balancing_strategy + + if strategy == "round_robin": + # Simple round-robin selection + if peer_id not in self._round_robin_index: + self._round_robin_index[peer_id] = 0 + + index = self._round_robin_index[peer_id] % len(connections) + self._round_robin_index[peer_id] += 1 + return connections[index] + + elif strategy == "least_loaded": + # Find connection with least streams + return min(connections, key=lambda c: len(c.get_streams())) + + else: + # Default to first connection + return connections[0] async def listen(self, *multiaddrs: Multiaddr) -> bool: """ @@ -637,9 +560,9 @@ class Swarm(Service, INetworkService): # Perform alternative cleanup if the manager isn't initialized # Close all connections manually if hasattr(self, "connections"): - for conn_id in list(self.connections.keys()): - conn = self.connections[conn_id] - await conn.close() + for peer_id, conns in list(self.connections.items()): + for conn in conns: + await conn.close() # Clear connection tracking dictionary self.connections.clear() @@ -669,17 +592,28 @@ class Swarm(Service, INetworkService): logger.debug("swarm successfully closed") async def close_peer(self, peer_id: ID) -> None: - if peer_id not in self.connections: + """ + Close all connections to the specified peer. + + Parameters + ---------- + peer_id : ID + The peer ID to close connections for. + + """ + connections = self.get_connections(peer_id) + if not connections: return - connection = self.connections[peer_id] - # Enhanced: Remove from connection pool if enabled - if self.connection_pool: - self.connection_pool.remove_connection(peer_id, connection) + # Close all connections + for connection in connections: + try: + await connection.close() + except Exception as e: + logger.warning(f"Error closing connection to {peer_id}: {e}") - # NOTE: `connection.close` will delete `peer_id` from `self.connections` - # and `notify_disconnected` for us. - await connection.close() + # Remove from connections dict + self.connections.pop(peer_id, None) logger.debug("successfully close the connection to peer %s", peer_id) @@ -698,20 +632,58 @@ class Swarm(Service, INetworkService): await muxed_conn.event_started.wait() self.manager.run_task(swarm_conn.start) await swarm_conn.event_started.wait() - # Enhanced: Add to connection pool if enabled - if self.connection_pool: - # For incoming connections, we don't have a specific address - # Use a placeholder that will be updated when we get more info - self.connection_pool.add_connection( - muxed_conn.peer_id, swarm_conn, "incoming" - ) - # Store muxed_conn with peer id (backward compatibility) - self.connections[muxed_conn.peer_id] = swarm_conn + # Add to connections dict with deduplication + peer_id = muxed_conn.peer_id + if peer_id not in self.connections: + self.connections[peer_id] = [] + + # Check for duplicate connections by comparing the underlying muxed connection + for existing_conn in self.connections[peer_id]: + if existing_conn.muxed_conn == muxed_conn: + logger.debug(f"Connection already exists for peer {peer_id}") + # existing_conn is a SwarmConn since it's stored in the connections list + return existing_conn # type: ignore[return-value] + + self.connections[peer_id].append(swarm_conn) + + # Trim if we exceed max connections + max_conns = self.connection_config.max_connections_per_peer + if len(self.connections[peer_id]) > max_conns: + self._trim_connections(peer_id) + # Call notifiers since event occurred await self.notify_connected(swarm_conn) return swarm_conn + def _trim_connections(self, peer_id: ID) -> None: + """ + Remove oldest connections when limit is exceeded. + """ + connections = self.connections[peer_id] + if len(connections) <= self.connection_config.max_connections_per_peer: + return + + # Sort by creation time and remove oldest + # For now, just keep the most recent connections + max_conns = self.connection_config.max_connections_per_peer + connections_to_remove = connections[:-max_conns] + + for conn in connections_to_remove: + logger.debug(f"Trimming old connection for peer {peer_id}") + trio.lowlevel.spawn_system_task(self._close_connection_async, conn) + + # Keep only the most recent connections + max_conns = self.connection_config.max_connections_per_peer + self.connections[peer_id] = connections[-max_conns:] + + async def _close_connection_async(self, connection: INetConn) -> None: + """Close a connection asynchronously.""" + try: + await connection.close() + except Exception as e: + logger.warning(f"Error closing connection: {e}") + def remove_conn(self, swarm_conn: SwarmConn) -> None: """ Simply remove the connection from Swarm's records, without closing @@ -719,13 +691,12 @@ class Swarm(Service, INetworkService): """ peer_id = swarm_conn.muxed_conn.peer_id - # Enhanced: Remove from connection pool if enabled - if self.connection_pool: - self.connection_pool.remove_connection(peer_id, swarm_conn) - - if peer_id not in self.connections: - return - del self.connections[peer_id] + if peer_id in self.connections: + self.connections[peer_id] = [ + conn for conn in self.connections[peer_id] if conn != swarm_conn + ] + if not self.connections[peer_id]: + del self.connections[peer_id] # Notifee @@ -771,3 +742,21 @@ class Swarm(Service, INetworkService): async with trio.open_nursery() as nursery: for notifee in self.notifees: nursery.start_soon(notifier, notifee) + + # Backward compatibility properties + @property + def connections_legacy(self) -> dict[ID, INetConn]: + """ + Legacy 1:1 mapping for backward compatibility. + + Returns + ------- + dict[ID, INetConn] + Legacy mapping with only the first connection per peer. + + """ + legacy_conns = {} + for peer_id, conns in self.connections.items(): + if conns: + legacy_conns[peer_id] = conns[0] + return legacy_conns diff --git a/tests/core/host/test_live_peers.py b/tests/core/host/test_live_peers.py index 1d7948ad..e5af42ba 100644 --- a/tests/core/host/test_live_peers.py +++ b/tests/core/host/test_live_peers.py @@ -164,8 +164,8 @@ async def test_live_peers_unexpected_drop(security_protocol): assert peer_a_id in host_b.get_live_peers() # Simulate unexpected connection drop by directly closing the connection - conn = host_a.get_network().connections[peer_b_id] - await conn.muxed_conn.close() + conns = host_a.get_network().connections[peer_b_id] + await conns[0].muxed_conn.close() # Allow for connection cleanup await trio.sleep(0.1) diff --git a/tests/core/network/test_enhanced_swarm.py b/tests/core/network/test_enhanced_swarm.py index 9b100ad9..e63de126 100644 --- a/tests/core/network/test_enhanced_swarm.py +++ b/tests/core/network/test_enhanced_swarm.py @@ -1,14 +1,15 @@ import time +from typing import cast from unittest.mock import Mock import pytest from multiaddr import Multiaddr +import trio from libp2p.abc import INetConn, INetStream from libp2p.network.exceptions import SwarmException from libp2p.network.swarm import ( ConnectionConfig, - ConnectionPool, RetryConfig, Swarm, ) @@ -21,10 +22,12 @@ class MockConnection(INetConn): def __init__(self, peer_id: ID, is_closed: bool = False): self.peer_id = peer_id self._is_closed = is_closed - self.stream_count = 0 + self.streams = set() # Track streams properly # Mock the muxed_conn attribute that Swarm expects self.muxed_conn = Mock() self.muxed_conn.peer_id = peer_id + # Required by INetConn interface + self.event_started = trio.Event() async def close(self): self._is_closed = True @@ -34,12 +37,14 @@ class MockConnection(INetConn): return self._is_closed async def new_stream(self) -> INetStream: - self.stream_count += 1 - return Mock(spec=INetStream) + # Create a mock stream and add it to the connection's stream set + mock_stream = Mock(spec=INetStream) + self.streams.add(mock_stream) + return mock_stream def get_streams(self) -> tuple[INetStream, ...]: - """Mock implementation of get_streams.""" - return tuple() + """Return all streams associated with this connection.""" + return tuple(self.streams) def get_transport_addresses(self) -> list[Multiaddr]: """Mock implementation of get_transport_addresses.""" @@ -70,114 +75,9 @@ async def test_connection_config_defaults(): config = ConnectionConfig() assert config.max_connections_per_peer == 3 assert config.connection_timeout == 30.0 - assert config.enable_connection_pool is True assert config.load_balancing_strategy == "round_robin" -@pytest.mark.trio -async def test_connection_pool_basic_operations(): - """Test basic ConnectionPool operations.""" - pool = ConnectionPool(max_connections_per_peer=2) - peer_id = ID(b"QmTest") - - # Test empty pool - assert not pool.has_connection(peer_id) - assert pool.get_connection(peer_id) is None - - # Add connection - conn1 = MockConnection(peer_id) - pool.add_connection(peer_id, conn1, "addr1") - assert pool.has_connection(peer_id) - assert pool.get_connection(peer_id) == conn1 - - # Add second connection - conn2 = MockConnection(peer_id) - pool.add_connection(peer_id, conn2, "addr2") - assert len(pool.peer_connections[peer_id]) == 2 - - # Test round-robin - should cycle through connections - first_conn = pool.get_connection(peer_id, "round_robin") - second_conn = pool.get_connection(peer_id, "round_robin") - third_conn = pool.get_connection(peer_id, "round_robin") - - # Should cycle through both connections - assert first_conn in [conn1, conn2] - assert second_conn in [conn1, conn2] - assert third_conn in [conn1, conn2] - assert first_conn != second_conn or second_conn != third_conn - - # Test least loaded - set different stream counts - conn1.stream_count = 5 - conn2.stream_count = 1 - least_loaded_conn = pool.get_connection(peer_id, "least_loaded") - assert least_loaded_conn == conn2 # conn2 has fewer streams - - -@pytest.mark.trio -async def test_connection_pool_deduplication(): - """Test connection deduplication by address.""" - pool = ConnectionPool(max_connections_per_peer=3) - peer_id = ID(b"QmTest") - - conn1 = MockConnection(peer_id) - pool.add_connection(peer_id, conn1, "addr1") - - # Try to add connection with same address - conn2 = MockConnection(peer_id) - pool.add_connection(peer_id, conn2, "addr1") - - # Should only have one connection - assert len(pool.peer_connections[peer_id]) == 1 - assert pool.get_connection(peer_id) == conn1 - - -@pytest.mark.trio -async def test_connection_pool_trimming(): - """Test connection trimming when limit is exceeded.""" - pool = ConnectionPool(max_connections_per_peer=2) - peer_id = ID(b"QmTest") - - # Add 3 connections - conn1 = MockConnection(peer_id) - conn2 = MockConnection(peer_id) - conn3 = MockConnection(peer_id) - - pool.add_connection(peer_id, conn1, "addr1") - pool.add_connection(peer_id, conn2, "addr2") - pool.add_connection(peer_id, conn3, "addr3") - - # Should trim to 2 connections - assert len(pool.peer_connections[peer_id]) == 2 - - # The oldest connections should be removed - remaining_connections = [c.connection for c in pool.peer_connections[peer_id]] - assert conn3 in remaining_connections # Most recent should remain - - -@pytest.mark.trio -async def test_connection_pool_remove_connection(): - """Test removing connections from pool.""" - pool = ConnectionPool(max_connections_per_peer=3) - peer_id = ID(b"QmTest") - - conn1 = MockConnection(peer_id) - conn2 = MockConnection(peer_id) - - pool.add_connection(peer_id, conn1, "addr1") - pool.add_connection(peer_id, conn2, "addr2") - - assert len(pool.peer_connections[peer_id]) == 2 - - # Remove connection - pool.remove_connection(peer_id, conn1) - assert len(pool.peer_connections[peer_id]) == 1 - assert pool.get_connection(peer_id) == conn2 - - # Remove last connection - pool.remove_connection(peer_id, conn2) - assert not pool.has_connection(peer_id) - - @pytest.mark.trio async def test_enhanced_swarm_constructor(): """Test enhanced Swarm constructor with new configuration.""" @@ -191,19 +91,16 @@ async def test_enhanced_swarm_constructor(): swarm = Swarm(peer_id, peerstore, upgrader, transport) assert swarm.retry_config.max_retries == 3 assert swarm.connection_config.max_connections_per_peer == 3 - assert swarm.connection_pool is not None + assert isinstance(swarm.connections, dict) # Test with custom config custom_retry = RetryConfig(max_retries=5, initial_delay=0.5) - custom_conn = ConnectionConfig( - max_connections_per_peer=5, enable_connection_pool=False - ) + custom_conn = ConnectionConfig(max_connections_per_peer=5) swarm = Swarm(peer_id, peerstore, upgrader, transport, custom_retry, custom_conn) assert swarm.retry_config.max_retries == 5 assert swarm.retry_config.initial_delay == 0.5 assert swarm.connection_config.max_connections_per_peer == 5 - assert swarm.connection_pool is None @pytest.mark.trio @@ -273,143 +170,155 @@ async def test_swarm_retry_logic(): # Should have succeeded after 3 attempts assert attempt_count[0] == 3 - assert result is not None - - # Should have taken some time due to retries - assert end_time - start_time > 0.02 # At least 2 delays + assert isinstance(result, MockConnection) + assert end_time - start_time > 0.01 # Should have some delay @pytest.mark.trio -async def test_swarm_multi_connection_support(): - """Test multi-connection support in Swarm.""" +async def test_swarm_load_balancing_strategies(): + """Test load balancing strategies.""" peer_id = ID(b"QmTest") peerstore = Mock() upgrader = Mock() transport = Mock() - connection_config = ConnectionConfig( - max_connections_per_peer=3, - enable_connection_pool=True, - load_balancing_strategy="round_robin", - ) + swarm = Swarm(peer_id, peerstore, upgrader, transport) + # Create mock connections with different stream counts + conn1 = MockConnection(peer_id) + conn2 = MockConnection(peer_id) + conn3 = MockConnection(peer_id) + + # Add some streams to simulate load + await conn1.new_stream() + await conn1.new_stream() + await conn2.new_stream() + + connections = [conn1, conn2, conn3] + + # Test round-robin strategy + swarm.connection_config.load_balancing_strategy = "round_robin" + # Cast to satisfy type checker + connections_cast = cast("list[INetConn]", connections) + selected1 = swarm._select_connection(connections_cast, peer_id) + selected2 = swarm._select_connection(connections_cast, peer_id) + selected3 = swarm._select_connection(connections_cast, peer_id) + + # Should cycle through connections + assert selected1 in connections + assert selected2 in connections + assert selected3 in connections + + # Test least loaded strategy + swarm.connection_config.load_balancing_strategy = "least_loaded" + least_loaded = swarm._select_connection(connections_cast, peer_id) + + # conn3 has 0 streams, conn2 has 1 stream, conn1 has 2 streams + # So conn3 should be selected as least loaded + assert least_loaded == conn3 + + # Test default strategy (first connection) + swarm.connection_config.load_balancing_strategy = "unknown" + default_selected = swarm._select_connection(connections_cast, peer_id) + assert default_selected == conn1 + + +@pytest.mark.trio +async def test_swarm_multiple_connections_api(): + """Test the new multiple connections API methods.""" + peer_id = ID(b"QmTest") + peerstore = Mock() + upgrader = Mock() + transport = Mock() + + swarm = Swarm(peer_id, peerstore, upgrader, transport) + + # Test empty connections + assert swarm.get_connections() == [] + assert swarm.get_connections(peer_id) == [] + assert swarm.get_connection(peer_id) is None + assert swarm.get_connections_map() == {} + + # Add some connections + conn1 = MockConnection(peer_id) + conn2 = MockConnection(peer_id) + swarm.connections[peer_id] = [conn1, conn2] + + # Test get_connections with peer_id + peer_connections = swarm.get_connections(peer_id) + assert len(peer_connections) == 2 + assert conn1 in peer_connections + assert conn2 in peer_connections + + # Test get_connections without peer_id (all connections) + all_connections = swarm.get_connections() + assert len(all_connections) == 2 + assert conn1 in all_connections + assert conn2 in all_connections + + # Test get_connection (backward compatibility) + single_conn = swarm.get_connection(peer_id) + assert single_conn in [conn1, conn2] + + # Test get_connections_map + connections_map = swarm.get_connections_map() + assert peer_id in connections_map + assert connections_map[peer_id] == [conn1, conn2] + + +@pytest.mark.trio +async def test_swarm_connection_trimming(): + """Test connection trimming when limit is exceeded.""" + peer_id = ID(b"QmTest") + peerstore = Mock() + upgrader = Mock() + transport = Mock() + + # Set max connections to 2 + connection_config = ConnectionConfig(max_connections_per_peer=2) swarm = Swarm( peer_id, peerstore, upgrader, transport, connection_config=connection_config ) - # Mock connection pool methods - assert swarm.connection_pool is not None - connection_pool = swarm.connection_pool - connection_pool.has_connection = Mock(return_value=True) - connection_pool.get_connection = Mock(return_value=MockConnection(peer_id)) + # Add 3 connections + conn1 = MockConnection(peer_id) + conn2 = MockConnection(peer_id) + conn3 = MockConnection(peer_id) - # Test that new_stream uses connection pool - result = await swarm.new_stream(peer_id) - assert result is not None - # Use the mocked method directly to avoid type checking issues - get_connection_mock = connection_pool.get_connection - assert get_connection_mock.call_count == 1 + swarm.connections[peer_id] = [conn1, conn2, conn3] + + # Trigger trimming + swarm._trim_connections(peer_id) + + # Should have only 2 connections + assert len(swarm.connections[peer_id]) == 2 + + # The most recent connections should remain + remaining = swarm.connections[peer_id] + assert conn2 in remaining + assert conn3 in remaining @pytest.mark.trio async def test_swarm_backward_compatibility(): - """Test that enhanced Swarm maintains backward compatibility.""" + """Test backward compatibility features.""" peer_id = ID(b"QmTest") peerstore = Mock() upgrader = Mock() transport = Mock() - # Create swarm with connection pool disabled - connection_config = ConnectionConfig(enable_connection_pool=False) - swarm = Swarm( - peer_id, peerstore, upgrader, transport, connection_config=connection_config - ) + swarm = Swarm(peer_id, peerstore, upgrader, transport) - # Should behave like original swarm - assert swarm.connection_pool is None - assert isinstance(swarm.connections, dict) + # Add connections + conn1 = MockConnection(peer_id) + conn2 = MockConnection(peer_id) + swarm.connections[peer_id] = [conn1, conn2] - # Test that dial_peer still works (will fail due to mocks, but structure is correct) - peerstore.addrs.return_value = [Mock(spec=Multiaddr)] - transport.dial.side_effect = Exception("Transport error") - - with pytest.raises(SwarmException): - await swarm.dial_peer(peer_id) - - -@pytest.mark.trio -async def test_swarm_connection_pool_integration(): - """Test integration between Swarm and ConnectionPool.""" - peer_id = ID(b"QmTest") - peerstore = Mock() - upgrader = Mock() - transport = Mock() - - connection_config = ConnectionConfig( - max_connections_per_peer=2, enable_connection_pool=True - ) - - swarm = Swarm( - peer_id, peerstore, upgrader, transport, connection_config=connection_config - ) - - # Mock successful connection creation - mock_conn = MockConnection(peer_id) - peerstore.addrs.return_value = [Mock(spec=Multiaddr)] - - async def mock_dial_with_retry(addr, peer_id): - return mock_conn - - swarm._dial_with_retry = mock_dial_with_retry - - # Test dial_peer adds to connection pool - result = await swarm.dial_peer(peer_id) - assert result == mock_conn - assert swarm.connection_pool is not None - assert swarm.connection_pool.has_connection(peer_id) - - # Test that subsequent calls reuse connection - result2 = await swarm.dial_peer(peer_id) - assert result2 == mock_conn - - -@pytest.mark.trio -async def test_swarm_connection_cleanup(): - """Test connection cleanup in enhanced Swarm.""" - peer_id = ID(b"QmTest") - peerstore = Mock() - upgrader = Mock() - transport = Mock() - - connection_config = ConnectionConfig(enable_connection_pool=True) - swarm = Swarm( - peer_id, peerstore, upgrader, transport, connection_config=connection_config - ) - - # Add a connection - mock_conn = MockConnection(peer_id) - swarm.connections[peer_id] = mock_conn - assert swarm.connection_pool is not None - swarm.connection_pool.add_connection(peer_id, mock_conn, "test_addr") - - # Test close_peer removes from pool - await swarm.close_peer(peer_id) - assert swarm.connection_pool is not None - assert not swarm.connection_pool.has_connection(peer_id) - - # Test remove_conn removes from pool - mock_conn2 = MockConnection(peer_id) - swarm.connections[peer_id] = mock_conn2 - assert swarm.connection_pool is not None - connection_pool = swarm.connection_pool - connection_pool.add_connection(peer_id, mock_conn2, "test_addr2") - - # Note: remove_conn expects SwarmConn, but for testing we'll just - # remove from pool directly - connection_pool = swarm.connection_pool - connection_pool.remove_connection(peer_id, mock_conn2) - assert connection_pool is not None - assert not connection_pool.has_connection(peer_id) + # Test connections_legacy property + legacy_connections = swarm.connections_legacy + assert peer_id in legacy_connections + # Should return first connection + assert legacy_connections[peer_id] in [conn1, conn2] if __name__ == "__main__": diff --git a/tests/core/network/test_swarm.py b/tests/core/network/test_swarm.py index 605913ec..df08ff98 100644 --- a/tests/core/network/test_swarm.py +++ b/tests/core/network/test_swarm.py @@ -51,14 +51,19 @@ async def test_swarm_dial_peer(security_protocol): for addr in transport.get_addrs() ) swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000) - await swarms[0].dial_peer(swarms[1].get_peer_id()) + + # New: dial_peer now returns list of connections + connections = await swarms[0].dial_peer(swarms[1].get_peer_id()) + assert len(connections) > 0 + + # Verify connections are established in both directions assert swarms[0].get_peer_id() in swarms[1].connections assert swarms[1].get_peer_id() in swarms[0].connections # Test: Reuse connections when we already have ones with a peer. - conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()] - conn = await swarms[0].dial_peer(swarms[1].get_peer_id()) - assert conn is conn_to_1 + existing_connections = swarms[0].get_connections(swarms[1].get_peer_id()) + new_connections = await swarms[0].dial_peer(swarms[1].get_peer_id()) + assert new_connections == existing_connections @pytest.mark.trio @@ -107,7 +112,8 @@ async def test_swarm_close_peer(security_protocol): @pytest.mark.trio async def test_swarm_remove_conn(swarm_pair): swarm_0, swarm_1 = swarm_pair - conn_0 = swarm_0.connections[swarm_1.get_peer_id()] + # Get the first connection from the list + conn_0 = swarm_0.connections[swarm_1.get_peer_id()][0] swarm_0.remove_conn(conn_0) assert swarm_1.get_peer_id() not in swarm_0.connections # Test: Remove twice. There should not be errors. @@ -115,6 +121,67 @@ async def test_swarm_remove_conn(swarm_pair): assert swarm_1.get_peer_id() not in swarm_0.connections +@pytest.mark.trio +async def test_swarm_multiple_connections(security_protocol): + """Test multiple connections per peer functionality.""" + async with SwarmFactory.create_batch_and_listen( + 2, security_protocol=security_protocol + ) as swarms: + # Setup multiple addresses for peer + addrs = tuple( + addr + for transport in swarms[1].listeners.values() + for addr in transport.get_addrs() + ) + swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000) + + # Dial peer - should return list of connections + connections = await swarms[0].dial_peer(swarms[1].get_peer_id()) + assert len(connections) > 0 + + # Test get_connections method + peer_connections = swarms[0].get_connections(swarms[1].get_peer_id()) + assert len(peer_connections) == len(connections) + + # Test get_connections_map method + connections_map = swarms[0].get_connections_map() + assert swarms[1].get_peer_id() in connections_map + assert len(connections_map[swarms[1].get_peer_id()]) == len(connections) + + # Test get_connection method (backward compatibility) + single_conn = swarms[0].get_connection(swarms[1].get_peer_id()) + assert single_conn is not None + assert single_conn in connections + + +@pytest.mark.trio +async def test_swarm_load_balancing(security_protocol): + """Test load balancing across multiple connections.""" + async with SwarmFactory.create_batch_and_listen( + 2, security_protocol=security_protocol + ) as swarms: + # Setup connection + addrs = tuple( + addr + for transport in swarms[1].listeners.values() + for addr in transport.get_addrs() + ) + swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000) + + # Create multiple streams - should use load balancing + streams = [] + for _ in range(5): + stream = await swarms[0].new_stream(swarms[1].get_peer_id()) + streams.append(stream) + + # Verify streams were created successfully + assert len(streams) == 5 + + # Clean up + for stream in streams: + await stream.close() + + @pytest.mark.trio async def test_swarm_multiaddr(security_protocol): async with SwarmFactory.create_batch_and_listen( diff --git a/tests/core/security/test_security_multistream.py b/tests/core/security/test_security_multistream.py index 577cf404..d4fed72d 100644 --- a/tests/core/security/test_security_multistream.py +++ b/tests/core/security/test_security_multistream.py @@ -51,6 +51,9 @@ async def perform_simple_test(assertion_func, security_protocol): # Extract the secured connection from either Mplex or Yamux implementation def get_secured_conn(conn): + # conn is now a list, get the first connection + if isinstance(conn, list): + conn = conn[0] muxed_conn = conn.muxed_conn # Direct attribute access for known implementations has_secured_conn = hasattr(muxed_conn, "secured_conn") diff --git a/tests/core/stream_muxer/test_multiplexer_selection.py b/tests/core/stream_muxer/test_multiplexer_selection.py index b2f3e305..9b45324e 100644 --- a/tests/core/stream_muxer/test_multiplexer_selection.py +++ b/tests/core/stream_muxer/test_multiplexer_selection.py @@ -74,7 +74,8 @@ async def test_multiplexer_preference_parameter(muxer_preference): assert len(connections) > 0, "Connection not established" # Get the first connection - conn = list(connections.values())[0] + conns = list(connections.values())[0] + conn = conns[0] # Get first connection from the list muxed_conn = conn.muxed_conn # Define a simple echo protocol @@ -150,7 +151,8 @@ async def test_explicit_muxer_options(muxer_option_func, expected_stream_class): assert len(connections) > 0, "Connection not established" # Get the first connection - conn = list(connections.values())[0] + conns = list(connections.values())[0] + conn = conns[0] # Get first connection from the list muxed_conn = conn.muxed_conn # Define a simple echo protocol @@ -219,7 +221,8 @@ async def test_global_default_muxer(global_default): assert len(connections) > 0, "Connection not established" # Get the first connection - conn = list(connections.values())[0] + conns = list(connections.values())[0] + conn = conns[0] # Get first connection from the list muxed_conn = conn.muxed_conn # Define a simple echo protocol diff --git a/tests/utils/factories.py b/tests/utils/factories.py index 75639e36..c006200f 100644 --- a/tests/utils/factories.py +++ b/tests/utils/factories.py @@ -669,8 +669,8 @@ async def swarm_conn_pair_factory( async with swarm_pair_factory( security_protocol=security_protocol, muxer_opt=muxer_opt ) as swarms: - conn_0 = swarms[0].connections[swarms[1].get_peer_id()] - conn_1 = swarms[1].connections[swarms[0].get_peer_id()] + conn_0 = swarms[0].connections[swarms[1].get_peer_id()][0] + conn_1 = swarms[1].connections[swarms[0].get_peer_id()][0] yield cast(SwarmConn, conn_0), cast(SwarmConn, conn_1) From 526b65e1d5a544b886555c24622672ecf6f88213 Mon Sep 17 00:00:00 2001 From: bomanaps Date: Sun, 31 Aug 2025 01:43:27 +0100 Subject: [PATCH 61/71] style: apply ruff formatting fixes --- docs/examples.multiple_connections.rst | 8 ++++---- examples/doc-examples/multiple_connections_example.py | 3 +-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/docs/examples.multiple_connections.rst b/docs/examples.multiple_connections.rst index da1d3b02..946d6e8f 100644 --- a/docs/examples.multiple_connections.rst +++ b/docs/examples.multiple_connections.rst @@ -63,13 +63,13 @@ The new API provides direct access to multiple connections: # Get all connections to a peer peer_connections = swarm.get_connections(peer_id) - + # Get all connections (across all peers) all_connections = swarm.get_connections() - + # Get the complete connections map connections_map = swarm.get_connections_map() - + # Backward compatibility - get single connection single_conn = swarm.get_connection(peer_id) @@ -82,7 +82,7 @@ Existing code continues to work through backward compatibility features: # Legacy 1:1 mapping (returns first connection for each peer) legacy_connections = swarm.connections_legacy - + # Single connection access (returns first available connection) conn = swarm.get_connection(peer_id) diff --git a/examples/doc-examples/multiple_connections_example.py b/examples/doc-examples/multiple_connections_example.py index 14a71ab8..f0738283 100644 --- a/examples/doc-examples/multiple_connections_example.py +++ b/examples/doc-examples/multiple_connections_example.py @@ -71,8 +71,7 @@ async def example_multiple_connections_api() -> None: logger.info("Demonstrating multiple connections API...") connection_config = ConnectionConfig( - max_connections_per_peer=3, - load_balancing_strategy="round_robin" + max_connections_per_peer=3, load_balancing_strategy="round_robin" ) swarm = new_swarm(connection_config=connection_config) From 9a06ee429fd6f60b61742e7251348f26b89eae3e Mon Sep 17 00:00:00 2001 From: bomanaps Date: Sun, 31 Aug 2025 02:01:39 +0100 Subject: [PATCH 62/71] Fix documentation build issues and add _build/ to .gitignore --- .gitignore | 3 +++ docs/examples.multiple_connections.rst | 12 ++++++------ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index e46cc8aa..fd2c8231 100644 --- a/.gitignore +++ b/.gitignore @@ -178,3 +178,6 @@ env.bak/ #lockfiles uv.lock poetry.lock + +# Sphinx documentation build +_build/ diff --git a/docs/examples.multiple_connections.rst b/docs/examples.multiple_connections.rst index 946d6e8f..814152b3 100644 --- a/docs/examples.multiple_connections.rst +++ b/docs/examples.multiple_connections.rst @@ -1,5 +1,5 @@ Multiple Connections Per Peer -============================ +============================= This example demonstrates how to use the multiple connections per peer feature in py-libp2p. @@ -35,7 +35,7 @@ The feature is configured through the `ConnectionConfig` class: ) Load Balancing Strategies ------------------------- +------------------------- Two load balancing strategies are available: @@ -74,7 +74,7 @@ The new API provides direct access to multiple connections: single_conn = swarm.get_connection(peer_id) Backward Compatibility ---------------------- +---------------------- Existing code continues to work through backward compatibility features: @@ -89,10 +89,10 @@ Existing code continues to work through backward compatibility features: Example ------- -See :doc:`examples/doc-examples/multiple_connections_example.py` for a complete working example. +A complete working example is available in the `examples/doc-examples/multiple_connections_example.py` file. Production Configuration ------------------------ +------------------------- For production use, consider these settings: @@ -121,7 +121,7 @@ For production use, consider these settings: ) Architecture ------------ +------------ The implementation follows the same architectural patterns as the Go and JavaScript reference implementations: From 6a24b138dd65b690ccc0e1f214d2d29a9f4c9b16 Mon Sep 17 00:00:00 2001 From: yashksaini-coder Date: Mon, 1 Sep 2025 01:35:32 +0530 Subject: [PATCH 63/71] feat: Add cross-platform path utilities module --- libp2p/utils/paths.py | 162 ++++++++++++++++++++++++++++ scripts/audit_paths.py | 222 ++++++++++++++++++++++++++++++++++++++ tests/utils/test_paths.py | 213 ++++++++++++++++++++++++++++++++++++ 3 files changed, 597 insertions(+) create mode 100644 libp2p/utils/paths.py create mode 100644 scripts/audit_paths.py create mode 100644 tests/utils/test_paths.py diff --git a/libp2p/utils/paths.py b/libp2p/utils/paths.py new file mode 100644 index 00000000..27924d8f --- /dev/null +++ b/libp2p/utils/paths.py @@ -0,0 +1,162 @@ +""" +Cross-platform path utilities for py-libp2p. + +This module provides standardized path operations to ensure consistent +behavior across Windows, macOS, and Linux platforms. +""" + +import os +import tempfile +from pathlib import Path +from typing import Union, Optional + +PathLike = Union[str, Path] + + +def get_temp_dir() -> Path: + """ + Get cross-platform temporary directory. + + Returns: + Path: Platform-specific temporary directory path + """ + return Path(tempfile.gettempdir()) + + +def get_project_root() -> Path: + """ + Get the project root directory. + + Returns: + Path: Path to the py-libp2p project root + """ + # Navigate from libp2p/utils/paths.py to project root + return Path(__file__).parent.parent.parent + + +def join_paths(*parts: PathLike) -> Path: + """ + Cross-platform path joining. + + Args: + *parts: Path components to join + + Returns: + Path: Joined path using platform-appropriate separator + """ + return Path(*parts) + + +def ensure_dir_exists(path: PathLike) -> Path: + """ + Ensure directory exists, create if needed. + + Args: + path: Directory path to ensure exists + + Returns: + Path: Path object for the directory + """ + path_obj = Path(path) + path_obj.mkdir(parents=True, exist_ok=True) + return path_obj + + +def get_config_dir() -> Path: + """ + Get user config directory (cross-platform). + + Returns: + Path: Platform-specific config directory + """ + if os.name == 'nt': # Windows + appdata = os.environ.get('APPDATA', '') + if appdata: + return Path(appdata) / 'py-libp2p' + else: + # Fallback to user home directory + return Path.home() / 'AppData' / 'Roaming' / 'py-libp2p' + else: # Unix-like (Linux, macOS) + return Path.home() / '.config' / 'py-libp2p' + + +def get_script_dir(script_path: Optional[PathLike] = None) -> Path: + """ + Get the directory containing a script file. + + Args: + script_path: Path to the script file. If None, uses __file__ + + Returns: + Path: Directory containing the script + """ + if script_path is None: + # This will be the directory of the calling script + import inspect + frame = inspect.currentframe() + if frame and frame.f_back: + script_path = frame.f_back.f_globals.get('__file__') + else: + raise RuntimeError("Could not determine script path") + + return Path(script_path).parent.absolute() + + +def create_temp_file(prefix: str = "py-libp2p_", suffix: str = ".log") -> Path: + """ + Create a temporary file with a unique name. + + Args: + prefix: File name prefix + suffix: File name suffix + + Returns: + Path: Path to the created temporary file + """ + temp_dir = get_temp_dir() + # Create a unique filename using timestamp and random bytes + import time + import secrets + + timestamp = time.strftime("%Y%m%d_%H%M%S") + microseconds = f"{time.time() % 1:.6f}"[2:] # Get microseconds as string + unique_id = secrets.token_hex(4) + filename = f"{prefix}{timestamp}_{microseconds}_{unique_id}{suffix}" + + temp_file = temp_dir / filename + # Create the file by touching it + temp_file.touch() + return temp_file + + +def resolve_relative_path(base_path: PathLike, relative_path: PathLike) -> Path: + """ + Resolve a relative path from a base path. + + Args: + base_path: Base directory path + relative_path: Relative path to resolve + + Returns: + Path: Resolved absolute path + """ + base = Path(base_path).resolve() + relative = Path(relative_path) + + if relative.is_absolute(): + return relative + else: + return (base / relative).resolve() + + +def normalize_path(path: PathLike) -> Path: + """ + Normalize a path, resolving any symbolic links and relative components. + + Args: + path: Path to normalize + + Returns: + Path: Normalized absolute path + """ + return Path(path).resolve() diff --git a/scripts/audit_paths.py b/scripts/audit_paths.py new file mode 100644 index 00000000..b0079869 --- /dev/null +++ b/scripts/audit_paths.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 +""" +Audit script to identify path handling issues in the py-libp2p codebase. + +This script scans for patterns that should be migrated to use the new +cross-platform path utilities. +""" + +import re +import os +from pathlib import Path +from typing import List, Dict, Any +import argparse + + +def scan_for_path_issues(directory: Path) -> Dict[str, List[Dict[str, Any]]]: + """ + Scan for path handling issues in the codebase. + + Args: + directory: Root directory to scan + + Returns: + Dictionary mapping issue types to lists of found issues + """ + issues = { + 'hard_coded_slash': [], + 'os_path_join': [], + 'temp_hardcode': [], + 'os_path_dirname': [], + 'os_path_abspath': [], + 'direct_path_concat': [], + } + + # Patterns to search for + patterns = { + 'hard_coded_slash': r'["\'][^"\']*\/[^"\']*["\']', + 'os_path_join': r'os\.path\.join\(', + 'temp_hardcode': r'["\']\/tmp\/|["\']C:\\\\', + 'os_path_dirname': r'os\.path\.dirname\(', + 'os_path_abspath': r'os\.path\.abspath\(', + 'direct_path_concat': r'["\'][^"\']*["\']\s*\+\s*["\'][^"\']*["\']', + } + + # Files to exclude + exclude_patterns = [ + r'__pycache__', + r'\.git', + r'\.pytest_cache', + r'\.mypy_cache', + r'\.ruff_cache', + r'env/', + r'venv/', + r'\.venv/', + ] + + for py_file in directory.rglob("*.py"): + # Skip excluded files + if any(re.search(pattern, str(py_file)) for pattern in exclude_patterns): + continue + + try: + content = py_file.read_text(encoding='utf-8') + except UnicodeDecodeError: + print(f"Warning: Could not read {py_file} (encoding issue)") + continue + + for issue_type, pattern in patterns.items(): + matches = re.finditer(pattern, content, re.MULTILINE) + for match in matches: + line_num = content[:match.start()].count('\n') + 1 + line_content = content.split('\n')[line_num - 1].strip() + + issues[issue_type].append({ + 'file': py_file, + 'line': line_num, + 'content': match.group(), + 'full_line': line_content, + 'relative_path': py_file.relative_to(directory) + }) + + return issues + + +def generate_migration_suggestions(issues: Dict[str, List[Dict[str, Any]]]) -> str: + """ + Generate migration suggestions for found issues. + + Args: + issues: Dictionary of found issues + + Returns: + Formatted string with migration suggestions + """ + suggestions = [] + + for issue_type, issue_list in issues.items(): + if not issue_list: + continue + + suggestions.append(f"\n## {issue_type.replace('_', ' ').title()}") + suggestions.append(f"Found {len(issue_list)} instances:") + + for issue in issue_list[:10]: # Show first 10 examples + suggestions.append(f"\n### {issue['relative_path']}:{issue['line']}") + suggestions.append(f"```python") + suggestions.append(f"# Current code:") + suggestions.append(f"{issue['full_line']}") + suggestions.append(f"```") + + # Add migration suggestion based on issue type + if issue_type == 'os_path_join': + suggestions.append(f"```python") + suggestions.append(f"# Suggested fix:") + suggestions.append(f"from libp2p.utils.paths import join_paths") + suggestions.append(f"# Replace os.path.join(a, b, c) with join_paths(a, b, c)") + suggestions.append(f"```") + elif issue_type == 'temp_hardcode': + suggestions.append(f"```python") + suggestions.append(f"# Suggested fix:") + suggestions.append(f"from libp2p.utils.paths import get_temp_dir, create_temp_file") + suggestions.append(f"# Replace hard-coded temp paths with get_temp_dir() or create_temp_file()") + suggestions.append(f"```") + elif issue_type == 'os_path_dirname': + suggestions.append(f"```python") + suggestions.append(f"# Suggested fix:") + suggestions.append(f"from libp2p.utils.paths import get_script_dir") + suggestions.append(f"# Replace os.path.dirname(os.path.abspath(__file__)) with get_script_dir(__file__)") + suggestions.append(f"```") + + if len(issue_list) > 10: + suggestions.append(f"\n... and {len(issue_list) - 10} more instances") + + return "\n".join(suggestions) + + +def generate_summary_report(issues: Dict[str, List[Dict[str, Any]]]) -> str: + """ + Generate a summary report of all found issues. + + Args: + issues: Dictionary of found issues + + Returns: + Formatted summary report + """ + total_issues = sum(len(issue_list) for issue_list in issues.values()) + + report = [ + "# Cross-Platform Path Handling Audit Report", + "", + f"## Summary", + f"Total issues found: {total_issues}", + "", + "## Issue Breakdown:", + ] + + for issue_type, issue_list in issues.items(): + if issue_list: + report.append(f"- **{issue_type.replace('_', ' ').title()}**: {len(issue_list)} instances") + + report.append("") + report.append("## Priority Matrix:") + report.append("") + report.append("| Priority | Issue Type | Risk Level | Impact |") + report.append("|----------|------------|------------|---------|") + + priority_map = { + 'temp_hardcode': ('šŸ”“ P0', 'HIGH', 'Core functionality fails on different platforms'), + 'os_path_join': ('🟔 P1', 'MEDIUM', 'Examples and utilities may break'), + 'os_path_dirname': ('🟔 P1', 'MEDIUM', 'Script location detection issues'), + 'hard_coded_slash': ('🟢 P2', 'LOW', 'Future-proofing and consistency'), + 'os_path_abspath': ('🟢 P2', 'LOW', 'Path resolution consistency'), + 'direct_path_concat': ('🟢 P2', 'LOW', 'String concatenation issues'), + } + + for issue_type, issue_list in issues.items(): + if issue_list: + priority, risk, impact = priority_map.get(issue_type, ('🟢 P2', 'LOW', 'General improvement')) + report.append(f"| {priority} | {issue_type.replace('_', ' ').title()} | {risk} | {impact} |") + + return "\n".join(report) + + +def main(): + """Main function to run the audit.""" + parser = argparse.ArgumentParser(description="Audit py-libp2p codebase for path handling issues") + parser.add_argument("--directory", default=".", help="Directory to scan (default: current directory)") + parser.add_argument("--output", help="Output file for detailed report") + parser.add_argument("--summary-only", action="store_true", help="Only show summary report") + + args = parser.parse_args() + + directory = Path(args.directory) + if not directory.exists(): + print(f"Error: Directory {directory} does not exist") + return 1 + + print("šŸ” Scanning for path handling issues...") + issues = scan_for_path_issues(directory) + + # Generate and display summary + summary = generate_summary_report(issues) + print(summary) + + if not args.summary_only: + # Generate detailed suggestions + suggestions = generate_migration_suggestions(issues) + + if args.output: + with open(args.output, 'w', encoding='utf-8') as f: + f.write(summary) + f.write(suggestions) + print(f"\nšŸ“„ Detailed report saved to {args.output}") + else: + print(suggestions) + + return 0 + + +if __name__ == "__main__": + exit(main()) diff --git a/tests/utils/test_paths.py b/tests/utils/test_paths.py new file mode 100644 index 00000000..fcd4c08a --- /dev/null +++ b/tests/utils/test_paths.py @@ -0,0 +1,213 @@ +""" +Tests for cross-platform path utilities. +""" + +import os +import tempfile +from pathlib import Path +import pytest + +from libp2p.utils.paths import ( + get_temp_dir, + get_project_root, + join_paths, + ensure_dir_exists, + get_config_dir, + get_script_dir, + create_temp_file, + resolve_relative_path, + normalize_path, +) + + +class TestPathUtilities: + """Test cross-platform path utilities.""" + + def test_get_temp_dir(self): + """Test that temp directory is accessible and exists.""" + temp_dir = get_temp_dir() + assert isinstance(temp_dir, Path) + assert temp_dir.exists() + assert temp_dir.is_dir() + # Should match system temp directory + assert temp_dir == Path(tempfile.gettempdir()) + + def test_get_project_root(self): + """Test that project root is correctly determined.""" + project_root = get_project_root() + assert isinstance(project_root, Path) + assert project_root.exists() + # Should contain pyproject.toml + assert (project_root / "pyproject.toml").exists() + # Should contain libp2p directory + assert (project_root / "libp2p").exists() + + def test_join_paths(self): + """Test cross-platform path joining.""" + # Test with strings + result = join_paths("a", "b", "c") + expected = Path("a") / "b" / "c" + assert result == expected + + # Test with mixed types + result = join_paths("a", Path("b"), "c") + expected = Path("a") / "b" / "c" + assert result == expected + + # Test with absolute path + result = join_paths("/absolute", "path") + expected = Path("/absolute") / "path" + assert result == expected + + def test_ensure_dir_exists(self, tmp_path): + """Test directory creation and existence checking.""" + # Test creating new directory + new_dir = tmp_path / "new_dir" + result = ensure_dir_exists(new_dir) + assert result == new_dir + assert new_dir.exists() + assert new_dir.is_dir() + + # Test creating nested directory + nested_dir = tmp_path / "parent" / "child" / "grandchild" + result = ensure_dir_exists(nested_dir) + assert result == nested_dir + assert nested_dir.exists() + assert nested_dir.is_dir() + + # Test with existing directory + result = ensure_dir_exists(new_dir) + assert result == new_dir + assert new_dir.exists() + + def test_get_config_dir(self): + """Test platform-specific config directory.""" + config_dir = get_config_dir() + assert isinstance(config_dir, Path) + + if os.name == 'nt': # Windows + # Should be in AppData/Roaming or user home + assert "AppData" in str(config_dir) or "py-libp2p" in str(config_dir) + else: # Unix-like + # Should be in ~/.config + assert ".config" in str(config_dir) + assert "py-libp2p" in str(config_dir) + + def test_get_script_dir(self): + """Test script directory detection.""" + # Test with current file + script_dir = get_script_dir(__file__) + assert isinstance(script_dir, Path) + assert script_dir.exists() + assert script_dir.is_dir() + # Should contain this test file + assert (script_dir / "test_paths.py").exists() + + def test_create_temp_file(self): + """Test temporary file creation.""" + temp_file = create_temp_file() + assert isinstance(temp_file, Path) + assert temp_file.parent == get_temp_dir() + assert temp_file.name.startswith("py-libp2p_") + assert temp_file.name.endswith(".log") + + # Test with custom prefix and suffix + temp_file = create_temp_file(prefix="test_", suffix=".txt") + assert temp_file.name.startswith("test_") + assert temp_file.name.endswith(".txt") + + def test_resolve_relative_path(self, tmp_path): + """Test relative path resolution.""" + base_path = tmp_path / "base" + base_path.mkdir() + + # Test relative path + relative_path = "subdir/file.txt" + result = resolve_relative_path(base_path, relative_path) + expected = (base_path / "subdir" / "file.txt").resolve() + assert result == expected + + # Test absolute path (platform-agnostic) + if os.name == 'nt': # Windows + absolute_path = "C:\\absolute\\path" + else: # Unix-like + absolute_path = "/absolute/path" + result = resolve_relative_path(base_path, absolute_path) + assert result == Path(absolute_path) + + def test_normalize_path(self, tmp_path): + """Test path normalization.""" + # Test with relative path + relative_path = tmp_path / ".." / "normalize_test" + result = normalize_path(relative_path) + assert result.is_absolute() + assert "normalize_test" in str(result) + + # Test with absolute path + absolute_path = tmp_path / "test_file" + result = normalize_path(absolute_path) + assert result.is_absolute() + assert result == absolute_path.resolve() + + +class TestCrossPlatformCompatibility: + """Test cross-platform compatibility.""" + + def test_config_dir_platform_specific_windows(self, monkeypatch): + """Test config directory respects Windows conventions.""" + monkeypatch.setattr('os.name', 'nt') + monkeypatch.setenv('APPDATA', 'C:\\Users\\Test\\AppData\\Roaming') + config_dir = get_config_dir() + assert "AppData" in str(config_dir) + assert "py-libp2p" in str(config_dir) + + def test_path_separators_consistent(self): + """Test that path separators are handled consistently.""" + # Test that join_paths uses platform-appropriate separators + result = join_paths("dir1", "dir2", "file.txt") + expected = Path("dir1") / "dir2" / "file.txt" + assert result == expected + + # Test that the result uses correct separators for the platform + if os.name == 'nt': # Windows + assert "\\" in str(result) or "/" in str(result) + else: # Unix-like + assert "/" in str(result) + + def test_temp_file_uniqueness(self): + """Test that temporary files have unique names.""" + files = set() + for _ in range(10): + temp_file = create_temp_file() + assert temp_file not in files + files.add(temp_file) + + +class TestBackwardCompatibility: + """Test backward compatibility with existing code patterns.""" + + def test_path_operations_equivalent(self): + """Test that new path operations are equivalent to old os.path operations.""" + # Test join_paths vs os.path.join + parts = ["a", "b", "c"] + new_result = join_paths(*parts) + old_result = Path(os.path.join(*parts)) + assert new_result == old_result + + # Test get_script_dir vs os.path.dirname(os.path.abspath(__file__)) + new_script_dir = get_script_dir(__file__) + old_script_dir = Path(os.path.dirname(os.path.abspath(__file__))) + assert new_script_dir == old_script_dir + + def test_existing_functionality_preserved(self): + """Ensure no existing functionality is broken.""" + # Test that all functions return Path objects + assert isinstance(get_temp_dir(), Path) + assert isinstance(get_project_root(), Path) + assert isinstance(join_paths("a", "b"), Path) + assert isinstance(ensure_dir_exists(tempfile.gettempdir()), Path) + assert isinstance(get_config_dir(), Path) + assert isinstance(get_script_dir(__file__), Path) + assert isinstance(create_temp_file(), Path) + assert isinstance(resolve_relative_path(".", "test"), Path) + assert isinstance(normalize_path("."), Path) From 64ccce17eb2e67a7dfa8f8a1cef97ddfa83d3235 Mon Sep 17 00:00:00 2001 From: yashksaini-coder Date: Mon, 1 Sep 2025 02:03:51 +0530 Subject: [PATCH 64/71] fix(app): 882 Comprehensive cross-platform path handling utilities --- docs/conf.py | 4 +- examples/kademlia/kademlia.py | 5 +- libp2p/utils/logging.py | 14 +- libp2p/utils/paths.py | 163 +++++++++++++++++++---- scripts/audit_paths.py | 241 +++++++++++++++++++--------------- tests/utils/test_paths.py | 103 ++++++++++++--- 6 files changed, 367 insertions(+), 163 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 446252f1..64618359 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -27,7 +27,9 @@ except ModuleNotFoundError: import tomli as tomllib # type: ignore (In case of >3.11 Pyrefly doesnt find tomli , which is right but a false flag) # Path to pyproject.toml (assuming conf.py is in a 'docs' subdirectory) -pyproject_path = os.path.join(os.path.dirname(__file__), "..", "pyproject.toml") +from libp2p.utils.paths import get_project_root, join_paths + +pyproject_path = join_paths(get_project_root(), "pyproject.toml") with open(pyproject_path, "rb") as f: pyproject_data = tomllib.load(f) diff --git a/examples/kademlia/kademlia.py b/examples/kademlia/kademlia.py index 5daa70d7..faaa66be 100644 --- a/examples/kademlia/kademlia.py +++ b/examples/kademlia/kademlia.py @@ -41,6 +41,7 @@ from libp2p.tools.async_service import ( from libp2p.tools.utils import ( info_from_p2p_addr, ) +from libp2p.utils.paths import get_script_dir, join_paths # Configure logging logging.basicConfig( @@ -53,8 +54,8 @@ logger = logging.getLogger("kademlia-example") # Configure DHT module loggers to inherit from the parent logger # This ensures all kademlia-example.* loggers use the same configuration # Get the directory where this script is located -SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) -SERVER_ADDR_LOG = os.path.join(SCRIPT_DIR, "server_node_addr.txt") +SCRIPT_DIR = get_script_dir(__file__) +SERVER_ADDR_LOG = join_paths(SCRIPT_DIR, "server_node_addr.txt") # Set the level for all child loggers for module in [ diff --git a/libp2p/utils/logging.py b/libp2p/utils/logging.py index 3458a41e..acc67373 100644 --- a/libp2p/utils/logging.py +++ b/libp2p/utils/logging.py @@ -1,7 +1,4 @@ import atexit -from datetime import ( - datetime, -) import logging import logging.handlers import os @@ -148,13 +145,10 @@ def setup_logging() -> None: log_path = Path(log_file) log_path.parent.mkdir(parents=True, exist_ok=True) else: - # Default log file with timestamp and unique identifier - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") - unique_id = os.urandom(4).hex() # Add a unique identifier to prevent collisions - if os.name == "nt": # Windows - log_file = f"C:\\Windows\\Temp\\py-libp2p_{timestamp}_{unique_id}.log" - else: # Unix-like - log_file = f"/tmp/py-libp2p_{timestamp}_{unique_id}.log" + # Use cross-platform temp file creation + from libp2p.utils.paths import create_temp_file + + log_file = str(create_temp_file(prefix="py-libp2p_", suffix=".log")) # Print the log file path so users know where to find it print(f"Logging to: {log_file}", file=sys.stderr) diff --git a/libp2p/utils/paths.py b/libp2p/utils/paths.py index 27924d8f..23f10dc6 100644 --- a/libp2p/utils/paths.py +++ b/libp2p/utils/paths.py @@ -6,9 +6,10 @@ behavior across Windows, macOS, and Linux platforms. """ import os -import tempfile from pathlib import Path -from typing import Union, Optional +import sys +import tempfile +from typing import Union PathLike = Union[str, Path] @@ -16,9 +17,10 @@ PathLike = Union[str, Path] def get_temp_dir() -> Path: """ Get cross-platform temporary directory. - + Returns: Path: Platform-specific temporary directory path + """ return Path(tempfile.gettempdir()) @@ -26,9 +28,10 @@ def get_temp_dir() -> Path: def get_project_root() -> Path: """ Get the project root directory. - + Returns: Path: Path to the py-libp2p project root + """ # Navigate from libp2p/utils/paths.py to project root return Path(__file__).parent.parent.parent @@ -37,12 +40,13 @@ def get_project_root() -> Path: def join_paths(*parts: PathLike) -> Path: """ Cross-platform path joining. - + Args: *parts: Path components to join - + Returns: Path: Joined path using platform-appropriate separator + """ return Path(*parts) @@ -50,12 +54,13 @@ def join_paths(*parts: PathLike) -> Path: def ensure_dir_exists(path: PathLike) -> Path: """ Ensure directory exists, create if needed. - + Args: path: Directory path to ensure exists - + Returns: Path: Path object for the directory + """ path_obj = Path(path) path_obj.mkdir(parents=True, exist_ok=True) @@ -65,64 +70,74 @@ def ensure_dir_exists(path: PathLike) -> Path: def get_config_dir() -> Path: """ Get user config directory (cross-platform). - + Returns: Path: Platform-specific config directory + """ - if os.name == 'nt': # Windows - appdata = os.environ.get('APPDATA', '') + if os.name == "nt": # Windows + appdata = os.environ.get("APPDATA", "") if appdata: - return Path(appdata) / 'py-libp2p' + return Path(appdata) / "py-libp2p" else: # Fallback to user home directory - return Path.home() / 'AppData' / 'Roaming' / 'py-libp2p' + return Path.home() / "AppData" / "Roaming" / "py-libp2p" else: # Unix-like (Linux, macOS) - return Path.home() / '.config' / 'py-libp2p' + return Path.home() / ".config" / "py-libp2p" -def get_script_dir(script_path: Optional[PathLike] = None) -> Path: +def get_script_dir(script_path: PathLike | None = None) -> Path: """ Get the directory containing a script file. - + Args: script_path: Path to the script file. If None, uses __file__ - + Returns: Path: Directory containing the script + + Raises: + RuntimeError: If script path cannot be determined + """ if script_path is None: # This will be the directory of the calling script import inspect + frame = inspect.currentframe() if frame and frame.f_back: - script_path = frame.f_back.f_globals.get('__file__') + script_path = frame.f_back.f_globals.get("__file__") else: raise RuntimeError("Could not determine script path") - + + if script_path is None: + raise RuntimeError("Script path is None") + return Path(script_path).parent.absolute() def create_temp_file(prefix: str = "py-libp2p_", suffix: str = ".log") -> Path: """ Create a temporary file with a unique name. - + Args: prefix: File name prefix suffix: File name suffix - + Returns: Path: Path to the created temporary file + """ temp_dir = get_temp_dir() # Create a unique filename using timestamp and random bytes - import time import secrets - + import time + timestamp = time.strftime("%Y%m%d_%H%M%S") microseconds = f"{time.time() % 1:.6f}"[2:] # Get microseconds as string unique_id = secrets.token_hex(4) filename = f"{prefix}{timestamp}_{microseconds}_{unique_id}{suffix}" - + temp_file = temp_dir / filename # Create the file by touching it temp_file.touch() @@ -132,17 +147,18 @@ def create_temp_file(prefix: str = "py-libp2p_", suffix: str = ".log") -> Path: def resolve_relative_path(base_path: PathLike, relative_path: PathLike) -> Path: """ Resolve a relative path from a base path. - + Args: base_path: Base directory path relative_path: Relative path to resolve - + Returns: Path: Resolved absolute path + """ base = Path(base_path).resolve() relative = Path(relative_path) - + if relative.is_absolute(): return relative else: @@ -152,11 +168,100 @@ def resolve_relative_path(base_path: PathLike, relative_path: PathLike) -> Path: def normalize_path(path: PathLike) -> Path: """ Normalize a path, resolving any symbolic links and relative components. - + Args: path: Path to normalize - + Returns: Path: Normalized absolute path + """ return Path(path).resolve() + + +def get_venv_path() -> Path | None: + """ + Get virtual environment path if active. + + Returns: + Path: Virtual environment path if active, None otherwise + + """ + venv_path = os.environ.get("VIRTUAL_ENV") + if venv_path: + return Path(venv_path) + return None + + +def get_python_executable() -> Path: + """ + Get current Python executable path. + + Returns: + Path: Path to the current Python executable + + """ + return Path(sys.executable) + + +def find_executable(name: str) -> Path | None: + """ + Find executable in system PATH. + + Args: + name: Name of the executable to find + + Returns: + Path: Path to executable if found, None otherwise + + """ + # Check if name already contains path + if os.path.dirname(name): + path = Path(name) + if path.exists() and os.access(path, os.X_OK): + return path + return None + + # Search in PATH + for path_dir in os.environ.get("PATH", "").split(os.pathsep): + if not path_dir: + continue + path = Path(path_dir) / name + if path.exists() and os.access(path, os.X_OK): + return path + + return None + + +def get_script_binary_path() -> Path: + """ + Get path to script's binary directory. + + Returns: + Path: Directory containing the script's binary + + """ + return get_python_executable().parent + + +def get_binary_path(binary_name: str) -> Path | None: + """ + Find binary in PATH or virtual environment. + + Args: + binary_name: Name of the binary to find + + Returns: + Path: Path to binary if found, None otherwise + + """ + # First check in virtual environment if active + venv_path = get_venv_path() + if venv_path: + venv_bin = venv_path / "bin" if os.name != "nt" else venv_path / "Scripts" + binary_path = venv_bin / binary_name + if binary_path.exists() and os.access(binary_path, os.X_OK): + return binary_path + + # Fall back to system PATH + return find_executable(binary_name) diff --git a/scripts/audit_paths.py b/scripts/audit_paths.py index b0079869..80df11f8 100644 --- a/scripts/audit_paths.py +++ b/scripts/audit_paths.py @@ -6,215 +6,248 @@ This script scans for patterns that should be migrated to use the new cross-platform path utilities. """ -import re -import os -from pathlib import Path -from typing import List, Dict, Any import argparse +from pathlib import Path +import re +from typing import Any -def scan_for_path_issues(directory: Path) -> Dict[str, List[Dict[str, Any]]]: +def scan_for_path_issues(directory: Path) -> dict[str, list[dict[str, Any]]]: """ Scan for path handling issues in the codebase. - + Args: directory: Root directory to scan - + Returns: Dictionary mapping issue types to lists of found issues + """ issues = { - 'hard_coded_slash': [], - 'os_path_join': [], - 'temp_hardcode': [], - 'os_path_dirname': [], - 'os_path_abspath': [], - 'direct_path_concat': [], + "hard_coded_slash": [], + "os_path_join": [], + "temp_hardcode": [], + "os_path_dirname": [], + "os_path_abspath": [], + "direct_path_concat": [], } - + # Patterns to search for patterns = { - 'hard_coded_slash': r'["\'][^"\']*\/[^"\']*["\']', - 'os_path_join': r'os\.path\.join\(', - 'temp_hardcode': r'["\']\/tmp\/|["\']C:\\\\', - 'os_path_dirname': r'os\.path\.dirname\(', - 'os_path_abspath': r'os\.path\.abspath\(', - 'direct_path_concat': r'["\'][^"\']*["\']\s*\+\s*["\'][^"\']*["\']', + "hard_coded_slash": r'["\'][^"\']*\/[^"\']*["\']', + "os_path_join": r"os\.path\.join\(", + "temp_hardcode": r'["\']\/tmp\/|["\']C:\\\\', + "os_path_dirname": r"os\.path\.dirname\(", + "os_path_abspath": r"os\.path\.abspath\(", + "direct_path_concat": r'["\'][^"\']*["\']\s*\+\s*["\'][^"\']*["\']', } - + # Files to exclude exclude_patterns = [ - r'__pycache__', - r'\.git', - r'\.pytest_cache', - r'\.mypy_cache', - r'\.ruff_cache', - r'env/', - r'venv/', - r'\.venv/', + r"__pycache__", + r"\.git", + r"\.pytest_cache", + r"\.mypy_cache", + r"\.ruff_cache", + r"env/", + r"venv/", + r"\.venv/", ] - + for py_file in directory.rglob("*.py"): # Skip excluded files if any(re.search(pattern, str(py_file)) for pattern in exclude_patterns): continue - + try: - content = py_file.read_text(encoding='utf-8') + content = py_file.read_text(encoding="utf-8") except UnicodeDecodeError: print(f"Warning: Could not read {py_file} (encoding issue)") continue - + for issue_type, pattern in patterns.items(): matches = re.finditer(pattern, content, re.MULTILINE) for match in matches: - line_num = content[:match.start()].count('\n') + 1 - line_content = content.split('\n')[line_num - 1].strip() - - issues[issue_type].append({ - 'file': py_file, - 'line': line_num, - 'content': match.group(), - 'full_line': line_content, - 'relative_path': py_file.relative_to(directory) - }) - + line_num = content[: match.start()].count("\n") + 1 + line_content = content.split("\n")[line_num - 1].strip() + + issues[issue_type].append( + { + "file": py_file, + "line": line_num, + "content": match.group(), + "full_line": line_content, + "relative_path": py_file.relative_to(directory), + } + ) + return issues -def generate_migration_suggestions(issues: Dict[str, List[Dict[str, Any]]]) -> str: +def generate_migration_suggestions(issues: dict[str, list[dict[str, Any]]]) -> str: """ Generate migration suggestions for found issues. - + Args: issues: Dictionary of found issues - + Returns: Formatted string with migration suggestions + """ suggestions = [] - + for issue_type, issue_list in issues.items(): if not issue_list: continue - + suggestions.append(f"\n## {issue_type.replace('_', ' ').title()}") suggestions.append(f"Found {len(issue_list)} instances:") - + for issue in issue_list[:10]: # Show first 10 examples suggestions.append(f"\n### {issue['relative_path']}:{issue['line']}") - suggestions.append(f"```python") - suggestions.append(f"# Current code:") + suggestions.append("```python") + suggestions.append("# Current code:") suggestions.append(f"{issue['full_line']}") - suggestions.append(f"```") - + suggestions.append("```") + # Add migration suggestion based on issue type - if issue_type == 'os_path_join': - suggestions.append(f"```python") - suggestions.append(f"# Suggested fix:") - suggestions.append(f"from libp2p.utils.paths import join_paths") - suggestions.append(f"# Replace os.path.join(a, b, c) with join_paths(a, b, c)") - suggestions.append(f"```") - elif issue_type == 'temp_hardcode': - suggestions.append(f"```python") - suggestions.append(f"# Suggested fix:") - suggestions.append(f"from libp2p.utils.paths import get_temp_dir, create_temp_file") - suggestions.append(f"# Replace hard-coded temp paths with get_temp_dir() or create_temp_file()") - suggestions.append(f"```") - elif issue_type == 'os_path_dirname': - suggestions.append(f"```python") - suggestions.append(f"# Suggested fix:") - suggestions.append(f"from libp2p.utils.paths import get_script_dir") - suggestions.append(f"# Replace os.path.dirname(os.path.abspath(__file__)) with get_script_dir(__file__)") - suggestions.append(f"```") - + if issue_type == "os_path_join": + suggestions.append("```python") + suggestions.append("# Suggested fix:") + suggestions.append("from libp2p.utils.paths import join_paths") + suggestions.append( + "# Replace os.path.join(a, b, c) with join_paths(a, b, c)" + ) + suggestions.append("```") + elif issue_type == "temp_hardcode": + suggestions.append("```python") + suggestions.append("# Suggested fix:") + suggestions.append( + "from libp2p.utils.paths import get_temp_dir, create_temp_file" + ) + temp_fix_msg = ( + "# Replace hard-coded temp paths with get_temp_dir() or " + "create_temp_file()" + ) + suggestions.append(temp_fix_msg) + suggestions.append("```") + elif issue_type == "os_path_dirname": + suggestions.append("```python") + suggestions.append("# Suggested fix:") + suggestions.append("from libp2p.utils.paths import get_script_dir") + script_dir_fix_msg = ( + "# Replace os.path.dirname(os.path.abspath(__file__)) with " + "get_script_dir(__file__)" + ) + suggestions.append(script_dir_fix_msg) + suggestions.append("```") + if len(issue_list) > 10: suggestions.append(f"\n... and {len(issue_list) - 10} more instances") - + return "\n".join(suggestions) -def generate_summary_report(issues: Dict[str, List[Dict[str, Any]]]) -> str: +def generate_summary_report(issues: dict[str, list[dict[str, Any]]]) -> str: """ Generate a summary report of all found issues. - + Args: issues: Dictionary of found issues - + Returns: Formatted summary report + """ total_issues = sum(len(issue_list) for issue_list in issues.values()) - + report = [ "# Cross-Platform Path Handling Audit Report", "", - f"## Summary", + "## Summary", f"Total issues found: {total_issues}", "", "## Issue Breakdown:", ] - + for issue_type, issue_list in issues.items(): if issue_list: - report.append(f"- **{issue_type.replace('_', ' ').title()}**: {len(issue_list)} instances") - + issue_title = issue_type.replace("_", " ").title() + instances_count = len(issue_list) + report.append(f"- **{issue_title}**: {instances_count} instances") + report.append("") report.append("## Priority Matrix:") report.append("") report.append("| Priority | Issue Type | Risk Level | Impact |") report.append("|----------|------------|------------|---------|") - + priority_map = { - 'temp_hardcode': ('šŸ”“ P0', 'HIGH', 'Core functionality fails on different platforms'), - 'os_path_join': ('🟔 P1', 'MEDIUM', 'Examples and utilities may break'), - 'os_path_dirname': ('🟔 P1', 'MEDIUM', 'Script location detection issues'), - 'hard_coded_slash': ('🟢 P2', 'LOW', 'Future-proofing and consistency'), - 'os_path_abspath': ('🟢 P2', 'LOW', 'Path resolution consistency'), - 'direct_path_concat': ('🟢 P2', 'LOW', 'String concatenation issues'), + "temp_hardcode": ( + "šŸ”“ P0", + "HIGH", + "Core functionality fails on different platforms", + ), + "os_path_join": ("🟔 P1", "MEDIUM", "Examples and utilities may break"), + "os_path_dirname": ("🟔 P1", "MEDIUM", "Script location detection issues"), + "hard_coded_slash": ("🟢 P2", "LOW", "Future-proofing and consistency"), + "os_path_abspath": ("🟢 P2", "LOW", "Path resolution consistency"), + "direct_path_concat": ("🟢 P2", "LOW", "String concatenation issues"), } - + for issue_type, issue_list in issues.items(): if issue_list: - priority, risk, impact = priority_map.get(issue_type, ('🟢 P2', 'LOW', 'General improvement')) - report.append(f"| {priority} | {issue_type.replace('_', ' ').title()} | {risk} | {impact} |") - + priority, risk, impact = priority_map.get( + issue_type, ("🟢 P2", "LOW", "General improvement") + ) + issue_title = issue_type.replace("_", " ").title() + report.append(f"| {priority} | {issue_title} | {risk} | {impact} |") + return "\n".join(report) def main(): """Main function to run the audit.""" - parser = argparse.ArgumentParser(description="Audit py-libp2p codebase for path handling issues") - parser.add_argument("--directory", default=".", help="Directory to scan (default: current directory)") + parser = argparse.ArgumentParser( + description="Audit py-libp2p codebase for path handling issues" + ) + parser.add_argument( + "--directory", + default=".", + help="Directory to scan (default: current directory)", + ) parser.add_argument("--output", help="Output file for detailed report") - parser.add_argument("--summary-only", action="store_true", help="Only show summary report") - + parser.add_argument( + "--summary-only", action="store_true", help="Only show summary report" + ) + args = parser.parse_args() - + directory = Path(args.directory) if not directory.exists(): print(f"Error: Directory {directory} does not exist") return 1 - + print("šŸ” Scanning for path handling issues...") issues = scan_for_path_issues(directory) - + # Generate and display summary summary = generate_summary_report(issues) print(summary) - + if not args.summary_only: # Generate detailed suggestions suggestions = generate_migration_suggestions(issues) - + if args.output: - with open(args.output, 'w', encoding='utf-8') as f: + with open(args.output, "w", encoding="utf-8") as f: f.write(summary) f.write(suggestions) print(f"\nšŸ“„ Detailed report saved to {args.output}") else: print(suggestions) - + return 0 diff --git a/tests/utils/test_paths.py b/tests/utils/test_paths.py index fcd4c08a..e5247cdb 100644 --- a/tests/utils/test_paths.py +++ b/tests/utils/test_paths.py @@ -3,20 +3,24 @@ Tests for cross-platform path utilities. """ import os -import tempfile from pathlib import Path -import pytest +import tempfile from libp2p.utils.paths import ( - get_temp_dir, - get_project_root, - join_paths, - ensure_dir_exists, - get_config_dir, - get_script_dir, create_temp_file, - resolve_relative_path, + ensure_dir_exists, + find_executable, + get_binary_path, + get_config_dir, + get_project_root, + get_python_executable, + get_script_binary_path, + get_script_dir, + get_temp_dir, + get_venv_path, + join_paths, normalize_path, + resolve_relative_path, ) @@ -84,8 +88,8 @@ class TestPathUtilities: """Test platform-specific config directory.""" config_dir = get_config_dir() assert isinstance(config_dir, Path) - - if os.name == 'nt': # Windows + + if os.name == "nt": # Windows # Should be in AppData/Roaming or user home assert "AppData" in str(config_dir) or "py-libp2p" in str(config_dir) else: # Unix-like @@ -120,7 +124,7 @@ class TestPathUtilities: """Test relative path resolution.""" base_path = tmp_path / "base" base_path.mkdir() - + # Test relative path relative_path = "subdir/file.txt" result = resolve_relative_path(base_path, relative_path) @@ -128,7 +132,7 @@ class TestPathUtilities: assert result == expected # Test absolute path (platform-agnostic) - if os.name == 'nt': # Windows + if os.name == "nt": # Windows absolute_path = "C:\\absolute\\path" else: # Unix-like absolute_path = "/absolute/path" @@ -149,14 +153,72 @@ class TestPathUtilities: assert result.is_absolute() assert result == absolute_path.resolve() + def test_get_venv_path(self, monkeypatch): + """Test virtual environment path detection.""" + # Test when no virtual environment is active + # Temporarily clear VIRTUAL_ENV to test the "no venv" case + monkeypatch.delenv("VIRTUAL_ENV", raising=False) + result = get_venv_path() + assert result is None + + # Test when virtual environment is active + test_venv_path = "/path/to/venv" + monkeypatch.setenv("VIRTUAL_ENV", test_venv_path) + result = get_venv_path() + assert result == Path(test_venv_path) + + def test_get_python_executable(self): + """Test Python executable path detection.""" + result = get_python_executable() + assert isinstance(result, Path) + assert result.exists() + assert result.name.startswith("python") + + def test_find_executable(self): + """Test executable finding in PATH.""" + # Test with non-existent executable + result = find_executable("nonexistent_executable") + assert result is None + + # Test with existing executable (python should be available) + result = find_executable("python") + if result: + assert isinstance(result, Path) + assert result.exists() + + def test_get_script_binary_path(self): + """Test script binary path detection.""" + result = get_script_binary_path() + assert isinstance(result, Path) + assert result.exists() + assert result.is_dir() + + def test_get_binary_path(self, monkeypatch): + """Test binary path resolution with virtual environment.""" + # Test when no virtual environment is active + result = get_binary_path("python") + if result: + assert isinstance(result, Path) + assert result.exists() + + # Test when virtual environment is active + test_venv_path = "/path/to/venv" + monkeypatch.setenv("VIRTUAL_ENV", test_venv_path) + # This test is more complex as it depends on the actual venv structure + # We'll just verify the function doesn't crash + result = get_binary_path("python") + # Result can be None if binary not found in venv + if result: + assert isinstance(result, Path) + class TestCrossPlatformCompatibility: """Test cross-platform compatibility.""" def test_config_dir_platform_specific_windows(self, monkeypatch): """Test config directory respects Windows conventions.""" - monkeypatch.setattr('os.name', 'nt') - monkeypatch.setenv('APPDATA', 'C:\\Users\\Test\\AppData\\Roaming') + monkeypatch.setattr("os.name", "nt") + monkeypatch.setenv("APPDATA", "C:\\Users\\Test\\AppData\\Roaming") config_dir = get_config_dir() assert "AppData" in str(config_dir) assert "py-libp2p" in str(config_dir) @@ -167,9 +229,9 @@ class TestCrossPlatformCompatibility: result = join_paths("dir1", "dir2", "file.txt") expected = Path("dir1") / "dir2" / "file.txt" assert result == expected - + # Test that the result uses correct separators for the platform - if os.name == 'nt': # Windows + if os.name == "nt": # Windows assert "\\" in str(result) or "/" in str(result) else: # Unix-like assert "/" in str(result) @@ -211,3 +273,10 @@ class TestBackwardCompatibility: assert isinstance(create_temp_file(), Path) assert isinstance(resolve_relative_path(".", "test"), Path) assert isinstance(normalize_path("."), Path) + assert isinstance(get_python_executable(), Path) + assert isinstance(get_script_binary_path(), Path) + + # Test optional return types + venv_path = get_venv_path() + if venv_path is not None: + assert isinstance(venv_path, Path) From 42c8937a8d419ae31c3b34828e841d49e08c8c9f Mon Sep 17 00:00:00 2001 From: yashksaini-coder Date: Mon, 1 Sep 2025 02:53:53 +0530 Subject: [PATCH 65/71] build(app): Add fallback to os.path.join + newsfragment 886 --- docs/conf.py | 15 +++++++++------ newsfragments/886.bugfix.rst | 2 ++ 2 files changed, 11 insertions(+), 6 deletions(-) create mode 100644 newsfragments/886.bugfix.rst diff --git a/docs/conf.py b/docs/conf.py index 64618359..6f7be709 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -16,20 +16,23 @@ # sys.path.insert(0, os.path.abspath('.')) import doctest -import os import sys from unittest.mock import MagicMock try: import tomllib -except ModuleNotFoundError: +except ImportError: # For Python < 3.11 - import tomli as tomllib # type: ignore (In case of >3.11 Pyrefly doesnt find tomli , which is right but a false flag) + import tomli as tomllib # type: ignore # Path to pyproject.toml (assuming conf.py is in a 'docs' subdirectory) -from libp2p.utils.paths import get_project_root, join_paths - -pyproject_path = join_paths(get_project_root(), "pyproject.toml") +try: + from libp2p.utils.paths import get_project_root, join_paths + pyproject_path = join_paths(get_project_root(), "pyproject.toml") +except ImportError: + # Fallback for documentation builds where libp2p is not available + import os + pyproject_path = os.path.join(os.path.dirname(__file__), "..", "pyproject.toml") with open(pyproject_path, "rb") as f: pyproject_data = tomllib.load(f) diff --git a/newsfragments/886.bugfix.rst b/newsfragments/886.bugfix.rst new file mode 100644 index 00000000..add7c4ab --- /dev/null +++ b/newsfragments/886.bugfix.rst @@ -0,0 +1,2 @@ +Fixed cross-platform path handling by replacing hardcoded OS-specific +paths with standardized utilities in core modules and examples. \ No newline at end of file From fcb35084b399d5c0d228b3594b73772779b74d6f Mon Sep 17 00:00:00 2001 From: yashksaini-coder Date: Mon, 1 Sep 2025 03:14:09 +0530 Subject: [PATCH 66/71] fix(docs): Update tomllib import handling and streamline pyproject path resolution --- docs/conf.py | 13 ++++--------- newsfragments/886.bugfix.rst | 4 ++-- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 6f7be709..446252f1 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -16,23 +16,18 @@ # sys.path.insert(0, os.path.abspath('.')) import doctest +import os import sys from unittest.mock import MagicMock try: import tomllib -except ImportError: +except ModuleNotFoundError: # For Python < 3.11 - import tomli as tomllib # type: ignore + import tomli as tomllib # type: ignore (In case of >3.11 Pyrefly doesnt find tomli , which is right but a false flag) # Path to pyproject.toml (assuming conf.py is in a 'docs' subdirectory) -try: - from libp2p.utils.paths import get_project_root, join_paths - pyproject_path = join_paths(get_project_root(), "pyproject.toml") -except ImportError: - # Fallback for documentation builds where libp2p is not available - import os - pyproject_path = os.path.join(os.path.dirname(__file__), "..", "pyproject.toml") +pyproject_path = os.path.join(os.path.dirname(__file__), "..", "pyproject.toml") with open(pyproject_path, "rb") as f: pyproject_data = tomllib.load(f) diff --git a/newsfragments/886.bugfix.rst b/newsfragments/886.bugfix.rst index add7c4ab..1ebf38d1 100644 --- a/newsfragments/886.bugfix.rst +++ b/newsfragments/886.bugfix.rst @@ -1,2 +1,2 @@ -Fixed cross-platform path handling by replacing hardcoded OS-specific -paths with standardized utilities in core modules and examples. \ No newline at end of file +Fixed cross-platform path handling by replacing hardcoded OS-specific +paths with standardized utilities in core modules and examples. From 7d6eb28d7c8fa30468de635890a6d194d56014f5 Mon Sep 17 00:00:00 2001 From: unniznd Date: Mon, 1 Sep 2025 09:48:08 +0530 Subject: [PATCH 67/71] message inconsistency fixed --- libp2p/security/security_multistream.py | 2 +- libp2p/stream_muxer/muxer_multistream.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/libp2p/security/security_multistream.py b/libp2p/security/security_multistream.py index f7c81de1..ee8d4475 100644 --- a/libp2p/security/security_multistream.py +++ b/libp2p/security/security_multistream.py @@ -119,7 +119,7 @@ class SecurityMultistream(ABC): protocol, _ = await self.multiselect.negotiate(communicator) if protocol is None: raise MultiselectError( - "fail to negotiate a security protocol: no protocl selected" + "Failed to negotiate a security protocol: no protocol selected" ) # Return transport from protocol return self.transports[protocol] diff --git a/libp2p/stream_muxer/muxer_multistream.py b/libp2p/stream_muxer/muxer_multistream.py index 76689c17..ef90fac0 100644 --- a/libp2p/stream_muxer/muxer_multistream.py +++ b/libp2p/stream_muxer/muxer_multistream.py @@ -86,7 +86,7 @@ class MuxerMultistream: protocol, _ = await self.multiselect.negotiate(communicator) if protocol is None: raise MultiselectError( - "fail to negotiate a stream muxer protocol: no protocol selected" + "Fail to negotiate a stream muxer protocol: no protocol selected" ) return self.transports[protocol] From aad87f983ff60834dba4a1f682a3f96d3dad1f0f Mon Sep 17 00:00:00 2001 From: bomanaps Date: Mon, 1 Sep 2025 11:58:42 +0100 Subject: [PATCH 68/71] Adress documentation comment --- docs/examples.multiple_connections.rst | 77 +++++++++++++++++++++++--- libp2p/network/swarm.py | 37 ++++++++++++- 2 files changed, 104 insertions(+), 10 deletions(-) diff --git a/docs/examples.multiple_connections.rst b/docs/examples.multiple_connections.rst index 814152b3..85ab8f2d 100644 --- a/docs/examples.multiple_connections.rst +++ b/docs/examples.multiple_connections.rst @@ -96,23 +96,46 @@ Production Configuration For production use, consider these settings: +**RetryConfig Parameters** + +The `RetryConfig` class controls connection retry behavior with exponential backoff: + +- **max_retries**: Maximum number of retry attempts before giving up (default: 3) +- **initial_delay**: Initial delay in seconds before the first retry (default: 0.1s) +- **max_delay**: Maximum delay cap to prevent excessive wait times (default: 30.0s) +- **backoff_multiplier**: Exponential backoff multiplier - each retry multiplies delay by this factor (default: 2.0) +- **jitter_factor**: Random jitter (0.0-1.0) to prevent synchronized retries (default: 0.1) + +**ConnectionConfig Parameters** + +The `ConnectionConfig` class manages multi-connection behavior: + +- **max_connections_per_peer**: Maximum connections allowed to a single peer (default: 3) +- **connection_timeout**: Timeout for establishing new connections in seconds (default: 30.0s) +- **load_balancing_strategy**: Strategy for distributing streams ("round_robin" or "least_loaded") + +**Load Balancing Strategies Explained** + +- **round_robin**: Cycles through connections in order, distributing load evenly. Simple and predictable. +- **least_loaded**: Selects the connection with the fewest active streams. Better for performance but more complex. + .. code-block:: python from libp2p.network.swarm import ConnectionConfig, RetryConfig # Production-ready configuration retry_config = RetryConfig( - max_retries=3, - initial_delay=0.1, - max_delay=30.0, - backoff_multiplier=2.0, - jitter_factor=0.1 + max_retries=3, # Maximum retry attempts before giving up + initial_delay=0.1, # Start with 100ms delay + max_delay=30.0, # Cap exponential backoff at 30 seconds + backoff_multiplier=2.0, # Double delay each retry (100ms -> 200ms -> 400ms) + jitter_factor=0.1 # Add 10% random jitter to prevent thundering herd ) connection_config = ConnectionConfig( - max_connections_per_peer=3, # Balance performance and resources - connection_timeout=30.0, # Reasonable timeout - load_balancing_strategy="round_robin" # Predictable behavior + max_connections_per_peer=3, # Allow up to 3 connections per peer + connection_timeout=30.0, # 30 second timeout for new connections + load_balancing_strategy="round_robin" # Simple, predictable load distribution ) swarm = new_swarm( @@ -120,6 +143,44 @@ For production use, consider these settings: connection_config=connection_config ) +**How RetryConfig Works in Practice** + +With the configuration above, connection retries follow this pattern: + +1. **Attempt 1**: Immediate connection attempt +2. **Attempt 2**: Wait 100ms ± 10ms jitter, then retry +3. **Attempt 3**: Wait 200ms ± 20ms jitter, then retry +4. **Attempt 4**: Wait 400ms ± 40ms jitter, then retry +5. **Attempt 5**: Wait 800ms ± 80ms jitter, then retry +6. **Attempt 6**: Wait 1.6s ± 160ms jitter, then retry +7. **Attempt 7**: Wait 3.2s ± 320ms jitter, then retry +8. **Attempt 8**: Wait 6.4s ± 640ms jitter, then retry +9. **Attempt 9**: Wait 12.8s ± 1.28s jitter, then retry +10. **Attempt 10**: Wait 25.6s ± 2.56s jitter, then retry +11. **Attempt 11**: Wait 30.0s (capped) ± 3.0s jitter, then retry +12. **Attempt 12**: Wait 30.0s (capped) ± 3.0s jitter, then retry +13. **Give up**: After 12 retries (3 initial + 9 retries), connection fails + +The jitter prevents multiple clients from retrying simultaneously, reducing server load. + +**Parameter Tuning Guidelines** + +**For Development/Testing:** +- Use lower `max_retries` (1-2) and shorter delays for faster feedback +- Example: `RetryConfig(max_retries=2, initial_delay=0.01, max_delay=0.1)` + +**For Production:** +- Use moderate `max_retries` (3-5) with reasonable delays for reliability +- Example: `RetryConfig(max_retries=5, initial_delay=0.1, max_delay=60.0)` + +**For High-Latency Networks:** +- Use higher `max_retries` (5-10) with longer delays +- Example: `RetryConfig(max_retries=8, initial_delay=0.5, max_delay=120.0)` + +**For Load Balancing:** +- Use `round_robin` for simple, predictable behavior +- Use `least_loaded` when you need optimal performance and can handle complexity + Architecture ------------ diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 23a94fdb..5a3ce7bb 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -63,7 +63,26 @@ logger = logging.getLogger("libp2p.network.swarm") @dataclass class RetryConfig: - """Configuration for retry logic with exponential backoff.""" + """ + Configuration for retry logic with exponential backoff. + + This configuration controls how connection attempts are retried when they fail. + The retry mechanism uses exponential backoff with jitter to prevent thundering + herd problems in distributed systems. + + Attributes: + max_retries: Maximum number of retry attempts before giving up. + Default: 3 attempts + initial_delay: Initial delay in seconds before the first retry. + Default: 0.1 seconds (100ms) + max_delay: Maximum delay cap in seconds to prevent excessive wait times. + Default: 30.0 seconds + backoff_multiplier: Multiplier for exponential backoff (each retry multiplies + the delay by this factor). Default: 2.0 (doubles each time) + jitter_factor: Random jitter factor (0.0-1.0) to add randomness to delays + and prevent synchronized retries. Default: 0.1 (10% jitter) + + """ max_retries: int = 3 initial_delay: float = 0.1 @@ -74,7 +93,21 @@ class RetryConfig: @dataclass class ConnectionConfig: - """Configuration for multi-connection support.""" + """ + Configuration for multi-connection support. + + This configuration controls how multiple connections per peer are managed, + including connection limits, timeouts, and load balancing strategies. + + Attributes: + max_connections_per_peer: Maximum number of connections allowed to a single + peer. Default: 3 connections + connection_timeout: Timeout in seconds for establishing new connections. + Default: 30.0 seconds + load_balancing_strategy: Strategy for distributing streams across connections. + Options: "round_robin" (default) or "least_loaded" + + """ max_connections_per_peer: int = 3 connection_timeout: float = 30.0 From 10775161968d72b733d2df0bb844aab3fa68b7a0 Mon Sep 17 00:00:00 2001 From: lla-dane Date: Mon, 1 Sep 2025 18:11:22 +0530 Subject: [PATCH 69/71] update newsfragment --- newsfragments/{835.feature.rst => 889.feature.rst} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename newsfragments/{835.feature.rst => 889.feature.rst} (100%) diff --git a/newsfragments/835.feature.rst b/newsfragments/889.feature.rst similarity index 100% rename from newsfragments/835.feature.rst rename to newsfragments/889.feature.rst From 84c1a7031a39cc2efaad412c42850fc29cd6d695 Mon Sep 17 00:00:00 2001 From: yashksaini-coder Date: Tue, 2 Sep 2025 01:23:12 +0530 Subject: [PATCH 70/71] Enhance logging cleanup: Introduce global handler management for proper resource cleanup on exit and during logging setup. Update tests to ensure file handlers are closed correctly across platforms. --- libp2p/utils/logging.py | 22 ++++++++++++++-- tests/utils/test_logging.py | 50 ++++++++++++++++++++++++++----------- tests/utils/test_paths.py | 8 ++++++ 3 files changed, 64 insertions(+), 16 deletions(-) diff --git a/libp2p/utils/logging.py b/libp2p/utils/logging.py index acc67373..a9da0d65 100644 --- a/libp2p/utils/logging.py +++ b/libp2p/utils/logging.py @@ -18,6 +18,9 @@ log_queue: "queue.Queue[Any]" = queue.Queue() # Store the current listener to stop it on exit _current_listener: logging.handlers.QueueListener | None = None +# Store the handlers for proper cleanup +_current_handlers: list[logging.Handler] = [] + # Event to track when the listener is ready _listener_ready = threading.Event() @@ -92,7 +95,7 @@ def setup_logging() -> None: - Child loggers inherit their parent's level unless explicitly set - The root libp2p logger controls the default level """ - global _current_listener, _listener_ready + global _current_listener, _listener_ready, _current_handlers # Reset the event _listener_ready.clear() @@ -101,6 +104,12 @@ def setup_logging() -> None: if _current_listener is not None: _current_listener.stop() _current_listener = None + + # Close and clear existing handlers + for handler in _current_handlers: + if isinstance(handler, logging.FileHandler): + handler.close() + _current_handlers.clear() # Get the log level from environment variable debug_str = os.environ.get("LIBP2P_DEBUG", "") @@ -189,6 +198,9 @@ def setup_logging() -> None: logger.setLevel(level) logger.propagate = False # Prevent message duplication + # Store handlers globally for cleanup + _current_handlers.extend(handlers) + # Start the listener AFTER configuring all loggers _current_listener = logging.handlers.QueueListener( log_queue, *handlers, respect_handler_level=True @@ -203,7 +215,13 @@ def setup_logging() -> None: @atexit.register def cleanup_logging() -> None: """Clean up logging resources on exit.""" - global _current_listener + global _current_listener, _current_handlers if _current_listener is not None: _current_listener.stop() _current_listener = None + + # Close all file handlers to ensure proper cleanup on Windows + for handler in _current_handlers: + if isinstance(handler, logging.FileHandler): + handler.close() + _current_handlers.clear() diff --git a/tests/utils/test_logging.py b/tests/utils/test_logging.py index 603af5e1..05c76ec2 100644 --- a/tests/utils/test_logging.py +++ b/tests/utils/test_logging.py @@ -15,6 +15,7 @@ import pytest import trio from libp2p.utils.logging import ( + _current_handlers, _current_listener, _listener_ready, log_queue, @@ -24,13 +25,19 @@ from libp2p.utils.logging import ( def _reset_logging(): """Reset all logging state.""" - global _current_listener, _listener_ready + global _current_listener, _listener_ready, _current_handlers # Stop existing listener if any if _current_listener is not None: _current_listener.stop() _current_listener = None + # Close all file handlers to ensure proper cleanup on Windows + for handler in _current_handlers: + if isinstance(handler, logging.FileHandler): + handler.close() + _current_handlers.clear() + # Reset the event _listener_ready = threading.Event() @@ -173,6 +180,15 @@ async def test_custom_log_file(clean_env): # Stop the listener to ensure all messages are written if _current_listener is not None: _current_listener.stop() + + # Give a moment for the listener to fully stop + await trio.sleep(0.05) + + # Close all file handlers to release the file + for handler in _current_handlers: + if isinstance(handler, logging.FileHandler): + handler.flush() # Ensure all writes are flushed + handler.close() # Check if the file exists and contains our message assert log_file.exists() @@ -185,17 +201,14 @@ async def test_default_log_file(clean_env): """Test logging to the default file path.""" os.environ["LIBP2P_DEBUG"] = "INFO" - with patch("libp2p.utils.logging.datetime") as mock_datetime: - # Mock the timestamp to have a predictable filename - mock_datetime.now.return_value.strftime.return_value = "20240101_120000" - + with patch("libp2p.utils.paths.create_temp_file") as mock_create_temp: + # Mock the temp file creation to return a predictable path + mock_temp_file = Path(tempfile.gettempdir()) / "test_py-libp2p_20240101_120000.log" + mock_create_temp.return_value = mock_temp_file + # Remove the log file if it exists - if os.name == "nt": # Windows - log_file = Path("C:/Windows/Temp/20240101_120000_py-libp2p.log") - else: # Unix-like - log_file = Path("/tmp/20240101_120000_py-libp2p.log") - log_file.unlink(missing_ok=True) - + mock_temp_file.unlink(missing_ok=True) + setup_logging() # Wait for the listener to be ready @@ -210,10 +223,19 @@ async def test_default_log_file(clean_env): # Stop the listener to ensure all messages are written if _current_listener is not None: _current_listener.stop() + + # Give a moment for the listener to fully stop + await trio.sleep(0.05) + + # Close all file handlers to release the file + for handler in _current_handlers: + if isinstance(handler, logging.FileHandler): + handler.flush() # Ensure all writes are flushed + handler.close() - # Check the default log file - if log_file.exists(): # Only check content if we have write permission - content = log_file.read_text() + # Check the mocked temp file + if mock_temp_file.exists(): + content = mock_temp_file.read_text() assert "Test message" in content diff --git a/tests/utils/test_paths.py b/tests/utils/test_paths.py index e5247cdb..a8eb4ed9 100644 --- a/tests/utils/test_paths.py +++ b/tests/utils/test_paths.py @@ -6,6 +6,8 @@ import os from pathlib import Path import tempfile +import pytest + from libp2p.utils.paths import ( create_temp_file, ensure_dir_exists, @@ -217,6 +219,12 @@ class TestCrossPlatformCompatibility: def test_config_dir_platform_specific_windows(self, monkeypatch): """Test config directory respects Windows conventions.""" + import platform + + # Only run this test on Windows systems + if platform.system() != "Windows": + pytest.skip("This test only runs on Windows systems") + monkeypatch.setattr("os.name", "nt") monkeypatch.setenv("APPDATA", "C:\\Users\\Test\\AppData\\Roaming") config_dir = get_config_dir() From 145727a9baae6948575a49c2d6da1a0ff63a32a9 Mon Sep 17 00:00:00 2001 From: yashksaini-coder Date: Tue, 2 Sep 2025 01:39:24 +0530 Subject: [PATCH 71/71] Refactor logging code: Remove unnecessary blank lines in logging setup and cleanup functions for improved readability. Update tests to reflect formatting changes. --- libp2p/utils/logging.py | 6 +++--- tests/utils/test_logging.py | 16 +++++++++------- tests/utils/test_paths.py | 4 ++-- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/libp2p/utils/logging.py b/libp2p/utils/logging.py index a9da0d65..b23136f5 100644 --- a/libp2p/utils/logging.py +++ b/libp2p/utils/logging.py @@ -104,7 +104,7 @@ def setup_logging() -> None: if _current_listener is not None: _current_listener.stop() _current_listener = None - + # Close and clear existing handlers for handler in _current_handlers: if isinstance(handler, logging.FileHandler): @@ -200,7 +200,7 @@ def setup_logging() -> None: # Store handlers globally for cleanup _current_handlers.extend(handlers) - + # Start the listener AFTER configuring all loggers _current_listener = logging.handlers.QueueListener( log_queue, *handlers, respect_handler_level=True @@ -219,7 +219,7 @@ def cleanup_logging() -> None: if _current_listener is not None: _current_listener.stop() _current_listener = None - + # Close all file handlers to ensure proper cleanup on Windows for handler in _current_handlers: if isinstance(handler, logging.FileHandler): diff --git a/tests/utils/test_logging.py b/tests/utils/test_logging.py index 05c76ec2..06be05c7 100644 --- a/tests/utils/test_logging.py +++ b/tests/utils/test_logging.py @@ -180,10 +180,10 @@ async def test_custom_log_file(clean_env): # Stop the listener to ensure all messages are written if _current_listener is not None: _current_listener.stop() - + # Give a moment for the listener to fully stop await trio.sleep(0.05) - + # Close all file handlers to release the file for handler in _current_handlers: if isinstance(handler, logging.FileHandler): @@ -203,12 +203,14 @@ async def test_default_log_file(clean_env): with patch("libp2p.utils.paths.create_temp_file") as mock_create_temp: # Mock the temp file creation to return a predictable path - mock_temp_file = Path(tempfile.gettempdir()) / "test_py-libp2p_20240101_120000.log" + mock_temp_file = ( + Path(tempfile.gettempdir()) / "test_py-libp2p_20240101_120000.log" + ) mock_create_temp.return_value = mock_temp_file - + # Remove the log file if it exists mock_temp_file.unlink(missing_ok=True) - + setup_logging() # Wait for the listener to be ready @@ -223,10 +225,10 @@ async def test_default_log_file(clean_env): # Stop the listener to ensure all messages are written if _current_listener is not None: _current_listener.stop() - + # Give a moment for the listener to fully stop await trio.sleep(0.05) - + # Close all file handlers to release the file for handler in _current_handlers: if isinstance(handler, logging.FileHandler): diff --git a/tests/utils/test_paths.py b/tests/utils/test_paths.py index a8eb4ed9..421fc557 100644 --- a/tests/utils/test_paths.py +++ b/tests/utils/test_paths.py @@ -220,11 +220,11 @@ class TestCrossPlatformCompatibility: def test_config_dir_platform_specific_windows(self, monkeypatch): """Test config directory respects Windows conventions.""" import platform - + # Only run this test on Windows systems if platform.system() != "Windows": pytest.skip("This test only runs on Windows systems") - + monkeypatch.setattr("os.name", "nt") monkeypatch.setenv("APPDATA", "C:\\Users\\Test\\AppData\\Roaming") config_dir = get_config_dir()