mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
address architectural refactoring discussed
This commit is contained in:
133
docs/examples.multiple_connections.rst
Normal file
133
docs/examples.multiple_connections.rst
Normal file
@ -0,0 +1,133 @@
|
||||
Multiple Connections Per Peer
|
||||
============================
|
||||
|
||||
This example demonstrates how to use the multiple connections per peer feature in py-libp2p.
|
||||
|
||||
Overview
|
||||
--------
|
||||
|
||||
The multiple connections per peer feature allows a libp2p node to maintain multiple network connections to the same peer. This provides several benefits:
|
||||
|
||||
- **Improved reliability**: If one connection fails, others remain available
|
||||
- **Better performance**: Load can be distributed across multiple connections
|
||||
- **Enhanced throughput**: Multiple streams can be created in parallel
|
||||
- **Fault tolerance**: Redundant connections provide backup paths
|
||||
|
||||
Configuration
|
||||
-------------
|
||||
|
||||
The feature is configured through the `ConnectionConfig` class:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from libp2p.network.swarm import ConnectionConfig
|
||||
|
||||
# Default configuration
|
||||
config = ConnectionConfig()
|
||||
print(f"Max connections per peer: {config.max_connections_per_peer}")
|
||||
print(f"Load balancing strategy: {config.load_balancing_strategy}")
|
||||
|
||||
# Custom configuration
|
||||
custom_config = ConnectionConfig(
|
||||
max_connections_per_peer=5,
|
||||
connection_timeout=60.0,
|
||||
load_balancing_strategy="least_loaded"
|
||||
)
|
||||
|
||||
Load Balancing Strategies
|
||||
------------------------
|
||||
|
||||
Two load balancing strategies are available:
|
||||
|
||||
**Round Robin** (default)
|
||||
Cycles through connections in order, distributing load evenly.
|
||||
|
||||
**Least Loaded**
|
||||
Selects the connection with the fewest active streams.
|
||||
|
||||
API Usage
|
||||
---------
|
||||
|
||||
The new API provides direct access to multiple connections:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from libp2p import new_swarm
|
||||
|
||||
# Create swarm with multiple connections support
|
||||
swarm = new_swarm()
|
||||
|
||||
# Dial a peer - returns list of connections
|
||||
connections = await swarm.dial_peer(peer_id)
|
||||
print(f"Established {len(connections)} connections")
|
||||
|
||||
# Get all connections to a peer
|
||||
peer_connections = swarm.get_connections(peer_id)
|
||||
|
||||
# Get all connections (across all peers)
|
||||
all_connections = swarm.get_connections()
|
||||
|
||||
# Get the complete connections map
|
||||
connections_map = swarm.get_connections_map()
|
||||
|
||||
# Backward compatibility - get single connection
|
||||
single_conn = swarm.get_connection(peer_id)
|
||||
|
||||
Backward Compatibility
|
||||
---------------------
|
||||
|
||||
Existing code continues to work through backward compatibility features:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Legacy 1:1 mapping (returns first connection for each peer)
|
||||
legacy_connections = swarm.connections_legacy
|
||||
|
||||
# Single connection access (returns first available connection)
|
||||
conn = swarm.get_connection(peer_id)
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
See :doc:`examples/doc-examples/multiple_connections_example.py` for a complete working example.
|
||||
|
||||
Production Configuration
|
||||
-----------------------
|
||||
|
||||
For production use, consider these settings:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from libp2p.network.swarm import ConnectionConfig, RetryConfig
|
||||
|
||||
# Production-ready configuration
|
||||
retry_config = RetryConfig(
|
||||
max_retries=3,
|
||||
initial_delay=0.1,
|
||||
max_delay=30.0,
|
||||
backoff_multiplier=2.0,
|
||||
jitter_factor=0.1
|
||||
)
|
||||
|
||||
connection_config = ConnectionConfig(
|
||||
max_connections_per_peer=3, # Balance performance and resources
|
||||
connection_timeout=30.0, # Reasonable timeout
|
||||
load_balancing_strategy="round_robin" # Predictable behavior
|
||||
)
|
||||
|
||||
swarm = new_swarm(
|
||||
retry_config=retry_config,
|
||||
connection_config=connection_config
|
||||
)
|
||||
|
||||
Architecture
|
||||
-----------
|
||||
|
||||
The implementation follows the same architectural patterns as the Go and JavaScript reference implementations:
|
||||
|
||||
- **Core data structure**: `dict[ID, list[INetConn]]` for 1:many mapping
|
||||
- **API consistency**: Methods like `get_connections()` match reference implementations
|
||||
- **Load balancing**: Integrated at the API level for optimal performance
|
||||
- **Backward compatibility**: Maintains existing interfaces for gradual migration
|
||||
|
||||
This design ensures consistency across libp2p implementations while providing the benefits of multiple connections per peer.
|
||||
@ -15,3 +15,4 @@ Examples
|
||||
examples.kademlia
|
||||
examples.mDNS
|
||||
examples.random_walk
|
||||
examples.multiple_connections
|
||||
|
||||
@ -1,18 +1,18 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Example demonstrating the enhanced Swarm with retry logic, exponential backoff,
|
||||
and multi-connection support.
|
||||
Example demonstrating multiple connections per peer support in libp2p.
|
||||
|
||||
This example shows how to:
|
||||
1. Configure retry behavior with exponential backoff
|
||||
2. Enable multi-connection support with connection pooling
|
||||
3. Use different load balancing strategies
|
||||
1. Configure multiple connections per peer
|
||||
2. Use different load balancing strategies
|
||||
3. Access multiple connections through the new API
|
||||
4. Maintain backward compatibility
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p import new_swarm
|
||||
from libp2p.network.swarm import ConnectionConfig, RetryConfig
|
||||
|
||||
@ -21,64 +21,32 @@ logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def example_basic_enhanced_swarm() -> None:
|
||||
"""Example of basic enhanced Swarm usage."""
|
||||
logger.info("Creating enhanced Swarm with default configuration...")
|
||||
async def example_basic_multiple_connections() -> None:
|
||||
"""Example of basic multiple connections per peer usage."""
|
||||
logger.info("Creating swarm with multiple connections support...")
|
||||
|
||||
# Create enhanced swarm with default retry and connection config
|
||||
# Create swarm with default configuration
|
||||
swarm = new_swarm()
|
||||
# Use default configuration values directly
|
||||
default_retry = RetryConfig()
|
||||
default_connection = ConnectionConfig()
|
||||
|
||||
logger.info(f"Swarm created with peer ID: {swarm.get_peer_id()}")
|
||||
logger.info(f"Retry config: max_retries={default_retry.max_retries}")
|
||||
logger.info(
|
||||
f"Connection config: max_connections_per_peer="
|
||||
f"{default_connection.max_connections_per_peer}"
|
||||
)
|
||||
logger.info(f"Connection pool enabled: {default_connection.enable_connection_pool}")
|
||||
|
||||
await swarm.close()
|
||||
logger.info("Basic enhanced Swarm example completed")
|
||||
|
||||
|
||||
async def example_custom_retry_config() -> None:
|
||||
"""Example of custom retry configuration."""
|
||||
logger.info("Creating enhanced Swarm with custom retry configuration...")
|
||||
|
||||
# Custom retry configuration for aggressive retry behavior
|
||||
retry_config = RetryConfig(
|
||||
max_retries=5, # More retries
|
||||
initial_delay=0.05, # Faster initial retry
|
||||
max_delay=10.0, # Lower max delay
|
||||
backoff_multiplier=1.5, # Less aggressive backoff
|
||||
jitter_factor=0.2, # More jitter
|
||||
)
|
||||
|
||||
# Create swarm with custom retry config
|
||||
swarm = new_swarm(retry_config=retry_config)
|
||||
|
||||
logger.info("Custom retry config applied:")
|
||||
logger.info(f" Max retries: {retry_config.max_retries}")
|
||||
logger.info(f" Initial delay: {retry_config.initial_delay}s")
|
||||
logger.info(f" Max delay: {retry_config.max_delay}s")
|
||||
logger.info(f" Backoff multiplier: {retry_config.backoff_multiplier}")
|
||||
logger.info(f" Jitter factor: {retry_config.jitter_factor}")
|
||||
|
||||
await swarm.close()
|
||||
logger.info("Custom retry config example completed")
|
||||
logger.info("Basic multiple connections example completed")
|
||||
|
||||
|
||||
async def example_custom_connection_config() -> None:
|
||||
"""Example of custom connection configuration."""
|
||||
logger.info("Creating enhanced Swarm with custom connection configuration...")
|
||||
logger.info("Creating swarm with custom connection configuration...")
|
||||
|
||||
# Custom connection configuration for high-performance scenarios
|
||||
connection_config = ConnectionConfig(
|
||||
max_connections_per_peer=5, # More connections per peer
|
||||
connection_timeout=60.0, # Longer timeout
|
||||
enable_connection_pool=True, # Enable connection pooling
|
||||
load_balancing_strategy="least_loaded", # Use least loaded strategy
|
||||
)
|
||||
|
||||
@ -90,9 +58,6 @@ async def example_custom_connection_config() -> None:
|
||||
f" Max connections per peer: {connection_config.max_connections_per_peer}"
|
||||
)
|
||||
logger.info(f" Connection timeout: {connection_config.connection_timeout}s")
|
||||
logger.info(
|
||||
f" Connection pool enabled: {connection_config.enable_connection_pool}"
|
||||
)
|
||||
logger.info(
|
||||
f" Load balancing strategy: {connection_config.load_balancing_strategy}"
|
||||
)
|
||||
@ -101,22 +66,39 @@ async def example_custom_connection_config() -> None:
|
||||
logger.info("Custom connection config example completed")
|
||||
|
||||
|
||||
async def example_backward_compatibility() -> None:
|
||||
"""Example showing backward compatibility."""
|
||||
logger.info("Creating enhanced Swarm with backward compatibility...")
|
||||
async def example_multiple_connections_api() -> None:
|
||||
"""Example of using the new multiple connections API."""
|
||||
logger.info("Demonstrating multiple connections API...")
|
||||
|
||||
# Disable connection pool to maintain original behavior
|
||||
connection_config = ConnectionConfig(enable_connection_pool=False)
|
||||
connection_config = ConnectionConfig(
|
||||
max_connections_per_peer=3,
|
||||
load_balancing_strategy="round_robin"
|
||||
)
|
||||
|
||||
# Create swarm with connection pool disabled
|
||||
swarm = new_swarm(connection_config=connection_config)
|
||||
|
||||
logger.info("Backward compatibility mode:")
|
||||
logger.info("Multiple connections API features:")
|
||||
logger.info(" - dial_peer() returns list[INetConn]")
|
||||
logger.info(" - get_connections(peer_id) returns list[INetConn]")
|
||||
logger.info(" - get_connections_map() returns dict[ID, list[INetConn]]")
|
||||
logger.info(
|
||||
f" Connection pool enabled: {connection_config.enable_connection_pool}"
|
||||
" - get_connection(peer_id) returns INetConn | None (backward compatibility)"
|
||||
)
|
||||
logger.info(f" Connections dict type: {type(swarm.connections)}")
|
||||
logger.info(" Retry logic still available: 3 max retries")
|
||||
|
||||
await swarm.close()
|
||||
logger.info("Multiple connections API example completed")
|
||||
|
||||
|
||||
async def example_backward_compatibility() -> None:
|
||||
"""Example of backward compatibility features."""
|
||||
logger.info("Demonstrating backward compatibility...")
|
||||
|
||||
swarm = new_swarm()
|
||||
|
||||
logger.info("Backward compatibility features:")
|
||||
logger.info(" - connections_legacy property provides 1:1 mapping")
|
||||
logger.info(" - get_connection() method for single connection access")
|
||||
logger.info(" - Existing code continues to work")
|
||||
|
||||
await swarm.close()
|
||||
logger.info("Backward compatibility example completed")
|
||||
@ -124,7 +106,7 @@ async def example_backward_compatibility() -> None:
|
||||
|
||||
async def example_production_ready_config() -> None:
|
||||
"""Example of production-ready configuration."""
|
||||
logger.info("Creating enhanced Swarm with production-ready configuration...")
|
||||
logger.info("Creating swarm with production-ready configuration...")
|
||||
|
||||
# Production-ready retry configuration
|
||||
retry_config = RetryConfig(
|
||||
@ -139,7 +121,6 @@ async def example_production_ready_config() -> None:
|
||||
connection_config = ConnectionConfig(
|
||||
max_connections_per_peer=3, # Balance between performance and resource usage
|
||||
connection_timeout=30.0, # Reasonable timeout
|
||||
enable_connection_pool=True, # Enable for better performance
|
||||
load_balancing_strategy="round_robin", # Simple, predictable strategy
|
||||
)
|
||||
|
||||
@ -160,19 +141,19 @@ async def example_production_ready_config() -> None:
|
||||
|
||||
async def main() -> None:
|
||||
"""Run all examples."""
|
||||
logger.info("Enhanced Swarm Examples")
|
||||
logger.info("Multiple Connections Per Peer Examples")
|
||||
logger.info("=" * 50)
|
||||
|
||||
try:
|
||||
await example_basic_enhanced_swarm()
|
||||
logger.info("-" * 30)
|
||||
|
||||
await example_custom_retry_config()
|
||||
await example_basic_multiple_connections()
|
||||
logger.info("-" * 30)
|
||||
|
||||
await example_custom_connection_config()
|
||||
logger.info("-" * 30)
|
||||
|
||||
await example_multiple_connections_api()
|
||||
logger.info("-" * 30)
|
||||
|
||||
await example_backward_compatibility()
|
||||
logger.info("-" * 30)
|
||||
|
||||
@ -187,4 +168,4 @@ async def main() -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
trio.run(main)
|
||||
@ -1412,15 +1412,16 @@ class INetwork(ABC):
|
||||
----------
|
||||
peerstore : IPeerStore
|
||||
The peer store for managing peer information.
|
||||
connections : dict[ID, INetConn]
|
||||
A mapping of peer IDs to network connections.
|
||||
connections : dict[ID, list[INetConn]]
|
||||
A mapping of peer IDs to lists of network connections
|
||||
(multiple connections per peer).
|
||||
listeners : dict[str, IListener]
|
||||
A mapping of listener identifiers to listener instances.
|
||||
|
||||
"""
|
||||
|
||||
peerstore: IPeerStore
|
||||
connections: dict[ID, INetConn]
|
||||
connections: dict[ID, list[INetConn]]
|
||||
listeners: dict[str, IListener]
|
||||
|
||||
@abstractmethod
|
||||
@ -1436,9 +1437,56 @@ class INetwork(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def dial_peer(self, peer_id: ID) -> INetConn:
|
||||
def get_connections(self, peer_id: ID | None = None) -> list[INetConn]:
|
||||
"""
|
||||
Create a connection to the specified peer.
|
||||
Get connections for peer (like JS getConnections, Go ConnsToPeer).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID | None
|
||||
The peer ID to get connections for. If None, returns all connections.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[INetConn]
|
||||
List of connections to the specified peer, or all connections
|
||||
if peer_id is None.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_connections_map(self) -> dict[ID, list[INetConn]]:
|
||||
"""
|
||||
Get all connections map (like JS getConnectionsMap).
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[ID, list[INetConn]]
|
||||
The complete mapping of peer IDs to their connection lists.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_connection(self, peer_id: ID) -> INetConn | None:
|
||||
"""
|
||||
Get single connection for backward compatibility.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer ID to get a connection for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
INetConn | None
|
||||
The first available connection, or None if no connections exist.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def dial_peer(self, peer_id: ID) -> list[INetConn]:
|
||||
"""
|
||||
Create connections to the specified peer with load balancing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -1447,8 +1495,8 @@ class INetwork(ABC):
|
||||
|
||||
Returns
|
||||
-------
|
||||
INetConn
|
||||
The network connection instance to the specified peer.
|
||||
list[INetConn]
|
||||
List of established connections to the peer.
|
||||
|
||||
Raises
|
||||
------
|
||||
|
||||
@ -338,7 +338,7 @@ class BasicHost(IHost):
|
||||
:param peer_id: ID of the peer to check
|
||||
:return: True if peer has an active connection, False otherwise
|
||||
"""
|
||||
return peer_id in self._network.connections
|
||||
return len(self._network.get_connections(peer_id)) > 0
|
||||
|
||||
def get_peer_connection_info(self, peer_id: ID) -> INetConn | None:
|
||||
"""
|
||||
@ -347,4 +347,4 @@ class BasicHost(IHost):
|
||||
:param peer_id: ID of the peer to get info for
|
||||
:return: Connection object if peer is connected, None otherwise
|
||||
"""
|
||||
return self._network.connections.get(peer_id)
|
||||
return self._network.get_connection(peer_id)
|
||||
|
||||
@ -74,175 +74,13 @@ class RetryConfig:
|
||||
|
||||
@dataclass
|
||||
class ConnectionConfig:
|
||||
"""Configuration for connection pool and multi-connection support."""
|
||||
"""Configuration for multi-connection support."""
|
||||
|
||||
max_connections_per_peer: int = 3
|
||||
connection_timeout: float = 30.0
|
||||
enable_connection_pool: bool = True
|
||||
load_balancing_strategy: str = "round_robin" # or "least_loaded"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionInfo:
|
||||
"""Information about a connection in the pool."""
|
||||
|
||||
connection: INetConn
|
||||
address: str
|
||||
established_at: float
|
||||
last_used: float
|
||||
stream_count: int
|
||||
is_healthy: bool
|
||||
|
||||
|
||||
class ConnectionPool:
|
||||
"""Manages multiple connections per peer with load balancing."""
|
||||
|
||||
def __init__(self, max_connections_per_peer: int = 3):
|
||||
self.max_connections_per_peer = max_connections_per_peer
|
||||
self.peer_connections: dict[ID, list[ConnectionInfo]] = {}
|
||||
self._round_robin_index: dict[ID, int] = {}
|
||||
|
||||
def add_connection(self, peer_id: ID, connection: INetConn, address: str) -> None:
|
||||
"""Add a connection to the pool with deduplication."""
|
||||
if peer_id not in self.peer_connections:
|
||||
self.peer_connections[peer_id] = []
|
||||
|
||||
# Check for duplicate connections to the same address
|
||||
for conn_info in self.peer_connections[peer_id]:
|
||||
if conn_info.address == address:
|
||||
logger.debug(
|
||||
f"Connection to {address} already exists for peer {peer_id}"
|
||||
)
|
||||
return
|
||||
|
||||
# Add new connection
|
||||
try:
|
||||
current_time = trio.current_time()
|
||||
except RuntimeError:
|
||||
# Fallback for testing contexts where trio is not running
|
||||
import time
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
conn_info = ConnectionInfo(
|
||||
connection=connection,
|
||||
address=address,
|
||||
established_at=current_time,
|
||||
last_used=current_time,
|
||||
stream_count=0,
|
||||
is_healthy=True,
|
||||
)
|
||||
|
||||
self.peer_connections[peer_id].append(conn_info)
|
||||
|
||||
# Trim if we exceed max connections
|
||||
if len(self.peer_connections[peer_id]) > self.max_connections_per_peer:
|
||||
self._trim_connections(peer_id)
|
||||
|
||||
def get_connection(
|
||||
self, peer_id: ID, strategy: str = "round_robin"
|
||||
) -> INetConn | None:
|
||||
"""Get a connection using the specified load balancing strategy."""
|
||||
if peer_id not in self.peer_connections or not self.peer_connections[peer_id]:
|
||||
return None
|
||||
|
||||
connections = self.peer_connections[peer_id]
|
||||
|
||||
if strategy == "round_robin":
|
||||
if peer_id not in self._round_robin_index:
|
||||
self._round_robin_index[peer_id] = 0
|
||||
|
||||
index = self._round_robin_index[peer_id] % len(connections)
|
||||
self._round_robin_index[peer_id] += 1
|
||||
|
||||
conn_info = connections[index]
|
||||
try:
|
||||
conn_info.last_used = trio.current_time()
|
||||
except RuntimeError:
|
||||
import time
|
||||
|
||||
conn_info.last_used = time.time()
|
||||
return conn_info.connection
|
||||
|
||||
elif strategy == "least_loaded":
|
||||
# Find connection with least streams
|
||||
# Note: stream_count is a custom attribute we add to connections
|
||||
conn_info = min(
|
||||
connections, key=lambda c: getattr(c.connection, "stream_count", 0)
|
||||
)
|
||||
try:
|
||||
conn_info.last_used = trio.current_time()
|
||||
except RuntimeError:
|
||||
import time
|
||||
|
||||
conn_info.last_used = time.time()
|
||||
return conn_info.connection
|
||||
|
||||
else:
|
||||
# Default to first connection
|
||||
conn_info = connections[0]
|
||||
try:
|
||||
conn_info.last_used = trio.current_time()
|
||||
except RuntimeError:
|
||||
import time
|
||||
|
||||
conn_info.last_used = time.time()
|
||||
return conn_info.connection
|
||||
|
||||
def has_connection(self, peer_id: ID) -> bool:
|
||||
"""Check if we have any connections to the peer."""
|
||||
return (
|
||||
peer_id in self.peer_connections and len(self.peer_connections[peer_id]) > 0
|
||||
)
|
||||
|
||||
def remove_connection(self, peer_id: ID, connection: INetConn) -> None:
|
||||
"""Remove a connection from the pool."""
|
||||
if peer_id in self.peer_connections:
|
||||
self.peer_connections[peer_id] = [
|
||||
conn_info
|
||||
for conn_info in self.peer_connections[peer_id]
|
||||
if conn_info.connection != connection
|
||||
]
|
||||
|
||||
# Clean up empty peer entries
|
||||
if not self.peer_connections[peer_id]:
|
||||
del self.peer_connections[peer_id]
|
||||
if peer_id in self._round_robin_index:
|
||||
del self._round_robin_index[peer_id]
|
||||
|
||||
def _trim_connections(self, peer_id: ID) -> None:
|
||||
"""Remove oldest connections when limit is exceeded."""
|
||||
connections = self.peer_connections[peer_id]
|
||||
if len(connections) <= self.max_connections_per_peer:
|
||||
return
|
||||
|
||||
# Sort by last used time and remove oldest
|
||||
connections.sort(key=lambda c: c.last_used)
|
||||
connections_to_remove = connections[: -self.max_connections_per_peer]
|
||||
|
||||
for conn_info in connections_to_remove:
|
||||
logger.debug(
|
||||
f"Trimming old connection to {conn_info.address} for peer {peer_id}"
|
||||
)
|
||||
try:
|
||||
# Close the connection asynchronously
|
||||
trio.lowlevel.spawn_system_task(
|
||||
self._close_connection_async, conn_info.connection
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing trimmed connection: {e}")
|
||||
|
||||
# Keep only the most recently used connections
|
||||
self.peer_connections[peer_id] = connections[-self.max_connections_per_peer :]
|
||||
|
||||
async def _close_connection_async(self, connection: INetConn) -> None:
|
||||
"""Close a connection asynchronously."""
|
||||
try:
|
||||
await connection.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing connection: {e}")
|
||||
|
||||
|
||||
def create_default_stream_handler(network: INetworkService) -> StreamHandlerFn:
|
||||
async def stream_handler(stream: INetStream) -> None:
|
||||
await network.get_manager().wait_finished()
|
||||
@ -256,7 +94,7 @@ class Swarm(Service, INetworkService):
|
||||
upgrader: TransportUpgrader
|
||||
transport: ITransport
|
||||
# Enhanced: Support for multiple connections per peer
|
||||
connections: dict[ID, INetConn] # Backward compatibility
|
||||
connections: dict[ID, list[INetConn]] # Multiple connections per peer
|
||||
listeners: dict[str, IListener]
|
||||
common_stream_handler: StreamHandlerFn
|
||||
listener_nursery: trio.Nursery | None
|
||||
@ -264,10 +102,10 @@ class Swarm(Service, INetworkService):
|
||||
|
||||
notifees: list[INotifee]
|
||||
|
||||
# Enhanced: New configuration and connection pool
|
||||
# Enhanced: New configuration
|
||||
retry_config: RetryConfig
|
||||
connection_config: ConnectionConfig
|
||||
connection_pool: ConnectionPool | None
|
||||
_round_robin_index: dict[ID, int]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -287,16 +125,8 @@ class Swarm(Service, INetworkService):
|
||||
self.retry_config = retry_config or RetryConfig()
|
||||
self.connection_config = connection_config or ConnectionConfig()
|
||||
|
||||
# Enhanced: Initialize connection pool if enabled
|
||||
if self.connection_config.enable_connection_pool:
|
||||
self.connection_pool = ConnectionPool(
|
||||
self.connection_config.max_connections_per_peer
|
||||
)
|
||||
else:
|
||||
self.connection_pool = None
|
||||
|
||||
# Backward compatibility: Keep existing connections dict
|
||||
self.connections = dict()
|
||||
# Enhanced: Initialize connections as 1:many mapping
|
||||
self.connections = {}
|
||||
self.listeners = dict()
|
||||
|
||||
# Create Notifee array
|
||||
@ -307,6 +137,9 @@ class Swarm(Service, INetworkService):
|
||||
self.listener_nursery = None
|
||||
self.event_listener_nursery_created = trio.Event()
|
||||
|
||||
# Load balancing state
|
||||
self._round_robin_index = {}
|
||||
|
||||
async def run(self) -> None:
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Create a nursery for listener tasks.
|
||||
@ -326,26 +159,74 @@ class Swarm(Service, INetworkService):
|
||||
def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None:
|
||||
self.common_stream_handler = stream_handler
|
||||
|
||||
async def dial_peer(self, peer_id: ID) -> INetConn:
|
||||
def get_connections(self, peer_id: ID | None = None) -> list[INetConn]:
|
||||
"""
|
||||
Try to create a connection to peer_id with enhanced retry logic.
|
||||
Get connections for peer (like JS getConnections, Go ConnsToPeer).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID | None
|
||||
The peer ID to get connections for. If None, returns all connections.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[INetConn]
|
||||
List of connections to the specified peer, or all connections
|
||||
if peer_id is None.
|
||||
|
||||
"""
|
||||
if peer_id is not None:
|
||||
return self.connections.get(peer_id, [])
|
||||
|
||||
# Return all connections from all peers
|
||||
all_conns = []
|
||||
for conns in self.connections.values():
|
||||
all_conns.extend(conns)
|
||||
return all_conns
|
||||
|
||||
def get_connections_map(self) -> dict[ID, list[INetConn]]:
|
||||
"""
|
||||
Get all connections map (like JS getConnectionsMap).
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[ID, list[INetConn]]
|
||||
The complete mapping of peer IDs to their connection lists.
|
||||
|
||||
"""
|
||||
return self.connections.copy()
|
||||
|
||||
def get_connection(self, peer_id: ID) -> INetConn | None:
|
||||
"""
|
||||
Get single connection for backward compatibility.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer ID to get a connection for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
INetConn | None
|
||||
The first available connection, or None if no connections exist.
|
||||
|
||||
"""
|
||||
conns = self.get_connections(peer_id)
|
||||
return conns[0] if conns else None
|
||||
|
||||
async def dial_peer(self, peer_id: ID) -> list[INetConn]:
|
||||
"""
|
||||
Try to create connections to peer_id with enhanced retry logic.
|
||||
|
||||
:param peer_id: peer if we want to dial
|
||||
:raises SwarmException: raised when an error occurs
|
||||
:return: muxed connection
|
||||
:return: list of muxed connections
|
||||
"""
|
||||
# Enhanced: Check connection pool first if enabled
|
||||
if self.connection_pool and self.connection_pool.has_connection(peer_id):
|
||||
connection = self.connection_pool.get_connection(peer_id)
|
||||
if connection:
|
||||
logger.debug(f"Reusing existing connection to peer {peer_id}")
|
||||
return connection
|
||||
|
||||
# Enhanced: Check existing single connection for backward compatibility
|
||||
if peer_id in self.connections:
|
||||
# If muxed connection already exists for peer_id,
|
||||
# set muxed connection equal to existing muxed connection
|
||||
return self.connections[peer_id]
|
||||
# Check if we already have connections
|
||||
existing_connections = self.get_connections(peer_id)
|
||||
if existing_connections:
|
||||
logger.debug(f"Reusing existing connections to peer {peer_id}")
|
||||
return existing_connections
|
||||
|
||||
logger.debug("attempting to dial peer %s", peer_id)
|
||||
|
||||
@ -358,23 +239,19 @@ class Swarm(Service, INetworkService):
|
||||
if not addrs:
|
||||
raise SwarmException(f"No known addresses to peer {peer_id}")
|
||||
|
||||
connections = []
|
||||
exceptions: list[SwarmException] = []
|
||||
|
||||
# Enhanced: Try all known addresses with retry logic
|
||||
for multiaddr in addrs:
|
||||
try:
|
||||
connection = await self._dial_with_retry(multiaddr, peer_id)
|
||||
connections.append(connection)
|
||||
|
||||
# Enhanced: Add to connection pool if enabled
|
||||
if self.connection_pool:
|
||||
self.connection_pool.add_connection(
|
||||
peer_id, connection, str(multiaddr)
|
||||
)
|
||||
# Limit number of connections per peer
|
||||
if len(connections) >= self.connection_config.max_connections_per_peer:
|
||||
break
|
||||
|
||||
# Backward compatibility: Keep existing connections dict
|
||||
self.connections[peer_id] = connection
|
||||
|
||||
return connection
|
||||
except SwarmException as e:
|
||||
exceptions.append(e)
|
||||
logger.debug(
|
||||
@ -384,11 +261,14 @@ class Swarm(Service, INetworkService):
|
||||
exc_info=e,
|
||||
)
|
||||
|
||||
# Tried all addresses, raising exception.
|
||||
raise SwarmException(
|
||||
f"unable to connect to {peer_id}, no addresses established a successful "
|
||||
"connection (with exceptions)"
|
||||
) from MultiError(exceptions)
|
||||
if not connections:
|
||||
# Tried all addresses, raising exception.
|
||||
raise SwarmException(
|
||||
f"unable to connect to {peer_id}, no addresses established a "
|
||||
"successful connection (with exceptions)"
|
||||
) from MultiError(exceptions)
|
||||
|
||||
return connections
|
||||
|
||||
async def _dial_with_retry(self, addr: Multiaddr, peer_id: ID) -> INetConn:
|
||||
"""
|
||||
@ -515,33 +395,76 @@ class Swarm(Service, INetworkService):
|
||||
"""
|
||||
logger.debug("attempting to open a stream to peer %s", peer_id)
|
||||
|
||||
# Enhanced: Try to get existing connection from pool first
|
||||
if self.connection_pool and self.connection_pool.has_connection(peer_id):
|
||||
connection = self.connection_pool.get_connection(
|
||||
peer_id, self.connection_config.load_balancing_strategy
|
||||
)
|
||||
if connection:
|
||||
try:
|
||||
net_stream = await connection.new_stream()
|
||||
logger.debug(
|
||||
"successfully opened a stream to peer %s "
|
||||
"using existing connection",
|
||||
peer_id,
|
||||
)
|
||||
return net_stream
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Failed to create stream on existing connection, "
|
||||
f"will dial new connection: {e}"
|
||||
)
|
||||
# Fall through to dial new connection
|
||||
# Get existing connections or dial new ones
|
||||
connections = self.get_connections(peer_id)
|
||||
if not connections:
|
||||
connections = await self.dial_peer(peer_id)
|
||||
|
||||
# Fall back to existing logic: dial peer and create stream
|
||||
swarm_conn = await self.dial_peer(peer_id)
|
||||
# Load balancing strategy at interface level
|
||||
connection = self._select_connection(connections, peer_id)
|
||||
|
||||
net_stream = await swarm_conn.new_stream()
|
||||
logger.debug("successfully opened a stream to peer %s", peer_id)
|
||||
return net_stream
|
||||
try:
|
||||
net_stream = await connection.new_stream()
|
||||
logger.debug("successfully opened a stream to peer %s", peer_id)
|
||||
return net_stream
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to create stream on connection: {e}")
|
||||
# Try other connections if available
|
||||
for other_conn in connections:
|
||||
if other_conn != connection:
|
||||
try:
|
||||
net_stream = await other_conn.new_stream()
|
||||
logger.debug(
|
||||
f"Successfully opened a stream to peer {peer_id} "
|
||||
"using alternative connection"
|
||||
)
|
||||
return net_stream
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# All connections failed, raise exception
|
||||
raise SwarmException(f"Failed to create stream to peer {peer_id}") from e
|
||||
|
||||
def _select_connection(self, connections: list[INetConn], peer_id: ID) -> INetConn:
|
||||
"""
|
||||
Select connection based on load balancing strategy.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
connections : list[INetConn]
|
||||
List of available connections.
|
||||
peer_id : ID
|
||||
The peer ID for round-robin tracking.
|
||||
strategy : str
|
||||
Load balancing strategy ("round_robin", "least_loaded", etc.).
|
||||
|
||||
Returns
|
||||
-------
|
||||
INetConn
|
||||
Selected connection.
|
||||
|
||||
"""
|
||||
if not connections:
|
||||
raise ValueError("No connections available")
|
||||
|
||||
strategy = self.connection_config.load_balancing_strategy
|
||||
|
||||
if strategy == "round_robin":
|
||||
# Simple round-robin selection
|
||||
if peer_id not in self._round_robin_index:
|
||||
self._round_robin_index[peer_id] = 0
|
||||
|
||||
index = self._round_robin_index[peer_id] % len(connections)
|
||||
self._round_robin_index[peer_id] += 1
|
||||
return connections[index]
|
||||
|
||||
elif strategy == "least_loaded":
|
||||
# Find connection with least streams
|
||||
return min(connections, key=lambda c: len(c.get_streams()))
|
||||
|
||||
else:
|
||||
# Default to first connection
|
||||
return connections[0]
|
||||
|
||||
async def listen(self, *multiaddrs: Multiaddr) -> bool:
|
||||
"""
|
||||
@ -637,9 +560,9 @@ class Swarm(Service, INetworkService):
|
||||
# Perform alternative cleanup if the manager isn't initialized
|
||||
# Close all connections manually
|
||||
if hasattr(self, "connections"):
|
||||
for conn_id in list(self.connections.keys()):
|
||||
conn = self.connections[conn_id]
|
||||
await conn.close()
|
||||
for peer_id, conns in list(self.connections.items()):
|
||||
for conn in conns:
|
||||
await conn.close()
|
||||
|
||||
# Clear connection tracking dictionary
|
||||
self.connections.clear()
|
||||
@ -669,17 +592,28 @@ class Swarm(Service, INetworkService):
|
||||
logger.debug("swarm successfully closed")
|
||||
|
||||
async def close_peer(self, peer_id: ID) -> None:
|
||||
if peer_id not in self.connections:
|
||||
"""
|
||||
Close all connections to the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer ID to close connections for.
|
||||
|
||||
"""
|
||||
connections = self.get_connections(peer_id)
|
||||
if not connections:
|
||||
return
|
||||
connection = self.connections[peer_id]
|
||||
|
||||
# Enhanced: Remove from connection pool if enabled
|
||||
if self.connection_pool:
|
||||
self.connection_pool.remove_connection(peer_id, connection)
|
||||
# Close all connections
|
||||
for connection in connections:
|
||||
try:
|
||||
await connection.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing connection to {peer_id}: {e}")
|
||||
|
||||
# NOTE: `connection.close` will delete `peer_id` from `self.connections`
|
||||
# and `notify_disconnected` for us.
|
||||
await connection.close()
|
||||
# Remove from connections dict
|
||||
self.connections.pop(peer_id, None)
|
||||
|
||||
logger.debug("successfully close the connection to peer %s", peer_id)
|
||||
|
||||
@ -698,20 +632,58 @@ class Swarm(Service, INetworkService):
|
||||
await muxed_conn.event_started.wait()
|
||||
self.manager.run_task(swarm_conn.start)
|
||||
await swarm_conn.event_started.wait()
|
||||
# Enhanced: Add to connection pool if enabled
|
||||
if self.connection_pool:
|
||||
# For incoming connections, we don't have a specific address
|
||||
# Use a placeholder that will be updated when we get more info
|
||||
self.connection_pool.add_connection(
|
||||
muxed_conn.peer_id, swarm_conn, "incoming"
|
||||
)
|
||||
|
||||
# Store muxed_conn with peer id (backward compatibility)
|
||||
self.connections[muxed_conn.peer_id] = swarm_conn
|
||||
# Add to connections dict with deduplication
|
||||
peer_id = muxed_conn.peer_id
|
||||
if peer_id not in self.connections:
|
||||
self.connections[peer_id] = []
|
||||
|
||||
# Check for duplicate connections by comparing the underlying muxed connection
|
||||
for existing_conn in self.connections[peer_id]:
|
||||
if existing_conn.muxed_conn == muxed_conn:
|
||||
logger.debug(f"Connection already exists for peer {peer_id}")
|
||||
# existing_conn is a SwarmConn since it's stored in the connections list
|
||||
return existing_conn # type: ignore[return-value]
|
||||
|
||||
self.connections[peer_id].append(swarm_conn)
|
||||
|
||||
# Trim if we exceed max connections
|
||||
max_conns = self.connection_config.max_connections_per_peer
|
||||
if len(self.connections[peer_id]) > max_conns:
|
||||
self._trim_connections(peer_id)
|
||||
|
||||
# Call notifiers since event occurred
|
||||
await self.notify_connected(swarm_conn)
|
||||
return swarm_conn
|
||||
|
||||
def _trim_connections(self, peer_id: ID) -> None:
|
||||
"""
|
||||
Remove oldest connections when limit is exceeded.
|
||||
"""
|
||||
connections = self.connections[peer_id]
|
||||
if len(connections) <= self.connection_config.max_connections_per_peer:
|
||||
return
|
||||
|
||||
# Sort by creation time and remove oldest
|
||||
# For now, just keep the most recent connections
|
||||
max_conns = self.connection_config.max_connections_per_peer
|
||||
connections_to_remove = connections[:-max_conns]
|
||||
|
||||
for conn in connections_to_remove:
|
||||
logger.debug(f"Trimming old connection for peer {peer_id}")
|
||||
trio.lowlevel.spawn_system_task(self._close_connection_async, conn)
|
||||
|
||||
# Keep only the most recent connections
|
||||
max_conns = self.connection_config.max_connections_per_peer
|
||||
self.connections[peer_id] = connections[-max_conns:]
|
||||
|
||||
async def _close_connection_async(self, connection: INetConn) -> None:
|
||||
"""Close a connection asynchronously."""
|
||||
try:
|
||||
await connection.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing connection: {e}")
|
||||
|
||||
def remove_conn(self, swarm_conn: SwarmConn) -> None:
|
||||
"""
|
||||
Simply remove the connection from Swarm's records, without closing
|
||||
@ -719,13 +691,12 @@ class Swarm(Service, INetworkService):
|
||||
"""
|
||||
peer_id = swarm_conn.muxed_conn.peer_id
|
||||
|
||||
# Enhanced: Remove from connection pool if enabled
|
||||
if self.connection_pool:
|
||||
self.connection_pool.remove_connection(peer_id, swarm_conn)
|
||||
|
||||
if peer_id not in self.connections:
|
||||
return
|
||||
del self.connections[peer_id]
|
||||
if peer_id in self.connections:
|
||||
self.connections[peer_id] = [
|
||||
conn for conn in self.connections[peer_id] if conn != swarm_conn
|
||||
]
|
||||
if not self.connections[peer_id]:
|
||||
del self.connections[peer_id]
|
||||
|
||||
# Notifee
|
||||
|
||||
@ -771,3 +742,21 @@ class Swarm(Service, INetworkService):
|
||||
async with trio.open_nursery() as nursery:
|
||||
for notifee in self.notifees:
|
||||
nursery.start_soon(notifier, notifee)
|
||||
|
||||
# Backward compatibility properties
|
||||
@property
|
||||
def connections_legacy(self) -> dict[ID, INetConn]:
|
||||
"""
|
||||
Legacy 1:1 mapping for backward compatibility.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[ID, INetConn]
|
||||
Legacy mapping with only the first connection per peer.
|
||||
|
||||
"""
|
||||
legacy_conns = {}
|
||||
for peer_id, conns in self.connections.items():
|
||||
if conns:
|
||||
legacy_conns[peer_id] = conns[0]
|
||||
return legacy_conns
|
||||
|
||||
@ -164,8 +164,8 @@ async def test_live_peers_unexpected_drop(security_protocol):
|
||||
assert peer_a_id in host_b.get_live_peers()
|
||||
|
||||
# Simulate unexpected connection drop by directly closing the connection
|
||||
conn = host_a.get_network().connections[peer_b_id]
|
||||
await conn.muxed_conn.close()
|
||||
conns = host_a.get_network().connections[peer_b_id]
|
||||
await conns[0].muxed_conn.close()
|
||||
|
||||
# Allow for connection cleanup
|
||||
await trio.sleep(0.1)
|
||||
|
||||
@ -1,14 +1,15 @@
|
||||
import time
|
||||
from typing import cast
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.abc import INetConn, INetStream
|
||||
from libp2p.network.exceptions import SwarmException
|
||||
from libp2p.network.swarm import (
|
||||
ConnectionConfig,
|
||||
ConnectionPool,
|
||||
RetryConfig,
|
||||
Swarm,
|
||||
)
|
||||
@ -21,10 +22,12 @@ class MockConnection(INetConn):
|
||||
def __init__(self, peer_id: ID, is_closed: bool = False):
|
||||
self.peer_id = peer_id
|
||||
self._is_closed = is_closed
|
||||
self.stream_count = 0
|
||||
self.streams = set() # Track streams properly
|
||||
# Mock the muxed_conn attribute that Swarm expects
|
||||
self.muxed_conn = Mock()
|
||||
self.muxed_conn.peer_id = peer_id
|
||||
# Required by INetConn interface
|
||||
self.event_started = trio.Event()
|
||||
|
||||
async def close(self):
|
||||
self._is_closed = True
|
||||
@ -34,12 +37,14 @@ class MockConnection(INetConn):
|
||||
return self._is_closed
|
||||
|
||||
async def new_stream(self) -> INetStream:
|
||||
self.stream_count += 1
|
||||
return Mock(spec=INetStream)
|
||||
# Create a mock stream and add it to the connection's stream set
|
||||
mock_stream = Mock(spec=INetStream)
|
||||
self.streams.add(mock_stream)
|
||||
return mock_stream
|
||||
|
||||
def get_streams(self) -> tuple[INetStream, ...]:
|
||||
"""Mock implementation of get_streams."""
|
||||
return tuple()
|
||||
"""Return all streams associated with this connection."""
|
||||
return tuple(self.streams)
|
||||
|
||||
def get_transport_addresses(self) -> list[Multiaddr]:
|
||||
"""Mock implementation of get_transport_addresses."""
|
||||
@ -70,114 +75,9 @@ async def test_connection_config_defaults():
|
||||
config = ConnectionConfig()
|
||||
assert config.max_connections_per_peer == 3
|
||||
assert config.connection_timeout == 30.0
|
||||
assert config.enable_connection_pool is True
|
||||
assert config.load_balancing_strategy == "round_robin"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_pool_basic_operations():
|
||||
"""Test basic ConnectionPool operations."""
|
||||
pool = ConnectionPool(max_connections_per_peer=2)
|
||||
peer_id = ID(b"QmTest")
|
||||
|
||||
# Test empty pool
|
||||
assert not pool.has_connection(peer_id)
|
||||
assert pool.get_connection(peer_id) is None
|
||||
|
||||
# Add connection
|
||||
conn1 = MockConnection(peer_id)
|
||||
pool.add_connection(peer_id, conn1, "addr1")
|
||||
assert pool.has_connection(peer_id)
|
||||
assert pool.get_connection(peer_id) == conn1
|
||||
|
||||
# Add second connection
|
||||
conn2 = MockConnection(peer_id)
|
||||
pool.add_connection(peer_id, conn2, "addr2")
|
||||
assert len(pool.peer_connections[peer_id]) == 2
|
||||
|
||||
# Test round-robin - should cycle through connections
|
||||
first_conn = pool.get_connection(peer_id, "round_robin")
|
||||
second_conn = pool.get_connection(peer_id, "round_robin")
|
||||
third_conn = pool.get_connection(peer_id, "round_robin")
|
||||
|
||||
# Should cycle through both connections
|
||||
assert first_conn in [conn1, conn2]
|
||||
assert second_conn in [conn1, conn2]
|
||||
assert third_conn in [conn1, conn2]
|
||||
assert first_conn != second_conn or second_conn != third_conn
|
||||
|
||||
# Test least loaded - set different stream counts
|
||||
conn1.stream_count = 5
|
||||
conn2.stream_count = 1
|
||||
least_loaded_conn = pool.get_connection(peer_id, "least_loaded")
|
||||
assert least_loaded_conn == conn2 # conn2 has fewer streams
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_pool_deduplication():
|
||||
"""Test connection deduplication by address."""
|
||||
pool = ConnectionPool(max_connections_per_peer=3)
|
||||
peer_id = ID(b"QmTest")
|
||||
|
||||
conn1 = MockConnection(peer_id)
|
||||
pool.add_connection(peer_id, conn1, "addr1")
|
||||
|
||||
# Try to add connection with same address
|
||||
conn2 = MockConnection(peer_id)
|
||||
pool.add_connection(peer_id, conn2, "addr1")
|
||||
|
||||
# Should only have one connection
|
||||
assert len(pool.peer_connections[peer_id]) == 1
|
||||
assert pool.get_connection(peer_id) == conn1
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_pool_trimming():
|
||||
"""Test connection trimming when limit is exceeded."""
|
||||
pool = ConnectionPool(max_connections_per_peer=2)
|
||||
peer_id = ID(b"QmTest")
|
||||
|
||||
# Add 3 connections
|
||||
conn1 = MockConnection(peer_id)
|
||||
conn2 = MockConnection(peer_id)
|
||||
conn3 = MockConnection(peer_id)
|
||||
|
||||
pool.add_connection(peer_id, conn1, "addr1")
|
||||
pool.add_connection(peer_id, conn2, "addr2")
|
||||
pool.add_connection(peer_id, conn3, "addr3")
|
||||
|
||||
# Should trim to 2 connections
|
||||
assert len(pool.peer_connections[peer_id]) == 2
|
||||
|
||||
# The oldest connections should be removed
|
||||
remaining_connections = [c.connection for c in pool.peer_connections[peer_id]]
|
||||
assert conn3 in remaining_connections # Most recent should remain
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_pool_remove_connection():
|
||||
"""Test removing connections from pool."""
|
||||
pool = ConnectionPool(max_connections_per_peer=3)
|
||||
peer_id = ID(b"QmTest")
|
||||
|
||||
conn1 = MockConnection(peer_id)
|
||||
conn2 = MockConnection(peer_id)
|
||||
|
||||
pool.add_connection(peer_id, conn1, "addr1")
|
||||
pool.add_connection(peer_id, conn2, "addr2")
|
||||
|
||||
assert len(pool.peer_connections[peer_id]) == 2
|
||||
|
||||
# Remove connection
|
||||
pool.remove_connection(peer_id, conn1)
|
||||
assert len(pool.peer_connections[peer_id]) == 1
|
||||
assert pool.get_connection(peer_id) == conn2
|
||||
|
||||
# Remove last connection
|
||||
pool.remove_connection(peer_id, conn2)
|
||||
assert not pool.has_connection(peer_id)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_enhanced_swarm_constructor():
|
||||
"""Test enhanced Swarm constructor with new configuration."""
|
||||
@ -191,19 +91,16 @@ async def test_enhanced_swarm_constructor():
|
||||
swarm = Swarm(peer_id, peerstore, upgrader, transport)
|
||||
assert swarm.retry_config.max_retries == 3
|
||||
assert swarm.connection_config.max_connections_per_peer == 3
|
||||
assert swarm.connection_pool is not None
|
||||
assert isinstance(swarm.connections, dict)
|
||||
|
||||
# Test with custom config
|
||||
custom_retry = RetryConfig(max_retries=5, initial_delay=0.5)
|
||||
custom_conn = ConnectionConfig(
|
||||
max_connections_per_peer=5, enable_connection_pool=False
|
||||
)
|
||||
custom_conn = ConnectionConfig(max_connections_per_peer=5)
|
||||
|
||||
swarm = Swarm(peer_id, peerstore, upgrader, transport, custom_retry, custom_conn)
|
||||
assert swarm.retry_config.max_retries == 5
|
||||
assert swarm.retry_config.initial_delay == 0.5
|
||||
assert swarm.connection_config.max_connections_per_peer == 5
|
||||
assert swarm.connection_pool is None
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
@ -273,143 +170,155 @@ async def test_swarm_retry_logic():
|
||||
|
||||
# Should have succeeded after 3 attempts
|
||||
assert attempt_count[0] == 3
|
||||
assert result is not None
|
||||
|
||||
# Should have taken some time due to retries
|
||||
assert end_time - start_time > 0.02 # At least 2 delays
|
||||
assert isinstance(result, MockConnection)
|
||||
assert end_time - start_time > 0.01 # Should have some delay
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_multi_connection_support():
|
||||
"""Test multi-connection support in Swarm."""
|
||||
async def test_swarm_load_balancing_strategies():
|
||||
"""Test load balancing strategies."""
|
||||
peer_id = ID(b"QmTest")
|
||||
peerstore = Mock()
|
||||
upgrader = Mock()
|
||||
transport = Mock()
|
||||
|
||||
connection_config = ConnectionConfig(
|
||||
max_connections_per_peer=3,
|
||||
enable_connection_pool=True,
|
||||
load_balancing_strategy="round_robin",
|
||||
)
|
||||
swarm = Swarm(peer_id, peerstore, upgrader, transport)
|
||||
|
||||
# Create mock connections with different stream counts
|
||||
conn1 = MockConnection(peer_id)
|
||||
conn2 = MockConnection(peer_id)
|
||||
conn3 = MockConnection(peer_id)
|
||||
|
||||
# Add some streams to simulate load
|
||||
await conn1.new_stream()
|
||||
await conn1.new_stream()
|
||||
await conn2.new_stream()
|
||||
|
||||
connections = [conn1, conn2, conn3]
|
||||
|
||||
# Test round-robin strategy
|
||||
swarm.connection_config.load_balancing_strategy = "round_robin"
|
||||
# Cast to satisfy type checker
|
||||
connections_cast = cast("list[INetConn]", connections)
|
||||
selected1 = swarm._select_connection(connections_cast, peer_id)
|
||||
selected2 = swarm._select_connection(connections_cast, peer_id)
|
||||
selected3 = swarm._select_connection(connections_cast, peer_id)
|
||||
|
||||
# Should cycle through connections
|
||||
assert selected1 in connections
|
||||
assert selected2 in connections
|
||||
assert selected3 in connections
|
||||
|
||||
# Test least loaded strategy
|
||||
swarm.connection_config.load_balancing_strategy = "least_loaded"
|
||||
least_loaded = swarm._select_connection(connections_cast, peer_id)
|
||||
|
||||
# conn3 has 0 streams, conn2 has 1 stream, conn1 has 2 streams
|
||||
# So conn3 should be selected as least loaded
|
||||
assert least_loaded == conn3
|
||||
|
||||
# Test default strategy (first connection)
|
||||
swarm.connection_config.load_balancing_strategy = "unknown"
|
||||
default_selected = swarm._select_connection(connections_cast, peer_id)
|
||||
assert default_selected == conn1
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_multiple_connections_api():
|
||||
"""Test the new multiple connections API methods."""
|
||||
peer_id = ID(b"QmTest")
|
||||
peerstore = Mock()
|
||||
upgrader = Mock()
|
||||
transport = Mock()
|
||||
|
||||
swarm = Swarm(peer_id, peerstore, upgrader, transport)
|
||||
|
||||
# Test empty connections
|
||||
assert swarm.get_connections() == []
|
||||
assert swarm.get_connections(peer_id) == []
|
||||
assert swarm.get_connection(peer_id) is None
|
||||
assert swarm.get_connections_map() == {}
|
||||
|
||||
# Add some connections
|
||||
conn1 = MockConnection(peer_id)
|
||||
conn2 = MockConnection(peer_id)
|
||||
swarm.connections[peer_id] = [conn1, conn2]
|
||||
|
||||
# Test get_connections with peer_id
|
||||
peer_connections = swarm.get_connections(peer_id)
|
||||
assert len(peer_connections) == 2
|
||||
assert conn1 in peer_connections
|
||||
assert conn2 in peer_connections
|
||||
|
||||
# Test get_connections without peer_id (all connections)
|
||||
all_connections = swarm.get_connections()
|
||||
assert len(all_connections) == 2
|
||||
assert conn1 in all_connections
|
||||
assert conn2 in all_connections
|
||||
|
||||
# Test get_connection (backward compatibility)
|
||||
single_conn = swarm.get_connection(peer_id)
|
||||
assert single_conn in [conn1, conn2]
|
||||
|
||||
# Test get_connections_map
|
||||
connections_map = swarm.get_connections_map()
|
||||
assert peer_id in connections_map
|
||||
assert connections_map[peer_id] == [conn1, conn2]
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_connection_trimming():
|
||||
"""Test connection trimming when limit is exceeded."""
|
||||
peer_id = ID(b"QmTest")
|
||||
peerstore = Mock()
|
||||
upgrader = Mock()
|
||||
transport = Mock()
|
||||
|
||||
# Set max connections to 2
|
||||
connection_config = ConnectionConfig(max_connections_per_peer=2)
|
||||
swarm = Swarm(
|
||||
peer_id, peerstore, upgrader, transport, connection_config=connection_config
|
||||
)
|
||||
|
||||
# Mock connection pool methods
|
||||
assert swarm.connection_pool is not None
|
||||
connection_pool = swarm.connection_pool
|
||||
connection_pool.has_connection = Mock(return_value=True)
|
||||
connection_pool.get_connection = Mock(return_value=MockConnection(peer_id))
|
||||
# Add 3 connections
|
||||
conn1 = MockConnection(peer_id)
|
||||
conn2 = MockConnection(peer_id)
|
||||
conn3 = MockConnection(peer_id)
|
||||
|
||||
# Test that new_stream uses connection pool
|
||||
result = await swarm.new_stream(peer_id)
|
||||
assert result is not None
|
||||
# Use the mocked method directly to avoid type checking issues
|
||||
get_connection_mock = connection_pool.get_connection
|
||||
assert get_connection_mock.call_count == 1
|
||||
swarm.connections[peer_id] = [conn1, conn2, conn3]
|
||||
|
||||
# Trigger trimming
|
||||
swarm._trim_connections(peer_id)
|
||||
|
||||
# Should have only 2 connections
|
||||
assert len(swarm.connections[peer_id]) == 2
|
||||
|
||||
# The most recent connections should remain
|
||||
remaining = swarm.connections[peer_id]
|
||||
assert conn2 in remaining
|
||||
assert conn3 in remaining
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_backward_compatibility():
|
||||
"""Test that enhanced Swarm maintains backward compatibility."""
|
||||
"""Test backward compatibility features."""
|
||||
peer_id = ID(b"QmTest")
|
||||
peerstore = Mock()
|
||||
upgrader = Mock()
|
||||
transport = Mock()
|
||||
|
||||
# Create swarm with connection pool disabled
|
||||
connection_config = ConnectionConfig(enable_connection_pool=False)
|
||||
swarm = Swarm(
|
||||
peer_id, peerstore, upgrader, transport, connection_config=connection_config
|
||||
)
|
||||
swarm = Swarm(peer_id, peerstore, upgrader, transport)
|
||||
|
||||
# Should behave like original swarm
|
||||
assert swarm.connection_pool is None
|
||||
assert isinstance(swarm.connections, dict)
|
||||
# Add connections
|
||||
conn1 = MockConnection(peer_id)
|
||||
conn2 = MockConnection(peer_id)
|
||||
swarm.connections[peer_id] = [conn1, conn2]
|
||||
|
||||
# Test that dial_peer still works (will fail due to mocks, but structure is correct)
|
||||
peerstore.addrs.return_value = [Mock(spec=Multiaddr)]
|
||||
transport.dial.side_effect = Exception("Transport error")
|
||||
|
||||
with pytest.raises(SwarmException):
|
||||
await swarm.dial_peer(peer_id)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_connection_pool_integration():
|
||||
"""Test integration between Swarm and ConnectionPool."""
|
||||
peer_id = ID(b"QmTest")
|
||||
peerstore = Mock()
|
||||
upgrader = Mock()
|
||||
transport = Mock()
|
||||
|
||||
connection_config = ConnectionConfig(
|
||||
max_connections_per_peer=2, enable_connection_pool=True
|
||||
)
|
||||
|
||||
swarm = Swarm(
|
||||
peer_id, peerstore, upgrader, transport, connection_config=connection_config
|
||||
)
|
||||
|
||||
# Mock successful connection creation
|
||||
mock_conn = MockConnection(peer_id)
|
||||
peerstore.addrs.return_value = [Mock(spec=Multiaddr)]
|
||||
|
||||
async def mock_dial_with_retry(addr, peer_id):
|
||||
return mock_conn
|
||||
|
||||
swarm._dial_with_retry = mock_dial_with_retry
|
||||
|
||||
# Test dial_peer adds to connection pool
|
||||
result = await swarm.dial_peer(peer_id)
|
||||
assert result == mock_conn
|
||||
assert swarm.connection_pool is not None
|
||||
assert swarm.connection_pool.has_connection(peer_id)
|
||||
|
||||
# Test that subsequent calls reuse connection
|
||||
result2 = await swarm.dial_peer(peer_id)
|
||||
assert result2 == mock_conn
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_connection_cleanup():
|
||||
"""Test connection cleanup in enhanced Swarm."""
|
||||
peer_id = ID(b"QmTest")
|
||||
peerstore = Mock()
|
||||
upgrader = Mock()
|
||||
transport = Mock()
|
||||
|
||||
connection_config = ConnectionConfig(enable_connection_pool=True)
|
||||
swarm = Swarm(
|
||||
peer_id, peerstore, upgrader, transport, connection_config=connection_config
|
||||
)
|
||||
|
||||
# Add a connection
|
||||
mock_conn = MockConnection(peer_id)
|
||||
swarm.connections[peer_id] = mock_conn
|
||||
assert swarm.connection_pool is not None
|
||||
swarm.connection_pool.add_connection(peer_id, mock_conn, "test_addr")
|
||||
|
||||
# Test close_peer removes from pool
|
||||
await swarm.close_peer(peer_id)
|
||||
assert swarm.connection_pool is not None
|
||||
assert not swarm.connection_pool.has_connection(peer_id)
|
||||
|
||||
# Test remove_conn removes from pool
|
||||
mock_conn2 = MockConnection(peer_id)
|
||||
swarm.connections[peer_id] = mock_conn2
|
||||
assert swarm.connection_pool is not None
|
||||
connection_pool = swarm.connection_pool
|
||||
connection_pool.add_connection(peer_id, mock_conn2, "test_addr2")
|
||||
|
||||
# Note: remove_conn expects SwarmConn, but for testing we'll just
|
||||
# remove from pool directly
|
||||
connection_pool = swarm.connection_pool
|
||||
connection_pool.remove_connection(peer_id, mock_conn2)
|
||||
assert connection_pool is not None
|
||||
assert not connection_pool.has_connection(peer_id)
|
||||
# Test connections_legacy property
|
||||
legacy_connections = swarm.connections_legacy
|
||||
assert peer_id in legacy_connections
|
||||
# Should return first connection
|
||||
assert legacy_connections[peer_id] in [conn1, conn2]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -51,14 +51,19 @@ async def test_swarm_dial_peer(security_protocol):
|
||||
for addr in transport.get_addrs()
|
||||
)
|
||||
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
|
||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
|
||||
# New: dial_peer now returns list of connections
|
||||
connections = await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
assert len(connections) > 0
|
||||
|
||||
# Verify connections are established in both directions
|
||||
assert swarms[0].get_peer_id() in swarms[1].connections
|
||||
assert swarms[1].get_peer_id() in swarms[0].connections
|
||||
|
||||
# Test: Reuse connections when we already have ones with a peer.
|
||||
conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()]
|
||||
conn = await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
assert conn is conn_to_1
|
||||
existing_connections = swarms[0].get_connections(swarms[1].get_peer_id())
|
||||
new_connections = await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
assert new_connections == existing_connections
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
@ -107,7 +112,8 @@ async def test_swarm_close_peer(security_protocol):
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_remove_conn(swarm_pair):
|
||||
swarm_0, swarm_1 = swarm_pair
|
||||
conn_0 = swarm_0.connections[swarm_1.get_peer_id()]
|
||||
# Get the first connection from the list
|
||||
conn_0 = swarm_0.connections[swarm_1.get_peer_id()][0]
|
||||
swarm_0.remove_conn(conn_0)
|
||||
assert swarm_1.get_peer_id() not in swarm_0.connections
|
||||
# Test: Remove twice. There should not be errors.
|
||||
@ -115,6 +121,67 @@ async def test_swarm_remove_conn(swarm_pair):
|
||||
assert swarm_1.get_peer_id() not in swarm_0.connections
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_multiple_connections(security_protocol):
|
||||
"""Test multiple connections per peer functionality."""
|
||||
async with SwarmFactory.create_batch_and_listen(
|
||||
2, security_protocol=security_protocol
|
||||
) as swarms:
|
||||
# Setup multiple addresses for peer
|
||||
addrs = tuple(
|
||||
addr
|
||||
for transport in swarms[1].listeners.values()
|
||||
for addr in transport.get_addrs()
|
||||
)
|
||||
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
|
||||
|
||||
# Dial peer - should return list of connections
|
||||
connections = await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
assert len(connections) > 0
|
||||
|
||||
# Test get_connections method
|
||||
peer_connections = swarms[0].get_connections(swarms[1].get_peer_id())
|
||||
assert len(peer_connections) == len(connections)
|
||||
|
||||
# Test get_connections_map method
|
||||
connections_map = swarms[0].get_connections_map()
|
||||
assert swarms[1].get_peer_id() in connections_map
|
||||
assert len(connections_map[swarms[1].get_peer_id()]) == len(connections)
|
||||
|
||||
# Test get_connection method (backward compatibility)
|
||||
single_conn = swarms[0].get_connection(swarms[1].get_peer_id())
|
||||
assert single_conn is not None
|
||||
assert single_conn in connections
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_load_balancing(security_protocol):
|
||||
"""Test load balancing across multiple connections."""
|
||||
async with SwarmFactory.create_batch_and_listen(
|
||||
2, security_protocol=security_protocol
|
||||
) as swarms:
|
||||
# Setup connection
|
||||
addrs = tuple(
|
||||
addr
|
||||
for transport in swarms[1].listeners.values()
|
||||
for addr in transport.get_addrs()
|
||||
)
|
||||
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
|
||||
|
||||
# Create multiple streams - should use load balancing
|
||||
streams = []
|
||||
for _ in range(5):
|
||||
stream = await swarms[0].new_stream(swarms[1].get_peer_id())
|
||||
streams.append(stream)
|
||||
|
||||
# Verify streams were created successfully
|
||||
assert len(streams) == 5
|
||||
|
||||
# Clean up
|
||||
for stream in streams:
|
||||
await stream.close()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_multiaddr(security_protocol):
|
||||
async with SwarmFactory.create_batch_and_listen(
|
||||
|
||||
@ -51,6 +51,9 @@ async def perform_simple_test(assertion_func, security_protocol):
|
||||
|
||||
# Extract the secured connection from either Mplex or Yamux implementation
|
||||
def get_secured_conn(conn):
|
||||
# conn is now a list, get the first connection
|
||||
if isinstance(conn, list):
|
||||
conn = conn[0]
|
||||
muxed_conn = conn.muxed_conn
|
||||
# Direct attribute access for known implementations
|
||||
has_secured_conn = hasattr(muxed_conn, "secured_conn")
|
||||
|
||||
@ -74,7 +74,8 @@ async def test_multiplexer_preference_parameter(muxer_preference):
|
||||
assert len(connections) > 0, "Connection not established"
|
||||
|
||||
# Get the first connection
|
||||
conn = list(connections.values())[0]
|
||||
conns = list(connections.values())[0]
|
||||
conn = conns[0] # Get first connection from the list
|
||||
muxed_conn = conn.muxed_conn
|
||||
|
||||
# Define a simple echo protocol
|
||||
@ -150,7 +151,8 @@ async def test_explicit_muxer_options(muxer_option_func, expected_stream_class):
|
||||
assert len(connections) > 0, "Connection not established"
|
||||
|
||||
# Get the first connection
|
||||
conn = list(connections.values())[0]
|
||||
conns = list(connections.values())[0]
|
||||
conn = conns[0] # Get first connection from the list
|
||||
muxed_conn = conn.muxed_conn
|
||||
|
||||
# Define a simple echo protocol
|
||||
@ -219,7 +221,8 @@ async def test_global_default_muxer(global_default):
|
||||
assert len(connections) > 0, "Connection not established"
|
||||
|
||||
# Get the first connection
|
||||
conn = list(connections.values())[0]
|
||||
conns = list(connections.values())[0]
|
||||
conn = conns[0] # Get first connection from the list
|
||||
muxed_conn = conn.muxed_conn
|
||||
|
||||
# Define a simple echo protocol
|
||||
|
||||
@ -669,8 +669,8 @@ async def swarm_conn_pair_factory(
|
||||
async with swarm_pair_factory(
|
||||
security_protocol=security_protocol, muxer_opt=muxer_opt
|
||||
) as swarms:
|
||||
conn_0 = swarms[0].connections[swarms[1].get_peer_id()]
|
||||
conn_1 = swarms[1].connections[swarms[0].get_peer_id()]
|
||||
conn_0 = swarms[0].connections[swarms[1].get_peer_id()][0]
|
||||
conn_1 = swarms[1].connections[swarms[0].get_peer_id()][0]
|
||||
yield cast(SwarmConn, conn_0), cast(SwarmConn, conn_1)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user