feat: implement AsyncContextManager for IMuxedStream to support async… (#629)

* feat: implement AsyncContextManager for IMuxedStream to support async with

* doc: add newsfragment
This commit is contained in:
acul71
2025-05-30 16:44:33 +02:00
committed by GitHub
parent 67ca1d7769
commit 67ab6e27d8
5 changed files with 179 additions and 1 deletions

View File

@ -8,10 +8,14 @@ from collections.abc import (
KeysView,
Sequence,
)
from types import (
TracebackType,
)
from typing import (
TYPE_CHECKING,
Any,
AsyncContextManager,
Optional,
)
from multiaddr import (
@ -215,7 +219,7 @@ class IMuxedConn(ABC):
"""
class IMuxedStream(ReadWriteCloser):
class IMuxedStream(ReadWriteCloser, AsyncContextManager["IMuxedStream"]):
"""
Interface for a multiplexed stream.
@ -249,6 +253,20 @@ class IMuxedStream(ReadWriteCloser):
otherwise False.
"""
@abstractmethod
async def __aenter__(self) -> "IMuxedStream":
"""Enter the async context manager."""
return self
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
"""Exit the async context manager and close the stream."""
await self.close()
# -------------------------- net_stream interface.py --------------------------

View File

@ -1,3 +1,6 @@
from types import (
TracebackType,
)
from typing import (
TYPE_CHECKING,
Optional,
@ -257,3 +260,16 @@ class MplexStream(IMuxedStream):
def get_remote_address(self) -> Optional[tuple[str, int]]:
"""Delegate to the parent Mplex connection."""
return self.muxed_conn.get_remote_address()
async def __aenter__(self) -> "MplexStream":
"""Enter the async context manager."""
return self
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
"""Exit the async context manager and close the stream."""
await self.close()

View File

@ -9,6 +9,9 @@ from collections.abc import (
import inspect
import logging
import struct
from types import (
TracebackType,
)
from typing import (
Callable,
Optional,
@ -74,6 +77,19 @@ class YamuxStream(IMuxedStream):
self.recv_window = DEFAULT_WINDOW_SIZE
self.window_lock = trio.Lock()
async def __aenter__(self) -> "YamuxStream":
"""Enter the async context manager."""
return self
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
"""Exit the async context manager and close the stream."""
await self.close()
async def write(self, data: bytes) -> None:
if self.send_closed:
raise MuxedStreamError("Stream is closed for sending")