mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
feat: implement AsyncContextManager for IMuxedStream to support async… (#629)
* feat: implement AsyncContextManager for IMuxedStream to support async with * doc: add newsfragment
This commit is contained in:
@ -8,10 +8,14 @@ from collections.abc import (
|
||||
KeysView,
|
||||
Sequence,
|
||||
)
|
||||
from types import (
|
||||
TracebackType,
|
||||
)
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncContextManager,
|
||||
Optional,
|
||||
)
|
||||
|
||||
from multiaddr import (
|
||||
@ -215,7 +219,7 @@ class IMuxedConn(ABC):
|
||||
"""
|
||||
|
||||
|
||||
class IMuxedStream(ReadWriteCloser):
|
||||
class IMuxedStream(ReadWriteCloser, AsyncContextManager["IMuxedStream"]):
|
||||
"""
|
||||
Interface for a multiplexed stream.
|
||||
|
||||
@ -249,6 +253,20 @@ class IMuxedStream(ReadWriteCloser):
|
||||
otherwise False.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def __aenter__(self) -> "IMuxedStream":
|
||||
"""Enter the async context manager."""
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
"""Exit the async context manager and close the stream."""
|
||||
await self.close()
|
||||
|
||||
|
||||
# -------------------------- net_stream interface.py --------------------------
|
||||
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
from types import (
|
||||
TracebackType,
|
||||
)
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Optional,
|
||||
@ -257,3 +260,16 @@ class MplexStream(IMuxedStream):
|
||||
def get_remote_address(self) -> Optional[tuple[str, int]]:
|
||||
"""Delegate to the parent Mplex connection."""
|
||||
return self.muxed_conn.get_remote_address()
|
||||
|
||||
async def __aenter__(self) -> "MplexStream":
|
||||
"""Enter the async context manager."""
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
"""Exit the async context manager and close the stream."""
|
||||
await self.close()
|
||||
|
||||
@ -9,6 +9,9 @@ from collections.abc import (
|
||||
import inspect
|
||||
import logging
|
||||
import struct
|
||||
from types import (
|
||||
TracebackType,
|
||||
)
|
||||
from typing import (
|
||||
Callable,
|
||||
Optional,
|
||||
@ -74,6 +77,19 @@ class YamuxStream(IMuxedStream):
|
||||
self.recv_window = DEFAULT_WINDOW_SIZE
|
||||
self.window_lock = trio.Lock()
|
||||
|
||||
async def __aenter__(self) -> "YamuxStream":
|
||||
"""Enter the async context manager."""
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
"""Exit the async context manager and close the stream."""
|
||||
await self.close()
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
if self.send_closed:
|
||||
raise MuxedStreamError("Stream is closed for sending")
|
||||
|
||||
1
newsfragments/629.feature.rst
Normal file
1
newsfragments/629.feature.rst
Normal file
@ -0,0 +1 @@
|
||||
implement AsyncContextManager for IMuxedStream to support async with
|
||||
127
tests/stream_muxer/test_async_context_manager.py
Normal file
127
tests/stream_muxer/test_async_context_manager.py
Normal file
@ -0,0 +1,127 @@
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.stream_muxer.exceptions import (
|
||||
MuxedStreamClosed,
|
||||
MuxedStreamError,
|
||||
)
|
||||
from libp2p.stream_muxer.mplex.datastructures import (
|
||||
StreamID,
|
||||
)
|
||||
from libp2p.stream_muxer.mplex.mplex_stream import (
|
||||
MplexStream,
|
||||
)
|
||||
from libp2p.stream_muxer.yamux.yamux import (
|
||||
YamuxStream,
|
||||
)
|
||||
|
||||
|
||||
class DummySecuredConn:
|
||||
async def write(self, data):
|
||||
pass
|
||||
|
||||
|
||||
class MockMuxedConn:
|
||||
def __init__(self):
|
||||
self.streams = {}
|
||||
self.streams_lock = trio.Lock()
|
||||
self.event_shutting_down = trio.Event()
|
||||
self.event_closed = trio.Event()
|
||||
self.event_started = trio.Event()
|
||||
self.secured_conn = DummySecuredConn() # For YamuxStream
|
||||
|
||||
async def send_message(self, flag, data, stream_id):
|
||||
pass
|
||||
|
||||
def get_remote_address(self):
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_mplex_stream_async_context_manager():
|
||||
muxed_conn = MockMuxedConn()
|
||||
stream_id = StreamID(1, True) # Use real StreamID
|
||||
stream = MplexStream(
|
||||
name="test_stream",
|
||||
stream_id=stream_id,
|
||||
muxed_conn=muxed_conn,
|
||||
incoming_data_channel=trio.open_memory_channel(8)[1],
|
||||
)
|
||||
async with stream as s:
|
||||
assert s is stream
|
||||
assert not stream.event_local_closed.is_set()
|
||||
assert not stream.event_remote_closed.is_set()
|
||||
assert not stream.event_reset.is_set()
|
||||
assert stream.event_local_closed.is_set()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_yamux_stream_async_context_manager():
|
||||
muxed_conn = MockMuxedConn()
|
||||
stream = YamuxStream(stream_id=1, conn=muxed_conn, is_initiator=True)
|
||||
async with stream as s:
|
||||
assert s is stream
|
||||
assert not stream.closed
|
||||
assert not stream.send_closed
|
||||
assert not stream.recv_closed
|
||||
assert stream.send_closed
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_mplex_stream_async_context_manager_with_error():
|
||||
muxed_conn = MockMuxedConn()
|
||||
stream_id = StreamID(1, True)
|
||||
stream = MplexStream(
|
||||
name="test_stream",
|
||||
stream_id=stream_id,
|
||||
muxed_conn=muxed_conn,
|
||||
incoming_data_channel=trio.open_memory_channel(8)[1],
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
async with stream as s:
|
||||
assert s is stream
|
||||
assert not stream.event_local_closed.is_set()
|
||||
assert not stream.event_remote_closed.is_set()
|
||||
assert not stream.event_reset.is_set()
|
||||
raise ValueError("Test error")
|
||||
assert stream.event_local_closed.is_set()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_yamux_stream_async_context_manager_with_error():
|
||||
muxed_conn = MockMuxedConn()
|
||||
stream = YamuxStream(stream_id=1, conn=muxed_conn, is_initiator=True)
|
||||
with pytest.raises(ValueError):
|
||||
async with stream as s:
|
||||
assert s is stream
|
||||
assert not stream.closed
|
||||
assert not stream.send_closed
|
||||
assert not stream.recv_closed
|
||||
raise ValueError("Test error")
|
||||
assert stream.send_closed
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_mplex_stream_async_context_manager_write_after_close():
|
||||
muxed_conn = MockMuxedConn()
|
||||
stream_id = StreamID(1, True)
|
||||
stream = MplexStream(
|
||||
name="test_stream",
|
||||
stream_id=stream_id,
|
||||
muxed_conn=muxed_conn,
|
||||
incoming_data_channel=trio.open_memory_channel(8)[1],
|
||||
)
|
||||
async with stream as s:
|
||||
assert s is stream
|
||||
with pytest.raises(MuxedStreamClosed):
|
||||
await stream.write(b"test data")
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_yamux_stream_async_context_manager_write_after_close():
|
||||
muxed_conn = MockMuxedConn()
|
||||
stream = YamuxStream(stream_id=1, conn=muxed_conn, is_initiator=True)
|
||||
async with stream as s:
|
||||
assert s is stream
|
||||
with pytest.raises(MuxedStreamError):
|
||||
await stream.write(b"test data")
|
||||
Reference in New Issue
Block a user