mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-12 08:00:54 +00:00
Add tests for discovery methods in circuit_relay_v2 (#750)
* Add test for direct_connection_relay_discovery * Add test for mux_method_relay_discovery * Fix newsfragments
This commit is contained in:
@ -64,7 +64,11 @@ class PeerStore(IPeerStore):
|
|||||||
return list(self.peer_data_map.keys())
|
return list(self.peer_data_map.keys())
|
||||||
|
|
||||||
def clear_peerdata(self, peer_id: ID) -> None:
|
def clear_peerdata(self, peer_id: ID) -> None:
|
||||||
"""Clears the peer data of the peer"""
|
"""Clears all data associated with the given peer_id."""
|
||||||
|
if peer_id in self.peer_data_map:
|
||||||
|
del self.peer_data_map[peer_id]
|
||||||
|
else:
|
||||||
|
raise PeerStoreError("peer ID not found")
|
||||||
|
|
||||||
def valid_peer_ids(self) -> list[ID]:
|
def valid_peer_ids(self) -> list[ID]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -234,7 +234,8 @@ class RelayDiscovery(Service):
|
|||||||
|
|
||||||
if not callable(proto_getter):
|
if not callable(proto_getter):
|
||||||
return None
|
return None
|
||||||
|
if peer_id not in peerstore.peer_ids():
|
||||||
|
return None
|
||||||
try:
|
try:
|
||||||
# Try to get protocols
|
# Try to get protocols
|
||||||
proto_result = proto_getter(peer_id)
|
proto_result = proto_getter(peer_id)
|
||||||
@ -283,8 +284,6 @@ class RelayDiscovery(Service):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
mux = self.host.get_mux()
|
mux = self.host.get_mux()
|
||||||
if not hasattr(mux, "protocols"):
|
|
||||||
return None
|
|
||||||
|
|
||||||
peer_protocols = set()
|
peer_protocols = set()
|
||||||
# Get protocols from mux with proper type safety
|
# Get protocols from mux with proper type safety
|
||||||
|
|||||||
1
newsfragments/749.internal.rst
Normal file
1
newsfragments/749.internal.rst
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add comprehensive tests for relay_discovery method in circuit_relay_v2
|
||||||
1
newsfragments/750.feature.rst
Normal file
1
newsfragments/750.feature.rst
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add logic to clear_peerdata method in peerstore
|
||||||
@ -105,11 +105,11 @@ async def test_relay_discovery_initialization():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
async def test_relay_discovery_find_relay():
|
async def test_relay_discovery_find_relay_peerstore_method():
|
||||||
"""Test finding a relay node via discovery."""
|
"""Test finding a relay node via discovery using the peerstore method."""
|
||||||
async with HostFactory.create_batch_and_listen(2) as hosts:
|
async with HostFactory.create_batch_and_listen(2) as hosts:
|
||||||
relay_host, client_host = hosts
|
relay_host, client_host = hosts
|
||||||
logger.info("Created hosts for test_relay_discovery_find_relay")
|
logger.info("Created host for test_relay_discovery_find_relay_peerstore_method")
|
||||||
logger.info("Relay host ID: %s", relay_host.get_id())
|
logger.info("Relay host ID: %s", relay_host.get_id())
|
||||||
logger.info("Client host ID: %s", client_host.get_id())
|
logger.info("Client host ID: %s", client_host.get_id())
|
||||||
|
|
||||||
@ -144,19 +144,19 @@ async def test_relay_discovery_find_relay():
|
|||||||
# Start discovery service
|
# Start discovery service
|
||||||
async with background_trio_service(client_discovery):
|
async with background_trio_service(client_discovery):
|
||||||
await client_discovery.event_started.wait()
|
await client_discovery.event_started.wait()
|
||||||
logger.info("Client discovery service started")
|
logger.info("Client discovery service started (peerstore method)")
|
||||||
|
|
||||||
# Wait for discovery to find the relay
|
# Wait for discovery to find the relay using the peerstore method
|
||||||
logger.info("Waiting for relay discovery...")
|
logger.info("Waiting for relay discovery using peerstore...")
|
||||||
|
|
||||||
# Manually trigger discovery instead of waiting
|
# Manually trigger discovery which uses peerstore as default
|
||||||
await client_discovery.discover_relays()
|
await client_discovery.discover_relays()
|
||||||
|
|
||||||
# Check if relay was found
|
# Check if relay was found
|
||||||
with trio.fail_after(DISCOVERY_TIMEOUT):
|
with trio.fail_after(DISCOVERY_TIMEOUT):
|
||||||
for _ in range(20): # Try multiple times
|
for _ in range(20): # Try multiple times
|
||||||
if relay_host.get_id() in client_discovery._discovered_relays:
|
if relay_host.get_id() in client_discovery._discovered_relays:
|
||||||
logger.info("Relay discovered successfully")
|
logger.info("Relay discovered successfully (peerstore method)")
|
||||||
break
|
break
|
||||||
|
|
||||||
# Wait and try again
|
# Wait and try again
|
||||||
@ -164,14 +164,194 @@ async def test_relay_discovery_find_relay():
|
|||||||
# Manually trigger discovery again
|
# Manually trigger discovery again
|
||||||
await client_discovery.discover_relays()
|
await client_discovery.discover_relays()
|
||||||
else:
|
else:
|
||||||
pytest.fail("Failed to discover relay node within timeout")
|
pytest.fail(
|
||||||
|
"Failed to discover relay node within timeout(peerstore method)"
|
||||||
|
)
|
||||||
|
|
||||||
# Verify that relay was found and is valid
|
# Verify that relay was found and is valid
|
||||||
assert relay_host.get_id() in client_discovery._discovered_relays, (
|
assert relay_host.get_id() in client_discovery._discovered_relays, (
|
||||||
"Relay should be discovered"
|
"Relay should be discovered (peerstore method)"
|
||||||
)
|
)
|
||||||
relay_info = client_discovery._discovered_relays[relay_host.get_id()]
|
relay_info = client_discovery._discovered_relays[relay_host.get_id()]
|
||||||
assert relay_info.peer_id == relay_host.get_id(), "Peer ID should match"
|
assert relay_info.peer_id == relay_host.get_id(), (
|
||||||
|
"Peer ID should match (peerstore method)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_relay_discovery_find_relay_direct_connection_method():
|
||||||
|
"""Test finding a relay node via discovery using the direct connection method."""
|
||||||
|
async with HostFactory.create_batch_and_listen(2) as hosts:
|
||||||
|
relay_host, client_host = hosts
|
||||||
|
logger.info("Created hosts for test_relay_discovery_find_relay_direct_method")
|
||||||
|
logger.info("Relay host ID: %s", relay_host.get_id())
|
||||||
|
logger.info("Client host ID: %s", client_host.get_id())
|
||||||
|
|
||||||
|
# Explicitly register the protocol handlers on relay_host
|
||||||
|
relay_host.set_stream_handler(PROTOCOL_ID, simple_stream_handler)
|
||||||
|
relay_host.set_stream_handler(STOP_PROTOCOL_ID, simple_stream_handler)
|
||||||
|
|
||||||
|
# Manually add protocol to peerstore for testing, then remove to force fallback
|
||||||
|
client_host.get_peerstore().add_protocols(
|
||||||
|
relay_host.get_id(), [str(PROTOCOL_ID)]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set up discovery on the client host
|
||||||
|
client_discovery = RelayDiscovery(
|
||||||
|
client_host, discovery_interval=5
|
||||||
|
) # Use shorter interval for testing
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Connect peers so they can discover each other
|
||||||
|
with trio.fail_after(CONNECT_TIMEOUT):
|
||||||
|
logger.info("Connecting client host to relay host")
|
||||||
|
await connect(client_host, relay_host)
|
||||||
|
assert relay_host.get_network().connections[client_host.get_id()], (
|
||||||
|
"Peers not connected"
|
||||||
|
)
|
||||||
|
logger.info("Connection established between peers")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to connect peers: %s", str(e))
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Remove the relay from the peerstore to test fallback to direct connection
|
||||||
|
client_host.get_peerstore().clear_peerdata(relay_host.get_id())
|
||||||
|
# Make sure that peer_id is not present in peerstore
|
||||||
|
assert relay_host.get_id() not in client_host.get_peerstore().peer_ids()
|
||||||
|
|
||||||
|
# Start discovery service
|
||||||
|
async with background_trio_service(client_discovery):
|
||||||
|
await client_discovery.event_started.wait()
|
||||||
|
logger.info("Client discovery service started (direct connection method)")
|
||||||
|
|
||||||
|
# Wait for discovery to find the relay using the direct connection method
|
||||||
|
logger.info(
|
||||||
|
"Waiting for relay discovery using direct connection fallback..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Manually trigger discovery which should fallback to direct connection
|
||||||
|
await client_discovery.discover_relays()
|
||||||
|
|
||||||
|
# Check if relay was found
|
||||||
|
with trio.fail_after(DISCOVERY_TIMEOUT):
|
||||||
|
for _ in range(20): # Try multiple times
|
||||||
|
if relay_host.get_id() in client_discovery._discovered_relays:
|
||||||
|
logger.info("Relay discovered successfully (direct method)")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Wait and try again
|
||||||
|
await trio.sleep(1)
|
||||||
|
# Manually trigger discovery again
|
||||||
|
await client_discovery.discover_relays()
|
||||||
|
else:
|
||||||
|
pytest.fail(
|
||||||
|
"Failed to discover relay node within timeout (direct method)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify that relay was found and is valid
|
||||||
|
assert relay_host.get_id() in client_discovery._discovered_relays, (
|
||||||
|
"Relay should be discovered (direct method)"
|
||||||
|
)
|
||||||
|
relay_info = client_discovery._discovered_relays[relay_host.get_id()]
|
||||||
|
assert relay_info.peer_id == relay_host.get_id(), (
|
||||||
|
"Peer ID should match (direct method)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_relay_discovery_find_relay_mux_method():
|
||||||
|
"""
|
||||||
|
Test finding a relay node via discovery using the mux method
|
||||||
|
(fallback after direct connection fails).
|
||||||
|
"""
|
||||||
|
async with HostFactory.create_batch_and_listen(2) as hosts:
|
||||||
|
relay_host, client_host = hosts
|
||||||
|
logger.info("Created hosts for test_relay_discovery_find_relay_mux_method")
|
||||||
|
logger.info("Relay host ID: %s", relay_host.get_id())
|
||||||
|
logger.info("Client host ID: %s", client_host.get_id())
|
||||||
|
|
||||||
|
# Explicitly register the protocol handlers on relay_host
|
||||||
|
relay_host.set_stream_handler(PROTOCOL_ID, simple_stream_handler)
|
||||||
|
relay_host.set_stream_handler(STOP_PROTOCOL_ID, simple_stream_handler)
|
||||||
|
|
||||||
|
client_host.set_stream_handler(PROTOCOL_ID, simple_stream_handler)
|
||||||
|
client_host.set_stream_handler(STOP_PROTOCOL_ID, simple_stream_handler)
|
||||||
|
|
||||||
|
# Set up discovery on the client host
|
||||||
|
client_discovery = RelayDiscovery(
|
||||||
|
client_host, discovery_interval=5
|
||||||
|
) # Use shorter interval for testing
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Connect peers so they can discover each other
|
||||||
|
with trio.fail_after(CONNECT_TIMEOUT):
|
||||||
|
logger.info("Connecting client host to relay host")
|
||||||
|
await connect(client_host, relay_host)
|
||||||
|
assert relay_host.get_network().connections[client_host.get_id()], (
|
||||||
|
"Peers not connected"
|
||||||
|
)
|
||||||
|
logger.info("Connection established between peers")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to connect peers: %s", str(e))
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Remove the relay from the peerstore to test fallback
|
||||||
|
client_host.get_peerstore().clear_peerdata(relay_host.get_id())
|
||||||
|
# Make sure that peer_id is not present in peerstore
|
||||||
|
assert relay_host.get_id() not in client_host.get_peerstore().peer_ids()
|
||||||
|
|
||||||
|
# Mock the _check_via_direct_connection method to return None
|
||||||
|
# This forces the discovery to fall back to the mux method
|
||||||
|
async def mock_direct_check_fails(peer_id):
|
||||||
|
"""Mock that always returns None to force mux fallback."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
client_discovery._check_via_direct_connection = mock_direct_check_fails
|
||||||
|
|
||||||
|
# Start discovery service
|
||||||
|
async with background_trio_service(client_discovery):
|
||||||
|
await client_discovery.event_started.wait()
|
||||||
|
logger.info("Client discovery service started (mux method)")
|
||||||
|
|
||||||
|
# Wait for discovery to find the relay using the mux method
|
||||||
|
logger.info("Waiting for relay discovery using mux fallback...")
|
||||||
|
|
||||||
|
# Manually trigger discovery which should fallback to mux method
|
||||||
|
await client_discovery.discover_relays()
|
||||||
|
|
||||||
|
# Check if relay was found
|
||||||
|
with trio.fail_after(DISCOVERY_TIMEOUT):
|
||||||
|
for _ in range(20): # Try multiple times
|
||||||
|
if relay_host.get_id() in client_discovery._discovered_relays:
|
||||||
|
logger.info("Relay discovered successfully (mux method)")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Wait and try again
|
||||||
|
await trio.sleep(1)
|
||||||
|
# Manually trigger discovery again
|
||||||
|
await client_discovery.discover_relays()
|
||||||
|
else:
|
||||||
|
pytest.fail(
|
||||||
|
"Failed to discover relay node within timeout (mux method)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify that relay was found and is valid
|
||||||
|
assert relay_host.get_id() in client_discovery._discovered_relays, (
|
||||||
|
"Relay should be discovered (mux method)"
|
||||||
|
)
|
||||||
|
relay_info = client_discovery._discovered_relays[relay_host.get_id()]
|
||||||
|
assert relay_info.peer_id == relay_host.get_id(), (
|
||||||
|
"Peer ID should match (mux method)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify that the protocol was cached via mux method
|
||||||
|
assert relay_host.get_id() in client_discovery._protocol_cache, (
|
||||||
|
"Protocol should be cached (mux method)"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
str(PROTOCOL_ID)
|
||||||
|
in client_discovery._protocol_cache[relay_host.get_id()]
|
||||||
|
), "Relay protocol should be in cache (mux method)"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
|
|||||||
Reference in New Issue
Block a user