Merge branch 'main' into async-validators

This commit is contained in:
Manu Sheel Gupta
2025-07-05 14:50:18 -07:00
committed by GitHub
4 changed files with 505 additions and 56 deletions

View File

@ -98,16 +98,32 @@ class YamuxStream(IMuxedStream):
# Flow control: Check if we have enough send window
total_len = len(data)
sent = 0
logging.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
while sent < total_len:
# Wait for available window with timeout
timeout = False
async with self.window_lock:
# Wait for available window
while self.send_window == 0 and not self.closed:
# Release lock while waiting
if self.send_window == 0:
logging.debug(
f"Stream {self.stream_id}: Window is zero, waiting for update"
)
# Release lock and wait with timeout
self.window_lock.release()
await trio.sleep(0.01)
# To avoid re-acquiring the lock immediately,
with trio.move_on_after(5.0) as cancel_scope:
while self.send_window == 0 and not self.closed:
await trio.sleep(0.01)
# If we timed out, cancel the scope
timeout = cancel_scope.cancelled_caught
# Re-acquire lock
await self.window_lock.acquire()
# If we timed out waiting for window update, raise an error
if timeout:
raise MuxedStreamError(
"Timed out waiting for window update after 5 seconds."
)
if self.closed:
raise MuxedStreamError("Stream is closed")
@ -123,25 +139,45 @@ class YamuxStream(IMuxedStream):
await self.conn.secured_conn.write(header + chunk)
sent += to_send
# If window is getting low, consider updating
if self.send_window < DEFAULT_WINDOW_SIZE // 2:
await self.send_window_update()
async def send_window_update(self, increment: int | None = None) -> None:
"""Send a window update to peer."""
if increment is None:
increment = DEFAULT_WINDOW_SIZE - self.recv_window
async def send_window_update(self, increment: int, skip_lock: bool = False) -> None:
"""
Send a window update to peer.
param:increment: The amount to increment the window size by.
If None, uses the difference between DEFAULT_WINDOW_SIZE
and current receive window.
param:skip_lock (bool): If True, skips acquiring window_lock.
This should only be used when calling from a context
that already holds the lock.
"""
if increment <= 0:
# If increment is zero or negative, skip sending update
logging.debug(
f"Stream {self.stream_id}: Skipping window update"
f"(increment={increment})"
)
return
logging.debug(
f"Stream {self.stream_id}: Sending window update with increment={increment}"
)
async with self.window_lock:
self.recv_window += increment
async def _do_window_update() -> None:
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_WINDOW_UPDATE, 0, self.stream_id, increment
YAMUX_HEADER_FORMAT,
0,
TYPE_WINDOW_UPDATE,
0,
self.stream_id,
increment,
)
await self.conn.secured_conn.write(header)
if skip_lock:
await _do_window_update()
else:
async with self.window_lock:
await _do_window_update()
async def read(self, n: int | None = -1) -> bytes:
# Handle None value for n by converting it to -1
if n is None:
@ -154,55 +190,68 @@ class YamuxStream(IMuxedStream):
)
raise MuxedStreamEOF("Stream is closed for receiving")
# If reading until EOF (n == -1), block until stream is closed
if n == -1:
while not self.recv_closed and not self.conn.event_shutting_down.is_set():
data = b""
while not self.conn.event_shutting_down.is_set():
# Check if there's data in the buffer
buffer = self.conn.stream_buffers.get(self.stream_id)
if buffer and len(buffer) > 0:
# Wait for closure even if data is available
logging.debug(
f"Stream {self.stream_id}:Waiting for FIN before returning data"
)
await self.conn.stream_events[self.stream_id].wait()
self.conn.stream_events[self.stream_id] = trio.Event()
else:
# No data, wait for data or closure
logging.debug(f"Stream {self.stream_id}: Waiting for data or FIN")
await self.conn.stream_events[self.stream_id].wait()
self.conn.stream_events[self.stream_id] = trio.Event()
# After loop, check if stream is closed or shutting down
async with self.conn.streams_lock:
if self.conn.event_shutting_down.is_set():
logging.debug(f"Stream {self.stream_id}: Connection shutting down")
raise MuxedStreamEOF("Connection shut down")
if self.closed:
if self.reset_received:
logging.debug(f"Stream {self.stream_id}: Stream was reset")
raise MuxedStreamReset("Stream was reset")
else:
logging.debug(
f"Stream {self.stream_id}: Stream closed cleanly (EOF)"
)
raise MuxedStreamEOF("Stream closed cleanly (EOF)")
buffer = self.conn.stream_buffers.get(self.stream_id)
# If buffer is not available, check if stream is closed
if buffer is None:
logging.debug(
f"Stream {self.stream_id}: Buffer gone, assuming closed"
)
logging.debug(f"Stream {self.stream_id}: No buffer available")
raise MuxedStreamEOF("Stream buffer closed")
# If we have data in buffer, process it
if len(buffer) > 0:
chunk = bytes(buffer)
buffer.clear()
data += chunk
# Send window update for the chunk we just read
async with self.window_lock:
self.recv_window += len(chunk)
logging.debug(f"Stream {self.stream_id}: Update {len(chunk)}")
await self.send_window_update(len(chunk), skip_lock=True)
# If stream is closed (FIN received) and buffer is empty, break
if self.recv_closed and len(buffer) == 0:
logging.debug(f"Stream {self.stream_id}: EOF reached")
raise MuxedStreamEOF("Stream is closed for receiving")
# Return all buffered data
data = bytes(buffer)
buffer.clear()
logging.debug(f"Stream {self.stream_id}: Returning {len(data)} bytes")
logging.debug(f"Stream {self.stream_id}: Closed with empty buffer")
break
# If stream was reset, raise reset error
if self.reset_received:
logging.debug(f"Stream {self.stream_id}: Stream was reset")
raise MuxedStreamReset("Stream was reset")
# Wait for more data or stream closure
logging.debug(f"Stream {self.stream_id}: Waiting for data or FIN")
await self.conn.stream_events[self.stream_id].wait()
self.conn.stream_events[self.stream_id] = trio.Event()
# After loop exit, first check if we have data to return
if data:
logging.debug(
f"Stream {self.stream_id}: Returning {len(data)} bytes after loop"
)
return data
# For specific size read (n > 0), return available data immediately
return await self.conn.read_stream(self.stream_id, n)
# No data accumulated, now check why we exited the loop
if self.conn.event_shutting_down.is_set():
logging.debug(f"Stream {self.stream_id}: Connection shutting down")
raise MuxedStreamEOF("Connection shut down")
# Return empty data
return b""
else:
data = await self.conn.read_stream(self.stream_id, n)
async with self.window_lock:
self.recv_window += len(data)
logging.debug(
f"Stream {self.stream_id}: Sending window update after read, "
f"increment={len(data)}"
)
await self.send_window_update(len(data), skip_lock=True)
return data
async def close(self) -> None:
if not self.send_closed:

View File

@ -0,0 +1,6 @@
Fixed several flow-control and concurrency issues in the `YamuxStream` class. Previously, stress-testing revealed that transferring data over `DEFAULT_WINDOW_SIZE` would break the stream due to inconsistent window update handling and lock management. The fixes include:
- Removed sending of window updates during writes to maintain correct flow-control.
- Added proper timeout handling when releasing and acquiring locks to prevent concurrency errors.
- Corrected the `read` function to properly handle window updates for both `read_until_EOF` and `read_n_bytes`.
- Added event logging at `send_window_updates` and `waiting_for_window_updates` for better observability.

View File

@ -0,0 +1,199 @@
import logging
import pytest
import trio
from trio.testing import (
memory_stream_pair,
)
from libp2p.abc import IRawConnection
from libp2p.crypto.ed25519 import (
create_new_key_pair,
)
from libp2p.peer.id import (
ID,
)
from libp2p.security.insecure.transport import (
InsecureTransport,
)
from libp2p.stream_muxer.yamux.yamux import (
Yamux,
YamuxStream,
)
class TrioStreamAdapter(IRawConnection):
"""Adapter to make trio memory streams work with libp2p."""
def __init__(self, send_stream, receive_stream, is_initiator=False):
self.send_stream = send_stream
self.receive_stream = receive_stream
self.is_initiator = is_initiator
async def write(self, data: bytes) -> None:
logging.debug(f"Attempting to write {len(data)} bytes")
with trio.move_on_after(2):
await self.send_stream.send_all(data)
async def read(self, n: int | None = None) -> bytes:
if n is None or n <= 0:
raise ValueError("Reading unbounded or zero bytes not supported")
logging.debug(f"Attempting to read {n} bytes")
with trio.move_on_after(2):
data = await self.receive_stream.receive_some(n)
logging.debug(f"Read {len(data)} bytes")
return data
async def close(self) -> None:
logging.debug("Closing stream")
await self.send_stream.aclose()
await self.receive_stream.aclose()
def get_remote_address(self) -> tuple[str, int] | None:
"""Return None since this is a test adapter without real network info."""
return None
@pytest.fixture
def key_pair():
return create_new_key_pair()
@pytest.fixture
def peer_id(key_pair):
return ID.from_pubkey(key_pair.public_key)
@pytest.fixture
async def secure_conn_pair(key_pair, peer_id):
"""Create a pair of secure connections for testing."""
logging.debug("Setting up secure_conn_pair")
client_send, server_receive = memory_stream_pair()
server_send, client_receive = memory_stream_pair()
client_rw = TrioStreamAdapter(client_send, client_receive)
server_rw = TrioStreamAdapter(server_send, server_receive)
insecure_transport = InsecureTransport(key_pair)
async def run_outbound(nursery_results):
with trio.move_on_after(5):
client_conn = await insecure_transport.secure_outbound(client_rw, peer_id)
logging.debug("Outbound handshake complete")
nursery_results["client"] = client_conn
async def run_inbound(nursery_results):
with trio.move_on_after(5):
server_conn = await insecure_transport.secure_inbound(server_rw)
logging.debug("Inbound handshake complete")
nursery_results["server"] = server_conn
nursery_results = {}
async with trio.open_nursery() as nursery:
nursery.start_soon(run_outbound, nursery_results)
nursery.start_soon(run_inbound, nursery_results)
await trio.sleep(0.1) # Give tasks a chance to finish
client_conn = nursery_results.get("client")
server_conn = nursery_results.get("server")
if client_conn is None or server_conn is None:
raise RuntimeError("Handshake failed: client_conn or server_conn is None")
logging.debug("secure_conn_pair setup complete")
return client_conn, server_conn
@pytest.fixture
async def yamux_pair(secure_conn_pair, peer_id):
"""Create a pair of Yamux multiplexers for testing."""
logging.debug("Setting up yamux_pair")
client_conn, server_conn = secure_conn_pair
client_yamux = Yamux(client_conn, peer_id, is_initiator=True)
server_yamux = Yamux(server_conn, peer_id, is_initiator=False)
async with trio.open_nursery() as nursery:
with trio.move_on_after(5):
nursery.start_soon(client_yamux.start)
nursery.start_soon(server_yamux.start)
await trio.sleep(0.1)
logging.debug("yamux_pair started")
yield client_yamux, server_yamux
logging.debug("yamux_pair cleanup")
@pytest.mark.trio
async def test_yamux_race_condition_without_locks(yamux_pair):
"""
Test for race-around/interleaving in Yamux streams,when reading in
segments of data.
This launches concurrent writers/readers on both sides of a stream.
If there is no proper locking, the received data may be interleaved
or corrupted.
The test creates structured messages and verifies they are received
intact and in order.
Without proper locking, concurrent read/write operations could cause
data corruption
or message interleaving, which this test will catch.
"""
client_yamux, server_yamux = yamux_pair
client_stream: YamuxStream = await client_yamux.open_stream()
server_stream: YamuxStream = await server_yamux.accept_stream()
MSG_COUNT = 10
MSG_SIZE = 256 * 1024 # At max,only DEFAULT_WINDOW_SIZE bytes can be read
client_msgs = [
f"CLIENT-MSG-{i:03d}-".encode().ljust(MSG_SIZE, b"C") for i in range(MSG_COUNT)
]
server_msgs = [
f"SERVER-MSG-{i:03d}-".encode().ljust(MSG_SIZE, b"S") for i in range(MSG_COUNT)
]
client_received = []
server_received = []
async def writer(stream, msgs, name):
"""Write messages with minimal delays to encourage race conditions."""
for i, msg in enumerate(msgs):
await stream.write(msg)
# Yield control frequently to encourage interleaving
if i % 5 == 0:
await trio.sleep(0.005)
async def reader(stream, received, name):
"""Read messages and store them for verification."""
for i in range(MSG_COUNT):
data = await stream.read(MSG_SIZE)
received.append(data)
if i % 3 == 0:
await trio.sleep(0.001)
# Running all operations concurrently
async with trio.open_nursery() as nursery:
nursery.start_soon(writer, client_stream, client_msgs, "client")
nursery.start_soon(writer, server_stream, server_msgs, "server")
nursery.start_soon(reader, client_stream, client_received, "client")
nursery.start_soon(reader, server_stream, server_received, "server")
assert len(client_received) == MSG_COUNT, (
f"Client received {len(client_received)} messages, expected {MSG_COUNT}"
)
assert len(server_received) == MSG_COUNT, (
f"Server received {len(server_received)} messages, expected {MSG_COUNT}"
)
assert client_received == server_msgs, (
"Client did not receive server messages in order or intact!"
)
assert server_received == client_msgs, (
"Server did not receive client messages in order or intact!"
)
for i, msg in enumerate(client_received):
assert len(msg) == MSG_SIZE, (
f"Client message {i} has wrong size: {len(msg)} != {MSG_SIZE}"
)
for i, msg in enumerate(server_received):
assert len(msg) == MSG_SIZE, (
f"Server message {i} has wrong size: {len(msg)} != {MSG_SIZE}"
)
await client_stream.close()
await server_stream.close()

View File

@ -0,0 +1,195 @@
import logging
import pytest
import trio
from trio.testing import (
memory_stream_pair,
)
from libp2p.abc import IRawConnection
from libp2p.crypto.ed25519 import (
create_new_key_pair,
)
from libp2p.peer.id import (
ID,
)
from libp2p.security.insecure.transport import (
InsecureTransport,
)
from libp2p.stream_muxer.exceptions import MuxedStreamEOF
from libp2p.stream_muxer.yamux.yamux import (
Yamux,
YamuxStream,
)
class TrioStreamAdapter(IRawConnection):
"""Adapter to make trio memory streams work with libp2p."""
def __init__(self, send_stream, receive_stream, is_initiator=False):
self.send_stream = send_stream
self.receive_stream = receive_stream
self.is_initiator = is_initiator
async def write(self, data: bytes) -> None:
logging.debug(f"Attempting to write {len(data)} bytes")
with trio.move_on_after(2):
await self.send_stream.send_all(data)
async def read(self, n: int | None = None) -> bytes:
if n is None or n <= 0:
raise ValueError("Reading unbounded or zero bytes not supported")
logging.debug(f"Attempting to read {n} bytes")
with trio.move_on_after(2):
data = await self.receive_stream.receive_some(n)
logging.debug(f"Read {len(data)} bytes")
return data
async def close(self) -> None:
logging.debug("Closing stream")
await self.send_stream.aclose()
await self.receive_stream.aclose()
def get_remote_address(self) -> tuple[str, int] | None:
"""Return None since this is a test adapter without real network info."""
return None
@pytest.fixture
def key_pair():
return create_new_key_pair()
@pytest.fixture
def peer_id(key_pair):
return ID.from_pubkey(key_pair.public_key)
@pytest.fixture
async def secure_conn_pair(key_pair, peer_id):
"""Create a pair of secure connections for testing."""
logging.debug("Setting up secure_conn_pair")
client_send, server_receive = memory_stream_pair()
server_send, client_receive = memory_stream_pair()
client_rw = TrioStreamAdapter(client_send, client_receive)
server_rw = TrioStreamAdapter(server_send, server_receive)
insecure_transport = InsecureTransport(key_pair)
async def run_outbound(nursery_results):
with trio.move_on_after(5):
client_conn = await insecure_transport.secure_outbound(client_rw, peer_id)
logging.debug("Outbound handshake complete")
nursery_results["client"] = client_conn
async def run_inbound(nursery_results):
with trio.move_on_after(5):
server_conn = await insecure_transport.secure_inbound(server_rw)
logging.debug("Inbound handshake complete")
nursery_results["server"] = server_conn
nursery_results = {}
async with trio.open_nursery() as nursery:
nursery.start_soon(run_outbound, nursery_results)
nursery.start_soon(run_inbound, nursery_results)
await trio.sleep(0.1) # Give tasks a chance to finish
client_conn = nursery_results.get("client")
server_conn = nursery_results.get("server")
if client_conn is None or server_conn is None:
raise RuntimeError("Handshake failed: client_conn or server_conn is None")
logging.debug("secure_conn_pair setup complete")
return client_conn, server_conn
@pytest.fixture
async def yamux_pair(secure_conn_pair, peer_id):
"""Create a pair of Yamux multiplexers for testing."""
logging.debug("Setting up yamux_pair")
client_conn, server_conn = secure_conn_pair
client_yamux = Yamux(client_conn, peer_id, is_initiator=True)
server_yamux = Yamux(server_conn, peer_id, is_initiator=False)
async with trio.open_nursery() as nursery:
with trio.move_on_after(5):
nursery.start_soon(client_yamux.start)
nursery.start_soon(server_yamux.start)
await trio.sleep(0.1)
logging.debug("yamux_pair started")
yield client_yamux, server_yamux
logging.debug("yamux_pair cleanup")
@pytest.mark.trio
async def test_yamux_race_condition_without_locks(yamux_pair):
"""
Test for race-around/interleaving in Yamux streams,when reading till
EOF is being used.
This launches concurrent writers/readers on both sides of a stream.
If there is no proper locking, the received data may be interleaved
or corrupted.
The test creates structured messages and verifies they are received
intact and in order.
Without proper locking, concurrent read/write operations could cause
data corruption
or message interleaving, which this test will catch.
"""
client_yamux, server_yamux = yamux_pair
client_stream: YamuxStream = await client_yamux.open_stream()
server_stream: YamuxStream = await server_yamux.accept_stream()
MSG_COUNT = 1
MSG_SIZE = 512 * 1024
client_msgs = [
f"CLIENT-MSG-{i:03d}-".encode().ljust(MSG_SIZE, b"C") for i in range(MSG_COUNT)
]
server_msgs = [
f"SERVER-MSG-{i:03d}-".encode().ljust(MSG_SIZE, b"S") for i in range(MSG_COUNT)
]
client_received = []
server_received = []
async def writer(stream, msgs, name):
"""Write messages with minimal delays to encourage race conditions."""
for i, msg in enumerate(msgs):
await stream.write(msg)
# Yield control frequently to encourage interleaving
if i % 5 == 0:
await trio.sleep(0.005)
async def reader(stream, received, name):
"""Read messages and store them for verification."""
try:
data = await stream.read()
if data:
received.append(data)
except MuxedStreamEOF:
pass
# Running all operations concurrently
async with trio.open_nursery() as nursery:
nursery.start_soon(writer, client_stream, client_msgs, "client")
nursery.start_soon(writer, server_stream, server_msgs, "server")
nursery.start_soon(reader, client_stream, client_received, "client")
nursery.start_soon(reader, server_stream, server_received, "server")
assert client_received == server_msgs, (
"Client did not receive server messages in order or intact!"
)
assert server_received == client_msgs, (
"Server did not receive client messages in order or intact!"
)
for i, msg in enumerate(client_received):
assert len(msg) == MSG_SIZE, (
f"Client message {i} has wrong size: {len(msg)} != {MSG_SIZE}"
)
for i, msg in enumerate(server_received):
assert len(msg) == MSG_SIZE, (
f"Server message {i} has wrong size: {len(msg)} != {MSG_SIZE}"
)
await client_stream.close()
await server_stream.close()