mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-12 16:10:57 +00:00
address architectural refactoring discussed
This commit is contained in:
133
docs/examples.multiple_connections.rst
Normal file
133
docs/examples.multiple_connections.rst
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
Multiple Connections Per Peer
|
||||||
|
============================
|
||||||
|
|
||||||
|
This example demonstrates how to use the multiple connections per peer feature in py-libp2p.
|
||||||
|
|
||||||
|
Overview
|
||||||
|
--------
|
||||||
|
|
||||||
|
The multiple connections per peer feature allows a libp2p node to maintain multiple network connections to the same peer. This provides several benefits:
|
||||||
|
|
||||||
|
- **Improved reliability**: If one connection fails, others remain available
|
||||||
|
- **Better performance**: Load can be distributed across multiple connections
|
||||||
|
- **Enhanced throughput**: Multiple streams can be created in parallel
|
||||||
|
- **Fault tolerance**: Redundant connections provide backup paths
|
||||||
|
|
||||||
|
Configuration
|
||||||
|
-------------
|
||||||
|
|
||||||
|
The feature is configured through the `ConnectionConfig` class:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from libp2p.network.swarm import ConnectionConfig
|
||||||
|
|
||||||
|
# Default configuration
|
||||||
|
config = ConnectionConfig()
|
||||||
|
print(f"Max connections per peer: {config.max_connections_per_peer}")
|
||||||
|
print(f"Load balancing strategy: {config.load_balancing_strategy}")
|
||||||
|
|
||||||
|
# Custom configuration
|
||||||
|
custom_config = ConnectionConfig(
|
||||||
|
max_connections_per_peer=5,
|
||||||
|
connection_timeout=60.0,
|
||||||
|
load_balancing_strategy="least_loaded"
|
||||||
|
)
|
||||||
|
|
||||||
|
Load Balancing Strategies
|
||||||
|
------------------------
|
||||||
|
|
||||||
|
Two load balancing strategies are available:
|
||||||
|
|
||||||
|
**Round Robin** (default)
|
||||||
|
Cycles through connections in order, distributing load evenly.
|
||||||
|
|
||||||
|
**Least Loaded**
|
||||||
|
Selects the connection with the fewest active streams.
|
||||||
|
|
||||||
|
API Usage
|
||||||
|
---------
|
||||||
|
|
||||||
|
The new API provides direct access to multiple connections:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from libp2p import new_swarm
|
||||||
|
|
||||||
|
# Create swarm with multiple connections support
|
||||||
|
swarm = new_swarm()
|
||||||
|
|
||||||
|
# Dial a peer - returns list of connections
|
||||||
|
connections = await swarm.dial_peer(peer_id)
|
||||||
|
print(f"Established {len(connections)} connections")
|
||||||
|
|
||||||
|
# Get all connections to a peer
|
||||||
|
peer_connections = swarm.get_connections(peer_id)
|
||||||
|
|
||||||
|
# Get all connections (across all peers)
|
||||||
|
all_connections = swarm.get_connections()
|
||||||
|
|
||||||
|
# Get the complete connections map
|
||||||
|
connections_map = swarm.get_connections_map()
|
||||||
|
|
||||||
|
# Backward compatibility - get single connection
|
||||||
|
single_conn = swarm.get_connection(peer_id)
|
||||||
|
|
||||||
|
Backward Compatibility
|
||||||
|
---------------------
|
||||||
|
|
||||||
|
Existing code continues to work through backward compatibility features:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
# Legacy 1:1 mapping (returns first connection for each peer)
|
||||||
|
legacy_connections = swarm.connections_legacy
|
||||||
|
|
||||||
|
# Single connection access (returns first available connection)
|
||||||
|
conn = swarm.get_connection(peer_id)
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
|
||||||
|
See :doc:`examples/doc-examples/multiple_connections_example.py` for a complete working example.
|
||||||
|
|
||||||
|
Production Configuration
|
||||||
|
-----------------------
|
||||||
|
|
||||||
|
For production use, consider these settings:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from libp2p.network.swarm import ConnectionConfig, RetryConfig
|
||||||
|
|
||||||
|
# Production-ready configuration
|
||||||
|
retry_config = RetryConfig(
|
||||||
|
max_retries=3,
|
||||||
|
initial_delay=0.1,
|
||||||
|
max_delay=30.0,
|
||||||
|
backoff_multiplier=2.0,
|
||||||
|
jitter_factor=0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
connection_config = ConnectionConfig(
|
||||||
|
max_connections_per_peer=3, # Balance performance and resources
|
||||||
|
connection_timeout=30.0, # Reasonable timeout
|
||||||
|
load_balancing_strategy="round_robin" # Predictable behavior
|
||||||
|
)
|
||||||
|
|
||||||
|
swarm = new_swarm(
|
||||||
|
retry_config=retry_config,
|
||||||
|
connection_config=connection_config
|
||||||
|
)
|
||||||
|
|
||||||
|
Architecture
|
||||||
|
-----------
|
||||||
|
|
||||||
|
The implementation follows the same architectural patterns as the Go and JavaScript reference implementations:
|
||||||
|
|
||||||
|
- **Core data structure**: `dict[ID, list[INetConn]]` for 1:many mapping
|
||||||
|
- **API consistency**: Methods like `get_connections()` match reference implementations
|
||||||
|
- **Load balancing**: Integrated at the API level for optimal performance
|
||||||
|
- **Backward compatibility**: Maintains existing interfaces for gradual migration
|
||||||
|
|
||||||
|
This design ensures consistency across libp2p implementations while providing the benefits of multiple connections per peer.
|
||||||
@ -15,3 +15,4 @@ Examples
|
|||||||
examples.kademlia
|
examples.kademlia
|
||||||
examples.mDNS
|
examples.mDNS
|
||||||
examples.random_walk
|
examples.random_walk
|
||||||
|
examples.multiple_connections
|
||||||
|
|||||||
@ -1,18 +1,18 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""
|
"""
|
||||||
Example demonstrating the enhanced Swarm with retry logic, exponential backoff,
|
Example demonstrating multiple connections per peer support in libp2p.
|
||||||
and multi-connection support.
|
|
||||||
|
|
||||||
This example shows how to:
|
This example shows how to:
|
||||||
1. Configure retry behavior with exponential backoff
|
1. Configure multiple connections per peer
|
||||||
2. Enable multi-connection support with connection pooling
|
2. Use different load balancing strategies
|
||||||
3. Use different load balancing strategies
|
3. Access multiple connections through the new API
|
||||||
4. Maintain backward compatibility
|
4. Maintain backward compatibility
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import trio
|
||||||
|
|
||||||
from libp2p import new_swarm
|
from libp2p import new_swarm
|
||||||
from libp2p.network.swarm import ConnectionConfig, RetryConfig
|
from libp2p.network.swarm import ConnectionConfig, RetryConfig
|
||||||
|
|
||||||
@ -21,64 +21,32 @@ logging.basicConfig(level=logging.INFO)
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def example_basic_enhanced_swarm() -> None:
|
async def example_basic_multiple_connections() -> None:
|
||||||
"""Example of basic enhanced Swarm usage."""
|
"""Example of basic multiple connections per peer usage."""
|
||||||
logger.info("Creating enhanced Swarm with default configuration...")
|
logger.info("Creating swarm with multiple connections support...")
|
||||||
|
|
||||||
# Create enhanced swarm with default retry and connection config
|
# Create swarm with default configuration
|
||||||
swarm = new_swarm()
|
swarm = new_swarm()
|
||||||
# Use default configuration values directly
|
|
||||||
default_retry = RetryConfig()
|
|
||||||
default_connection = ConnectionConfig()
|
default_connection = ConnectionConfig()
|
||||||
|
|
||||||
logger.info(f"Swarm created with peer ID: {swarm.get_peer_id()}")
|
logger.info(f"Swarm created with peer ID: {swarm.get_peer_id()}")
|
||||||
logger.info(f"Retry config: max_retries={default_retry.max_retries}")
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Connection config: max_connections_per_peer="
|
f"Connection config: max_connections_per_peer="
|
||||||
f"{default_connection.max_connections_per_peer}"
|
f"{default_connection.max_connections_per_peer}"
|
||||||
)
|
)
|
||||||
logger.info(f"Connection pool enabled: {default_connection.enable_connection_pool}")
|
|
||||||
|
|
||||||
await swarm.close()
|
await swarm.close()
|
||||||
logger.info("Basic enhanced Swarm example completed")
|
logger.info("Basic multiple connections example completed")
|
||||||
|
|
||||||
|
|
||||||
async def example_custom_retry_config() -> None:
|
|
||||||
"""Example of custom retry configuration."""
|
|
||||||
logger.info("Creating enhanced Swarm with custom retry configuration...")
|
|
||||||
|
|
||||||
# Custom retry configuration for aggressive retry behavior
|
|
||||||
retry_config = RetryConfig(
|
|
||||||
max_retries=5, # More retries
|
|
||||||
initial_delay=0.05, # Faster initial retry
|
|
||||||
max_delay=10.0, # Lower max delay
|
|
||||||
backoff_multiplier=1.5, # Less aggressive backoff
|
|
||||||
jitter_factor=0.2, # More jitter
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create swarm with custom retry config
|
|
||||||
swarm = new_swarm(retry_config=retry_config)
|
|
||||||
|
|
||||||
logger.info("Custom retry config applied:")
|
|
||||||
logger.info(f" Max retries: {retry_config.max_retries}")
|
|
||||||
logger.info(f" Initial delay: {retry_config.initial_delay}s")
|
|
||||||
logger.info(f" Max delay: {retry_config.max_delay}s")
|
|
||||||
logger.info(f" Backoff multiplier: {retry_config.backoff_multiplier}")
|
|
||||||
logger.info(f" Jitter factor: {retry_config.jitter_factor}")
|
|
||||||
|
|
||||||
await swarm.close()
|
|
||||||
logger.info("Custom retry config example completed")
|
|
||||||
|
|
||||||
|
|
||||||
async def example_custom_connection_config() -> None:
|
async def example_custom_connection_config() -> None:
|
||||||
"""Example of custom connection configuration."""
|
"""Example of custom connection configuration."""
|
||||||
logger.info("Creating enhanced Swarm with custom connection configuration...")
|
logger.info("Creating swarm with custom connection configuration...")
|
||||||
|
|
||||||
# Custom connection configuration for high-performance scenarios
|
# Custom connection configuration for high-performance scenarios
|
||||||
connection_config = ConnectionConfig(
|
connection_config = ConnectionConfig(
|
||||||
max_connections_per_peer=5, # More connections per peer
|
max_connections_per_peer=5, # More connections per peer
|
||||||
connection_timeout=60.0, # Longer timeout
|
connection_timeout=60.0, # Longer timeout
|
||||||
enable_connection_pool=True, # Enable connection pooling
|
|
||||||
load_balancing_strategy="least_loaded", # Use least loaded strategy
|
load_balancing_strategy="least_loaded", # Use least loaded strategy
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -90,9 +58,6 @@ async def example_custom_connection_config() -> None:
|
|||||||
f" Max connections per peer: {connection_config.max_connections_per_peer}"
|
f" Max connections per peer: {connection_config.max_connections_per_peer}"
|
||||||
)
|
)
|
||||||
logger.info(f" Connection timeout: {connection_config.connection_timeout}s")
|
logger.info(f" Connection timeout: {connection_config.connection_timeout}s")
|
||||||
logger.info(
|
|
||||||
f" Connection pool enabled: {connection_config.enable_connection_pool}"
|
|
||||||
)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f" Load balancing strategy: {connection_config.load_balancing_strategy}"
|
f" Load balancing strategy: {connection_config.load_balancing_strategy}"
|
||||||
)
|
)
|
||||||
@ -101,22 +66,39 @@ async def example_custom_connection_config() -> None:
|
|||||||
logger.info("Custom connection config example completed")
|
logger.info("Custom connection config example completed")
|
||||||
|
|
||||||
|
|
||||||
async def example_backward_compatibility() -> None:
|
async def example_multiple_connections_api() -> None:
|
||||||
"""Example showing backward compatibility."""
|
"""Example of using the new multiple connections API."""
|
||||||
logger.info("Creating enhanced Swarm with backward compatibility...")
|
logger.info("Demonstrating multiple connections API...")
|
||||||
|
|
||||||
# Disable connection pool to maintain original behavior
|
connection_config = ConnectionConfig(
|
||||||
connection_config = ConnectionConfig(enable_connection_pool=False)
|
max_connections_per_peer=3,
|
||||||
|
load_balancing_strategy="round_robin"
|
||||||
|
)
|
||||||
|
|
||||||
# Create swarm with connection pool disabled
|
|
||||||
swarm = new_swarm(connection_config=connection_config)
|
swarm = new_swarm(connection_config=connection_config)
|
||||||
|
|
||||||
logger.info("Backward compatibility mode:")
|
logger.info("Multiple connections API features:")
|
||||||
|
logger.info(" - dial_peer() returns list[INetConn]")
|
||||||
|
logger.info(" - get_connections(peer_id) returns list[INetConn]")
|
||||||
|
logger.info(" - get_connections_map() returns dict[ID, list[INetConn]]")
|
||||||
logger.info(
|
logger.info(
|
||||||
f" Connection pool enabled: {connection_config.enable_connection_pool}"
|
" - get_connection(peer_id) returns INetConn | None (backward compatibility)"
|
||||||
)
|
)
|
||||||
logger.info(f" Connections dict type: {type(swarm.connections)}")
|
|
||||||
logger.info(" Retry logic still available: 3 max retries")
|
await swarm.close()
|
||||||
|
logger.info("Multiple connections API example completed")
|
||||||
|
|
||||||
|
|
||||||
|
async def example_backward_compatibility() -> None:
|
||||||
|
"""Example of backward compatibility features."""
|
||||||
|
logger.info("Demonstrating backward compatibility...")
|
||||||
|
|
||||||
|
swarm = new_swarm()
|
||||||
|
|
||||||
|
logger.info("Backward compatibility features:")
|
||||||
|
logger.info(" - connections_legacy property provides 1:1 mapping")
|
||||||
|
logger.info(" - get_connection() method for single connection access")
|
||||||
|
logger.info(" - Existing code continues to work")
|
||||||
|
|
||||||
await swarm.close()
|
await swarm.close()
|
||||||
logger.info("Backward compatibility example completed")
|
logger.info("Backward compatibility example completed")
|
||||||
@ -124,7 +106,7 @@ async def example_backward_compatibility() -> None:
|
|||||||
|
|
||||||
async def example_production_ready_config() -> None:
|
async def example_production_ready_config() -> None:
|
||||||
"""Example of production-ready configuration."""
|
"""Example of production-ready configuration."""
|
||||||
logger.info("Creating enhanced Swarm with production-ready configuration...")
|
logger.info("Creating swarm with production-ready configuration...")
|
||||||
|
|
||||||
# Production-ready retry configuration
|
# Production-ready retry configuration
|
||||||
retry_config = RetryConfig(
|
retry_config = RetryConfig(
|
||||||
@ -139,7 +121,6 @@ async def example_production_ready_config() -> None:
|
|||||||
connection_config = ConnectionConfig(
|
connection_config = ConnectionConfig(
|
||||||
max_connections_per_peer=3, # Balance between performance and resource usage
|
max_connections_per_peer=3, # Balance between performance and resource usage
|
||||||
connection_timeout=30.0, # Reasonable timeout
|
connection_timeout=30.0, # Reasonable timeout
|
||||||
enable_connection_pool=True, # Enable for better performance
|
|
||||||
load_balancing_strategy="round_robin", # Simple, predictable strategy
|
load_balancing_strategy="round_robin", # Simple, predictable strategy
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -160,19 +141,19 @@ async def example_production_ready_config() -> None:
|
|||||||
|
|
||||||
async def main() -> None:
|
async def main() -> None:
|
||||||
"""Run all examples."""
|
"""Run all examples."""
|
||||||
logger.info("Enhanced Swarm Examples")
|
logger.info("Multiple Connections Per Peer Examples")
|
||||||
logger.info("=" * 50)
|
logger.info("=" * 50)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await example_basic_enhanced_swarm()
|
await example_basic_multiple_connections()
|
||||||
logger.info("-" * 30)
|
|
||||||
|
|
||||||
await example_custom_retry_config()
|
|
||||||
logger.info("-" * 30)
|
logger.info("-" * 30)
|
||||||
|
|
||||||
await example_custom_connection_config()
|
await example_custom_connection_config()
|
||||||
logger.info("-" * 30)
|
logger.info("-" * 30)
|
||||||
|
|
||||||
|
await example_multiple_connections_api()
|
||||||
|
logger.info("-" * 30)
|
||||||
|
|
||||||
await example_backward_compatibility()
|
await example_backward_compatibility()
|
||||||
logger.info("-" * 30)
|
logger.info("-" * 30)
|
||||||
|
|
||||||
@ -187,4 +168,4 @@ async def main() -> None:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
trio.run(main)
|
||||||
@ -1412,15 +1412,16 @@ class INetwork(ABC):
|
|||||||
----------
|
----------
|
||||||
peerstore : IPeerStore
|
peerstore : IPeerStore
|
||||||
The peer store for managing peer information.
|
The peer store for managing peer information.
|
||||||
connections : dict[ID, INetConn]
|
connections : dict[ID, list[INetConn]]
|
||||||
A mapping of peer IDs to network connections.
|
A mapping of peer IDs to lists of network connections
|
||||||
|
(multiple connections per peer).
|
||||||
listeners : dict[str, IListener]
|
listeners : dict[str, IListener]
|
||||||
A mapping of listener identifiers to listener instances.
|
A mapping of listener identifiers to listener instances.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
peerstore: IPeerStore
|
peerstore: IPeerStore
|
||||||
connections: dict[ID, INetConn]
|
connections: dict[ID, list[INetConn]]
|
||||||
listeners: dict[str, IListener]
|
listeners: dict[str, IListener]
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -1436,9 +1437,56 @@ class INetwork(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def dial_peer(self, peer_id: ID) -> INetConn:
|
def get_connections(self, peer_id: ID | None = None) -> list[INetConn]:
|
||||||
"""
|
"""
|
||||||
Create a connection to the specified peer.
|
Get connections for peer (like JS getConnections, Go ConnsToPeer).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
peer_id : ID | None
|
||||||
|
The peer ID to get connections for. If None, returns all connections.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list[INetConn]
|
||||||
|
List of connections to the specified peer, or all connections
|
||||||
|
if peer_id is None.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_connections_map(self) -> dict[ID, list[INetConn]]:
|
||||||
|
"""
|
||||||
|
Get all connections map (like JS getConnectionsMap).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dict[ID, list[INetConn]]
|
||||||
|
The complete mapping of peer IDs to their connection lists.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_connection(self, peer_id: ID) -> INetConn | None:
|
||||||
|
"""
|
||||||
|
Get single connection for backward compatibility.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
peer_id : ID
|
||||||
|
The peer ID to get a connection for.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
INetConn | None
|
||||||
|
The first available connection, or None if no connections exist.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def dial_peer(self, peer_id: ID) -> list[INetConn]:
|
||||||
|
"""
|
||||||
|
Create connections to the specified peer with load balancing.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@ -1447,8 +1495,8 @@ class INetwork(ABC):
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
INetConn
|
list[INetConn]
|
||||||
The network connection instance to the specified peer.
|
List of established connections to the peer.
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
|
|||||||
@ -338,7 +338,7 @@ class BasicHost(IHost):
|
|||||||
:param peer_id: ID of the peer to check
|
:param peer_id: ID of the peer to check
|
||||||
:return: True if peer has an active connection, False otherwise
|
:return: True if peer has an active connection, False otherwise
|
||||||
"""
|
"""
|
||||||
return peer_id in self._network.connections
|
return len(self._network.get_connections(peer_id)) > 0
|
||||||
|
|
||||||
def get_peer_connection_info(self, peer_id: ID) -> INetConn | None:
|
def get_peer_connection_info(self, peer_id: ID) -> INetConn | None:
|
||||||
"""
|
"""
|
||||||
@ -347,4 +347,4 @@ class BasicHost(IHost):
|
|||||||
:param peer_id: ID of the peer to get info for
|
:param peer_id: ID of the peer to get info for
|
||||||
:return: Connection object if peer is connected, None otherwise
|
:return: Connection object if peer is connected, None otherwise
|
||||||
"""
|
"""
|
||||||
return self._network.connections.get(peer_id)
|
return self._network.get_connection(peer_id)
|
||||||
|
|||||||
@ -74,175 +74,13 @@ class RetryConfig:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConnectionConfig:
|
class ConnectionConfig:
|
||||||
"""Configuration for connection pool and multi-connection support."""
|
"""Configuration for multi-connection support."""
|
||||||
|
|
||||||
max_connections_per_peer: int = 3
|
max_connections_per_peer: int = 3
|
||||||
connection_timeout: float = 30.0
|
connection_timeout: float = 30.0
|
||||||
enable_connection_pool: bool = True
|
|
||||||
load_balancing_strategy: str = "round_robin" # or "least_loaded"
|
load_balancing_strategy: str = "round_robin" # or "least_loaded"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ConnectionInfo:
|
|
||||||
"""Information about a connection in the pool."""
|
|
||||||
|
|
||||||
connection: INetConn
|
|
||||||
address: str
|
|
||||||
established_at: float
|
|
||||||
last_used: float
|
|
||||||
stream_count: int
|
|
||||||
is_healthy: bool
|
|
||||||
|
|
||||||
|
|
||||||
class ConnectionPool:
|
|
||||||
"""Manages multiple connections per peer with load balancing."""
|
|
||||||
|
|
||||||
def __init__(self, max_connections_per_peer: int = 3):
|
|
||||||
self.max_connections_per_peer = max_connections_per_peer
|
|
||||||
self.peer_connections: dict[ID, list[ConnectionInfo]] = {}
|
|
||||||
self._round_robin_index: dict[ID, int] = {}
|
|
||||||
|
|
||||||
def add_connection(self, peer_id: ID, connection: INetConn, address: str) -> None:
|
|
||||||
"""Add a connection to the pool with deduplication."""
|
|
||||||
if peer_id not in self.peer_connections:
|
|
||||||
self.peer_connections[peer_id] = []
|
|
||||||
|
|
||||||
# Check for duplicate connections to the same address
|
|
||||||
for conn_info in self.peer_connections[peer_id]:
|
|
||||||
if conn_info.address == address:
|
|
||||||
logger.debug(
|
|
||||||
f"Connection to {address} already exists for peer {peer_id}"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Add new connection
|
|
||||||
try:
|
|
||||||
current_time = trio.current_time()
|
|
||||||
except RuntimeError:
|
|
||||||
# Fallback for testing contexts where trio is not running
|
|
||||||
import time
|
|
||||||
|
|
||||||
current_time = time.time()
|
|
||||||
|
|
||||||
conn_info = ConnectionInfo(
|
|
||||||
connection=connection,
|
|
||||||
address=address,
|
|
||||||
established_at=current_time,
|
|
||||||
last_used=current_time,
|
|
||||||
stream_count=0,
|
|
||||||
is_healthy=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.peer_connections[peer_id].append(conn_info)
|
|
||||||
|
|
||||||
# Trim if we exceed max connections
|
|
||||||
if len(self.peer_connections[peer_id]) > self.max_connections_per_peer:
|
|
||||||
self._trim_connections(peer_id)
|
|
||||||
|
|
||||||
def get_connection(
|
|
||||||
self, peer_id: ID, strategy: str = "round_robin"
|
|
||||||
) -> INetConn | None:
|
|
||||||
"""Get a connection using the specified load balancing strategy."""
|
|
||||||
if peer_id not in self.peer_connections or not self.peer_connections[peer_id]:
|
|
||||||
return None
|
|
||||||
|
|
||||||
connections = self.peer_connections[peer_id]
|
|
||||||
|
|
||||||
if strategy == "round_robin":
|
|
||||||
if peer_id not in self._round_robin_index:
|
|
||||||
self._round_robin_index[peer_id] = 0
|
|
||||||
|
|
||||||
index = self._round_robin_index[peer_id] % len(connections)
|
|
||||||
self._round_robin_index[peer_id] += 1
|
|
||||||
|
|
||||||
conn_info = connections[index]
|
|
||||||
try:
|
|
||||||
conn_info.last_used = trio.current_time()
|
|
||||||
except RuntimeError:
|
|
||||||
import time
|
|
||||||
|
|
||||||
conn_info.last_used = time.time()
|
|
||||||
return conn_info.connection
|
|
||||||
|
|
||||||
elif strategy == "least_loaded":
|
|
||||||
# Find connection with least streams
|
|
||||||
# Note: stream_count is a custom attribute we add to connections
|
|
||||||
conn_info = min(
|
|
||||||
connections, key=lambda c: getattr(c.connection, "stream_count", 0)
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
conn_info.last_used = trio.current_time()
|
|
||||||
except RuntimeError:
|
|
||||||
import time
|
|
||||||
|
|
||||||
conn_info.last_used = time.time()
|
|
||||||
return conn_info.connection
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Default to first connection
|
|
||||||
conn_info = connections[0]
|
|
||||||
try:
|
|
||||||
conn_info.last_used = trio.current_time()
|
|
||||||
except RuntimeError:
|
|
||||||
import time
|
|
||||||
|
|
||||||
conn_info.last_used = time.time()
|
|
||||||
return conn_info.connection
|
|
||||||
|
|
||||||
def has_connection(self, peer_id: ID) -> bool:
|
|
||||||
"""Check if we have any connections to the peer."""
|
|
||||||
return (
|
|
||||||
peer_id in self.peer_connections and len(self.peer_connections[peer_id]) > 0
|
|
||||||
)
|
|
||||||
|
|
||||||
def remove_connection(self, peer_id: ID, connection: INetConn) -> None:
|
|
||||||
"""Remove a connection from the pool."""
|
|
||||||
if peer_id in self.peer_connections:
|
|
||||||
self.peer_connections[peer_id] = [
|
|
||||||
conn_info
|
|
||||||
for conn_info in self.peer_connections[peer_id]
|
|
||||||
if conn_info.connection != connection
|
|
||||||
]
|
|
||||||
|
|
||||||
# Clean up empty peer entries
|
|
||||||
if not self.peer_connections[peer_id]:
|
|
||||||
del self.peer_connections[peer_id]
|
|
||||||
if peer_id in self._round_robin_index:
|
|
||||||
del self._round_robin_index[peer_id]
|
|
||||||
|
|
||||||
def _trim_connections(self, peer_id: ID) -> None:
|
|
||||||
"""Remove oldest connections when limit is exceeded."""
|
|
||||||
connections = self.peer_connections[peer_id]
|
|
||||||
if len(connections) <= self.max_connections_per_peer:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Sort by last used time and remove oldest
|
|
||||||
connections.sort(key=lambda c: c.last_used)
|
|
||||||
connections_to_remove = connections[: -self.max_connections_per_peer]
|
|
||||||
|
|
||||||
for conn_info in connections_to_remove:
|
|
||||||
logger.debug(
|
|
||||||
f"Trimming old connection to {conn_info.address} for peer {peer_id}"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
# Close the connection asynchronously
|
|
||||||
trio.lowlevel.spawn_system_task(
|
|
||||||
self._close_connection_async, conn_info.connection
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error closing trimmed connection: {e}")
|
|
||||||
|
|
||||||
# Keep only the most recently used connections
|
|
||||||
self.peer_connections[peer_id] = connections[-self.max_connections_per_peer :]
|
|
||||||
|
|
||||||
async def _close_connection_async(self, connection: INetConn) -> None:
|
|
||||||
"""Close a connection asynchronously."""
|
|
||||||
try:
|
|
||||||
await connection.close()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error closing connection: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def create_default_stream_handler(network: INetworkService) -> StreamHandlerFn:
|
def create_default_stream_handler(network: INetworkService) -> StreamHandlerFn:
|
||||||
async def stream_handler(stream: INetStream) -> None:
|
async def stream_handler(stream: INetStream) -> None:
|
||||||
await network.get_manager().wait_finished()
|
await network.get_manager().wait_finished()
|
||||||
@ -256,7 +94,7 @@ class Swarm(Service, INetworkService):
|
|||||||
upgrader: TransportUpgrader
|
upgrader: TransportUpgrader
|
||||||
transport: ITransport
|
transport: ITransport
|
||||||
# Enhanced: Support for multiple connections per peer
|
# Enhanced: Support for multiple connections per peer
|
||||||
connections: dict[ID, INetConn] # Backward compatibility
|
connections: dict[ID, list[INetConn]] # Multiple connections per peer
|
||||||
listeners: dict[str, IListener]
|
listeners: dict[str, IListener]
|
||||||
common_stream_handler: StreamHandlerFn
|
common_stream_handler: StreamHandlerFn
|
||||||
listener_nursery: trio.Nursery | None
|
listener_nursery: trio.Nursery | None
|
||||||
@ -264,10 +102,10 @@ class Swarm(Service, INetworkService):
|
|||||||
|
|
||||||
notifees: list[INotifee]
|
notifees: list[INotifee]
|
||||||
|
|
||||||
# Enhanced: New configuration and connection pool
|
# Enhanced: New configuration
|
||||||
retry_config: RetryConfig
|
retry_config: RetryConfig
|
||||||
connection_config: ConnectionConfig
|
connection_config: ConnectionConfig
|
||||||
connection_pool: ConnectionPool | None
|
_round_robin_index: dict[ID, int]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -287,16 +125,8 @@ class Swarm(Service, INetworkService):
|
|||||||
self.retry_config = retry_config or RetryConfig()
|
self.retry_config = retry_config or RetryConfig()
|
||||||
self.connection_config = connection_config or ConnectionConfig()
|
self.connection_config = connection_config or ConnectionConfig()
|
||||||
|
|
||||||
# Enhanced: Initialize connection pool if enabled
|
# Enhanced: Initialize connections as 1:many mapping
|
||||||
if self.connection_config.enable_connection_pool:
|
self.connections = {}
|
||||||
self.connection_pool = ConnectionPool(
|
|
||||||
self.connection_config.max_connections_per_peer
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.connection_pool = None
|
|
||||||
|
|
||||||
# Backward compatibility: Keep existing connections dict
|
|
||||||
self.connections = dict()
|
|
||||||
self.listeners = dict()
|
self.listeners = dict()
|
||||||
|
|
||||||
# Create Notifee array
|
# Create Notifee array
|
||||||
@ -307,6 +137,9 @@ class Swarm(Service, INetworkService):
|
|||||||
self.listener_nursery = None
|
self.listener_nursery = None
|
||||||
self.event_listener_nursery_created = trio.Event()
|
self.event_listener_nursery_created = trio.Event()
|
||||||
|
|
||||||
|
# Load balancing state
|
||||||
|
self._round_robin_index = {}
|
||||||
|
|
||||||
async def run(self) -> None:
|
async def run(self) -> None:
|
||||||
async with trio.open_nursery() as nursery:
|
async with trio.open_nursery() as nursery:
|
||||||
# Create a nursery for listener tasks.
|
# Create a nursery for listener tasks.
|
||||||
@ -326,26 +159,74 @@ class Swarm(Service, INetworkService):
|
|||||||
def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None:
|
def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None:
|
||||||
self.common_stream_handler = stream_handler
|
self.common_stream_handler = stream_handler
|
||||||
|
|
||||||
async def dial_peer(self, peer_id: ID) -> INetConn:
|
def get_connections(self, peer_id: ID | None = None) -> list[INetConn]:
|
||||||
"""
|
"""
|
||||||
Try to create a connection to peer_id with enhanced retry logic.
|
Get connections for peer (like JS getConnections, Go ConnsToPeer).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
peer_id : ID | None
|
||||||
|
The peer ID to get connections for. If None, returns all connections.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list[INetConn]
|
||||||
|
List of connections to the specified peer, or all connections
|
||||||
|
if peer_id is None.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if peer_id is not None:
|
||||||
|
return self.connections.get(peer_id, [])
|
||||||
|
|
||||||
|
# Return all connections from all peers
|
||||||
|
all_conns = []
|
||||||
|
for conns in self.connections.values():
|
||||||
|
all_conns.extend(conns)
|
||||||
|
return all_conns
|
||||||
|
|
||||||
|
def get_connections_map(self) -> dict[ID, list[INetConn]]:
|
||||||
|
"""
|
||||||
|
Get all connections map (like JS getConnectionsMap).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dict[ID, list[INetConn]]
|
||||||
|
The complete mapping of peer IDs to their connection lists.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return self.connections.copy()
|
||||||
|
|
||||||
|
def get_connection(self, peer_id: ID) -> INetConn | None:
|
||||||
|
"""
|
||||||
|
Get single connection for backward compatibility.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
peer_id : ID
|
||||||
|
The peer ID to get a connection for.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
INetConn | None
|
||||||
|
The first available connection, or None if no connections exist.
|
||||||
|
|
||||||
|
"""
|
||||||
|
conns = self.get_connections(peer_id)
|
||||||
|
return conns[0] if conns else None
|
||||||
|
|
||||||
|
async def dial_peer(self, peer_id: ID) -> list[INetConn]:
|
||||||
|
"""
|
||||||
|
Try to create connections to peer_id with enhanced retry logic.
|
||||||
|
|
||||||
:param peer_id: peer if we want to dial
|
:param peer_id: peer if we want to dial
|
||||||
:raises SwarmException: raised when an error occurs
|
:raises SwarmException: raised when an error occurs
|
||||||
:return: muxed connection
|
:return: list of muxed connections
|
||||||
"""
|
"""
|
||||||
# Enhanced: Check connection pool first if enabled
|
# Check if we already have connections
|
||||||
if self.connection_pool and self.connection_pool.has_connection(peer_id):
|
existing_connections = self.get_connections(peer_id)
|
||||||
connection = self.connection_pool.get_connection(peer_id)
|
if existing_connections:
|
||||||
if connection:
|
logger.debug(f"Reusing existing connections to peer {peer_id}")
|
||||||
logger.debug(f"Reusing existing connection to peer {peer_id}")
|
return existing_connections
|
||||||
return connection
|
|
||||||
|
|
||||||
# Enhanced: Check existing single connection for backward compatibility
|
|
||||||
if peer_id in self.connections:
|
|
||||||
# If muxed connection already exists for peer_id,
|
|
||||||
# set muxed connection equal to existing muxed connection
|
|
||||||
return self.connections[peer_id]
|
|
||||||
|
|
||||||
logger.debug("attempting to dial peer %s", peer_id)
|
logger.debug("attempting to dial peer %s", peer_id)
|
||||||
|
|
||||||
@ -358,23 +239,19 @@ class Swarm(Service, INetworkService):
|
|||||||
if not addrs:
|
if not addrs:
|
||||||
raise SwarmException(f"No known addresses to peer {peer_id}")
|
raise SwarmException(f"No known addresses to peer {peer_id}")
|
||||||
|
|
||||||
|
connections = []
|
||||||
exceptions: list[SwarmException] = []
|
exceptions: list[SwarmException] = []
|
||||||
|
|
||||||
# Enhanced: Try all known addresses with retry logic
|
# Enhanced: Try all known addresses with retry logic
|
||||||
for multiaddr in addrs:
|
for multiaddr in addrs:
|
||||||
try:
|
try:
|
||||||
connection = await self._dial_with_retry(multiaddr, peer_id)
|
connection = await self._dial_with_retry(multiaddr, peer_id)
|
||||||
|
connections.append(connection)
|
||||||
|
|
||||||
# Enhanced: Add to connection pool if enabled
|
# Limit number of connections per peer
|
||||||
if self.connection_pool:
|
if len(connections) >= self.connection_config.max_connections_per_peer:
|
||||||
self.connection_pool.add_connection(
|
break
|
||||||
peer_id, connection, str(multiaddr)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Backward compatibility: Keep existing connections dict
|
|
||||||
self.connections[peer_id] = connection
|
|
||||||
|
|
||||||
return connection
|
|
||||||
except SwarmException as e:
|
except SwarmException as e:
|
||||||
exceptions.append(e)
|
exceptions.append(e)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@ -384,11 +261,14 @@ class Swarm(Service, INetworkService):
|
|||||||
exc_info=e,
|
exc_info=e,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Tried all addresses, raising exception.
|
if not connections:
|
||||||
raise SwarmException(
|
# Tried all addresses, raising exception.
|
||||||
f"unable to connect to {peer_id}, no addresses established a successful "
|
raise SwarmException(
|
||||||
"connection (with exceptions)"
|
f"unable to connect to {peer_id}, no addresses established a "
|
||||||
) from MultiError(exceptions)
|
"successful connection (with exceptions)"
|
||||||
|
) from MultiError(exceptions)
|
||||||
|
|
||||||
|
return connections
|
||||||
|
|
||||||
async def _dial_with_retry(self, addr: Multiaddr, peer_id: ID) -> INetConn:
|
async def _dial_with_retry(self, addr: Multiaddr, peer_id: ID) -> INetConn:
|
||||||
"""
|
"""
|
||||||
@ -515,33 +395,76 @@ class Swarm(Service, INetworkService):
|
|||||||
"""
|
"""
|
||||||
logger.debug("attempting to open a stream to peer %s", peer_id)
|
logger.debug("attempting to open a stream to peer %s", peer_id)
|
||||||
|
|
||||||
# Enhanced: Try to get existing connection from pool first
|
# Get existing connections or dial new ones
|
||||||
if self.connection_pool and self.connection_pool.has_connection(peer_id):
|
connections = self.get_connections(peer_id)
|
||||||
connection = self.connection_pool.get_connection(
|
if not connections:
|
||||||
peer_id, self.connection_config.load_balancing_strategy
|
connections = await self.dial_peer(peer_id)
|
||||||
)
|
|
||||||
if connection:
|
|
||||||
try:
|
|
||||||
net_stream = await connection.new_stream()
|
|
||||||
logger.debug(
|
|
||||||
"successfully opened a stream to peer %s "
|
|
||||||
"using existing connection",
|
|
||||||
peer_id,
|
|
||||||
)
|
|
||||||
return net_stream
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(
|
|
||||||
f"Failed to create stream on existing connection, "
|
|
||||||
f"will dial new connection: {e}"
|
|
||||||
)
|
|
||||||
# Fall through to dial new connection
|
|
||||||
|
|
||||||
# Fall back to existing logic: dial peer and create stream
|
# Load balancing strategy at interface level
|
||||||
swarm_conn = await self.dial_peer(peer_id)
|
connection = self._select_connection(connections, peer_id)
|
||||||
|
|
||||||
net_stream = await swarm_conn.new_stream()
|
try:
|
||||||
logger.debug("successfully opened a stream to peer %s", peer_id)
|
net_stream = await connection.new_stream()
|
||||||
return net_stream
|
logger.debug("successfully opened a stream to peer %s", peer_id)
|
||||||
|
return net_stream
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to create stream on connection: {e}")
|
||||||
|
# Try other connections if available
|
||||||
|
for other_conn in connections:
|
||||||
|
if other_conn != connection:
|
||||||
|
try:
|
||||||
|
net_stream = await other_conn.new_stream()
|
||||||
|
logger.debug(
|
||||||
|
f"Successfully opened a stream to peer {peer_id} "
|
||||||
|
"using alternative connection"
|
||||||
|
)
|
||||||
|
return net_stream
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# All connections failed, raise exception
|
||||||
|
raise SwarmException(f"Failed to create stream to peer {peer_id}") from e
|
||||||
|
|
||||||
|
def _select_connection(self, connections: list[INetConn], peer_id: ID) -> INetConn:
|
||||||
|
"""
|
||||||
|
Select connection based on load balancing strategy.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
connections : list[INetConn]
|
||||||
|
List of available connections.
|
||||||
|
peer_id : ID
|
||||||
|
The peer ID for round-robin tracking.
|
||||||
|
strategy : str
|
||||||
|
Load balancing strategy ("round_robin", "least_loaded", etc.).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
INetConn
|
||||||
|
Selected connection.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not connections:
|
||||||
|
raise ValueError("No connections available")
|
||||||
|
|
||||||
|
strategy = self.connection_config.load_balancing_strategy
|
||||||
|
|
||||||
|
if strategy == "round_robin":
|
||||||
|
# Simple round-robin selection
|
||||||
|
if peer_id not in self._round_robin_index:
|
||||||
|
self._round_robin_index[peer_id] = 0
|
||||||
|
|
||||||
|
index = self._round_robin_index[peer_id] % len(connections)
|
||||||
|
self._round_robin_index[peer_id] += 1
|
||||||
|
return connections[index]
|
||||||
|
|
||||||
|
elif strategy == "least_loaded":
|
||||||
|
# Find connection with least streams
|
||||||
|
return min(connections, key=lambda c: len(c.get_streams()))
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Default to first connection
|
||||||
|
return connections[0]
|
||||||
|
|
||||||
async def listen(self, *multiaddrs: Multiaddr) -> bool:
|
async def listen(self, *multiaddrs: Multiaddr) -> bool:
|
||||||
"""
|
"""
|
||||||
@ -637,9 +560,9 @@ class Swarm(Service, INetworkService):
|
|||||||
# Perform alternative cleanup if the manager isn't initialized
|
# Perform alternative cleanup if the manager isn't initialized
|
||||||
# Close all connections manually
|
# Close all connections manually
|
||||||
if hasattr(self, "connections"):
|
if hasattr(self, "connections"):
|
||||||
for conn_id in list(self.connections.keys()):
|
for peer_id, conns in list(self.connections.items()):
|
||||||
conn = self.connections[conn_id]
|
for conn in conns:
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
|
||||||
# Clear connection tracking dictionary
|
# Clear connection tracking dictionary
|
||||||
self.connections.clear()
|
self.connections.clear()
|
||||||
@ -669,17 +592,28 @@ class Swarm(Service, INetworkService):
|
|||||||
logger.debug("swarm successfully closed")
|
logger.debug("swarm successfully closed")
|
||||||
|
|
||||||
async def close_peer(self, peer_id: ID) -> None:
|
async def close_peer(self, peer_id: ID) -> None:
|
||||||
if peer_id not in self.connections:
|
"""
|
||||||
|
Close all connections to the specified peer.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
peer_id : ID
|
||||||
|
The peer ID to close connections for.
|
||||||
|
|
||||||
|
"""
|
||||||
|
connections = self.get_connections(peer_id)
|
||||||
|
if not connections:
|
||||||
return
|
return
|
||||||
connection = self.connections[peer_id]
|
|
||||||
|
|
||||||
# Enhanced: Remove from connection pool if enabled
|
# Close all connections
|
||||||
if self.connection_pool:
|
for connection in connections:
|
||||||
self.connection_pool.remove_connection(peer_id, connection)
|
try:
|
||||||
|
await connection.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error closing connection to {peer_id}: {e}")
|
||||||
|
|
||||||
# NOTE: `connection.close` will delete `peer_id` from `self.connections`
|
# Remove from connections dict
|
||||||
# and `notify_disconnected` for us.
|
self.connections.pop(peer_id, None)
|
||||||
await connection.close()
|
|
||||||
|
|
||||||
logger.debug("successfully close the connection to peer %s", peer_id)
|
logger.debug("successfully close the connection to peer %s", peer_id)
|
||||||
|
|
||||||
@ -698,20 +632,58 @@ class Swarm(Service, INetworkService):
|
|||||||
await muxed_conn.event_started.wait()
|
await muxed_conn.event_started.wait()
|
||||||
self.manager.run_task(swarm_conn.start)
|
self.manager.run_task(swarm_conn.start)
|
||||||
await swarm_conn.event_started.wait()
|
await swarm_conn.event_started.wait()
|
||||||
# Enhanced: Add to connection pool if enabled
|
|
||||||
if self.connection_pool:
|
|
||||||
# For incoming connections, we don't have a specific address
|
|
||||||
# Use a placeholder that will be updated when we get more info
|
|
||||||
self.connection_pool.add_connection(
|
|
||||||
muxed_conn.peer_id, swarm_conn, "incoming"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store muxed_conn with peer id (backward compatibility)
|
# Add to connections dict with deduplication
|
||||||
self.connections[muxed_conn.peer_id] = swarm_conn
|
peer_id = muxed_conn.peer_id
|
||||||
|
if peer_id not in self.connections:
|
||||||
|
self.connections[peer_id] = []
|
||||||
|
|
||||||
|
# Check for duplicate connections by comparing the underlying muxed connection
|
||||||
|
for existing_conn in self.connections[peer_id]:
|
||||||
|
if existing_conn.muxed_conn == muxed_conn:
|
||||||
|
logger.debug(f"Connection already exists for peer {peer_id}")
|
||||||
|
# existing_conn is a SwarmConn since it's stored in the connections list
|
||||||
|
return existing_conn # type: ignore[return-value]
|
||||||
|
|
||||||
|
self.connections[peer_id].append(swarm_conn)
|
||||||
|
|
||||||
|
# Trim if we exceed max connections
|
||||||
|
max_conns = self.connection_config.max_connections_per_peer
|
||||||
|
if len(self.connections[peer_id]) > max_conns:
|
||||||
|
self._trim_connections(peer_id)
|
||||||
|
|
||||||
# Call notifiers since event occurred
|
# Call notifiers since event occurred
|
||||||
await self.notify_connected(swarm_conn)
|
await self.notify_connected(swarm_conn)
|
||||||
return swarm_conn
|
return swarm_conn
|
||||||
|
|
||||||
|
def _trim_connections(self, peer_id: ID) -> None:
|
||||||
|
"""
|
||||||
|
Remove oldest connections when limit is exceeded.
|
||||||
|
"""
|
||||||
|
connections = self.connections[peer_id]
|
||||||
|
if len(connections) <= self.connection_config.max_connections_per_peer:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Sort by creation time and remove oldest
|
||||||
|
# For now, just keep the most recent connections
|
||||||
|
max_conns = self.connection_config.max_connections_per_peer
|
||||||
|
connections_to_remove = connections[:-max_conns]
|
||||||
|
|
||||||
|
for conn in connections_to_remove:
|
||||||
|
logger.debug(f"Trimming old connection for peer {peer_id}")
|
||||||
|
trio.lowlevel.spawn_system_task(self._close_connection_async, conn)
|
||||||
|
|
||||||
|
# Keep only the most recent connections
|
||||||
|
max_conns = self.connection_config.max_connections_per_peer
|
||||||
|
self.connections[peer_id] = connections[-max_conns:]
|
||||||
|
|
||||||
|
async def _close_connection_async(self, connection: INetConn) -> None:
|
||||||
|
"""Close a connection asynchronously."""
|
||||||
|
try:
|
||||||
|
await connection.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error closing connection: {e}")
|
||||||
|
|
||||||
def remove_conn(self, swarm_conn: SwarmConn) -> None:
|
def remove_conn(self, swarm_conn: SwarmConn) -> None:
|
||||||
"""
|
"""
|
||||||
Simply remove the connection from Swarm's records, without closing
|
Simply remove the connection from Swarm's records, without closing
|
||||||
@ -719,13 +691,12 @@ class Swarm(Service, INetworkService):
|
|||||||
"""
|
"""
|
||||||
peer_id = swarm_conn.muxed_conn.peer_id
|
peer_id = swarm_conn.muxed_conn.peer_id
|
||||||
|
|
||||||
# Enhanced: Remove from connection pool if enabled
|
if peer_id in self.connections:
|
||||||
if self.connection_pool:
|
self.connections[peer_id] = [
|
||||||
self.connection_pool.remove_connection(peer_id, swarm_conn)
|
conn for conn in self.connections[peer_id] if conn != swarm_conn
|
||||||
|
]
|
||||||
if peer_id not in self.connections:
|
if not self.connections[peer_id]:
|
||||||
return
|
del self.connections[peer_id]
|
||||||
del self.connections[peer_id]
|
|
||||||
|
|
||||||
# Notifee
|
# Notifee
|
||||||
|
|
||||||
@ -771,3 +742,21 @@ class Swarm(Service, INetworkService):
|
|||||||
async with trio.open_nursery() as nursery:
|
async with trio.open_nursery() as nursery:
|
||||||
for notifee in self.notifees:
|
for notifee in self.notifees:
|
||||||
nursery.start_soon(notifier, notifee)
|
nursery.start_soon(notifier, notifee)
|
||||||
|
|
||||||
|
# Backward compatibility properties
|
||||||
|
@property
|
||||||
|
def connections_legacy(self) -> dict[ID, INetConn]:
|
||||||
|
"""
|
||||||
|
Legacy 1:1 mapping for backward compatibility.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dict[ID, INetConn]
|
||||||
|
Legacy mapping with only the first connection per peer.
|
||||||
|
|
||||||
|
"""
|
||||||
|
legacy_conns = {}
|
||||||
|
for peer_id, conns in self.connections.items():
|
||||||
|
if conns:
|
||||||
|
legacy_conns[peer_id] = conns[0]
|
||||||
|
return legacy_conns
|
||||||
|
|||||||
@ -164,8 +164,8 @@ async def test_live_peers_unexpected_drop(security_protocol):
|
|||||||
assert peer_a_id in host_b.get_live_peers()
|
assert peer_a_id in host_b.get_live_peers()
|
||||||
|
|
||||||
# Simulate unexpected connection drop by directly closing the connection
|
# Simulate unexpected connection drop by directly closing the connection
|
||||||
conn = host_a.get_network().connections[peer_b_id]
|
conns = host_a.get_network().connections[peer_b_id]
|
||||||
await conn.muxed_conn.close()
|
await conns[0].muxed_conn.close()
|
||||||
|
|
||||||
# Allow for connection cleanup
|
# Allow for connection cleanup
|
||||||
await trio.sleep(0.1)
|
await trio.sleep(0.1)
|
||||||
|
|||||||
@ -1,14 +1,15 @@
|
|||||||
import time
|
import time
|
||||||
|
from typing import cast
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from multiaddr import Multiaddr
|
from multiaddr import Multiaddr
|
||||||
|
import trio
|
||||||
|
|
||||||
from libp2p.abc import INetConn, INetStream
|
from libp2p.abc import INetConn, INetStream
|
||||||
from libp2p.network.exceptions import SwarmException
|
from libp2p.network.exceptions import SwarmException
|
||||||
from libp2p.network.swarm import (
|
from libp2p.network.swarm import (
|
||||||
ConnectionConfig,
|
ConnectionConfig,
|
||||||
ConnectionPool,
|
|
||||||
RetryConfig,
|
RetryConfig,
|
||||||
Swarm,
|
Swarm,
|
||||||
)
|
)
|
||||||
@ -21,10 +22,12 @@ class MockConnection(INetConn):
|
|||||||
def __init__(self, peer_id: ID, is_closed: bool = False):
|
def __init__(self, peer_id: ID, is_closed: bool = False):
|
||||||
self.peer_id = peer_id
|
self.peer_id = peer_id
|
||||||
self._is_closed = is_closed
|
self._is_closed = is_closed
|
||||||
self.stream_count = 0
|
self.streams = set() # Track streams properly
|
||||||
# Mock the muxed_conn attribute that Swarm expects
|
# Mock the muxed_conn attribute that Swarm expects
|
||||||
self.muxed_conn = Mock()
|
self.muxed_conn = Mock()
|
||||||
self.muxed_conn.peer_id = peer_id
|
self.muxed_conn.peer_id = peer_id
|
||||||
|
# Required by INetConn interface
|
||||||
|
self.event_started = trio.Event()
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
self._is_closed = True
|
self._is_closed = True
|
||||||
@ -34,12 +37,14 @@ class MockConnection(INetConn):
|
|||||||
return self._is_closed
|
return self._is_closed
|
||||||
|
|
||||||
async def new_stream(self) -> INetStream:
|
async def new_stream(self) -> INetStream:
|
||||||
self.stream_count += 1
|
# Create a mock stream and add it to the connection's stream set
|
||||||
return Mock(spec=INetStream)
|
mock_stream = Mock(spec=INetStream)
|
||||||
|
self.streams.add(mock_stream)
|
||||||
|
return mock_stream
|
||||||
|
|
||||||
def get_streams(self) -> tuple[INetStream, ...]:
|
def get_streams(self) -> tuple[INetStream, ...]:
|
||||||
"""Mock implementation of get_streams."""
|
"""Return all streams associated with this connection."""
|
||||||
return tuple()
|
return tuple(self.streams)
|
||||||
|
|
||||||
def get_transport_addresses(self) -> list[Multiaddr]:
|
def get_transport_addresses(self) -> list[Multiaddr]:
|
||||||
"""Mock implementation of get_transport_addresses."""
|
"""Mock implementation of get_transport_addresses."""
|
||||||
@ -70,114 +75,9 @@ async def test_connection_config_defaults():
|
|||||||
config = ConnectionConfig()
|
config = ConnectionConfig()
|
||||||
assert config.max_connections_per_peer == 3
|
assert config.max_connections_per_peer == 3
|
||||||
assert config.connection_timeout == 30.0
|
assert config.connection_timeout == 30.0
|
||||||
assert config.enable_connection_pool is True
|
|
||||||
assert config.load_balancing_strategy == "round_robin"
|
assert config.load_balancing_strategy == "round_robin"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
|
||||||
async def test_connection_pool_basic_operations():
|
|
||||||
"""Test basic ConnectionPool operations."""
|
|
||||||
pool = ConnectionPool(max_connections_per_peer=2)
|
|
||||||
peer_id = ID(b"QmTest")
|
|
||||||
|
|
||||||
# Test empty pool
|
|
||||||
assert not pool.has_connection(peer_id)
|
|
||||||
assert pool.get_connection(peer_id) is None
|
|
||||||
|
|
||||||
# Add connection
|
|
||||||
conn1 = MockConnection(peer_id)
|
|
||||||
pool.add_connection(peer_id, conn1, "addr1")
|
|
||||||
assert pool.has_connection(peer_id)
|
|
||||||
assert pool.get_connection(peer_id) == conn1
|
|
||||||
|
|
||||||
# Add second connection
|
|
||||||
conn2 = MockConnection(peer_id)
|
|
||||||
pool.add_connection(peer_id, conn2, "addr2")
|
|
||||||
assert len(pool.peer_connections[peer_id]) == 2
|
|
||||||
|
|
||||||
# Test round-robin - should cycle through connections
|
|
||||||
first_conn = pool.get_connection(peer_id, "round_robin")
|
|
||||||
second_conn = pool.get_connection(peer_id, "round_robin")
|
|
||||||
third_conn = pool.get_connection(peer_id, "round_robin")
|
|
||||||
|
|
||||||
# Should cycle through both connections
|
|
||||||
assert first_conn in [conn1, conn2]
|
|
||||||
assert second_conn in [conn1, conn2]
|
|
||||||
assert third_conn in [conn1, conn2]
|
|
||||||
assert first_conn != second_conn or second_conn != third_conn
|
|
||||||
|
|
||||||
# Test least loaded - set different stream counts
|
|
||||||
conn1.stream_count = 5
|
|
||||||
conn2.stream_count = 1
|
|
||||||
least_loaded_conn = pool.get_connection(peer_id, "least_loaded")
|
|
||||||
assert least_loaded_conn == conn2 # conn2 has fewer streams
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
|
||||||
async def test_connection_pool_deduplication():
|
|
||||||
"""Test connection deduplication by address."""
|
|
||||||
pool = ConnectionPool(max_connections_per_peer=3)
|
|
||||||
peer_id = ID(b"QmTest")
|
|
||||||
|
|
||||||
conn1 = MockConnection(peer_id)
|
|
||||||
pool.add_connection(peer_id, conn1, "addr1")
|
|
||||||
|
|
||||||
# Try to add connection with same address
|
|
||||||
conn2 = MockConnection(peer_id)
|
|
||||||
pool.add_connection(peer_id, conn2, "addr1")
|
|
||||||
|
|
||||||
# Should only have one connection
|
|
||||||
assert len(pool.peer_connections[peer_id]) == 1
|
|
||||||
assert pool.get_connection(peer_id) == conn1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
|
||||||
async def test_connection_pool_trimming():
|
|
||||||
"""Test connection trimming when limit is exceeded."""
|
|
||||||
pool = ConnectionPool(max_connections_per_peer=2)
|
|
||||||
peer_id = ID(b"QmTest")
|
|
||||||
|
|
||||||
# Add 3 connections
|
|
||||||
conn1 = MockConnection(peer_id)
|
|
||||||
conn2 = MockConnection(peer_id)
|
|
||||||
conn3 = MockConnection(peer_id)
|
|
||||||
|
|
||||||
pool.add_connection(peer_id, conn1, "addr1")
|
|
||||||
pool.add_connection(peer_id, conn2, "addr2")
|
|
||||||
pool.add_connection(peer_id, conn3, "addr3")
|
|
||||||
|
|
||||||
# Should trim to 2 connections
|
|
||||||
assert len(pool.peer_connections[peer_id]) == 2
|
|
||||||
|
|
||||||
# The oldest connections should be removed
|
|
||||||
remaining_connections = [c.connection for c in pool.peer_connections[peer_id]]
|
|
||||||
assert conn3 in remaining_connections # Most recent should remain
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
|
||||||
async def test_connection_pool_remove_connection():
|
|
||||||
"""Test removing connections from pool."""
|
|
||||||
pool = ConnectionPool(max_connections_per_peer=3)
|
|
||||||
peer_id = ID(b"QmTest")
|
|
||||||
|
|
||||||
conn1 = MockConnection(peer_id)
|
|
||||||
conn2 = MockConnection(peer_id)
|
|
||||||
|
|
||||||
pool.add_connection(peer_id, conn1, "addr1")
|
|
||||||
pool.add_connection(peer_id, conn2, "addr2")
|
|
||||||
|
|
||||||
assert len(pool.peer_connections[peer_id]) == 2
|
|
||||||
|
|
||||||
# Remove connection
|
|
||||||
pool.remove_connection(peer_id, conn1)
|
|
||||||
assert len(pool.peer_connections[peer_id]) == 1
|
|
||||||
assert pool.get_connection(peer_id) == conn2
|
|
||||||
|
|
||||||
# Remove last connection
|
|
||||||
pool.remove_connection(peer_id, conn2)
|
|
||||||
assert not pool.has_connection(peer_id)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
async def test_enhanced_swarm_constructor():
|
async def test_enhanced_swarm_constructor():
|
||||||
"""Test enhanced Swarm constructor with new configuration."""
|
"""Test enhanced Swarm constructor with new configuration."""
|
||||||
@ -191,19 +91,16 @@ async def test_enhanced_swarm_constructor():
|
|||||||
swarm = Swarm(peer_id, peerstore, upgrader, transport)
|
swarm = Swarm(peer_id, peerstore, upgrader, transport)
|
||||||
assert swarm.retry_config.max_retries == 3
|
assert swarm.retry_config.max_retries == 3
|
||||||
assert swarm.connection_config.max_connections_per_peer == 3
|
assert swarm.connection_config.max_connections_per_peer == 3
|
||||||
assert swarm.connection_pool is not None
|
assert isinstance(swarm.connections, dict)
|
||||||
|
|
||||||
# Test with custom config
|
# Test with custom config
|
||||||
custom_retry = RetryConfig(max_retries=5, initial_delay=0.5)
|
custom_retry = RetryConfig(max_retries=5, initial_delay=0.5)
|
||||||
custom_conn = ConnectionConfig(
|
custom_conn = ConnectionConfig(max_connections_per_peer=5)
|
||||||
max_connections_per_peer=5, enable_connection_pool=False
|
|
||||||
)
|
|
||||||
|
|
||||||
swarm = Swarm(peer_id, peerstore, upgrader, transport, custom_retry, custom_conn)
|
swarm = Swarm(peer_id, peerstore, upgrader, transport, custom_retry, custom_conn)
|
||||||
assert swarm.retry_config.max_retries == 5
|
assert swarm.retry_config.max_retries == 5
|
||||||
assert swarm.retry_config.initial_delay == 0.5
|
assert swarm.retry_config.initial_delay == 0.5
|
||||||
assert swarm.connection_config.max_connections_per_peer == 5
|
assert swarm.connection_config.max_connections_per_peer == 5
|
||||||
assert swarm.connection_pool is None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
@ -273,143 +170,155 @@ async def test_swarm_retry_logic():
|
|||||||
|
|
||||||
# Should have succeeded after 3 attempts
|
# Should have succeeded after 3 attempts
|
||||||
assert attempt_count[0] == 3
|
assert attempt_count[0] == 3
|
||||||
assert result is not None
|
assert isinstance(result, MockConnection)
|
||||||
|
assert end_time - start_time > 0.01 # Should have some delay
|
||||||
# Should have taken some time due to retries
|
|
||||||
assert end_time - start_time > 0.02 # At least 2 delays
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
async def test_swarm_multi_connection_support():
|
async def test_swarm_load_balancing_strategies():
|
||||||
"""Test multi-connection support in Swarm."""
|
"""Test load balancing strategies."""
|
||||||
peer_id = ID(b"QmTest")
|
peer_id = ID(b"QmTest")
|
||||||
peerstore = Mock()
|
peerstore = Mock()
|
||||||
upgrader = Mock()
|
upgrader = Mock()
|
||||||
transport = Mock()
|
transport = Mock()
|
||||||
|
|
||||||
connection_config = ConnectionConfig(
|
swarm = Swarm(peer_id, peerstore, upgrader, transport)
|
||||||
max_connections_per_peer=3,
|
|
||||||
enable_connection_pool=True,
|
|
||||||
load_balancing_strategy="round_robin",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# Create mock connections with different stream counts
|
||||||
|
conn1 = MockConnection(peer_id)
|
||||||
|
conn2 = MockConnection(peer_id)
|
||||||
|
conn3 = MockConnection(peer_id)
|
||||||
|
|
||||||
|
# Add some streams to simulate load
|
||||||
|
await conn1.new_stream()
|
||||||
|
await conn1.new_stream()
|
||||||
|
await conn2.new_stream()
|
||||||
|
|
||||||
|
connections = [conn1, conn2, conn3]
|
||||||
|
|
||||||
|
# Test round-robin strategy
|
||||||
|
swarm.connection_config.load_balancing_strategy = "round_robin"
|
||||||
|
# Cast to satisfy type checker
|
||||||
|
connections_cast = cast("list[INetConn]", connections)
|
||||||
|
selected1 = swarm._select_connection(connections_cast, peer_id)
|
||||||
|
selected2 = swarm._select_connection(connections_cast, peer_id)
|
||||||
|
selected3 = swarm._select_connection(connections_cast, peer_id)
|
||||||
|
|
||||||
|
# Should cycle through connections
|
||||||
|
assert selected1 in connections
|
||||||
|
assert selected2 in connections
|
||||||
|
assert selected3 in connections
|
||||||
|
|
||||||
|
# Test least loaded strategy
|
||||||
|
swarm.connection_config.load_balancing_strategy = "least_loaded"
|
||||||
|
least_loaded = swarm._select_connection(connections_cast, peer_id)
|
||||||
|
|
||||||
|
# conn3 has 0 streams, conn2 has 1 stream, conn1 has 2 streams
|
||||||
|
# So conn3 should be selected as least loaded
|
||||||
|
assert least_loaded == conn3
|
||||||
|
|
||||||
|
# Test default strategy (first connection)
|
||||||
|
swarm.connection_config.load_balancing_strategy = "unknown"
|
||||||
|
default_selected = swarm._select_connection(connections_cast, peer_id)
|
||||||
|
assert default_selected == conn1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_swarm_multiple_connections_api():
|
||||||
|
"""Test the new multiple connections API methods."""
|
||||||
|
peer_id = ID(b"QmTest")
|
||||||
|
peerstore = Mock()
|
||||||
|
upgrader = Mock()
|
||||||
|
transport = Mock()
|
||||||
|
|
||||||
|
swarm = Swarm(peer_id, peerstore, upgrader, transport)
|
||||||
|
|
||||||
|
# Test empty connections
|
||||||
|
assert swarm.get_connections() == []
|
||||||
|
assert swarm.get_connections(peer_id) == []
|
||||||
|
assert swarm.get_connection(peer_id) is None
|
||||||
|
assert swarm.get_connections_map() == {}
|
||||||
|
|
||||||
|
# Add some connections
|
||||||
|
conn1 = MockConnection(peer_id)
|
||||||
|
conn2 = MockConnection(peer_id)
|
||||||
|
swarm.connections[peer_id] = [conn1, conn2]
|
||||||
|
|
||||||
|
# Test get_connections with peer_id
|
||||||
|
peer_connections = swarm.get_connections(peer_id)
|
||||||
|
assert len(peer_connections) == 2
|
||||||
|
assert conn1 in peer_connections
|
||||||
|
assert conn2 in peer_connections
|
||||||
|
|
||||||
|
# Test get_connections without peer_id (all connections)
|
||||||
|
all_connections = swarm.get_connections()
|
||||||
|
assert len(all_connections) == 2
|
||||||
|
assert conn1 in all_connections
|
||||||
|
assert conn2 in all_connections
|
||||||
|
|
||||||
|
# Test get_connection (backward compatibility)
|
||||||
|
single_conn = swarm.get_connection(peer_id)
|
||||||
|
assert single_conn in [conn1, conn2]
|
||||||
|
|
||||||
|
# Test get_connections_map
|
||||||
|
connections_map = swarm.get_connections_map()
|
||||||
|
assert peer_id in connections_map
|
||||||
|
assert connections_map[peer_id] == [conn1, conn2]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_swarm_connection_trimming():
|
||||||
|
"""Test connection trimming when limit is exceeded."""
|
||||||
|
peer_id = ID(b"QmTest")
|
||||||
|
peerstore = Mock()
|
||||||
|
upgrader = Mock()
|
||||||
|
transport = Mock()
|
||||||
|
|
||||||
|
# Set max connections to 2
|
||||||
|
connection_config = ConnectionConfig(max_connections_per_peer=2)
|
||||||
swarm = Swarm(
|
swarm = Swarm(
|
||||||
peer_id, peerstore, upgrader, transport, connection_config=connection_config
|
peer_id, peerstore, upgrader, transport, connection_config=connection_config
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock connection pool methods
|
# Add 3 connections
|
||||||
assert swarm.connection_pool is not None
|
conn1 = MockConnection(peer_id)
|
||||||
connection_pool = swarm.connection_pool
|
conn2 = MockConnection(peer_id)
|
||||||
connection_pool.has_connection = Mock(return_value=True)
|
conn3 = MockConnection(peer_id)
|
||||||
connection_pool.get_connection = Mock(return_value=MockConnection(peer_id))
|
|
||||||
|
|
||||||
# Test that new_stream uses connection pool
|
swarm.connections[peer_id] = [conn1, conn2, conn3]
|
||||||
result = await swarm.new_stream(peer_id)
|
|
||||||
assert result is not None
|
# Trigger trimming
|
||||||
# Use the mocked method directly to avoid type checking issues
|
swarm._trim_connections(peer_id)
|
||||||
get_connection_mock = connection_pool.get_connection
|
|
||||||
assert get_connection_mock.call_count == 1
|
# Should have only 2 connections
|
||||||
|
assert len(swarm.connections[peer_id]) == 2
|
||||||
|
|
||||||
|
# The most recent connections should remain
|
||||||
|
remaining = swarm.connections[peer_id]
|
||||||
|
assert conn2 in remaining
|
||||||
|
assert conn3 in remaining
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
async def test_swarm_backward_compatibility():
|
async def test_swarm_backward_compatibility():
|
||||||
"""Test that enhanced Swarm maintains backward compatibility."""
|
"""Test backward compatibility features."""
|
||||||
peer_id = ID(b"QmTest")
|
peer_id = ID(b"QmTest")
|
||||||
peerstore = Mock()
|
peerstore = Mock()
|
||||||
upgrader = Mock()
|
upgrader = Mock()
|
||||||
transport = Mock()
|
transport = Mock()
|
||||||
|
|
||||||
# Create swarm with connection pool disabled
|
swarm = Swarm(peer_id, peerstore, upgrader, transport)
|
||||||
connection_config = ConnectionConfig(enable_connection_pool=False)
|
|
||||||
swarm = Swarm(
|
|
||||||
peer_id, peerstore, upgrader, transport, connection_config=connection_config
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should behave like original swarm
|
# Add connections
|
||||||
assert swarm.connection_pool is None
|
conn1 = MockConnection(peer_id)
|
||||||
assert isinstance(swarm.connections, dict)
|
conn2 = MockConnection(peer_id)
|
||||||
|
swarm.connections[peer_id] = [conn1, conn2]
|
||||||
|
|
||||||
# Test that dial_peer still works (will fail due to mocks, but structure is correct)
|
# Test connections_legacy property
|
||||||
peerstore.addrs.return_value = [Mock(spec=Multiaddr)]
|
legacy_connections = swarm.connections_legacy
|
||||||
transport.dial.side_effect = Exception("Transport error")
|
assert peer_id in legacy_connections
|
||||||
|
# Should return first connection
|
||||||
with pytest.raises(SwarmException):
|
assert legacy_connections[peer_id] in [conn1, conn2]
|
||||||
await swarm.dial_peer(peer_id)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
|
||||||
async def test_swarm_connection_pool_integration():
|
|
||||||
"""Test integration between Swarm and ConnectionPool."""
|
|
||||||
peer_id = ID(b"QmTest")
|
|
||||||
peerstore = Mock()
|
|
||||||
upgrader = Mock()
|
|
||||||
transport = Mock()
|
|
||||||
|
|
||||||
connection_config = ConnectionConfig(
|
|
||||||
max_connections_per_peer=2, enable_connection_pool=True
|
|
||||||
)
|
|
||||||
|
|
||||||
swarm = Swarm(
|
|
||||||
peer_id, peerstore, upgrader, transport, connection_config=connection_config
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock successful connection creation
|
|
||||||
mock_conn = MockConnection(peer_id)
|
|
||||||
peerstore.addrs.return_value = [Mock(spec=Multiaddr)]
|
|
||||||
|
|
||||||
async def mock_dial_with_retry(addr, peer_id):
|
|
||||||
return mock_conn
|
|
||||||
|
|
||||||
swarm._dial_with_retry = mock_dial_with_retry
|
|
||||||
|
|
||||||
# Test dial_peer adds to connection pool
|
|
||||||
result = await swarm.dial_peer(peer_id)
|
|
||||||
assert result == mock_conn
|
|
||||||
assert swarm.connection_pool is not None
|
|
||||||
assert swarm.connection_pool.has_connection(peer_id)
|
|
||||||
|
|
||||||
# Test that subsequent calls reuse connection
|
|
||||||
result2 = await swarm.dial_peer(peer_id)
|
|
||||||
assert result2 == mock_conn
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
|
||||||
async def test_swarm_connection_cleanup():
|
|
||||||
"""Test connection cleanup in enhanced Swarm."""
|
|
||||||
peer_id = ID(b"QmTest")
|
|
||||||
peerstore = Mock()
|
|
||||||
upgrader = Mock()
|
|
||||||
transport = Mock()
|
|
||||||
|
|
||||||
connection_config = ConnectionConfig(enable_connection_pool=True)
|
|
||||||
swarm = Swarm(
|
|
||||||
peer_id, peerstore, upgrader, transport, connection_config=connection_config
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add a connection
|
|
||||||
mock_conn = MockConnection(peer_id)
|
|
||||||
swarm.connections[peer_id] = mock_conn
|
|
||||||
assert swarm.connection_pool is not None
|
|
||||||
swarm.connection_pool.add_connection(peer_id, mock_conn, "test_addr")
|
|
||||||
|
|
||||||
# Test close_peer removes from pool
|
|
||||||
await swarm.close_peer(peer_id)
|
|
||||||
assert swarm.connection_pool is not None
|
|
||||||
assert not swarm.connection_pool.has_connection(peer_id)
|
|
||||||
|
|
||||||
# Test remove_conn removes from pool
|
|
||||||
mock_conn2 = MockConnection(peer_id)
|
|
||||||
swarm.connections[peer_id] = mock_conn2
|
|
||||||
assert swarm.connection_pool is not None
|
|
||||||
connection_pool = swarm.connection_pool
|
|
||||||
connection_pool.add_connection(peer_id, mock_conn2, "test_addr2")
|
|
||||||
|
|
||||||
# Note: remove_conn expects SwarmConn, but for testing we'll just
|
|
||||||
# remove from pool directly
|
|
||||||
connection_pool = swarm.connection_pool
|
|
||||||
connection_pool.remove_connection(peer_id, mock_conn2)
|
|
||||||
assert connection_pool is not None
|
|
||||||
assert not connection_pool.has_connection(peer_id)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -51,14 +51,19 @@ async def test_swarm_dial_peer(security_protocol):
|
|||||||
for addr in transport.get_addrs()
|
for addr in transport.get_addrs()
|
||||||
)
|
)
|
||||||
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
|
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
|
||||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
|
||||||
|
# New: dial_peer now returns list of connections
|
||||||
|
connections = await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||||
|
assert len(connections) > 0
|
||||||
|
|
||||||
|
# Verify connections are established in both directions
|
||||||
assert swarms[0].get_peer_id() in swarms[1].connections
|
assert swarms[0].get_peer_id() in swarms[1].connections
|
||||||
assert swarms[1].get_peer_id() in swarms[0].connections
|
assert swarms[1].get_peer_id() in swarms[0].connections
|
||||||
|
|
||||||
# Test: Reuse connections when we already have ones with a peer.
|
# Test: Reuse connections when we already have ones with a peer.
|
||||||
conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()]
|
existing_connections = swarms[0].get_connections(swarms[1].get_peer_id())
|
||||||
conn = await swarms[0].dial_peer(swarms[1].get_peer_id())
|
new_connections = await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||||
assert conn is conn_to_1
|
assert new_connections == existing_connections
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
@ -107,7 +112,8 @@ async def test_swarm_close_peer(security_protocol):
|
|||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
async def test_swarm_remove_conn(swarm_pair):
|
async def test_swarm_remove_conn(swarm_pair):
|
||||||
swarm_0, swarm_1 = swarm_pair
|
swarm_0, swarm_1 = swarm_pair
|
||||||
conn_0 = swarm_0.connections[swarm_1.get_peer_id()]
|
# Get the first connection from the list
|
||||||
|
conn_0 = swarm_0.connections[swarm_1.get_peer_id()][0]
|
||||||
swarm_0.remove_conn(conn_0)
|
swarm_0.remove_conn(conn_0)
|
||||||
assert swarm_1.get_peer_id() not in swarm_0.connections
|
assert swarm_1.get_peer_id() not in swarm_0.connections
|
||||||
# Test: Remove twice. There should not be errors.
|
# Test: Remove twice. There should not be errors.
|
||||||
@ -115,6 +121,67 @@ async def test_swarm_remove_conn(swarm_pair):
|
|||||||
assert swarm_1.get_peer_id() not in swarm_0.connections
|
assert swarm_1.get_peer_id() not in swarm_0.connections
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_swarm_multiple_connections(security_protocol):
|
||||||
|
"""Test multiple connections per peer functionality."""
|
||||||
|
async with SwarmFactory.create_batch_and_listen(
|
||||||
|
2, security_protocol=security_protocol
|
||||||
|
) as swarms:
|
||||||
|
# Setup multiple addresses for peer
|
||||||
|
addrs = tuple(
|
||||||
|
addr
|
||||||
|
for transport in swarms[1].listeners.values()
|
||||||
|
for addr in transport.get_addrs()
|
||||||
|
)
|
||||||
|
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
|
||||||
|
|
||||||
|
# Dial peer - should return list of connections
|
||||||
|
connections = await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||||
|
assert len(connections) > 0
|
||||||
|
|
||||||
|
# Test get_connections method
|
||||||
|
peer_connections = swarms[0].get_connections(swarms[1].get_peer_id())
|
||||||
|
assert len(peer_connections) == len(connections)
|
||||||
|
|
||||||
|
# Test get_connections_map method
|
||||||
|
connections_map = swarms[0].get_connections_map()
|
||||||
|
assert swarms[1].get_peer_id() in connections_map
|
||||||
|
assert len(connections_map[swarms[1].get_peer_id()]) == len(connections)
|
||||||
|
|
||||||
|
# Test get_connection method (backward compatibility)
|
||||||
|
single_conn = swarms[0].get_connection(swarms[1].get_peer_id())
|
||||||
|
assert single_conn is not None
|
||||||
|
assert single_conn in connections
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_swarm_load_balancing(security_protocol):
|
||||||
|
"""Test load balancing across multiple connections."""
|
||||||
|
async with SwarmFactory.create_batch_and_listen(
|
||||||
|
2, security_protocol=security_protocol
|
||||||
|
) as swarms:
|
||||||
|
# Setup connection
|
||||||
|
addrs = tuple(
|
||||||
|
addr
|
||||||
|
for transport in swarms[1].listeners.values()
|
||||||
|
for addr in transport.get_addrs()
|
||||||
|
)
|
||||||
|
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
|
||||||
|
|
||||||
|
# Create multiple streams - should use load balancing
|
||||||
|
streams = []
|
||||||
|
for _ in range(5):
|
||||||
|
stream = await swarms[0].new_stream(swarms[1].get_peer_id())
|
||||||
|
streams.append(stream)
|
||||||
|
|
||||||
|
# Verify streams were created successfully
|
||||||
|
assert len(streams) == 5
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
for stream in streams:
|
||||||
|
await stream.close()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
async def test_swarm_multiaddr(security_protocol):
|
async def test_swarm_multiaddr(security_protocol):
|
||||||
async with SwarmFactory.create_batch_and_listen(
|
async with SwarmFactory.create_batch_and_listen(
|
||||||
|
|||||||
@ -51,6 +51,9 @@ async def perform_simple_test(assertion_func, security_protocol):
|
|||||||
|
|
||||||
# Extract the secured connection from either Mplex or Yamux implementation
|
# Extract the secured connection from either Mplex or Yamux implementation
|
||||||
def get_secured_conn(conn):
|
def get_secured_conn(conn):
|
||||||
|
# conn is now a list, get the first connection
|
||||||
|
if isinstance(conn, list):
|
||||||
|
conn = conn[0]
|
||||||
muxed_conn = conn.muxed_conn
|
muxed_conn = conn.muxed_conn
|
||||||
# Direct attribute access for known implementations
|
# Direct attribute access for known implementations
|
||||||
has_secured_conn = hasattr(muxed_conn, "secured_conn")
|
has_secured_conn = hasattr(muxed_conn, "secured_conn")
|
||||||
|
|||||||
@ -74,7 +74,8 @@ async def test_multiplexer_preference_parameter(muxer_preference):
|
|||||||
assert len(connections) > 0, "Connection not established"
|
assert len(connections) > 0, "Connection not established"
|
||||||
|
|
||||||
# Get the first connection
|
# Get the first connection
|
||||||
conn = list(connections.values())[0]
|
conns = list(connections.values())[0]
|
||||||
|
conn = conns[0] # Get first connection from the list
|
||||||
muxed_conn = conn.muxed_conn
|
muxed_conn = conn.muxed_conn
|
||||||
|
|
||||||
# Define a simple echo protocol
|
# Define a simple echo protocol
|
||||||
@ -150,7 +151,8 @@ async def test_explicit_muxer_options(muxer_option_func, expected_stream_class):
|
|||||||
assert len(connections) > 0, "Connection not established"
|
assert len(connections) > 0, "Connection not established"
|
||||||
|
|
||||||
# Get the first connection
|
# Get the first connection
|
||||||
conn = list(connections.values())[0]
|
conns = list(connections.values())[0]
|
||||||
|
conn = conns[0] # Get first connection from the list
|
||||||
muxed_conn = conn.muxed_conn
|
muxed_conn = conn.muxed_conn
|
||||||
|
|
||||||
# Define a simple echo protocol
|
# Define a simple echo protocol
|
||||||
@ -219,7 +221,8 @@ async def test_global_default_muxer(global_default):
|
|||||||
assert len(connections) > 0, "Connection not established"
|
assert len(connections) > 0, "Connection not established"
|
||||||
|
|
||||||
# Get the first connection
|
# Get the first connection
|
||||||
conn = list(connections.values())[0]
|
conns = list(connections.values())[0]
|
||||||
|
conn = conns[0] # Get first connection from the list
|
||||||
muxed_conn = conn.muxed_conn
|
muxed_conn = conn.muxed_conn
|
||||||
|
|
||||||
# Define a simple echo protocol
|
# Define a simple echo protocol
|
||||||
|
|||||||
@ -669,8 +669,8 @@ async def swarm_conn_pair_factory(
|
|||||||
async with swarm_pair_factory(
|
async with swarm_pair_factory(
|
||||||
security_protocol=security_protocol, muxer_opt=muxer_opt
|
security_protocol=security_protocol, muxer_opt=muxer_opt
|
||||||
) as swarms:
|
) as swarms:
|
||||||
conn_0 = swarms[0].connections[swarms[1].get_peer_id()]
|
conn_0 = swarms[0].connections[swarms[1].get_peer_id()][0]
|
||||||
conn_1 = swarms[1].connections[swarms[0].get_peer_id()]
|
conn_1 = swarms[1].connections[swarms[0].get_peer_id()][0]
|
||||||
yield cast(SwarmConn, conn_0), cast(SwarmConn, conn_1)
|
yield cast(SwarmConn, conn_0), cast(SwarmConn, conn_1)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user