Add test and fix ci

This commit is contained in:
sukhman
2025-09-23 09:05:33 +05:30
parent 3363f57338
commit c8c7e55d5c
3 changed files with 82 additions and 4 deletions

View File

@ -508,7 +508,9 @@ class DCUtRProtocol(Service):
# Handle both single connection and list of connections
connections: list[INetConn] = (
[conn_or_conns] if not isinstance(conn_or_conns, list) else conn_or_conns
list(conn_or_conns)
if not isinstance(conn_or_conns, list)
else conn_or_conns
)
# Check if any connection is direct (not relayed)

View File

@ -90,7 +90,7 @@ class CircuitV2Transport(ITransport):
discovery_interval=config.discovery_interval,
max_relays=config.max_relays,
)
self.relay_counter = 0 # for round robin load balancing
self.relay_counter = 0 # for round robin load balancing
async def dial(
self,
@ -235,10 +235,10 @@ class CircuitV2Transport(ITransport):
self.relay_counter += 1
if relays_with_reservations:
return relays_with_reservations[
(self.relay_counter-1) % len(relays_with_reservations)
(self.relay_counter - 1) % len(relays_with_reservations)
]
elif other_relays:
return other_relays[(self.relay_counter-1) % len(other_relays)]
return other_relays[(self.relay_counter - 1) % len(other_relays)]
await trio.sleep(1)
attempts += 1

View File

@ -11,6 +11,7 @@ from libp2p.network.stream.exceptions import (
StreamEOF,
StreamReset,
)
from libp2p.peer.peerinfo import PeerInfo
from libp2p.relay.circuit_v2.config import (
RelayConfig,
)
@ -344,3 +345,78 @@ async def test_circuit_v2_transport_relay_limits():
# Test successful - transports were initialized with the correct limits
logger.info("Transport limit test successful")
@pytest.mark.trio
async def test_circuit_v2_transport_relay_selection():
"""Test relay round robin load balancing and reservation priority"""
async with HostFactory.create_batch_and_listen(5) as hosts:
client1_host, relay_host1, relay_host2, relay_host3, target_host = hosts
# Setup relay with strict limits
limits = RelayLimits(
duration=DEFAULT_RELAY_LIMITS.duration,
data=DEFAULT_RELAY_LIMITS.data,
max_circuit_conns=DEFAULT_RELAY_LIMITS.max_circuit_conns,
max_reservations=DEFAULT_RELAY_LIMITS.max_reservations,
)
# Register test handler on target
test_protocol = "/test/echo/1.0.0"
target_host.set_stream_handler(TProtocol(test_protocol), echo_stream_handler)
target_host_info = PeerInfo(target_host.get_id(), target_host.get_addrs())
client_config = RelayConfig()
# Client setup
client1_protocol = CircuitV2Protocol(client1_host, limits, allow_hop=False)
client1_discovery = RelayDiscovery(
host=client1_host,
auto_reserve=False,
discovery_interval=client_config.discovery_interval,
max_relays=client_config.max_relays,
)
client1_transport = CircuitV2Transport(
client1_host, client1_protocol, client_config
)
client1_transport.discovery = client1_discovery
# Add relay to discovery
relay_id1 = relay_host1.get_id()
relay_id2 = relay_host2.get_id()
relay_id3 = relay_host3.get_id()
# Connect all peers
try:
with trio.fail_after(CONNECT_TIMEOUT):
# Connect clients to relay
await connect(client1_host, relay_host1)
await connect(client1_host, relay_host2)
await connect(client1_host, relay_host3)
await client1_discovery._add_relay(relay_id1)
await client1_discovery._add_relay(relay_id2)
await client1_discovery._add_relay(relay_id3)
logger.info("All connections established")
except Exception as e:
logger.error("Failed to connect peers: %s", str(e))
raise
selected_relay = await client1_transport._select_relay(target_host_info)
# Without reservation preferance
# Round robin, so 1st time must be relay1
assert selected_relay is not None and selected_relay is relay_id1
selected_relay = await client1_transport._select_relay(target_host_info)
# Round robin, so 2nd time must be relay2
assert selected_relay is not None and selected_relay is relay_id2
# Mock reservation with relay1 to prioritize over relay2
relay_info3 = client1_discovery.get_relay_info(relay_id3)
if relay_info3:
relay_info3.has_reservation = True
selected_relay = await client1_transport._select_relay(target_host_info)
# With reservation preferance, relay2 must be chosen for target_peer.
assert selected_relay is not None and selected_relay is relay_host3.get_id()
logger.info("Relay selection successful")