From c8c7e55d5cd92437b5976420d3c342775c7f15bc Mon Sep 17 00:00:00 2001 From: sukhman Date: Tue, 23 Sep 2025 09:05:33 +0530 Subject: [PATCH] Add test and fix ci --- libp2p/relay/circuit_v2/dcutr.py | 4 +- libp2p/relay/circuit_v2/transport.py | 6 +- tests/core/relay/test_circuit_v2_transport.py | 76 +++++++++++++++++++ 3 files changed, 82 insertions(+), 4 deletions(-) diff --git a/libp2p/relay/circuit_v2/dcutr.py b/libp2p/relay/circuit_v2/dcutr.py index 2cece5d2..48ba1a3f 100644 --- a/libp2p/relay/circuit_v2/dcutr.py +++ b/libp2p/relay/circuit_v2/dcutr.py @@ -508,7 +508,9 @@ class DCUtRProtocol(Service): # Handle both single connection and list of connections connections: list[INetConn] = ( - [conn_or_conns] if not isinstance(conn_or_conns, list) else conn_or_conns + list(conn_or_conns) + if not isinstance(conn_or_conns, list) + else conn_or_conns ) # Check if any connection is direct (not relayed) diff --git a/libp2p/relay/circuit_v2/transport.py b/libp2p/relay/circuit_v2/transport.py index 44e4e22d..8ac43d99 100644 --- a/libp2p/relay/circuit_v2/transport.py +++ b/libp2p/relay/circuit_v2/transport.py @@ -90,7 +90,7 @@ class CircuitV2Transport(ITransport): discovery_interval=config.discovery_interval, max_relays=config.max_relays, ) - self.relay_counter = 0 # for round robin load balancing + self.relay_counter = 0 # for round robin load balancing async def dial( self, @@ -235,10 +235,10 @@ class CircuitV2Transport(ITransport): self.relay_counter += 1 if relays_with_reservations: return relays_with_reservations[ - (self.relay_counter-1) % len(relays_with_reservations) + (self.relay_counter - 1) % len(relays_with_reservations) ] elif other_relays: - return other_relays[(self.relay_counter-1) % len(other_relays)] + return other_relays[(self.relay_counter - 1) % len(other_relays)] await trio.sleep(1) attempts += 1 diff --git a/tests/core/relay/test_circuit_v2_transport.py b/tests/core/relay/test_circuit_v2_transport.py index 8498dba4..dc027381 100644 --- a/tests/core/relay/test_circuit_v2_transport.py +++ b/tests/core/relay/test_circuit_v2_transport.py @@ -11,6 +11,7 @@ from libp2p.network.stream.exceptions import ( StreamEOF, StreamReset, ) +from libp2p.peer.peerinfo import PeerInfo from libp2p.relay.circuit_v2.config import ( RelayConfig, ) @@ -344,3 +345,78 @@ async def test_circuit_v2_transport_relay_limits(): # Test successful - transports were initialized with the correct limits logger.info("Transport limit test successful") + + +@pytest.mark.trio +async def test_circuit_v2_transport_relay_selection(): + """Test relay round robin load balancing and reservation priority""" + async with HostFactory.create_batch_and_listen(5) as hosts: + client1_host, relay_host1, relay_host2, relay_host3, target_host = hosts + + # Setup relay with strict limits + limits = RelayLimits( + duration=DEFAULT_RELAY_LIMITS.duration, + data=DEFAULT_RELAY_LIMITS.data, + max_circuit_conns=DEFAULT_RELAY_LIMITS.max_circuit_conns, + max_reservations=DEFAULT_RELAY_LIMITS.max_reservations, + ) + + # Register test handler on target + test_protocol = "/test/echo/1.0.0" + target_host.set_stream_handler(TProtocol(test_protocol), echo_stream_handler) + target_host_info = PeerInfo(target_host.get_id(), target_host.get_addrs()) + client_config = RelayConfig() + + # Client setup + client1_protocol = CircuitV2Protocol(client1_host, limits, allow_hop=False) + client1_discovery = RelayDiscovery( + host=client1_host, + auto_reserve=False, + discovery_interval=client_config.discovery_interval, + max_relays=client_config.max_relays, + ) + client1_transport = CircuitV2Transport( + client1_host, client1_protocol, client_config + ) + client1_transport.discovery = client1_discovery + # Add relay to discovery + relay_id1 = relay_host1.get_id() + relay_id2 = relay_host2.get_id() + relay_id3 = relay_host3.get_id() + + # Connect all peers + try: + with trio.fail_after(CONNECT_TIMEOUT): + # Connect clients to relay + await connect(client1_host, relay_host1) + await connect(client1_host, relay_host2) + await connect(client1_host, relay_host3) + + await client1_discovery._add_relay(relay_id1) + await client1_discovery._add_relay(relay_id2) + await client1_discovery._add_relay(relay_id3) + + logger.info("All connections established") + except Exception as e: + logger.error("Failed to connect peers: %s", str(e)) + raise + + selected_relay = await client1_transport._select_relay(target_host_info) + # Without reservation preferance + # Round robin, so 1st time must be relay1 + assert selected_relay is not None and selected_relay is relay_id1 + + selected_relay = await client1_transport._select_relay(target_host_info) + # Round robin, so 2nd time must be relay2 + assert selected_relay is not None and selected_relay is relay_id2 + + # Mock reservation with relay1 to prioritize over relay2 + relay_info3 = client1_discovery.get_relay_info(relay_id3) + if relay_info3: + relay_info3.has_reservation = True + + selected_relay = await client1_transport._select_relay(target_host_info) + # With reservation preferance, relay2 must be chosen for target_peer. + assert selected_relay is not None and selected_relay is relay_host3.get_id() + + logger.info("Relay selection successful")