Updated Yamux impl.,added tests for yamux and mplex

This commit is contained in:
kaneki003
2025-05-31 19:02:18 +05:30
parent 286752c517
commit e397ce25a6
5 changed files with 458 additions and 67 deletions

View File

@ -46,8 +46,9 @@ class MplexStream(IMuxedStream):
read_deadline: int | None read_deadline: int | None
write_deadline: int | None write_deadline: int | None
# TODO: Add lock for read/write to avoid interleaving receiving messages?
close_lock: trio.Lock close_lock: trio.Lock
read_lock: trio.Lock
write_lock: trio.Lock
# NOTE: `dataIn` is size of 8 in Go implementation. # NOTE: `dataIn` is size of 8 in Go implementation.
incoming_data_channel: "trio.MemoryReceiveChannel[bytes]" incoming_data_channel: "trio.MemoryReceiveChannel[bytes]"
@ -80,6 +81,8 @@ class MplexStream(IMuxedStream):
self.event_remote_closed = trio.Event() self.event_remote_closed = trio.Event()
self.event_reset = trio.Event() self.event_reset = trio.Event()
self.close_lock = trio.Lock() self.close_lock = trio.Lock()
self.read_lock = trio.Lock()
self.write_lock = trio.Lock()
self.incoming_data_channel = incoming_data_channel self.incoming_data_channel = incoming_data_channel
self._buf = bytearray() self._buf = bytearray()
@ -113,48 +116,49 @@ class MplexStream(IMuxedStream):
:param n: number of bytes to read :param n: number of bytes to read
:return: bytes actually read :return: bytes actually read
""" """
if n is not None and n < 0: async with self.read_lock:
raise ValueError( if n is not None and n < 0:
"the number of bytes to read `n` must be non-negative or " raise ValueError(
f"`None` to indicate read until EOF, got n={n}" "the number of bytes to read `n` must be non-negative or "
) f"`None` to indicate read until EOF, got n={n}"
if self.event_reset.is_set(): )
raise MplexStreamReset if self.event_reset.is_set():
if n is None: raise MplexStreamReset
return await self._read_until_eof() if n is None:
if len(self._buf) == 0: return await self._read_until_eof()
data: bytes if len(self._buf) == 0:
# Peek whether there is data available. If yes, we just read until there is data: bytes
# no data, then return. # Peek whether there is data available. If yes, we just read until
try: # there is no data, then return.
data = self.incoming_data_channel.receive_nowait()
self._buf.extend(data)
except trio.EndOfChannel:
raise MplexStreamEOF
except trio.WouldBlock:
# We know `receive` will be blocked here. Wait for data here with
# `receive` and catch all kinds of errors here.
try: try:
data = await self.incoming_data_channel.receive() data = self.incoming_data_channel.receive_nowait()
self._buf.extend(data) self._buf.extend(data)
except trio.EndOfChannel: except trio.EndOfChannel:
if self.event_reset.is_set(): raise MplexStreamEOF
raise MplexStreamReset except trio.WouldBlock:
if self.event_remote_closed.is_set(): # We know `receive` will be blocked here. Wait for data here with
raise MplexStreamEOF # `receive` and catch all kinds of errors here.
except trio.ClosedResourceError as error: try:
# Probably `incoming_data_channel` is closed in `reset` when we are data = await self.incoming_data_channel.receive()
# waiting for `receive`. self._buf.extend(data)
if self.event_reset.is_set(): except trio.EndOfChannel:
raise MplexStreamReset if self.event_reset.is_set():
raise Exception( raise MplexStreamReset
"`incoming_data_channel` is closed but stream is not reset. " if self.event_remote_closed.is_set():
"This should never happen." raise MplexStreamEOF
) from error except trio.ClosedResourceError as error:
self._buf.extend(self._read_return_when_blocked()) # Probably `incoming_data_channel` is closed in `reset` when
payload = self._buf[:n] # we are waiting for `receive`.
self._buf = self._buf[len(payload) :] if self.event_reset.is_set():
return bytes(payload) raise MplexStreamReset
raise Exception(
"`incoming_data_channel` is closed but stream is not reset."
"This should never happen."
) from error
self._buf.extend(self._read_return_when_blocked())
payload = self._buf[:n]
self._buf = self._buf[len(payload) :]
return bytes(payload)
async def write(self, data: bytes) -> None: async def write(self, data: bytes) -> None:
""" """
@ -162,14 +166,15 @@ class MplexStream(IMuxedStream):
:return: number of bytes written :return: number of bytes written
""" """
if self.event_local_closed.is_set(): async with self.write_lock:
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}") if self.event_local_closed.is_set():
flag = ( raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")
HeaderTags.MessageInitiator flag = (
if self.is_initiator HeaderTags.MessageInitiator
else HeaderTags.MessageReceiver if self.is_initiator
) else HeaderTags.MessageReceiver
await self.muxed_conn.send_message(flag, data, self.stream_id) )
await self.muxed_conn.send_message(flag, data, self.stream_id)
async def close(self) -> None: async def close(self) -> None:
""" """

View File

@ -77,6 +77,8 @@ class YamuxStream(IMuxedStream):
self.send_window = DEFAULT_WINDOW_SIZE self.send_window = DEFAULT_WINDOW_SIZE
self.recv_window = DEFAULT_WINDOW_SIZE self.recv_window = DEFAULT_WINDOW_SIZE
self.window_lock = trio.Lock() self.window_lock = trio.Lock()
self.read_lock = trio.Lock()
self.write_lock = trio.Lock()
async def __aenter__(self) -> "YamuxStream": async def __aenter__(self) -> "YamuxStream":
"""Enter the async context manager.""" """Enter the async context manager."""
@ -98,16 +100,32 @@ class YamuxStream(IMuxedStream):
# Flow control: Check if we have enough send window # Flow control: Check if we have enough send window
total_len = len(data) total_len = len(data)
sent = 0 sent = 0
logging.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
while sent < total_len: while sent < total_len:
# Wait for available window with timeout
timeout = False
async with self.window_lock: async with self.window_lock:
# Wait for available window if self.send_window == 0:
while self.send_window == 0 and not self.closed: logging.debug(
# Release lock while waiting f"Stream {self.stream_id}: Window is zero, waiting for update"
)
# Release lock and wait with timeout
self.window_lock.release() self.window_lock.release()
await trio.sleep(0.01) # To avoid re-acquiring the lock immediately,
with trio.move_on_after(5.0) as cancel_scope:
while self.send_window == 0 and not self.closed:
await trio.sleep(0.01)
# If we timed out, cancel the scope
timeout = cancel_scope.cancelled_caught
# Re-acquire lock
await self.window_lock.acquire() await self.window_lock.acquire()
# If we timed out waiting for window update, raise an error
if timeout:
raise MuxedStreamError(
"Timed out waiting for window update after 5 seconds."
)
if self.closed: if self.closed:
raise MuxedStreamError("Stream is closed") raise MuxedStreamError("Stream is closed")
@ -123,25 +141,53 @@ class YamuxStream(IMuxedStream):
await self.conn.secured_conn.write(header + chunk) await self.conn.secured_conn.write(header + chunk)
sent += to_send sent += to_send
# If window is getting low, consider updating async def send_window_update(
if self.send_window < DEFAULT_WINDOW_SIZE // 2: self, increment: int | None, skip_lock: bool = False
await self.send_window_update() ) -> None:
"""
Send a window update to peer.
async def send_window_update(self, increment: int | None = None) -> None: param:increment: The amount to increment the window size by.
"""Send a window update to peer.""" If None, uses the difference between DEFAULT_WINDOW_SIZE
and current receive window.
param:skip_lock (bool): If True, skips acquiring window_lock.
This should only be used when calling from a context
that already holds the lock.
"""
increment_value = 0
if increment is None: if increment is None:
increment = DEFAULT_WINDOW_SIZE - self.recv_window increment_value = DEFAULT_WINDOW_SIZE - self.recv_window
else:
if increment <= 0: increment_value = increment
if increment_value <= 0:
# If increment is zero or negative, skip sending update
logging.debug(
f"Stream {self.stream_id}: Skipping window update"
f"(increment={increment})"
)
return return
logging.debug(
f"Stream {self.stream_id}: Sending window update with increment={increment}"
)
async with self.window_lock: async def _do_window_update() -> None:
self.recv_window += increment self.recv_window += increment_value
header = struct.pack( header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_WINDOW_UPDATE, 0, self.stream_id, increment YAMUX_HEADER_FORMAT,
0,
TYPE_WINDOW_UPDATE,
0,
self.stream_id,
increment_value,
) )
await self.conn.secured_conn.write(header) await self.conn.secured_conn.write(header)
if skip_lock:
await _do_window_update()
else:
async with self.window_lock:
await _do_window_update()
async def read(self, n: int | None = -1) -> bytes: async def read(self, n: int | None = -1) -> bytes:
# Handle None value for n by converting it to -1 # Handle None value for n by converting it to -1
if n is None: if n is None:
@ -198,11 +244,19 @@ class YamuxStream(IMuxedStream):
# Return all buffered data # Return all buffered data
data = bytes(buffer) data = bytes(buffer)
buffer.clear() buffer.clear()
logging.debug(f"Stream {self.stream_id}: Returning {len(data)} bytes")
return data return data
# For specific size read (n > 0), return available data immediately data = await self.conn.read_stream(self.stream_id, n)
return await self.conn.read_stream(self.stream_id, n) async with self.window_lock:
self.recv_window -= len(data)
# Automatically send a window update if recv_window is low
if self.recv_window <= DEFAULT_WINDOW_SIZE // 2:
logging.debug(
f"Stream {self.stream_id}: "
f"Low recv_window ({self.recv_window}), sending update"
)
await self.send_window_update(None, skip_lock=True)
return data
async def close(self) -> None: async def close(self) -> None:
if not self.send_closed: if not self.send_closed:

View File

@ -0,0 +1 @@
Added separate read and write locks to the `MplexStream` & `YamuxStream` class.This ensures thread-safe access and data integrity when multiple coroutines interact with the same MplexStream instance.

View File

@ -0,0 +1,124 @@
import pytest
import trio
from libp2p.abc import ISecureConn
from libp2p.crypto.keys import PrivateKey, PublicKey
from libp2p.peer.id import ID
from libp2p.stream_muxer.mplex.constants import (
HeaderTags,
)
from libp2p.stream_muxer.mplex.datastructures import (
StreamID,
)
from libp2p.stream_muxer.mplex.mplex import (
Mplex,
)
from libp2p.stream_muxer.mplex.mplex_stream import (
MplexStream,
)
class DummySecureConn(ISecureConn):
"""A minimal implementation of ISecureConn for testing."""
async def write(self, data: bytes) -> None:
pass
async def read(self, n: int | None = -1) -> bytes:
return b""
async def close(self) -> None:
pass
def get_remote_address(self) -> tuple[str, int] | None:
return None
def get_local_peer(self) -> ID:
return ID(b"local")
def get_local_private_key(self) -> PrivateKey:
return PrivateKey() # Dummy key for testing
def get_remote_peer(self) -> ID:
return ID(b"remote")
def get_remote_public_key(self) -> PublicKey:
return PublicKey() # Dummy key for testing
class DummyMuxedConn(Mplex):
"""A minimal mock of Mplex for testing read/write locks."""
def __init__(self) -> None:
self.secured_conn = DummySecureConn()
self.peer_id = ID(b"dummy")
self.streams = {}
self.streams_lock = trio.Lock()
self.event_shutting_down = trio.Event()
self.event_closed = trio.Event()
self.event_started = trio.Event()
self.stream_backlog_limit = 256
self.stream_backlog_semaphore = trio.Semaphore(256)
channels = trio.open_memory_channel[MplexStream](0)
self.new_stream_send_channel, self.new_stream_receive_channel = channels
async def send_message(
self, flag: HeaderTags, data: bytes, stream_id: StreamID
) -> None:
await trio.sleep(0.01)
@pytest.mark.trio
async def test_concurrent_writes_are_serialized():
stream_id = StreamID(1, True)
send_log = []
class LoggingMuxedConn(DummyMuxedConn):
async def send_message(
self, flag: HeaderTags, data: bytes, stream_id: StreamID
) -> None:
send_log.append(data)
await trio.sleep(0.01)
memory_send, memory_recv = trio.open_memory_channel(8)
stream = MplexStream(
name="test",
stream_id=stream_id,
muxed_conn=LoggingMuxedConn(),
incoming_data_channel=memory_recv,
)
async def writer(data):
await stream.write(data)
async with trio.open_nursery() as nursery:
for i in range(5):
nursery.start_soon(writer, f"msg-{i}".encode())
# Order doesn't matter due to concurrent execution
assert sorted(send_log) == sorted([f"msg-{i}".encode() for i in range(5)])
@pytest.mark.trio
async def test_concurrent_reads_are_serialized():
stream_id = StreamID(2, True)
muxed_conn = DummyMuxedConn()
memory_send, memory_recv = trio.open_memory_channel(8)
results = []
stream = MplexStream(
name="test",
stream_id=stream_id,
muxed_conn=muxed_conn,
incoming_data_channel=memory_recv,
)
for i in range(5):
await memory_send.send(f"data-{i}".encode())
await memory_send.aclose()
async def reader():
data = await stream.read(6)
results.append(data)
async with trio.open_nursery() as nursery:
for _ in range(5):
nursery.start_soon(reader)
assert sorted(results) == [f"data-{i}".encode() for i in range(5)]

View File

@ -0,0 +1,207 @@
import logging
import pytest
import trio
from trio.testing import (
memory_stream_pair,
)
from libp2p.abc import IRawConnection
from libp2p.crypto.ed25519 import (
create_new_key_pair,
)
from libp2p.peer.id import (
ID,
)
from libp2p.security.insecure.transport import (
InsecureTransport,
)
from libp2p.stream_muxer.yamux.yamux import (
Yamux,
YamuxStream,
)
class TrioStreamAdapter(IRawConnection):
"""Adapter to make trio memory streams work with libp2p."""
def __init__(self, send_stream, receive_stream, is_initiator=False):
self.send_stream = send_stream
self.receive_stream = receive_stream
self.is_initiator = is_initiator
async def write(self, data: bytes) -> None:
logging.debug(f"Attempting to write {len(data)} bytes")
with trio.move_on_after(2):
await self.send_stream.send_all(data)
async def read(self, n: int | None = None) -> bytes:
if n is None or n <= 0:
raise ValueError("Reading unbounded or zero bytes not supported")
logging.debug(f"Attempting to read {n} bytes")
with trio.move_on_after(2):
data = await self.receive_stream.receive_some(n)
logging.debug(f"Read {len(data)} bytes")
return data
async def close(self) -> None:
logging.debug("Closing stream")
await self.send_stream.aclose()
await self.receive_stream.aclose()
def get_remote_address(self) -> tuple[str, int] | None:
"""Return None since this is a test adapter without real network info."""
return None
@pytest.fixture
def key_pair():
return create_new_key_pair()
@pytest.fixture
def peer_id(key_pair):
return ID.from_pubkey(key_pair.public_key)
@pytest.fixture
async def secure_conn_pair(key_pair, peer_id):
"""Create a pair of secure connections for testing."""
logging.debug("Setting up secure_conn_pair")
client_send, server_receive = memory_stream_pair()
server_send, client_receive = memory_stream_pair()
client_rw = TrioStreamAdapter(client_send, client_receive)
server_rw = TrioStreamAdapter(server_send, server_receive)
insecure_transport = InsecureTransport(key_pair)
async def run_outbound(nursery_results):
with trio.move_on_after(5):
client_conn = await insecure_transport.secure_outbound(client_rw, peer_id)
logging.debug("Outbound handshake complete")
nursery_results["client"] = client_conn
async def run_inbound(nursery_results):
with trio.move_on_after(5):
server_conn = await insecure_transport.secure_inbound(server_rw)
logging.debug("Inbound handshake complete")
nursery_results["server"] = server_conn
nursery_results = {}
async with trio.open_nursery() as nursery:
nursery.start_soon(run_outbound, nursery_results)
nursery.start_soon(run_inbound, nursery_results)
await trio.sleep(0.1) # Give tasks a chance to finish
client_conn = nursery_results.get("client")
server_conn = nursery_results.get("server")
if client_conn is None or server_conn is None:
raise RuntimeError("Handshake failed: client_conn or server_conn is None")
logging.debug("secure_conn_pair setup complete")
return client_conn, server_conn
@pytest.fixture
async def yamux_pair(secure_conn_pair, peer_id):
"""Create a pair of Yamux multiplexers for testing."""
logging.debug("Setting up yamux_pair")
client_conn, server_conn = secure_conn_pair
client_yamux = Yamux(client_conn, peer_id, is_initiator=True)
server_yamux = Yamux(server_conn, peer_id, is_initiator=False)
async with trio.open_nursery() as nursery:
with trio.move_on_after(5):
nursery.start_soon(client_yamux.start)
nursery.start_soon(server_yamux.start)
await trio.sleep(0.1)
logging.debug("yamux_pair started")
yield client_yamux, server_yamux
logging.debug("yamux_pair cleanup")
@pytest.mark.trio
async def test_yamux_race_condition_without_locks(yamux_pair):
"""
Test for race-around/interleaving in Yamux streams when read/write
locks are disabled.
This launches concurrent writers/readers on both sides of a stream.
If there is no proper locking, the received data may be interleaved
or corrupted.
The test creates structured messages and verifies they are received
intact and in order.
Without proper locking, concurrent read/write operations could cause
data corruption
or message interleaving, which this test will catch.
"""
client_yamux, server_yamux = yamux_pair
client_stream: YamuxStream = await client_yamux.open_stream()
server_stream: YamuxStream = await server_yamux.accept_stream()
MSG_COUNT = 10
MSG_SIZE = 256 * 1024
client_msgs = [
f"CLIENT-MSG-{i:03d}-".encode().ljust(MSG_SIZE, b"C") for i in range(MSG_COUNT)
]
server_msgs = [
f"SERVER-MSG-{i:03d}-".encode().ljust(MSG_SIZE, b"S") for i in range(MSG_COUNT)
]
client_received = []
server_received = []
async def writer(stream, msgs, name):
"""Write messages with minimal delays to encourage race conditions."""
for i, msg in enumerate(msgs):
await stream.write(msg)
# Yield control frequently to encourage interleaving
if i % 5 == 0:
await trio.sleep(0.005)
async def reader(stream, received, name):
"""Read messages and store them for verification."""
for i in range(MSG_COUNT):
data = await stream.read(MSG_SIZE)
received.append(data)
if i % 3 == 0:
await trio.sleep(0.001)
# Running all operations concurrently
async with trio.open_nursery() as nursery:
nursery.start_soon(writer, client_stream, client_msgs, "client")
nursery.start_soon(writer, server_stream, server_msgs, "server")
nursery.start_soon(reader, client_stream, client_received, "client")
nursery.start_soon(reader, server_stream, server_received, "server")
assert len(client_received) == MSG_COUNT, (
f"Client received {len(client_received)} messages, expected {MSG_COUNT}"
)
assert len(server_received) == MSG_COUNT, (
f"Server received {len(server_received)} messages, expected {MSG_COUNT}"
)
assert client_received == server_msgs, (
"Client did not receive server messages in order or intact!"
)
assert server_received == client_msgs, (
"Server did not receive client messages in order or intact!"
)
for i, msg in enumerate(client_received):
# logging.debug(f"datatype of msg: {type(msg)}, length: {len(msg)}")
# logging.debug(f"datatype of msg: {type(b"SERVER-MSG-")}")
assert len(msg) == MSG_SIZE, (
f"Client message {i} has wrong size: {len(msg)} != {MSG_SIZE}"
)
assert msg.startswith(b"SERVER-MSG-"), (
f"Client message {i} doesn't start with expected prefix"
)
for i, msg in enumerate(server_received):
assert len(msg) == MSG_SIZE, (
f"Server message {i} has wrong size: {len(msg)} != {MSG_SIZE}"
)
assert msg.startswith(b"CLIENT-MSG-"), (
f"Server message {i} doesn't start with expected prefix"
)
await client_stream.close()
await server_stream.close()