mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-10 07:00:54 +00:00
Merge remote-tracking branch 'upstream/main'
This commit is contained in:
@ -51,14 +51,19 @@ async def test_swarm_dial_peer(security_protocol):
|
||||
for addr in transport.get_addrs()
|
||||
)
|
||||
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
|
||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
|
||||
# New: dial_peer now returns list of connections
|
||||
connections = await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
assert len(connections) > 0
|
||||
|
||||
# Verify connections are established in both directions
|
||||
assert swarms[0].get_peer_id() in swarms[1].connections
|
||||
assert swarms[1].get_peer_id() in swarms[0].connections
|
||||
|
||||
# Test: Reuse connections when we already have ones with a peer.
|
||||
conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()]
|
||||
conn = await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
assert conn is conn_to_1
|
||||
existing_connections = swarms[0].get_connections(swarms[1].get_peer_id())
|
||||
new_connections = await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
assert new_connections == existing_connections
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
@ -107,7 +112,8 @@ async def test_swarm_close_peer(security_protocol):
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_remove_conn(swarm_pair):
|
||||
swarm_0, swarm_1 = swarm_pair
|
||||
conn_0 = swarm_0.connections[swarm_1.get_peer_id()]
|
||||
# Get the first connection from the list
|
||||
conn_0 = swarm_0.connections[swarm_1.get_peer_id()][0]
|
||||
swarm_0.remove_conn(conn_0)
|
||||
assert swarm_1.get_peer_id() not in swarm_0.connections
|
||||
# Test: Remove twice. There should not be errors.
|
||||
@ -115,6 +121,67 @@ async def test_swarm_remove_conn(swarm_pair):
|
||||
assert swarm_1.get_peer_id() not in swarm_0.connections
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_multiple_connections(security_protocol):
|
||||
"""Test multiple connections per peer functionality."""
|
||||
async with SwarmFactory.create_batch_and_listen(
|
||||
2, security_protocol=security_protocol
|
||||
) as swarms:
|
||||
# Setup multiple addresses for peer
|
||||
addrs = tuple(
|
||||
addr
|
||||
for transport in swarms[1].listeners.values()
|
||||
for addr in transport.get_addrs()
|
||||
)
|
||||
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
|
||||
|
||||
# Dial peer - should return list of connections
|
||||
connections = await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
assert len(connections) > 0
|
||||
|
||||
# Test get_connections method
|
||||
peer_connections = swarms[0].get_connections(swarms[1].get_peer_id())
|
||||
assert len(peer_connections) == len(connections)
|
||||
|
||||
# Test get_connections_map method
|
||||
connections_map = swarms[0].get_connections_map()
|
||||
assert swarms[1].get_peer_id() in connections_map
|
||||
assert len(connections_map[swarms[1].get_peer_id()]) == len(connections)
|
||||
|
||||
# Test get_connection method (backward compatibility)
|
||||
single_conn = swarms[0].get_connection(swarms[1].get_peer_id())
|
||||
assert single_conn is not None
|
||||
assert single_conn in connections
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_load_balancing(security_protocol):
|
||||
"""Test load balancing across multiple connections."""
|
||||
async with SwarmFactory.create_batch_and_listen(
|
||||
2, security_protocol=security_protocol
|
||||
) as swarms:
|
||||
# Setup connection
|
||||
addrs = tuple(
|
||||
addr
|
||||
for transport in swarms[1].listeners.values()
|
||||
for addr in transport.get_addrs()
|
||||
)
|
||||
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
|
||||
|
||||
# Create multiple streams - should use load balancing
|
||||
streams = []
|
||||
for _ in range(5):
|
||||
stream = await swarms[0].new_stream(swarms[1].get_peer_id())
|
||||
streams.append(stream)
|
||||
|
||||
# Verify streams were created successfully
|
||||
assert len(streams) == 5
|
||||
|
||||
# Clean up
|
||||
for stream in streams:
|
||||
await stream.close()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_multiaddr(security_protocol):
|
||||
async with SwarmFactory.create_batch_and_listen(
|
||||
|
||||
Reference in New Issue
Block a user