mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
feat: implement AsyncContextManager for IMuxedStream to support async… (#629)
* feat: implement AsyncContextManager for IMuxedStream to support async with * doc: add newsfragment
This commit is contained in:
@ -8,10 +8,14 @@ from collections.abc import (
|
|||||||
KeysView,
|
KeysView,
|
||||||
Sequence,
|
Sequence,
|
||||||
)
|
)
|
||||||
|
from types import (
|
||||||
|
TracebackType,
|
||||||
|
)
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
AsyncContextManager,
|
AsyncContextManager,
|
||||||
|
Optional,
|
||||||
)
|
)
|
||||||
|
|
||||||
from multiaddr import (
|
from multiaddr import (
|
||||||
@ -215,7 +219,7 @@ class IMuxedConn(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class IMuxedStream(ReadWriteCloser):
|
class IMuxedStream(ReadWriteCloser, AsyncContextManager["IMuxedStream"]):
|
||||||
"""
|
"""
|
||||||
Interface for a multiplexed stream.
|
Interface for a multiplexed stream.
|
||||||
|
|
||||||
@ -249,6 +253,20 @@ class IMuxedStream(ReadWriteCloser):
|
|||||||
otherwise False.
|
otherwise False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def __aenter__(self) -> "IMuxedStream":
|
||||||
|
"""Enter the async context manager."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[type[BaseException]],
|
||||||
|
exc_val: Optional[BaseException],
|
||||||
|
exc_tb: Optional[TracebackType],
|
||||||
|
) -> None:
|
||||||
|
"""Exit the async context manager and close the stream."""
|
||||||
|
await self.close()
|
||||||
|
|
||||||
|
|
||||||
# -------------------------- net_stream interface.py --------------------------
|
# -------------------------- net_stream interface.py --------------------------
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,6 @@
|
|||||||
|
from types import (
|
||||||
|
TracebackType,
|
||||||
|
)
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Optional,
|
Optional,
|
||||||
@ -257,3 +260,16 @@ class MplexStream(IMuxedStream):
|
|||||||
def get_remote_address(self) -> Optional[tuple[str, int]]:
|
def get_remote_address(self) -> Optional[tuple[str, int]]:
|
||||||
"""Delegate to the parent Mplex connection."""
|
"""Delegate to the parent Mplex connection."""
|
||||||
return self.muxed_conn.get_remote_address()
|
return self.muxed_conn.get_remote_address()
|
||||||
|
|
||||||
|
async def __aenter__(self) -> "MplexStream":
|
||||||
|
"""Enter the async context manager."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[type[BaseException]],
|
||||||
|
exc_val: Optional[BaseException],
|
||||||
|
exc_tb: Optional[TracebackType],
|
||||||
|
) -> None:
|
||||||
|
"""Exit the async context manager and close the stream."""
|
||||||
|
await self.close()
|
||||||
|
|||||||
@ -9,6 +9,9 @@ from collections.abc import (
|
|||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import struct
|
import struct
|
||||||
|
from types import (
|
||||||
|
TracebackType,
|
||||||
|
)
|
||||||
from typing import (
|
from typing import (
|
||||||
Callable,
|
Callable,
|
||||||
Optional,
|
Optional,
|
||||||
@ -74,6 +77,19 @@ class YamuxStream(IMuxedStream):
|
|||||||
self.recv_window = DEFAULT_WINDOW_SIZE
|
self.recv_window = DEFAULT_WINDOW_SIZE
|
||||||
self.window_lock = trio.Lock()
|
self.window_lock = trio.Lock()
|
||||||
|
|
||||||
|
async def __aenter__(self) -> "YamuxStream":
|
||||||
|
"""Enter the async context manager."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[type[BaseException]],
|
||||||
|
exc_val: Optional[BaseException],
|
||||||
|
exc_tb: Optional[TracebackType],
|
||||||
|
) -> None:
|
||||||
|
"""Exit the async context manager and close the stream."""
|
||||||
|
await self.close()
|
||||||
|
|
||||||
async def write(self, data: bytes) -> None:
|
async def write(self, data: bytes) -> None:
|
||||||
if self.send_closed:
|
if self.send_closed:
|
||||||
raise MuxedStreamError("Stream is closed for sending")
|
raise MuxedStreamError("Stream is closed for sending")
|
||||||
|
|||||||
1
newsfragments/629.feature.rst
Normal file
1
newsfragments/629.feature.rst
Normal file
@ -0,0 +1 @@
|
|||||||
|
implement AsyncContextManager for IMuxedStream to support async with
|
||||||
127
tests/stream_muxer/test_async_context_manager.py
Normal file
127
tests/stream_muxer/test_async_context_manager.py
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
import pytest
|
||||||
|
import trio
|
||||||
|
|
||||||
|
from libp2p.stream_muxer.exceptions import (
|
||||||
|
MuxedStreamClosed,
|
||||||
|
MuxedStreamError,
|
||||||
|
)
|
||||||
|
from libp2p.stream_muxer.mplex.datastructures import (
|
||||||
|
StreamID,
|
||||||
|
)
|
||||||
|
from libp2p.stream_muxer.mplex.mplex_stream import (
|
||||||
|
MplexStream,
|
||||||
|
)
|
||||||
|
from libp2p.stream_muxer.yamux.yamux import (
|
||||||
|
YamuxStream,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DummySecuredConn:
|
||||||
|
async def write(self, data):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MockMuxedConn:
|
||||||
|
def __init__(self):
|
||||||
|
self.streams = {}
|
||||||
|
self.streams_lock = trio.Lock()
|
||||||
|
self.event_shutting_down = trio.Event()
|
||||||
|
self.event_closed = trio.Event()
|
||||||
|
self.event_started = trio.Event()
|
||||||
|
self.secured_conn = DummySecuredConn() # For YamuxStream
|
||||||
|
|
||||||
|
async def send_message(self, flag, data, stream_id):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_remote_address(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_mplex_stream_async_context_manager():
|
||||||
|
muxed_conn = MockMuxedConn()
|
||||||
|
stream_id = StreamID(1, True) # Use real StreamID
|
||||||
|
stream = MplexStream(
|
||||||
|
name="test_stream",
|
||||||
|
stream_id=stream_id,
|
||||||
|
muxed_conn=muxed_conn,
|
||||||
|
incoming_data_channel=trio.open_memory_channel(8)[1],
|
||||||
|
)
|
||||||
|
async with stream as s:
|
||||||
|
assert s is stream
|
||||||
|
assert not stream.event_local_closed.is_set()
|
||||||
|
assert not stream.event_remote_closed.is_set()
|
||||||
|
assert not stream.event_reset.is_set()
|
||||||
|
assert stream.event_local_closed.is_set()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_yamux_stream_async_context_manager():
|
||||||
|
muxed_conn = MockMuxedConn()
|
||||||
|
stream = YamuxStream(stream_id=1, conn=muxed_conn, is_initiator=True)
|
||||||
|
async with stream as s:
|
||||||
|
assert s is stream
|
||||||
|
assert not stream.closed
|
||||||
|
assert not stream.send_closed
|
||||||
|
assert not stream.recv_closed
|
||||||
|
assert stream.send_closed
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_mplex_stream_async_context_manager_with_error():
|
||||||
|
muxed_conn = MockMuxedConn()
|
||||||
|
stream_id = StreamID(1, True)
|
||||||
|
stream = MplexStream(
|
||||||
|
name="test_stream",
|
||||||
|
stream_id=stream_id,
|
||||||
|
muxed_conn=muxed_conn,
|
||||||
|
incoming_data_channel=trio.open_memory_channel(8)[1],
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
async with stream as s:
|
||||||
|
assert s is stream
|
||||||
|
assert not stream.event_local_closed.is_set()
|
||||||
|
assert not stream.event_remote_closed.is_set()
|
||||||
|
assert not stream.event_reset.is_set()
|
||||||
|
raise ValueError("Test error")
|
||||||
|
assert stream.event_local_closed.is_set()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_yamux_stream_async_context_manager_with_error():
|
||||||
|
muxed_conn = MockMuxedConn()
|
||||||
|
stream = YamuxStream(stream_id=1, conn=muxed_conn, is_initiator=True)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
async with stream as s:
|
||||||
|
assert s is stream
|
||||||
|
assert not stream.closed
|
||||||
|
assert not stream.send_closed
|
||||||
|
assert not stream.recv_closed
|
||||||
|
raise ValueError("Test error")
|
||||||
|
assert stream.send_closed
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_mplex_stream_async_context_manager_write_after_close():
|
||||||
|
muxed_conn = MockMuxedConn()
|
||||||
|
stream_id = StreamID(1, True)
|
||||||
|
stream = MplexStream(
|
||||||
|
name="test_stream",
|
||||||
|
stream_id=stream_id,
|
||||||
|
muxed_conn=muxed_conn,
|
||||||
|
incoming_data_channel=trio.open_memory_channel(8)[1],
|
||||||
|
)
|
||||||
|
async with stream as s:
|
||||||
|
assert s is stream
|
||||||
|
with pytest.raises(MuxedStreamClosed):
|
||||||
|
await stream.write(b"test data")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_yamux_stream_async_context_manager_write_after_close():
|
||||||
|
muxed_conn = MockMuxedConn()
|
||||||
|
stream = YamuxStream(stream_id=1, conn=muxed_conn, is_initiator=True)
|
||||||
|
async with stream as s:
|
||||||
|
assert s is stream
|
||||||
|
with pytest.raises(MuxedStreamError):
|
||||||
|
await stream.write(b"test data")
|
||||||
Reference in New Issue
Block a user