From 59e1d9ae39a09d2730919ace567753412185e940 Mon Sep 17 00:00:00 2001 From: bomanaps Date: Sun, 31 Aug 2025 01:38:29 +0100 Subject: [PATCH] address architectural refactoring discussed --- docs/examples.multiple_connections.rst | 133 +++++ docs/examples.rst | 1 + .../multiple_connections_example.py} | 111 ++-- libp2p/abc.py | 62 ++- libp2p/host/basic_host.py | 4 +- libp2p/network/swarm.py | 503 +++++++++--------- tests/core/host/test_live_peers.py | 4 +- tests/core/network/test_enhanced_swarm.py | 365 +++++-------- tests/core/network/test_swarm.py | 77 ++- .../security/test_security_multistream.py | 3 + .../test_multiplexer_selection.py | 9 +- tests/utils/factories.py | 4 +- 12 files changed, 705 insertions(+), 571 deletions(-) create mode 100644 docs/examples.multiple_connections.rst rename examples/{enhanced_swarm_example.py => doc-examples/multiple_connections_example.py} (55%) diff --git a/docs/examples.multiple_connections.rst b/docs/examples.multiple_connections.rst new file mode 100644 index 00000000..da1d3b02 --- /dev/null +++ b/docs/examples.multiple_connections.rst @@ -0,0 +1,133 @@ +Multiple Connections Per Peer +============================ + +This example demonstrates how to use the multiple connections per peer feature in py-libp2p. + +Overview +-------- + +The multiple connections per peer feature allows a libp2p node to maintain multiple network connections to the same peer. This provides several benefits: + +- **Improved reliability**: If one connection fails, others remain available +- **Better performance**: Load can be distributed across multiple connections +- **Enhanced throughput**: Multiple streams can be created in parallel +- **Fault tolerance**: Redundant connections provide backup paths + +Configuration +------------- + +The feature is configured through the `ConnectionConfig` class: + +.. code-block:: python + + from libp2p.network.swarm import ConnectionConfig + + # Default configuration + config = ConnectionConfig() + print(f"Max connections per peer: {config.max_connections_per_peer}") + print(f"Load balancing strategy: {config.load_balancing_strategy}") + + # Custom configuration + custom_config = ConnectionConfig( + max_connections_per_peer=5, + connection_timeout=60.0, + load_balancing_strategy="least_loaded" + ) + +Load Balancing Strategies +------------------------ + +Two load balancing strategies are available: + +**Round Robin** (default) + Cycles through connections in order, distributing load evenly. + +**Least Loaded** + Selects the connection with the fewest active streams. + +API Usage +--------- + +The new API provides direct access to multiple connections: + +.. code-block:: python + + from libp2p import new_swarm + + # Create swarm with multiple connections support + swarm = new_swarm() + + # Dial a peer - returns list of connections + connections = await swarm.dial_peer(peer_id) + print(f"Established {len(connections)} connections") + + # Get all connections to a peer + peer_connections = swarm.get_connections(peer_id) + + # Get all connections (across all peers) + all_connections = swarm.get_connections() + + # Get the complete connections map + connections_map = swarm.get_connections_map() + + # Backward compatibility - get single connection + single_conn = swarm.get_connection(peer_id) + +Backward Compatibility +--------------------- + +Existing code continues to work through backward compatibility features: + +.. code-block:: python + + # Legacy 1:1 mapping (returns first connection for each peer) + legacy_connections = swarm.connections_legacy + + # Single connection access (returns first available connection) + conn = swarm.get_connection(peer_id) + +Example +------- + +See :doc:`examples/doc-examples/multiple_connections_example.py` for a complete working example. + +Production Configuration +----------------------- + +For production use, consider these settings: + +.. code-block:: python + + from libp2p.network.swarm import ConnectionConfig, RetryConfig + + # Production-ready configuration + retry_config = RetryConfig( + max_retries=3, + initial_delay=0.1, + max_delay=30.0, + backoff_multiplier=2.0, + jitter_factor=0.1 + ) + + connection_config = ConnectionConfig( + max_connections_per_peer=3, # Balance performance and resources + connection_timeout=30.0, # Reasonable timeout + load_balancing_strategy="round_robin" # Predictable behavior + ) + + swarm = new_swarm( + retry_config=retry_config, + connection_config=connection_config + ) + +Architecture +----------- + +The implementation follows the same architectural patterns as the Go and JavaScript reference implementations: + +- **Core data structure**: `dict[ID, list[INetConn]]` for 1:many mapping +- **API consistency**: Methods like `get_connections()` match reference implementations +- **Load balancing**: Integrated at the API level for optimal performance +- **Backward compatibility**: Maintains existing interfaces for gradual migration + +This design ensures consistency across libp2p implementations while providing the benefits of multiple connections per peer. diff --git a/docs/examples.rst b/docs/examples.rst index b8ba44d7..74864cbe 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -15,3 +15,4 @@ Examples examples.kademlia examples.mDNS examples.random_walk + examples.multiple_connections diff --git a/examples/enhanced_swarm_example.py b/examples/doc-examples/multiple_connections_example.py similarity index 55% rename from examples/enhanced_swarm_example.py rename to examples/doc-examples/multiple_connections_example.py index b5367af8..14a71ab8 100644 --- a/examples/enhanced_swarm_example.py +++ b/examples/doc-examples/multiple_connections_example.py @@ -1,18 +1,18 @@ #!/usr/bin/env python3 """ -Example demonstrating the enhanced Swarm with retry logic, exponential backoff, -and multi-connection support. +Example demonstrating multiple connections per peer support in libp2p. This example shows how to: -1. Configure retry behavior with exponential backoff -2. Enable multi-connection support with connection pooling -3. Use different load balancing strategies +1. Configure multiple connections per peer +2. Use different load balancing strategies +3. Access multiple connections through the new API 4. Maintain backward compatibility """ -import asyncio import logging +import trio + from libp2p import new_swarm from libp2p.network.swarm import ConnectionConfig, RetryConfig @@ -21,64 +21,32 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -async def example_basic_enhanced_swarm() -> None: - """Example of basic enhanced Swarm usage.""" - logger.info("Creating enhanced Swarm with default configuration...") +async def example_basic_multiple_connections() -> None: + """Example of basic multiple connections per peer usage.""" + logger.info("Creating swarm with multiple connections support...") - # Create enhanced swarm with default retry and connection config + # Create swarm with default configuration swarm = new_swarm() - # Use default configuration values directly - default_retry = RetryConfig() default_connection = ConnectionConfig() logger.info(f"Swarm created with peer ID: {swarm.get_peer_id()}") - logger.info(f"Retry config: max_retries={default_retry.max_retries}") logger.info( f"Connection config: max_connections_per_peer=" f"{default_connection.max_connections_per_peer}" ) - logger.info(f"Connection pool enabled: {default_connection.enable_connection_pool}") await swarm.close() - logger.info("Basic enhanced Swarm example completed") - - -async def example_custom_retry_config() -> None: - """Example of custom retry configuration.""" - logger.info("Creating enhanced Swarm with custom retry configuration...") - - # Custom retry configuration for aggressive retry behavior - retry_config = RetryConfig( - max_retries=5, # More retries - initial_delay=0.05, # Faster initial retry - max_delay=10.0, # Lower max delay - backoff_multiplier=1.5, # Less aggressive backoff - jitter_factor=0.2, # More jitter - ) - - # Create swarm with custom retry config - swarm = new_swarm(retry_config=retry_config) - - logger.info("Custom retry config applied:") - logger.info(f" Max retries: {retry_config.max_retries}") - logger.info(f" Initial delay: {retry_config.initial_delay}s") - logger.info(f" Max delay: {retry_config.max_delay}s") - logger.info(f" Backoff multiplier: {retry_config.backoff_multiplier}") - logger.info(f" Jitter factor: {retry_config.jitter_factor}") - - await swarm.close() - logger.info("Custom retry config example completed") + logger.info("Basic multiple connections example completed") async def example_custom_connection_config() -> None: """Example of custom connection configuration.""" - logger.info("Creating enhanced Swarm with custom connection configuration...") + logger.info("Creating swarm with custom connection configuration...") # Custom connection configuration for high-performance scenarios connection_config = ConnectionConfig( max_connections_per_peer=5, # More connections per peer connection_timeout=60.0, # Longer timeout - enable_connection_pool=True, # Enable connection pooling load_balancing_strategy="least_loaded", # Use least loaded strategy ) @@ -90,9 +58,6 @@ async def example_custom_connection_config() -> None: f" Max connections per peer: {connection_config.max_connections_per_peer}" ) logger.info(f" Connection timeout: {connection_config.connection_timeout}s") - logger.info( - f" Connection pool enabled: {connection_config.enable_connection_pool}" - ) logger.info( f" Load balancing strategy: {connection_config.load_balancing_strategy}" ) @@ -101,22 +66,39 @@ async def example_custom_connection_config() -> None: logger.info("Custom connection config example completed") -async def example_backward_compatibility() -> None: - """Example showing backward compatibility.""" - logger.info("Creating enhanced Swarm with backward compatibility...") +async def example_multiple_connections_api() -> None: + """Example of using the new multiple connections API.""" + logger.info("Demonstrating multiple connections API...") - # Disable connection pool to maintain original behavior - connection_config = ConnectionConfig(enable_connection_pool=False) + connection_config = ConnectionConfig( + max_connections_per_peer=3, + load_balancing_strategy="round_robin" + ) - # Create swarm with connection pool disabled swarm = new_swarm(connection_config=connection_config) - logger.info("Backward compatibility mode:") + logger.info("Multiple connections API features:") + logger.info(" - dial_peer() returns list[INetConn]") + logger.info(" - get_connections(peer_id) returns list[INetConn]") + logger.info(" - get_connections_map() returns dict[ID, list[INetConn]]") logger.info( - f" Connection pool enabled: {connection_config.enable_connection_pool}" + " - get_connection(peer_id) returns INetConn | None (backward compatibility)" ) - logger.info(f" Connections dict type: {type(swarm.connections)}") - logger.info(" Retry logic still available: 3 max retries") + + await swarm.close() + logger.info("Multiple connections API example completed") + + +async def example_backward_compatibility() -> None: + """Example of backward compatibility features.""" + logger.info("Demonstrating backward compatibility...") + + swarm = new_swarm() + + logger.info("Backward compatibility features:") + logger.info(" - connections_legacy property provides 1:1 mapping") + logger.info(" - get_connection() method for single connection access") + logger.info(" - Existing code continues to work") await swarm.close() logger.info("Backward compatibility example completed") @@ -124,7 +106,7 @@ async def example_backward_compatibility() -> None: async def example_production_ready_config() -> None: """Example of production-ready configuration.""" - logger.info("Creating enhanced Swarm with production-ready configuration...") + logger.info("Creating swarm with production-ready configuration...") # Production-ready retry configuration retry_config = RetryConfig( @@ -139,7 +121,6 @@ async def example_production_ready_config() -> None: connection_config = ConnectionConfig( max_connections_per_peer=3, # Balance between performance and resource usage connection_timeout=30.0, # Reasonable timeout - enable_connection_pool=True, # Enable for better performance load_balancing_strategy="round_robin", # Simple, predictable strategy ) @@ -160,19 +141,19 @@ async def example_production_ready_config() -> None: async def main() -> None: """Run all examples.""" - logger.info("Enhanced Swarm Examples") + logger.info("Multiple Connections Per Peer Examples") logger.info("=" * 50) try: - await example_basic_enhanced_swarm() - logger.info("-" * 30) - - await example_custom_retry_config() + await example_basic_multiple_connections() logger.info("-" * 30) await example_custom_connection_config() logger.info("-" * 30) + await example_multiple_connections_api() + logger.info("-" * 30) + await example_backward_compatibility() logger.info("-" * 30) @@ -187,4 +168,4 @@ async def main() -> None: if __name__ == "__main__": - asyncio.run(main()) + trio.run(main) diff --git a/libp2p/abc.py b/libp2p/abc.py index a9748339..964c7454 100644 --- a/libp2p/abc.py +++ b/libp2p/abc.py @@ -1412,15 +1412,16 @@ class INetwork(ABC): ---------- peerstore : IPeerStore The peer store for managing peer information. - connections : dict[ID, INetConn] - A mapping of peer IDs to network connections. + connections : dict[ID, list[INetConn]] + A mapping of peer IDs to lists of network connections + (multiple connections per peer). listeners : dict[str, IListener] A mapping of listener identifiers to listener instances. """ peerstore: IPeerStore - connections: dict[ID, INetConn] + connections: dict[ID, list[INetConn]] listeners: dict[str, IListener] @abstractmethod @@ -1436,9 +1437,56 @@ class INetwork(ABC): """ @abstractmethod - async def dial_peer(self, peer_id: ID) -> INetConn: + def get_connections(self, peer_id: ID | None = None) -> list[INetConn]: """ - Create a connection to the specified peer. + Get connections for peer (like JS getConnections, Go ConnsToPeer). + + Parameters + ---------- + peer_id : ID | None + The peer ID to get connections for. If None, returns all connections. + + Returns + ------- + list[INetConn] + List of connections to the specified peer, or all connections + if peer_id is None. + + """ + + @abstractmethod + def get_connections_map(self) -> dict[ID, list[INetConn]]: + """ + Get all connections map (like JS getConnectionsMap). + + Returns + ------- + dict[ID, list[INetConn]] + The complete mapping of peer IDs to their connection lists. + + """ + + @abstractmethod + def get_connection(self, peer_id: ID) -> INetConn | None: + """ + Get single connection for backward compatibility. + + Parameters + ---------- + peer_id : ID + The peer ID to get a connection for. + + Returns + ------- + INetConn | None + The first available connection, or None if no connections exist. + + """ + + @abstractmethod + async def dial_peer(self, peer_id: ID) -> list[INetConn]: + """ + Create connections to the specified peer with load balancing. Parameters ---------- @@ -1447,8 +1495,8 @@ class INetwork(ABC): Returns ------- - INetConn - The network connection instance to the specified peer. + list[INetConn] + List of established connections to the peer. Raises ------ diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index a0311bd8..a3a89dda 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -338,7 +338,7 @@ class BasicHost(IHost): :param peer_id: ID of the peer to check :return: True if peer has an active connection, False otherwise """ - return peer_id in self._network.connections + return len(self._network.get_connections(peer_id)) > 0 def get_peer_connection_info(self, peer_id: ID) -> INetConn | None: """ @@ -347,4 +347,4 @@ class BasicHost(IHost): :param peer_id: ID of the peer to get info for :return: Connection object if peer is connected, None otherwise """ - return self._network.connections.get(peer_id) + return self._network.get_connection(peer_id) diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 77fe2b6d..23a94fdb 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -74,175 +74,13 @@ class RetryConfig: @dataclass class ConnectionConfig: - """Configuration for connection pool and multi-connection support.""" + """Configuration for multi-connection support.""" max_connections_per_peer: int = 3 connection_timeout: float = 30.0 - enable_connection_pool: bool = True load_balancing_strategy: str = "round_robin" # or "least_loaded" -@dataclass -class ConnectionInfo: - """Information about a connection in the pool.""" - - connection: INetConn - address: str - established_at: float - last_used: float - stream_count: int - is_healthy: bool - - -class ConnectionPool: - """Manages multiple connections per peer with load balancing.""" - - def __init__(self, max_connections_per_peer: int = 3): - self.max_connections_per_peer = max_connections_per_peer - self.peer_connections: dict[ID, list[ConnectionInfo]] = {} - self._round_robin_index: dict[ID, int] = {} - - def add_connection(self, peer_id: ID, connection: INetConn, address: str) -> None: - """Add a connection to the pool with deduplication.""" - if peer_id not in self.peer_connections: - self.peer_connections[peer_id] = [] - - # Check for duplicate connections to the same address - for conn_info in self.peer_connections[peer_id]: - if conn_info.address == address: - logger.debug( - f"Connection to {address} already exists for peer {peer_id}" - ) - return - - # Add new connection - try: - current_time = trio.current_time() - except RuntimeError: - # Fallback for testing contexts where trio is not running - import time - - current_time = time.time() - - conn_info = ConnectionInfo( - connection=connection, - address=address, - established_at=current_time, - last_used=current_time, - stream_count=0, - is_healthy=True, - ) - - self.peer_connections[peer_id].append(conn_info) - - # Trim if we exceed max connections - if len(self.peer_connections[peer_id]) > self.max_connections_per_peer: - self._trim_connections(peer_id) - - def get_connection( - self, peer_id: ID, strategy: str = "round_robin" - ) -> INetConn | None: - """Get a connection using the specified load balancing strategy.""" - if peer_id not in self.peer_connections or not self.peer_connections[peer_id]: - return None - - connections = self.peer_connections[peer_id] - - if strategy == "round_robin": - if peer_id not in self._round_robin_index: - self._round_robin_index[peer_id] = 0 - - index = self._round_robin_index[peer_id] % len(connections) - self._round_robin_index[peer_id] += 1 - - conn_info = connections[index] - try: - conn_info.last_used = trio.current_time() - except RuntimeError: - import time - - conn_info.last_used = time.time() - return conn_info.connection - - elif strategy == "least_loaded": - # Find connection with least streams - # Note: stream_count is a custom attribute we add to connections - conn_info = min( - connections, key=lambda c: getattr(c.connection, "stream_count", 0) - ) - try: - conn_info.last_used = trio.current_time() - except RuntimeError: - import time - - conn_info.last_used = time.time() - return conn_info.connection - - else: - # Default to first connection - conn_info = connections[0] - try: - conn_info.last_used = trio.current_time() - except RuntimeError: - import time - - conn_info.last_used = time.time() - return conn_info.connection - - def has_connection(self, peer_id: ID) -> bool: - """Check if we have any connections to the peer.""" - return ( - peer_id in self.peer_connections and len(self.peer_connections[peer_id]) > 0 - ) - - def remove_connection(self, peer_id: ID, connection: INetConn) -> None: - """Remove a connection from the pool.""" - if peer_id in self.peer_connections: - self.peer_connections[peer_id] = [ - conn_info - for conn_info in self.peer_connections[peer_id] - if conn_info.connection != connection - ] - - # Clean up empty peer entries - if not self.peer_connections[peer_id]: - del self.peer_connections[peer_id] - if peer_id in self._round_robin_index: - del self._round_robin_index[peer_id] - - def _trim_connections(self, peer_id: ID) -> None: - """Remove oldest connections when limit is exceeded.""" - connections = self.peer_connections[peer_id] - if len(connections) <= self.max_connections_per_peer: - return - - # Sort by last used time and remove oldest - connections.sort(key=lambda c: c.last_used) - connections_to_remove = connections[: -self.max_connections_per_peer] - - for conn_info in connections_to_remove: - logger.debug( - f"Trimming old connection to {conn_info.address} for peer {peer_id}" - ) - try: - # Close the connection asynchronously - trio.lowlevel.spawn_system_task( - self._close_connection_async, conn_info.connection - ) - except Exception as e: - logger.warning(f"Error closing trimmed connection: {e}") - - # Keep only the most recently used connections - self.peer_connections[peer_id] = connections[-self.max_connections_per_peer :] - - async def _close_connection_async(self, connection: INetConn) -> None: - """Close a connection asynchronously.""" - try: - await connection.close() - except Exception as e: - logger.warning(f"Error closing connection: {e}") - - def create_default_stream_handler(network: INetworkService) -> StreamHandlerFn: async def stream_handler(stream: INetStream) -> None: await network.get_manager().wait_finished() @@ -256,7 +94,7 @@ class Swarm(Service, INetworkService): upgrader: TransportUpgrader transport: ITransport # Enhanced: Support for multiple connections per peer - connections: dict[ID, INetConn] # Backward compatibility + connections: dict[ID, list[INetConn]] # Multiple connections per peer listeners: dict[str, IListener] common_stream_handler: StreamHandlerFn listener_nursery: trio.Nursery | None @@ -264,10 +102,10 @@ class Swarm(Service, INetworkService): notifees: list[INotifee] - # Enhanced: New configuration and connection pool + # Enhanced: New configuration retry_config: RetryConfig connection_config: ConnectionConfig - connection_pool: ConnectionPool | None + _round_robin_index: dict[ID, int] def __init__( self, @@ -287,16 +125,8 @@ class Swarm(Service, INetworkService): self.retry_config = retry_config or RetryConfig() self.connection_config = connection_config or ConnectionConfig() - # Enhanced: Initialize connection pool if enabled - if self.connection_config.enable_connection_pool: - self.connection_pool = ConnectionPool( - self.connection_config.max_connections_per_peer - ) - else: - self.connection_pool = None - - # Backward compatibility: Keep existing connections dict - self.connections = dict() + # Enhanced: Initialize connections as 1:many mapping + self.connections = {} self.listeners = dict() # Create Notifee array @@ -307,6 +137,9 @@ class Swarm(Service, INetworkService): self.listener_nursery = None self.event_listener_nursery_created = trio.Event() + # Load balancing state + self._round_robin_index = {} + async def run(self) -> None: async with trio.open_nursery() as nursery: # Create a nursery for listener tasks. @@ -326,26 +159,74 @@ class Swarm(Service, INetworkService): def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None: self.common_stream_handler = stream_handler - async def dial_peer(self, peer_id: ID) -> INetConn: + def get_connections(self, peer_id: ID | None = None) -> list[INetConn]: """ - Try to create a connection to peer_id with enhanced retry logic. + Get connections for peer (like JS getConnections, Go ConnsToPeer). + + Parameters + ---------- + peer_id : ID | None + The peer ID to get connections for. If None, returns all connections. + + Returns + ------- + list[INetConn] + List of connections to the specified peer, or all connections + if peer_id is None. + + """ + if peer_id is not None: + return self.connections.get(peer_id, []) + + # Return all connections from all peers + all_conns = [] + for conns in self.connections.values(): + all_conns.extend(conns) + return all_conns + + def get_connections_map(self) -> dict[ID, list[INetConn]]: + """ + Get all connections map (like JS getConnectionsMap). + + Returns + ------- + dict[ID, list[INetConn]] + The complete mapping of peer IDs to their connection lists. + + """ + return self.connections.copy() + + def get_connection(self, peer_id: ID) -> INetConn | None: + """ + Get single connection for backward compatibility. + + Parameters + ---------- + peer_id : ID + The peer ID to get a connection for. + + Returns + ------- + INetConn | None + The first available connection, or None if no connections exist. + + """ + conns = self.get_connections(peer_id) + return conns[0] if conns else None + + async def dial_peer(self, peer_id: ID) -> list[INetConn]: + """ + Try to create connections to peer_id with enhanced retry logic. :param peer_id: peer if we want to dial :raises SwarmException: raised when an error occurs - :return: muxed connection + :return: list of muxed connections """ - # Enhanced: Check connection pool first if enabled - if self.connection_pool and self.connection_pool.has_connection(peer_id): - connection = self.connection_pool.get_connection(peer_id) - if connection: - logger.debug(f"Reusing existing connection to peer {peer_id}") - return connection - - # Enhanced: Check existing single connection for backward compatibility - if peer_id in self.connections: - # If muxed connection already exists for peer_id, - # set muxed connection equal to existing muxed connection - return self.connections[peer_id] + # Check if we already have connections + existing_connections = self.get_connections(peer_id) + if existing_connections: + logger.debug(f"Reusing existing connections to peer {peer_id}") + return existing_connections logger.debug("attempting to dial peer %s", peer_id) @@ -358,23 +239,19 @@ class Swarm(Service, INetworkService): if not addrs: raise SwarmException(f"No known addresses to peer {peer_id}") + connections = [] exceptions: list[SwarmException] = [] # Enhanced: Try all known addresses with retry logic for multiaddr in addrs: try: connection = await self._dial_with_retry(multiaddr, peer_id) + connections.append(connection) - # Enhanced: Add to connection pool if enabled - if self.connection_pool: - self.connection_pool.add_connection( - peer_id, connection, str(multiaddr) - ) + # Limit number of connections per peer + if len(connections) >= self.connection_config.max_connections_per_peer: + break - # Backward compatibility: Keep existing connections dict - self.connections[peer_id] = connection - - return connection except SwarmException as e: exceptions.append(e) logger.debug( @@ -384,11 +261,14 @@ class Swarm(Service, INetworkService): exc_info=e, ) - # Tried all addresses, raising exception. - raise SwarmException( - f"unable to connect to {peer_id}, no addresses established a successful " - "connection (with exceptions)" - ) from MultiError(exceptions) + if not connections: + # Tried all addresses, raising exception. + raise SwarmException( + f"unable to connect to {peer_id}, no addresses established a " + "successful connection (with exceptions)" + ) from MultiError(exceptions) + + return connections async def _dial_with_retry(self, addr: Multiaddr, peer_id: ID) -> INetConn: """ @@ -515,33 +395,76 @@ class Swarm(Service, INetworkService): """ logger.debug("attempting to open a stream to peer %s", peer_id) - # Enhanced: Try to get existing connection from pool first - if self.connection_pool and self.connection_pool.has_connection(peer_id): - connection = self.connection_pool.get_connection( - peer_id, self.connection_config.load_balancing_strategy - ) - if connection: - try: - net_stream = await connection.new_stream() - logger.debug( - "successfully opened a stream to peer %s " - "using existing connection", - peer_id, - ) - return net_stream - except Exception as e: - logger.debug( - f"Failed to create stream on existing connection, " - f"will dial new connection: {e}" - ) - # Fall through to dial new connection + # Get existing connections or dial new ones + connections = self.get_connections(peer_id) + if not connections: + connections = await self.dial_peer(peer_id) - # Fall back to existing logic: dial peer and create stream - swarm_conn = await self.dial_peer(peer_id) + # Load balancing strategy at interface level + connection = self._select_connection(connections, peer_id) - net_stream = await swarm_conn.new_stream() - logger.debug("successfully opened a stream to peer %s", peer_id) - return net_stream + try: + net_stream = await connection.new_stream() + logger.debug("successfully opened a stream to peer %s", peer_id) + return net_stream + except Exception as e: + logger.debug(f"Failed to create stream on connection: {e}") + # Try other connections if available + for other_conn in connections: + if other_conn != connection: + try: + net_stream = await other_conn.new_stream() + logger.debug( + f"Successfully opened a stream to peer {peer_id} " + "using alternative connection" + ) + return net_stream + except Exception: + continue + + # All connections failed, raise exception + raise SwarmException(f"Failed to create stream to peer {peer_id}") from e + + def _select_connection(self, connections: list[INetConn], peer_id: ID) -> INetConn: + """ + Select connection based on load balancing strategy. + + Parameters + ---------- + connections : list[INetConn] + List of available connections. + peer_id : ID + The peer ID for round-robin tracking. + strategy : str + Load balancing strategy ("round_robin", "least_loaded", etc.). + + Returns + ------- + INetConn + Selected connection. + + """ + if not connections: + raise ValueError("No connections available") + + strategy = self.connection_config.load_balancing_strategy + + if strategy == "round_robin": + # Simple round-robin selection + if peer_id not in self._round_robin_index: + self._round_robin_index[peer_id] = 0 + + index = self._round_robin_index[peer_id] % len(connections) + self._round_robin_index[peer_id] += 1 + return connections[index] + + elif strategy == "least_loaded": + # Find connection with least streams + return min(connections, key=lambda c: len(c.get_streams())) + + else: + # Default to first connection + return connections[0] async def listen(self, *multiaddrs: Multiaddr) -> bool: """ @@ -637,9 +560,9 @@ class Swarm(Service, INetworkService): # Perform alternative cleanup if the manager isn't initialized # Close all connections manually if hasattr(self, "connections"): - for conn_id in list(self.connections.keys()): - conn = self.connections[conn_id] - await conn.close() + for peer_id, conns in list(self.connections.items()): + for conn in conns: + await conn.close() # Clear connection tracking dictionary self.connections.clear() @@ -669,17 +592,28 @@ class Swarm(Service, INetworkService): logger.debug("swarm successfully closed") async def close_peer(self, peer_id: ID) -> None: - if peer_id not in self.connections: + """ + Close all connections to the specified peer. + + Parameters + ---------- + peer_id : ID + The peer ID to close connections for. + + """ + connections = self.get_connections(peer_id) + if not connections: return - connection = self.connections[peer_id] - # Enhanced: Remove from connection pool if enabled - if self.connection_pool: - self.connection_pool.remove_connection(peer_id, connection) + # Close all connections + for connection in connections: + try: + await connection.close() + except Exception as e: + logger.warning(f"Error closing connection to {peer_id}: {e}") - # NOTE: `connection.close` will delete `peer_id` from `self.connections` - # and `notify_disconnected` for us. - await connection.close() + # Remove from connections dict + self.connections.pop(peer_id, None) logger.debug("successfully close the connection to peer %s", peer_id) @@ -698,20 +632,58 @@ class Swarm(Service, INetworkService): await muxed_conn.event_started.wait() self.manager.run_task(swarm_conn.start) await swarm_conn.event_started.wait() - # Enhanced: Add to connection pool if enabled - if self.connection_pool: - # For incoming connections, we don't have a specific address - # Use a placeholder that will be updated when we get more info - self.connection_pool.add_connection( - muxed_conn.peer_id, swarm_conn, "incoming" - ) - # Store muxed_conn with peer id (backward compatibility) - self.connections[muxed_conn.peer_id] = swarm_conn + # Add to connections dict with deduplication + peer_id = muxed_conn.peer_id + if peer_id not in self.connections: + self.connections[peer_id] = [] + + # Check for duplicate connections by comparing the underlying muxed connection + for existing_conn in self.connections[peer_id]: + if existing_conn.muxed_conn == muxed_conn: + logger.debug(f"Connection already exists for peer {peer_id}") + # existing_conn is a SwarmConn since it's stored in the connections list + return existing_conn # type: ignore[return-value] + + self.connections[peer_id].append(swarm_conn) + + # Trim if we exceed max connections + max_conns = self.connection_config.max_connections_per_peer + if len(self.connections[peer_id]) > max_conns: + self._trim_connections(peer_id) + # Call notifiers since event occurred await self.notify_connected(swarm_conn) return swarm_conn + def _trim_connections(self, peer_id: ID) -> None: + """ + Remove oldest connections when limit is exceeded. + """ + connections = self.connections[peer_id] + if len(connections) <= self.connection_config.max_connections_per_peer: + return + + # Sort by creation time and remove oldest + # For now, just keep the most recent connections + max_conns = self.connection_config.max_connections_per_peer + connections_to_remove = connections[:-max_conns] + + for conn in connections_to_remove: + logger.debug(f"Trimming old connection for peer {peer_id}") + trio.lowlevel.spawn_system_task(self._close_connection_async, conn) + + # Keep only the most recent connections + max_conns = self.connection_config.max_connections_per_peer + self.connections[peer_id] = connections[-max_conns:] + + async def _close_connection_async(self, connection: INetConn) -> None: + """Close a connection asynchronously.""" + try: + await connection.close() + except Exception as e: + logger.warning(f"Error closing connection: {e}") + def remove_conn(self, swarm_conn: SwarmConn) -> None: """ Simply remove the connection from Swarm's records, without closing @@ -719,13 +691,12 @@ class Swarm(Service, INetworkService): """ peer_id = swarm_conn.muxed_conn.peer_id - # Enhanced: Remove from connection pool if enabled - if self.connection_pool: - self.connection_pool.remove_connection(peer_id, swarm_conn) - - if peer_id not in self.connections: - return - del self.connections[peer_id] + if peer_id in self.connections: + self.connections[peer_id] = [ + conn for conn in self.connections[peer_id] if conn != swarm_conn + ] + if not self.connections[peer_id]: + del self.connections[peer_id] # Notifee @@ -771,3 +742,21 @@ class Swarm(Service, INetworkService): async with trio.open_nursery() as nursery: for notifee in self.notifees: nursery.start_soon(notifier, notifee) + + # Backward compatibility properties + @property + def connections_legacy(self) -> dict[ID, INetConn]: + """ + Legacy 1:1 mapping for backward compatibility. + + Returns + ------- + dict[ID, INetConn] + Legacy mapping with only the first connection per peer. + + """ + legacy_conns = {} + for peer_id, conns in self.connections.items(): + if conns: + legacy_conns[peer_id] = conns[0] + return legacy_conns diff --git a/tests/core/host/test_live_peers.py b/tests/core/host/test_live_peers.py index 1d7948ad..e5af42ba 100644 --- a/tests/core/host/test_live_peers.py +++ b/tests/core/host/test_live_peers.py @@ -164,8 +164,8 @@ async def test_live_peers_unexpected_drop(security_protocol): assert peer_a_id in host_b.get_live_peers() # Simulate unexpected connection drop by directly closing the connection - conn = host_a.get_network().connections[peer_b_id] - await conn.muxed_conn.close() + conns = host_a.get_network().connections[peer_b_id] + await conns[0].muxed_conn.close() # Allow for connection cleanup await trio.sleep(0.1) diff --git a/tests/core/network/test_enhanced_swarm.py b/tests/core/network/test_enhanced_swarm.py index 9b100ad9..e63de126 100644 --- a/tests/core/network/test_enhanced_swarm.py +++ b/tests/core/network/test_enhanced_swarm.py @@ -1,14 +1,15 @@ import time +from typing import cast from unittest.mock import Mock import pytest from multiaddr import Multiaddr +import trio from libp2p.abc import INetConn, INetStream from libp2p.network.exceptions import SwarmException from libp2p.network.swarm import ( ConnectionConfig, - ConnectionPool, RetryConfig, Swarm, ) @@ -21,10 +22,12 @@ class MockConnection(INetConn): def __init__(self, peer_id: ID, is_closed: bool = False): self.peer_id = peer_id self._is_closed = is_closed - self.stream_count = 0 + self.streams = set() # Track streams properly # Mock the muxed_conn attribute that Swarm expects self.muxed_conn = Mock() self.muxed_conn.peer_id = peer_id + # Required by INetConn interface + self.event_started = trio.Event() async def close(self): self._is_closed = True @@ -34,12 +37,14 @@ class MockConnection(INetConn): return self._is_closed async def new_stream(self) -> INetStream: - self.stream_count += 1 - return Mock(spec=INetStream) + # Create a mock stream and add it to the connection's stream set + mock_stream = Mock(spec=INetStream) + self.streams.add(mock_stream) + return mock_stream def get_streams(self) -> tuple[INetStream, ...]: - """Mock implementation of get_streams.""" - return tuple() + """Return all streams associated with this connection.""" + return tuple(self.streams) def get_transport_addresses(self) -> list[Multiaddr]: """Mock implementation of get_transport_addresses.""" @@ -70,114 +75,9 @@ async def test_connection_config_defaults(): config = ConnectionConfig() assert config.max_connections_per_peer == 3 assert config.connection_timeout == 30.0 - assert config.enable_connection_pool is True assert config.load_balancing_strategy == "round_robin" -@pytest.mark.trio -async def test_connection_pool_basic_operations(): - """Test basic ConnectionPool operations.""" - pool = ConnectionPool(max_connections_per_peer=2) - peer_id = ID(b"QmTest") - - # Test empty pool - assert not pool.has_connection(peer_id) - assert pool.get_connection(peer_id) is None - - # Add connection - conn1 = MockConnection(peer_id) - pool.add_connection(peer_id, conn1, "addr1") - assert pool.has_connection(peer_id) - assert pool.get_connection(peer_id) == conn1 - - # Add second connection - conn2 = MockConnection(peer_id) - pool.add_connection(peer_id, conn2, "addr2") - assert len(pool.peer_connections[peer_id]) == 2 - - # Test round-robin - should cycle through connections - first_conn = pool.get_connection(peer_id, "round_robin") - second_conn = pool.get_connection(peer_id, "round_robin") - third_conn = pool.get_connection(peer_id, "round_robin") - - # Should cycle through both connections - assert first_conn in [conn1, conn2] - assert second_conn in [conn1, conn2] - assert third_conn in [conn1, conn2] - assert first_conn != second_conn or second_conn != third_conn - - # Test least loaded - set different stream counts - conn1.stream_count = 5 - conn2.stream_count = 1 - least_loaded_conn = pool.get_connection(peer_id, "least_loaded") - assert least_loaded_conn == conn2 # conn2 has fewer streams - - -@pytest.mark.trio -async def test_connection_pool_deduplication(): - """Test connection deduplication by address.""" - pool = ConnectionPool(max_connections_per_peer=3) - peer_id = ID(b"QmTest") - - conn1 = MockConnection(peer_id) - pool.add_connection(peer_id, conn1, "addr1") - - # Try to add connection with same address - conn2 = MockConnection(peer_id) - pool.add_connection(peer_id, conn2, "addr1") - - # Should only have one connection - assert len(pool.peer_connections[peer_id]) == 1 - assert pool.get_connection(peer_id) == conn1 - - -@pytest.mark.trio -async def test_connection_pool_trimming(): - """Test connection trimming when limit is exceeded.""" - pool = ConnectionPool(max_connections_per_peer=2) - peer_id = ID(b"QmTest") - - # Add 3 connections - conn1 = MockConnection(peer_id) - conn2 = MockConnection(peer_id) - conn3 = MockConnection(peer_id) - - pool.add_connection(peer_id, conn1, "addr1") - pool.add_connection(peer_id, conn2, "addr2") - pool.add_connection(peer_id, conn3, "addr3") - - # Should trim to 2 connections - assert len(pool.peer_connections[peer_id]) == 2 - - # The oldest connections should be removed - remaining_connections = [c.connection for c in pool.peer_connections[peer_id]] - assert conn3 in remaining_connections # Most recent should remain - - -@pytest.mark.trio -async def test_connection_pool_remove_connection(): - """Test removing connections from pool.""" - pool = ConnectionPool(max_connections_per_peer=3) - peer_id = ID(b"QmTest") - - conn1 = MockConnection(peer_id) - conn2 = MockConnection(peer_id) - - pool.add_connection(peer_id, conn1, "addr1") - pool.add_connection(peer_id, conn2, "addr2") - - assert len(pool.peer_connections[peer_id]) == 2 - - # Remove connection - pool.remove_connection(peer_id, conn1) - assert len(pool.peer_connections[peer_id]) == 1 - assert pool.get_connection(peer_id) == conn2 - - # Remove last connection - pool.remove_connection(peer_id, conn2) - assert not pool.has_connection(peer_id) - - @pytest.mark.trio async def test_enhanced_swarm_constructor(): """Test enhanced Swarm constructor with new configuration.""" @@ -191,19 +91,16 @@ async def test_enhanced_swarm_constructor(): swarm = Swarm(peer_id, peerstore, upgrader, transport) assert swarm.retry_config.max_retries == 3 assert swarm.connection_config.max_connections_per_peer == 3 - assert swarm.connection_pool is not None + assert isinstance(swarm.connections, dict) # Test with custom config custom_retry = RetryConfig(max_retries=5, initial_delay=0.5) - custom_conn = ConnectionConfig( - max_connections_per_peer=5, enable_connection_pool=False - ) + custom_conn = ConnectionConfig(max_connections_per_peer=5) swarm = Swarm(peer_id, peerstore, upgrader, transport, custom_retry, custom_conn) assert swarm.retry_config.max_retries == 5 assert swarm.retry_config.initial_delay == 0.5 assert swarm.connection_config.max_connections_per_peer == 5 - assert swarm.connection_pool is None @pytest.mark.trio @@ -273,143 +170,155 @@ async def test_swarm_retry_logic(): # Should have succeeded after 3 attempts assert attempt_count[0] == 3 - assert result is not None - - # Should have taken some time due to retries - assert end_time - start_time > 0.02 # At least 2 delays + assert isinstance(result, MockConnection) + assert end_time - start_time > 0.01 # Should have some delay @pytest.mark.trio -async def test_swarm_multi_connection_support(): - """Test multi-connection support in Swarm.""" +async def test_swarm_load_balancing_strategies(): + """Test load balancing strategies.""" peer_id = ID(b"QmTest") peerstore = Mock() upgrader = Mock() transport = Mock() - connection_config = ConnectionConfig( - max_connections_per_peer=3, - enable_connection_pool=True, - load_balancing_strategy="round_robin", - ) + swarm = Swarm(peer_id, peerstore, upgrader, transport) + # Create mock connections with different stream counts + conn1 = MockConnection(peer_id) + conn2 = MockConnection(peer_id) + conn3 = MockConnection(peer_id) + + # Add some streams to simulate load + await conn1.new_stream() + await conn1.new_stream() + await conn2.new_stream() + + connections = [conn1, conn2, conn3] + + # Test round-robin strategy + swarm.connection_config.load_balancing_strategy = "round_robin" + # Cast to satisfy type checker + connections_cast = cast("list[INetConn]", connections) + selected1 = swarm._select_connection(connections_cast, peer_id) + selected2 = swarm._select_connection(connections_cast, peer_id) + selected3 = swarm._select_connection(connections_cast, peer_id) + + # Should cycle through connections + assert selected1 in connections + assert selected2 in connections + assert selected3 in connections + + # Test least loaded strategy + swarm.connection_config.load_balancing_strategy = "least_loaded" + least_loaded = swarm._select_connection(connections_cast, peer_id) + + # conn3 has 0 streams, conn2 has 1 stream, conn1 has 2 streams + # So conn3 should be selected as least loaded + assert least_loaded == conn3 + + # Test default strategy (first connection) + swarm.connection_config.load_balancing_strategy = "unknown" + default_selected = swarm._select_connection(connections_cast, peer_id) + assert default_selected == conn1 + + +@pytest.mark.trio +async def test_swarm_multiple_connections_api(): + """Test the new multiple connections API methods.""" + peer_id = ID(b"QmTest") + peerstore = Mock() + upgrader = Mock() + transport = Mock() + + swarm = Swarm(peer_id, peerstore, upgrader, transport) + + # Test empty connections + assert swarm.get_connections() == [] + assert swarm.get_connections(peer_id) == [] + assert swarm.get_connection(peer_id) is None + assert swarm.get_connections_map() == {} + + # Add some connections + conn1 = MockConnection(peer_id) + conn2 = MockConnection(peer_id) + swarm.connections[peer_id] = [conn1, conn2] + + # Test get_connections with peer_id + peer_connections = swarm.get_connections(peer_id) + assert len(peer_connections) == 2 + assert conn1 in peer_connections + assert conn2 in peer_connections + + # Test get_connections without peer_id (all connections) + all_connections = swarm.get_connections() + assert len(all_connections) == 2 + assert conn1 in all_connections + assert conn2 in all_connections + + # Test get_connection (backward compatibility) + single_conn = swarm.get_connection(peer_id) + assert single_conn in [conn1, conn2] + + # Test get_connections_map + connections_map = swarm.get_connections_map() + assert peer_id in connections_map + assert connections_map[peer_id] == [conn1, conn2] + + +@pytest.mark.trio +async def test_swarm_connection_trimming(): + """Test connection trimming when limit is exceeded.""" + peer_id = ID(b"QmTest") + peerstore = Mock() + upgrader = Mock() + transport = Mock() + + # Set max connections to 2 + connection_config = ConnectionConfig(max_connections_per_peer=2) swarm = Swarm( peer_id, peerstore, upgrader, transport, connection_config=connection_config ) - # Mock connection pool methods - assert swarm.connection_pool is not None - connection_pool = swarm.connection_pool - connection_pool.has_connection = Mock(return_value=True) - connection_pool.get_connection = Mock(return_value=MockConnection(peer_id)) + # Add 3 connections + conn1 = MockConnection(peer_id) + conn2 = MockConnection(peer_id) + conn3 = MockConnection(peer_id) - # Test that new_stream uses connection pool - result = await swarm.new_stream(peer_id) - assert result is not None - # Use the mocked method directly to avoid type checking issues - get_connection_mock = connection_pool.get_connection - assert get_connection_mock.call_count == 1 + swarm.connections[peer_id] = [conn1, conn2, conn3] + + # Trigger trimming + swarm._trim_connections(peer_id) + + # Should have only 2 connections + assert len(swarm.connections[peer_id]) == 2 + + # The most recent connections should remain + remaining = swarm.connections[peer_id] + assert conn2 in remaining + assert conn3 in remaining @pytest.mark.trio async def test_swarm_backward_compatibility(): - """Test that enhanced Swarm maintains backward compatibility.""" + """Test backward compatibility features.""" peer_id = ID(b"QmTest") peerstore = Mock() upgrader = Mock() transport = Mock() - # Create swarm with connection pool disabled - connection_config = ConnectionConfig(enable_connection_pool=False) - swarm = Swarm( - peer_id, peerstore, upgrader, transport, connection_config=connection_config - ) + swarm = Swarm(peer_id, peerstore, upgrader, transport) - # Should behave like original swarm - assert swarm.connection_pool is None - assert isinstance(swarm.connections, dict) + # Add connections + conn1 = MockConnection(peer_id) + conn2 = MockConnection(peer_id) + swarm.connections[peer_id] = [conn1, conn2] - # Test that dial_peer still works (will fail due to mocks, but structure is correct) - peerstore.addrs.return_value = [Mock(spec=Multiaddr)] - transport.dial.side_effect = Exception("Transport error") - - with pytest.raises(SwarmException): - await swarm.dial_peer(peer_id) - - -@pytest.mark.trio -async def test_swarm_connection_pool_integration(): - """Test integration between Swarm and ConnectionPool.""" - peer_id = ID(b"QmTest") - peerstore = Mock() - upgrader = Mock() - transport = Mock() - - connection_config = ConnectionConfig( - max_connections_per_peer=2, enable_connection_pool=True - ) - - swarm = Swarm( - peer_id, peerstore, upgrader, transport, connection_config=connection_config - ) - - # Mock successful connection creation - mock_conn = MockConnection(peer_id) - peerstore.addrs.return_value = [Mock(spec=Multiaddr)] - - async def mock_dial_with_retry(addr, peer_id): - return mock_conn - - swarm._dial_with_retry = mock_dial_with_retry - - # Test dial_peer adds to connection pool - result = await swarm.dial_peer(peer_id) - assert result == mock_conn - assert swarm.connection_pool is not None - assert swarm.connection_pool.has_connection(peer_id) - - # Test that subsequent calls reuse connection - result2 = await swarm.dial_peer(peer_id) - assert result2 == mock_conn - - -@pytest.mark.trio -async def test_swarm_connection_cleanup(): - """Test connection cleanup in enhanced Swarm.""" - peer_id = ID(b"QmTest") - peerstore = Mock() - upgrader = Mock() - transport = Mock() - - connection_config = ConnectionConfig(enable_connection_pool=True) - swarm = Swarm( - peer_id, peerstore, upgrader, transport, connection_config=connection_config - ) - - # Add a connection - mock_conn = MockConnection(peer_id) - swarm.connections[peer_id] = mock_conn - assert swarm.connection_pool is not None - swarm.connection_pool.add_connection(peer_id, mock_conn, "test_addr") - - # Test close_peer removes from pool - await swarm.close_peer(peer_id) - assert swarm.connection_pool is not None - assert not swarm.connection_pool.has_connection(peer_id) - - # Test remove_conn removes from pool - mock_conn2 = MockConnection(peer_id) - swarm.connections[peer_id] = mock_conn2 - assert swarm.connection_pool is not None - connection_pool = swarm.connection_pool - connection_pool.add_connection(peer_id, mock_conn2, "test_addr2") - - # Note: remove_conn expects SwarmConn, but for testing we'll just - # remove from pool directly - connection_pool = swarm.connection_pool - connection_pool.remove_connection(peer_id, mock_conn2) - assert connection_pool is not None - assert not connection_pool.has_connection(peer_id) + # Test connections_legacy property + legacy_connections = swarm.connections_legacy + assert peer_id in legacy_connections + # Should return first connection + assert legacy_connections[peer_id] in [conn1, conn2] if __name__ == "__main__": diff --git a/tests/core/network/test_swarm.py b/tests/core/network/test_swarm.py index 605913ec..df08ff98 100644 --- a/tests/core/network/test_swarm.py +++ b/tests/core/network/test_swarm.py @@ -51,14 +51,19 @@ async def test_swarm_dial_peer(security_protocol): for addr in transport.get_addrs() ) swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000) - await swarms[0].dial_peer(swarms[1].get_peer_id()) + + # New: dial_peer now returns list of connections + connections = await swarms[0].dial_peer(swarms[1].get_peer_id()) + assert len(connections) > 0 + + # Verify connections are established in both directions assert swarms[0].get_peer_id() in swarms[1].connections assert swarms[1].get_peer_id() in swarms[0].connections # Test: Reuse connections when we already have ones with a peer. - conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()] - conn = await swarms[0].dial_peer(swarms[1].get_peer_id()) - assert conn is conn_to_1 + existing_connections = swarms[0].get_connections(swarms[1].get_peer_id()) + new_connections = await swarms[0].dial_peer(swarms[1].get_peer_id()) + assert new_connections == existing_connections @pytest.mark.trio @@ -107,7 +112,8 @@ async def test_swarm_close_peer(security_protocol): @pytest.mark.trio async def test_swarm_remove_conn(swarm_pair): swarm_0, swarm_1 = swarm_pair - conn_0 = swarm_0.connections[swarm_1.get_peer_id()] + # Get the first connection from the list + conn_0 = swarm_0.connections[swarm_1.get_peer_id()][0] swarm_0.remove_conn(conn_0) assert swarm_1.get_peer_id() not in swarm_0.connections # Test: Remove twice. There should not be errors. @@ -115,6 +121,67 @@ async def test_swarm_remove_conn(swarm_pair): assert swarm_1.get_peer_id() not in swarm_0.connections +@pytest.mark.trio +async def test_swarm_multiple_connections(security_protocol): + """Test multiple connections per peer functionality.""" + async with SwarmFactory.create_batch_and_listen( + 2, security_protocol=security_protocol + ) as swarms: + # Setup multiple addresses for peer + addrs = tuple( + addr + for transport in swarms[1].listeners.values() + for addr in transport.get_addrs() + ) + swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000) + + # Dial peer - should return list of connections + connections = await swarms[0].dial_peer(swarms[1].get_peer_id()) + assert len(connections) > 0 + + # Test get_connections method + peer_connections = swarms[0].get_connections(swarms[1].get_peer_id()) + assert len(peer_connections) == len(connections) + + # Test get_connections_map method + connections_map = swarms[0].get_connections_map() + assert swarms[1].get_peer_id() in connections_map + assert len(connections_map[swarms[1].get_peer_id()]) == len(connections) + + # Test get_connection method (backward compatibility) + single_conn = swarms[0].get_connection(swarms[1].get_peer_id()) + assert single_conn is not None + assert single_conn in connections + + +@pytest.mark.trio +async def test_swarm_load_balancing(security_protocol): + """Test load balancing across multiple connections.""" + async with SwarmFactory.create_batch_and_listen( + 2, security_protocol=security_protocol + ) as swarms: + # Setup connection + addrs = tuple( + addr + for transport in swarms[1].listeners.values() + for addr in transport.get_addrs() + ) + swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000) + + # Create multiple streams - should use load balancing + streams = [] + for _ in range(5): + stream = await swarms[0].new_stream(swarms[1].get_peer_id()) + streams.append(stream) + + # Verify streams were created successfully + assert len(streams) == 5 + + # Clean up + for stream in streams: + await stream.close() + + @pytest.mark.trio async def test_swarm_multiaddr(security_protocol): async with SwarmFactory.create_batch_and_listen( diff --git a/tests/core/security/test_security_multistream.py b/tests/core/security/test_security_multistream.py index 577cf404..d4fed72d 100644 --- a/tests/core/security/test_security_multistream.py +++ b/tests/core/security/test_security_multistream.py @@ -51,6 +51,9 @@ async def perform_simple_test(assertion_func, security_protocol): # Extract the secured connection from either Mplex or Yamux implementation def get_secured_conn(conn): + # conn is now a list, get the first connection + if isinstance(conn, list): + conn = conn[0] muxed_conn = conn.muxed_conn # Direct attribute access for known implementations has_secured_conn = hasattr(muxed_conn, "secured_conn") diff --git a/tests/core/stream_muxer/test_multiplexer_selection.py b/tests/core/stream_muxer/test_multiplexer_selection.py index b2f3e305..9b45324e 100644 --- a/tests/core/stream_muxer/test_multiplexer_selection.py +++ b/tests/core/stream_muxer/test_multiplexer_selection.py @@ -74,7 +74,8 @@ async def test_multiplexer_preference_parameter(muxer_preference): assert len(connections) > 0, "Connection not established" # Get the first connection - conn = list(connections.values())[0] + conns = list(connections.values())[0] + conn = conns[0] # Get first connection from the list muxed_conn = conn.muxed_conn # Define a simple echo protocol @@ -150,7 +151,8 @@ async def test_explicit_muxer_options(muxer_option_func, expected_stream_class): assert len(connections) > 0, "Connection not established" # Get the first connection - conn = list(connections.values())[0] + conns = list(connections.values())[0] + conn = conns[0] # Get first connection from the list muxed_conn = conn.muxed_conn # Define a simple echo protocol @@ -219,7 +221,8 @@ async def test_global_default_muxer(global_default): assert len(connections) > 0, "Connection not established" # Get the first connection - conn = list(connections.values())[0] + conns = list(connections.values())[0] + conn = conns[0] # Get first connection from the list muxed_conn = conn.muxed_conn # Define a simple echo protocol diff --git a/tests/utils/factories.py b/tests/utils/factories.py index 75639e36..c006200f 100644 --- a/tests/utils/factories.py +++ b/tests/utils/factories.py @@ -669,8 +669,8 @@ async def swarm_conn_pair_factory( async with swarm_pair_factory( security_protocol=security_protocol, muxer_opt=muxer_opt ) as swarms: - conn_0 = swarms[0].connections[swarms[1].get_peer_id()] - conn_1 = swarms[1].connections[swarms[0].get_peer_id()] + conn_0 = swarms[0].connections[swarms[1].get_peer_id()][0] + conn_1 = swarms[1].connections[swarms[0].get_peer_id()][0] yield cast(SwarmConn, conn_0), cast(SwarmConn, conn_1)