Merge branch 'main' into feature/bootstrap

This commit is contained in:
Manu Sheel Gupta
2025-07-20 04:39:56 -07:00
committed by GitHub
14 changed files with 1163 additions and 491 deletions

View File

@ -72,13 +72,46 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No
client_addr = server_addr.replace("/ip4/0.0.0.0/", "/ip4/127.0.0.1/")
format_name = "length-prefixed" if use_varint_format else "raw protobuf"
format_flag = "--raw-format" if not use_varint_format else ""
print(
f"First host listening (using {format_name} format). "
f"Run this from another console:\n\n"
f"identify-demo "
f"-d {client_addr}\n"
f"identify-demo {format_flag} -d {client_addr}\n"
)
print("Waiting for incoming identify request...")
# Add a custom handler to show connection events
async def custom_identify_handler(stream):
peer_id = stream.muxed_conn.peer_id
print(f"\n🔗 Received identify request from peer: {peer_id}")
# Show remote address in multiaddr format
try:
from libp2p.identity.identify.identify import (
_remote_address_to_multiaddr,
)
remote_address = stream.get_remote_address()
if remote_address:
observed_multiaddr = _remote_address_to_multiaddr(
remote_address
)
# Add the peer ID to create a complete multiaddr
complete_multiaddr = f"{observed_multiaddr}/p2p/{peer_id}"
print(f" Remote address: {complete_multiaddr}")
else:
print(f" Remote address: {remote_address}")
except Exception:
print(f" Remote address: {stream.get_remote_address()}")
# Call the original handler
await identify_handler(stream)
print(f"✅ Successfully processed identify request from {peer_id}")
# Replace the handler with our custom one
host_a.set_stream_handler(IDENTIFY_PROTOCOL_ID, custom_identify_handler)
await trio.sleep_forever()
else:
@ -93,25 +126,99 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No
info = info_from_p2p_addr(maddr)
print(f"Second host connecting to peer: {info.peer_id}")
await host_b.connect(info)
try:
await host_b.connect(info)
except Exception as e:
error_msg = str(e)
if "unable to connect" in error_msg or "SwarmException" in error_msg:
print(f"\n❌ Cannot connect to peer: {info.peer_id}")
print(f" Address: {destination}")
print(f" Error: {error_msg}")
print(
"\n💡 Make sure the peer is running and the address is correct."
)
return
else:
# Re-raise other exceptions
raise
stream = await host_b.new_stream(info.peer_id, (IDENTIFY_PROTOCOL_ID,))
try:
print("Starting identify protocol...")
# Read the complete response (could be either format)
# Read a larger chunk to get all the data before stream closes
response = await stream.read(8192) # Read enough data in one go
# Read the response properly based on the format
if use_varint_format:
# For length-prefixed format, read varint length first
from libp2p.utils.varint import decode_varint_from_bytes
# Read varint length prefix
length_bytes = b""
while True:
b = await stream.read(1)
if not b:
raise Exception("Stream closed while reading varint length")
length_bytes += b
if b[0] & 0x80 == 0:
break
msg_length = decode_varint_from_bytes(length_bytes)
print(f"Expected message length: {msg_length} bytes")
# Read the protobuf message
response = await stream.read(msg_length)
if len(response) != msg_length:
raise Exception(
f"Incomplete message: expected {msg_length} bytes, "
f"got {len(response)}"
)
# Combine length prefix and message
full_response = length_bytes + response
else:
# For raw format, read all available data
response = await stream.read(8192)
full_response = response
await stream.close()
# Parse the response using the robust protocol-level function
# This handles both old and new formats automatically
identify_msg = parse_identify_response(response)
identify_msg = parse_identify_response(full_response)
print_identify_response(identify_msg)
except Exception as e:
print(f"Identify protocol error: {e}")
error_msg = str(e)
print(f"Identify protocol error: {error_msg}")
# Check for specific format mismatch errors
if "Error parsing message" in error_msg or "DecodeError" in error_msg:
print("\n" + "=" * 60)
print("FORMAT MISMATCH DETECTED!")
print("=" * 60)
if use_varint_format:
print(
"You are using length-prefixed format (default) but the "
"listener"
)
print("is using raw protobuf format.")
print(
"\nTo fix this, run the dialer with the --raw-format flag:"
)
print(f"identify-demo --raw-format -d {destination}")
else:
print("You are using raw protobuf format but the listener")
print("is using length-prefixed format (default).")
print(
"\nTo fix this, run the dialer without the --raw-format "
"flag:"
)
print(f"identify-demo -d {destination}")
print("=" * 60)
else:
import traceback
traceback.print_exc()
return

View File

@ -1,3 +1,4 @@
import functools
import hashlib
import base58
@ -36,25 +37,23 @@ if ENABLE_INLINING:
class ID:
_bytes: bytes
_xor_id: int | None = None
_b58_str: str | None = None
def __init__(self, peer_id_bytes: bytes) -> None:
self._bytes = peer_id_bytes
@property
@functools.cached_property
def xor_id(self) -> int:
if not self._xor_id:
self._xor_id = int(sha256_digest(self._bytes).hex(), 16)
return self._xor_id
return int(sha256_digest(self._bytes).hex(), 16)
@functools.cached_property
def base58(self) -> str:
return base58.b58encode(self._bytes).decode()
def to_bytes(self) -> bytes:
return self._bytes
def to_base58(self) -> str:
if not self._b58_str:
self._b58_str = base58.b58encode(self._bytes).decode()
return self._b58_str
return self.base58
def __repr__(self) -> str:
return f"<libp2p.peer.id.ID ({self!s})>"

View File

@ -102,6 +102,9 @@ class TopicValidator(NamedTuple):
is_async: bool
MAX_CONCURRENT_VALIDATORS = 10
class Pubsub(Service, IPubsub):
host: IHost
@ -109,6 +112,7 @@ class Pubsub(Service, IPubsub):
peer_receive_channel: trio.MemoryReceiveChannel[ID]
dead_peer_receive_channel: trio.MemoryReceiveChannel[ID]
_validator_semaphore: trio.Semaphore
seen_messages: LastSeenCache
@ -143,6 +147,7 @@ class Pubsub(Service, IPubsub):
msg_id_constructor: Callable[
[rpc_pb2.Message], bytes
] = get_peer_and_seqno_msg_id,
max_concurrent_validator_count: int = MAX_CONCURRENT_VALIDATORS,
) -> None:
"""
Construct a new Pubsub object, which is responsible for handling all
@ -168,6 +173,7 @@ class Pubsub(Service, IPubsub):
# Therefore, we can only close from the receive side.
self.peer_receive_channel = peer_receive
self.dead_peer_receive_channel = dead_peer_receive
self._validator_semaphore = trio.Semaphore(max_concurrent_validator_count)
# Register a notifee
self.host.get_network().register_notifee(
PubsubNotifee(peer_send, dead_peer_send)
@ -657,7 +663,11 @@ class Pubsub(Service, IPubsub):
logger.debug("successfully published message %s", msg)
async def validate_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
async def validate_msg(
self,
msg_forwarder: ID,
msg: rpc_pb2.Message,
) -> None:
"""
Validate the received message.
@ -680,23 +690,34 @@ class Pubsub(Service, IPubsub):
if not validator(msg_forwarder, msg):
raise ValidationError(f"Validation failed for msg={msg}")
# TODO: Implement throttle on async validators
if len(async_topic_validators) > 0:
# Appends to lists are thread safe in CPython
results = []
async def run_async_validator(func: AsyncValidatorFn) -> None:
result = await func(msg_forwarder, msg)
results.append(result)
results: list[bool] = []
async with trio.open_nursery() as nursery:
for async_validator in async_topic_validators:
nursery.start_soon(run_async_validator, async_validator)
nursery.start_soon(
self._run_async_validator,
async_validator,
msg_forwarder,
msg,
results,
)
if not all(results):
raise ValidationError(f"Validation failed for msg={msg}")
async def _run_async_validator(
self,
func: AsyncValidatorFn,
msg_forwarder: ID,
msg: rpc_pb2.Message,
results: list[bool],
) -> None:
async with self._validator_semaphore:
result = await func(msg_forwarder, msg)
results.append(result)
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
"""
Push a pubsub message to others.

View File

@ -1,3 +1,5 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from types import (
TracebackType,
)
@ -32,6 +34,72 @@ if TYPE_CHECKING:
)
class ReadWriteLock:
"""
A read-write lock that allows multiple concurrent readers
or one exclusive writer, implemented using Trio primitives.
"""
def __init__(self) -> None:
self._readers = 0
self._readers_lock = trio.Lock() # Protects access to _readers count
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time
async def acquire_read(self) -> None:
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
try:
async with self._readers_lock:
if self._readers == 0:
await self._writer_lock.acquire()
self._readers += 1
except trio.Cancelled:
raise
async def release_read(self) -> None:
"""Release a read lock."""
async with self._readers_lock:
if self._readers == 1:
self._writer_lock.release()
self._readers -= 1
async def acquire_write(self) -> None:
"""Acquire an exclusive write lock."""
try:
await self._writer_lock.acquire()
except trio.Cancelled:
raise
def release_write(self) -> None:
"""Release the exclusive write lock."""
self._writer_lock.release()
@asynccontextmanager
async def read_lock(self) -> AsyncGenerator[None, None]:
"""Context manager for acquiring and releasing a read lock safely."""
acquire = False
try:
await self.acquire_read()
acquire = True
yield
finally:
if acquire:
with trio.CancelScope() as scope:
scope.shield = True
await self.release_read()
@asynccontextmanager
async def write_lock(self) -> AsyncGenerator[None, None]:
"""Context manager for acquiring and releasing a write lock safely."""
acquire = False
try:
await self.acquire_write()
acquire = True
yield
finally:
if acquire:
self.release_write()
class MplexStream(IMuxedStream):
"""
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go
@ -46,7 +114,7 @@ class MplexStream(IMuxedStream):
read_deadline: int | None
write_deadline: int | None
# TODO: Add lock for read/write to avoid interleaving receiving messages?
rw_lock: ReadWriteLock
close_lock: trio.Lock
# NOTE: `dataIn` is size of 8 in Go implementation.
@ -80,6 +148,7 @@ class MplexStream(IMuxedStream):
self.event_remote_closed = trio.Event()
self.event_reset = trio.Event()
self.close_lock = trio.Lock()
self.rw_lock = ReadWriteLock()
self.incoming_data_channel = incoming_data_channel
self._buf = bytearray()
@ -113,48 +182,49 @@ class MplexStream(IMuxedStream):
:param n: number of bytes to read
:return: bytes actually read
"""
if n is not None and n < 0:
raise ValueError(
"the number of bytes to read `n` must be non-negative or "
f"`None` to indicate read until EOF, got n={n}"
)
if self.event_reset.is_set():
raise MplexStreamReset
if n is None:
return await self._read_until_eof()
if len(self._buf) == 0:
data: bytes
# Peek whether there is data available. If yes, we just read until there is
# no data, then return.
try:
data = self.incoming_data_channel.receive_nowait()
self._buf.extend(data)
except trio.EndOfChannel:
raise MplexStreamEOF
except trio.WouldBlock:
# We know `receive` will be blocked here. Wait for data here with
# `receive` and catch all kinds of errors here.
async with self.rw_lock.read_lock():
if n is not None and n < 0:
raise ValueError(
"the number of bytes to read `n` must be non-negative or "
f"`None` to indicate read until EOF, got n={n}"
)
if self.event_reset.is_set():
raise MplexStreamReset
if n is None:
return await self._read_until_eof()
if len(self._buf) == 0:
data: bytes
# Peek whether there is data available. If yes, we just read until
# there is no data, then return.
try:
data = await self.incoming_data_channel.receive()
data = self.incoming_data_channel.receive_nowait()
self._buf.extend(data)
except trio.EndOfChannel:
if self.event_reset.is_set():
raise MplexStreamReset
if self.event_remote_closed.is_set():
raise MplexStreamEOF
except trio.ClosedResourceError as error:
# Probably `incoming_data_channel` is closed in `reset` when we are
# waiting for `receive`.
if self.event_reset.is_set():
raise MplexStreamReset
raise Exception(
"`incoming_data_channel` is closed but stream is not reset. "
"This should never happen."
) from error
self._buf.extend(self._read_return_when_blocked())
payload = self._buf[:n]
self._buf = self._buf[len(payload) :]
return bytes(payload)
raise MplexStreamEOF
except trio.WouldBlock:
# We know `receive` will be blocked here. Wait for data here with
# `receive` and catch all kinds of errors here.
try:
data = await self.incoming_data_channel.receive()
self._buf.extend(data)
except trio.EndOfChannel:
if self.event_reset.is_set():
raise MplexStreamReset
if self.event_remote_closed.is_set():
raise MplexStreamEOF
except trio.ClosedResourceError as error:
# Probably `incoming_data_channel` is closed in `reset` when
# we are waiting for `receive`.
if self.event_reset.is_set():
raise MplexStreamReset
raise Exception(
"`incoming_data_channel` is closed but stream is not reset."
"This should never happen."
) from error
self._buf.extend(self._read_return_when_blocked())
payload = self._buf[:n]
self._buf = self._buf[len(payload) :]
return bytes(payload)
async def write(self, data: bytes) -> None:
"""
@ -162,14 +232,15 @@ class MplexStream(IMuxedStream):
:return: number of bytes written
"""
if self.event_local_closed.is_set():
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")
flag = (
HeaderTags.MessageInitiator
if self.is_initiator
else HeaderTags.MessageReceiver
)
await self.muxed_conn.send_message(flag, data, self.stream_id)
async with self.rw_lock.write_lock():
if self.event_local_closed.is_set():
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")
flag = (
HeaderTags.MessageInitiator
if self.is_initiator
else HeaderTags.MessageReceiver
)
await self.muxed_conn.send_message(flag, data, self.stream_id)
async def close(self) -> None:
"""

View File

@ -0,0 +1 @@
Add lock for read/write to avoid interleaving receiving messages in mplex_stream.py

View File

@ -0,0 +1,2 @@
Added throttling for async topic validators in validate_msg, enforcing a
concurrency limit to prevent resource exhaustion under heavy load.

View File

@ -0,0 +1 @@
Replace the libp2p.peer.ID cache attributes with functools.cached_property functional decorator.

View File

@ -0,0 +1 @@
Clarified the requirement for a trailing newline in newsfragments to pass lint checks.

View File

@ -0,0 +1 @@
Fixed incorrect handling of raw protobuf format in identify protocol. The identify example now properly handles both raw and length-prefixed (varint) message formats, provides better error messages, and displays connection status with peer IDs. Replaced mock-based tests with comprehensive real network integration tests for both formats.

View File

@ -18,12 +18,19 @@ Each file should be named like `<ISSUE>.<TYPE>.rst`, where
- `performance`
- `removal`
So for example: `123.feature.rst`, `456.bugfix.rst`
So for example: `1024.feature.rst`
**Important**: Ensure the file ends with a newline character (`\n`) to pass GitHub tox linting checks.
```
Added support for Ed25519 key generation in libp2p peer identity creation.
```
If the PR fixes an issue, use that number here. If there is no issue,
then open up the PR first and use the PR number for the newsfragment.
Note that the `towncrier` tool will automatically
**Note** that the `towncrier` tool will automatically
reflow your text, so don't try to do any fancy formatting. Run
`towncrier build --draft` to get a preview of what the release notes entry
will look like in the final release notes.

View File

@ -0,0 +1,241 @@
import logging
import pytest
from libp2p.custom_types import TProtocol
from libp2p.identity.identify.identify import (
AGENT_VERSION,
ID,
PROTOCOL_VERSION,
_multiaddr_to_bytes,
identify_handler_for,
parse_identify_response,
)
from tests.utils.factories import host_pair_factory
logger = logging.getLogger("libp2p.identity.identify-integration-test")
@pytest.mark.trio
async def test_identify_protocol_varint_format_integration(security_protocol):
"""Test identify protocol with varint format in real network scenario."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=True)
)
# Make identify request
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
assert result.listen_addrs == [
_multiaddr_to_bytes(addr) for addr in host_a.get_addrs()
]
@pytest.mark.trio
async def test_identify_protocol_raw_format_integration(security_protocol):
"""Test identify protocol with raw format in real network scenario."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=False)
)
# Make identify request
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
assert result.listen_addrs == [
_multiaddr_to_bytes(addr) for addr in host_a.get_addrs()
]
@pytest.mark.trio
async def test_identify_default_format_behavior(security_protocol):
"""Test identify protocol uses correct default format."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Use default identify handler (should use varint format)
host_a.set_stream_handler(ID, identify_handler_for(host_a))
# Make identify request
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
@pytest.mark.trio
async def test_identify_cross_format_compatibility_varint_to_raw(security_protocol):
"""Test varint dialer with raw listener compatibility."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Host A uses raw format
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=False)
)
# Host B makes request (will automatically detect format)
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response (should work with automatic format detection)
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
@pytest.mark.trio
async def test_identify_cross_format_compatibility_raw_to_varint(security_protocol):
"""Test raw dialer with varint listener compatibility."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Host A uses varint format
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=True)
)
# Host B makes request (will automatically detect format)
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response (should work with automatic format detection)
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
@pytest.mark.trio
async def test_identify_format_detection_robustness(security_protocol):
"""Test identify protocol format detection is robust with various message sizes."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Test both formats with different message sizes
for use_varint in [True, False]:
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=use_varint)
)
# Make identify request
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
@pytest.mark.trio
async def test_identify_large_message_handling(security_protocol):
"""Test identify protocol handles large messages with many protocols."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Add many protocols to make the message larger
async def dummy_handler(stream):
pass
for i in range(10):
host_a.set_stream_handler(TProtocol(f"/test/protocol/{i}"), dummy_handler)
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=True)
)
# Make identify request
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
@pytest.mark.trio
async def test_identify_message_equivalence_real_network(security_protocol):
"""Test that both formats produce equivalent messages in real network."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Test varint format
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=True)
)
stream_varint = await host_b.new_stream(host_a.get_id(), (ID,))
response_varint = await stream_varint.read(8192)
await stream_varint.close()
# Test raw format
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=False)
)
stream_raw = await host_b.new_stream(host_a.get_id(), (ID,))
response_raw = await stream_raw.read(8192)
await stream_raw.close()
# Parse both responses
result_varint = parse_identify_response(response_varint)
result_raw = parse_identify_response(response_raw)
# Both should produce identical parsed results
assert result_varint.agent_version == result_raw.agent_version
assert result_varint.protocol_version == result_raw.protocol_version
assert result_varint.public_key == result_raw.public_key
assert result_varint.listen_addrs == result_raw.listen_addrs

View File

@ -1,410 +0,0 @@
import pytest
from libp2p.identity.identify.identify import (
_mk_identify_protobuf,
)
from libp2p.identity.identify.pb.identify_pb2 import (
Identify,
)
from libp2p.io.abc import Closer, Reader, Writer
from libp2p.utils.varint import (
decode_varint_from_bytes,
encode_varint_prefixed,
)
from tests.utils.factories import (
host_pair_factory,
)
class MockStream(Reader, Writer, Closer):
"""Mock stream for testing identify protocol compatibility."""
def __init__(self, data: bytes):
self.data = data
self.position = 0
self.closed = False
async def read(self, n: int | None = None) -> bytes:
if self.closed or self.position >= len(self.data):
return b""
if n is None:
n = len(self.data) - self.position
result = self.data[self.position : self.position + n]
self.position += len(result)
return result
async def write(self, data: bytes) -> None:
# Mock write - just store the data
pass
async def close(self) -> None:
self.closed = True
def create_identify_message(host, observed_multiaddr=None):
"""Create an identify protobuf message."""
return _mk_identify_protobuf(host, observed_multiaddr)
def create_new_format_message(identify_msg):
"""Create a new format (length-prefixed) identify message."""
msg_bytes = identify_msg.SerializeToString()
return encode_varint_prefixed(msg_bytes)
def create_old_format_message(identify_msg):
"""Create an old format (raw protobuf) identify message."""
return identify_msg.SerializeToString()
async def read_new_format_message(stream) -> bytes:
"""Read a new format (length-prefixed) identify message."""
# Read varint length prefix
length_bytes = b""
while True:
b = await stream.read(1)
if not b:
break
length_bytes += b
if b[0] & 0x80 == 0:
break
if not length_bytes:
raise ValueError("No length prefix received")
msg_length = decode_varint_from_bytes(length_bytes)
# Read the protobuf message
response = await stream.read(msg_length)
if len(response) != msg_length:
raise ValueError("Incomplete message received")
return response
async def read_old_format_message(stream) -> bytes:
"""Read an old format (raw protobuf) identify message."""
# Read all available data
response = b""
while True:
chunk = await stream.read(4096)
if not chunk:
break
response += chunk
return response
async def read_compatible_message(stream) -> bytes:
"""Read an identify message in either old or new format."""
# Try to read a few bytes to detect the format
first_bytes = await stream.read(10)
if not first_bytes:
raise ValueError("No data received")
# Try to decode as varint length prefix (new format)
try:
msg_length = decode_varint_from_bytes(first_bytes)
# Validate that the length is reasonable (not too large)
if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB
# Calculate how many bytes the varint consumed
varint_len = 0
for i, byte in enumerate(first_bytes):
varint_len += 1
if (byte & 0x80) == 0:
break
# Read the remaining protobuf message
remaining_bytes = await stream.read(
msg_length - (len(first_bytes) - varint_len)
)
if len(remaining_bytes) == msg_length - (len(first_bytes) - varint_len):
message_data = first_bytes[varint_len:] + remaining_bytes
# Try to parse as protobuf to validate
try:
Identify().ParseFromString(message_data)
return message_data
except Exception:
# If protobuf parsing fails, fall back to old format
pass
except Exception:
pass
# Fall back to old format (raw protobuf)
response = first_bytes
# Read more data if available
while True:
chunk = await stream.read(4096)
if not chunk:
break
response += chunk
return response
async def read_compatible_message_simple(stream) -> bytes:
"""Read a message in either old or new format (simplified version for testing)."""
# Try to read a few bytes to detect the format
first_bytes = await stream.read(10)
if not first_bytes:
raise ValueError("No data received")
# Try to decode as varint length prefix (new format)
try:
msg_length = decode_varint_from_bytes(first_bytes)
# Validate that the length is reasonable (not too large)
if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB
# Calculate how many bytes the varint consumed
varint_len = 0
for i, byte in enumerate(first_bytes):
varint_len += 1
if (byte & 0x80) == 0:
break
# Read the remaining message
remaining_bytes = await stream.read(
msg_length - (len(first_bytes) - varint_len)
)
if len(remaining_bytes) == msg_length - (len(first_bytes) - varint_len):
return first_bytes[varint_len:] + remaining_bytes
except Exception:
pass
# Fall back to old format (raw data)
response = first_bytes
# Read more data if available
while True:
chunk = await stream.read(4096)
if not chunk:
break
response += chunk
return response
def detect_format(data):
"""Detect if data is in new or old format (varint-prefixed or raw protobuf)."""
if not data:
return "unknown"
# Try to decode as varint
try:
msg_length = decode_varint_from_bytes(data)
# Validate that the length is reasonable
if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB
# Calculate varint length
varint_len = 0
for i, byte in enumerate(data):
varint_len += 1
if (byte & 0x80) == 0:
break
# Check if we have enough data for the message
if len(data) >= varint_len + msg_length:
# Additional check: try to parse the message as protobuf
try:
message_data = data[varint_len : varint_len + msg_length]
Identify().ParseFromString(message_data)
return "new"
except Exception:
# If protobuf parsing fails, it's probably not a valid new format
pass
except Exception:
pass
# If varint decoding fails or length is unreasonable, assume old format
return "old"
@pytest.mark.trio
async def test_identify_new_format_compatibility(security_protocol):
"""Test that identify protocol works with new format (length-prefixed) messages."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Create new format message
new_format_data = create_new_format_message(identify_msg)
# Create mock stream with new format data
stream = MockStream(new_format_data)
# Read using new format reader
response = await read_new_format_message(stream)
# Parse the response
parsed_msg = Identify()
parsed_msg.ParseFromString(response)
# Verify the message content
assert parsed_msg.protocol_version == identify_msg.protocol_version
assert parsed_msg.agent_version == identify_msg.agent_version
assert parsed_msg.public_key == identify_msg.public_key
@pytest.mark.trio
async def test_identify_old_format_compatibility(security_protocol):
"""Test that identify protocol works with old format (raw protobuf) messages."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Create old format message
old_format_data = create_old_format_message(identify_msg)
# Create mock stream with old format data
stream = MockStream(old_format_data)
# Read using old format reader
response = await read_old_format_message(stream)
# Parse the response
parsed_msg = Identify()
parsed_msg.ParseFromString(response)
# Verify the message content
assert parsed_msg.protocol_version == identify_msg.protocol_version
assert parsed_msg.agent_version == identify_msg.agent_version
assert parsed_msg.public_key == identify_msg.public_key
@pytest.mark.trio
async def test_identify_backward_compatibility_old_format(security_protocol):
"""Test backward compatibility reader with old format messages."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Create old format message
old_format_data = create_old_format_message(identify_msg)
# Create mock stream with old format data
stream = MockStream(old_format_data)
# Read using old format reader (which should work reliably)
response = await read_old_format_message(stream)
# Parse the response
parsed_msg = Identify()
parsed_msg.ParseFromString(response)
# Verify the message content
assert parsed_msg.protocol_version == identify_msg.protocol_version
assert parsed_msg.agent_version == identify_msg.agent_version
assert parsed_msg.public_key == identify_msg.public_key
@pytest.mark.trio
async def test_identify_backward_compatibility_new_format(security_protocol):
"""Test backward compatibility reader with new format messages."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Create new format message
new_format_data = create_new_format_message(identify_msg)
# Create mock stream with new format data
stream = MockStream(new_format_data)
# Read using new format reader (which should work reliably)
response = await read_new_format_message(stream)
# Parse the response
parsed_msg = Identify()
parsed_msg.ParseFromString(response)
# Verify the message content
assert parsed_msg.protocol_version == identify_msg.protocol_version
assert parsed_msg.agent_version == identify_msg.agent_version
assert parsed_msg.public_key == identify_msg.public_key
@pytest.mark.trio
async def test_identify_format_detection(security_protocol):
"""Test that the format detection works correctly."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Test new format detection
new_format_data = create_new_format_message(identify_msg)
format_type = detect_format(new_format_data)
assert format_type == "new", "New format should be detected correctly"
# Test old format detection
old_format_data = create_old_format_message(identify_msg)
format_type = detect_format(old_format_data)
assert format_type == "old", "Old format should be detected correctly"
@pytest.mark.trio
async def test_identify_error_handling(security_protocol):
"""Test error handling for malformed messages."""
from libp2p.exceptions import ParseError
# Test with empty data
stream = MockStream(b"")
with pytest.raises(ValueError, match="No data received"):
await read_compatible_message(stream)
# Test with incomplete varint
stream = MockStream(b"\x80") # Incomplete varint
with pytest.raises(ParseError, match="Unexpected end of data"):
await read_new_format_message(stream)
# Test with invalid protobuf data
stream = MockStream(b"\x05invalid") # Length prefix but invalid protobuf
with pytest.raises(Exception): # Should fail when parsing protobuf
response = await read_new_format_message(stream)
Identify().ParseFromString(response)
@pytest.mark.trio
async def test_identify_message_equivalence(security_protocol):
"""Test that old and new format messages are equivalent."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Create both formats
new_format_data = create_new_format_message(identify_msg)
old_format_data = create_old_format_message(identify_msg)
# Extract the protobuf message from new format
varint_len = 0
for i, byte in enumerate(new_format_data):
varint_len += 1
if (byte & 0x80) == 0:
break
new_format_protobuf = new_format_data[varint_len:]
# The protobuf messages should be identical
assert new_format_protobuf == old_format_data, (
"Protobuf messages should be identical in both formats"
)

View File

@ -5,10 +5,12 @@ import inspect
from typing import (
NamedTuple,
)
from unittest.mock import patch
import pytest
import trio
from libp2p.custom_types import AsyncValidatorFn
from libp2p.exceptions import (
ValidationError,
)
@ -243,7 +245,37 @@ async def test_get_msg_validators():
((False, True), (True, False), (True, True)),
)
@pytest.mark.trio
async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed):
async def test_validate_msg_with_throttle_condition(
is_topic_1_val_passed, is_topic_2_val_passed
):
CONCURRENCY_LIMIT = 10
state = {
"concurrency_counter": 0,
"max_observed": 0,
}
lock = trio.Lock()
async def mock_run_async_validator(
self,
func: AsyncValidatorFn,
msg_forwarder: ID,
msg: rpc_pb2.Message,
results: list[bool],
) -> None:
async with self._validator_semaphore:
async with lock:
state["concurrency_counter"] += 1
if state["concurrency_counter"] > state["max_observed"]:
state["max_observed"] = state["concurrency_counter"]
try:
result = await func(msg_forwarder, msg)
results.append(result)
finally:
async with lock:
state["concurrency_counter"] -= 1
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
def passed_sync_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool:
@ -280,11 +312,19 @@ async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed):
seqno=b"\x00" * 8,
)
if is_topic_1_val_passed and is_topic_2_val_passed:
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
else:
with pytest.raises(ValidationError):
with patch(
"libp2p.pubsub.pubsub.Pubsub._run_async_validator",
new=mock_run_async_validator,
):
if is_topic_1_val_passed and is_topic_2_val_passed:
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
else:
with pytest.raises(ValidationError):
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
assert state["max_observed"] <= CONCURRENCY_LIMIT, (
f"Max concurrency observed: {state['max_observed']}"
)
@pytest.mark.trio

View File

@ -0,0 +1,590 @@
from typing import Any, cast
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import trio
from trio.testing import wait_all_tasks_blocked
from libp2p.stream_muxer.exceptions import (
MuxedConnUnavailable,
)
from libp2p.stream_muxer.mplex.constants import HeaderTags
from libp2p.stream_muxer.mplex.datastructures import StreamID
from libp2p.stream_muxer.mplex.exceptions import (
MplexStreamClosed,
MplexStreamEOF,
MplexStreamReset,
)
from libp2p.stream_muxer.mplex.mplex_stream import MplexStream
class MockMuxedConn:
"""A mock Mplex connection for testing purposes."""
def __init__(self):
self.sent_messages = []
self.streams: dict[StreamID, MplexStream] = {}
self.streams_lock = trio.Lock()
self.is_unavailable = False
async def send_message(
self, flag: HeaderTags, data: bytes | None, stream_id: StreamID
) -> None:
"""Mocks sending a message over the connection."""
if self.is_unavailable:
raise MuxedConnUnavailable("Connection is unavailable")
self.sent_messages.append((flag, data, stream_id))
# Yield to allow other tasks to run
await trio.lowlevel.checkpoint()
def get_remote_address(self) -> tuple[str, int]:
"""Mocks getting the remote address."""
return "127.0.0.1", 4001
@pytest.fixture
async def mplex_stream():
"""Provides a fully initialized MplexStream and its communication channels."""
# Use a buffered channel to prevent deadlocks in simple tests
send_chan, recv_chan = trio.open_memory_channel(10)
stream_id = StreamID(1, is_initiator=True)
muxed_conn = MockMuxedConn()
stream = MplexStream("test-stream", stream_id, cast(Any, muxed_conn), recv_chan)
muxed_conn.streams[stream_id] = stream
yield stream, send_chan, muxed_conn
# Cleanup: Close channels and reset stream state
await send_chan.aclose()
await recv_chan.aclose()
# Reset stream state to prevent cross-test contamination
stream.event_local_closed = trio.Event()
stream.event_remote_closed = trio.Event()
stream.event_reset = trio.Event()
# ===============================================
# 1. Tests for Stream-Level Lock Integration
# ===============================================
@pytest.mark.trio
async def test_stream_write_is_protected_by_rwlock(mplex_stream):
"""Verify that stream.write() acquires and releases the write lock."""
stream, _, muxed_conn = mplex_stream
# Mock lock methods
original_acquire = stream.rw_lock.acquire_write
original_release = stream.rw_lock.release_write
stream.rw_lock.acquire_write = AsyncMock(wraps=original_acquire)
stream.rw_lock.release_write = MagicMock(wraps=original_release)
await stream.write(b"test data")
stream.rw_lock.acquire_write.assert_awaited_once()
stream.rw_lock.release_write.assert_called_once()
# Verify the message was actually sent
assert len(muxed_conn.sent_messages) == 1
flag, data, stream_id = muxed_conn.sent_messages[0]
assert flag == HeaderTags.MessageInitiator
assert data == b"test data"
assert stream_id == stream.stream_id
@pytest.mark.trio
async def test_stream_read_is_protected_by_rwlock(mplex_stream):
"""Verify that stream.read() acquires and releases the read lock."""
stream, send_chan, _ = mplex_stream
# Mock lock methods
original_acquire = stream.rw_lock.acquire_read
original_release = stream.rw_lock.release_read
stream.rw_lock.acquire_read = AsyncMock(wraps=original_acquire)
stream.rw_lock.release_read = AsyncMock(wraps=original_release)
await send_chan.send(b"hello")
result = await stream.read(5)
stream.rw_lock.acquire_read.assert_awaited_once()
stream.rw_lock.release_read.assert_awaited_once()
assert result == b"hello"
@pytest.mark.trio
async def test_multiple_readers_can_coexist(mplex_stream):
"""Verify multiple readers can operate concurrently."""
stream, send_chan, _ = mplex_stream
# Send enough data for both reads
await send_chan.send(b"data1")
await send_chan.send(b"data2")
# Track lock acquisition order
acquisition_order = []
release_order = []
# Patch lock methods to track concurrency
original_acquire = stream.rw_lock.acquire_read
original_release = stream.rw_lock.release_read
async def tracked_acquire():
nonlocal acquisition_order
acquisition_order.append("start")
await original_acquire()
acquisition_order.append("acquired")
async def tracked_release():
nonlocal release_order
release_order.append("start")
await original_release()
release_order.append("released")
with (
patch.object(
stream.rw_lock, "acquire_read", side_effect=tracked_acquire, autospec=True
),
patch.object(
stream.rw_lock, "release_read", side_effect=tracked_release, autospec=True
),
):
# Execute concurrent reads
async with trio.open_nursery() as nursery:
nursery.start_soon(stream.read, 5)
nursery.start_soon(stream.read, 5)
# Verify both reads happened
assert acquisition_order.count("start") == 2
assert acquisition_order.count("acquired") == 2
assert release_order.count("start") == 2
assert release_order.count("released") == 2
@pytest.mark.trio
async def test_writer_blocks_readers(mplex_stream):
"""Verify that a writer blocks all readers and new readers queue behind."""
stream, send_chan, _ = mplex_stream
writer_acquired = trio.Event()
readers_ready = trio.Event()
writer_finished = trio.Event()
all_readers_started = trio.Event()
all_readers_done = trio.Event()
counters = {"reader_start_count": 0, "reader_done_count": 0}
reader_target = 3
reader_start_lock = trio.Lock()
# Patch write lock to control test flow
original_acquire_write = stream.rw_lock.acquire_write
original_release_write = stream.rw_lock.release_write
async def tracked_acquire_write():
await original_acquire_write()
writer_acquired.set()
# Wait for readers to queue up
await readers_ready.wait()
# Must be synchronous since real release_write is sync
def tracked_release_write():
original_release_write()
writer_finished.set()
with (
patch.object(
stream.rw_lock, "acquire_write", side_effect=tracked_acquire_write
),
patch.object(
stream.rw_lock, "release_write", side_effect=tracked_release_write
),
):
async with trio.open_nursery() as nursery:
# Start writer
nursery.start_soon(stream.write, b"test")
await writer_acquired.wait()
# Start readers
async def reader_task():
async with reader_start_lock:
counters["reader_start_count"] += 1
if counters["reader_start_count"] == reader_target:
all_readers_started.set()
try:
# This will block until data is available
await stream.read(5)
except (MplexStreamReset, MplexStreamEOF):
pass
finally:
async with reader_start_lock:
counters["reader_done_count"] += 1
if counters["reader_done_count"] == reader_target:
all_readers_done.set()
for _ in range(reader_target):
nursery.start_soon(reader_task)
# Wait until all readers are started
await all_readers_started.wait()
# Let the writer finish and release the lock
readers_ready.set()
await writer_finished.wait()
# Send data to unblock the readers
for i in range(reader_target):
await send_chan.send(b"data" + str(i).encode())
# Wait for all readers to finish
await all_readers_done.wait()
@pytest.mark.trio
async def test_writer_waits_for_readers(mplex_stream):
"""Verify a writer waits for existing readers to complete."""
stream, send_chan, _ = mplex_stream
readers_started = trio.Event()
writer_entered = trio.Event()
writer_acquiring = trio.Event()
readers_finished = trio.Event()
# Send data for readers
await send_chan.send(b"data1")
await send_chan.send(b"data2")
# Patch read lock to control test flow
original_acquire_read = stream.rw_lock.acquire_read
async def tracked_acquire_read():
await original_acquire_read()
readers_started.set()
# Wait until readers are allowed to finish
await readers_finished.wait()
# Patch write lock to detect when writer is blocked
original_acquire_write = stream.rw_lock.acquire_write
async def tracked_acquire_write():
writer_acquiring.set()
await original_acquire_write()
writer_entered.set()
with (
patch.object(stream.rw_lock, "acquire_read", side_effect=tracked_acquire_read),
patch.object(
stream.rw_lock, "acquire_write", side_effect=tracked_acquire_write
),
):
async with trio.open_nursery() as nursery:
# Start readers
nursery.start_soon(stream.read, 5)
nursery.start_soon(stream.read, 5)
# Wait for at least one reader to acquire the lock
await readers_started.wait()
# Start writer (should block)
nursery.start_soon(stream.write, b"test")
# Wait for writer to start acquiring lock
await writer_acquiring.wait()
# Verify writer hasn't entered critical section
assert not writer_entered.is_set()
# Allow readers to finish
readers_finished.set()
# Verify writer can proceed
await writer_entered.wait()
@pytest.mark.trio
async def test_lock_behavior_during_cancellation(mplex_stream):
"""Verify that a lock is released when a task holding it is cancelled."""
stream, _, _ = mplex_stream
reader_acquired_lock = trio.Event()
async def cancellable_reader(task_status):
async with stream.rw_lock.read_lock():
reader_acquired_lock.set()
task_status.started()
# Wait indefinitely until cancelled.
await trio.sleep_forever()
async with trio.open_nursery() as nursery:
# Start the reader and wait for it to acquire the lock.
await nursery.start(cancellable_reader)
await reader_acquired_lock.wait()
# Now that the reader has the lock, cancel the nursery.
# This will cancel the reader task, and its lock should be released.
nursery.cancel_scope.cancel()
# After the nursery is cancelled, the reader should have released the lock.
# To verify, we try to acquire a write lock. If the read lock was not
# released, this will time out.
with trio.move_on_after(1) as cancel_scope:
async with stream.rw_lock.write_lock():
pass
if cancel_scope.cancelled_caught:
pytest.fail(
"Write lock could not be acquired after a cancelled reader, "
"indicating the read lock was not released."
)
@pytest.mark.trio
async def test_concurrent_read_write_sequence(mplex_stream):
"""Verify complex sequence of interleaved reads and writes."""
stream, send_chan, _ = mplex_stream
results = []
# Use a mock to intercept writes and feed them back to the read channel
original_write = stream.write
reader1_finished = trio.Event()
writer1_finished = trio.Event()
reader2_finished = trio.Event()
async def mocked_write(data: bytes) -> None:
await original_write(data)
# Simulate the other side receiving the data and sending a response
# by putting data into the read channel.
await send_chan.send(data)
with patch.object(stream, "write", wraps=mocked_write) as patched_write:
async with trio.open_nursery() as nursery:
# Test scenario:
# 1. Reader 1 starts, waits for data.
# 2. Writer 1 writes, which gets fed back to the stream.
# 3. Reader 2 starts, reads what Writer 1 wrote.
# 4. Writer 2 writes.
async def reader1():
nonlocal results
results.append("R1 start")
data = await stream.read(5)
results.append(data)
results.append("R1 done")
reader1_finished.set()
async def writer1():
nonlocal results
await reader1_finished.wait()
results.append("W1 start")
await stream.write(b"write1")
results.append("W1 done")
writer1_finished.set()
async def reader2():
nonlocal results
await writer1_finished.wait()
# This will read the data from writer1
results.append("R2 start")
data = await stream.read(6)
results.append(data)
results.append("R2 done")
reader2_finished.set()
async def writer2():
nonlocal results
await reader2_finished.wait()
results.append("W2 start")
await stream.write(b"write2")
results.append("W2 done")
# Execute sequence
nursery.start_soon(reader1)
nursery.start_soon(writer1)
nursery.start_soon(reader2)
nursery.start_soon(writer2)
await send_chan.send(b"data1")
# Verify sequence and that write was called
assert patched_write.call_count == 2
assert results == [
"R1 start",
b"data1",
"R1 done",
"W1 start",
"W1 done",
"R2 start",
b"write1",
"R2 done",
"W2 start",
"W2 done",
]
# ===============================================
# 2. Tests for Reset, EOF, and Close Interactions
# ===============================================
@pytest.mark.trio
async def test_read_after_remote_close_triggers_eof(mplex_stream):
"""Verify reading from a remotely closed stream returns EOF correctly."""
stream, send_chan, _ = mplex_stream
# Send some data that can be read first
await send_chan.send(b"data")
# Close the channel to signify no more data will ever arrive
await send_chan.aclose()
# Mark the stream as remotely closed
stream.event_remote_closed.set()
# The first read should succeed, consuming the buffered data
data = await stream.read(4)
assert data == b"data"
# Now that the buffer is empty and the channel is closed, this should raise EOF
with pytest.raises(MplexStreamEOF):
await stream.read(1)
@pytest.mark.trio
async def test_read_on_closed_stream_raises_eof(mplex_stream):
"""Test that reading from a closed stream with no data raises EOF."""
stream, send_chan, _ = mplex_stream
stream.event_remote_closed.set()
await send_chan.aclose() # Ensure the channel is closed
# Reading from a stream that is closed and has no data should raise EOF
with pytest.raises(MplexStreamEOF):
await stream.read(100)
@pytest.mark.trio
async def test_write_to_locally_closed_stream_raises(mplex_stream):
"""Verify writing to a locally closed stream raises MplexStreamClosed."""
stream, _, _ = mplex_stream
stream.event_local_closed.set()
with pytest.raises(MplexStreamClosed):
await stream.write(b"this should fail")
@pytest.mark.trio
async def test_read_from_reset_stream_raises(mplex_stream):
"""Verify reading from a reset stream raises MplexStreamReset."""
stream, _, _ = mplex_stream
stream.event_reset.set()
with pytest.raises(MplexStreamReset):
await stream.read(10)
@pytest.mark.trio
async def test_write_to_reset_stream_raises(mplex_stream):
"""Verify writing to a reset stream raises MplexStreamClosed."""
stream, _, _ = mplex_stream
# A stream reset implies it's also locally closed.
await stream.reset()
# The `write` method checks `event_local_closed`, which `reset` sets.
with pytest.raises(MplexStreamClosed):
await stream.write(b"this should also fail")
@pytest.mark.trio
async def test_stream_reset_cleans_up_resources(mplex_stream):
"""Verify reset() cleans up stream state and resources."""
stream, _, muxed_conn = mplex_stream
stream_id = stream.stream_id
assert stream_id in muxed_conn.streams
await stream.reset()
assert stream.event_reset.is_set()
assert stream.event_local_closed.is_set()
assert stream.event_remote_closed.is_set()
assert (HeaderTags.ResetInitiator, None, stream_id) in muxed_conn.sent_messages
assert stream_id not in muxed_conn.streams
# Verify the underlying data channel is closed
with pytest.raises(trio.ClosedResourceError):
await stream.incoming_data_channel.receive()
# ===============================================
# 3. Rigorous Concurrency Tests with Events
# ===============================================
@pytest.mark.trio
async def test_writer_is_blocked_by_reader_using_events(mplex_stream):
"""Verify a writer must wait for a reader using trio.Event for synchronization."""
stream, _, _ = mplex_stream
reader_has_lock = trio.Event()
writer_finished = trio.Event()
async def reader():
async with stream.rw_lock.read_lock():
reader_has_lock.set()
# Hold the lock until the writer has finished its attempt
await writer_finished.wait()
async def writer():
await reader_has_lock.wait()
# This call will now block until the reader releases the lock
await stream.write(b"data")
writer_finished.set()
async with trio.open_nursery() as nursery:
nursery.start_soon(reader)
nursery.start_soon(writer)
# Verify writer is blocked
await wait_all_tasks_blocked()
assert not writer_finished.is_set()
# Signal the reader to finish
writer_finished.set()
@pytest.mark.trio
async def test_multiple_readers_can_read_concurrently_using_events(mplex_stream):
"""Verify that multiple readers can acquire a read lock simultaneously."""
stream, _, _ = mplex_stream
counters = {"readers_in_critical_section": 0}
lock = trio.Lock() # To safely mutate the counter
reader1_acquired = trio.Event()
reader2_acquired = trio.Event()
all_readers_finished = trio.Event()
async def concurrent_reader(event_to_set: trio.Event):
async with stream.rw_lock.read_lock():
async with lock:
counters["readers_in_critical_section"] += 1
event_to_set.set()
# Wait until all readers have finished before exiting the lock context
await all_readers_finished.wait()
async with lock:
counters["readers_in_critical_section"] -= 1
async with trio.open_nursery() as nursery:
nursery.start_soon(concurrent_reader, reader1_acquired)
nursery.start_soon(concurrent_reader, reader2_acquired)
# Wait for both readers to acquire their locks
await reader1_acquired.wait()
await reader2_acquired.wait()
# Check that both were in the critical section at the same time
async with lock:
assert counters["readers_in_critical_section"] == 2
# Signal for all readers to finish
all_readers_finished.set()
# Verify they exit cleanly
await wait_all_tasks_blocked()
async with lock:
assert counters["readers_in_critical_section"] == 0