address architectural refactoring discussed

This commit is contained in:
bomanaps
2025-08-31 01:38:29 +01:00
parent df39e240e7
commit 59e1d9ae39
12 changed files with 705 additions and 571 deletions

View File

@ -1,14 +1,15 @@
import time
from typing import cast
from unittest.mock import Mock
import pytest
from multiaddr import Multiaddr
import trio
from libp2p.abc import INetConn, INetStream
from libp2p.network.exceptions import SwarmException
from libp2p.network.swarm import (
ConnectionConfig,
ConnectionPool,
RetryConfig,
Swarm,
)
@ -21,10 +22,12 @@ class MockConnection(INetConn):
def __init__(self, peer_id: ID, is_closed: bool = False):
self.peer_id = peer_id
self._is_closed = is_closed
self.stream_count = 0
self.streams = set() # Track streams properly
# Mock the muxed_conn attribute that Swarm expects
self.muxed_conn = Mock()
self.muxed_conn.peer_id = peer_id
# Required by INetConn interface
self.event_started = trio.Event()
async def close(self):
self._is_closed = True
@ -34,12 +37,14 @@ class MockConnection(INetConn):
return self._is_closed
async def new_stream(self) -> INetStream:
self.stream_count += 1
return Mock(spec=INetStream)
# Create a mock stream and add it to the connection's stream set
mock_stream = Mock(spec=INetStream)
self.streams.add(mock_stream)
return mock_stream
def get_streams(self) -> tuple[INetStream, ...]:
"""Mock implementation of get_streams."""
return tuple()
"""Return all streams associated with this connection."""
return tuple(self.streams)
def get_transport_addresses(self) -> list[Multiaddr]:
"""Mock implementation of get_transport_addresses."""
@ -70,114 +75,9 @@ async def test_connection_config_defaults():
config = ConnectionConfig()
assert config.max_connections_per_peer == 3
assert config.connection_timeout == 30.0
assert config.enable_connection_pool is True
assert config.load_balancing_strategy == "round_robin"
@pytest.mark.trio
async def test_connection_pool_basic_operations():
"""Test basic ConnectionPool operations."""
pool = ConnectionPool(max_connections_per_peer=2)
peer_id = ID(b"QmTest")
# Test empty pool
assert not pool.has_connection(peer_id)
assert pool.get_connection(peer_id) is None
# Add connection
conn1 = MockConnection(peer_id)
pool.add_connection(peer_id, conn1, "addr1")
assert pool.has_connection(peer_id)
assert pool.get_connection(peer_id) == conn1
# Add second connection
conn2 = MockConnection(peer_id)
pool.add_connection(peer_id, conn2, "addr2")
assert len(pool.peer_connections[peer_id]) == 2
# Test round-robin - should cycle through connections
first_conn = pool.get_connection(peer_id, "round_robin")
second_conn = pool.get_connection(peer_id, "round_robin")
third_conn = pool.get_connection(peer_id, "round_robin")
# Should cycle through both connections
assert first_conn in [conn1, conn2]
assert second_conn in [conn1, conn2]
assert third_conn in [conn1, conn2]
assert first_conn != second_conn or second_conn != third_conn
# Test least loaded - set different stream counts
conn1.stream_count = 5
conn2.stream_count = 1
least_loaded_conn = pool.get_connection(peer_id, "least_loaded")
assert least_loaded_conn == conn2 # conn2 has fewer streams
@pytest.mark.trio
async def test_connection_pool_deduplication():
"""Test connection deduplication by address."""
pool = ConnectionPool(max_connections_per_peer=3)
peer_id = ID(b"QmTest")
conn1 = MockConnection(peer_id)
pool.add_connection(peer_id, conn1, "addr1")
# Try to add connection with same address
conn2 = MockConnection(peer_id)
pool.add_connection(peer_id, conn2, "addr1")
# Should only have one connection
assert len(pool.peer_connections[peer_id]) == 1
assert pool.get_connection(peer_id) == conn1
@pytest.mark.trio
async def test_connection_pool_trimming():
"""Test connection trimming when limit is exceeded."""
pool = ConnectionPool(max_connections_per_peer=2)
peer_id = ID(b"QmTest")
# Add 3 connections
conn1 = MockConnection(peer_id)
conn2 = MockConnection(peer_id)
conn3 = MockConnection(peer_id)
pool.add_connection(peer_id, conn1, "addr1")
pool.add_connection(peer_id, conn2, "addr2")
pool.add_connection(peer_id, conn3, "addr3")
# Should trim to 2 connections
assert len(pool.peer_connections[peer_id]) == 2
# The oldest connections should be removed
remaining_connections = [c.connection for c in pool.peer_connections[peer_id]]
assert conn3 in remaining_connections # Most recent should remain
@pytest.mark.trio
async def test_connection_pool_remove_connection():
"""Test removing connections from pool."""
pool = ConnectionPool(max_connections_per_peer=3)
peer_id = ID(b"QmTest")
conn1 = MockConnection(peer_id)
conn2 = MockConnection(peer_id)
pool.add_connection(peer_id, conn1, "addr1")
pool.add_connection(peer_id, conn2, "addr2")
assert len(pool.peer_connections[peer_id]) == 2
# Remove connection
pool.remove_connection(peer_id, conn1)
assert len(pool.peer_connections[peer_id]) == 1
assert pool.get_connection(peer_id) == conn2
# Remove last connection
pool.remove_connection(peer_id, conn2)
assert not pool.has_connection(peer_id)
@pytest.mark.trio
async def test_enhanced_swarm_constructor():
"""Test enhanced Swarm constructor with new configuration."""
@ -191,19 +91,16 @@ async def test_enhanced_swarm_constructor():
swarm = Swarm(peer_id, peerstore, upgrader, transport)
assert swarm.retry_config.max_retries == 3
assert swarm.connection_config.max_connections_per_peer == 3
assert swarm.connection_pool is not None
assert isinstance(swarm.connections, dict)
# Test with custom config
custom_retry = RetryConfig(max_retries=5, initial_delay=0.5)
custom_conn = ConnectionConfig(
max_connections_per_peer=5, enable_connection_pool=False
)
custom_conn = ConnectionConfig(max_connections_per_peer=5)
swarm = Swarm(peer_id, peerstore, upgrader, transport, custom_retry, custom_conn)
assert swarm.retry_config.max_retries == 5
assert swarm.retry_config.initial_delay == 0.5
assert swarm.connection_config.max_connections_per_peer == 5
assert swarm.connection_pool is None
@pytest.mark.trio
@ -273,143 +170,155 @@ async def test_swarm_retry_logic():
# Should have succeeded after 3 attempts
assert attempt_count[0] == 3
assert result is not None
# Should have taken some time due to retries
assert end_time - start_time > 0.02 # At least 2 delays
assert isinstance(result, MockConnection)
assert end_time - start_time > 0.01 # Should have some delay
@pytest.mark.trio
async def test_swarm_multi_connection_support():
"""Test multi-connection support in Swarm."""
async def test_swarm_load_balancing_strategies():
"""Test load balancing strategies."""
peer_id = ID(b"QmTest")
peerstore = Mock()
upgrader = Mock()
transport = Mock()
connection_config = ConnectionConfig(
max_connections_per_peer=3,
enable_connection_pool=True,
load_balancing_strategy="round_robin",
)
swarm = Swarm(peer_id, peerstore, upgrader, transport)
# Create mock connections with different stream counts
conn1 = MockConnection(peer_id)
conn2 = MockConnection(peer_id)
conn3 = MockConnection(peer_id)
# Add some streams to simulate load
await conn1.new_stream()
await conn1.new_stream()
await conn2.new_stream()
connections = [conn1, conn2, conn3]
# Test round-robin strategy
swarm.connection_config.load_balancing_strategy = "round_robin"
# Cast to satisfy type checker
connections_cast = cast("list[INetConn]", connections)
selected1 = swarm._select_connection(connections_cast, peer_id)
selected2 = swarm._select_connection(connections_cast, peer_id)
selected3 = swarm._select_connection(connections_cast, peer_id)
# Should cycle through connections
assert selected1 in connections
assert selected2 in connections
assert selected3 in connections
# Test least loaded strategy
swarm.connection_config.load_balancing_strategy = "least_loaded"
least_loaded = swarm._select_connection(connections_cast, peer_id)
# conn3 has 0 streams, conn2 has 1 stream, conn1 has 2 streams
# So conn3 should be selected as least loaded
assert least_loaded == conn3
# Test default strategy (first connection)
swarm.connection_config.load_balancing_strategy = "unknown"
default_selected = swarm._select_connection(connections_cast, peer_id)
assert default_selected == conn1
@pytest.mark.trio
async def test_swarm_multiple_connections_api():
"""Test the new multiple connections API methods."""
peer_id = ID(b"QmTest")
peerstore = Mock()
upgrader = Mock()
transport = Mock()
swarm = Swarm(peer_id, peerstore, upgrader, transport)
# Test empty connections
assert swarm.get_connections() == []
assert swarm.get_connections(peer_id) == []
assert swarm.get_connection(peer_id) is None
assert swarm.get_connections_map() == {}
# Add some connections
conn1 = MockConnection(peer_id)
conn2 = MockConnection(peer_id)
swarm.connections[peer_id] = [conn1, conn2]
# Test get_connections with peer_id
peer_connections = swarm.get_connections(peer_id)
assert len(peer_connections) == 2
assert conn1 in peer_connections
assert conn2 in peer_connections
# Test get_connections without peer_id (all connections)
all_connections = swarm.get_connections()
assert len(all_connections) == 2
assert conn1 in all_connections
assert conn2 in all_connections
# Test get_connection (backward compatibility)
single_conn = swarm.get_connection(peer_id)
assert single_conn in [conn1, conn2]
# Test get_connections_map
connections_map = swarm.get_connections_map()
assert peer_id in connections_map
assert connections_map[peer_id] == [conn1, conn2]
@pytest.mark.trio
async def test_swarm_connection_trimming():
"""Test connection trimming when limit is exceeded."""
peer_id = ID(b"QmTest")
peerstore = Mock()
upgrader = Mock()
transport = Mock()
# Set max connections to 2
connection_config = ConnectionConfig(max_connections_per_peer=2)
swarm = Swarm(
peer_id, peerstore, upgrader, transport, connection_config=connection_config
)
# Mock connection pool methods
assert swarm.connection_pool is not None
connection_pool = swarm.connection_pool
connection_pool.has_connection = Mock(return_value=True)
connection_pool.get_connection = Mock(return_value=MockConnection(peer_id))
# Add 3 connections
conn1 = MockConnection(peer_id)
conn2 = MockConnection(peer_id)
conn3 = MockConnection(peer_id)
# Test that new_stream uses connection pool
result = await swarm.new_stream(peer_id)
assert result is not None
# Use the mocked method directly to avoid type checking issues
get_connection_mock = connection_pool.get_connection
assert get_connection_mock.call_count == 1
swarm.connections[peer_id] = [conn1, conn2, conn3]
# Trigger trimming
swarm._trim_connections(peer_id)
# Should have only 2 connections
assert len(swarm.connections[peer_id]) == 2
# The most recent connections should remain
remaining = swarm.connections[peer_id]
assert conn2 in remaining
assert conn3 in remaining
@pytest.mark.trio
async def test_swarm_backward_compatibility():
"""Test that enhanced Swarm maintains backward compatibility."""
"""Test backward compatibility features."""
peer_id = ID(b"QmTest")
peerstore = Mock()
upgrader = Mock()
transport = Mock()
# Create swarm with connection pool disabled
connection_config = ConnectionConfig(enable_connection_pool=False)
swarm = Swarm(
peer_id, peerstore, upgrader, transport, connection_config=connection_config
)
swarm = Swarm(peer_id, peerstore, upgrader, transport)
# Should behave like original swarm
assert swarm.connection_pool is None
assert isinstance(swarm.connections, dict)
# Add connections
conn1 = MockConnection(peer_id)
conn2 = MockConnection(peer_id)
swarm.connections[peer_id] = [conn1, conn2]
# Test that dial_peer still works (will fail due to mocks, but structure is correct)
peerstore.addrs.return_value = [Mock(spec=Multiaddr)]
transport.dial.side_effect = Exception("Transport error")
with pytest.raises(SwarmException):
await swarm.dial_peer(peer_id)
@pytest.mark.trio
async def test_swarm_connection_pool_integration():
"""Test integration between Swarm and ConnectionPool."""
peer_id = ID(b"QmTest")
peerstore = Mock()
upgrader = Mock()
transport = Mock()
connection_config = ConnectionConfig(
max_connections_per_peer=2, enable_connection_pool=True
)
swarm = Swarm(
peer_id, peerstore, upgrader, transport, connection_config=connection_config
)
# Mock successful connection creation
mock_conn = MockConnection(peer_id)
peerstore.addrs.return_value = [Mock(spec=Multiaddr)]
async def mock_dial_with_retry(addr, peer_id):
return mock_conn
swarm._dial_with_retry = mock_dial_with_retry
# Test dial_peer adds to connection pool
result = await swarm.dial_peer(peer_id)
assert result == mock_conn
assert swarm.connection_pool is not None
assert swarm.connection_pool.has_connection(peer_id)
# Test that subsequent calls reuse connection
result2 = await swarm.dial_peer(peer_id)
assert result2 == mock_conn
@pytest.mark.trio
async def test_swarm_connection_cleanup():
"""Test connection cleanup in enhanced Swarm."""
peer_id = ID(b"QmTest")
peerstore = Mock()
upgrader = Mock()
transport = Mock()
connection_config = ConnectionConfig(enable_connection_pool=True)
swarm = Swarm(
peer_id, peerstore, upgrader, transport, connection_config=connection_config
)
# Add a connection
mock_conn = MockConnection(peer_id)
swarm.connections[peer_id] = mock_conn
assert swarm.connection_pool is not None
swarm.connection_pool.add_connection(peer_id, mock_conn, "test_addr")
# Test close_peer removes from pool
await swarm.close_peer(peer_id)
assert swarm.connection_pool is not None
assert not swarm.connection_pool.has_connection(peer_id)
# Test remove_conn removes from pool
mock_conn2 = MockConnection(peer_id)
swarm.connections[peer_id] = mock_conn2
assert swarm.connection_pool is not None
connection_pool = swarm.connection_pool
connection_pool.add_connection(peer_id, mock_conn2, "test_addr2")
# Note: remove_conn expects SwarmConn, but for testing we'll just
# remove from pool directly
connection_pool = swarm.connection_pool
connection_pool.remove_connection(peer_id, mock_conn2)
assert connection_pool is not None
assert not connection_pool.has_connection(peer_id)
# Test connections_legacy property
legacy_connections = swarm.connections_legacy
assert peer_id in legacy_connections
# Should return first connection
assert legacy_connections[peer_id] in [conn1, conn2]
if __name__ == "__main__":

View File

@ -51,14 +51,19 @@ async def test_swarm_dial_peer(security_protocol):
for addr in transport.get_addrs()
)
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
await swarms[0].dial_peer(swarms[1].get_peer_id())
# New: dial_peer now returns list of connections
connections = await swarms[0].dial_peer(swarms[1].get_peer_id())
assert len(connections) > 0
# Verify connections are established in both directions
assert swarms[0].get_peer_id() in swarms[1].connections
assert swarms[1].get_peer_id() in swarms[0].connections
# Test: Reuse connections when we already have ones with a peer.
conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()]
conn = await swarms[0].dial_peer(swarms[1].get_peer_id())
assert conn is conn_to_1
existing_connections = swarms[0].get_connections(swarms[1].get_peer_id())
new_connections = await swarms[0].dial_peer(swarms[1].get_peer_id())
assert new_connections == existing_connections
@pytest.mark.trio
@ -107,7 +112,8 @@ async def test_swarm_close_peer(security_protocol):
@pytest.mark.trio
async def test_swarm_remove_conn(swarm_pair):
swarm_0, swarm_1 = swarm_pair
conn_0 = swarm_0.connections[swarm_1.get_peer_id()]
# Get the first connection from the list
conn_0 = swarm_0.connections[swarm_1.get_peer_id()][0]
swarm_0.remove_conn(conn_0)
assert swarm_1.get_peer_id() not in swarm_0.connections
# Test: Remove twice. There should not be errors.
@ -115,6 +121,67 @@ async def test_swarm_remove_conn(swarm_pair):
assert swarm_1.get_peer_id() not in swarm_0.connections
@pytest.mark.trio
async def test_swarm_multiple_connections(security_protocol):
"""Test multiple connections per peer functionality."""
async with SwarmFactory.create_batch_and_listen(
2, security_protocol=security_protocol
) as swarms:
# Setup multiple addresses for peer
addrs = tuple(
addr
for transport in swarms[1].listeners.values()
for addr in transport.get_addrs()
)
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
# Dial peer - should return list of connections
connections = await swarms[0].dial_peer(swarms[1].get_peer_id())
assert len(connections) > 0
# Test get_connections method
peer_connections = swarms[0].get_connections(swarms[1].get_peer_id())
assert len(peer_connections) == len(connections)
# Test get_connections_map method
connections_map = swarms[0].get_connections_map()
assert swarms[1].get_peer_id() in connections_map
assert len(connections_map[swarms[1].get_peer_id()]) == len(connections)
# Test get_connection method (backward compatibility)
single_conn = swarms[0].get_connection(swarms[1].get_peer_id())
assert single_conn is not None
assert single_conn in connections
@pytest.mark.trio
async def test_swarm_load_balancing(security_protocol):
"""Test load balancing across multiple connections."""
async with SwarmFactory.create_batch_and_listen(
2, security_protocol=security_protocol
) as swarms:
# Setup connection
addrs = tuple(
addr
for transport in swarms[1].listeners.values()
for addr in transport.get_addrs()
)
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
# Create multiple streams - should use load balancing
streams = []
for _ in range(5):
stream = await swarms[0].new_stream(swarms[1].get_peer_id())
streams.append(stream)
# Verify streams were created successfully
assert len(streams) == 5
# Clean up
for stream in streams:
await stream.close()
@pytest.mark.trio
async def test_swarm_multiaddr(security_protocol):
async with SwarmFactory.create_batch_and_listen(