Merge branch 'main' into add-ws-transport

This commit is contained in:
GautamBytes
2025-07-20 09:44:11 +00:00
10 changed files with 1144 additions and 480 deletions

View File

@ -72,13 +72,46 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No
client_addr = server_addr.replace("/ip4/0.0.0.0/", "/ip4/127.0.0.1/")
format_name = "length-prefixed" if use_varint_format else "raw protobuf"
format_flag = "--raw-format" if not use_varint_format else ""
print(
f"First host listening (using {format_name} format). "
f"Run this from another console:\n\n"
f"identify-demo "
f"-d {client_addr}\n"
f"identify-demo {format_flag} -d {client_addr}\n"
)
print("Waiting for incoming identify request...")
# Add a custom handler to show connection events
async def custom_identify_handler(stream):
peer_id = stream.muxed_conn.peer_id
print(f"\n🔗 Received identify request from peer: {peer_id}")
# Show remote address in multiaddr format
try:
from libp2p.identity.identify.identify import (
_remote_address_to_multiaddr,
)
remote_address = stream.get_remote_address()
if remote_address:
observed_multiaddr = _remote_address_to_multiaddr(
remote_address
)
# Add the peer ID to create a complete multiaddr
complete_multiaddr = f"{observed_multiaddr}/p2p/{peer_id}"
print(f" Remote address: {complete_multiaddr}")
else:
print(f" Remote address: {remote_address}")
except Exception:
print(f" Remote address: {stream.get_remote_address()}")
# Call the original handler
await identify_handler(stream)
print(f"✅ Successfully processed identify request from {peer_id}")
# Replace the handler with our custom one
host_a.set_stream_handler(IDENTIFY_PROTOCOL_ID, custom_identify_handler)
await trio.sleep_forever()
else:
@ -93,25 +126,99 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No
info = info_from_p2p_addr(maddr)
print(f"Second host connecting to peer: {info.peer_id}")
await host_b.connect(info)
try:
await host_b.connect(info)
except Exception as e:
error_msg = str(e)
if "unable to connect" in error_msg or "SwarmException" in error_msg:
print(f"\n❌ Cannot connect to peer: {info.peer_id}")
print(f" Address: {destination}")
print(f" Error: {error_msg}")
print(
"\n💡 Make sure the peer is running and the address is correct."
)
return
else:
# Re-raise other exceptions
raise
stream = await host_b.new_stream(info.peer_id, (IDENTIFY_PROTOCOL_ID,))
try:
print("Starting identify protocol...")
# Read the complete response (could be either format)
# Read a larger chunk to get all the data before stream closes
response = await stream.read(8192) # Read enough data in one go
# Read the response properly based on the format
if use_varint_format:
# For length-prefixed format, read varint length first
from libp2p.utils.varint import decode_varint_from_bytes
# Read varint length prefix
length_bytes = b""
while True:
b = await stream.read(1)
if not b:
raise Exception("Stream closed while reading varint length")
length_bytes += b
if b[0] & 0x80 == 0:
break
msg_length = decode_varint_from_bytes(length_bytes)
print(f"Expected message length: {msg_length} bytes")
# Read the protobuf message
response = await stream.read(msg_length)
if len(response) != msg_length:
raise Exception(
f"Incomplete message: expected {msg_length} bytes, "
f"got {len(response)}"
)
# Combine length prefix and message
full_response = length_bytes + response
else:
# For raw format, read all available data
response = await stream.read(8192)
full_response = response
await stream.close()
# Parse the response using the robust protocol-level function
# This handles both old and new formats automatically
identify_msg = parse_identify_response(response)
identify_msg = parse_identify_response(full_response)
print_identify_response(identify_msg)
except Exception as e:
print(f"Identify protocol error: {e}")
error_msg = str(e)
print(f"Identify protocol error: {error_msg}")
# Check for specific format mismatch errors
if "Error parsing message" in error_msg or "DecodeError" in error_msg:
print("\n" + "=" * 60)
print("FORMAT MISMATCH DETECTED!")
print("=" * 60)
if use_varint_format:
print(
"You are using length-prefixed format (default) but the "
"listener"
)
print("is using raw protobuf format.")
print(
"\nTo fix this, run the dialer with the --raw-format flag:"
)
print(f"identify-demo --raw-format -d {destination}")
else:
print("You are using raw protobuf format but the listener")
print("is using length-prefixed format (default).")
print(
"\nTo fix this, run the dialer without the --raw-format "
"flag:"
)
print(f"identify-demo -d {destination}")
print("=" * 60)
else:
import traceback
traceback.print_exc()
return

View File

@ -102,6 +102,9 @@ class TopicValidator(NamedTuple):
is_async: bool
MAX_CONCURRENT_VALIDATORS = 10
class Pubsub(Service, IPubsub):
host: IHost
@ -109,6 +112,7 @@ class Pubsub(Service, IPubsub):
peer_receive_channel: trio.MemoryReceiveChannel[ID]
dead_peer_receive_channel: trio.MemoryReceiveChannel[ID]
_validator_semaphore: trio.Semaphore
seen_messages: LastSeenCache
@ -143,6 +147,7 @@ class Pubsub(Service, IPubsub):
msg_id_constructor: Callable[
[rpc_pb2.Message], bytes
] = get_peer_and_seqno_msg_id,
max_concurrent_validator_count: int = MAX_CONCURRENT_VALIDATORS,
) -> None:
"""
Construct a new Pubsub object, which is responsible for handling all
@ -168,6 +173,7 @@ class Pubsub(Service, IPubsub):
# Therefore, we can only close from the receive side.
self.peer_receive_channel = peer_receive
self.dead_peer_receive_channel = dead_peer_receive
self._validator_semaphore = trio.Semaphore(max_concurrent_validator_count)
# Register a notifee
self.host.get_network().register_notifee(
PubsubNotifee(peer_send, dead_peer_send)
@ -657,7 +663,11 @@ class Pubsub(Service, IPubsub):
logger.debug("successfully published message %s", msg)
async def validate_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
async def validate_msg(
self,
msg_forwarder: ID,
msg: rpc_pb2.Message,
) -> None:
"""
Validate the received message.
@ -680,23 +690,34 @@ class Pubsub(Service, IPubsub):
if not validator(msg_forwarder, msg):
raise ValidationError(f"Validation failed for msg={msg}")
# TODO: Implement throttle on async validators
if len(async_topic_validators) > 0:
# Appends to lists are thread safe in CPython
results = []
async def run_async_validator(func: AsyncValidatorFn) -> None:
result = await func(msg_forwarder, msg)
results.append(result)
results: list[bool] = []
async with trio.open_nursery() as nursery:
for async_validator in async_topic_validators:
nursery.start_soon(run_async_validator, async_validator)
nursery.start_soon(
self._run_async_validator,
async_validator,
msg_forwarder,
msg,
results,
)
if not all(results):
raise ValidationError(f"Validation failed for msg={msg}")
async def _run_async_validator(
self,
func: AsyncValidatorFn,
msg_forwarder: ID,
msg: rpc_pb2.Message,
results: list[bool],
) -> None:
async with self._validator_semaphore:
result = await func(msg_forwarder, msg)
results.append(result)
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
"""
Push a pubsub message to others.

View File

@ -1,3 +1,5 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from types import (
TracebackType,
)
@ -32,6 +34,72 @@ if TYPE_CHECKING:
)
class ReadWriteLock:
"""
A read-write lock that allows multiple concurrent readers
or one exclusive writer, implemented using Trio primitives.
"""
def __init__(self) -> None:
self._readers = 0
self._readers_lock = trio.Lock() # Protects access to _readers count
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time
async def acquire_read(self) -> None:
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
try:
async with self._readers_lock:
if self._readers == 0:
await self._writer_lock.acquire()
self._readers += 1
except trio.Cancelled:
raise
async def release_read(self) -> None:
"""Release a read lock."""
async with self._readers_lock:
if self._readers == 1:
self._writer_lock.release()
self._readers -= 1
async def acquire_write(self) -> None:
"""Acquire an exclusive write lock."""
try:
await self._writer_lock.acquire()
except trio.Cancelled:
raise
def release_write(self) -> None:
"""Release the exclusive write lock."""
self._writer_lock.release()
@asynccontextmanager
async def read_lock(self) -> AsyncGenerator[None, None]:
"""Context manager for acquiring and releasing a read lock safely."""
acquire = False
try:
await self.acquire_read()
acquire = True
yield
finally:
if acquire:
with trio.CancelScope() as scope:
scope.shield = True
await self.release_read()
@asynccontextmanager
async def write_lock(self) -> AsyncGenerator[None, None]:
"""Context manager for acquiring and releasing a write lock safely."""
acquire = False
try:
await self.acquire_write()
acquire = True
yield
finally:
if acquire:
self.release_write()
class MplexStream(IMuxedStream):
"""
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go
@ -46,7 +114,7 @@ class MplexStream(IMuxedStream):
read_deadline: int | None
write_deadline: int | None
# TODO: Add lock for read/write to avoid interleaving receiving messages?
rw_lock: ReadWriteLock
close_lock: trio.Lock
# NOTE: `dataIn` is size of 8 in Go implementation.
@ -80,6 +148,7 @@ class MplexStream(IMuxedStream):
self.event_remote_closed = trio.Event()
self.event_reset = trio.Event()
self.close_lock = trio.Lock()
self.rw_lock = ReadWriteLock()
self.incoming_data_channel = incoming_data_channel
self._buf = bytearray()
@ -113,48 +182,49 @@ class MplexStream(IMuxedStream):
:param n: number of bytes to read
:return: bytes actually read
"""
if n is not None and n < 0:
raise ValueError(
"the number of bytes to read `n` must be non-negative or "
f"`None` to indicate read until EOF, got n={n}"
)
if self.event_reset.is_set():
raise MplexStreamReset
if n is None:
return await self._read_until_eof()
if len(self._buf) == 0:
data: bytes
# Peek whether there is data available. If yes, we just read until there is
# no data, then return.
try:
data = self.incoming_data_channel.receive_nowait()
self._buf.extend(data)
except trio.EndOfChannel:
raise MplexStreamEOF
except trio.WouldBlock:
# We know `receive` will be blocked here. Wait for data here with
# `receive` and catch all kinds of errors here.
async with self.rw_lock.read_lock():
if n is not None and n < 0:
raise ValueError(
"the number of bytes to read `n` must be non-negative or "
f"`None` to indicate read until EOF, got n={n}"
)
if self.event_reset.is_set():
raise MplexStreamReset
if n is None:
return await self._read_until_eof()
if len(self._buf) == 0:
data: bytes
# Peek whether there is data available. If yes, we just read until
# there is no data, then return.
try:
data = await self.incoming_data_channel.receive()
data = self.incoming_data_channel.receive_nowait()
self._buf.extend(data)
except trio.EndOfChannel:
if self.event_reset.is_set():
raise MplexStreamReset
if self.event_remote_closed.is_set():
raise MplexStreamEOF
except trio.ClosedResourceError as error:
# Probably `incoming_data_channel` is closed in `reset` when we are
# waiting for `receive`.
if self.event_reset.is_set():
raise MplexStreamReset
raise Exception(
"`incoming_data_channel` is closed but stream is not reset. "
"This should never happen."
) from error
self._buf.extend(self._read_return_when_blocked())
payload = self._buf[:n]
self._buf = self._buf[len(payload) :]
return bytes(payload)
raise MplexStreamEOF
except trio.WouldBlock:
# We know `receive` will be blocked here. Wait for data here with
# `receive` and catch all kinds of errors here.
try:
data = await self.incoming_data_channel.receive()
self._buf.extend(data)
except trio.EndOfChannel:
if self.event_reset.is_set():
raise MplexStreamReset
if self.event_remote_closed.is_set():
raise MplexStreamEOF
except trio.ClosedResourceError as error:
# Probably `incoming_data_channel` is closed in `reset` when
# we are waiting for `receive`.
if self.event_reset.is_set():
raise MplexStreamReset
raise Exception(
"`incoming_data_channel` is closed but stream is not reset."
"This should never happen."
) from error
self._buf.extend(self._read_return_when_blocked())
payload = self._buf[:n]
self._buf = self._buf[len(payload) :]
return bytes(payload)
async def write(self, data: bytes) -> None:
"""
@ -162,14 +232,15 @@ class MplexStream(IMuxedStream):
:return: number of bytes written
"""
if self.event_local_closed.is_set():
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")
flag = (
HeaderTags.MessageInitiator
if self.is_initiator
else HeaderTags.MessageReceiver
)
await self.muxed_conn.send_message(flag, data, self.stream_id)
async with self.rw_lock.write_lock():
if self.event_local_closed.is_set():
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")
flag = (
HeaderTags.MessageInitiator
if self.is_initiator
else HeaderTags.MessageReceiver
)
await self.muxed_conn.send_message(flag, data, self.stream_id)
async def close(self) -> None:
"""

View File

@ -0,0 +1 @@
Add lock for read/write to avoid interleaving receiving messages in mplex_stream.py

View File

@ -0,0 +1,2 @@
Added throttling for async topic validators in validate_msg, enforcing a
concurrency limit to prevent resource exhaustion under heavy load.

View File

@ -0,0 +1 @@
Fixed incorrect handling of raw protobuf format in identify protocol. The identify example now properly handles both raw and length-prefixed (varint) message formats, provides better error messages, and displays connection status with peer IDs. Replaced mock-based tests with comprehensive real network integration tests for both formats.

View File

@ -0,0 +1,241 @@
import logging
import pytest
from libp2p.custom_types import TProtocol
from libp2p.identity.identify.identify import (
AGENT_VERSION,
ID,
PROTOCOL_VERSION,
_multiaddr_to_bytes,
identify_handler_for,
parse_identify_response,
)
from tests.utils.factories import host_pair_factory
logger = logging.getLogger("libp2p.identity.identify-integration-test")
@pytest.mark.trio
async def test_identify_protocol_varint_format_integration(security_protocol):
"""Test identify protocol with varint format in real network scenario."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=True)
)
# Make identify request
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
assert result.listen_addrs == [
_multiaddr_to_bytes(addr) for addr in host_a.get_addrs()
]
@pytest.mark.trio
async def test_identify_protocol_raw_format_integration(security_protocol):
"""Test identify protocol with raw format in real network scenario."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=False)
)
# Make identify request
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
assert result.listen_addrs == [
_multiaddr_to_bytes(addr) for addr in host_a.get_addrs()
]
@pytest.mark.trio
async def test_identify_default_format_behavior(security_protocol):
"""Test identify protocol uses correct default format."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Use default identify handler (should use varint format)
host_a.set_stream_handler(ID, identify_handler_for(host_a))
# Make identify request
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
@pytest.mark.trio
async def test_identify_cross_format_compatibility_varint_to_raw(security_protocol):
"""Test varint dialer with raw listener compatibility."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Host A uses raw format
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=False)
)
# Host B makes request (will automatically detect format)
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response (should work with automatic format detection)
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
@pytest.mark.trio
async def test_identify_cross_format_compatibility_raw_to_varint(security_protocol):
"""Test raw dialer with varint listener compatibility."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Host A uses varint format
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=True)
)
# Host B makes request (will automatically detect format)
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response (should work with automatic format detection)
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
@pytest.mark.trio
async def test_identify_format_detection_robustness(security_protocol):
"""Test identify protocol format detection is robust with various message sizes."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Test both formats with different message sizes
for use_varint in [True, False]:
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=use_varint)
)
# Make identify request
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
@pytest.mark.trio
async def test_identify_large_message_handling(security_protocol):
"""Test identify protocol handles large messages with many protocols."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Add many protocols to make the message larger
async def dummy_handler(stream):
pass
for i in range(10):
host_a.set_stream_handler(TProtocol(f"/test/protocol/{i}"), dummy_handler)
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=True)
)
# Make identify request
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
@pytest.mark.trio
async def test_identify_message_equivalence_real_network(security_protocol):
"""Test that both formats produce equivalent messages in real network."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Test varint format
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=True)
)
stream_varint = await host_b.new_stream(host_a.get_id(), (ID,))
response_varint = await stream_varint.read(8192)
await stream_varint.close()
# Test raw format
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=False)
)
stream_raw = await host_b.new_stream(host_a.get_id(), (ID,))
response_raw = await stream_raw.read(8192)
await stream_raw.close()
# Parse both responses
result_varint = parse_identify_response(response_varint)
result_raw = parse_identify_response(response_raw)
# Both should produce identical parsed results
assert result_varint.agent_version == result_raw.agent_version
assert result_varint.protocol_version == result_raw.protocol_version
assert result_varint.public_key == result_raw.public_key
assert result_varint.listen_addrs == result_raw.listen_addrs

View File

@ -1,410 +0,0 @@
import pytest
from libp2p.identity.identify.identify import (
_mk_identify_protobuf,
)
from libp2p.identity.identify.pb.identify_pb2 import (
Identify,
)
from libp2p.io.abc import Closer, Reader, Writer
from libp2p.utils.varint import (
decode_varint_from_bytes,
encode_varint_prefixed,
)
from tests.utils.factories import (
host_pair_factory,
)
class MockStream(Reader, Writer, Closer):
"""Mock stream for testing identify protocol compatibility."""
def __init__(self, data: bytes):
self.data = data
self.position = 0
self.closed = False
async def read(self, n: int | None = None) -> bytes:
if self.closed or self.position >= len(self.data):
return b""
if n is None:
n = len(self.data) - self.position
result = self.data[self.position : self.position + n]
self.position += len(result)
return result
async def write(self, data: bytes) -> None:
# Mock write - just store the data
pass
async def close(self) -> None:
self.closed = True
def create_identify_message(host, observed_multiaddr=None):
"""Create an identify protobuf message."""
return _mk_identify_protobuf(host, observed_multiaddr)
def create_new_format_message(identify_msg):
"""Create a new format (length-prefixed) identify message."""
msg_bytes = identify_msg.SerializeToString()
return encode_varint_prefixed(msg_bytes)
def create_old_format_message(identify_msg):
"""Create an old format (raw protobuf) identify message."""
return identify_msg.SerializeToString()
async def read_new_format_message(stream) -> bytes:
"""Read a new format (length-prefixed) identify message."""
# Read varint length prefix
length_bytes = b""
while True:
b = await stream.read(1)
if not b:
break
length_bytes += b
if b[0] & 0x80 == 0:
break
if not length_bytes:
raise ValueError("No length prefix received")
msg_length = decode_varint_from_bytes(length_bytes)
# Read the protobuf message
response = await stream.read(msg_length)
if len(response) != msg_length:
raise ValueError("Incomplete message received")
return response
async def read_old_format_message(stream) -> bytes:
"""Read an old format (raw protobuf) identify message."""
# Read all available data
response = b""
while True:
chunk = await stream.read(4096)
if not chunk:
break
response += chunk
return response
async def read_compatible_message(stream) -> bytes:
"""Read an identify message in either old or new format."""
# Try to read a few bytes to detect the format
first_bytes = await stream.read(10)
if not first_bytes:
raise ValueError("No data received")
# Try to decode as varint length prefix (new format)
try:
msg_length = decode_varint_from_bytes(first_bytes)
# Validate that the length is reasonable (not too large)
if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB
# Calculate how many bytes the varint consumed
varint_len = 0
for i, byte in enumerate(first_bytes):
varint_len += 1
if (byte & 0x80) == 0:
break
# Read the remaining protobuf message
remaining_bytes = await stream.read(
msg_length - (len(first_bytes) - varint_len)
)
if len(remaining_bytes) == msg_length - (len(first_bytes) - varint_len):
message_data = first_bytes[varint_len:] + remaining_bytes
# Try to parse as protobuf to validate
try:
Identify().ParseFromString(message_data)
return message_data
except Exception:
# If protobuf parsing fails, fall back to old format
pass
except Exception:
pass
# Fall back to old format (raw protobuf)
response = first_bytes
# Read more data if available
while True:
chunk = await stream.read(4096)
if not chunk:
break
response += chunk
return response
async def read_compatible_message_simple(stream) -> bytes:
"""Read a message in either old or new format (simplified version for testing)."""
# Try to read a few bytes to detect the format
first_bytes = await stream.read(10)
if not first_bytes:
raise ValueError("No data received")
# Try to decode as varint length prefix (new format)
try:
msg_length = decode_varint_from_bytes(first_bytes)
# Validate that the length is reasonable (not too large)
if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB
# Calculate how many bytes the varint consumed
varint_len = 0
for i, byte in enumerate(first_bytes):
varint_len += 1
if (byte & 0x80) == 0:
break
# Read the remaining message
remaining_bytes = await stream.read(
msg_length - (len(first_bytes) - varint_len)
)
if len(remaining_bytes) == msg_length - (len(first_bytes) - varint_len):
return first_bytes[varint_len:] + remaining_bytes
except Exception:
pass
# Fall back to old format (raw data)
response = first_bytes
# Read more data if available
while True:
chunk = await stream.read(4096)
if not chunk:
break
response += chunk
return response
def detect_format(data):
"""Detect if data is in new or old format (varint-prefixed or raw protobuf)."""
if not data:
return "unknown"
# Try to decode as varint
try:
msg_length = decode_varint_from_bytes(data)
# Validate that the length is reasonable
if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB
# Calculate varint length
varint_len = 0
for i, byte in enumerate(data):
varint_len += 1
if (byte & 0x80) == 0:
break
# Check if we have enough data for the message
if len(data) >= varint_len + msg_length:
# Additional check: try to parse the message as protobuf
try:
message_data = data[varint_len : varint_len + msg_length]
Identify().ParseFromString(message_data)
return "new"
except Exception:
# If protobuf parsing fails, it's probably not a valid new format
pass
except Exception:
pass
# If varint decoding fails or length is unreasonable, assume old format
return "old"
@pytest.mark.trio
async def test_identify_new_format_compatibility(security_protocol):
"""Test that identify protocol works with new format (length-prefixed) messages."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Create new format message
new_format_data = create_new_format_message(identify_msg)
# Create mock stream with new format data
stream = MockStream(new_format_data)
# Read using new format reader
response = await read_new_format_message(stream)
# Parse the response
parsed_msg = Identify()
parsed_msg.ParseFromString(response)
# Verify the message content
assert parsed_msg.protocol_version == identify_msg.protocol_version
assert parsed_msg.agent_version == identify_msg.agent_version
assert parsed_msg.public_key == identify_msg.public_key
@pytest.mark.trio
async def test_identify_old_format_compatibility(security_protocol):
"""Test that identify protocol works with old format (raw protobuf) messages."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Create old format message
old_format_data = create_old_format_message(identify_msg)
# Create mock stream with old format data
stream = MockStream(old_format_data)
# Read using old format reader
response = await read_old_format_message(stream)
# Parse the response
parsed_msg = Identify()
parsed_msg.ParseFromString(response)
# Verify the message content
assert parsed_msg.protocol_version == identify_msg.protocol_version
assert parsed_msg.agent_version == identify_msg.agent_version
assert parsed_msg.public_key == identify_msg.public_key
@pytest.mark.trio
async def test_identify_backward_compatibility_old_format(security_protocol):
"""Test backward compatibility reader with old format messages."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Create old format message
old_format_data = create_old_format_message(identify_msg)
# Create mock stream with old format data
stream = MockStream(old_format_data)
# Read using old format reader (which should work reliably)
response = await read_old_format_message(stream)
# Parse the response
parsed_msg = Identify()
parsed_msg.ParseFromString(response)
# Verify the message content
assert parsed_msg.protocol_version == identify_msg.protocol_version
assert parsed_msg.agent_version == identify_msg.agent_version
assert parsed_msg.public_key == identify_msg.public_key
@pytest.mark.trio
async def test_identify_backward_compatibility_new_format(security_protocol):
"""Test backward compatibility reader with new format messages."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Create new format message
new_format_data = create_new_format_message(identify_msg)
# Create mock stream with new format data
stream = MockStream(new_format_data)
# Read using new format reader (which should work reliably)
response = await read_new_format_message(stream)
# Parse the response
parsed_msg = Identify()
parsed_msg.ParseFromString(response)
# Verify the message content
assert parsed_msg.protocol_version == identify_msg.protocol_version
assert parsed_msg.agent_version == identify_msg.agent_version
assert parsed_msg.public_key == identify_msg.public_key
@pytest.mark.trio
async def test_identify_format_detection(security_protocol):
"""Test that the format detection works correctly."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Test new format detection
new_format_data = create_new_format_message(identify_msg)
format_type = detect_format(new_format_data)
assert format_type == "new", "New format should be detected correctly"
# Test old format detection
old_format_data = create_old_format_message(identify_msg)
format_type = detect_format(old_format_data)
assert format_type == "old", "Old format should be detected correctly"
@pytest.mark.trio
async def test_identify_error_handling(security_protocol):
"""Test error handling for malformed messages."""
from libp2p.exceptions import ParseError
# Test with empty data
stream = MockStream(b"")
with pytest.raises(ValueError, match="No data received"):
await read_compatible_message(stream)
# Test with incomplete varint
stream = MockStream(b"\x80") # Incomplete varint
with pytest.raises(ParseError, match="Unexpected end of data"):
await read_new_format_message(stream)
# Test with invalid protobuf data
stream = MockStream(b"\x05invalid") # Length prefix but invalid protobuf
with pytest.raises(Exception): # Should fail when parsing protobuf
response = await read_new_format_message(stream)
Identify().ParseFromString(response)
@pytest.mark.trio
async def test_identify_message_equivalence(security_protocol):
"""Test that old and new format messages are equivalent."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Create both formats
new_format_data = create_new_format_message(identify_msg)
old_format_data = create_old_format_message(identify_msg)
# Extract the protobuf message from new format
varint_len = 0
for i, byte in enumerate(new_format_data):
varint_len += 1
if (byte & 0x80) == 0:
break
new_format_protobuf = new_format_data[varint_len:]
# The protobuf messages should be identical
assert new_format_protobuf == old_format_data, (
"Protobuf messages should be identical in both formats"
)

View File

@ -5,10 +5,12 @@ import inspect
from typing import (
NamedTuple,
)
from unittest.mock import patch
import pytest
import trio
from libp2p.custom_types import AsyncValidatorFn
from libp2p.exceptions import (
ValidationError,
)
@ -243,7 +245,37 @@ async def test_get_msg_validators():
((False, True), (True, False), (True, True)),
)
@pytest.mark.trio
async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed):
async def test_validate_msg_with_throttle_condition(
is_topic_1_val_passed, is_topic_2_val_passed
):
CONCURRENCY_LIMIT = 10
state = {
"concurrency_counter": 0,
"max_observed": 0,
}
lock = trio.Lock()
async def mock_run_async_validator(
self,
func: AsyncValidatorFn,
msg_forwarder: ID,
msg: rpc_pb2.Message,
results: list[bool],
) -> None:
async with self._validator_semaphore:
async with lock:
state["concurrency_counter"] += 1
if state["concurrency_counter"] > state["max_observed"]:
state["max_observed"] = state["concurrency_counter"]
try:
result = await func(msg_forwarder, msg)
results.append(result)
finally:
async with lock:
state["concurrency_counter"] -= 1
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
def passed_sync_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool:
@ -280,11 +312,19 @@ async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed):
seqno=b"\x00" * 8,
)
if is_topic_1_val_passed and is_topic_2_val_passed:
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
else:
with pytest.raises(ValidationError):
with patch(
"libp2p.pubsub.pubsub.Pubsub._run_async_validator",
new=mock_run_async_validator,
):
if is_topic_1_val_passed and is_topic_2_val_passed:
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
else:
with pytest.raises(ValidationError):
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
assert state["max_observed"] <= CONCURRENCY_LIMIT, (
f"Max concurrency observed: {state['max_observed']}"
)
@pytest.mark.trio

View File

@ -0,0 +1,590 @@
from typing import Any, cast
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import trio
from trio.testing import wait_all_tasks_blocked
from libp2p.stream_muxer.exceptions import (
MuxedConnUnavailable,
)
from libp2p.stream_muxer.mplex.constants import HeaderTags
from libp2p.stream_muxer.mplex.datastructures import StreamID
from libp2p.stream_muxer.mplex.exceptions import (
MplexStreamClosed,
MplexStreamEOF,
MplexStreamReset,
)
from libp2p.stream_muxer.mplex.mplex_stream import MplexStream
class MockMuxedConn:
"""A mock Mplex connection for testing purposes."""
def __init__(self):
self.sent_messages = []
self.streams: dict[StreamID, MplexStream] = {}
self.streams_lock = trio.Lock()
self.is_unavailable = False
async def send_message(
self, flag: HeaderTags, data: bytes | None, stream_id: StreamID
) -> None:
"""Mocks sending a message over the connection."""
if self.is_unavailable:
raise MuxedConnUnavailable("Connection is unavailable")
self.sent_messages.append((flag, data, stream_id))
# Yield to allow other tasks to run
await trio.lowlevel.checkpoint()
def get_remote_address(self) -> tuple[str, int]:
"""Mocks getting the remote address."""
return "127.0.0.1", 4001
@pytest.fixture
async def mplex_stream():
"""Provides a fully initialized MplexStream and its communication channels."""
# Use a buffered channel to prevent deadlocks in simple tests
send_chan, recv_chan = trio.open_memory_channel(10)
stream_id = StreamID(1, is_initiator=True)
muxed_conn = MockMuxedConn()
stream = MplexStream("test-stream", stream_id, cast(Any, muxed_conn), recv_chan)
muxed_conn.streams[stream_id] = stream
yield stream, send_chan, muxed_conn
# Cleanup: Close channels and reset stream state
await send_chan.aclose()
await recv_chan.aclose()
# Reset stream state to prevent cross-test contamination
stream.event_local_closed = trio.Event()
stream.event_remote_closed = trio.Event()
stream.event_reset = trio.Event()
# ===============================================
# 1. Tests for Stream-Level Lock Integration
# ===============================================
@pytest.mark.trio
async def test_stream_write_is_protected_by_rwlock(mplex_stream):
"""Verify that stream.write() acquires and releases the write lock."""
stream, _, muxed_conn = mplex_stream
# Mock lock methods
original_acquire = stream.rw_lock.acquire_write
original_release = stream.rw_lock.release_write
stream.rw_lock.acquire_write = AsyncMock(wraps=original_acquire)
stream.rw_lock.release_write = MagicMock(wraps=original_release)
await stream.write(b"test data")
stream.rw_lock.acquire_write.assert_awaited_once()
stream.rw_lock.release_write.assert_called_once()
# Verify the message was actually sent
assert len(muxed_conn.sent_messages) == 1
flag, data, stream_id = muxed_conn.sent_messages[0]
assert flag == HeaderTags.MessageInitiator
assert data == b"test data"
assert stream_id == stream.stream_id
@pytest.mark.trio
async def test_stream_read_is_protected_by_rwlock(mplex_stream):
"""Verify that stream.read() acquires and releases the read lock."""
stream, send_chan, _ = mplex_stream
# Mock lock methods
original_acquire = stream.rw_lock.acquire_read
original_release = stream.rw_lock.release_read
stream.rw_lock.acquire_read = AsyncMock(wraps=original_acquire)
stream.rw_lock.release_read = AsyncMock(wraps=original_release)
await send_chan.send(b"hello")
result = await stream.read(5)
stream.rw_lock.acquire_read.assert_awaited_once()
stream.rw_lock.release_read.assert_awaited_once()
assert result == b"hello"
@pytest.mark.trio
async def test_multiple_readers_can_coexist(mplex_stream):
"""Verify multiple readers can operate concurrently."""
stream, send_chan, _ = mplex_stream
# Send enough data for both reads
await send_chan.send(b"data1")
await send_chan.send(b"data2")
# Track lock acquisition order
acquisition_order = []
release_order = []
# Patch lock methods to track concurrency
original_acquire = stream.rw_lock.acquire_read
original_release = stream.rw_lock.release_read
async def tracked_acquire():
nonlocal acquisition_order
acquisition_order.append("start")
await original_acquire()
acquisition_order.append("acquired")
async def tracked_release():
nonlocal release_order
release_order.append("start")
await original_release()
release_order.append("released")
with (
patch.object(
stream.rw_lock, "acquire_read", side_effect=tracked_acquire, autospec=True
),
patch.object(
stream.rw_lock, "release_read", side_effect=tracked_release, autospec=True
),
):
# Execute concurrent reads
async with trio.open_nursery() as nursery:
nursery.start_soon(stream.read, 5)
nursery.start_soon(stream.read, 5)
# Verify both reads happened
assert acquisition_order.count("start") == 2
assert acquisition_order.count("acquired") == 2
assert release_order.count("start") == 2
assert release_order.count("released") == 2
@pytest.mark.trio
async def test_writer_blocks_readers(mplex_stream):
"""Verify that a writer blocks all readers and new readers queue behind."""
stream, send_chan, _ = mplex_stream
writer_acquired = trio.Event()
readers_ready = trio.Event()
writer_finished = trio.Event()
all_readers_started = trio.Event()
all_readers_done = trio.Event()
counters = {"reader_start_count": 0, "reader_done_count": 0}
reader_target = 3
reader_start_lock = trio.Lock()
# Patch write lock to control test flow
original_acquire_write = stream.rw_lock.acquire_write
original_release_write = stream.rw_lock.release_write
async def tracked_acquire_write():
await original_acquire_write()
writer_acquired.set()
# Wait for readers to queue up
await readers_ready.wait()
# Must be synchronous since real release_write is sync
def tracked_release_write():
original_release_write()
writer_finished.set()
with (
patch.object(
stream.rw_lock, "acquire_write", side_effect=tracked_acquire_write
),
patch.object(
stream.rw_lock, "release_write", side_effect=tracked_release_write
),
):
async with trio.open_nursery() as nursery:
# Start writer
nursery.start_soon(stream.write, b"test")
await writer_acquired.wait()
# Start readers
async def reader_task():
async with reader_start_lock:
counters["reader_start_count"] += 1
if counters["reader_start_count"] == reader_target:
all_readers_started.set()
try:
# This will block until data is available
await stream.read(5)
except (MplexStreamReset, MplexStreamEOF):
pass
finally:
async with reader_start_lock:
counters["reader_done_count"] += 1
if counters["reader_done_count"] == reader_target:
all_readers_done.set()
for _ in range(reader_target):
nursery.start_soon(reader_task)
# Wait until all readers are started
await all_readers_started.wait()
# Let the writer finish and release the lock
readers_ready.set()
await writer_finished.wait()
# Send data to unblock the readers
for i in range(reader_target):
await send_chan.send(b"data" + str(i).encode())
# Wait for all readers to finish
await all_readers_done.wait()
@pytest.mark.trio
async def test_writer_waits_for_readers(mplex_stream):
"""Verify a writer waits for existing readers to complete."""
stream, send_chan, _ = mplex_stream
readers_started = trio.Event()
writer_entered = trio.Event()
writer_acquiring = trio.Event()
readers_finished = trio.Event()
# Send data for readers
await send_chan.send(b"data1")
await send_chan.send(b"data2")
# Patch read lock to control test flow
original_acquire_read = stream.rw_lock.acquire_read
async def tracked_acquire_read():
await original_acquire_read()
readers_started.set()
# Wait until readers are allowed to finish
await readers_finished.wait()
# Patch write lock to detect when writer is blocked
original_acquire_write = stream.rw_lock.acquire_write
async def tracked_acquire_write():
writer_acquiring.set()
await original_acquire_write()
writer_entered.set()
with (
patch.object(stream.rw_lock, "acquire_read", side_effect=tracked_acquire_read),
patch.object(
stream.rw_lock, "acquire_write", side_effect=tracked_acquire_write
),
):
async with trio.open_nursery() as nursery:
# Start readers
nursery.start_soon(stream.read, 5)
nursery.start_soon(stream.read, 5)
# Wait for at least one reader to acquire the lock
await readers_started.wait()
# Start writer (should block)
nursery.start_soon(stream.write, b"test")
# Wait for writer to start acquiring lock
await writer_acquiring.wait()
# Verify writer hasn't entered critical section
assert not writer_entered.is_set()
# Allow readers to finish
readers_finished.set()
# Verify writer can proceed
await writer_entered.wait()
@pytest.mark.trio
async def test_lock_behavior_during_cancellation(mplex_stream):
"""Verify that a lock is released when a task holding it is cancelled."""
stream, _, _ = mplex_stream
reader_acquired_lock = trio.Event()
async def cancellable_reader(task_status):
async with stream.rw_lock.read_lock():
reader_acquired_lock.set()
task_status.started()
# Wait indefinitely until cancelled.
await trio.sleep_forever()
async with trio.open_nursery() as nursery:
# Start the reader and wait for it to acquire the lock.
await nursery.start(cancellable_reader)
await reader_acquired_lock.wait()
# Now that the reader has the lock, cancel the nursery.
# This will cancel the reader task, and its lock should be released.
nursery.cancel_scope.cancel()
# After the nursery is cancelled, the reader should have released the lock.
# To verify, we try to acquire a write lock. If the read lock was not
# released, this will time out.
with trio.move_on_after(1) as cancel_scope:
async with stream.rw_lock.write_lock():
pass
if cancel_scope.cancelled_caught:
pytest.fail(
"Write lock could not be acquired after a cancelled reader, "
"indicating the read lock was not released."
)
@pytest.mark.trio
async def test_concurrent_read_write_sequence(mplex_stream):
"""Verify complex sequence of interleaved reads and writes."""
stream, send_chan, _ = mplex_stream
results = []
# Use a mock to intercept writes and feed them back to the read channel
original_write = stream.write
reader1_finished = trio.Event()
writer1_finished = trio.Event()
reader2_finished = trio.Event()
async def mocked_write(data: bytes) -> None:
await original_write(data)
# Simulate the other side receiving the data and sending a response
# by putting data into the read channel.
await send_chan.send(data)
with patch.object(stream, "write", wraps=mocked_write) as patched_write:
async with trio.open_nursery() as nursery:
# Test scenario:
# 1. Reader 1 starts, waits for data.
# 2. Writer 1 writes, which gets fed back to the stream.
# 3. Reader 2 starts, reads what Writer 1 wrote.
# 4. Writer 2 writes.
async def reader1():
nonlocal results
results.append("R1 start")
data = await stream.read(5)
results.append(data)
results.append("R1 done")
reader1_finished.set()
async def writer1():
nonlocal results
await reader1_finished.wait()
results.append("W1 start")
await stream.write(b"write1")
results.append("W1 done")
writer1_finished.set()
async def reader2():
nonlocal results
await writer1_finished.wait()
# This will read the data from writer1
results.append("R2 start")
data = await stream.read(6)
results.append(data)
results.append("R2 done")
reader2_finished.set()
async def writer2():
nonlocal results
await reader2_finished.wait()
results.append("W2 start")
await stream.write(b"write2")
results.append("W2 done")
# Execute sequence
nursery.start_soon(reader1)
nursery.start_soon(writer1)
nursery.start_soon(reader2)
nursery.start_soon(writer2)
await send_chan.send(b"data1")
# Verify sequence and that write was called
assert patched_write.call_count == 2
assert results == [
"R1 start",
b"data1",
"R1 done",
"W1 start",
"W1 done",
"R2 start",
b"write1",
"R2 done",
"W2 start",
"W2 done",
]
# ===============================================
# 2. Tests for Reset, EOF, and Close Interactions
# ===============================================
@pytest.mark.trio
async def test_read_after_remote_close_triggers_eof(mplex_stream):
"""Verify reading from a remotely closed stream returns EOF correctly."""
stream, send_chan, _ = mplex_stream
# Send some data that can be read first
await send_chan.send(b"data")
# Close the channel to signify no more data will ever arrive
await send_chan.aclose()
# Mark the stream as remotely closed
stream.event_remote_closed.set()
# The first read should succeed, consuming the buffered data
data = await stream.read(4)
assert data == b"data"
# Now that the buffer is empty and the channel is closed, this should raise EOF
with pytest.raises(MplexStreamEOF):
await stream.read(1)
@pytest.mark.trio
async def test_read_on_closed_stream_raises_eof(mplex_stream):
"""Test that reading from a closed stream with no data raises EOF."""
stream, send_chan, _ = mplex_stream
stream.event_remote_closed.set()
await send_chan.aclose() # Ensure the channel is closed
# Reading from a stream that is closed and has no data should raise EOF
with pytest.raises(MplexStreamEOF):
await stream.read(100)
@pytest.mark.trio
async def test_write_to_locally_closed_stream_raises(mplex_stream):
"""Verify writing to a locally closed stream raises MplexStreamClosed."""
stream, _, _ = mplex_stream
stream.event_local_closed.set()
with pytest.raises(MplexStreamClosed):
await stream.write(b"this should fail")
@pytest.mark.trio
async def test_read_from_reset_stream_raises(mplex_stream):
"""Verify reading from a reset stream raises MplexStreamReset."""
stream, _, _ = mplex_stream
stream.event_reset.set()
with pytest.raises(MplexStreamReset):
await stream.read(10)
@pytest.mark.trio
async def test_write_to_reset_stream_raises(mplex_stream):
"""Verify writing to a reset stream raises MplexStreamClosed."""
stream, _, _ = mplex_stream
# A stream reset implies it's also locally closed.
await stream.reset()
# The `write` method checks `event_local_closed`, which `reset` sets.
with pytest.raises(MplexStreamClosed):
await stream.write(b"this should also fail")
@pytest.mark.trio
async def test_stream_reset_cleans_up_resources(mplex_stream):
"""Verify reset() cleans up stream state and resources."""
stream, _, muxed_conn = mplex_stream
stream_id = stream.stream_id
assert stream_id in muxed_conn.streams
await stream.reset()
assert stream.event_reset.is_set()
assert stream.event_local_closed.is_set()
assert stream.event_remote_closed.is_set()
assert (HeaderTags.ResetInitiator, None, stream_id) in muxed_conn.sent_messages
assert stream_id not in muxed_conn.streams
# Verify the underlying data channel is closed
with pytest.raises(trio.ClosedResourceError):
await stream.incoming_data_channel.receive()
# ===============================================
# 3. Rigorous Concurrency Tests with Events
# ===============================================
@pytest.mark.trio
async def test_writer_is_blocked_by_reader_using_events(mplex_stream):
"""Verify a writer must wait for a reader using trio.Event for synchronization."""
stream, _, _ = mplex_stream
reader_has_lock = trio.Event()
writer_finished = trio.Event()
async def reader():
async with stream.rw_lock.read_lock():
reader_has_lock.set()
# Hold the lock until the writer has finished its attempt
await writer_finished.wait()
async def writer():
await reader_has_lock.wait()
# This call will now block until the reader releases the lock
await stream.write(b"data")
writer_finished.set()
async with trio.open_nursery() as nursery:
nursery.start_soon(reader)
nursery.start_soon(writer)
# Verify writer is blocked
await wait_all_tasks_blocked()
assert not writer_finished.is_set()
# Signal the reader to finish
writer_finished.set()
@pytest.mark.trio
async def test_multiple_readers_can_read_concurrently_using_events(mplex_stream):
"""Verify that multiple readers can acquire a read lock simultaneously."""
stream, _, _ = mplex_stream
counters = {"readers_in_critical_section": 0}
lock = trio.Lock() # To safely mutate the counter
reader1_acquired = trio.Event()
reader2_acquired = trio.Event()
all_readers_finished = trio.Event()
async def concurrent_reader(event_to_set: trio.Event):
async with stream.rw_lock.read_lock():
async with lock:
counters["readers_in_critical_section"] += 1
event_to_set.set()
# Wait until all readers have finished before exiting the lock context
await all_readers_finished.wait()
async with lock:
counters["readers_in_critical_section"] -= 1
async with trio.open_nursery() as nursery:
nursery.start_soon(concurrent_reader, reader1_acquired)
nursery.start_soon(concurrent_reader, reader2_acquired)
# Wait for both readers to acquire their locks
await reader1_acquired.wait()
await reader2_acquired.wait()
# Check that both were in the critical section at the same time
async with lock:
assert counters["readers_in_critical_section"] == 2
# Signal for all readers to finish
all_readers_finished.set()
# Verify they exit cleanly
await wait_all_tasks_blocked()
async with lock:
assert counters["readers_in_critical_section"] == 0