From dd14aad47c60b7862c8c3f3a31a5e4e6701a9a3e Mon Sep 17 00:00:00 2001 From: Sukhman Singh <63765293+sukhman-sukh@users.noreply.github.com> Date: Sat, 12 Jul 2025 02:23:27 +0530 Subject: [PATCH] Add tests for discovery methods in circuit_relay_v2 (#750) * Add test for direct_connection_relay_discovery * Add test for mux_method_relay_discovery * Fix newsfragments --- libp2p/peer/peerstore.py | 6 +- libp2p/relay/circuit_v2/discovery.py | 5 +- newsfragments/749.internal.rst | 1 + newsfragments/750.feature.rst | 1 + tests/core/relay/test_circuit_v2_discovery.py | 202 +++++++++++++++++- 5 files changed, 200 insertions(+), 15 deletions(-) create mode 100644 newsfragments/749.internal.rst create mode 100644 newsfragments/750.feature.rst diff --git a/libp2p/peer/peerstore.py b/libp2p/peer/peerstore.py index 40cb7893..7f67e575 100644 --- a/libp2p/peer/peerstore.py +++ b/libp2p/peer/peerstore.py @@ -64,7 +64,11 @@ class PeerStore(IPeerStore): return list(self.peer_data_map.keys()) def clear_peerdata(self, peer_id: ID) -> None: - """Clears the peer data of the peer""" + """Clears all data associated with the given peer_id.""" + if peer_id in self.peer_data_map: + del self.peer_data_map[peer_id] + else: + raise PeerStoreError("peer ID not found") def valid_peer_ids(self) -> list[ID]: """ diff --git a/libp2p/relay/circuit_v2/discovery.py b/libp2p/relay/circuit_v2/discovery.py index b1310d8d..734a7869 100644 --- a/libp2p/relay/circuit_v2/discovery.py +++ b/libp2p/relay/circuit_v2/discovery.py @@ -234,7 +234,8 @@ class RelayDiscovery(Service): if not callable(proto_getter): return None - + if peer_id not in peerstore.peer_ids(): + return None try: # Try to get protocols proto_result = proto_getter(peer_id) @@ -283,8 +284,6 @@ class RelayDiscovery(Service): return None mux = self.host.get_mux() - if not hasattr(mux, "protocols"): - return None peer_protocols = set() # Get protocols from mux with proper type safety diff --git a/newsfragments/749.internal.rst b/newsfragments/749.internal.rst new file mode 100644 index 00000000..c7316d8c --- /dev/null +++ b/newsfragments/749.internal.rst @@ -0,0 +1 @@ +Add comprehensive tests for relay_discovery method in circuit_relay_v2 diff --git a/newsfragments/750.feature.rst b/newsfragments/750.feature.rst new file mode 100644 index 00000000..a49c5fb7 --- /dev/null +++ b/newsfragments/750.feature.rst @@ -0,0 +1 @@ +Add logic to clear_peerdata method in peerstore diff --git a/tests/core/relay/test_circuit_v2_discovery.py b/tests/core/relay/test_circuit_v2_discovery.py index 97ed353f..923f5937 100644 --- a/tests/core/relay/test_circuit_v2_discovery.py +++ b/tests/core/relay/test_circuit_v2_discovery.py @@ -105,11 +105,11 @@ async def test_relay_discovery_initialization(): @pytest.mark.trio -async def test_relay_discovery_find_relay(): - """Test finding a relay node via discovery.""" +async def test_relay_discovery_find_relay_peerstore_method(): + """Test finding a relay node via discovery using the peerstore method.""" async with HostFactory.create_batch_and_listen(2) as hosts: relay_host, client_host = hosts - logger.info("Created hosts for test_relay_discovery_find_relay") + logger.info("Created host for test_relay_discovery_find_relay_peerstore_method") logger.info("Relay host ID: %s", relay_host.get_id()) logger.info("Client host ID: %s", client_host.get_id()) @@ -144,19 +144,19 @@ async def test_relay_discovery_find_relay(): # Start discovery service async with background_trio_service(client_discovery): await client_discovery.event_started.wait() - logger.info("Client discovery service started") + logger.info("Client discovery service started (peerstore method)") - # Wait for discovery to find the relay - logger.info("Waiting for relay discovery...") + # Wait for discovery to find the relay using the peerstore method + logger.info("Waiting for relay discovery using peerstore...") - # Manually trigger discovery instead of waiting + # Manually trigger discovery which uses peerstore as default await client_discovery.discover_relays() # Check if relay was found with trio.fail_after(DISCOVERY_TIMEOUT): for _ in range(20): # Try multiple times if relay_host.get_id() in client_discovery._discovered_relays: - logger.info("Relay discovered successfully") + logger.info("Relay discovered successfully (peerstore method)") break # Wait and try again @@ -164,14 +164,194 @@ async def test_relay_discovery_find_relay(): # Manually trigger discovery again await client_discovery.discover_relays() else: - pytest.fail("Failed to discover relay node within timeout") + pytest.fail( + "Failed to discover relay node within timeout(peerstore method)" + ) # Verify that relay was found and is valid assert relay_host.get_id() in client_discovery._discovered_relays, ( - "Relay should be discovered" + "Relay should be discovered (peerstore method)" ) relay_info = client_discovery._discovered_relays[relay_host.get_id()] - assert relay_info.peer_id == relay_host.get_id(), "Peer ID should match" + assert relay_info.peer_id == relay_host.get_id(), ( + "Peer ID should match (peerstore method)" + ) + + +@pytest.mark.trio +async def test_relay_discovery_find_relay_direct_connection_method(): + """Test finding a relay node via discovery using the direct connection method.""" + async with HostFactory.create_batch_and_listen(2) as hosts: + relay_host, client_host = hosts + logger.info("Created hosts for test_relay_discovery_find_relay_direct_method") + logger.info("Relay host ID: %s", relay_host.get_id()) + logger.info("Client host ID: %s", client_host.get_id()) + + # Explicitly register the protocol handlers on relay_host + relay_host.set_stream_handler(PROTOCOL_ID, simple_stream_handler) + relay_host.set_stream_handler(STOP_PROTOCOL_ID, simple_stream_handler) + + # Manually add protocol to peerstore for testing, then remove to force fallback + client_host.get_peerstore().add_protocols( + relay_host.get_id(), [str(PROTOCOL_ID)] + ) + + # Set up discovery on the client host + client_discovery = RelayDiscovery( + client_host, discovery_interval=5 + ) # Use shorter interval for testing + + try: + # Connect peers so they can discover each other + with trio.fail_after(CONNECT_TIMEOUT): + logger.info("Connecting client host to relay host") + await connect(client_host, relay_host) + assert relay_host.get_network().connections[client_host.get_id()], ( + "Peers not connected" + ) + logger.info("Connection established between peers") + except Exception as e: + logger.error("Failed to connect peers: %s", str(e)) + raise + + # Remove the relay from the peerstore to test fallback to direct connection + client_host.get_peerstore().clear_peerdata(relay_host.get_id()) + # Make sure that peer_id is not present in peerstore + assert relay_host.get_id() not in client_host.get_peerstore().peer_ids() + + # Start discovery service + async with background_trio_service(client_discovery): + await client_discovery.event_started.wait() + logger.info("Client discovery service started (direct connection method)") + + # Wait for discovery to find the relay using the direct connection method + logger.info( + "Waiting for relay discovery using direct connection fallback..." + ) + + # Manually trigger discovery which should fallback to direct connection + await client_discovery.discover_relays() + + # Check if relay was found + with trio.fail_after(DISCOVERY_TIMEOUT): + for _ in range(20): # Try multiple times + if relay_host.get_id() in client_discovery._discovered_relays: + logger.info("Relay discovered successfully (direct method)") + break + + # Wait and try again + await trio.sleep(1) + # Manually trigger discovery again + await client_discovery.discover_relays() + else: + pytest.fail( + "Failed to discover relay node within timeout (direct method)" + ) + + # Verify that relay was found and is valid + assert relay_host.get_id() in client_discovery._discovered_relays, ( + "Relay should be discovered (direct method)" + ) + relay_info = client_discovery._discovered_relays[relay_host.get_id()] + assert relay_info.peer_id == relay_host.get_id(), ( + "Peer ID should match (direct method)" + ) + + +@pytest.mark.trio +async def test_relay_discovery_find_relay_mux_method(): + """ + Test finding a relay node via discovery using the mux method + (fallback after direct connection fails). + """ + async with HostFactory.create_batch_and_listen(2) as hosts: + relay_host, client_host = hosts + logger.info("Created hosts for test_relay_discovery_find_relay_mux_method") + logger.info("Relay host ID: %s", relay_host.get_id()) + logger.info("Client host ID: %s", client_host.get_id()) + + # Explicitly register the protocol handlers on relay_host + relay_host.set_stream_handler(PROTOCOL_ID, simple_stream_handler) + relay_host.set_stream_handler(STOP_PROTOCOL_ID, simple_stream_handler) + + client_host.set_stream_handler(PROTOCOL_ID, simple_stream_handler) + client_host.set_stream_handler(STOP_PROTOCOL_ID, simple_stream_handler) + + # Set up discovery on the client host + client_discovery = RelayDiscovery( + client_host, discovery_interval=5 + ) # Use shorter interval for testing + + try: + # Connect peers so they can discover each other + with trio.fail_after(CONNECT_TIMEOUT): + logger.info("Connecting client host to relay host") + await connect(client_host, relay_host) + assert relay_host.get_network().connections[client_host.get_id()], ( + "Peers not connected" + ) + logger.info("Connection established between peers") + except Exception as e: + logger.error("Failed to connect peers: %s", str(e)) + raise + + # Remove the relay from the peerstore to test fallback + client_host.get_peerstore().clear_peerdata(relay_host.get_id()) + # Make sure that peer_id is not present in peerstore + assert relay_host.get_id() not in client_host.get_peerstore().peer_ids() + + # Mock the _check_via_direct_connection method to return None + # This forces the discovery to fall back to the mux method + async def mock_direct_check_fails(peer_id): + """Mock that always returns None to force mux fallback.""" + return None + + client_discovery._check_via_direct_connection = mock_direct_check_fails + + # Start discovery service + async with background_trio_service(client_discovery): + await client_discovery.event_started.wait() + logger.info("Client discovery service started (mux method)") + + # Wait for discovery to find the relay using the mux method + logger.info("Waiting for relay discovery using mux fallback...") + + # Manually trigger discovery which should fallback to mux method + await client_discovery.discover_relays() + + # Check if relay was found + with trio.fail_after(DISCOVERY_TIMEOUT): + for _ in range(20): # Try multiple times + if relay_host.get_id() in client_discovery._discovered_relays: + logger.info("Relay discovered successfully (mux method)") + break + + # Wait and try again + await trio.sleep(1) + # Manually trigger discovery again + await client_discovery.discover_relays() + else: + pytest.fail( + "Failed to discover relay node within timeout (mux method)" + ) + + # Verify that relay was found and is valid + assert relay_host.get_id() in client_discovery._discovered_relays, ( + "Relay should be discovered (mux method)" + ) + relay_info = client_discovery._discovered_relays[relay_host.get_id()] + assert relay_info.peer_id == relay_host.get_id(), ( + "Peer ID should match (mux method)" + ) + + # Verify that the protocol was cached via mux method + assert relay_host.get_id() in client_discovery._protocol_cache, ( + "Protocol should be cached (mux method)" + ) + assert ( + str(PROTOCOL_ID) + in client_discovery._protocol_cache[relay_host.get_id()] + ), "Relay protocol should be in cache (mux method)" @pytest.mark.trio