From 5bc4d01eea22f9526ec5784c794efcefd6afff5f Mon Sep 17 00:00:00 2001 From: Akash Mondal Date: Fri, 6 Jun 2025 06:27:23 +0000 Subject: [PATCH] fix: add connection states for net stream Other changes: 1. Add operation validation based on states 2. Gracefully handle exceptions and cleanup --- libp2p/network/stream/net_stream.py | 189 +++++++++++++++++++++++++--- 1 file changed, 174 insertions(+), 15 deletions(-) diff --git a/libp2p/network/stream/net_stream.py b/libp2p/network/stream/net_stream.py index 300f0fa4..5b1e2668 100644 --- a/libp2p/network/stream/net_stream.py +++ b/libp2p/network/stream/net_stream.py @@ -1,3 +1,12 @@ +from enum import ( + Enum, +) +from typing import ( + Optional, +) + +import trio + from libp2p.abc import ( IMuxedStream, INetStream, @@ -19,18 +28,42 @@ from .exceptions import ( ) -# TODO: Handle exceptions from `muxed_stream` -# TODO: Add stream state -# - Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501 -class NetStream(INetStream): - muxed_stream: IMuxedStream - protocol_id: TProtocol | None +class StreamState(Enum): + """NetStream States""" + + OPEN = "open" + CLOSE_READ = "close_read" + CLOSE_WRITE = "close_write" + CLOSE_BOTH = "close_both" + RESET = "reset" + + +class NetStream(INetStream): + """Class representing NetStream Handler""" + + muxed_stream: IMuxedStream + protocol_id: Optional[TProtocol] + __stream_state: StreamState + + def __init__( + self, muxed_stream: IMuxedStream, nursery: Optional[trio.Nursery] = None + ) -> None: + super().__init__() - def __init__(self, muxed_stream: IMuxedStream) -> None: self.muxed_stream = muxed_stream self.muxed_conn = muxed_stream.muxed_conn self.protocol_id = None + # For background tasks + self._nursery = nursery + + # State management + self.__stream_state = StreamState.OPEN + self._state_lock = trio.Lock() + + # For notification handling + self._notify_lock = trio.Lock() + def get_protocol(self) -> TProtocol | None: """ :return: protocol id that stream runs on @@ -43,42 +76,168 @@ class NetStream(INetStream): """ self.protocol_id = protocol_id - async def read(self, n: int | None = None) -> bytes: + @property + async def state(self) -> StreamState: + """Get current stream state.""" + async with self._state_lock: + return self.__stream_state + + async def read(self, n: Optional[int] = None) -> bytes: """ Read from stream. :param n: number of bytes to read :return: bytes of input """ + async with self._state_lock: + if self.__stream_state in [ + StreamState.CLOSE_READ, + StreamState.CLOSE_BOTH, + ]: + raise StreamClosed("Stream is closed for reading") + + if self.__stream_state == StreamState.RESET: + raise StreamReset("Stream is reset, cannot be used to read") + try: - return await self.muxed_stream.read(n) + data = await self.muxed_stream.read(n) + return data except MuxedStreamEOF as error: + async with self._state_lock: + if self.__stream_state == StreamState.CLOSE_WRITE: + self.__stream_state = StreamState.CLOSE_BOTH + await self._remove() + elif self.__stream_state == StreamState.OPEN: + self.__stream_state = StreamState.CLOSE_READ raise StreamEOF() from error except MuxedStreamReset as error: + async with self._state_lock: + if self.__stream_state in [ + StreamState.OPEN, + StreamState.CLOSE_READ, + StreamState.CLOSE_WRITE, + ]: + self.__stream_state = StreamState.RESET + await self._remove() raise StreamReset() from error async def write(self, data: bytes) -> None: """ Write to stream. - :return: number of bytes written + :param data: bytes to write """ + async with self._state_lock: + if self.__stream_state in [ + StreamState.CLOSE_WRITE, + StreamState.CLOSE_BOTH, + StreamState.RESET, + ]: + raise StreamClosed("Stream is closed for writing") + try: await self.muxed_stream.write(data) except (MuxedStreamClosed, MuxedStreamError) as error: + async with self._state_lock: + if self.__stream_state == StreamState.OPEN: + self.__stream_state = StreamState.CLOSE_WRITE + elif self.__stream_state == StreamState.CLOSE_READ: + self.__stream_state = StreamState.CLOSE_BOTH + await self._remove() raise StreamClosed() from error async def close(self) -> None: - """Close stream.""" + """Close stream for writing.""" + async with self._state_lock: + if self.__stream_state in [ + StreamState.CLOSE_BOTH, + StreamState.RESET, + StreamState.CLOSE_WRITE, + ]: + return + await self.muxed_stream.close() + async with self._state_lock: + if self.__stream_state == StreamState.CLOSE_READ: + self.__stream_state = StreamState.CLOSE_BOTH + await self._remove() + elif self.__stream_state == StreamState.OPEN: + self.__stream_state = StreamState.CLOSE_WRITE + async def reset(self) -> None: + """Reset stream, closing both ends.""" + async with self._state_lock: + if self.__stream_state == StreamState.RESET: + return + await self.muxed_stream.reset() - def get_remote_address(self) -> tuple[str, int] | None: + async with self._state_lock: + if self.__stream_state in [ + StreamState.OPEN, + StreamState.CLOSE_READ, + StreamState.CLOSE_WRITE, + ]: + self.__stream_state = StreamState.RESET + await self._remove() + + async def _remove(self) -> None: + """ + Remove stream from connection and notify listeners. + This is called when the stream is fully closed or reset. + """ + if hasattr(self.muxed_conn, "remove_stream"): + remove_stream = getattr(self.muxed_conn, "remove_stream") + await remove_stream(self) + + # Notify in background using Trio nursery if available + if self._nursery: + self._nursery.start_soon(self._notify_closed) + else: + await self._notify_closed() + + async def _notify_closed(self) -> None: + """ + Notify all listeners that the stream has been closed. + This runs in a separate task to avoid blocking the main flow. + """ + async with self._notify_lock: + if hasattr(self.muxed_conn, "swarm"): + swarm = getattr(self.muxed_conn, "swarm") + + if hasattr(swarm, "notify_all"): + await swarm.notify_all( + lambda notifiee: notifiee.closed_stream(swarm, self) + ) + + if hasattr(swarm, "refs") and hasattr(swarm.refs, "done"): + swarm.refs.done() + + def get_remote_address(self) -> Optional[tuple[str, int]]: """Delegate to the underlying muxed stream.""" return self.muxed_stream.get_remote_address() - # TODO: `remove`: Called by close and write when the stream is in specific states. - # It notifies `ClosedStream` after `SwarmConn.remove_stream` is called. - # Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501 + def is_closed(self) -> bool: + """Check if stream is closed.""" + return self.__stream_state in [StreamState.CLOSE_BOTH, StreamState.RESET] + + def is_readable(self) -> bool: + """Check if stream is readable.""" + return self.__stream_state not in [ + StreamState.CLOSE_READ, + StreamState.CLOSE_BOTH, + StreamState.RESET, + ] + + def is_writable(self) -> bool: + """Check if stream is writable.""" + return self.__stream_state not in [ + StreamState.CLOSE_WRITE, + StreamState.CLOSE_BOTH, + StreamState.RESET, + ] + + def __str__(self) -> str: + """String representation of the stream.""" + return f""