diff --git a/libp2p/io/trio.py b/libp2p/io/trio.py index 7d0584e4..e74e9ed2 100644 --- a/libp2p/io/trio.py +++ b/libp2p/io/trio.py @@ -23,6 +23,10 @@ class TrioReadWriteCloser(ReadWriteCloser): raise IOException(error) async def read(self, n: int = -1) -> bytes: + if n == 0: + # Check point + await trio.sleep(0) + return b"" max_bytes = n if n != -1 else None try: return await self.stream.receive_some(max_bytes) diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index b91783d1..46b3f6fd 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -50,8 +50,11 @@ class SwarmConn(INetConn, Service): await self._notify_disconnected() async def _handle_new_streams(self) -> None: - while True: + while self.manager.is_running: try: + print( + f"!@# SwarmConn._handle_new_streams: {self.muxed_conn._id}: waiting for new streams" + ) stream = await self.muxed_conn.accept_stream() except MuxedConnUnavailable: # If there is anything wrong in the MuxedConn, @@ -60,6 +63,9 @@ class SwarmConn(INetConn, Service): # Asynchronously handle the accepted stream, to avoid blocking the next stream. self.manager.run_task(self._handle_muxed_stream, stream) + print( + f"!@# SwarmConn._handle_new_streams: {self.muxed_conn._id}: out of the loop" + ) await self.close() async def _call_stream_handler(self, net_stream: NetStream) -> None: diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 337da896..78fb7fdc 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -206,8 +206,7 @@ class Swarm(INetwork, Service): logger.debug("successfully opened connection to peer %s", peer_id) # FIXME: This is a intentional barrier to prevent from the handler exiting and # closing the connection. - event = trio.Event() - await event.wait() + await trio.sleep_forever() try: # Success @@ -240,7 +239,7 @@ class Swarm(INetwork, Service): # await asyncio.gather( # *[connection.close() for connection in self.connections.values()] # ) - self.manager.stop() + await self.manager.stop() await self.manager.wait_finished() logger.debug("swarm successfully closed") diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 9ac56143..53d855b2 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -1,3 +1,4 @@ +import math import asyncio import logging from typing import Any # noqa: F401 @@ -18,7 +19,6 @@ from libp2p.utils import ( encode_uvarint, encode_varint_prefixed, read_varint_prefixed_bytes, - TrioQueue, ) from .constants import HeaderTags @@ -41,7 +41,10 @@ class Mplex(IMuxedConn, Service): next_channel_id: int streams: Dict[StreamID, MplexStream] streams_lock: trio.Lock - new_stream_queue: "TrioQueue[IMuxedStream]" + streams_msg_channels: Dict[StreamID, "trio.MemorySendChannel[bytes]"] + new_stream_send_channel: "trio.MemorySendChannel[IMuxedStream]" + new_stream_receive_channel: "trio.MemoryReceiveChannel[IMuxedStream]" + event_shutting_down: trio.Event event_closed: trio.Event @@ -64,7 +67,10 @@ class Mplex(IMuxedConn, Service): # Mapping from stream ID -> buffer of messages for that stream self.streams = {} self.streams_lock = trio.Lock() - self.new_stream_queue = TrioQueue() + self.streams_msg_channels = {} + send_channel, receive_channel = trio.open_memory_channel(math.inf) + self.new_stream_send_channel = send_channel + self.new_stream_receive_channel = receive_channel self.event_shutting_down = trio.Event() self.event_closed = trio.Event() @@ -105,9 +111,13 @@ class Mplex(IMuxedConn, Service): return next_id async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream: - stream = MplexStream(name, stream_id, self) + # Use an unbounded buffer, to avoid `handle_incoming` being blocked when doing + # `send_channel.send`. + send_channel, receive_channel = trio.open_memory_channel(math.inf) + stream = MplexStream(name, stream_id, self, receive_channel) async with self.streams_lock: self.streams[stream_id] = stream + self.streams_msg_channels[stream_id] = send_channel return stream async def open_stream(self) -> IMuxedStream: @@ -126,7 +136,10 @@ class Mplex(IMuxedConn, Service): async def accept_stream(self) -> IMuxedStream: """accepts a muxed stream opened by the other end.""" - return await self.new_stream_queue.get() + try: + return await self.new_stream_receive_channel.receive() + except (trio.ClosedResourceError, trio.EndOfChannel): + raise MplexUnavailable async def send_message( self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID @@ -138,6 +151,9 @@ class Mplex(IMuxedConn, Service): :param data: data to send in the message :param stream_id: stream the message is in """ + print( + f"!@# send_message: {self._id}: flag={flag}, data={data}, stream_id={stream_id}" + ) # << by 3, then or with flag header = encode_uvarint((stream_id.channel_id << 3) | flag.value) @@ -162,14 +178,21 @@ class Mplex(IMuxedConn, Service): """Read a message off of the secured connection and add it to the corresponding message buffer.""" - while True: + while self.manager.is_running: try: + print( + f"!@# handle_incoming: {self._id}: before _handle_incoming_message" + ) await self._handle_incoming_message() + print( + f"!@# handle_incoming: {self._id}: after _handle_incoming_message" + ) except MplexUnavailable as e: logger.debug("mplex unavailable while waiting for incoming: %s", e) + print(f"!@# handle_incoming: {self._id}: MplexUnavailable: {e}") break - # Force context switch - await trio.sleep(0) + + print(f"!@# handle_incoming: {self._id}: leaving") # If we enter here, it means this connection is shutting down. # We should clean things up. await self._cleanup() @@ -181,51 +204,73 @@ class Mplex(IMuxedConn, Service): :return: stream_id, flag, message contents """ - # FIXME: No timeout is used in Go implementation. try: header = await decode_uvarint_from_stream(self.secured_conn) + except (ParseError, RawConnError, IncompleteReadError) as error: + raise MplexUnavailable( + f"failed to read the header correctly from the underlying connection: {error}" + ) + try: message = await read_varint_prefixed_bytes(self.secured_conn) except (ParseError, RawConnError, IncompleteReadError) as error: raise MplexUnavailable( - "failed to read messages correctly from the underlying connection" - ) from error - except asyncio.TimeoutError as error: - raise MplexUnavailable( - "failed to read more message body within the timeout" - ) from error + "failed to read the message body correctly from the underlying connection: " + f"{error}" + ) flag = header & 0x07 channel_id = header >> 3 return channel_id, flag, message + @property + def _id(self) -> int: + return 0 if self.is_initiator else 1 + async def _handle_incoming_message(self) -> None: """ Read and handle a new incoming message. :raise MplexUnavailable: `Mplex` encounters fatal error or is shutting down. """ + print(f"!@# _handle_incoming_message: {self._id}: before reading") channel_id, flag, message = await self.read_message() + print( + f"!@# _handle_incoming_message: {self._id}: channel_id={channel_id}, flag={flag}, message={message}" + ) stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1)) + print(f"!@# _handle_incoming_message: {self._id}: 2") if flag == HeaderTags.NewStream.value: + print(f"!@# _handle_incoming_message: {self._id}: 3") await self._handle_new_stream(stream_id, message) + print(f"!@# _handle_incoming_message: {self._id}: 4") elif flag in ( HeaderTags.MessageInitiator.value, HeaderTags.MessageReceiver.value, ): + print(f"!@# _handle_incoming_message: {self._id}: 5") await self._handle_message(stream_id, message) + print(f"!@# _handle_incoming_message: {self._id}: 6") elif flag in (HeaderTags.CloseInitiator.value, HeaderTags.CloseReceiver.value): + print(f"!@# _handle_incoming_message: {self._id}: 7") await self._handle_close(stream_id) + print(f"!@# _handle_incoming_message: {self._id}: 8") elif flag in (HeaderTags.ResetInitiator.value, HeaderTags.ResetReceiver.value): + print(f"!@# _handle_incoming_message: {self._id}: 9") await self._handle_reset(stream_id) + print(f"!@# _handle_incoming_message: {self._id}: 10") else: + print(f"!@# _handle_incoming_message: {self._id}: 11") # Receives messages with an unknown flag # TODO: logging async with self.streams_lock: + print(f"!@# _handle_incoming_message: {self._id}: 12") if stream_id in self.streams: + print(f"!@# _handle_incoming_message: {self._id}: 13") stream = self.streams[stream_id] await stream.reset() + print(f"!@# _handle_incoming_message: {self._id}: 14") async def _handle_new_stream(self, stream_id: StreamID, message: bytes) -> None: async with self.streams_lock: @@ -235,43 +280,65 @@ class Mplex(IMuxedConn, Service): f"received NewStream message for existing stream: {stream_id}" ) mplex_stream = await self._initialize_stream(stream_id, message.decode()) - await self.new_stream_queue.put(mplex_stream) + try: + await self.new_stream_send_channel.send(mplex_stream) + except (trio.BrokenResourceError, trio.EndOfChannel): + raise MplexUnavailable async def _handle_message(self, stream_id: StreamID, message: bytes) -> None: + print( + f"!@# _handle_message: {self._id}: stream_id={stream_id}, message={message}" + ) async with self.streams_lock: + print(f"!@# _handle_message: {self._id}: 1") if stream_id not in self.streams: # We receive a message of the stream `stream_id` which is not accepted # before. It is abnormal. Possibly disconnect? # TODO: Warn and emit logs about this. + print(f"!@# _handle_message: {self._id}: 2") return + print(f"!@# _handle_message: {self._id}: 3") stream = self.streams[stream_id] + send_channel = self.streams_msg_channels[stream_id] async with stream.close_lock: + print(f"!@# _handle_message: {self._id}: 4") if stream.event_remote_closed.is_set(): + print(f"!@# _handle_message: {self._id}: 5") # TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501 return - await stream.incoming_data.put(message) + print(f"!@# _handle_message: {self._id}: 6") + await send_channel.send(message) + print(f"!@# _handle_message: {self._id}: 7") async def _handle_close(self, stream_id: StreamID) -> None: + print(f"!@# _handle_close: {self._id}: step=0") async with self.streams_lock: if stream_id not in self.streams: # Ignore unmatched messages for now. return stream = self.streams[stream_id] + send_channel = self.streams_msg_channels[stream_id] + print(f"!@# _handle_close: {self._id}: step=1") + await send_channel.aclose() + print(f"!@# _handle_close: {self._id}: step=2") # NOTE: If remote is already closed, then return: Technically a bug # on the other side. We should consider killing the connection. async with stream.close_lock: if stream.event_remote_closed.is_set(): return + print(f"!@# _handle_close: {self._id}: step=3") is_local_closed: bool async with stream.close_lock: stream.event_remote_closed.set() is_local_closed = stream.event_local_closed.is_set() + print(f"!@# _handle_close: {self._id}: step=4") # If local is also closed, both sides are closed. Then, we should clean up # the entry of this stream, to avoid others from accessing it. if is_local_closed: async with self.streams_lock: if stream_id in self.streams: del self.streams[stream_id] + print(f"!@# _handle_close: {self._id}: step=5") async def _handle_reset(self, stream_id: StreamID) -> None: async with self.streams_lock: @@ -279,11 +346,11 @@ class Mplex(IMuxedConn, Service): # This is *ok*. We forget the stream on reset. return stream = self.streams[stream_id] - + send_channel = self.streams_msg_channels[stream_id] + await send_channel.aclose() async with stream.close_lock: if not stream.event_remote_closed.is_set(): stream.event_reset.set() - stream.event_remote_closed.set() # If local is not closed, we should close it. if not stream.event_local_closed.is_set(): @@ -291,16 +358,21 @@ class Mplex(IMuxedConn, Service): async with self.streams_lock: if stream_id in self.streams: del self.streams[stream_id] + del self.streams_msg_channels[stream_id] async def _cleanup(self) -> None: if not self.event_shutting_down.is_set(): self.event_shutting_down.set() async with self.streams_lock: - for stream in self.streams.values(): + for stream_id, stream in self.streams.items(): async with stream.close_lock: if not stream.event_remote_closed.is_set(): stream.event_remote_closed.set() stream.event_reset.set() stream.event_local_closed.set() + send_channel = self.streams_msg_channels[stream_id] + await send_channel.aclose() self.streams = None self.event_closed.set() + await self.new_stream_send_channel.aclose() + await self.new_stream_receive_channel.aclose() diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 58da4040..6ecc4077 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -3,7 +3,6 @@ from typing import TYPE_CHECKING import trio from libp2p.stream_muxer.abc import IMuxedStream -from libp2p.utils import IQueue, TrioQueue from .constants import HeaderTags from .datastructures import StreamID @@ -24,10 +23,11 @@ class MplexStream(IMuxedStream): read_deadline: int write_deadline: int + # TODO: Add lock for read/write to avoid interleaving receiving messages? close_lock: trio.Lock # NOTE: `dataIn` is size of 8 in Go implementation. - incoming_data: IQueue[bytes] + incoming_data_channel: "trio.MemoryReceiveChannel[bytes]" event_local_closed: trio.Event event_remote_closed: trio.Event @@ -35,7 +35,13 @@ class MplexStream(IMuxedStream): _buf: bytearray - def __init__(self, name: str, stream_id: StreamID, muxed_conn: "Mplex") -> None: + def __init__( + self, + name: str, + stream_id: StreamID, + muxed_conn: "Mplex", + incoming_data_channel: "trio.MemoryReceiveChannel[bytes]", + ) -> None: """ create new MuxedStream in muxer. @@ -51,13 +57,30 @@ class MplexStream(IMuxedStream): self.event_remote_closed = trio.Event() self.event_reset = trio.Event() self.close_lock = trio.Lock() - self.incoming_data = TrioQueue() + self.incoming_data_channel = incoming_data_channel self._buf = bytearray() @property def is_initiator(self) -> bool: return self.stream_id.is_initiator + async def _read_until_eof(self) -> bytes: + async for data in self.incoming_data_channel: + self._buf.extend(data) + payload = self._buf + self._buf = self._buf[len(payload) :] + return bytes(payload) + + def _read_return_when_blocked(self) -> bytes: + buf = bytearray() + while True: + try: + data = self.incoming_data_channel.receive_nowait() + buf.extend(data) + except (trio.WouldBlock, trio.EndOfChannel): + break + return buf + async def read(self, n: int = -1) -> bytes: """ Read up to n bytes. Read possibly returns fewer than `n` bytes, if @@ -73,7 +96,40 @@ class MplexStream(IMuxedStream): ) if self.event_reset.is_set(): raise MplexStreamReset - return await self.incoming_data.get() + if n == -1: + return await self._read_until_eof() + if len(self._buf) == 0: + data: bytes + # Peek whether there is data available. If yes, we just read until there is no data, + # and then return. + try: + data = self.incoming_data_channel.receive_nowait() + except trio.EndOfChannel: + raise MplexStreamEOF + except trio.WouldBlock: + # We know `receive` will be blocked here. Wait for data here with `receive` and + # catch all kinds of errors here. + try: + data = await self.incoming_data_channel.receive() + except trio.EndOfChannel: + if self.event_reset.is_set(): + raise MplexStreamReset + if self.event_remote_closed.is_set(): + raise MplexStreamEOF + except trio.ClosedResourceError as error: + # Probably `incoming_data_channel` is closed in `reset` when we are waiting + # for `receive`. + if self.event_reset.is_set(): + raise MplexStreamReset + raise Exception( + "`incoming_data_channel` is closed but stream is not reset. " + "This should never happen." + ) from error + self._buf.extend(data) + self._buf.extend(self._read_return_when_blocked()) + payload = self._buf[:n] + self._buf = self._buf[len(payload) :] + return bytes(payload) async def write(self, data: bytes) -> int: """ @@ -99,22 +155,26 @@ class MplexStream(IMuxedStream): if self.event_local_closed.is_set(): return + print(f"!@# stream.close: {self.muxed_conn._id}: step=0") flag = ( HeaderTags.CloseInitiator if self.is_initiator else HeaderTags.CloseReceiver ) # TODO: Raise when `muxed_conn.send_message` fails and `Mplex` isn't shutdown. await self.muxed_conn.send_message(flag, None, self.stream_id) + print(f"!@# stream.close: {self.muxed_conn._id}: step=1") _is_remote_closed: bool async with self.close_lock: self.event_local_closed.set() _is_remote_closed = self.event_remote_closed.is_set() + print(f"!@# stream.close: {self.muxed_conn._id}: step=2") if _is_remote_closed: # Both sides are closed, we can safely remove the buffer from the dict. async with self.muxed_conn.streams_lock: if self.stream_id in self.muxed_conn.streams: del self.muxed_conn.streams[self.stream_id] + print(f"!@# stream.close: {self.muxed_conn._id}: step=3") async def reset(self) -> None: """closes both ends of the stream tells this remote side to hang up.""" @@ -132,14 +192,15 @@ class MplexStream(IMuxedStream): if self.is_initiator else HeaderTags.ResetReceiver ) - async with trio.open_nursery() as nursery: - nursery.start_soon( - self.muxed_conn.send_message, flag, None, self.stream_id - ) + self.muxed_conn.manager.run_task( + self.muxed_conn.send_message, flag, None, self.stream_id + ) self.event_local_closed.set() self.event_remote_closed.set() + await self.incoming_data_channel.aclose() + async with self.muxed_conn.streams_lock: if ( self.muxed_conn.streams is not None diff --git a/libp2p/tools/factories.py b/libp2p/tools/factories.py index f324371a..448ec90d 100644 --- a/libp2p/tools/factories.py +++ b/libp2p/tools/factories.py @@ -205,7 +205,7 @@ async def mplex_stream_pair_factory(is_secure: bool) -> Tuple[MplexStream, Mplex stream_1: MplexStream async with mplex_conn_1.streams_lock: if len(mplex_conn_1.streams) != 1: - raise Exception("Mplex should not have any stream upon connection") + raise Exception("Mplex should not have any other stream") stream_1 = tuple(mplex_conn_1.streams.values())[0] yield cast(MplexStream, stream_0), cast(MplexStream, stream_1) diff --git a/libp2p/utils.py b/libp2p/utils.py index 27880c48..3d0794a1 100644 --- a/libp2p/utils.py +++ b/libp2p/utils.py @@ -1,8 +1,5 @@ import itertools import math -from typing import Generic, TypeVar - -import trio from libp2p.exceptions import ParseError from libp2p.io.abc import Reader @@ -98,25 +95,3 @@ async def read_fixedint_prefixed(reader: Reader) -> bytes: len_bytes = await reader.read(SIZE_LEN_BYTES) len_int = int.from_bytes(len_bytes, "big") return await reader.read(len_int) - - -TItem = TypeVar("TItem") - - -class IQueue(Generic[TItem]): - async def put(self, item: TItem): - ... - - async def get(self) -> TItem: - ... - - -class TrioQueue(IQueue): - def __init__(self): - self.send_channel, self.receive_channel = trio.open_memory_channel(0) - - async def put(self, item: TItem): - await self.send_channel.send(item) - - async def get(self) -> TItem: - return await self.receive_channel.receive() diff --git a/tests/conftest.py b/tests/conftest.py index 746fb026..48d705c2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,5 @@ -import asyncio - import pytest -from libp2p.tools.constants import LISTEN_MADDR from libp2p.tools.factories import HostFactory @@ -17,17 +14,6 @@ def num_hosts(): @pytest.fixture -async def hosts(num_hosts, is_host_secure): - _hosts = HostFactory.create_batch(num_hosts, is_secure=is_host_secure) - await asyncio.gather( - *[_host.get_network().listen(LISTEN_MADDR) for _host in _hosts] - ) - try: +async def hosts(num_hosts, is_host_secure, nursery): + async with HostFactory.create_batch_and_listen(is_host_secure, num_hosts) as _hosts: yield _hosts - finally: - # TODO: It's possible that `close` raises exceptions currently, - # due to the connection reset things. Though we don't care much about that when - # cleaning up the tasks, it is probably better to handle the exceptions properly. - await asyncio.gather( - *[_host.close() for _host in _hosts], return_exceptions=True - ) diff --git a/tests/network/conftest.py b/tests/network/conftest.py index c45dbdba..5aad36c9 100644 --- a/tests/network/conftest.py +++ b/tests/network/conftest.py @@ -1,5 +1,3 @@ -import asyncio - import pytest from libp2p.tools.factories import ( diff --git a/tests/stream_muxer/conftest.py b/tests/stream_muxer/conftest.py index 5c5bc2bb..248422d9 100644 --- a/tests/stream_muxer/conftest.py +++ b/tests/stream_muxer/conftest.py @@ -1,5 +1,3 @@ -import asyncio - import pytest from libp2p.tools.factories import mplex_conn_pair_factory, mplex_stream_pair_factory diff --git a/tests/stream_muxer/test_mplex_conn.py b/tests/stream_muxer/test_mplex_conn.py index 6dc98ad6..d48432d2 100644 --- a/tests/stream_muxer/test_mplex_conn.py +++ b/tests/stream_muxer/test_mplex_conn.py @@ -1,9 +1,9 @@ -import asyncio +import trio import pytest -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_conn(mplex_conn_pair): conn_0, conn_1 = mplex_conn_pair @@ -16,19 +16,19 @@ async def test_mplex_conn(mplex_conn_pair): # Test: Open a stream, and both side get 1 more stream. stream_0 = await conn_0.open_stream() - await asyncio.sleep(0.01) + await trio.sleep(0.01) assert len(conn_0.streams) == 1 assert len(conn_1.streams) == 1 # Test: From another side. stream_1 = await conn_1.open_stream() - await asyncio.sleep(0.01) + await trio.sleep(0.01) assert len(conn_0.streams) == 2 assert len(conn_1.streams) == 2 # Close from one side. await conn_0.close() # Sleep for a while for both side to handle `close`. - await asyncio.sleep(0.01) + await trio.sleep(0.01) # Test: Both side is closed. assert conn_0.event_shutting_down.is_set() assert conn_0.event_closed.is_set() diff --git a/tests/stream_muxer/test_mplex_stream.py b/tests/stream_muxer/test_mplex_stream.py index dae66571..e3d19a5d 100644 --- a/tests/stream_muxer/test_mplex_stream.py +++ b/tests/stream_muxer/test_mplex_stream.py @@ -1,5 +1,6 @@ import pytest import trio +from trio.testing import wait_all_tasks_blocked from libp2p.stream_muxer.mplex.exceptions import ( MplexStreamClosed, @@ -37,37 +38,65 @@ async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair): async def read_until_eof(): read_bytes.extend(await stream_1.read()) - task = trio.ensure_future(read_until_eof()) - expected_data = bytearray() - # Test: `read` doesn't return before `close` is called. - await stream_0.write(DATA) - expected_data.extend(DATA) - await trio.sleep(0.01) - assert len(read_bytes) == 0 - # Test: `read` doesn't return before `close` is called. - await stream_0.write(DATA) - expected_data.extend(DATA) - await trio.sleep(0.01) - assert len(read_bytes) == 0 + async with trio.open_nursery() as nursery: + nursery.start_soon(read_until_eof) + # Test: `read` doesn't return before `close` is called. + await stream_0.write(DATA) + expected_data.extend(DATA) + await trio.sleep(0.01) + assert len(read_bytes) == 0 + # Test: `read` doesn't return before `close` is called. + await stream_0.write(DATA) + expected_data.extend(DATA) + await trio.sleep(0.01) + assert len(read_bytes) == 0 + + # Test: Close the stream, `read` returns, and receive previous sent data. + await stream_0.close() - # Test: Close the stream, `read` returns, and receive previous sent data. - await stream_0.close() - await trio.sleep(0.01) assert read_bytes == expected_data - task.cancel() - @pytest.mark.trio async def test_mplex_stream_read_after_remote_closed(mplex_stream_pair): stream_0, stream_1 = mplex_stream_pair assert not stream_1.event_remote_closed.is_set() await stream_0.write(DATA) - await stream_0.close() + assert not stream_0.event_local_closed.is_set() await trio.sleep(0.01) + await wait_all_tasks_blocked() + await stream_0.close() + assert stream_0.event_local_closed.is_set() + await trio.sleep(0.01) + print( + "!@# ", + stream_0.muxed_conn.event_shutting_down.is_set(), + stream_0.muxed_conn.event_closed.is_set(), + stream_1.muxed_conn.event_shutting_down.is_set(), + stream_1.muxed_conn.event_closed.is_set(), + ) + # await trio.sleep(100000) + await wait_all_tasks_blocked() + print( + "!@# ", + stream_0.muxed_conn.event_shutting_down.is_set(), + stream_0.muxed_conn.event_closed.is_set(), + stream_1.muxed_conn.event_shutting_down.is_set(), + stream_1.muxed_conn.event_closed.is_set(), + ) + print("!@# sleeping") + print("!@# result=", stream_1.event_remote_closed.is_set()) + # await trio.sleep_forever() assert stream_1.event_remote_closed.is_set() + print( + "!@# ", + stream_0.muxed_conn.event_shutting_down.is_set(), + stream_0.muxed_conn.event_closed.is_set(), + stream_1.muxed_conn.event_shutting_down.is_set(), + stream_1.muxed_conn.event_closed.is_set(), + ) assert (await stream_1.read(MAX_READ_LEN)) == DATA with pytest.raises(MplexStreamEOF): await stream_1.read(MAX_READ_LEN) @@ -87,7 +116,8 @@ async def test_mplex_stream_read_after_remote_reset(mplex_stream_pair): await stream_0.write(DATA) await stream_0.reset() # Sleep to let `stream_1` receive the message. - await trio.sleep(0.01) + await trio.sleep(0.1) + await wait_all_tasks_blocked() with pytest.raises(MplexStreamReset): await stream_1.read(MAX_READ_LEN) diff --git a/tests/test_utils.py b/tests/test_utils.py deleted file mode 100644 index 7bd807d5..00000000 --- a/tests/test_utils.py +++ /dev/null @@ -1,19 +0,0 @@ -import pytest -import trio - -from libp2p.utils import TrioQueue - - -@pytest.mark.trio -async def test_trio_queue(): - queue = TrioQueue() - - async def queue_get(task_status=None): - result = await queue.get() - task_status.started(result) - - async with trio.open_nursery() as nursery: - nursery.start_soon(queue.put, 123) - result = await nursery.start(queue_get) - - assert result == 123