address architectural refactoring discussed

This commit is contained in:
bomanaps
2025-08-31 01:38:29 +01:00
parent df39e240e7
commit 59e1d9ae39
12 changed files with 705 additions and 571 deletions

View File

@ -74,175 +74,13 @@ class RetryConfig:
@dataclass
class ConnectionConfig:
"""Configuration for connection pool and multi-connection support."""
"""Configuration for multi-connection support."""
max_connections_per_peer: int = 3
connection_timeout: float = 30.0
enable_connection_pool: bool = True
load_balancing_strategy: str = "round_robin" # or "least_loaded"
@dataclass
class ConnectionInfo:
"""Information about a connection in the pool."""
connection: INetConn
address: str
established_at: float
last_used: float
stream_count: int
is_healthy: bool
class ConnectionPool:
"""Manages multiple connections per peer with load balancing."""
def __init__(self, max_connections_per_peer: int = 3):
self.max_connections_per_peer = max_connections_per_peer
self.peer_connections: dict[ID, list[ConnectionInfo]] = {}
self._round_robin_index: dict[ID, int] = {}
def add_connection(self, peer_id: ID, connection: INetConn, address: str) -> None:
"""Add a connection to the pool with deduplication."""
if peer_id not in self.peer_connections:
self.peer_connections[peer_id] = []
# Check for duplicate connections to the same address
for conn_info in self.peer_connections[peer_id]:
if conn_info.address == address:
logger.debug(
f"Connection to {address} already exists for peer {peer_id}"
)
return
# Add new connection
try:
current_time = trio.current_time()
except RuntimeError:
# Fallback for testing contexts where trio is not running
import time
current_time = time.time()
conn_info = ConnectionInfo(
connection=connection,
address=address,
established_at=current_time,
last_used=current_time,
stream_count=0,
is_healthy=True,
)
self.peer_connections[peer_id].append(conn_info)
# Trim if we exceed max connections
if len(self.peer_connections[peer_id]) > self.max_connections_per_peer:
self._trim_connections(peer_id)
def get_connection(
self, peer_id: ID, strategy: str = "round_robin"
) -> INetConn | None:
"""Get a connection using the specified load balancing strategy."""
if peer_id not in self.peer_connections or not self.peer_connections[peer_id]:
return None
connections = self.peer_connections[peer_id]
if strategy == "round_robin":
if peer_id not in self._round_robin_index:
self._round_robin_index[peer_id] = 0
index = self._round_robin_index[peer_id] % len(connections)
self._round_robin_index[peer_id] += 1
conn_info = connections[index]
try:
conn_info.last_used = trio.current_time()
except RuntimeError:
import time
conn_info.last_used = time.time()
return conn_info.connection
elif strategy == "least_loaded":
# Find connection with least streams
# Note: stream_count is a custom attribute we add to connections
conn_info = min(
connections, key=lambda c: getattr(c.connection, "stream_count", 0)
)
try:
conn_info.last_used = trio.current_time()
except RuntimeError:
import time
conn_info.last_used = time.time()
return conn_info.connection
else:
# Default to first connection
conn_info = connections[0]
try:
conn_info.last_used = trio.current_time()
except RuntimeError:
import time
conn_info.last_used = time.time()
return conn_info.connection
def has_connection(self, peer_id: ID) -> bool:
"""Check if we have any connections to the peer."""
return (
peer_id in self.peer_connections and len(self.peer_connections[peer_id]) > 0
)
def remove_connection(self, peer_id: ID, connection: INetConn) -> None:
"""Remove a connection from the pool."""
if peer_id in self.peer_connections:
self.peer_connections[peer_id] = [
conn_info
for conn_info in self.peer_connections[peer_id]
if conn_info.connection != connection
]
# Clean up empty peer entries
if not self.peer_connections[peer_id]:
del self.peer_connections[peer_id]
if peer_id in self._round_robin_index:
del self._round_robin_index[peer_id]
def _trim_connections(self, peer_id: ID) -> None:
"""Remove oldest connections when limit is exceeded."""
connections = self.peer_connections[peer_id]
if len(connections) <= self.max_connections_per_peer:
return
# Sort by last used time and remove oldest
connections.sort(key=lambda c: c.last_used)
connections_to_remove = connections[: -self.max_connections_per_peer]
for conn_info in connections_to_remove:
logger.debug(
f"Trimming old connection to {conn_info.address} for peer {peer_id}"
)
try:
# Close the connection asynchronously
trio.lowlevel.spawn_system_task(
self._close_connection_async, conn_info.connection
)
except Exception as e:
logger.warning(f"Error closing trimmed connection: {e}")
# Keep only the most recently used connections
self.peer_connections[peer_id] = connections[-self.max_connections_per_peer :]
async def _close_connection_async(self, connection: INetConn) -> None:
"""Close a connection asynchronously."""
try:
await connection.close()
except Exception as e:
logger.warning(f"Error closing connection: {e}")
def create_default_stream_handler(network: INetworkService) -> StreamHandlerFn:
async def stream_handler(stream: INetStream) -> None:
await network.get_manager().wait_finished()
@ -256,7 +94,7 @@ class Swarm(Service, INetworkService):
upgrader: TransportUpgrader
transport: ITransport
# Enhanced: Support for multiple connections per peer
connections: dict[ID, INetConn] # Backward compatibility
connections: dict[ID, list[INetConn]] # Multiple connections per peer
listeners: dict[str, IListener]
common_stream_handler: StreamHandlerFn
listener_nursery: trio.Nursery | None
@ -264,10 +102,10 @@ class Swarm(Service, INetworkService):
notifees: list[INotifee]
# Enhanced: New configuration and connection pool
# Enhanced: New configuration
retry_config: RetryConfig
connection_config: ConnectionConfig
connection_pool: ConnectionPool | None
_round_robin_index: dict[ID, int]
def __init__(
self,
@ -287,16 +125,8 @@ class Swarm(Service, INetworkService):
self.retry_config = retry_config or RetryConfig()
self.connection_config = connection_config or ConnectionConfig()
# Enhanced: Initialize connection pool if enabled
if self.connection_config.enable_connection_pool:
self.connection_pool = ConnectionPool(
self.connection_config.max_connections_per_peer
)
else:
self.connection_pool = None
# Backward compatibility: Keep existing connections dict
self.connections = dict()
# Enhanced: Initialize connections as 1:many mapping
self.connections = {}
self.listeners = dict()
# Create Notifee array
@ -307,6 +137,9 @@ class Swarm(Service, INetworkService):
self.listener_nursery = None
self.event_listener_nursery_created = trio.Event()
# Load balancing state
self._round_robin_index = {}
async def run(self) -> None:
async with trio.open_nursery() as nursery:
# Create a nursery for listener tasks.
@ -326,26 +159,74 @@ class Swarm(Service, INetworkService):
def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None:
self.common_stream_handler = stream_handler
async def dial_peer(self, peer_id: ID) -> INetConn:
def get_connections(self, peer_id: ID | None = None) -> list[INetConn]:
"""
Try to create a connection to peer_id with enhanced retry logic.
Get connections for peer (like JS getConnections, Go ConnsToPeer).
Parameters
----------
peer_id : ID | None
The peer ID to get connections for. If None, returns all connections.
Returns
-------
list[INetConn]
List of connections to the specified peer, or all connections
if peer_id is None.
"""
if peer_id is not None:
return self.connections.get(peer_id, [])
# Return all connections from all peers
all_conns = []
for conns in self.connections.values():
all_conns.extend(conns)
return all_conns
def get_connections_map(self) -> dict[ID, list[INetConn]]:
"""
Get all connections map (like JS getConnectionsMap).
Returns
-------
dict[ID, list[INetConn]]
The complete mapping of peer IDs to their connection lists.
"""
return self.connections.copy()
def get_connection(self, peer_id: ID) -> INetConn | None:
"""
Get single connection for backward compatibility.
Parameters
----------
peer_id : ID
The peer ID to get a connection for.
Returns
-------
INetConn | None
The first available connection, or None if no connections exist.
"""
conns = self.get_connections(peer_id)
return conns[0] if conns else None
async def dial_peer(self, peer_id: ID) -> list[INetConn]:
"""
Try to create connections to peer_id with enhanced retry logic.
:param peer_id: peer if we want to dial
:raises SwarmException: raised when an error occurs
:return: muxed connection
:return: list of muxed connections
"""
# Enhanced: Check connection pool first if enabled
if self.connection_pool and self.connection_pool.has_connection(peer_id):
connection = self.connection_pool.get_connection(peer_id)
if connection:
logger.debug(f"Reusing existing connection to peer {peer_id}")
return connection
# Enhanced: Check existing single connection for backward compatibility
if peer_id in self.connections:
# If muxed connection already exists for peer_id,
# set muxed connection equal to existing muxed connection
return self.connections[peer_id]
# Check if we already have connections
existing_connections = self.get_connections(peer_id)
if existing_connections:
logger.debug(f"Reusing existing connections to peer {peer_id}")
return existing_connections
logger.debug("attempting to dial peer %s", peer_id)
@ -358,23 +239,19 @@ class Swarm(Service, INetworkService):
if not addrs:
raise SwarmException(f"No known addresses to peer {peer_id}")
connections = []
exceptions: list[SwarmException] = []
# Enhanced: Try all known addresses with retry logic
for multiaddr in addrs:
try:
connection = await self._dial_with_retry(multiaddr, peer_id)
connections.append(connection)
# Enhanced: Add to connection pool if enabled
if self.connection_pool:
self.connection_pool.add_connection(
peer_id, connection, str(multiaddr)
)
# Limit number of connections per peer
if len(connections) >= self.connection_config.max_connections_per_peer:
break
# Backward compatibility: Keep existing connections dict
self.connections[peer_id] = connection
return connection
except SwarmException as e:
exceptions.append(e)
logger.debug(
@ -384,11 +261,14 @@ class Swarm(Service, INetworkService):
exc_info=e,
)
# Tried all addresses, raising exception.
raise SwarmException(
f"unable to connect to {peer_id}, no addresses established a successful "
"connection (with exceptions)"
) from MultiError(exceptions)
if not connections:
# Tried all addresses, raising exception.
raise SwarmException(
f"unable to connect to {peer_id}, no addresses established a "
"successful connection (with exceptions)"
) from MultiError(exceptions)
return connections
async def _dial_with_retry(self, addr: Multiaddr, peer_id: ID) -> INetConn:
"""
@ -515,33 +395,76 @@ class Swarm(Service, INetworkService):
"""
logger.debug("attempting to open a stream to peer %s", peer_id)
# Enhanced: Try to get existing connection from pool first
if self.connection_pool and self.connection_pool.has_connection(peer_id):
connection = self.connection_pool.get_connection(
peer_id, self.connection_config.load_balancing_strategy
)
if connection:
try:
net_stream = await connection.new_stream()
logger.debug(
"successfully opened a stream to peer %s "
"using existing connection",
peer_id,
)
return net_stream
except Exception as e:
logger.debug(
f"Failed to create stream on existing connection, "
f"will dial new connection: {e}"
)
# Fall through to dial new connection
# Get existing connections or dial new ones
connections = self.get_connections(peer_id)
if not connections:
connections = await self.dial_peer(peer_id)
# Fall back to existing logic: dial peer and create stream
swarm_conn = await self.dial_peer(peer_id)
# Load balancing strategy at interface level
connection = self._select_connection(connections, peer_id)
net_stream = await swarm_conn.new_stream()
logger.debug("successfully opened a stream to peer %s", peer_id)
return net_stream
try:
net_stream = await connection.new_stream()
logger.debug("successfully opened a stream to peer %s", peer_id)
return net_stream
except Exception as e:
logger.debug(f"Failed to create stream on connection: {e}")
# Try other connections if available
for other_conn in connections:
if other_conn != connection:
try:
net_stream = await other_conn.new_stream()
logger.debug(
f"Successfully opened a stream to peer {peer_id} "
"using alternative connection"
)
return net_stream
except Exception:
continue
# All connections failed, raise exception
raise SwarmException(f"Failed to create stream to peer {peer_id}") from e
def _select_connection(self, connections: list[INetConn], peer_id: ID) -> INetConn:
"""
Select connection based on load balancing strategy.
Parameters
----------
connections : list[INetConn]
List of available connections.
peer_id : ID
The peer ID for round-robin tracking.
strategy : str
Load balancing strategy ("round_robin", "least_loaded", etc.).
Returns
-------
INetConn
Selected connection.
"""
if not connections:
raise ValueError("No connections available")
strategy = self.connection_config.load_balancing_strategy
if strategy == "round_robin":
# Simple round-robin selection
if peer_id not in self._round_robin_index:
self._round_robin_index[peer_id] = 0
index = self._round_robin_index[peer_id] % len(connections)
self._round_robin_index[peer_id] += 1
return connections[index]
elif strategy == "least_loaded":
# Find connection with least streams
return min(connections, key=lambda c: len(c.get_streams()))
else:
# Default to first connection
return connections[0]
async def listen(self, *multiaddrs: Multiaddr) -> bool:
"""
@ -637,9 +560,9 @@ class Swarm(Service, INetworkService):
# Perform alternative cleanup if the manager isn't initialized
# Close all connections manually
if hasattr(self, "connections"):
for conn_id in list(self.connections.keys()):
conn = self.connections[conn_id]
await conn.close()
for peer_id, conns in list(self.connections.items()):
for conn in conns:
await conn.close()
# Clear connection tracking dictionary
self.connections.clear()
@ -669,17 +592,28 @@ class Swarm(Service, INetworkService):
logger.debug("swarm successfully closed")
async def close_peer(self, peer_id: ID) -> None:
if peer_id not in self.connections:
"""
Close all connections to the specified peer.
Parameters
----------
peer_id : ID
The peer ID to close connections for.
"""
connections = self.get_connections(peer_id)
if not connections:
return
connection = self.connections[peer_id]
# Enhanced: Remove from connection pool if enabled
if self.connection_pool:
self.connection_pool.remove_connection(peer_id, connection)
# Close all connections
for connection in connections:
try:
await connection.close()
except Exception as e:
logger.warning(f"Error closing connection to {peer_id}: {e}")
# NOTE: `connection.close` will delete `peer_id` from `self.connections`
# and `notify_disconnected` for us.
await connection.close()
# Remove from connections dict
self.connections.pop(peer_id, None)
logger.debug("successfully close the connection to peer %s", peer_id)
@ -698,20 +632,58 @@ class Swarm(Service, INetworkService):
await muxed_conn.event_started.wait()
self.manager.run_task(swarm_conn.start)
await swarm_conn.event_started.wait()
# Enhanced: Add to connection pool if enabled
if self.connection_pool:
# For incoming connections, we don't have a specific address
# Use a placeholder that will be updated when we get more info
self.connection_pool.add_connection(
muxed_conn.peer_id, swarm_conn, "incoming"
)
# Store muxed_conn with peer id (backward compatibility)
self.connections[muxed_conn.peer_id] = swarm_conn
# Add to connections dict with deduplication
peer_id = muxed_conn.peer_id
if peer_id not in self.connections:
self.connections[peer_id] = []
# Check for duplicate connections by comparing the underlying muxed connection
for existing_conn in self.connections[peer_id]:
if existing_conn.muxed_conn == muxed_conn:
logger.debug(f"Connection already exists for peer {peer_id}")
# existing_conn is a SwarmConn since it's stored in the connections list
return existing_conn # type: ignore[return-value]
self.connections[peer_id].append(swarm_conn)
# Trim if we exceed max connections
max_conns = self.connection_config.max_connections_per_peer
if len(self.connections[peer_id]) > max_conns:
self._trim_connections(peer_id)
# Call notifiers since event occurred
await self.notify_connected(swarm_conn)
return swarm_conn
def _trim_connections(self, peer_id: ID) -> None:
"""
Remove oldest connections when limit is exceeded.
"""
connections = self.connections[peer_id]
if len(connections) <= self.connection_config.max_connections_per_peer:
return
# Sort by creation time and remove oldest
# For now, just keep the most recent connections
max_conns = self.connection_config.max_connections_per_peer
connections_to_remove = connections[:-max_conns]
for conn in connections_to_remove:
logger.debug(f"Trimming old connection for peer {peer_id}")
trio.lowlevel.spawn_system_task(self._close_connection_async, conn)
# Keep only the most recent connections
max_conns = self.connection_config.max_connections_per_peer
self.connections[peer_id] = connections[-max_conns:]
async def _close_connection_async(self, connection: INetConn) -> None:
"""Close a connection asynchronously."""
try:
await connection.close()
except Exception as e:
logger.warning(f"Error closing connection: {e}")
def remove_conn(self, swarm_conn: SwarmConn) -> None:
"""
Simply remove the connection from Swarm's records, without closing
@ -719,13 +691,12 @@ class Swarm(Service, INetworkService):
"""
peer_id = swarm_conn.muxed_conn.peer_id
# Enhanced: Remove from connection pool if enabled
if self.connection_pool:
self.connection_pool.remove_connection(peer_id, swarm_conn)
if peer_id not in self.connections:
return
del self.connections[peer_id]
if peer_id in self.connections:
self.connections[peer_id] = [
conn for conn in self.connections[peer_id] if conn != swarm_conn
]
if not self.connections[peer_id]:
del self.connections[peer_id]
# Notifee
@ -771,3 +742,21 @@ class Swarm(Service, INetworkService):
async with trio.open_nursery() as nursery:
for notifee in self.notifees:
nursery.start_soon(notifier, notifee)
# Backward compatibility properties
@property
def connections_legacy(self) -> dict[ID, INetConn]:
"""
Legacy 1:1 mapping for backward compatibility.
Returns
-------
dict[ID, INetConn]
Legacy mapping with only the first connection per peer.
"""
legacy_conns = {}
for peer_id, conns in self.connections.items():
if conns:
legacy_conns[peer_id] = conns[0]
return legacy_conns