feat: implement AsyncContextManager for IMuxedStream to support async… (#629)

* feat: implement AsyncContextManager for IMuxedStream to support async with

* doc: add newsfragment
This commit is contained in:
acul71
2025-05-30 16:44:33 +02:00
committed by GitHub
parent 67ca1d7769
commit 67ab6e27d8
5 changed files with 179 additions and 1 deletions

View File

@ -8,10 +8,14 @@ from collections.abc import (
KeysView, KeysView,
Sequence, Sequence,
) )
from types import (
TracebackType,
)
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
AsyncContextManager, AsyncContextManager,
Optional,
) )
from multiaddr import ( from multiaddr import (
@ -215,7 +219,7 @@ class IMuxedConn(ABC):
""" """
class IMuxedStream(ReadWriteCloser): class IMuxedStream(ReadWriteCloser, AsyncContextManager["IMuxedStream"]):
""" """
Interface for a multiplexed stream. Interface for a multiplexed stream.
@ -249,6 +253,20 @@ class IMuxedStream(ReadWriteCloser):
otherwise False. otherwise False.
""" """
@abstractmethod
async def __aenter__(self) -> "IMuxedStream":
"""Enter the async context manager."""
return self
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
"""Exit the async context manager and close the stream."""
await self.close()
# -------------------------- net_stream interface.py -------------------------- # -------------------------- net_stream interface.py --------------------------

View File

@ -1,3 +1,6 @@
from types import (
TracebackType,
)
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Optional, Optional,
@ -257,3 +260,16 @@ class MplexStream(IMuxedStream):
def get_remote_address(self) -> Optional[tuple[str, int]]: def get_remote_address(self) -> Optional[tuple[str, int]]:
"""Delegate to the parent Mplex connection.""" """Delegate to the parent Mplex connection."""
return self.muxed_conn.get_remote_address() return self.muxed_conn.get_remote_address()
async def __aenter__(self) -> "MplexStream":
"""Enter the async context manager."""
return self
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
"""Exit the async context manager and close the stream."""
await self.close()

View File

@ -9,6 +9,9 @@ from collections.abc import (
import inspect import inspect
import logging import logging
import struct import struct
from types import (
TracebackType,
)
from typing import ( from typing import (
Callable, Callable,
Optional, Optional,
@ -74,6 +77,19 @@ class YamuxStream(IMuxedStream):
self.recv_window = DEFAULT_WINDOW_SIZE self.recv_window = DEFAULT_WINDOW_SIZE
self.window_lock = trio.Lock() self.window_lock = trio.Lock()
async def __aenter__(self) -> "YamuxStream":
"""Enter the async context manager."""
return self
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
"""Exit the async context manager and close the stream."""
await self.close()
async def write(self, data: bytes) -> None: async def write(self, data: bytes) -> None:
if self.send_closed: if self.send_closed:
raise MuxedStreamError("Stream is closed for sending") raise MuxedStreamError("Stream is closed for sending")

View File

@ -0,0 +1 @@
implement AsyncContextManager for IMuxedStream to support async with

View File

@ -0,0 +1,127 @@
import pytest
import trio
from libp2p.stream_muxer.exceptions import (
MuxedStreamClosed,
MuxedStreamError,
)
from libp2p.stream_muxer.mplex.datastructures import (
StreamID,
)
from libp2p.stream_muxer.mplex.mplex_stream import (
MplexStream,
)
from libp2p.stream_muxer.yamux.yamux import (
YamuxStream,
)
class DummySecuredConn:
async def write(self, data):
pass
class MockMuxedConn:
def __init__(self):
self.streams = {}
self.streams_lock = trio.Lock()
self.event_shutting_down = trio.Event()
self.event_closed = trio.Event()
self.event_started = trio.Event()
self.secured_conn = DummySecuredConn() # For YamuxStream
async def send_message(self, flag, data, stream_id):
pass
def get_remote_address(self):
return None
@pytest.mark.trio
async def test_mplex_stream_async_context_manager():
muxed_conn = MockMuxedConn()
stream_id = StreamID(1, True) # Use real StreamID
stream = MplexStream(
name="test_stream",
stream_id=stream_id,
muxed_conn=muxed_conn,
incoming_data_channel=trio.open_memory_channel(8)[1],
)
async with stream as s:
assert s is stream
assert not stream.event_local_closed.is_set()
assert not stream.event_remote_closed.is_set()
assert not stream.event_reset.is_set()
assert stream.event_local_closed.is_set()
@pytest.mark.trio
async def test_yamux_stream_async_context_manager():
muxed_conn = MockMuxedConn()
stream = YamuxStream(stream_id=1, conn=muxed_conn, is_initiator=True)
async with stream as s:
assert s is stream
assert not stream.closed
assert not stream.send_closed
assert not stream.recv_closed
assert stream.send_closed
@pytest.mark.trio
async def test_mplex_stream_async_context_manager_with_error():
muxed_conn = MockMuxedConn()
stream_id = StreamID(1, True)
stream = MplexStream(
name="test_stream",
stream_id=stream_id,
muxed_conn=muxed_conn,
incoming_data_channel=trio.open_memory_channel(8)[1],
)
with pytest.raises(ValueError):
async with stream as s:
assert s is stream
assert not stream.event_local_closed.is_set()
assert not stream.event_remote_closed.is_set()
assert not stream.event_reset.is_set()
raise ValueError("Test error")
assert stream.event_local_closed.is_set()
@pytest.mark.trio
async def test_yamux_stream_async_context_manager_with_error():
muxed_conn = MockMuxedConn()
stream = YamuxStream(stream_id=1, conn=muxed_conn, is_initiator=True)
with pytest.raises(ValueError):
async with stream as s:
assert s is stream
assert not stream.closed
assert not stream.send_closed
assert not stream.recv_closed
raise ValueError("Test error")
assert stream.send_closed
@pytest.mark.trio
async def test_mplex_stream_async_context_manager_write_after_close():
muxed_conn = MockMuxedConn()
stream_id = StreamID(1, True)
stream = MplexStream(
name="test_stream",
stream_id=stream_id,
muxed_conn=muxed_conn,
incoming_data_channel=trio.open_memory_channel(8)[1],
)
async with stream as s:
assert s is stream
with pytest.raises(MuxedStreamClosed):
await stream.write(b"test data")
@pytest.mark.trio
async def test_yamux_stream_async_context_manager_write_after_close():
muxed_conn = MockMuxedConn()
stream = YamuxStream(stream_id=1, conn=muxed_conn, is_initiator=True)
async with stream as s:
assert s is stream
with pytest.raises(MuxedStreamError):
await stream.write(b"test data")