diff --git a/libp2p/abc.py b/libp2p/abc.py index 688b1623..f9686bac 100644 --- a/libp2p/abc.py +++ b/libp2p/abc.py @@ -8,10 +8,14 @@ from collections.abc import ( KeysView, Sequence, ) +from types import ( + TracebackType, +) from typing import ( TYPE_CHECKING, Any, AsyncContextManager, + Optional, ) from multiaddr import ( @@ -215,7 +219,7 @@ class IMuxedConn(ABC): """ -class IMuxedStream(ReadWriteCloser): +class IMuxedStream(ReadWriteCloser, AsyncContextManager["IMuxedStream"]): """ Interface for a multiplexed stream. @@ -249,6 +253,20 @@ class IMuxedStream(ReadWriteCloser): otherwise False. """ + @abstractmethod + async def __aenter__(self) -> "IMuxedStream": + """Enter the async context manager.""" + return self + + async def __aexit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + """Exit the async context manager and close the stream.""" + await self.close() + # -------------------------- net_stream interface.py -------------------------- diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 9b876a55..a5bce0c1 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -1,3 +1,6 @@ +from types import ( + TracebackType, +) from typing import ( TYPE_CHECKING, Optional, @@ -257,3 +260,16 @@ class MplexStream(IMuxedStream): def get_remote_address(self) -> Optional[tuple[str, int]]: """Delegate to the parent Mplex connection.""" return self.muxed_conn.get_remote_address() + + async def __aenter__(self) -> "MplexStream": + """Enter the async context manager.""" + return self + + async def __aexit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + """Exit the async context manager and close the stream.""" + await self.close() diff --git a/libp2p/stream_muxer/yamux/yamux.py b/libp2p/stream_muxer/yamux/yamux.py index 200d986c..ceceb541 100644 --- a/libp2p/stream_muxer/yamux/yamux.py +++ b/libp2p/stream_muxer/yamux/yamux.py @@ -9,6 +9,9 @@ from collections.abc import ( import inspect import logging import struct +from types import ( + TracebackType, +) from typing import ( Callable, Optional, @@ -74,6 +77,19 @@ class YamuxStream(IMuxedStream): self.recv_window = DEFAULT_WINDOW_SIZE self.window_lock = trio.Lock() + async def __aenter__(self) -> "YamuxStream": + """Enter the async context manager.""" + return self + + async def __aexit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + """Exit the async context manager and close the stream.""" + await self.close() + async def write(self, data: bytes) -> None: if self.send_closed: raise MuxedStreamError("Stream is closed for sending") diff --git a/newsfragments/629.feature.rst b/newsfragments/629.feature.rst new file mode 100644 index 00000000..939ba6a4 --- /dev/null +++ b/newsfragments/629.feature.rst @@ -0,0 +1 @@ +implement AsyncContextManager for IMuxedStream to support async with diff --git a/tests/stream_muxer/test_async_context_manager.py b/tests/stream_muxer/test_async_context_manager.py new file mode 100644 index 00000000..a79e6a7c --- /dev/null +++ b/tests/stream_muxer/test_async_context_manager.py @@ -0,0 +1,127 @@ +import pytest +import trio + +from libp2p.stream_muxer.exceptions import ( + MuxedStreamClosed, + MuxedStreamError, +) +from libp2p.stream_muxer.mplex.datastructures import ( + StreamID, +) +from libp2p.stream_muxer.mplex.mplex_stream import ( + MplexStream, +) +from libp2p.stream_muxer.yamux.yamux import ( + YamuxStream, +) + + +class DummySecuredConn: + async def write(self, data): + pass + + +class MockMuxedConn: + def __init__(self): + self.streams = {} + self.streams_lock = trio.Lock() + self.event_shutting_down = trio.Event() + self.event_closed = trio.Event() + self.event_started = trio.Event() + self.secured_conn = DummySecuredConn() # For YamuxStream + + async def send_message(self, flag, data, stream_id): + pass + + def get_remote_address(self): + return None + + +@pytest.mark.trio +async def test_mplex_stream_async_context_manager(): + muxed_conn = MockMuxedConn() + stream_id = StreamID(1, True) # Use real StreamID + stream = MplexStream( + name="test_stream", + stream_id=stream_id, + muxed_conn=muxed_conn, + incoming_data_channel=trio.open_memory_channel(8)[1], + ) + async with stream as s: + assert s is stream + assert not stream.event_local_closed.is_set() + assert not stream.event_remote_closed.is_set() + assert not stream.event_reset.is_set() + assert stream.event_local_closed.is_set() + + +@pytest.mark.trio +async def test_yamux_stream_async_context_manager(): + muxed_conn = MockMuxedConn() + stream = YamuxStream(stream_id=1, conn=muxed_conn, is_initiator=True) + async with stream as s: + assert s is stream + assert not stream.closed + assert not stream.send_closed + assert not stream.recv_closed + assert stream.send_closed + + +@pytest.mark.trio +async def test_mplex_stream_async_context_manager_with_error(): + muxed_conn = MockMuxedConn() + stream_id = StreamID(1, True) + stream = MplexStream( + name="test_stream", + stream_id=stream_id, + muxed_conn=muxed_conn, + incoming_data_channel=trio.open_memory_channel(8)[1], + ) + with pytest.raises(ValueError): + async with stream as s: + assert s is stream + assert not stream.event_local_closed.is_set() + assert not stream.event_remote_closed.is_set() + assert not stream.event_reset.is_set() + raise ValueError("Test error") + assert stream.event_local_closed.is_set() + + +@pytest.mark.trio +async def test_yamux_stream_async_context_manager_with_error(): + muxed_conn = MockMuxedConn() + stream = YamuxStream(stream_id=1, conn=muxed_conn, is_initiator=True) + with pytest.raises(ValueError): + async with stream as s: + assert s is stream + assert not stream.closed + assert not stream.send_closed + assert not stream.recv_closed + raise ValueError("Test error") + assert stream.send_closed + + +@pytest.mark.trio +async def test_mplex_stream_async_context_manager_write_after_close(): + muxed_conn = MockMuxedConn() + stream_id = StreamID(1, True) + stream = MplexStream( + name="test_stream", + stream_id=stream_id, + muxed_conn=muxed_conn, + incoming_data_channel=trio.open_memory_channel(8)[1], + ) + async with stream as s: + assert s is stream + with pytest.raises(MuxedStreamClosed): + await stream.write(b"test data") + + +@pytest.mark.trio +async def test_yamux_stream_async_context_manager_write_after_close(): + muxed_conn = MockMuxedConn() + stream = YamuxStream(stream_id=1, conn=muxed_conn, is_initiator=True) + async with stream as s: + assert s is stream + with pytest.raises(MuxedStreamError): + await stream.write(b"test data")