Feat: Adding Yamux as default multiplexer, keeping Mplex as fallback (#538)

* feat: Replace mplex with yamux as default multiplexer in py-libp2p

* Retain Mplex alongside Yamux in new_swarm with messaging that Yamux is preferred

* moved !BBHII to a constant YAMUX_HEADER_FORMAT at the top of yamux.py with a comment explaining its structure

* renamed the news fragment to 534.feature.rst and updated the description

* renamed the news fragment to 534.feature.rst and updated the description

* added a docstring to clarify that Yamux does not support deadlines natively

* Remove the __main__ block entirely from test_yamux.py

* Replaced the print statements in test_yamux.py with logging.debug

* Added a comment linking to the spec for clarity

* Raise NotImplementedError in YamuxStream.set_deadline per review

* Add muxed_conn to YamuxStream and test deadline NotImplementedError

* Fix Yamux implementation to meet libp2p spec

* Fix None handling in YamuxStream.read and Yamux.read_stream

* Fix test_connected_peers.py to correctly handle peer connections

* fix: Ensure StreamReset is raised on read after local reset in yamux

* fix: Map MuxedStreamError to StreamClosed in NetStream.write for Yamux

* fix: Raise MuxedStreamReset in Yamux.read_stream for closed streams

* fix: Correct Yamux stream read behavior for NetStream tests

Fixed 	est_net_stream_read_after_remote_closed by updating NetStream.read to raise StreamEOF when the stream is remotely closed and no data is available, aligning with test expectations and Fixed 	est_net_stream_read_until_eof by modifying YamuxStream.read to block until the stream is closed (
ecv_closed=True) for
=-1 reads, ensuring data is only returned after remote closure.

* fix: Correct Yamux stream read behavior for NetStream tests

Fixed 	est_net_stream_read_after_remote_closed by updating NetStream.read to raise StreamEOF when the stream is remotely closed and no data is available, aligning with test expectations and Fixed 	est_net_stream_read_until_eof by modifying YamuxStream.read to block until the stream is closed (
ecv_closed=True) for
=-1 reads, ensuring data is only returned after remote closure.

* fix: raise StreamEOF when reading from closed stream with empty buffer

* fix: prioritize returning buffered data even after stream reset

* fix: prioritize returning buffered data even after stream reset

* fix: Ensure test_net_stream_read_after_remote_closed_and_reset passes in full suite

* fix: Add __init__.py to yamux module to fix documentation build

* fix: Add __init__.py to yamux module to fix documentation build

* fix: Add libp2p.stream_muxer.yamux to libp2p.stream_muxer.rst toctree

* fix: Correct title underline length in libp2p.stream_muxer.yamux.rst

* fix: Add a = so that is matches the libp2p.stream\_muxer.yamux length

* fix(tests): Resolve race condition in network notification test

* fix: fixing failing tests and examples with yamux and noise

* refactor: remove debug logging and improve x25519 tests

* fix: Add functionality for users to choose between Yamux and Mplex

* fix: increased trio sleep to 0.1 sec for slow environment

* feat: Add test for switching between Yamux and mplex

* refactor: move host fixtures to interop tests

* chore: Update __init__.py removing unused import

removed unused
```python
import os
import logging
```

* lint: fix import order

* fix: Resolve conftest.py conflict by removing trio test support

* fix: Resolve test skipping by keeping trio test support

* Fix: add a newline at end of the file

---------

Co-authored-by: acul71 <luca.pisani@birdo.net>
Co-authored-by: acul71 <34693171+acul71@users.noreply.github.com>
This commit is contained in:
Paschal
2025-05-22 21:01:51 +01:00
committed by GitHub
parent 18c6f529c6
commit 4b1860766d
29 changed files with 2215 additions and 101 deletions

View File

@ -1,10 +1,21 @@
from collections.abc import (
Mapping,
)
from importlib.metadata import version as __version
from typing import (
Literal,
Optional,
Type,
cast,
)
from libp2p.abc import (
IHost,
IMuxedConn,
INetworkService,
IPeerRouting,
IPeerStore,
ISecureTransport,
)
from libp2p.crypto.keys import (
KeyPair,
@ -12,6 +23,7 @@ from libp2p.crypto.keys import (
from libp2p.crypto.rsa import (
create_new_key_pair,
)
from libp2p.crypto.x25519 import create_new_key_pair as create_new_x25519_key_pair
from libp2p.custom_types import (
TMuxerOptions,
TProtocol,
@ -36,11 +48,17 @@ from libp2p.security.insecure.transport import (
PLAINTEXT_PROTOCOL_ID,
InsecureTransport,
)
from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID
from libp2p.security.noise.transport import Transport as NoiseTransport
import libp2p.security.secio.transport as secio
from libp2p.stream_muxer.mplex.mplex import (
MPLEX_PROTOCOL_ID,
Mplex,
)
from libp2p.stream_muxer.yamux.yamux import (
Yamux,
)
from libp2p.stream_muxer.yamux.yamux import PROTOCOL_ID as YAMUX_PROTOCOL_ID
from libp2p.transport.tcp.tcp import (
TCP,
)
@ -54,6 +72,60 @@ from libp2p.utils.logging import (
# Initialize logging configuration
setup_logging()
# Default multiplexer choice
DEFAULT_MUXER = "YAMUX"
# Multiplexer options
MUXER_YAMUX = "YAMUX"
MUXER_MPLEX = "MPLEX"
def set_default_muxer(muxer_name: Literal["YAMUX", "MPLEX"]) -> None:
"""
Set the default multiplexer protocol to use.
:param muxer_name: Either "YAMUX" or "MPLEX"
:raise ValueError: If an unsupported muxer name is provided
"""
global DEFAULT_MUXER
muxer_upper = muxer_name.upper()
if muxer_upper not in [MUXER_YAMUX, MUXER_MPLEX]:
raise ValueError(f"Unknown muxer: {muxer_name}. Use 'YAMUX' or 'MPLEX'.")
DEFAULT_MUXER = muxer_upper
def get_default_muxer() -> str:
"""
Returns the currently selected default muxer.
:return: Either "YAMUX" or "MPLEX"
"""
return DEFAULT_MUXER
def create_yamux_muxer_option() -> TMuxerOptions:
"""
Returns muxer options with Yamux as the primary choice.
:return: Muxer options with Yamux first
"""
return {
TProtocol(YAMUX_PROTOCOL_ID): Yamux, # Primary choice
TProtocol(MPLEX_PROTOCOL_ID): Mplex, # Fallback for compatibility
}
def create_mplex_muxer_option() -> TMuxerOptions:
"""
Returns muxer options with Mplex as the primary choice.
:return: Muxer options with Mplex first
"""
return {
TProtocol(MPLEX_PROTOCOL_ID): Mplex, # Primary choice
TProtocol(YAMUX_PROTOCOL_ID): Yamux, # Fallback
}
def generate_new_rsa_identity() -> KeyPair:
return create_new_key_pair()
@ -64,11 +136,24 @@ def generate_peer_id_from(key_pair: KeyPair) -> ID:
return ID.from_pubkey(public_key)
def get_default_muxer_options() -> TMuxerOptions:
"""
Returns the default muxer options based on the current default muxer setting.
:return: Muxer options with the preferred muxer first
"""
if DEFAULT_MUXER == "MPLEX":
return create_mplex_muxer_option()
else: # YAMUX is default
return create_yamux_muxer_option()
def new_swarm(
key_pair: KeyPair = None,
muxer_opt: TMuxerOptions = None,
sec_opt: TSecurityOptions = None,
peerstore_opt: IPeerStore = None,
key_pair: Optional[KeyPair] = None,
muxer_opt: Optional[TMuxerOptions] = None,
sec_opt: Optional[TSecurityOptions] = None,
peerstore_opt: Optional[IPeerStore] = None,
muxer_preference: Optional[Literal["YAMUX", "MPLEX"]] = None,
) -> INetworkService:
"""
Create a swarm instance based on the parameters.
@ -77,7 +162,13 @@ def new_swarm(
:param muxer_opt: optional choice of stream muxer
:param sec_opt: optional choice of security upgrade
:param peerstore_opt: optional peerstore
:param muxer_preference: optional explicit muxer preference
:return: return a default swarm instance
Note: Yamux (/yamux/1.0.0) is the preferred stream multiplexer
due to its improved performance and features.
Mplex (/mplex/6.7.0) is retained for backward compatibility
but may be deprecated in the future.
"""
if key_pair is None:
key_pair = generate_new_rsa_identity()
@ -87,13 +178,41 @@ def new_swarm(
# TODO: Parse `listen_addrs` to determine transport
transport = TCP()
muxer_transports_by_protocol = muxer_opt or {MPLEX_PROTOCOL_ID: Mplex}
security_transports_by_protocol = sec_opt or {
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair),
# Generate X25519 keypair for Noise
noise_key_pair = create_new_x25519_key_pair()
# Default security transports (using Noise as primary)
secure_transports_by_protocol: Mapping[TProtocol, ISecureTransport] = sec_opt or {
NOISE_PROTOCOL_ID: NoiseTransport(
key_pair, noise_privkey=noise_key_pair.private_key
),
TProtocol(secio.ID): secio.Transport(key_pair),
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair),
}
# Use given muxer preference if provided, otherwise use global default
if muxer_preference is not None:
temp_pref = muxer_preference.upper()
if temp_pref not in [MUXER_YAMUX, MUXER_MPLEX]:
raise ValueError(
f"Unknown muxer: {muxer_preference}. Use 'YAMUX' or 'MPLEX'."
)
active_preference = temp_pref
else:
active_preference = DEFAULT_MUXER
# Use provided muxer options if given, otherwise create based on preference
if muxer_opt is not None:
muxer_transports_by_protocol = muxer_opt
else:
if active_preference == MUXER_MPLEX:
muxer_transports_by_protocol = create_mplex_muxer_option()
else: # YAMUX is default
muxer_transports_by_protocol = create_yamux_muxer_option()
upgrader = TransportUpgrader(
security_transports_by_protocol, muxer_transports_by_protocol
secure_transports_by_protocol=secure_transports_by_protocol,
muxer_transports_by_protocol=muxer_transports_by_protocol,
)
peerstore = peerstore_opt or PeerStore()
@ -104,11 +223,12 @@ def new_swarm(
def new_host(
key_pair: KeyPair = None,
muxer_opt: TMuxerOptions = None,
sec_opt: TSecurityOptions = None,
peerstore_opt: IPeerStore = None,
disc_opt: IPeerRouting = None,
key_pair: Optional[KeyPair] = None,
muxer_opt: Optional[TMuxerOptions] = None,
sec_opt: Optional[TSecurityOptions] = None,
peerstore_opt: Optional[IPeerStore] = None,
disc_opt: Optional[IPeerRouting] = None,
muxer_preference: Optional[Literal["YAMUX", "MPLEX"]] = None,
) -> IHost:
"""
Create a new libp2p host based on the given parameters.
@ -118,6 +238,7 @@ def new_host(
:param sec_opt: optional choice of security upgrade
:param peerstore_opt: optional peerstore
:param disc_opt: optional discovery
:param muxer_preference: optional explicit muxer preference
:return: return a host instance
"""
swarm = new_swarm(
@ -125,13 +246,12 @@ def new_host(
muxer_opt=muxer_opt,
sec_opt=sec_opt,
peerstore_opt=peerstore_opt,
muxer_preference=muxer_preference,
)
host: IHost
if disc_opt:
host = RoutedHost(swarm, disc_opt)
else:
host = BasicHost(swarm)
return host
if disc_opt is not None:
return RoutedHost(swarm, disc_opt)
return BasicHost(swarm)
__version__ = __version("libp2p")

View File

@ -1,3 +1,5 @@
"""Key types and interfaces."""
from abc import (
ABC,
abstractmethod,
@ -9,17 +11,24 @@ from enum import (
Enum,
unique,
)
from typing import (
cast,
)
from .pb import crypto_pb2 as protobuf
from libp2p.crypto.pb import (
crypto_pb2,
)
@unique
class KeyType(Enum):
RSA = protobuf.KeyType.RSA
Ed25519 = protobuf.KeyType.Ed25519
Secp256k1 = protobuf.KeyType.Secp256k1
ECDSA = protobuf.KeyType.ECDSA
ECC_P256 = protobuf.KeyType.ECC_P256
RSA = crypto_pb2.KeyType.RSA
Ed25519 = crypto_pb2.KeyType.Ed25519
Secp256k1 = crypto_pb2.KeyType.Secp256k1
ECDSA = crypto_pb2.KeyType.ECDSA
ECC_P256 = crypto_pb2.KeyType.ECC_P256
# X25519 is added for Noise protocol
X25519 = cast(crypto_pb2.KeyType.ValueType, 5)
class Key(ABC):
@ -52,11 +61,11 @@ class PublicKey(Key):
"""
...
def _serialize_to_protobuf(self) -> protobuf.PublicKey:
def _serialize_to_protobuf(self) -> crypto_pb2.PublicKey:
"""Return the protobuf representation of this ``Key``."""
key_type = self.get_type().value
data = self.to_bytes()
protobuf_key = protobuf.PublicKey(key_type=key_type, data=data)
protobuf_key = crypto_pb2.PublicKey(key_type=key_type, data=data)
return protobuf_key
def serialize(self) -> bytes:
@ -64,8 +73,8 @@ class PublicKey(Key):
return self._serialize_to_protobuf().SerializeToString()
@classmethod
def deserialize_from_protobuf(cls, protobuf_data: bytes) -> protobuf.PublicKey:
return protobuf.PublicKey.FromString(protobuf_data)
def deserialize_from_protobuf(cls, protobuf_data: bytes) -> crypto_pb2.PublicKey:
return crypto_pb2.PublicKey.FromString(protobuf_data)
class PrivateKey(Key):
@ -79,11 +88,11 @@ class PrivateKey(Key):
def get_public_key(self) -> PublicKey:
...
def _serialize_to_protobuf(self) -> protobuf.PrivateKey:
def _serialize_to_protobuf(self) -> crypto_pb2.PrivateKey:
"""Return the protobuf representation of this ``Key``."""
key_type = self.get_type().value
data = self.to_bytes()
protobuf_key = protobuf.PrivateKey(key_type=key_type, data=data)
protobuf_key = crypto_pb2.PrivateKey(key_type=key_type, data=data)
return protobuf_key
def serialize(self) -> bytes:
@ -91,8 +100,8 @@ class PrivateKey(Key):
return self._serialize_to_protobuf().SerializeToString()
@classmethod
def deserialize_from_protobuf(cls, protobuf_data: bytes) -> protobuf.PrivateKey:
return protobuf.PrivateKey.FromString(protobuf_data)
def deserialize_from_protobuf(cls, protobuf_data: bytes) -> crypto_pb2.PrivateKey:
return crypto_pb2.PrivateKey.FromString(protobuf_data)
@dataclass(frozen=True)

View File

@ -8,6 +8,7 @@ enum KeyType {
Secp256k1 = 2;
ECDSA = 3;
ECC_P256 = 4;
X25519 = 5;
}
message PublicKey {

69
libp2p/crypto/x25519.py Normal file
View File

@ -0,0 +1,69 @@
from cryptography.hazmat.primitives import (
serialization,
)
from cryptography.hazmat.primitives.asymmetric import (
x25519,
)
from libp2p.crypto.keys import (
KeyPair,
KeyType,
PrivateKey,
PublicKey,
)
class X25519PublicKey(PublicKey):
def __init__(self, impl: x25519.X25519PublicKey) -> None:
self.impl = impl
def to_bytes(self) -> bytes:
return self.impl.public_bytes(
encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
)
@classmethod
def from_bytes(cls, data: bytes) -> "X25519PublicKey":
return cls(x25519.X25519PublicKey.from_public_bytes(data))
def get_type(self) -> KeyType:
# Not in protobuf, but for Noise use only
return KeyType.X25519 # Or define KeyType.X25519 if you want to extend
def verify(self, data: bytes, signature: bytes) -> bool:
raise NotImplementedError("X25519 does not support signatures.")
class X25519PrivateKey(PrivateKey):
def __init__(self, impl: x25519.X25519PrivateKey) -> None:
self.impl = impl
@classmethod
def new(cls) -> "X25519PrivateKey":
return cls(x25519.X25519PrivateKey.generate())
def to_bytes(self) -> bytes:
return self.impl.private_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PrivateFormat.Raw,
encryption_algorithm=serialization.NoEncryption(),
)
@classmethod
def from_bytes(cls, data: bytes) -> "X25519PrivateKey":
return cls(x25519.X25519PrivateKey.from_private_bytes(data))
def get_type(self) -> KeyType:
return KeyType.X25519
def sign(self, data: bytes) -> bytes:
raise NotImplementedError("X25519 does not support signatures.")
def get_public_key(self) -> PublicKey:
return X25519PublicKey(self.impl.public_key())
def create_new_key_pair() -> KeyPair:
priv = X25519PrivateKey.new()
pub = priv.get_public_key()
return KeyPair(priv, pub)

View File

@ -1,3 +1,4 @@
import logging
from typing import (
TYPE_CHECKING,
)
@ -37,30 +38,69 @@ class SwarmConn(INetConn):
self.streams = set()
self.event_closed = trio.Event()
self.event_started = trio.Event()
if hasattr(muxed_conn, "on_close"):
logging.debug(f"Setting on_close for peer {muxed_conn.peer_id}")
muxed_conn.on_close = self._on_muxed_conn_closed
else:
logging.error(
f"muxed_conn for peer {muxed_conn.peer_id} has no on_close attribute"
)
@property
def is_closed(self) -> bool:
return self.event_closed.is_set()
async def _on_muxed_conn_closed(self) -> None:
"""Handle closure of the underlying muxed connection."""
peer_id = self.muxed_conn.peer_id
logging.debug(f"SwarmConn closing for peer {peer_id} due to muxed_conn closure")
# Only call close if we're not already closing
if not self.event_closed.is_set():
await self.close()
async def close(self) -> None:
if self.event_closed.is_set():
return
logging.debug(f"Closing SwarmConn for peer {self.muxed_conn.peer_id}")
self.event_closed.set()
# Close the muxed connection
try:
await self.muxed_conn.close()
except Exception as e:
logging.warning(f"Error while closing muxed connection: {e}")
# Perform proper cleanup of resources
await self._cleanup()
async def _cleanup(self) -> None:
# Remove the connection from swarm
logging.debug(f"Removing connection for peer {self.muxed_conn.peer_id}")
self.swarm.remove_conn(self)
await self.muxed_conn.close()
# Only close the connection if it's not already closed
# Be defensive here to avoid exceptions during cleanup
try:
if not self.muxed_conn.is_closed:
await self.muxed_conn.close()
except Exception as e:
logging.warning(f"Error closing muxed connection: {e}")
# This is just for cleaning up state. The connection has already been closed.
# We *could* optimize this but it really isn't worth it.
logging.debug(f"Resetting streams for peer {self.muxed_conn.peer_id}")
for stream in self.streams.copy():
await stream.reset()
try:
await stream.reset()
except Exception as e:
logging.warning(f"Error resetting stream: {e}")
# Force context switch for stream handlers to process the stream reset event we
# just emit before we cancel the stream handler tasks.
await trio.sleep(0.1)
# Notify all listeners about the disconnection
logging.debug(f"Notifying disconnection for peer {self.muxed_conn.peer_id}")
await self._notify_disconnected()
async def _handle_new_streams(self) -> None:

View File

@ -12,6 +12,7 @@ from libp2p.custom_types import (
from libp2p.stream_muxer.exceptions import (
MuxedStreamClosed,
MuxedStreamEOF,
MuxedStreamError,
MuxedStreamReset,
)
@ -68,7 +69,7 @@ class NetStream(INetStream):
"""
try:
await self.muxed_stream.write(data)
except MuxedStreamClosed as error:
except (MuxedStreamClosed, MuxedStreamError) as error:
raise StreamClosed() from error
async def close(self) -> None:

View File

@ -313,7 +313,35 @@ class Swarm(Service, INetworkService):
return False
async def close(self) -> None:
await self.manager.stop()
"""
Close the swarm instance and cleanup resources.
"""
# Check if manager exists before trying to stop it
if hasattr(self, "_manager") and self._manager is not None:
await self._manager.stop()
else:
# Perform alternative cleanup if the manager isn't initialized
# Close all connections manually
if hasattr(self, "connections"):
for conn_id in list(self.connections.keys()):
conn = self.connections[conn_id]
await conn.close()
# Clear connection tracking dictionary
self.connections.clear()
# Close all listeners
if hasattr(self, "listeners"):
for listener in self.listeners.values():
await listener.close()
self.listeners.clear()
# Close the transport if it exists and has a close method
if hasattr(self, "transport") and self.transport is not None:
# Check if transport has close method before calling it
if hasattr(self.transport, "close"):
await self.transport.close()
logger.debug("swarm successfully closed")
async def close_peer(self, peer_id: ID) -> None:

View File

@ -242,45 +242,50 @@ class Pubsub(Service, IPubsub):
"""
peer_id = stream.muxed_conn.peer_id
while self.manager.is_running:
incoming: bytes = await read_varint_prefixed_bytes(stream)
rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC()
rpc_incoming.ParseFromString(incoming)
if rpc_incoming.publish:
# deal with RPC.publish
for msg in rpc_incoming.publish:
if not self._is_subscribed_to_msg(msg):
continue
logger.debug(
"received `publish` message %s from peer %s", msg, peer_id
)
self.manager.run_task(self.push_msg, peer_id, msg)
try:
while self.manager.is_running:
incoming: bytes = await read_varint_prefixed_bytes(stream)
rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC()
rpc_incoming.ParseFromString(incoming)
if rpc_incoming.publish:
# deal with RPC.publish
for msg in rpc_incoming.publish:
if not self._is_subscribed_to_msg(msg):
continue
logger.debug(
"received `publish` message %s from peer %s", msg, peer_id
)
self.manager.run_task(self.push_msg, peer_id, msg)
if rpc_incoming.subscriptions:
# deal with RPC.subscriptions
# We don't need to relay the subscription to our
# peers because a given node only needs its peers
# to know that it is subscribed to the topic (doesn't
# need everyone to know)
for message in rpc_incoming.subscriptions:
if rpc_incoming.subscriptions:
# deal with RPC.subscriptions
# We don't need to relay the subscription to our
# peers because a given node only needs its peers
# to know that it is subscribed to the topic (doesn't
# need everyone to know)
for message in rpc_incoming.subscriptions:
logger.debug(
"received `subscriptions` message %s from peer %s",
message,
peer_id,
)
self.handle_subscription(peer_id, message)
# NOTE: Check if `rpc_incoming.control` is set through `HasField`.
# This is necessary because `control` is an optional field in pb2.
# Ref: https://developers.google.com/protocol-buffers/docs/reference/python-generated#singular-fields-proto2 # noqa: E501
if rpc_incoming.HasField("control"):
# Pass rpc to router so router could perform custom logic
logger.debug(
"received `subscriptions` message %s from peer %s",
message,
"received `control` message %s from peer %s",
rpc_incoming.control,
peer_id,
)
self.handle_subscription(peer_id, message)
# NOTE: Check if `rpc_incoming.control` is set through `HasField`.
# This is necessary because `control` is an optional field in pb2.
# Ref: https://developers.google.com/protocol-buffers/docs/reference/python-generated#singular-fields-proto2 # noqa: E501
if rpc_incoming.HasField("control"):
# Pass rpc to router so router could perform custom logic
logger.debug(
"received `control` message %s from peer %s",
rpc_incoming.control,
peer_id,
)
await self.router.handle_rpc(rpc_incoming, peer_id)
await self.router.handle_rpc(rpc_incoming, peer_id)
except StreamEOF:
logger.debug(
f"Stream closed for peer {peer_id}, exiting read loop cleanly."
)
def set_topic_validator(
self, topic: str, validator: ValidatorFn, is_async_validator: bool

View File

@ -1,4 +1,5 @@
from typing import (
Optional,
cast,
)
@ -66,6 +67,14 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter):
async def close(self) -> None:
await self.read_writer.close()
def get_remote_address(self) -> Optional[tuple[str, int]]:
# Delegate to the underlying connection if possible
if hasattr(self.read_writer, "read_write_closer") and hasattr(
self.read_writer.read_write_closer, "get_remote_address"
):
return self.read_writer.read_write_closer.get_remote_address()
return None
class NoiseHandshakeReadWriter(BaseNoiseMsgReadWriter):
def encrypt(self, data: bytes) -> bytes:

View File

@ -2,6 +2,8 @@ from collections import (
OrderedDict,
)
import trio
from libp2p.abc import (
IMuxedConn,
IRawConnection,
@ -24,6 +26,10 @@ from libp2p.protocol_muxer.multiselect_client import (
from libp2p.protocol_muxer.multiselect_communicator import (
MultiselectCommunicator,
)
from libp2p.stream_muxer.yamux.yamux import (
PROTOCOL_ID,
Yamux,
)
# FIXME: add negotiate timeout to `MuxerMultistream`
DEFAULT_NEGOTIATE_TIMEOUT = 60
@ -44,7 +50,7 @@ class MuxerMultistream:
def __init__(self, muxer_transports_by_protocol: TMuxerOptions) -> None:
self.transports = OrderedDict()
self.multiselect = Multiselect()
self.multiselect_client = MultiselectClient()
self.multistream_client = MultiselectClient()
for protocol, transport in muxer_transports_by_protocol.items():
self.add_transport(protocol, transport)
@ -81,5 +87,18 @@ class MuxerMultistream:
return self.transports[protocol]
async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn:
transport_class = await self.select_transport(conn)
communicator = MultiselectCommunicator(conn)
protocol = await self.multistream_client.select_one_of(
tuple(self.transports.keys()), communicator
)
transport_class = self.transports[protocol]
if protocol == PROTOCOL_ID:
async with trio.open_nursery():
def on_close() -> None:
pass
return Yamux(
conn, peer_id, is_initiator=conn.is_initiator, on_close=on_close
)
return transport_class(conn, peer_id)

View File

@ -0,0 +1,5 @@
from .yamux import (
Yamux,
)
__all__ = ["Yamux"]

View File

@ -0,0 +1,676 @@
"""
Yamux stream multiplexer implementation for py-libp2p.
This is the preferred multiplexing protocol due to its performance and feature set.
Mplex is also available for legacy compatibility but may be deprecated in the future.
"""
from collections.abc import (
Awaitable,
)
import inspect
import logging
import struct
from typing import (
Callable,
Optional,
)
import trio
from trio import (
MemoryReceiveChannel,
MemorySendChannel,
Nursery,
)
from libp2p.abc import (
IMuxedConn,
IMuxedStream,
ISecureConn,
)
from libp2p.io.exceptions import (
IncompleteReadError,
)
from libp2p.network.connection.exceptions import (
RawConnError,
)
from libp2p.peer.id import (
ID,
)
from libp2p.stream_muxer.exceptions import (
MuxedStreamEOF,
MuxedStreamError,
MuxedStreamReset,
)
PROTOCOL_ID = "/yamux/1.0.0"
TYPE_DATA = 0x0
TYPE_WINDOW_UPDATE = 0x1
TYPE_PING = 0x2
TYPE_GO_AWAY = 0x3
FLAG_SYN = 0x1
FLAG_ACK = 0x2
FLAG_FIN = 0x4
FLAG_RST = 0x8
HEADER_SIZE = 12
# Network byte order: version (B), type (B), flags (H), stream_id (I), length (I)
YAMUX_HEADER_FORMAT = "!BBHII"
DEFAULT_WINDOW_SIZE = 256 * 1024
GO_AWAY_NORMAL = 0x0
GO_AWAY_PROTOCOL_ERROR = 0x1
GO_AWAY_INTERNAL_ERROR = 0x2
class YamuxStream(IMuxedStream):
def __init__(self, stream_id: int, conn: "Yamux", is_initiator: bool) -> None:
self.stream_id = stream_id
self.conn = conn
self.muxed_conn = conn
self.is_initiator = is_initiator
self.closed = False
self.send_closed = False
self.recv_closed = False
self.reset_received = False # Track if RST was received
self.send_window = DEFAULT_WINDOW_SIZE
self.recv_window = DEFAULT_WINDOW_SIZE
self.window_lock = trio.Lock()
async def write(self, data: bytes) -> None:
if self.send_closed:
raise MuxedStreamError("Stream is closed for sending")
# Flow control: Check if we have enough send window
total_len = len(data)
sent = 0
while sent < total_len:
async with self.window_lock:
# Wait for available window
while self.send_window == 0 and not self.closed:
# Release lock while waiting
self.window_lock.release()
await trio.sleep(0.01)
await self.window_lock.acquire()
if self.closed:
raise MuxedStreamError("Stream is closed")
# Calculate how much we can send now
to_send = min(self.send_window, total_len - sent)
chunk = data[sent : sent + to_send]
self.send_window -= to_send
# Send the data
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, 0, self.stream_id, len(chunk)
)
await self.conn.secured_conn.write(header + chunk)
sent += to_send
# If window is getting low, consider updating
if self.send_window < DEFAULT_WINDOW_SIZE // 2:
await self.send_window_update()
async def send_window_update(self, increment: Optional[int] = None) -> None:
"""Send a window update to peer."""
if increment is None:
increment = DEFAULT_WINDOW_SIZE - self.recv_window
if increment <= 0:
return
async with self.window_lock:
self.recv_window += increment
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_WINDOW_UPDATE, 0, self.stream_id, increment
)
await self.conn.secured_conn.write(header)
async def read(self, n: int = -1) -> bytes:
# Handle None value for n by converting it to -1
if n is None:
n = -1
# If the stream is closed for receiving and the buffer is empty, raise EOF
if self.recv_closed and not self.conn.stream_buffers.get(self.stream_id):
logging.debug(
f"Stream {self.stream_id}: Stream closed for receiving and buffer empty"
)
raise MuxedStreamEOF("Stream is closed for receiving")
# If reading until EOF (n == -1), block until stream is closed
if n == -1:
while not self.recv_closed and not self.conn.event_shutting_down.is_set():
# Check if there's data in the buffer
buffer = self.conn.stream_buffers.get(self.stream_id)
if buffer and len(buffer) > 0:
# Wait for closure even if data is available
logging.debug(
f"Stream {self.stream_id}:"
f"Waiting for FIN before returning data"
)
await self.conn.stream_events[self.stream_id].wait()
self.conn.stream_events[self.stream_id] = trio.Event()
else:
# No data, wait for data or closure
logging.debug(f"Stream {self.stream_id}: Waiting for data or FIN")
await self.conn.stream_events[self.stream_id].wait()
self.conn.stream_events[self.stream_id] = trio.Event()
# After loop, check if stream is closed or shutting down
async with self.conn.streams_lock:
if self.conn.event_shutting_down.is_set():
logging.debug(f"Stream {self.stream_id}: Connection shutting down")
raise MuxedStreamEOF("Connection shut down")
if self.closed:
if self.reset_received:
logging.debug(f"Stream {self.stream_id}: Stream was reset")
raise MuxedStreamReset("Stream was reset")
else:
logging.debug(
f"Stream {self.stream_id}: Stream closed cleanly (EOF)"
)
raise MuxedStreamEOF("Stream closed cleanly (EOF)")
buffer = self.conn.stream_buffers.get(self.stream_id)
if buffer is None:
logging.debug(
f"Stream {self.stream_id}: Buffer gone, assuming closed"
)
raise MuxedStreamEOF("Stream buffer closed")
if self.recv_closed and len(buffer) == 0:
logging.debug(f"Stream {self.stream_id}: EOF reached")
raise MuxedStreamEOF("Stream is closed for receiving")
# Return all buffered data
data = bytes(buffer)
buffer.clear()
logging.debug(f"Stream {self.stream_id}: Returning {len(data)} bytes")
return data
# For specific size read (n > 0), return available data immediately
return await self.conn.read_stream(self.stream_id, n)
async def close(self) -> None:
if not self.send_closed:
logging.debug(f"Half-closing stream {self.stream_id} (local end)")
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0
)
await self.conn.secured_conn.write(header)
self.send_closed = True
# Only set fully closed if both directions are closed
if self.send_closed and self.recv_closed:
self.closed = True
else:
# Stream is half-closed but not fully closed
self.closed = False
async def reset(self) -> None:
if not self.closed:
logging.debug(f"Resetting stream {self.stream_id}")
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0
)
await self.conn.secured_conn.write(header)
self.closed = True
self.reset_received = True # Mark as reset
def set_deadline(self, ttl: int) -> bool:
"""
Set a deadline for the stream. Yamux does not support deadlines natively,
so this method always returns False to indicate the operation is unsupported.
:param ttl: Time-to-live in seconds (ignored).
:return: False, as deadlines are not supported.
"""
raise NotImplementedError("Yamux does not support setting read deadlines")
def get_remote_address(self) -> Optional[tuple[str, int]]:
"""
Returns the remote address of the underlying connection.
"""
# Delegate to the secured_conn's get_remote_address method
if hasattr(self.conn.secured_conn, "get_remote_address"):
remote_addr = self.conn.secured_conn.get_remote_address()
# Ensure the return value matches tuple[str, int] | None
if (
remote_addr is None
or isinstance(remote_addr, tuple)
and len(remote_addr) == 2
):
return remote_addr
else:
raise ValueError(
"Underlying connection returned an unexpected address format"
)
else:
# Return None if the underlying connection doesn't provide this info
return None
class Yamux(IMuxedConn):
def __init__(
self,
secured_conn: ISecureConn,
peer_id: ID,
is_initiator: Optional[bool] = None,
on_close: Optional[Callable[[], Awaitable[None]]] = None,
) -> None:
self.secured_conn = secured_conn
self.peer_id = peer_id
self.stream_backlog_limit = 256
self.stream_backlog_semaphore = trio.Semaphore(256)
self.on_close = on_close
# Per Yamux spec
# (https://github.com/hashicorp/yamux/blob/master/spec.md#streamid-field):
# Initiators assign odd stream IDs (starting at 1),
# responders use even IDs (starting at 2).
self.is_initiator_value = (
is_initiator if is_initiator is not None else secured_conn.is_initiator
)
self.next_stream_id = 1 if self.is_initiator_value else 2
self.streams: dict[int, YamuxStream] = {}
self.streams_lock = trio.Lock()
self.new_stream_send_channel: MemorySendChannel[YamuxStream]
self.new_stream_receive_channel: MemoryReceiveChannel[YamuxStream]
(
self.new_stream_send_channel,
self.new_stream_receive_channel,
) = trio.open_memory_channel(10)
self.event_shutting_down = trio.Event()
self.event_closed = trio.Event()
self.event_started = trio.Event()
self.stream_buffers: dict[int, bytearray] = {}
self.stream_events: dict[int, trio.Event] = {}
self._nursery: Optional[Nursery] = None
async def start(self) -> None:
logging.debug(f"Starting Yamux for {self.peer_id}")
if self.event_started.is_set():
return
async with trio.open_nursery() as nursery:
self._nursery = nursery
nursery.start_soon(self.handle_incoming)
self.event_started.set()
@property
def is_initiator(self) -> bool:
return self.is_initiator_value
async def close(self, error_code: int = GO_AWAY_NORMAL) -> None:
logging.debug(f"Closing Yamux connection with code {error_code}")
async with self.streams_lock:
if not self.event_shutting_down.is_set():
try:
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_GO_AWAY, 0, 0, error_code
)
await self.secured_conn.write(header)
except Exception as e:
logging.debug(f"Failed to send GO_AWAY: {e}")
self.event_shutting_down.set()
for stream in self.streams.values():
stream.closed = True
stream.send_closed = True
stream.recv_closed = True
self.streams.clear()
self.stream_buffers.clear()
self.stream_events.clear()
try:
await self.secured_conn.close()
logging.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
except Exception as e:
logging.debug(f"Error closing secured_conn for peer {self.peer_id}: {e}")
self.event_closed.set()
if self.on_close:
logging.debug(f"Calling on_close in Yamux.close for peer {self.peer_id}")
if inspect.iscoroutinefunction(self.on_close):
if self.on_close is not None:
await self.on_close()
else:
if self.on_close is not None:
await self.on_close()
await trio.sleep(0.1)
@property
def is_closed(self) -> bool:
return self.event_closed.is_set()
async def open_stream(self) -> YamuxStream:
# Wait for backlog slot
await self.stream_backlog_semaphore.acquire()
async with self.streams_lock:
stream_id = self.next_stream_id
self.next_stream_id += 2
stream = YamuxStream(stream_id, self, True)
self.streams[stream_id] = stream
self.stream_buffers[stream_id] = bytearray()
self.stream_events[stream_id] = trio.Event()
# If stream is rejected or errors, release the semaphore
try:
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_SYN, stream_id, 0
)
logging.debug(f"Sending SYN header for stream {stream_id}")
await self.secured_conn.write(header)
return stream
except Exception as e:
self.stream_backlog_semaphore.release()
raise e
async def accept_stream(self) -> IMuxedStream:
logging.debug("Waiting for new stream")
try:
stream = await self.new_stream_receive_channel.receive()
logging.debug(f"Received stream {stream.stream_id}")
return stream
except trio.EndOfChannel:
raise MuxedStreamError("No new streams available")
async def read_stream(self, stream_id: int, n: int = -1) -> bytes:
logging.debug(f"Reading from stream {self.peer_id}:{stream_id}, n={n}")
if n is None:
n = -1
while True:
async with self.streams_lock:
if stream_id not in self.streams:
logging.debug(f"Stream {self.peer_id}:{stream_id} unknown")
raise MuxedStreamEOF("Stream closed")
if self.event_shutting_down.is_set():
logging.debug(
f"Stream {self.peer_id}:{stream_id}: connection shutting down"
)
raise MuxedStreamEOF("Connection shut down")
stream = self.streams[stream_id]
buffer = self.stream_buffers.get(stream_id)
logging.debug(
f"Stream {self.peer_id}:{stream_id}: "
f"closed={stream.closed}, "
f"recv_closed={stream.recv_closed}, "
f"reset_received={stream.reset_received}, "
f"buffer_len={len(buffer) if buffer else 0}"
)
if buffer is None:
logging.debug(
f"Stream {self.peer_id}:{stream_id}:"
f"Buffer gone, assuming closed"
)
raise MuxedStreamEOF("Stream buffer closed")
# If FIN received and buffer has data, return it
if stream.recv_closed and buffer and len(buffer) > 0:
if n == -1 or n >= len(buffer):
data = bytes(buffer)
buffer.clear()
else:
data = bytes(buffer[:n])
del buffer[:n]
logging.debug(
f"Returning {len(data)} bytes"
f"from stream {self.peer_id}:{stream_id}, "
f"buffer_len={len(buffer)}"
)
return data
# If reset received and buffer is empty, raise reset
if stream.reset_received:
logging.debug(
f"Stream {self.peer_id}:{stream_id}:"
f"reset_received=True, raising MuxedStreamReset"
)
raise MuxedStreamReset("Stream was reset")
# Check if we can return data (no FIN or reset)
if buffer and len(buffer) > 0:
if n == -1 or n >= len(buffer):
data = bytes(buffer)
buffer.clear()
else:
data = bytes(buffer[:n])
del buffer[:n]
logging.debug(
f"Returning {len(data)} bytes"
f"from stream {self.peer_id}:{stream_id}, "
f"buffer_len={len(buffer)}"
)
return data
# Check if stream is closed
if stream.closed:
logging.debug(
f"Stream {self.peer_id}:{stream_id}:"
f"closed=True, raising MuxedStreamReset"
)
raise MuxedStreamReset("Stream is reset or closed")
# Check if recv_closed and buffer empty
if stream.recv_closed:
logging.debug(
f"Stream {self.peer_id}:{stream_id}:"
f"recv_closed=True, buffer empty, raising EOF"
)
raise MuxedStreamEOF("Stream is closed for receiving")
# Wait for data if stream is still open
logging.debug(f"Waiting for data on stream {self.peer_id}:{stream_id}")
await self.stream_events[stream_id].wait()
self.stream_events[stream_id] = trio.Event()
async def handle_incoming(self) -> None:
while not self.event_shutting_down.is_set():
try:
header = await self.secured_conn.read(HEADER_SIZE)
if not header or len(header) < HEADER_SIZE:
logging.debug(
f"Connection closed or"
f"incomplete header for peer {self.peer_id}"
)
self.event_shutting_down.set()
await self._cleanup_on_error()
break
version, typ, flags, stream_id, length = struct.unpack(
YAMUX_HEADER_FORMAT, header
)
logging.debug(
f"Received header for peer {self.peer_id}:"
f"type={typ}, flags={flags}, stream_id={stream_id},"
f"length={length}"
)
if typ == TYPE_DATA and flags & FLAG_SYN:
async with self.streams_lock:
if stream_id not in self.streams:
stream = YamuxStream(stream_id, self, False)
self.streams[stream_id] = stream
self.stream_buffers[stream_id] = bytearray()
self.stream_events[stream_id] = trio.Event()
ack_header = struct.pack(
YAMUX_HEADER_FORMAT,
0,
TYPE_DATA,
FLAG_ACK,
stream_id,
0,
)
await self.secured_conn.write(ack_header)
logging.debug(
f"Sending stream {stream_id}"
f"to channel for peer {self.peer_id}"
)
await self.new_stream_send_channel.send(stream)
else:
rst_header = struct.pack(
YAMUX_HEADER_FORMAT,
0,
TYPE_DATA,
FLAG_RST,
stream_id,
0,
)
await self.secured_conn.write(rst_header)
elif typ == TYPE_DATA and flags & FLAG_RST:
async with self.streams_lock:
if stream_id in self.streams:
logging.debug(
f"Resetting stream {stream_id} for peer {self.peer_id}"
)
self.streams[stream_id].closed = True
self.streams[stream_id].reset_received = True
self.stream_events[stream_id].set()
elif typ == TYPE_DATA and flags & FLAG_ACK:
async with self.streams_lock:
if stream_id in self.streams:
logging.debug(
f"Received ACK for stream"
f"{stream_id} for peer {self.peer_id}"
)
elif typ == TYPE_GO_AWAY:
error_code = length
if error_code == GO_AWAY_NORMAL:
logging.debug(
f"Received GO_AWAY for peer"
f"{self.peer_id}: Normal termination"
)
elif error_code == GO_AWAY_PROTOCOL_ERROR:
logging.error(
f"Received GO_AWAY for peer"
f"{self.peer_id}: Protocol error"
)
elif error_code == GO_AWAY_INTERNAL_ERROR:
logging.error(
f"Received GO_AWAY for peer {self.peer_id}: Internal error"
)
else:
logging.error(
f"Received GO_AWAY for peer {self.peer_id}"
f"with unknown error code: {error_code}"
)
self.event_shutting_down.set()
await self._cleanup_on_error()
break
elif typ == TYPE_PING:
if flags & FLAG_SYN:
logging.debug(
f"Received ping request with value"
f"{length} for peer {self.peer_id}"
)
ping_header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_PING, FLAG_ACK, 0, length
)
await self.secured_conn.write(ping_header)
elif flags & FLAG_ACK:
logging.debug(
f"Received ping response with value"
f"{length} for peer {self.peer_id}"
)
elif typ == TYPE_DATA:
try:
data = (
await self.secured_conn.read(length) if length > 0 else b""
)
async with self.streams_lock:
if stream_id in self.streams:
self.stream_buffers[stream_id].extend(data)
self.stream_events[stream_id].set()
if flags & FLAG_FIN:
logging.debug(
f"Received FIN for stream {self.peer_id}:"
f"{stream_id}, marking recv_closed"
)
self.streams[stream_id].recv_closed = True
if self.streams[stream_id].send_closed:
self.streams[stream_id].closed = True
except Exception as e:
logging.error(f"Error reading data for stream {stream_id}: {e}")
# Mark stream as closed on read error
async with self.streams_lock:
if stream_id in self.streams:
self.streams[stream_id].recv_closed = True
if self.streams[stream_id].send_closed:
self.streams[stream_id].closed = True
self.stream_events[stream_id].set()
elif typ == TYPE_WINDOW_UPDATE:
increment = length
async with self.streams_lock:
if stream_id in self.streams:
stream = self.streams[stream_id]
async with stream.window_lock:
logging.debug(
f"Received window update for stream"
f"{self.peer_id}:{stream_id},"
f" increment: {increment}"
)
stream.send_window += increment
except Exception as e:
# Special handling for expected IncompleteReadError on stream close
if isinstance(e, IncompleteReadError):
details = getattr(e, "args", [{}])[0]
if (
isinstance(details, dict)
and details.get("requested_count") == 2
and details.get("received_count") == 0
):
logging.info(
f"Stream closed cleanly for peer {self.peer_id}"
+ f" (IncompleteReadError: {details})"
)
self.event_shutting_down.set()
await self._cleanup_on_error()
break
else:
logging.error(
f"Error in handle_incoming for peer {self.peer_id}: "
+ f"{type(e).__name__}: {str(e)}"
)
else:
logging.error(
f"Error in handle_incoming for peer {self.peer_id}: "
+ f"{type(e).__name__}: {str(e)}"
)
# Don't crash the whole connection for temporary errors
if self.event_shutting_down.is_set() or isinstance(
e, (RawConnError, OSError)
):
await self._cleanup_on_error()
break
# For other errors, log and continue
await trio.sleep(0.01)
async def _cleanup_on_error(self) -> None:
# Set shutdown flag first to prevent other operations
self.event_shutting_down.set()
# Clean up streams
async with self.streams_lock:
for stream in self.streams.values():
stream.closed = True
stream.send_closed = True
stream.recv_closed = True
# Set the event so any waiters are woken up
if stream.stream_id in self.stream_events:
self.stream_events[stream.stream_id].set()
# Clear buffers and events
self.stream_buffers.clear()
self.stream_events.clear()
# Close the secured connection
try:
await self.secured_conn.close()
logging.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
except Exception as close_error:
logging.error(
f"Error closing secured_conn for peer {self.peer_id}: {close_error}"
)
# Set closed flag
self.event_closed.set()
# Call on_close callback if provided
if self.on_close:
logging.debug(f"Calling on_close for peer {self.peer_id}")
try:
if inspect.iscoroutinefunction(self.on_close):
await self.on_close()
else:
self.on_close()
except Exception as callback_error:
logging.error(f"Error in on_close callback: {callback_error}")
# Cancel nursery tasks
if self._nursery:
self._nursery.cancel_scope.cancel()

View File

@ -1,10 +1,13 @@
from collections.abc import (
Awaitable,
)
import logging
from typing import (
Callable,
)
import trio
from libp2p.abc import (
IHost,
INetStream,
@ -32,16 +35,104 @@ async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm) -> None:
for addr in transport.get_addrs()
)
swarm_0.peerstore.add_addrs(peer_id, addrs, 10000)
await swarm_0.dial_peer(peer_id)
assert swarm_0.get_peer_id() in swarm_1.connections
assert swarm_1.get_peer_id() in swarm_0.connections
# Add retry logic for more robust connection
max_retries = 3
retry_delay = 0.2
last_error = None
for attempt in range(max_retries):
try:
await swarm_0.dial_peer(peer_id)
# Verify connection is established in both directions
if (
swarm_0.get_peer_id() in swarm_1.connections
and swarm_1.get_peer_id() in swarm_0.connections
):
return
# Connection partially established, wait a bit for it to complete
await trio.sleep(0.1)
if (
swarm_0.get_peer_id() in swarm_1.connections
and swarm_1.get_peer_id() in swarm_0.connections
):
return
logging.debug(
"Swarm connection verification failed on attempt"
+ f" {attempt+1}, retrying..."
)
except Exception as e:
last_error = e
logging.debug(f"Swarm connection attempt {attempt+1} failed: {e}")
await trio.sleep(retry_delay)
# If we got here, all retries failed
if last_error:
raise RuntimeError(
f"Failed to connect swarms after {max_retries} attempts"
) from last_error
else:
err_msg = (
"Failed to establish bidirectional swarm connection"
+ f" after {max_retries} attempts"
)
raise RuntimeError(err_msg)
async def connect(node1: IHost, node2: IHost) -> None:
"""Connect node1 to node2."""
addr = node2.get_addrs()[0]
info = info_from_p2p_addr(addr)
await node1.connect(info)
# Add retry logic for more robust connection
max_retries = 3
retry_delay = 0.2
last_error = None
for attempt in range(max_retries):
try:
await node1.connect(info)
# Verify connection is established in both directions
if (
node2.get_id() in node1.get_network().connections
and node1.get_id() in node2.get_network().connections
):
return
# Connection partially established, wait a bit for it to complete
await trio.sleep(0.1)
if (
node2.get_id() in node1.get_network().connections
and node1.get_id() in node2.get_network().connections
):
return
logging.debug(
f"Connection verification failed on attempt {attempt+1}, retrying..."
)
except Exception as e:
last_error = e
logging.debug(f"Connection attempt {attempt+1} failed: {e}")
await trio.sleep(retry_delay)
# If we got here, all retries failed
if last_error:
raise RuntimeError(
f"Failed to connect after {max_retries} attempts"
) from last_error
else:
err_msg = (
f"Failed to establish bidirectional connection after {max_retries} attempts"
)
raise RuntimeError(err_msg)
def create_echo_stream_handler(