mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
feat(swarm): enhance swarm with retry backoff
This commit is contained in:
@ -1,3 +1,5 @@
|
||||
"""Libp2p Python implementation."""
|
||||
|
||||
from collections.abc import (
|
||||
Mapping,
|
||||
Sequence,
|
||||
@ -6,15 +8,12 @@ from importlib.metadata import version as __version
|
||||
from typing import (
|
||||
Literal,
|
||||
Optional,
|
||||
Type,
|
||||
cast,
|
||||
)
|
||||
|
||||
import multiaddr
|
||||
|
||||
from libp2p.abc import (
|
||||
IHost,
|
||||
IMuxedConn,
|
||||
INetworkService,
|
||||
IPeerRouting,
|
||||
IPeerStore,
|
||||
@ -32,9 +31,6 @@ from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
TSecurityOptions,
|
||||
)
|
||||
from libp2p.discovery.mdns.mdns import (
|
||||
MDNSDiscovery,
|
||||
)
|
||||
from libp2p.host.basic_host import (
|
||||
BasicHost,
|
||||
)
|
||||
@ -42,6 +38,8 @@ from libp2p.host.routed_host import (
|
||||
RoutedHost,
|
||||
)
|
||||
from libp2p.network.swarm import (
|
||||
ConnectionConfig,
|
||||
RetryConfig,
|
||||
Swarm,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
@ -54,17 +52,19 @@ from libp2p.security.insecure.transport import (
|
||||
PLAINTEXT_PROTOCOL_ID,
|
||||
InsecureTransport,
|
||||
)
|
||||
from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID
|
||||
from libp2p.security.noise.transport import Transport as NoiseTransport
|
||||
from libp2p.security.noise.transport import (
|
||||
PROTOCOL_ID as NOISE_PROTOCOL_ID,
|
||||
Transport as NoiseTransport,
|
||||
)
|
||||
import libp2p.security.secio.transport as secio
|
||||
from libp2p.stream_muxer.mplex.mplex import (
|
||||
MPLEX_PROTOCOL_ID,
|
||||
Mplex,
|
||||
)
|
||||
from libp2p.stream_muxer.yamux.yamux import (
|
||||
PROTOCOL_ID as YAMUX_PROTOCOL_ID,
|
||||
Yamux,
|
||||
)
|
||||
from libp2p.stream_muxer.yamux.yamux import PROTOCOL_ID as YAMUX_PROTOCOL_ID
|
||||
from libp2p.transport.tcp.tcp import (
|
||||
TCP,
|
||||
)
|
||||
@ -87,7 +87,6 @@ MUXER_MPLEX = "MPLEX"
|
||||
DEFAULT_NEGOTIATE_TIMEOUT = 5
|
||||
|
||||
|
||||
|
||||
def set_default_muxer(muxer_name: Literal["YAMUX", "MPLEX"]) -> None:
|
||||
"""
|
||||
Set the default multiplexer protocol to use.
|
||||
@ -163,6 +162,8 @@ def new_swarm(
|
||||
peerstore_opt: IPeerStore | None = None,
|
||||
muxer_preference: Literal["YAMUX", "MPLEX"] | None = None,
|
||||
listen_addrs: Sequence[multiaddr.Multiaddr] | None = None,
|
||||
retry_config: Optional["RetryConfig"] = None,
|
||||
connection_config: Optional["ConnectionConfig"] = None,
|
||||
) -> INetworkService:
|
||||
"""
|
||||
Create a swarm instance based on the parameters.
|
||||
@ -239,7 +240,14 @@ def new_swarm(
|
||||
# Store our key pair in peerstore
|
||||
peerstore.add_key_pair(id_opt, key_pair)
|
||||
|
||||
return Swarm(id_opt, peerstore, upgrader, transport)
|
||||
return Swarm(
|
||||
id_opt,
|
||||
peerstore,
|
||||
upgrader,
|
||||
transport,
|
||||
retry_config=retry_config,
|
||||
connection_config=connection_config
|
||||
)
|
||||
|
||||
|
||||
def new_host(
|
||||
@ -279,6 +287,12 @@ def new_host(
|
||||
|
||||
if disc_opt is not None:
|
||||
return RoutedHost(swarm, disc_opt, enable_mDNS, bootstrap)
|
||||
return BasicHost(network=swarm,enable_mDNS=enable_mDNS , bootstrap=bootstrap, negotitate_timeout=negotiate_timeout)
|
||||
return BasicHost(
|
||||
network=swarm,
|
||||
enable_mDNS=enable_mDNS,
|
||||
bootstrap=bootstrap,
|
||||
negotitate_timeout=negotiate_timeout
|
||||
)
|
||||
|
||||
|
||||
__version__ = __version("libp2p")
|
||||
|
||||
@ -2,7 +2,9 @@ from collections.abc import (
|
||||
Awaitable,
|
||||
Callable,
|
||||
)
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
import random
|
||||
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
@ -59,6 +61,188 @@ from .exceptions import (
|
||||
logger = logging.getLogger("libp2p.network.swarm")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetryConfig:
|
||||
"""Configuration for retry logic with exponential backoff."""
|
||||
|
||||
max_retries: int = 3
|
||||
initial_delay: float = 0.1
|
||||
max_delay: float = 30.0
|
||||
backoff_multiplier: float = 2.0
|
||||
jitter_factor: float = 0.1
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionConfig:
|
||||
"""Configuration for connection pool and multi-connection support."""
|
||||
|
||||
max_connections_per_peer: int = 3
|
||||
connection_timeout: float = 30.0
|
||||
enable_connection_pool: bool = True
|
||||
load_balancing_strategy: str = "round_robin" # or "least_loaded"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionInfo:
|
||||
"""Information about a connection in the pool."""
|
||||
|
||||
connection: INetConn
|
||||
address: str
|
||||
established_at: float
|
||||
last_used: float
|
||||
stream_count: int
|
||||
is_healthy: bool
|
||||
|
||||
|
||||
class ConnectionPool:
|
||||
"""Manages multiple connections per peer with load balancing."""
|
||||
|
||||
def __init__(self, max_connections_per_peer: int = 3):
|
||||
self.max_connections_per_peer = max_connections_per_peer
|
||||
self.peer_connections: dict[ID, list[ConnectionInfo]] = {}
|
||||
self._round_robin_index: dict[ID, int] = {}
|
||||
|
||||
def add_connection(self, peer_id: ID, connection: INetConn, address: str) -> None:
|
||||
"""Add a connection to the pool with deduplication."""
|
||||
if peer_id not in self.peer_connections:
|
||||
self.peer_connections[peer_id] = []
|
||||
|
||||
# Check for duplicate connections to the same address
|
||||
for conn_info in self.peer_connections[peer_id]:
|
||||
if conn_info.address == address:
|
||||
logger.debug(
|
||||
f"Connection to {address} already exists for peer {peer_id}"
|
||||
)
|
||||
return
|
||||
|
||||
# Add new connection
|
||||
try:
|
||||
current_time = trio.current_time()
|
||||
except RuntimeError:
|
||||
# Fallback for testing contexts where trio is not running
|
||||
import time
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
conn_info = ConnectionInfo(
|
||||
connection=connection,
|
||||
address=address,
|
||||
established_at=current_time,
|
||||
last_used=current_time,
|
||||
stream_count=0,
|
||||
is_healthy=True,
|
||||
)
|
||||
|
||||
self.peer_connections[peer_id].append(conn_info)
|
||||
|
||||
# Trim if we exceed max connections
|
||||
if len(self.peer_connections[peer_id]) > self.max_connections_per_peer:
|
||||
self._trim_connections(peer_id)
|
||||
|
||||
def get_connection(
|
||||
self, peer_id: ID, strategy: str = "round_robin"
|
||||
) -> INetConn | None:
|
||||
"""Get a connection using the specified load balancing strategy."""
|
||||
if peer_id not in self.peer_connections or not self.peer_connections[peer_id]:
|
||||
return None
|
||||
|
||||
connections = self.peer_connections[peer_id]
|
||||
|
||||
if strategy == "round_robin":
|
||||
if peer_id not in self._round_robin_index:
|
||||
self._round_robin_index[peer_id] = 0
|
||||
|
||||
index = self._round_robin_index[peer_id] % len(connections)
|
||||
self._round_robin_index[peer_id] += 1
|
||||
|
||||
conn_info = connections[index]
|
||||
try:
|
||||
conn_info.last_used = trio.current_time()
|
||||
except RuntimeError:
|
||||
import time
|
||||
|
||||
conn_info.last_used = time.time()
|
||||
return conn_info.connection
|
||||
|
||||
elif strategy == "least_loaded":
|
||||
# Find connection with least streams
|
||||
# Note: stream_count is a custom attribute we add to connections
|
||||
conn_info = min(
|
||||
connections, key=lambda c: getattr(c.connection, "stream_count", 0)
|
||||
)
|
||||
try:
|
||||
conn_info.last_used = trio.current_time()
|
||||
except RuntimeError:
|
||||
import time
|
||||
|
||||
conn_info.last_used = time.time()
|
||||
return conn_info.connection
|
||||
|
||||
else:
|
||||
# Default to first connection
|
||||
conn_info = connections[0]
|
||||
try:
|
||||
conn_info.last_used = trio.current_time()
|
||||
except RuntimeError:
|
||||
import time
|
||||
|
||||
conn_info.last_used = time.time()
|
||||
return conn_info.connection
|
||||
|
||||
def has_connection(self, peer_id: ID) -> bool:
|
||||
"""Check if we have any connections to the peer."""
|
||||
return (
|
||||
peer_id in self.peer_connections and len(self.peer_connections[peer_id]) > 0
|
||||
)
|
||||
|
||||
def remove_connection(self, peer_id: ID, connection: INetConn) -> None:
|
||||
"""Remove a connection from the pool."""
|
||||
if peer_id in self.peer_connections:
|
||||
self.peer_connections[peer_id] = [
|
||||
conn_info
|
||||
for conn_info in self.peer_connections[peer_id]
|
||||
if conn_info.connection != connection
|
||||
]
|
||||
|
||||
# Clean up empty peer entries
|
||||
if not self.peer_connections[peer_id]:
|
||||
del self.peer_connections[peer_id]
|
||||
if peer_id in self._round_robin_index:
|
||||
del self._round_robin_index[peer_id]
|
||||
|
||||
def _trim_connections(self, peer_id: ID) -> None:
|
||||
"""Remove oldest connections when limit is exceeded."""
|
||||
connections = self.peer_connections[peer_id]
|
||||
if len(connections) <= self.max_connections_per_peer:
|
||||
return
|
||||
|
||||
# Sort by last used time and remove oldest
|
||||
connections.sort(key=lambda c: c.last_used)
|
||||
connections_to_remove = connections[: -self.max_connections_per_peer]
|
||||
|
||||
for conn_info in connections_to_remove:
|
||||
logger.debug(
|
||||
f"Trimming old connection to {conn_info.address} for peer {peer_id}"
|
||||
)
|
||||
try:
|
||||
# Close the connection asynchronously
|
||||
trio.lowlevel.spawn_system_task(
|
||||
self._close_connection_async, conn_info.connection
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing trimmed connection: {e}")
|
||||
|
||||
# Keep only the most recently used connections
|
||||
self.peer_connections[peer_id] = connections[-self.max_connections_per_peer :]
|
||||
|
||||
async def _close_connection_async(self, connection: INetConn) -> None:
|
||||
"""Close a connection asynchronously."""
|
||||
try:
|
||||
await connection.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing connection: {e}")
|
||||
|
||||
|
||||
def create_default_stream_handler(network: INetworkService) -> StreamHandlerFn:
|
||||
async def stream_handler(stream: INetStream) -> None:
|
||||
await network.get_manager().wait_finished()
|
||||
@ -71,9 +255,8 @@ class Swarm(Service, INetworkService):
|
||||
peerstore: IPeerStore
|
||||
upgrader: TransportUpgrader
|
||||
transport: ITransport
|
||||
# TODO: Connection and `peer_id` are 1-1 mapping in our implementation,
|
||||
# whereas in Go one `peer_id` may point to multiple connections.
|
||||
connections: dict[ID, INetConn]
|
||||
# Enhanced: Support for multiple connections per peer
|
||||
connections: dict[ID, INetConn] # Backward compatibility
|
||||
listeners: dict[str, IListener]
|
||||
common_stream_handler: StreamHandlerFn
|
||||
listener_nursery: trio.Nursery | None
|
||||
@ -81,17 +264,38 @@ class Swarm(Service, INetworkService):
|
||||
|
||||
notifees: list[INotifee]
|
||||
|
||||
# Enhanced: New configuration and connection pool
|
||||
retry_config: RetryConfig
|
||||
connection_config: ConnectionConfig
|
||||
connection_pool: ConnectionPool | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
peer_id: ID,
|
||||
peerstore: IPeerStore,
|
||||
upgrader: TransportUpgrader,
|
||||
transport: ITransport,
|
||||
retry_config: RetryConfig | None = None,
|
||||
connection_config: ConnectionConfig | None = None,
|
||||
):
|
||||
self.self_id = peer_id
|
||||
self.peerstore = peerstore
|
||||
self.upgrader = upgrader
|
||||
self.transport = transport
|
||||
|
||||
# Enhanced: Initialize retry and connection configuration
|
||||
self.retry_config = retry_config or RetryConfig()
|
||||
self.connection_config = connection_config or ConnectionConfig()
|
||||
|
||||
# Enhanced: Initialize connection pool if enabled
|
||||
if self.connection_config.enable_connection_pool:
|
||||
self.connection_pool = ConnectionPool(
|
||||
self.connection_config.max_connections_per_peer
|
||||
)
|
||||
else:
|
||||
self.connection_pool = None
|
||||
|
||||
# Backward compatibility: Keep existing connections dict
|
||||
self.connections = dict()
|
||||
self.listeners = dict()
|
||||
|
||||
@ -124,12 +328,20 @@ class Swarm(Service, INetworkService):
|
||||
|
||||
async def dial_peer(self, peer_id: ID) -> INetConn:
|
||||
"""
|
||||
Try to create a connection to peer_id.
|
||||
Try to create a connection to peer_id with enhanced retry logic.
|
||||
|
||||
:param peer_id: peer if we want to dial
|
||||
:raises SwarmException: raised when an error occurs
|
||||
:return: muxed connection
|
||||
"""
|
||||
# Enhanced: Check connection pool first if enabled
|
||||
if self.connection_pool and self.connection_pool.has_connection(peer_id):
|
||||
connection = self.connection_pool.get_connection(peer_id)
|
||||
if connection:
|
||||
logger.debug(f"Reusing existing connection to peer {peer_id}")
|
||||
return connection
|
||||
|
||||
# Enhanced: Check existing single connection for backward compatibility
|
||||
if peer_id in self.connections:
|
||||
# If muxed connection already exists for peer_id,
|
||||
# set muxed connection equal to existing muxed connection
|
||||
@ -148,10 +360,21 @@ class Swarm(Service, INetworkService):
|
||||
|
||||
exceptions: list[SwarmException] = []
|
||||
|
||||
# Try all known addresses
|
||||
# Enhanced: Try all known addresses with retry logic
|
||||
for multiaddr in addrs:
|
||||
try:
|
||||
return await self.dial_addr(multiaddr, peer_id)
|
||||
connection = await self._dial_with_retry(multiaddr, peer_id)
|
||||
|
||||
# Enhanced: Add to connection pool if enabled
|
||||
if self.connection_pool:
|
||||
self.connection_pool.add_connection(
|
||||
peer_id, connection, str(multiaddr)
|
||||
)
|
||||
|
||||
# Backward compatibility: Keep existing connections dict
|
||||
self.connections[peer_id] = connection
|
||||
|
||||
return connection
|
||||
except SwarmException as e:
|
||||
exceptions.append(e)
|
||||
logger.debug(
|
||||
@ -167,9 +390,64 @@ class Swarm(Service, INetworkService):
|
||||
"connection (with exceptions)"
|
||||
) from MultiError(exceptions)
|
||||
|
||||
async def dial_addr(self, addr: Multiaddr, peer_id: ID) -> INetConn:
|
||||
async def _dial_with_retry(self, addr: Multiaddr, peer_id: ID) -> INetConn:
|
||||
"""
|
||||
Try to create a connection to peer_id with addr.
|
||||
Enhanced: Dial with retry logic and exponential backoff.
|
||||
|
||||
:param addr: the address to dial
|
||||
:param peer_id: the peer we want to connect to
|
||||
:raises SwarmException: raised when all retry attempts fail
|
||||
:return: network connection
|
||||
"""
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(self.retry_config.max_retries + 1):
|
||||
try:
|
||||
return await self._dial_addr_single_attempt(addr, peer_id)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempt < self.retry_config.max_retries:
|
||||
delay = self._calculate_backoff_delay(attempt)
|
||||
logger.debug(
|
||||
f"Connection attempt {attempt + 1} failed, "
|
||||
f"retrying in {delay:.2f}s: {e}"
|
||||
)
|
||||
await trio.sleep(delay)
|
||||
else:
|
||||
logger.debug(f"All {self.retry_config.max_retries} attempts failed")
|
||||
|
||||
# Convert the last exception to SwarmException for consistency
|
||||
if last_exception is not None:
|
||||
if isinstance(last_exception, SwarmException):
|
||||
raise last_exception
|
||||
else:
|
||||
raise SwarmException(
|
||||
f"Failed to connect after {self.retry_config.max_retries} attempts"
|
||||
) from last_exception
|
||||
|
||||
# This should never be reached, but mypy requires it
|
||||
raise SwarmException("Unexpected error in retry logic")
|
||||
|
||||
def _calculate_backoff_delay(self, attempt: int) -> float:
|
||||
"""
|
||||
Enhanced: Calculate backoff delay with jitter to prevent thundering herd.
|
||||
|
||||
:param attempt: the current attempt number (0-based)
|
||||
:return: delay in seconds
|
||||
"""
|
||||
delay = min(
|
||||
self.retry_config.initial_delay
|
||||
* (self.retry_config.backoff_multiplier**attempt),
|
||||
self.retry_config.max_delay,
|
||||
)
|
||||
|
||||
# Add jitter to prevent synchronized retries
|
||||
jitter = delay * self.retry_config.jitter_factor
|
||||
return delay + random.uniform(-jitter, jitter)
|
||||
|
||||
async def _dial_addr_single_attempt(self, addr: Multiaddr, peer_id: ID) -> INetConn:
|
||||
"""
|
||||
Enhanced: Single attempt to dial an address (extracted from original dial_addr).
|
||||
|
||||
:param addr: the address we want to connect with
|
||||
:param peer_id: the peer we want to connect to
|
||||
@ -216,14 +494,49 @@ class Swarm(Service, INetworkService):
|
||||
|
||||
return swarm_conn
|
||||
|
||||
async def dial_addr(self, addr: Multiaddr, peer_id: ID) -> INetConn:
|
||||
"""
|
||||
Enhanced: Try to create a connection to peer_id with addr using retry logic.
|
||||
|
||||
:param addr: the address we want to connect with
|
||||
:param peer_id: the peer we want to connect to
|
||||
:raises SwarmException: raised when an error occurs
|
||||
:return: network connection
|
||||
"""
|
||||
return await self._dial_with_retry(addr, peer_id)
|
||||
|
||||
async def new_stream(self, peer_id: ID) -> INetStream:
|
||||
"""
|
||||
Enhanced: Create a new stream with load balancing across multiple connections.
|
||||
|
||||
:param peer_id: peer_id of destination
|
||||
:raises SwarmException: raised when an error occurs
|
||||
:return: net stream instance
|
||||
"""
|
||||
logger.debug("attempting to open a stream to peer %s", peer_id)
|
||||
|
||||
# Enhanced: Try to get existing connection from pool first
|
||||
if self.connection_pool and self.connection_pool.has_connection(peer_id):
|
||||
connection = self.connection_pool.get_connection(
|
||||
peer_id, self.connection_config.load_balancing_strategy
|
||||
)
|
||||
if connection:
|
||||
try:
|
||||
net_stream = await connection.new_stream()
|
||||
logger.debug(
|
||||
"successfully opened a stream to peer %s "
|
||||
"using existing connection",
|
||||
peer_id,
|
||||
)
|
||||
return net_stream
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Failed to create stream on existing connection, "
|
||||
f"will dial new connection: {e}"
|
||||
)
|
||||
# Fall through to dial new connection
|
||||
|
||||
# Fall back to existing logic: dial peer and create stream
|
||||
swarm_conn = await self.dial_peer(peer_id)
|
||||
|
||||
net_stream = await swarm_conn.new_stream()
|
||||
@ -359,6 +672,11 @@ class Swarm(Service, INetworkService):
|
||||
if peer_id not in self.connections:
|
||||
return
|
||||
connection = self.connections[peer_id]
|
||||
|
||||
# Enhanced: Remove from connection pool if enabled
|
||||
if self.connection_pool:
|
||||
self.connection_pool.remove_connection(peer_id, connection)
|
||||
|
||||
# NOTE: `connection.close` will delete `peer_id` from `self.connections`
|
||||
# and `notify_disconnected` for us.
|
||||
await connection.close()
|
||||
@ -380,7 +698,15 @@ class Swarm(Service, INetworkService):
|
||||
await muxed_conn.event_started.wait()
|
||||
self.manager.run_task(swarm_conn.start)
|
||||
await swarm_conn.event_started.wait()
|
||||
# Store muxed_conn with peer id
|
||||
# Enhanced: Add to connection pool if enabled
|
||||
if self.connection_pool:
|
||||
# For incoming connections, we don't have a specific address
|
||||
# Use a placeholder that will be updated when we get more info
|
||||
self.connection_pool.add_connection(
|
||||
muxed_conn.peer_id, swarm_conn, "incoming"
|
||||
)
|
||||
|
||||
# Store muxed_conn with peer id (backward compatibility)
|
||||
self.connections[muxed_conn.peer_id] = swarm_conn
|
||||
# Call notifiers since event occurred
|
||||
await self.notify_connected(swarm_conn)
|
||||
@ -392,6 +718,11 @@ class Swarm(Service, INetworkService):
|
||||
the connection.
|
||||
"""
|
||||
peer_id = swarm_conn.muxed_conn.peer_id
|
||||
|
||||
# Enhanced: Remove from connection pool if enabled
|
||||
if self.connection_pool:
|
||||
self.connection_pool.remove_connection(peer_id, swarm_conn)
|
||||
|
||||
if peer_id not in self.connections:
|
||||
return
|
||||
del self.connections[peer_id]
|
||||
|
||||
Reference in New Issue
Block a user