From dd14aad47c60b7862c8c3f3a31a5e4e6701a9a3e Mon Sep 17 00:00:00 2001 From: Sukhman Singh <63765293+sukhman-sukh@users.noreply.github.com> Date: Sat, 12 Jul 2025 02:23:27 +0530 Subject: [PATCH 1/2] Add tests for discovery methods in circuit_relay_v2 (#750) * Add test for direct_connection_relay_discovery * Add test for mux_method_relay_discovery * Fix newsfragments --- libp2p/peer/peerstore.py | 6 +- libp2p/relay/circuit_v2/discovery.py | 5 +- newsfragments/749.internal.rst | 1 + newsfragments/750.feature.rst | 1 + tests/core/relay/test_circuit_v2_discovery.py | 202 +++++++++++++++++- 5 files changed, 200 insertions(+), 15 deletions(-) create mode 100644 newsfragments/749.internal.rst create mode 100644 newsfragments/750.feature.rst diff --git a/libp2p/peer/peerstore.py b/libp2p/peer/peerstore.py index 40cb7893..7f67e575 100644 --- a/libp2p/peer/peerstore.py +++ b/libp2p/peer/peerstore.py @@ -64,7 +64,11 @@ class PeerStore(IPeerStore): return list(self.peer_data_map.keys()) def clear_peerdata(self, peer_id: ID) -> None: - """Clears the peer data of the peer""" + """Clears all data associated with the given peer_id.""" + if peer_id in self.peer_data_map: + del self.peer_data_map[peer_id] + else: + raise PeerStoreError("peer ID not found") def valid_peer_ids(self) -> list[ID]: """ diff --git a/libp2p/relay/circuit_v2/discovery.py b/libp2p/relay/circuit_v2/discovery.py index b1310d8d..734a7869 100644 --- a/libp2p/relay/circuit_v2/discovery.py +++ b/libp2p/relay/circuit_v2/discovery.py @@ -234,7 +234,8 @@ class RelayDiscovery(Service): if not callable(proto_getter): return None - + if peer_id not in peerstore.peer_ids(): + return None try: # Try to get protocols proto_result = proto_getter(peer_id) @@ -283,8 +284,6 @@ class RelayDiscovery(Service): return None mux = self.host.get_mux() - if not hasattr(mux, "protocols"): - return None peer_protocols = set() # Get protocols from mux with proper type safety diff --git a/newsfragments/749.internal.rst b/newsfragments/749.internal.rst new file mode 100644 index 00000000..c7316d8c --- /dev/null +++ b/newsfragments/749.internal.rst @@ -0,0 +1 @@ +Add comprehensive tests for relay_discovery method in circuit_relay_v2 diff --git a/newsfragments/750.feature.rst b/newsfragments/750.feature.rst new file mode 100644 index 00000000..a49c5fb7 --- /dev/null +++ b/newsfragments/750.feature.rst @@ -0,0 +1 @@ +Add logic to clear_peerdata method in peerstore diff --git a/tests/core/relay/test_circuit_v2_discovery.py b/tests/core/relay/test_circuit_v2_discovery.py index 97ed353f..923f5937 100644 --- a/tests/core/relay/test_circuit_v2_discovery.py +++ b/tests/core/relay/test_circuit_v2_discovery.py @@ -105,11 +105,11 @@ async def test_relay_discovery_initialization(): @pytest.mark.trio -async def test_relay_discovery_find_relay(): - """Test finding a relay node via discovery.""" +async def test_relay_discovery_find_relay_peerstore_method(): + """Test finding a relay node via discovery using the peerstore method.""" async with HostFactory.create_batch_and_listen(2) as hosts: relay_host, client_host = hosts - logger.info("Created hosts for test_relay_discovery_find_relay") + logger.info("Created host for test_relay_discovery_find_relay_peerstore_method") logger.info("Relay host ID: %s", relay_host.get_id()) logger.info("Client host ID: %s", client_host.get_id()) @@ -144,19 +144,19 @@ async def test_relay_discovery_find_relay(): # Start discovery service async with background_trio_service(client_discovery): await client_discovery.event_started.wait() - logger.info("Client discovery service started") + logger.info("Client discovery service started (peerstore method)") - # Wait for discovery to find the relay - logger.info("Waiting for relay discovery...") + # Wait for discovery to find the relay using the peerstore method + logger.info("Waiting for relay discovery using peerstore...") - # Manually trigger discovery instead of waiting + # Manually trigger discovery which uses peerstore as default await client_discovery.discover_relays() # Check if relay was found with trio.fail_after(DISCOVERY_TIMEOUT): for _ in range(20): # Try multiple times if relay_host.get_id() in client_discovery._discovered_relays: - logger.info("Relay discovered successfully") + logger.info("Relay discovered successfully (peerstore method)") break # Wait and try again @@ -164,14 +164,194 @@ async def test_relay_discovery_find_relay(): # Manually trigger discovery again await client_discovery.discover_relays() else: - pytest.fail("Failed to discover relay node within timeout") + pytest.fail( + "Failed to discover relay node within timeout(peerstore method)" + ) # Verify that relay was found and is valid assert relay_host.get_id() in client_discovery._discovered_relays, ( - "Relay should be discovered" + "Relay should be discovered (peerstore method)" ) relay_info = client_discovery._discovered_relays[relay_host.get_id()] - assert relay_info.peer_id == relay_host.get_id(), "Peer ID should match" + assert relay_info.peer_id == relay_host.get_id(), ( + "Peer ID should match (peerstore method)" + ) + + +@pytest.mark.trio +async def test_relay_discovery_find_relay_direct_connection_method(): + """Test finding a relay node via discovery using the direct connection method.""" + async with HostFactory.create_batch_and_listen(2) as hosts: + relay_host, client_host = hosts + logger.info("Created hosts for test_relay_discovery_find_relay_direct_method") + logger.info("Relay host ID: %s", relay_host.get_id()) + logger.info("Client host ID: %s", client_host.get_id()) + + # Explicitly register the protocol handlers on relay_host + relay_host.set_stream_handler(PROTOCOL_ID, simple_stream_handler) + relay_host.set_stream_handler(STOP_PROTOCOL_ID, simple_stream_handler) + + # Manually add protocol to peerstore for testing, then remove to force fallback + client_host.get_peerstore().add_protocols( + relay_host.get_id(), [str(PROTOCOL_ID)] + ) + + # Set up discovery on the client host + client_discovery = RelayDiscovery( + client_host, discovery_interval=5 + ) # Use shorter interval for testing + + try: + # Connect peers so they can discover each other + with trio.fail_after(CONNECT_TIMEOUT): + logger.info("Connecting client host to relay host") + await connect(client_host, relay_host) + assert relay_host.get_network().connections[client_host.get_id()], ( + "Peers not connected" + ) + logger.info("Connection established between peers") + except Exception as e: + logger.error("Failed to connect peers: %s", str(e)) + raise + + # Remove the relay from the peerstore to test fallback to direct connection + client_host.get_peerstore().clear_peerdata(relay_host.get_id()) + # Make sure that peer_id is not present in peerstore + assert relay_host.get_id() not in client_host.get_peerstore().peer_ids() + + # Start discovery service + async with background_trio_service(client_discovery): + await client_discovery.event_started.wait() + logger.info("Client discovery service started (direct connection method)") + + # Wait for discovery to find the relay using the direct connection method + logger.info( + "Waiting for relay discovery using direct connection fallback..." + ) + + # Manually trigger discovery which should fallback to direct connection + await client_discovery.discover_relays() + + # Check if relay was found + with trio.fail_after(DISCOVERY_TIMEOUT): + for _ in range(20): # Try multiple times + if relay_host.get_id() in client_discovery._discovered_relays: + logger.info("Relay discovered successfully (direct method)") + break + + # Wait and try again + await trio.sleep(1) + # Manually trigger discovery again + await client_discovery.discover_relays() + else: + pytest.fail( + "Failed to discover relay node within timeout (direct method)" + ) + + # Verify that relay was found and is valid + assert relay_host.get_id() in client_discovery._discovered_relays, ( + "Relay should be discovered (direct method)" + ) + relay_info = client_discovery._discovered_relays[relay_host.get_id()] + assert relay_info.peer_id == relay_host.get_id(), ( + "Peer ID should match (direct method)" + ) + + +@pytest.mark.trio +async def test_relay_discovery_find_relay_mux_method(): + """ + Test finding a relay node via discovery using the mux method + (fallback after direct connection fails). + """ + async with HostFactory.create_batch_and_listen(2) as hosts: + relay_host, client_host = hosts + logger.info("Created hosts for test_relay_discovery_find_relay_mux_method") + logger.info("Relay host ID: %s", relay_host.get_id()) + logger.info("Client host ID: %s", client_host.get_id()) + + # Explicitly register the protocol handlers on relay_host + relay_host.set_stream_handler(PROTOCOL_ID, simple_stream_handler) + relay_host.set_stream_handler(STOP_PROTOCOL_ID, simple_stream_handler) + + client_host.set_stream_handler(PROTOCOL_ID, simple_stream_handler) + client_host.set_stream_handler(STOP_PROTOCOL_ID, simple_stream_handler) + + # Set up discovery on the client host + client_discovery = RelayDiscovery( + client_host, discovery_interval=5 + ) # Use shorter interval for testing + + try: + # Connect peers so they can discover each other + with trio.fail_after(CONNECT_TIMEOUT): + logger.info("Connecting client host to relay host") + await connect(client_host, relay_host) + assert relay_host.get_network().connections[client_host.get_id()], ( + "Peers not connected" + ) + logger.info("Connection established between peers") + except Exception as e: + logger.error("Failed to connect peers: %s", str(e)) + raise + + # Remove the relay from the peerstore to test fallback + client_host.get_peerstore().clear_peerdata(relay_host.get_id()) + # Make sure that peer_id is not present in peerstore + assert relay_host.get_id() not in client_host.get_peerstore().peer_ids() + + # Mock the _check_via_direct_connection method to return None + # This forces the discovery to fall back to the mux method + async def mock_direct_check_fails(peer_id): + """Mock that always returns None to force mux fallback.""" + return None + + client_discovery._check_via_direct_connection = mock_direct_check_fails + + # Start discovery service + async with background_trio_service(client_discovery): + await client_discovery.event_started.wait() + logger.info("Client discovery service started (mux method)") + + # Wait for discovery to find the relay using the mux method + logger.info("Waiting for relay discovery using mux fallback...") + + # Manually trigger discovery which should fallback to mux method + await client_discovery.discover_relays() + + # Check if relay was found + with trio.fail_after(DISCOVERY_TIMEOUT): + for _ in range(20): # Try multiple times + if relay_host.get_id() in client_discovery._discovered_relays: + logger.info("Relay discovered successfully (mux method)") + break + + # Wait and try again + await trio.sleep(1) + # Manually trigger discovery again + await client_discovery.discover_relays() + else: + pytest.fail( + "Failed to discover relay node within timeout (mux method)" + ) + + # Verify that relay was found and is valid + assert relay_host.get_id() in client_discovery._discovered_relays, ( + "Relay should be discovered (mux method)" + ) + relay_info = client_discovery._discovered_relays[relay_host.get_id()] + assert relay_info.peer_id == relay_host.get_id(), ( + "Peer ID should match (mux method)" + ) + + # Verify that the protocol was cached via mux method + assert relay_host.get_id() in client_discovery._protocol_cache, ( + "Protocol should be cached (mux method)" + ) + assert ( + str(PROTOCOL_ID) + in client_discovery._protocol_cache[relay_host.get_id()] + ), "Relay protocol should be in cache (mux method)" @pytest.mark.trio From 5fcfc677f31e17b9bd842425bab66169dc71a339 Mon Sep 17 00:00:00 2001 From: Archit Dabral <147427717+Minimega12121@users.noreply.github.com> Date: Sat, 12 Jul 2025 02:57:17 +0530 Subject: [PATCH 2/2] fixme/correct-type (#746) * fixme/correct-type * added newsfragment and test --- libp2p/abc.py | 10 ++++++--- libp2p/identity/identify/identify.py | 2 +- libp2p/protocol_muxer/multiselect.py | 12 +++++++++++ libp2p/relay/circuit_v2/discovery.py | 6 ++++-- newsfragments/746.bugfix.rst | 3 +++ .../protocol_muxer/test_protocol_muxer.py | 21 +++++++++++++++++++ 6 files changed, 48 insertions(+), 6 deletions(-) create mode 100644 newsfragments/746.bugfix.rst diff --git a/libp2p/abc.py b/libp2p/abc.py index 70c4ab71..3adb04aa 100644 --- a/libp2p/abc.py +++ b/libp2p/abc.py @@ -50,6 +50,11 @@ if TYPE_CHECKING: Pubsub, ) +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from libp2p.protocol_muxer.multiselect import Multiselect + from libp2p.pubsub.pb import ( rpc_pb2, ) @@ -1545,9 +1550,8 @@ class IHost(ABC): """ - # FIXME: Replace with correct return type @abstractmethod - def get_mux(self) -> Any: + def get_mux(self) -> "Multiselect": """ Retrieve the muxer instance for the host. @@ -2158,6 +2162,7 @@ class IMultiselectMuxer(ABC): """ + @abstractmethod def get_protocols(self) -> tuple[TProtocol | None, ...]: """ Retrieve the protocols for which handlers have been registered. @@ -2168,7 +2173,6 @@ class IMultiselectMuxer(ABC): A tuple of registered protocol names. """ - return tuple(self.handlers.keys()) @abstractmethod async def negotiate( diff --git a/libp2p/identity/identify/identify.py b/libp2p/identity/identify/identify.py index 5d066e37..15367c43 100644 --- a/libp2p/identity/identify/identify.py +++ b/libp2p/identity/identify/identify.py @@ -59,7 +59,7 @@ def _mk_identify_protobuf( ) -> Identify: public_key = host.get_public_key() laddrs = host.get_addrs() - protocols = host.get_mux().get_protocols() + protocols = tuple(str(p) for p in host.get_mux().get_protocols() if p is not None) observed_addr = observed_multiaddr.to_bytes() if observed_multiaddr else b"" return Identify( diff --git a/libp2p/protocol_muxer/multiselect.py b/libp2p/protocol_muxer/multiselect.py index 3f6ef02f..8d311391 100644 --- a/libp2p/protocol_muxer/multiselect.py +++ b/libp2p/protocol_muxer/multiselect.py @@ -101,6 +101,18 @@ class Multiselect(IMultiselectMuxer): except trio.TooSlowError: raise MultiselectError("handshake read timeout") + def get_protocols(self) -> tuple[TProtocol | None, ...]: + """ + Retrieve the protocols for which handlers have been registered. + + Returns + ------- + tuple[TProtocol, ...] + A tuple of registered protocol names. + + """ + return tuple(self.handlers.keys()) + async def handshake(self, communicator: IMultiselectCommunicator) -> None: """ Perform handshake to agree on multiselect protocol. diff --git a/libp2p/relay/circuit_v2/discovery.py b/libp2p/relay/circuit_v2/discovery.py index 734a7869..a35eacdc 100644 --- a/libp2p/relay/circuit_v2/discovery.py +++ b/libp2p/relay/circuit_v2/discovery.py @@ -292,7 +292,9 @@ class RelayDiscovery(Service): # Get protocols with proper typing mux_protocols = mux.get_protocols() if isinstance(mux_protocols, (list, tuple)): - available_protocols = list(mux_protocols) + available_protocols = [ + p for p in mux.get_protocols() if p is not None + ] for protocol in available_protocols: try: @@ -312,7 +314,7 @@ class RelayDiscovery(Service): self._protocol_cache[peer_id] = peer_protocols protocol_str = str(PROTOCOL_ID) - for protocol in peer_protocols: + for protocol in map(TProtocol, peer_protocols): if protocol == protocol_str: return True return False diff --git a/newsfragments/746.bugfix.rst b/newsfragments/746.bugfix.rst new file mode 100644 index 00000000..71970b48 --- /dev/null +++ b/newsfragments/746.bugfix.rst @@ -0,0 +1,3 @@ +Improved type safety in `get_mux()` and `get_protocols()` by returning properly typed values instead +of `Any`. Also updated `identify.py` and `discovery.py` to handle `None` values safely and +compare protocols correctly. diff --git a/tests/core/protocol_muxer/test_protocol_muxer.py b/tests/core/protocol_muxer/test_protocol_muxer.py index b089390b..1d6a0f86 100644 --- a/tests/core/protocol_muxer/test_protocol_muxer.py +++ b/tests/core/protocol_muxer/test_protocol_muxer.py @@ -3,6 +3,7 @@ import pytest from libp2p.custom_types import ( TProtocol, ) +from libp2p.protocol_muxer.multiselect import Multiselect from libp2p.tools.utils import ( create_echo_stream_handler, ) @@ -138,3 +139,23 @@ async def test_multistream_command(security_protocol): # Dialer asks for unspoorted command with pytest.raises(ValueError, match="Command not supported"): await dialer.send_command(listener.get_id(), "random") + + +@pytest.mark.trio +async def test_get_protocols_returns_all_registered_protocols(): + ms = Multiselect() + + async def dummy_handler(stream): + pass + + p1 = TProtocol("/echo/1.0.0") + p2 = TProtocol("/foo/1.0.0") + p3 = TProtocol("/bar/1.0.0") + + ms.add_handler(p1, dummy_handler) + ms.add_handler(p2, dummy_handler) + ms.add_handler(p3, dummy_handler) + + protocols = ms.get_protocols() + + assert set(protocols) == {p1, p2, p3}