mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Merge pull request #876 from bomanaps/feat/swarm-multi-connection-support
feat(swarm): enhance swarm with retry backoff
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@ -178,3 +178,6 @@ env.bak/
|
||||
#lockfiles
|
||||
uv.lock
|
||||
poetry.lock
|
||||
|
||||
# Sphinx documentation build
|
||||
_build/
|
||||
|
||||
194
docs/examples.multiple_connections.rst
Normal file
194
docs/examples.multiple_connections.rst
Normal file
@ -0,0 +1,194 @@
|
||||
Multiple Connections Per Peer
|
||||
=============================
|
||||
|
||||
This example demonstrates how to use the multiple connections per peer feature in py-libp2p.
|
||||
|
||||
Overview
|
||||
--------
|
||||
|
||||
The multiple connections per peer feature allows a libp2p node to maintain multiple network connections to the same peer. This provides several benefits:
|
||||
|
||||
- **Improved reliability**: If one connection fails, others remain available
|
||||
- **Better performance**: Load can be distributed across multiple connections
|
||||
- **Enhanced throughput**: Multiple streams can be created in parallel
|
||||
- **Fault tolerance**: Redundant connections provide backup paths
|
||||
|
||||
Configuration
|
||||
-------------
|
||||
|
||||
The feature is configured through the `ConnectionConfig` class:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from libp2p.network.swarm import ConnectionConfig
|
||||
|
||||
# Default configuration
|
||||
config = ConnectionConfig()
|
||||
print(f"Max connections per peer: {config.max_connections_per_peer}")
|
||||
print(f"Load balancing strategy: {config.load_balancing_strategy}")
|
||||
|
||||
# Custom configuration
|
||||
custom_config = ConnectionConfig(
|
||||
max_connections_per_peer=5,
|
||||
connection_timeout=60.0,
|
||||
load_balancing_strategy="least_loaded"
|
||||
)
|
||||
|
||||
Load Balancing Strategies
|
||||
-------------------------
|
||||
|
||||
Two load balancing strategies are available:
|
||||
|
||||
**Round Robin** (default)
|
||||
Cycles through connections in order, distributing load evenly.
|
||||
|
||||
**Least Loaded**
|
||||
Selects the connection with the fewest active streams.
|
||||
|
||||
API Usage
|
||||
---------
|
||||
|
||||
The new API provides direct access to multiple connections:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from libp2p import new_swarm
|
||||
|
||||
# Create swarm with multiple connections support
|
||||
swarm = new_swarm()
|
||||
|
||||
# Dial a peer - returns list of connections
|
||||
connections = await swarm.dial_peer(peer_id)
|
||||
print(f"Established {len(connections)} connections")
|
||||
|
||||
# Get all connections to a peer
|
||||
peer_connections = swarm.get_connections(peer_id)
|
||||
|
||||
# Get all connections (across all peers)
|
||||
all_connections = swarm.get_connections()
|
||||
|
||||
# Get the complete connections map
|
||||
connections_map = swarm.get_connections_map()
|
||||
|
||||
# Backward compatibility - get single connection
|
||||
single_conn = swarm.get_connection(peer_id)
|
||||
|
||||
Backward Compatibility
|
||||
----------------------
|
||||
|
||||
Existing code continues to work through backward compatibility features:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Legacy 1:1 mapping (returns first connection for each peer)
|
||||
legacy_connections = swarm.connections_legacy
|
||||
|
||||
# Single connection access (returns first available connection)
|
||||
conn = swarm.get_connection(peer_id)
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
A complete working example is available in the `examples/doc-examples/multiple_connections_example.py` file.
|
||||
|
||||
Production Configuration
|
||||
-------------------------
|
||||
|
||||
For production use, consider these settings:
|
||||
|
||||
**RetryConfig Parameters**
|
||||
|
||||
The `RetryConfig` class controls connection retry behavior with exponential backoff:
|
||||
|
||||
- **max_retries**: Maximum number of retry attempts before giving up (default: 3)
|
||||
- **initial_delay**: Initial delay in seconds before the first retry (default: 0.1s)
|
||||
- **max_delay**: Maximum delay cap to prevent excessive wait times (default: 30.0s)
|
||||
- **backoff_multiplier**: Exponential backoff multiplier - each retry multiplies delay by this factor (default: 2.0)
|
||||
- **jitter_factor**: Random jitter (0.0-1.0) to prevent synchronized retries (default: 0.1)
|
||||
|
||||
**ConnectionConfig Parameters**
|
||||
|
||||
The `ConnectionConfig` class manages multi-connection behavior:
|
||||
|
||||
- **max_connections_per_peer**: Maximum connections allowed to a single peer (default: 3)
|
||||
- **connection_timeout**: Timeout for establishing new connections in seconds (default: 30.0s)
|
||||
- **load_balancing_strategy**: Strategy for distributing streams ("round_robin" or "least_loaded")
|
||||
|
||||
**Load Balancing Strategies Explained**
|
||||
|
||||
- **round_robin**: Cycles through connections in order, distributing load evenly. Simple and predictable.
|
||||
- **least_loaded**: Selects the connection with the fewest active streams. Better for performance but more complex.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from libp2p.network.swarm import ConnectionConfig, RetryConfig
|
||||
|
||||
# Production-ready configuration
|
||||
retry_config = RetryConfig(
|
||||
max_retries=3, # Maximum retry attempts before giving up
|
||||
initial_delay=0.1, # Start with 100ms delay
|
||||
max_delay=30.0, # Cap exponential backoff at 30 seconds
|
||||
backoff_multiplier=2.0, # Double delay each retry (100ms -> 200ms -> 400ms)
|
||||
jitter_factor=0.1 # Add 10% random jitter to prevent thundering herd
|
||||
)
|
||||
|
||||
connection_config = ConnectionConfig(
|
||||
max_connections_per_peer=3, # Allow up to 3 connections per peer
|
||||
connection_timeout=30.0, # 30 second timeout for new connections
|
||||
load_balancing_strategy="round_robin" # Simple, predictable load distribution
|
||||
)
|
||||
|
||||
swarm = new_swarm(
|
||||
retry_config=retry_config,
|
||||
connection_config=connection_config
|
||||
)
|
||||
|
||||
**How RetryConfig Works in Practice**
|
||||
|
||||
With the configuration above, connection retries follow this pattern:
|
||||
|
||||
1. **Attempt 1**: Immediate connection attempt
|
||||
2. **Attempt 2**: Wait 100ms ± 10ms jitter, then retry
|
||||
3. **Attempt 3**: Wait 200ms ± 20ms jitter, then retry
|
||||
4. **Attempt 4**: Wait 400ms ± 40ms jitter, then retry
|
||||
5. **Attempt 5**: Wait 800ms ± 80ms jitter, then retry
|
||||
6. **Attempt 6**: Wait 1.6s ± 160ms jitter, then retry
|
||||
7. **Attempt 7**: Wait 3.2s ± 320ms jitter, then retry
|
||||
8. **Attempt 8**: Wait 6.4s ± 640ms jitter, then retry
|
||||
9. **Attempt 9**: Wait 12.8s ± 1.28s jitter, then retry
|
||||
10. **Attempt 10**: Wait 25.6s ± 2.56s jitter, then retry
|
||||
11. **Attempt 11**: Wait 30.0s (capped) ± 3.0s jitter, then retry
|
||||
12. **Attempt 12**: Wait 30.0s (capped) ± 3.0s jitter, then retry
|
||||
13. **Give up**: After 12 retries (3 initial + 9 retries), connection fails
|
||||
|
||||
The jitter prevents multiple clients from retrying simultaneously, reducing server load.
|
||||
|
||||
**Parameter Tuning Guidelines**
|
||||
|
||||
**For Development/Testing:**
|
||||
- Use lower `max_retries` (1-2) and shorter delays for faster feedback
|
||||
- Example: `RetryConfig(max_retries=2, initial_delay=0.01, max_delay=0.1)`
|
||||
|
||||
**For Production:**
|
||||
- Use moderate `max_retries` (3-5) with reasonable delays for reliability
|
||||
- Example: `RetryConfig(max_retries=5, initial_delay=0.1, max_delay=60.0)`
|
||||
|
||||
**For High-Latency Networks:**
|
||||
- Use higher `max_retries` (5-10) with longer delays
|
||||
- Example: `RetryConfig(max_retries=8, initial_delay=0.5, max_delay=120.0)`
|
||||
|
||||
**For Load Balancing:**
|
||||
- Use `round_robin` for simple, predictable behavior
|
||||
- Use `least_loaded` when you need optimal performance and can handle complexity
|
||||
|
||||
Architecture
|
||||
------------
|
||||
|
||||
The implementation follows the same architectural patterns as the Go and JavaScript reference implementations:
|
||||
|
||||
- **Core data structure**: `dict[ID, list[INetConn]]` for 1:many mapping
|
||||
- **API consistency**: Methods like `get_connections()` match reference implementations
|
||||
- **Load balancing**: Integrated at the API level for optimal performance
|
||||
- **Backward compatibility**: Maintains existing interfaces for gradual migration
|
||||
|
||||
This design ensures consistency across libp2p implementations while providing the benefits of multiple connections per peer.
|
||||
@ -15,3 +15,4 @@ Examples
|
||||
examples.kademlia
|
||||
examples.mDNS
|
||||
examples.random_walk
|
||||
examples.multiple_connections
|
||||
|
||||
170
examples/doc-examples/multiple_connections_example.py
Normal file
170
examples/doc-examples/multiple_connections_example.py
Normal file
@ -0,0 +1,170 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Example demonstrating multiple connections per peer support in libp2p.
|
||||
|
||||
This example shows how to:
|
||||
1. Configure multiple connections per peer
|
||||
2. Use different load balancing strategies
|
||||
3. Access multiple connections through the new API
|
||||
4. Maintain backward compatibility
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p import new_swarm
|
||||
from libp2p.network.swarm import ConnectionConfig, RetryConfig
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def example_basic_multiple_connections() -> None:
|
||||
"""Example of basic multiple connections per peer usage."""
|
||||
logger.info("Creating swarm with multiple connections support...")
|
||||
|
||||
# Create swarm with default configuration
|
||||
swarm = new_swarm()
|
||||
default_connection = ConnectionConfig()
|
||||
|
||||
logger.info(f"Swarm created with peer ID: {swarm.get_peer_id()}")
|
||||
logger.info(
|
||||
f"Connection config: max_connections_per_peer="
|
||||
f"{default_connection.max_connections_per_peer}"
|
||||
)
|
||||
|
||||
await swarm.close()
|
||||
logger.info("Basic multiple connections example completed")
|
||||
|
||||
|
||||
async def example_custom_connection_config() -> None:
|
||||
"""Example of custom connection configuration."""
|
||||
logger.info("Creating swarm with custom connection configuration...")
|
||||
|
||||
# Custom connection configuration for high-performance scenarios
|
||||
connection_config = ConnectionConfig(
|
||||
max_connections_per_peer=5, # More connections per peer
|
||||
connection_timeout=60.0, # Longer timeout
|
||||
load_balancing_strategy="least_loaded", # Use least loaded strategy
|
||||
)
|
||||
|
||||
# Create swarm with custom connection config
|
||||
swarm = new_swarm(connection_config=connection_config)
|
||||
|
||||
logger.info("Custom connection config applied:")
|
||||
logger.info(
|
||||
f" Max connections per peer: {connection_config.max_connections_per_peer}"
|
||||
)
|
||||
logger.info(f" Connection timeout: {connection_config.connection_timeout}s")
|
||||
logger.info(
|
||||
f" Load balancing strategy: {connection_config.load_balancing_strategy}"
|
||||
)
|
||||
|
||||
await swarm.close()
|
||||
logger.info("Custom connection config example completed")
|
||||
|
||||
|
||||
async def example_multiple_connections_api() -> None:
|
||||
"""Example of using the new multiple connections API."""
|
||||
logger.info("Demonstrating multiple connections API...")
|
||||
|
||||
connection_config = ConnectionConfig(
|
||||
max_connections_per_peer=3, load_balancing_strategy="round_robin"
|
||||
)
|
||||
|
||||
swarm = new_swarm(connection_config=connection_config)
|
||||
|
||||
logger.info("Multiple connections API features:")
|
||||
logger.info(" - dial_peer() returns list[INetConn]")
|
||||
logger.info(" - get_connections(peer_id) returns list[INetConn]")
|
||||
logger.info(" - get_connections_map() returns dict[ID, list[INetConn]]")
|
||||
logger.info(
|
||||
" - get_connection(peer_id) returns INetConn | None (backward compatibility)"
|
||||
)
|
||||
|
||||
await swarm.close()
|
||||
logger.info("Multiple connections API example completed")
|
||||
|
||||
|
||||
async def example_backward_compatibility() -> None:
|
||||
"""Example of backward compatibility features."""
|
||||
logger.info("Demonstrating backward compatibility...")
|
||||
|
||||
swarm = new_swarm()
|
||||
|
||||
logger.info("Backward compatibility features:")
|
||||
logger.info(" - connections_legacy property provides 1:1 mapping")
|
||||
logger.info(" - get_connection() method for single connection access")
|
||||
logger.info(" - Existing code continues to work")
|
||||
|
||||
await swarm.close()
|
||||
logger.info("Backward compatibility example completed")
|
||||
|
||||
|
||||
async def example_production_ready_config() -> None:
|
||||
"""Example of production-ready configuration."""
|
||||
logger.info("Creating swarm with production-ready configuration...")
|
||||
|
||||
# Production-ready retry configuration
|
||||
retry_config = RetryConfig(
|
||||
max_retries=3, # Reasonable retry limit
|
||||
initial_delay=0.1, # Quick initial retry
|
||||
max_delay=30.0, # Cap exponential backoff
|
||||
backoff_multiplier=2.0, # Standard exponential backoff
|
||||
jitter_factor=0.1, # Small jitter to prevent thundering herd
|
||||
)
|
||||
|
||||
# Production-ready connection configuration
|
||||
connection_config = ConnectionConfig(
|
||||
max_connections_per_peer=3, # Balance between performance and resource usage
|
||||
connection_timeout=30.0, # Reasonable timeout
|
||||
load_balancing_strategy="round_robin", # Simple, predictable strategy
|
||||
)
|
||||
|
||||
# Create swarm with production config
|
||||
swarm = new_swarm(retry_config=retry_config, connection_config=connection_config)
|
||||
|
||||
logger.info("Production-ready configuration applied:")
|
||||
logger.info(
|
||||
f" Retry: {retry_config.max_retries} retries, "
|
||||
f"{retry_config.max_delay}s max delay"
|
||||
)
|
||||
logger.info(f" Connections: {connection_config.max_connections_per_peer} per peer")
|
||||
logger.info(f" Load balancing: {connection_config.load_balancing_strategy}")
|
||||
|
||||
await swarm.close()
|
||||
logger.info("Production-ready configuration example completed")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Run all examples."""
|
||||
logger.info("Multiple Connections Per Peer Examples")
|
||||
logger.info("=" * 50)
|
||||
|
||||
try:
|
||||
await example_basic_multiple_connections()
|
||||
logger.info("-" * 30)
|
||||
|
||||
await example_custom_connection_config()
|
||||
logger.info("-" * 30)
|
||||
|
||||
await example_multiple_connections_api()
|
||||
logger.info("-" * 30)
|
||||
|
||||
await example_backward_compatibility()
|
||||
logger.info("-" * 30)
|
||||
|
||||
await example_production_ready_config()
|
||||
logger.info("-" * 30)
|
||||
|
||||
logger.info("All examples completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Example failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
trio.run(main)
|
||||
@ -1,3 +1,5 @@
|
||||
"""Libp2p Python implementation."""
|
||||
|
||||
from collections.abc import (
|
||||
Mapping,
|
||||
Sequence,
|
||||
@ -6,15 +8,12 @@ from importlib.metadata import version as __version
|
||||
from typing import (
|
||||
Literal,
|
||||
Optional,
|
||||
Type,
|
||||
cast,
|
||||
)
|
||||
|
||||
import multiaddr
|
||||
|
||||
from libp2p.abc import (
|
||||
IHost,
|
||||
IMuxedConn,
|
||||
INetworkService,
|
||||
IPeerRouting,
|
||||
IPeerStore,
|
||||
@ -32,9 +31,6 @@ from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
TSecurityOptions,
|
||||
)
|
||||
from libp2p.discovery.mdns.mdns import (
|
||||
MDNSDiscovery,
|
||||
)
|
||||
from libp2p.host.basic_host import (
|
||||
BasicHost,
|
||||
)
|
||||
@ -42,6 +38,8 @@ from libp2p.host.routed_host import (
|
||||
RoutedHost,
|
||||
)
|
||||
from libp2p.network.swarm import (
|
||||
ConnectionConfig,
|
||||
RetryConfig,
|
||||
Swarm,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
@ -55,17 +53,19 @@ from libp2p.security.insecure.transport import (
|
||||
PLAINTEXT_PROTOCOL_ID,
|
||||
InsecureTransport,
|
||||
)
|
||||
from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID
|
||||
from libp2p.security.noise.transport import Transport as NoiseTransport
|
||||
from libp2p.security.noise.transport import (
|
||||
PROTOCOL_ID as NOISE_PROTOCOL_ID,
|
||||
Transport as NoiseTransport,
|
||||
)
|
||||
import libp2p.security.secio.transport as secio
|
||||
from libp2p.stream_muxer.mplex.mplex import (
|
||||
MPLEX_PROTOCOL_ID,
|
||||
Mplex,
|
||||
)
|
||||
from libp2p.stream_muxer.yamux.yamux import (
|
||||
PROTOCOL_ID as YAMUX_PROTOCOL_ID,
|
||||
Yamux,
|
||||
)
|
||||
from libp2p.stream_muxer.yamux.yamux import PROTOCOL_ID as YAMUX_PROTOCOL_ID
|
||||
from libp2p.transport.tcp.tcp import (
|
||||
TCP,
|
||||
)
|
||||
@ -88,7 +88,6 @@ MUXER_MPLEX = "MPLEX"
|
||||
DEFAULT_NEGOTIATE_TIMEOUT = 5
|
||||
|
||||
|
||||
|
||||
def set_default_muxer(muxer_name: Literal["YAMUX", "MPLEX"]) -> None:
|
||||
"""
|
||||
Set the default multiplexer protocol to use.
|
||||
@ -163,6 +162,8 @@ def new_swarm(
|
||||
peerstore_opt: IPeerStore | None = None,
|
||||
muxer_preference: Literal["YAMUX", "MPLEX"] | None = None,
|
||||
listen_addrs: Sequence[multiaddr.Multiaddr] | None = None,
|
||||
retry_config: Optional["RetryConfig"] = None,
|
||||
connection_config: Optional["ConnectionConfig"] = None,
|
||||
) -> INetworkService:
|
||||
"""
|
||||
Create a swarm instance based on the parameters.
|
||||
@ -239,7 +240,14 @@ def new_swarm(
|
||||
# Store our key pair in peerstore
|
||||
peerstore.add_key_pair(id_opt, key_pair)
|
||||
|
||||
return Swarm(id_opt, peerstore, upgrader, transport)
|
||||
return Swarm(
|
||||
id_opt,
|
||||
peerstore,
|
||||
upgrader,
|
||||
transport,
|
||||
retry_config=retry_config,
|
||||
connection_config=connection_config
|
||||
)
|
||||
|
||||
|
||||
def new_host(
|
||||
@ -279,6 +287,12 @@ def new_host(
|
||||
|
||||
if disc_opt is not None:
|
||||
return RoutedHost(swarm, disc_opt, enable_mDNS, bootstrap)
|
||||
return BasicHost(network=swarm,enable_mDNS=enable_mDNS , bootstrap=bootstrap, negotitate_timeout=negotiate_timeout)
|
||||
return BasicHost(
|
||||
network=swarm,
|
||||
enable_mDNS=enable_mDNS,
|
||||
bootstrap=bootstrap,
|
||||
negotitate_timeout=negotiate_timeout
|
||||
)
|
||||
|
||||
|
||||
__version__ = __version("libp2p")
|
||||
|
||||
@ -1412,15 +1412,16 @@ class INetwork(ABC):
|
||||
----------
|
||||
peerstore : IPeerStore
|
||||
The peer store for managing peer information.
|
||||
connections : dict[ID, INetConn]
|
||||
A mapping of peer IDs to network connections.
|
||||
connections : dict[ID, list[INetConn]]
|
||||
A mapping of peer IDs to lists of network connections
|
||||
(multiple connections per peer).
|
||||
listeners : dict[str, IListener]
|
||||
A mapping of listener identifiers to listener instances.
|
||||
|
||||
"""
|
||||
|
||||
peerstore: IPeerStore
|
||||
connections: dict[ID, INetConn]
|
||||
connections: dict[ID, list[INetConn]]
|
||||
listeners: dict[str, IListener]
|
||||
|
||||
@abstractmethod
|
||||
@ -1436,9 +1437,56 @@ class INetwork(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def dial_peer(self, peer_id: ID) -> INetConn:
|
||||
def get_connections(self, peer_id: ID | None = None) -> list[INetConn]:
|
||||
"""
|
||||
Create a connection to the specified peer.
|
||||
Get connections for peer (like JS getConnections, Go ConnsToPeer).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID | None
|
||||
The peer ID to get connections for. If None, returns all connections.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[INetConn]
|
||||
List of connections to the specified peer, or all connections
|
||||
if peer_id is None.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_connections_map(self) -> dict[ID, list[INetConn]]:
|
||||
"""
|
||||
Get all connections map (like JS getConnectionsMap).
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[ID, list[INetConn]]
|
||||
The complete mapping of peer IDs to their connection lists.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_connection(self, peer_id: ID) -> INetConn | None:
|
||||
"""
|
||||
Get single connection for backward compatibility.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer ID to get a connection for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
INetConn | None
|
||||
The first available connection, or None if no connections exist.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def dial_peer(self, peer_id: ID) -> list[INetConn]:
|
||||
"""
|
||||
Create connections to the specified peer with load balancing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -1447,8 +1495,8 @@ class INetwork(ABC):
|
||||
|
||||
Returns
|
||||
-------
|
||||
INetConn
|
||||
The network connection instance to the specified peer.
|
||||
list[INetConn]
|
||||
List of established connections to the peer.
|
||||
|
||||
Raises
|
||||
------
|
||||
|
||||
@ -338,7 +338,7 @@ class BasicHost(IHost):
|
||||
:param peer_id: ID of the peer to check
|
||||
:return: True if peer has an active connection, False otherwise
|
||||
"""
|
||||
return peer_id in self._network.connections
|
||||
return len(self._network.get_connections(peer_id)) > 0
|
||||
|
||||
def get_peer_connection_info(self, peer_id: ID) -> INetConn | None:
|
||||
"""
|
||||
@ -347,4 +347,4 @@ class BasicHost(IHost):
|
||||
:param peer_id: ID of the peer to get info for
|
||||
:return: Connection object if peer is connected, None otherwise
|
||||
"""
|
||||
return self._network.connections.get(peer_id)
|
||||
return self._network.get_connection(peer_id)
|
||||
|
||||
@ -2,7 +2,9 @@ from collections.abc import (
|
||||
Awaitable,
|
||||
Callable,
|
||||
)
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
import random
|
||||
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
@ -59,6 +61,59 @@ from .exceptions import (
|
||||
logger = logging.getLogger("libp2p.network.swarm")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetryConfig:
|
||||
"""
|
||||
Configuration for retry logic with exponential backoff.
|
||||
|
||||
This configuration controls how connection attempts are retried when they fail.
|
||||
The retry mechanism uses exponential backoff with jitter to prevent thundering
|
||||
herd problems in distributed systems.
|
||||
|
||||
Attributes:
|
||||
max_retries: Maximum number of retry attempts before giving up.
|
||||
Default: 3 attempts
|
||||
initial_delay: Initial delay in seconds before the first retry.
|
||||
Default: 0.1 seconds (100ms)
|
||||
max_delay: Maximum delay cap in seconds to prevent excessive wait times.
|
||||
Default: 30.0 seconds
|
||||
backoff_multiplier: Multiplier for exponential backoff (each retry multiplies
|
||||
the delay by this factor). Default: 2.0 (doubles each time)
|
||||
jitter_factor: Random jitter factor (0.0-1.0) to add randomness to delays
|
||||
and prevent synchronized retries. Default: 0.1 (10% jitter)
|
||||
|
||||
"""
|
||||
|
||||
max_retries: int = 3
|
||||
initial_delay: float = 0.1
|
||||
max_delay: float = 30.0
|
||||
backoff_multiplier: float = 2.0
|
||||
jitter_factor: float = 0.1
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionConfig:
|
||||
"""
|
||||
Configuration for multi-connection support.
|
||||
|
||||
This configuration controls how multiple connections per peer are managed,
|
||||
including connection limits, timeouts, and load balancing strategies.
|
||||
|
||||
Attributes:
|
||||
max_connections_per_peer: Maximum number of connections allowed to a single
|
||||
peer. Default: 3 connections
|
||||
connection_timeout: Timeout in seconds for establishing new connections.
|
||||
Default: 30.0 seconds
|
||||
load_balancing_strategy: Strategy for distributing streams across connections.
|
||||
Options: "round_robin" (default) or "least_loaded"
|
||||
|
||||
"""
|
||||
|
||||
max_connections_per_peer: int = 3
|
||||
connection_timeout: float = 30.0
|
||||
load_balancing_strategy: str = "round_robin" # or "least_loaded"
|
||||
|
||||
|
||||
def create_default_stream_handler(network: INetworkService) -> StreamHandlerFn:
|
||||
async def stream_handler(stream: INetStream) -> None:
|
||||
await network.get_manager().wait_finished()
|
||||
@ -71,9 +126,8 @@ class Swarm(Service, INetworkService):
|
||||
peerstore: IPeerStore
|
||||
upgrader: TransportUpgrader
|
||||
transport: ITransport
|
||||
# TODO: Connection and `peer_id` are 1-1 mapping in our implementation,
|
||||
# whereas in Go one `peer_id` may point to multiple connections.
|
||||
connections: dict[ID, INetConn]
|
||||
# Enhanced: Support for multiple connections per peer
|
||||
connections: dict[ID, list[INetConn]] # Multiple connections per peer
|
||||
listeners: dict[str, IListener]
|
||||
common_stream_handler: StreamHandlerFn
|
||||
listener_nursery: trio.Nursery | None
|
||||
@ -81,18 +135,31 @@ class Swarm(Service, INetworkService):
|
||||
|
||||
notifees: list[INotifee]
|
||||
|
||||
# Enhanced: New configuration
|
||||
retry_config: RetryConfig
|
||||
connection_config: ConnectionConfig
|
||||
_round_robin_index: dict[ID, int]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
peer_id: ID,
|
||||
peerstore: IPeerStore,
|
||||
upgrader: TransportUpgrader,
|
||||
transport: ITransport,
|
||||
retry_config: RetryConfig | None = None,
|
||||
connection_config: ConnectionConfig | None = None,
|
||||
):
|
||||
self.self_id = peer_id
|
||||
self.peerstore = peerstore
|
||||
self.upgrader = upgrader
|
||||
self.transport = transport
|
||||
self.connections = dict()
|
||||
|
||||
# Enhanced: Initialize retry and connection configuration
|
||||
self.retry_config = retry_config or RetryConfig()
|
||||
self.connection_config = connection_config or ConnectionConfig()
|
||||
|
||||
# Enhanced: Initialize connections as 1:many mapping
|
||||
self.connections = {}
|
||||
self.listeners = dict()
|
||||
|
||||
# Create Notifee array
|
||||
@ -103,6 +170,9 @@ class Swarm(Service, INetworkService):
|
||||
self.listener_nursery = None
|
||||
self.event_listener_nursery_created = trio.Event()
|
||||
|
||||
# Load balancing state
|
||||
self._round_robin_index = {}
|
||||
|
||||
async def run(self) -> None:
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Create a nursery for listener tasks.
|
||||
@ -122,18 +192,74 @@ class Swarm(Service, INetworkService):
|
||||
def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None:
|
||||
self.common_stream_handler = stream_handler
|
||||
|
||||
async def dial_peer(self, peer_id: ID) -> INetConn:
|
||||
def get_connections(self, peer_id: ID | None = None) -> list[INetConn]:
|
||||
"""
|
||||
Try to create a connection to peer_id.
|
||||
Get connections for peer (like JS getConnections, Go ConnsToPeer).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID | None
|
||||
The peer ID to get connections for. If None, returns all connections.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[INetConn]
|
||||
List of connections to the specified peer, or all connections
|
||||
if peer_id is None.
|
||||
|
||||
"""
|
||||
if peer_id is not None:
|
||||
return self.connections.get(peer_id, [])
|
||||
|
||||
# Return all connections from all peers
|
||||
all_conns = []
|
||||
for conns in self.connections.values():
|
||||
all_conns.extend(conns)
|
||||
return all_conns
|
||||
|
||||
def get_connections_map(self) -> dict[ID, list[INetConn]]:
|
||||
"""
|
||||
Get all connections map (like JS getConnectionsMap).
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[ID, list[INetConn]]
|
||||
The complete mapping of peer IDs to their connection lists.
|
||||
|
||||
"""
|
||||
return self.connections.copy()
|
||||
|
||||
def get_connection(self, peer_id: ID) -> INetConn | None:
|
||||
"""
|
||||
Get single connection for backward compatibility.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer ID to get a connection for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
INetConn | None
|
||||
The first available connection, or None if no connections exist.
|
||||
|
||||
"""
|
||||
conns = self.get_connections(peer_id)
|
||||
return conns[0] if conns else None
|
||||
|
||||
async def dial_peer(self, peer_id: ID) -> list[INetConn]:
|
||||
"""
|
||||
Try to create connections to peer_id with enhanced retry logic.
|
||||
|
||||
:param peer_id: peer if we want to dial
|
||||
:raises SwarmException: raised when an error occurs
|
||||
:return: muxed connection
|
||||
:return: list of muxed connections
|
||||
"""
|
||||
if peer_id in self.connections:
|
||||
# If muxed connection already exists for peer_id,
|
||||
# set muxed connection equal to existing muxed connection
|
||||
return self.connections[peer_id]
|
||||
# Check if we already have connections
|
||||
existing_connections = self.get_connections(peer_id)
|
||||
if existing_connections:
|
||||
logger.debug(f"Reusing existing connections to peer {peer_id}")
|
||||
return existing_connections
|
||||
|
||||
logger.debug("attempting to dial peer %s", peer_id)
|
||||
|
||||
@ -146,12 +272,19 @@ class Swarm(Service, INetworkService):
|
||||
if not addrs:
|
||||
raise SwarmException(f"No known addresses to peer {peer_id}")
|
||||
|
||||
connections = []
|
||||
exceptions: list[SwarmException] = []
|
||||
|
||||
# Try all known addresses
|
||||
# Enhanced: Try all known addresses with retry logic
|
||||
for multiaddr in addrs:
|
||||
try:
|
||||
return await self.dial_addr(multiaddr, peer_id)
|
||||
connection = await self._dial_with_retry(multiaddr, peer_id)
|
||||
connections.append(connection)
|
||||
|
||||
# Limit number of connections per peer
|
||||
if len(connections) >= self.connection_config.max_connections_per_peer:
|
||||
break
|
||||
|
||||
except SwarmException as e:
|
||||
exceptions.append(e)
|
||||
logger.debug(
|
||||
@ -161,15 +294,73 @@ class Swarm(Service, INetworkService):
|
||||
exc_info=e,
|
||||
)
|
||||
|
||||
# Tried all addresses, raising exception.
|
||||
raise SwarmException(
|
||||
f"unable to connect to {peer_id}, no addresses established a successful "
|
||||
"connection (with exceptions)"
|
||||
) from MultiError(exceptions)
|
||||
if not connections:
|
||||
# Tried all addresses, raising exception.
|
||||
raise SwarmException(
|
||||
f"unable to connect to {peer_id}, no addresses established a "
|
||||
"successful connection (with exceptions)"
|
||||
) from MultiError(exceptions)
|
||||
|
||||
async def dial_addr(self, addr: Multiaddr, peer_id: ID) -> INetConn:
|
||||
return connections
|
||||
|
||||
async def _dial_with_retry(self, addr: Multiaddr, peer_id: ID) -> INetConn:
|
||||
"""
|
||||
Try to create a connection to peer_id with addr.
|
||||
Enhanced: Dial with retry logic and exponential backoff.
|
||||
|
||||
:param addr: the address to dial
|
||||
:param peer_id: the peer we want to connect to
|
||||
:raises SwarmException: raised when all retry attempts fail
|
||||
:return: network connection
|
||||
"""
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(self.retry_config.max_retries + 1):
|
||||
try:
|
||||
return await self._dial_addr_single_attempt(addr, peer_id)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempt < self.retry_config.max_retries:
|
||||
delay = self._calculate_backoff_delay(attempt)
|
||||
logger.debug(
|
||||
f"Connection attempt {attempt + 1} failed, "
|
||||
f"retrying in {delay:.2f}s: {e}"
|
||||
)
|
||||
await trio.sleep(delay)
|
||||
else:
|
||||
logger.debug(f"All {self.retry_config.max_retries} attempts failed")
|
||||
|
||||
# Convert the last exception to SwarmException for consistency
|
||||
if last_exception is not None:
|
||||
if isinstance(last_exception, SwarmException):
|
||||
raise last_exception
|
||||
else:
|
||||
raise SwarmException(
|
||||
f"Failed to connect after {self.retry_config.max_retries} attempts"
|
||||
) from last_exception
|
||||
|
||||
# This should never be reached, but mypy requires it
|
||||
raise SwarmException("Unexpected error in retry logic")
|
||||
|
||||
def _calculate_backoff_delay(self, attempt: int) -> float:
|
||||
"""
|
||||
Enhanced: Calculate backoff delay with jitter to prevent thundering herd.
|
||||
|
||||
:param attempt: the current attempt number (0-based)
|
||||
:return: delay in seconds
|
||||
"""
|
||||
delay = min(
|
||||
self.retry_config.initial_delay
|
||||
* (self.retry_config.backoff_multiplier**attempt),
|
||||
self.retry_config.max_delay,
|
||||
)
|
||||
|
||||
# Add jitter to prevent synchronized retries
|
||||
jitter = delay * self.retry_config.jitter_factor
|
||||
return delay + random.uniform(-jitter, jitter)
|
||||
|
||||
async def _dial_addr_single_attempt(self, addr: Multiaddr, peer_id: ID) -> INetConn:
|
||||
"""
|
||||
Enhanced: Single attempt to dial an address (extracted from original dial_addr).
|
||||
|
||||
:param addr: the address we want to connect with
|
||||
:param peer_id: the peer we want to connect to
|
||||
@ -216,19 +407,97 @@ class Swarm(Service, INetworkService):
|
||||
|
||||
return swarm_conn
|
||||
|
||||
async def dial_addr(self, addr: Multiaddr, peer_id: ID) -> INetConn:
|
||||
"""
|
||||
Enhanced: Try to create a connection to peer_id with addr using retry logic.
|
||||
|
||||
:param addr: the address we want to connect with
|
||||
:param peer_id: the peer we want to connect to
|
||||
:raises SwarmException: raised when an error occurs
|
||||
:return: network connection
|
||||
"""
|
||||
return await self._dial_with_retry(addr, peer_id)
|
||||
|
||||
async def new_stream(self, peer_id: ID) -> INetStream:
|
||||
"""
|
||||
Enhanced: Create a new stream with load balancing across multiple connections.
|
||||
|
||||
:param peer_id: peer_id of destination
|
||||
:raises SwarmException: raised when an error occurs
|
||||
:return: net stream instance
|
||||
"""
|
||||
logger.debug("attempting to open a stream to peer %s", peer_id)
|
||||
|
||||
swarm_conn = await self.dial_peer(peer_id)
|
||||
# Get existing connections or dial new ones
|
||||
connections = self.get_connections(peer_id)
|
||||
if not connections:
|
||||
connections = await self.dial_peer(peer_id)
|
||||
|
||||
net_stream = await swarm_conn.new_stream()
|
||||
logger.debug("successfully opened a stream to peer %s", peer_id)
|
||||
return net_stream
|
||||
# Load balancing strategy at interface level
|
||||
connection = self._select_connection(connections, peer_id)
|
||||
|
||||
try:
|
||||
net_stream = await connection.new_stream()
|
||||
logger.debug("successfully opened a stream to peer %s", peer_id)
|
||||
return net_stream
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to create stream on connection: {e}")
|
||||
# Try other connections if available
|
||||
for other_conn in connections:
|
||||
if other_conn != connection:
|
||||
try:
|
||||
net_stream = await other_conn.new_stream()
|
||||
logger.debug(
|
||||
f"Successfully opened a stream to peer {peer_id} "
|
||||
"using alternative connection"
|
||||
)
|
||||
return net_stream
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# All connections failed, raise exception
|
||||
raise SwarmException(f"Failed to create stream to peer {peer_id}") from e
|
||||
|
||||
def _select_connection(self, connections: list[INetConn], peer_id: ID) -> INetConn:
|
||||
"""
|
||||
Select connection based on load balancing strategy.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
connections : list[INetConn]
|
||||
List of available connections.
|
||||
peer_id : ID
|
||||
The peer ID for round-robin tracking.
|
||||
strategy : str
|
||||
Load balancing strategy ("round_robin", "least_loaded", etc.).
|
||||
|
||||
Returns
|
||||
-------
|
||||
INetConn
|
||||
Selected connection.
|
||||
|
||||
"""
|
||||
if not connections:
|
||||
raise ValueError("No connections available")
|
||||
|
||||
strategy = self.connection_config.load_balancing_strategy
|
||||
|
||||
if strategy == "round_robin":
|
||||
# Simple round-robin selection
|
||||
if peer_id not in self._round_robin_index:
|
||||
self._round_robin_index[peer_id] = 0
|
||||
|
||||
index = self._round_robin_index[peer_id] % len(connections)
|
||||
self._round_robin_index[peer_id] += 1
|
||||
return connections[index]
|
||||
|
||||
elif strategy == "least_loaded":
|
||||
# Find connection with least streams
|
||||
return min(connections, key=lambda c: len(c.get_streams()))
|
||||
|
||||
else:
|
||||
# Default to first connection
|
||||
return connections[0]
|
||||
|
||||
async def listen(self, *multiaddrs: Multiaddr) -> bool:
|
||||
"""
|
||||
@ -324,9 +593,9 @@ class Swarm(Service, INetworkService):
|
||||
# Perform alternative cleanup if the manager isn't initialized
|
||||
# Close all connections manually
|
||||
if hasattr(self, "connections"):
|
||||
for conn_id in list(self.connections.keys()):
|
||||
conn = self.connections[conn_id]
|
||||
await conn.close()
|
||||
for peer_id, conns in list(self.connections.items()):
|
||||
for conn in conns:
|
||||
await conn.close()
|
||||
|
||||
# Clear connection tracking dictionary
|
||||
self.connections.clear()
|
||||
@ -356,12 +625,28 @@ class Swarm(Service, INetworkService):
|
||||
logger.debug("swarm successfully closed")
|
||||
|
||||
async def close_peer(self, peer_id: ID) -> None:
|
||||
if peer_id not in self.connections:
|
||||
"""
|
||||
Close all connections to the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer ID to close connections for.
|
||||
|
||||
"""
|
||||
connections = self.get_connections(peer_id)
|
||||
if not connections:
|
||||
return
|
||||
connection = self.connections[peer_id]
|
||||
# NOTE: `connection.close` will delete `peer_id` from `self.connections`
|
||||
# and `notify_disconnected` for us.
|
||||
await connection.close()
|
||||
|
||||
# Close all connections
|
||||
for connection in connections:
|
||||
try:
|
||||
await connection.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing connection to {peer_id}: {e}")
|
||||
|
||||
# Remove from connections dict
|
||||
self.connections.pop(peer_id, None)
|
||||
|
||||
logger.debug("successfully close the connection to peer %s", peer_id)
|
||||
|
||||
@ -380,21 +665,71 @@ class Swarm(Service, INetworkService):
|
||||
await muxed_conn.event_started.wait()
|
||||
self.manager.run_task(swarm_conn.start)
|
||||
await swarm_conn.event_started.wait()
|
||||
# Store muxed_conn with peer id
|
||||
self.connections[muxed_conn.peer_id] = swarm_conn
|
||||
|
||||
# Add to connections dict with deduplication
|
||||
peer_id = muxed_conn.peer_id
|
||||
if peer_id not in self.connections:
|
||||
self.connections[peer_id] = []
|
||||
|
||||
# Check for duplicate connections by comparing the underlying muxed connection
|
||||
for existing_conn in self.connections[peer_id]:
|
||||
if existing_conn.muxed_conn == muxed_conn:
|
||||
logger.debug(f"Connection already exists for peer {peer_id}")
|
||||
# existing_conn is a SwarmConn since it's stored in the connections list
|
||||
return existing_conn # type: ignore[return-value]
|
||||
|
||||
self.connections[peer_id].append(swarm_conn)
|
||||
|
||||
# Trim if we exceed max connections
|
||||
max_conns = self.connection_config.max_connections_per_peer
|
||||
if len(self.connections[peer_id]) > max_conns:
|
||||
self._trim_connections(peer_id)
|
||||
|
||||
# Call notifiers since event occurred
|
||||
await self.notify_connected(swarm_conn)
|
||||
return swarm_conn
|
||||
|
||||
def _trim_connections(self, peer_id: ID) -> None:
|
||||
"""
|
||||
Remove oldest connections when limit is exceeded.
|
||||
"""
|
||||
connections = self.connections[peer_id]
|
||||
if len(connections) <= self.connection_config.max_connections_per_peer:
|
||||
return
|
||||
|
||||
# Sort by creation time and remove oldest
|
||||
# For now, just keep the most recent connections
|
||||
max_conns = self.connection_config.max_connections_per_peer
|
||||
connections_to_remove = connections[:-max_conns]
|
||||
|
||||
for conn in connections_to_remove:
|
||||
logger.debug(f"Trimming old connection for peer {peer_id}")
|
||||
trio.lowlevel.spawn_system_task(self._close_connection_async, conn)
|
||||
|
||||
# Keep only the most recent connections
|
||||
max_conns = self.connection_config.max_connections_per_peer
|
||||
self.connections[peer_id] = connections[-max_conns:]
|
||||
|
||||
async def _close_connection_async(self, connection: INetConn) -> None:
|
||||
"""Close a connection asynchronously."""
|
||||
try:
|
||||
await connection.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing connection: {e}")
|
||||
|
||||
def remove_conn(self, swarm_conn: SwarmConn) -> None:
|
||||
"""
|
||||
Simply remove the connection from Swarm's records, without closing
|
||||
the connection.
|
||||
"""
|
||||
peer_id = swarm_conn.muxed_conn.peer_id
|
||||
if peer_id not in self.connections:
|
||||
return
|
||||
del self.connections[peer_id]
|
||||
|
||||
if peer_id in self.connections:
|
||||
self.connections[peer_id] = [
|
||||
conn for conn in self.connections[peer_id] if conn != swarm_conn
|
||||
]
|
||||
if not self.connections[peer_id]:
|
||||
del self.connections[peer_id]
|
||||
|
||||
# Notifee
|
||||
|
||||
@ -440,3 +775,21 @@ class Swarm(Service, INetworkService):
|
||||
async with trio.open_nursery() as nursery:
|
||||
for notifee in self.notifees:
|
||||
nursery.start_soon(notifier, notifee)
|
||||
|
||||
# Backward compatibility properties
|
||||
@property
|
||||
def connections_legacy(self) -> dict[ID, INetConn]:
|
||||
"""
|
||||
Legacy 1:1 mapping for backward compatibility.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[ID, INetConn]
|
||||
Legacy mapping with only the first connection per peer.
|
||||
|
||||
"""
|
||||
legacy_conns = {}
|
||||
for peer_id, conns in self.connections.items():
|
||||
if conns:
|
||||
legacy_conns[peer_id] = conns[0]
|
||||
return legacy_conns
|
||||
|
||||
1
newsfragments/874.feature.rst
Normal file
1
newsfragments/874.feature.rst
Normal file
@ -0,0 +1 @@
|
||||
Enhanced Swarm networking with retry logic, exponential backoff, and multi-connection support. Added configurable retry mechanisms that automatically recover from transient connection failures using exponential backoff with jitter to prevent thundering herd problems. Introduced connection pooling that allows multiple concurrent connections per peer for improved performance and fault tolerance. Added load balancing across connections and automatic connection health management. All enhancements are fully backward compatible and can be configured through new RetryConfig and ConnectionConfig classes.
|
||||
@ -164,8 +164,8 @@ async def test_live_peers_unexpected_drop(security_protocol):
|
||||
assert peer_a_id in host_b.get_live_peers()
|
||||
|
||||
# Simulate unexpected connection drop by directly closing the connection
|
||||
conn = host_a.get_network().connections[peer_b_id]
|
||||
await conn.muxed_conn.close()
|
||||
conns = host_a.get_network().connections[peer_b_id]
|
||||
await conns[0].muxed_conn.close()
|
||||
|
||||
# Allow for connection cleanup
|
||||
await trio.sleep(0.1)
|
||||
|
||||
325
tests/core/network/test_enhanced_swarm.py
Normal file
325
tests/core/network/test_enhanced_swarm.py
Normal file
@ -0,0 +1,325 @@
|
||||
import time
|
||||
from typing import cast
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.abc import INetConn, INetStream
|
||||
from libp2p.network.exceptions import SwarmException
|
||||
from libp2p.network.swarm import (
|
||||
ConnectionConfig,
|
||||
RetryConfig,
|
||||
Swarm,
|
||||
)
|
||||
from libp2p.peer.id import ID
|
||||
|
||||
|
||||
class MockConnection(INetConn):
|
||||
"""Mock connection for testing."""
|
||||
|
||||
def __init__(self, peer_id: ID, is_closed: bool = False):
|
||||
self.peer_id = peer_id
|
||||
self._is_closed = is_closed
|
||||
self.streams = set() # Track streams properly
|
||||
# Mock the muxed_conn attribute that Swarm expects
|
||||
self.muxed_conn = Mock()
|
||||
self.muxed_conn.peer_id = peer_id
|
||||
# Required by INetConn interface
|
||||
self.event_started = trio.Event()
|
||||
|
||||
async def close(self):
|
||||
self._is_closed = True
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
return self._is_closed
|
||||
|
||||
async def new_stream(self) -> INetStream:
|
||||
# Create a mock stream and add it to the connection's stream set
|
||||
mock_stream = Mock(spec=INetStream)
|
||||
self.streams.add(mock_stream)
|
||||
return mock_stream
|
||||
|
||||
def get_streams(self) -> tuple[INetStream, ...]:
|
||||
"""Return all streams associated with this connection."""
|
||||
return tuple(self.streams)
|
||||
|
||||
def get_transport_addresses(self) -> list[Multiaddr]:
|
||||
"""Mock implementation of get_transport_addresses."""
|
||||
return []
|
||||
|
||||
|
||||
class MockNetStream(INetStream):
|
||||
"""Mock network stream for testing."""
|
||||
|
||||
def __init__(self, peer_id: ID):
|
||||
self.peer_id = peer_id
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_retry_config_defaults():
|
||||
"""Test RetryConfig default values."""
|
||||
config = RetryConfig()
|
||||
assert config.max_retries == 3
|
||||
assert config.initial_delay == 0.1
|
||||
assert config.max_delay == 30.0
|
||||
assert config.backoff_multiplier == 2.0
|
||||
assert config.jitter_factor == 0.1
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_config_defaults():
|
||||
"""Test ConnectionConfig default values."""
|
||||
config = ConnectionConfig()
|
||||
assert config.max_connections_per_peer == 3
|
||||
assert config.connection_timeout == 30.0
|
||||
assert config.load_balancing_strategy == "round_robin"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_enhanced_swarm_constructor():
|
||||
"""Test enhanced Swarm constructor with new configuration."""
|
||||
# Create mock dependencies
|
||||
peer_id = ID(b"QmTest")
|
||||
peerstore = Mock()
|
||||
upgrader = Mock()
|
||||
transport = Mock()
|
||||
|
||||
# Test with default config
|
||||
swarm = Swarm(peer_id, peerstore, upgrader, transport)
|
||||
assert swarm.retry_config.max_retries == 3
|
||||
assert swarm.connection_config.max_connections_per_peer == 3
|
||||
assert isinstance(swarm.connections, dict)
|
||||
|
||||
# Test with custom config
|
||||
custom_retry = RetryConfig(max_retries=5, initial_delay=0.5)
|
||||
custom_conn = ConnectionConfig(max_connections_per_peer=5)
|
||||
|
||||
swarm = Swarm(peer_id, peerstore, upgrader, transport, custom_retry, custom_conn)
|
||||
assert swarm.retry_config.max_retries == 5
|
||||
assert swarm.retry_config.initial_delay == 0.5
|
||||
assert swarm.connection_config.max_connections_per_peer == 5
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_backoff_calculation():
|
||||
"""Test exponential backoff calculation with jitter."""
|
||||
peer_id = ID(b"QmTest")
|
||||
peerstore = Mock()
|
||||
upgrader = Mock()
|
||||
transport = Mock()
|
||||
|
||||
retry_config = RetryConfig(
|
||||
initial_delay=0.1, max_delay=1.0, backoff_multiplier=2.0, jitter_factor=0.1
|
||||
)
|
||||
|
||||
swarm = Swarm(peer_id, peerstore, upgrader, transport, retry_config)
|
||||
|
||||
# Test backoff calculation
|
||||
delay1 = swarm._calculate_backoff_delay(0)
|
||||
delay2 = swarm._calculate_backoff_delay(1)
|
||||
delay3 = swarm._calculate_backoff_delay(2)
|
||||
|
||||
# Should increase exponentially
|
||||
assert delay2 > delay1
|
||||
assert delay3 > delay2
|
||||
|
||||
# Should respect max delay
|
||||
assert delay1 <= 1.0
|
||||
assert delay2 <= 1.0
|
||||
assert delay3 <= 1.0
|
||||
|
||||
# Should have jitter
|
||||
assert delay1 != 0.1 # Should have jitter added
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_retry_logic():
|
||||
"""Test retry logic in dial operations."""
|
||||
peer_id = ID(b"QmTest")
|
||||
peerstore = Mock()
|
||||
upgrader = Mock()
|
||||
transport = Mock()
|
||||
|
||||
# Configure for fast testing
|
||||
retry_config = RetryConfig(
|
||||
max_retries=2,
|
||||
initial_delay=0.01, # Very short for testing
|
||||
max_delay=0.1,
|
||||
)
|
||||
|
||||
swarm = Swarm(peer_id, peerstore, upgrader, transport, retry_config)
|
||||
|
||||
# Mock the single attempt method to fail twice then succeed
|
||||
attempt_count = [0]
|
||||
|
||||
async def mock_single_attempt(addr, peer_id):
|
||||
attempt_count[0] += 1
|
||||
if attempt_count[0] < 3:
|
||||
raise SwarmException(f"Attempt {attempt_count[0]} failed")
|
||||
return MockConnection(peer_id)
|
||||
|
||||
swarm._dial_addr_single_attempt = mock_single_attempt
|
||||
|
||||
# Test retry logic
|
||||
start_time = time.time()
|
||||
result = await swarm._dial_with_retry(Mock(spec=Multiaddr), peer_id)
|
||||
end_time = time.time()
|
||||
|
||||
# Should have succeeded after 3 attempts
|
||||
assert attempt_count[0] == 3
|
||||
assert isinstance(result, MockConnection)
|
||||
assert end_time - start_time > 0.01 # Should have some delay
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_load_balancing_strategies():
|
||||
"""Test load balancing strategies."""
|
||||
peer_id = ID(b"QmTest")
|
||||
peerstore = Mock()
|
||||
upgrader = Mock()
|
||||
transport = Mock()
|
||||
|
||||
swarm = Swarm(peer_id, peerstore, upgrader, transport)
|
||||
|
||||
# Create mock connections with different stream counts
|
||||
conn1 = MockConnection(peer_id)
|
||||
conn2 = MockConnection(peer_id)
|
||||
conn3 = MockConnection(peer_id)
|
||||
|
||||
# Add some streams to simulate load
|
||||
await conn1.new_stream()
|
||||
await conn1.new_stream()
|
||||
await conn2.new_stream()
|
||||
|
||||
connections = [conn1, conn2, conn3]
|
||||
|
||||
# Test round-robin strategy
|
||||
swarm.connection_config.load_balancing_strategy = "round_robin"
|
||||
# Cast to satisfy type checker
|
||||
connections_cast = cast("list[INetConn]", connections)
|
||||
selected1 = swarm._select_connection(connections_cast, peer_id)
|
||||
selected2 = swarm._select_connection(connections_cast, peer_id)
|
||||
selected3 = swarm._select_connection(connections_cast, peer_id)
|
||||
|
||||
# Should cycle through connections
|
||||
assert selected1 in connections
|
||||
assert selected2 in connections
|
||||
assert selected3 in connections
|
||||
|
||||
# Test least loaded strategy
|
||||
swarm.connection_config.load_balancing_strategy = "least_loaded"
|
||||
least_loaded = swarm._select_connection(connections_cast, peer_id)
|
||||
|
||||
# conn3 has 0 streams, conn2 has 1 stream, conn1 has 2 streams
|
||||
# So conn3 should be selected as least loaded
|
||||
assert least_loaded == conn3
|
||||
|
||||
# Test default strategy (first connection)
|
||||
swarm.connection_config.load_balancing_strategy = "unknown"
|
||||
default_selected = swarm._select_connection(connections_cast, peer_id)
|
||||
assert default_selected == conn1
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_multiple_connections_api():
|
||||
"""Test the new multiple connections API methods."""
|
||||
peer_id = ID(b"QmTest")
|
||||
peerstore = Mock()
|
||||
upgrader = Mock()
|
||||
transport = Mock()
|
||||
|
||||
swarm = Swarm(peer_id, peerstore, upgrader, transport)
|
||||
|
||||
# Test empty connections
|
||||
assert swarm.get_connections() == []
|
||||
assert swarm.get_connections(peer_id) == []
|
||||
assert swarm.get_connection(peer_id) is None
|
||||
assert swarm.get_connections_map() == {}
|
||||
|
||||
# Add some connections
|
||||
conn1 = MockConnection(peer_id)
|
||||
conn2 = MockConnection(peer_id)
|
||||
swarm.connections[peer_id] = [conn1, conn2]
|
||||
|
||||
# Test get_connections with peer_id
|
||||
peer_connections = swarm.get_connections(peer_id)
|
||||
assert len(peer_connections) == 2
|
||||
assert conn1 in peer_connections
|
||||
assert conn2 in peer_connections
|
||||
|
||||
# Test get_connections without peer_id (all connections)
|
||||
all_connections = swarm.get_connections()
|
||||
assert len(all_connections) == 2
|
||||
assert conn1 in all_connections
|
||||
assert conn2 in all_connections
|
||||
|
||||
# Test get_connection (backward compatibility)
|
||||
single_conn = swarm.get_connection(peer_id)
|
||||
assert single_conn in [conn1, conn2]
|
||||
|
||||
# Test get_connections_map
|
||||
connections_map = swarm.get_connections_map()
|
||||
assert peer_id in connections_map
|
||||
assert connections_map[peer_id] == [conn1, conn2]
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_connection_trimming():
|
||||
"""Test connection trimming when limit is exceeded."""
|
||||
peer_id = ID(b"QmTest")
|
||||
peerstore = Mock()
|
||||
upgrader = Mock()
|
||||
transport = Mock()
|
||||
|
||||
# Set max connections to 2
|
||||
connection_config = ConnectionConfig(max_connections_per_peer=2)
|
||||
swarm = Swarm(
|
||||
peer_id, peerstore, upgrader, transport, connection_config=connection_config
|
||||
)
|
||||
|
||||
# Add 3 connections
|
||||
conn1 = MockConnection(peer_id)
|
||||
conn2 = MockConnection(peer_id)
|
||||
conn3 = MockConnection(peer_id)
|
||||
|
||||
swarm.connections[peer_id] = [conn1, conn2, conn3]
|
||||
|
||||
# Trigger trimming
|
||||
swarm._trim_connections(peer_id)
|
||||
|
||||
# Should have only 2 connections
|
||||
assert len(swarm.connections[peer_id]) == 2
|
||||
|
||||
# The most recent connections should remain
|
||||
remaining = swarm.connections[peer_id]
|
||||
assert conn2 in remaining
|
||||
assert conn3 in remaining
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_backward_compatibility():
|
||||
"""Test backward compatibility features."""
|
||||
peer_id = ID(b"QmTest")
|
||||
peerstore = Mock()
|
||||
upgrader = Mock()
|
||||
transport = Mock()
|
||||
|
||||
swarm = Swarm(peer_id, peerstore, upgrader, transport)
|
||||
|
||||
# Add connections
|
||||
conn1 = MockConnection(peer_id)
|
||||
conn2 = MockConnection(peer_id)
|
||||
swarm.connections[peer_id] = [conn1, conn2]
|
||||
|
||||
# Test connections_legacy property
|
||||
legacy_connections = swarm.connections_legacy
|
||||
assert peer_id in legacy_connections
|
||||
# Should return first connection
|
||||
assert legacy_connections[peer_id] in [conn1, conn2]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
@ -51,14 +51,19 @@ async def test_swarm_dial_peer(security_protocol):
|
||||
for addr in transport.get_addrs()
|
||||
)
|
||||
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
|
||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
|
||||
# New: dial_peer now returns list of connections
|
||||
connections = await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
assert len(connections) > 0
|
||||
|
||||
# Verify connections are established in both directions
|
||||
assert swarms[0].get_peer_id() in swarms[1].connections
|
||||
assert swarms[1].get_peer_id() in swarms[0].connections
|
||||
|
||||
# Test: Reuse connections when we already have ones with a peer.
|
||||
conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()]
|
||||
conn = await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
assert conn is conn_to_1
|
||||
existing_connections = swarms[0].get_connections(swarms[1].get_peer_id())
|
||||
new_connections = await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
assert new_connections == existing_connections
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
@ -107,7 +112,8 @@ async def test_swarm_close_peer(security_protocol):
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_remove_conn(swarm_pair):
|
||||
swarm_0, swarm_1 = swarm_pair
|
||||
conn_0 = swarm_0.connections[swarm_1.get_peer_id()]
|
||||
# Get the first connection from the list
|
||||
conn_0 = swarm_0.connections[swarm_1.get_peer_id()][0]
|
||||
swarm_0.remove_conn(conn_0)
|
||||
assert swarm_1.get_peer_id() not in swarm_0.connections
|
||||
# Test: Remove twice. There should not be errors.
|
||||
@ -115,6 +121,67 @@ async def test_swarm_remove_conn(swarm_pair):
|
||||
assert swarm_1.get_peer_id() not in swarm_0.connections
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_multiple_connections(security_protocol):
|
||||
"""Test multiple connections per peer functionality."""
|
||||
async with SwarmFactory.create_batch_and_listen(
|
||||
2, security_protocol=security_protocol
|
||||
) as swarms:
|
||||
# Setup multiple addresses for peer
|
||||
addrs = tuple(
|
||||
addr
|
||||
for transport in swarms[1].listeners.values()
|
||||
for addr in transport.get_addrs()
|
||||
)
|
||||
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
|
||||
|
||||
# Dial peer - should return list of connections
|
||||
connections = await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
assert len(connections) > 0
|
||||
|
||||
# Test get_connections method
|
||||
peer_connections = swarms[0].get_connections(swarms[1].get_peer_id())
|
||||
assert len(peer_connections) == len(connections)
|
||||
|
||||
# Test get_connections_map method
|
||||
connections_map = swarms[0].get_connections_map()
|
||||
assert swarms[1].get_peer_id() in connections_map
|
||||
assert len(connections_map[swarms[1].get_peer_id()]) == len(connections)
|
||||
|
||||
# Test get_connection method (backward compatibility)
|
||||
single_conn = swarms[0].get_connection(swarms[1].get_peer_id())
|
||||
assert single_conn is not None
|
||||
assert single_conn in connections
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_load_balancing(security_protocol):
|
||||
"""Test load balancing across multiple connections."""
|
||||
async with SwarmFactory.create_batch_and_listen(
|
||||
2, security_protocol=security_protocol
|
||||
) as swarms:
|
||||
# Setup connection
|
||||
addrs = tuple(
|
||||
addr
|
||||
for transport in swarms[1].listeners.values()
|
||||
for addr in transport.get_addrs()
|
||||
)
|
||||
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
|
||||
|
||||
# Create multiple streams - should use load balancing
|
||||
streams = []
|
||||
for _ in range(5):
|
||||
stream = await swarms[0].new_stream(swarms[1].get_peer_id())
|
||||
streams.append(stream)
|
||||
|
||||
# Verify streams were created successfully
|
||||
assert len(streams) == 5
|
||||
|
||||
# Clean up
|
||||
for stream in streams:
|
||||
await stream.close()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_multiaddr(security_protocol):
|
||||
async with SwarmFactory.create_batch_and_listen(
|
||||
|
||||
@ -51,6 +51,9 @@ async def perform_simple_test(assertion_func, security_protocol):
|
||||
|
||||
# Extract the secured connection from either Mplex or Yamux implementation
|
||||
def get_secured_conn(conn):
|
||||
# conn is now a list, get the first connection
|
||||
if isinstance(conn, list):
|
||||
conn = conn[0]
|
||||
muxed_conn = conn.muxed_conn
|
||||
# Direct attribute access for known implementations
|
||||
has_secured_conn = hasattr(muxed_conn, "secured_conn")
|
||||
|
||||
@ -74,7 +74,8 @@ async def test_multiplexer_preference_parameter(muxer_preference):
|
||||
assert len(connections) > 0, "Connection not established"
|
||||
|
||||
# Get the first connection
|
||||
conn = list(connections.values())[0]
|
||||
conns = list(connections.values())[0]
|
||||
conn = conns[0] # Get first connection from the list
|
||||
muxed_conn = conn.muxed_conn
|
||||
|
||||
# Define a simple echo protocol
|
||||
@ -150,7 +151,8 @@ async def test_explicit_muxer_options(muxer_option_func, expected_stream_class):
|
||||
assert len(connections) > 0, "Connection not established"
|
||||
|
||||
# Get the first connection
|
||||
conn = list(connections.values())[0]
|
||||
conns = list(connections.values())[0]
|
||||
conn = conns[0] # Get first connection from the list
|
||||
muxed_conn = conn.muxed_conn
|
||||
|
||||
# Define a simple echo protocol
|
||||
@ -219,7 +221,8 @@ async def test_global_default_muxer(global_default):
|
||||
assert len(connections) > 0, "Connection not established"
|
||||
|
||||
# Get the first connection
|
||||
conn = list(connections.values())[0]
|
||||
conns = list(connections.values())[0]
|
||||
conn = conns[0] # Get first connection from the list
|
||||
muxed_conn = conn.muxed_conn
|
||||
|
||||
# Define a simple echo protocol
|
||||
|
||||
@ -669,8 +669,8 @@ async def swarm_conn_pair_factory(
|
||||
async with swarm_pair_factory(
|
||||
security_protocol=security_protocol, muxer_opt=muxer_opt
|
||||
) as swarms:
|
||||
conn_0 = swarms[0].connections[swarms[1].get_peer_id()]
|
||||
conn_1 = swarms[1].connections[swarms[0].get_peer_id()]
|
||||
conn_0 = swarms[0].connections[swarms[1].get_peer_id()][0]
|
||||
conn_1 = swarms[1].connections[swarms[0].get_peer_id()][0]
|
||||
yield cast(SwarmConn, conn_0), cast(SwarmConn, conn_1)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user