Merge pull request #658 from AkMo3/main

fix: add connection state for net stream and gracefully handle failure
This commit is contained in:
Manu Sheel Gupta
2025-06-10 01:03:18 +05:30
committed by GitHub
5 changed files with 503 additions and 14 deletions

View File

@ -0,0 +1,263 @@
"""
Enhanced NetStream Example for py-libp2p with State Management
This example demonstrates the new NetStream features including:
- State tracking and transitions
- Proper error handling and validation
- Resource cleanup and event notifications
- Thread-safe operations with Trio locks
Based on the standard echo demo but enhanced to show NetStream state management.
"""
import argparse
import random
import secrets
import multiaddr
import trio
from libp2p import (
new_host,
)
from libp2p.crypto.secp256k1 import (
create_new_key_pair,
)
from libp2p.custom_types import (
TProtocol,
)
from libp2p.network.stream.exceptions import (
StreamClosed,
StreamEOF,
StreamReset,
)
from libp2p.network.stream.net_stream import (
NetStream,
StreamState,
)
from libp2p.peer.peerinfo import (
info_from_p2p_addr,
)
PROTOCOL_ID = TProtocol("/echo/1.0.0")
async def enhanced_echo_handler(stream: NetStream) -> None:
"""
Enhanced echo handler that demonstrates NetStream state management.
"""
print(f"New connection established: {stream}")
print(f"Initial stream state: {await stream.state}")
try:
# Verify stream is in expected initial state
assert await stream.state == StreamState.OPEN
assert await stream.is_readable()
assert await stream.is_writable()
print("✓ Stream initialized in OPEN state")
# Read incoming data with proper state checking
print("Waiting for client data...")
while await stream.is_readable():
try:
# Read data from client
data = await stream.read(1024)
if not data:
print("Received empty data, client may have closed")
break
print(f"Received: {data.decode('utf-8').strip()}")
# Check if we can still write before echoing
if await stream.is_writable():
await stream.write(data)
print(f"Echoed: {data.decode('utf-8').strip()}")
else:
print("Cannot echo - stream not writable")
break
except StreamEOF:
print("Client closed their write side (EOF)")
break
except StreamReset:
print("Stream was reset by client")
return
except StreamClosed as e:
print(f"Stream operation failed: {e}")
break
# Demonstrate graceful closure
current_state = await stream.state
print(f"Current state before close: {current_state}")
if current_state not in [StreamState.CLOSE_BOTH, StreamState.RESET]:
await stream.close()
print("Server closed write side")
final_state = await stream.state
print(f"Final stream state: {final_state}")
except Exception as e:
print(f"Handler error: {e}")
# Reset stream on unexpected errors
if await stream.state not in [StreamState.RESET, StreamState.CLOSE_BOTH]:
await stream.reset()
print("Stream reset due to error")
async def enhanced_client_demo(stream: NetStream) -> None:
"""
Enhanced client that demonstrates various NetStream state scenarios.
"""
print(f"Client stream established: {stream}")
print(f"Initial state: {await stream.state}")
try:
# Verify initial state
assert await stream.state == StreamState.OPEN
print("✓ Client stream in OPEN state")
# Scenario 1: Normal communication
message = b"Hello from enhanced NetStream client!\n"
if await stream.is_writable():
await stream.write(message)
print(f"Sent: {message.decode('utf-8').strip()}")
else:
print("Cannot write - stream not writable")
return
# Close write side to signal EOF to server
await stream.close()
print("Client closed write side")
# Verify state transition
state_after_close = await stream.state
print(f"State after close: {state_after_close}")
assert state_after_close == StreamState.CLOSE_WRITE
assert await stream.is_readable() # Should still be readable
assert not await stream.is_writable() # Should not be writable
# Try to write (should fail)
try:
await stream.write(b"This should fail")
print("ERROR: Write succeeded when it should have failed!")
except StreamClosed as e:
print(f"✓ Expected error when writing to closed stream: {e}")
# Read the echo response
if await stream.is_readable():
try:
response = await stream.read()
print(f"Received echo: {response.decode('utf-8').strip()}")
except StreamEOF:
print("Server closed their write side")
except StreamReset:
print("Stream was reset")
# Check final state
final_state = await stream.state
print(f"Final client state: {final_state}")
except Exception as e:
print(f"Client error: {e}")
# Reset on error
await stream.reset()
print("Client reset stream due to error")
async def run_enhanced_demo(
port: int, destination: str, seed: int | None = None
) -> None:
"""
Run enhanced echo demo with NetStream state management.
"""
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
# Generate or use provided key
if seed:
random.seed(seed)
secret_number = random.getrandbits(32 * 8)
secret = secret_number.to_bytes(length=32, byteorder="big")
else:
secret = secrets.token_bytes(32)
host = new_host(key_pair=create_new_key_pair(secret))
async with host.run(listen_addrs=[listen_addr]):
print(f"Host ID: {host.get_id().to_string()}")
print("=" * 60)
if not destination: # Server mode
print("🖥️ ENHANCED ECHO SERVER MODE")
print("=" * 60)
# type: ignore: Stream is type of NetStream
host.set_stream_handler(PROTOCOL_ID, enhanced_echo_handler)
print(
"Run client from another console:\n"
f"python3 example_net_stream.py "
f"-d {host.get_addrs()[0]}\n"
)
print("Waiting for connections...")
print("Press Ctrl+C to stop server")
await trio.sleep_forever()
else: # Client mode
print("📱 ENHANCED ECHO CLIENT MODE")
print("=" * 60)
# Connect to server
maddr = multiaddr.Multiaddr(destination)
info = info_from_p2p_addr(maddr)
await host.connect(info)
print(f"Connected to server: {info.peer_id.pretty()}")
# Create stream and run enhanced demo
stream = await host.new_stream(info.peer_id, [PROTOCOL_ID])
if isinstance(stream, NetStream):
await enhanced_client_demo(stream)
print("\n" + "=" * 60)
print("CLIENT DEMO COMPLETE")
def main() -> None:
example_maddr = (
"/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
)
parser = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
parser.add_argument(
"-d",
"--destination",
type=str,
help=f"destination multiaddr string, e.g. {example_maddr}",
)
parser.add_argument(
"-s",
"--seed",
type=int,
help="seed for deterministic peer ID generation",
)
parser.add_argument(
"--demo-states", action="store_true", help="run state transition demo only"
)
args = parser.parse_args()
try:
trio.run(run_enhanced_demo, args.port, args.destination, args.seed)
except KeyboardInterrupt:
print("\n👋 Demo interrupted by user")
except Exception as e:
print(f"❌ Demo failed: {e}")
if __name__ == "__main__":
main()

View File

@ -1,3 +1,9 @@
from enum import (
Enum,
)
import trio
from libp2p.abc import (
IMuxedStream,
INetStream,
@ -19,18 +25,102 @@ from .exceptions import (
)
# TODO: Handle exceptions from `muxed_stream`
# TODO: Add stream state
# - Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501
class StreamState(Enum):
"""NetStream States"""
OPEN = "open"
CLOSE_READ = "close_read"
CLOSE_WRITE = "close_write"
CLOSE_BOTH = "close_both"
RESET = "reset"
class NetStream(INetStream):
"""
Summary
_______
A Network stream implementation.
NetStream wraps a muxed stream and provides proper state tracking, resource cleanup,
and event notification capabilities.
State Machine
_____________
.. code:: markdown
[CREATED] → OPEN → CLOSE_READ → CLOSE_BOTH → [CLEANUP]
↓ ↗ ↗
CLOSE_WRITE → ← ↗
↓ ↗
RESET → → → → → → → →
State Transitions
_________________
- OPEN → CLOSE_READ: EOF encountered during read()
- OPEN → CLOSE_WRITE: Explicit close() call
- OPEN → RESET: reset() call or critical stream error
- CLOSE_READ → CLOSE_BOTH: Explicit close() call
- CLOSE_WRITE → CLOSE_BOTH: EOF encountered during read()
- Any state → RESET: reset() call
Terminal States (trigger cleanup)
_________________________________
- CLOSE_BOTH: Stream fully closed, triggers resource cleanup
- RESET: Stream reset/terminated, triggers resource cleanup
Operation Validity by State
___________________________
OPEN: read() ✓ write() ✓ close() ✓ reset() ✓
CLOSE_READ: read() ✗ write() ✓ close() ✓ reset() ✓
CLOSE_WRITE: read() ✓ write() ✗ close() ✓ reset() ✓
CLOSE_BOTH: read() ✗ write() ✗ close() ✓ reset() ✓
RESET: read() ✗ write() ✗ close() ✓ reset() ✓
Cleanup Process (triggered by CLOSE_BOTH or RESET)
__________________________________________________
1. Remove stream from SwarmConn
2. Notify all listeners with ClosedStream event
3. Decrement reference counter
4. Background cleanup via nursery (if provided)
Thread Safety
_____________
All state operations are protected by trio.Lock() for safe concurrent access.
State checks and modifications are atomic operations.
Example: See :file:`examples/doc-examples/example_net_stream.py`
:param muxed_stream (IMuxedStream): The underlying muxed stream
:param nursery (Optional[trio.Nursery]): Nursery for background cleanup tasks
:raises StreamClosed: When attempting invalid operations on closed streams
:raises StreamEOF: When EOF is encountered during read operations
:raises StreamReset: When the underlying stream has been reset
"""
muxed_stream: IMuxedStream
protocol_id: TProtocol | None
__stream_state: StreamState
def __init__(
self, muxed_stream: IMuxedStream, nursery: trio.Nursery | None = None
) -> None:
super().__init__()
def __init__(self, muxed_stream: IMuxedStream) -> None:
self.muxed_stream = muxed_stream
self.muxed_conn = muxed_stream.muxed_conn
self.protocol_id = None
# For background tasks
self._nursery = nursery
# State management
self.__stream_state = StreamState.OPEN
self._state_lock = trio.Lock()
# For notification handling
self._notify_lock = trio.Lock()
def get_protocol(self) -> TProtocol | None:
"""
:return: protocol id that stream runs on
@ -43,42 +133,176 @@ class NetStream(INetStream):
"""
self.protocol_id = protocol_id
@property
async def state(self) -> StreamState:
"""Get current stream state."""
async with self._state_lock:
return self.__stream_state
async def read(self, n: int | None = None) -> bytes:
"""
Read from stream.
:param n: number of bytes to read
:return: bytes of input
:raises StreamClosed: If `NetStream` is closed for reading
:raises StreamReset: If `NetStream` is reset
:raises StreamEOF: If trying to read after reaching end of file
:return: Bytes read from the stream
"""
async with self._state_lock:
if self.__stream_state in [
StreamState.CLOSE_READ,
StreamState.CLOSE_BOTH,
]:
raise StreamClosed("Stream is closed for reading")
if self.__stream_state == StreamState.RESET:
raise StreamReset("Stream is reset, cannot be used to read")
try:
return await self.muxed_stream.read(n)
data = await self.muxed_stream.read(n)
return data
except MuxedStreamEOF as error:
async with self._state_lock:
if self.__stream_state == StreamState.CLOSE_WRITE:
self.__stream_state = StreamState.CLOSE_BOTH
await self._remove()
elif self.__stream_state == StreamState.OPEN:
self.__stream_state = StreamState.CLOSE_READ
raise StreamEOF() from error
except MuxedStreamReset as error:
async with self._state_lock:
if self.__stream_state in [
StreamState.OPEN,
StreamState.CLOSE_READ,
StreamState.CLOSE_WRITE,
]:
self.__stream_state = StreamState.RESET
await self._remove()
raise StreamReset() from error
async def write(self, data: bytes) -> None:
"""
Write to stream.
:return: number of bytes written
:param data: bytes to write
:raises StreamClosed: If `NetStream` is closed for writing or reset
:raises StreamClosed: If `StreamError` occurred while writing
"""
async with self._state_lock:
if self.__stream_state in [
StreamState.CLOSE_WRITE,
StreamState.CLOSE_BOTH,
StreamState.RESET,
]:
raise StreamClosed("Stream is closed for writing")
try:
await self.muxed_stream.write(data)
except (MuxedStreamClosed, MuxedStreamError) as error:
async with self._state_lock:
if self.__stream_state == StreamState.OPEN:
self.__stream_state = StreamState.CLOSE_WRITE
elif self.__stream_state == StreamState.CLOSE_READ:
self.__stream_state = StreamState.CLOSE_BOTH
await self._remove()
raise StreamClosed() from error
async def close(self) -> None:
"""Close stream."""
"""Close stream for writing."""
async with self._state_lock:
if self.__stream_state in [
StreamState.CLOSE_BOTH,
StreamState.RESET,
StreamState.CLOSE_WRITE,
]:
return
await self.muxed_stream.close()
async with self._state_lock:
if self.__stream_state == StreamState.CLOSE_READ:
self.__stream_state = StreamState.CLOSE_BOTH
await self._remove()
elif self.__stream_state == StreamState.OPEN:
self.__stream_state = StreamState.CLOSE_WRITE
async def reset(self) -> None:
"""Reset stream, closing both ends."""
async with self._state_lock:
if self.__stream_state == StreamState.RESET:
return
await self.muxed_stream.reset()
async with self._state_lock:
if self.__stream_state in [
StreamState.OPEN,
StreamState.CLOSE_READ,
StreamState.CLOSE_WRITE,
]:
self.__stream_state = StreamState.RESET
await self._remove()
async def _remove(self) -> None:
"""
Remove stream from connection and notify listeners.
This is called when the stream is fully closed or reset.
"""
if hasattr(self.muxed_conn, "remove_stream"):
remove_stream = getattr(self.muxed_conn, "remove_stream")
await remove_stream(self)
# Notify in background using Trio nursery if available
if self._nursery:
self._nursery.start_soon(self._notify_closed)
else:
await self._notify_closed()
async def _notify_closed(self) -> None:
"""
Notify all listeners that the stream has been closed.
This runs in a separate task to avoid blocking the main flow.
"""
async with self._notify_lock:
if hasattr(self.muxed_conn, "swarm"):
swarm = getattr(self.muxed_conn, "swarm")
if hasattr(swarm, "notify_all"):
await swarm.notify_all(
lambda notifiee: notifiee.closed_stream(swarm, self)
)
if hasattr(swarm, "refs") and hasattr(swarm.refs, "done"):
swarm.refs.done()
def get_remote_address(self) -> tuple[str, int] | None:
"""Delegate to the underlying muxed stream."""
return self.muxed_stream.get_remote_address()
# TODO: `remove`: Called by close and write when the stream is in specific states.
# It notifies `ClosedStream` after `SwarmConn.remove_stream` is called.
# Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501
async def is_closed(self) -> bool:
"""Check if stream is closed."""
current_state = await self.state
return current_state in [StreamState.CLOSE_BOTH, StreamState.RESET]
async def is_readable(self) -> bool:
"""Check if stream is readable."""
current_state = await self.state
return current_state not in [
StreamState.CLOSE_READ,
StreamState.CLOSE_BOTH,
StreamState.RESET,
]
async def is_writable(self) -> bool:
"""Check if stream is writable."""
current_state = await self.state
return current_state not in [
StreamState.CLOSE_WRITE,
StreamState.CLOSE_BOTH,
StreamState.RESET,
]
def __str__(self) -> str:
"""String representation of the stream."""
return f"<NetStream[{self.__stream_state.value}] protocol={self.protocol_id}>"

View File

@ -0,0 +1 @@
The `NetStream.state` property is now async and requires `await`. Update any direct state access to use `await stream.state`.

View File

@ -0,0 +1 @@
Added proper state management and resource cleanup to `NetStream`, fixing memory leaks and improved error handling.

View File

@ -19,10 +19,10 @@ max_issue_threshold=1
[testenv]
usedevelop=True
commands=
core: pytest {posargs:tests/core}
interop: pytest {posargs:tests/interop}
core: pytest -n auto {posargs:tests/core}
interop: pytest -n auto {posargs:tests/interop}
docs: make check-docs-ci
demos: pytest {posargs:tests/core/examples/test_examples.py}
demos: pytest -n auto {posargs:tests/core/examples/test_examples.py}
basepython=
docs: python
windows-wheel: python