mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Merge pull request #658 from AkMo3/main
fix: add connection state for net stream and gracefully handle failure
This commit is contained in:
263
examples/doc-examples/example_net_stream.py
Normal file
263
examples/doc-examples/example_net_stream.py
Normal file
@ -0,0 +1,263 @@
|
||||
"""
|
||||
Enhanced NetStream Example for py-libp2p with State Management
|
||||
|
||||
This example demonstrates the new NetStream features including:
|
||||
- State tracking and transitions
|
||||
- Proper error handling and validation
|
||||
- Resource cleanup and event notifications
|
||||
- Thread-safe operations with Trio locks
|
||||
|
||||
Based on the standard echo demo but enhanced to show NetStream state management.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import random
|
||||
import secrets
|
||||
|
||||
import multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p import (
|
||||
new_host,
|
||||
)
|
||||
from libp2p.crypto.secp256k1 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.network.stream.exceptions import (
|
||||
StreamClosed,
|
||||
StreamEOF,
|
||||
StreamReset,
|
||||
)
|
||||
from libp2p.network.stream.net_stream import (
|
||||
NetStream,
|
||||
StreamState,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
info_from_p2p_addr,
|
||||
)
|
||||
|
||||
PROTOCOL_ID = TProtocol("/echo/1.0.0")
|
||||
|
||||
|
||||
async def enhanced_echo_handler(stream: NetStream) -> None:
|
||||
"""
|
||||
Enhanced echo handler that demonstrates NetStream state management.
|
||||
"""
|
||||
print(f"New connection established: {stream}")
|
||||
print(f"Initial stream state: {await stream.state}")
|
||||
|
||||
try:
|
||||
# Verify stream is in expected initial state
|
||||
assert await stream.state == StreamState.OPEN
|
||||
assert await stream.is_readable()
|
||||
assert await stream.is_writable()
|
||||
print("✓ Stream initialized in OPEN state")
|
||||
|
||||
# Read incoming data with proper state checking
|
||||
print("Waiting for client data...")
|
||||
|
||||
while await stream.is_readable():
|
||||
try:
|
||||
# Read data from client
|
||||
data = await stream.read(1024)
|
||||
if not data:
|
||||
print("Received empty data, client may have closed")
|
||||
break
|
||||
|
||||
print(f"Received: {data.decode('utf-8').strip()}")
|
||||
|
||||
# Check if we can still write before echoing
|
||||
if await stream.is_writable():
|
||||
await stream.write(data)
|
||||
print(f"Echoed: {data.decode('utf-8').strip()}")
|
||||
else:
|
||||
print("Cannot echo - stream not writable")
|
||||
break
|
||||
|
||||
except StreamEOF:
|
||||
print("Client closed their write side (EOF)")
|
||||
break
|
||||
except StreamReset:
|
||||
print("Stream was reset by client")
|
||||
return
|
||||
except StreamClosed as e:
|
||||
print(f"Stream operation failed: {e}")
|
||||
break
|
||||
|
||||
# Demonstrate graceful closure
|
||||
current_state = await stream.state
|
||||
print(f"Current state before close: {current_state}")
|
||||
|
||||
if current_state not in [StreamState.CLOSE_BOTH, StreamState.RESET]:
|
||||
await stream.close()
|
||||
print("Server closed write side")
|
||||
|
||||
final_state = await stream.state
|
||||
print(f"Final stream state: {final_state}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Handler error: {e}")
|
||||
# Reset stream on unexpected errors
|
||||
if await stream.state not in [StreamState.RESET, StreamState.CLOSE_BOTH]:
|
||||
await stream.reset()
|
||||
print("Stream reset due to error")
|
||||
|
||||
|
||||
async def enhanced_client_demo(stream: NetStream) -> None:
|
||||
"""
|
||||
Enhanced client that demonstrates various NetStream state scenarios.
|
||||
"""
|
||||
print(f"Client stream established: {stream}")
|
||||
print(f"Initial state: {await stream.state}")
|
||||
|
||||
try:
|
||||
# Verify initial state
|
||||
assert await stream.state == StreamState.OPEN
|
||||
print("✓ Client stream in OPEN state")
|
||||
|
||||
# Scenario 1: Normal communication
|
||||
message = b"Hello from enhanced NetStream client!\n"
|
||||
|
||||
if await stream.is_writable():
|
||||
await stream.write(message)
|
||||
print(f"Sent: {message.decode('utf-8').strip()}")
|
||||
else:
|
||||
print("Cannot write - stream not writable")
|
||||
return
|
||||
|
||||
# Close write side to signal EOF to server
|
||||
await stream.close()
|
||||
print("Client closed write side")
|
||||
|
||||
# Verify state transition
|
||||
state_after_close = await stream.state
|
||||
print(f"State after close: {state_after_close}")
|
||||
assert state_after_close == StreamState.CLOSE_WRITE
|
||||
assert await stream.is_readable() # Should still be readable
|
||||
assert not await stream.is_writable() # Should not be writable
|
||||
|
||||
# Try to write (should fail)
|
||||
try:
|
||||
await stream.write(b"This should fail")
|
||||
print("ERROR: Write succeeded when it should have failed!")
|
||||
except StreamClosed as e:
|
||||
print(f"✓ Expected error when writing to closed stream: {e}")
|
||||
|
||||
# Read the echo response
|
||||
if await stream.is_readable():
|
||||
try:
|
||||
response = await stream.read()
|
||||
print(f"Received echo: {response.decode('utf-8').strip()}")
|
||||
except StreamEOF:
|
||||
print("Server closed their write side")
|
||||
except StreamReset:
|
||||
print("Stream was reset")
|
||||
|
||||
# Check final state
|
||||
final_state = await stream.state
|
||||
print(f"Final client state: {final_state}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Client error: {e}")
|
||||
# Reset on error
|
||||
await stream.reset()
|
||||
print("Client reset stream due to error")
|
||||
|
||||
|
||||
async def run_enhanced_demo(
|
||||
port: int, destination: str, seed: int | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Run enhanced echo demo with NetStream state management.
|
||||
"""
|
||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
||||
|
||||
# Generate or use provided key
|
||||
if seed:
|
||||
random.seed(seed)
|
||||
secret_number = random.getrandbits(32 * 8)
|
||||
secret = secret_number.to_bytes(length=32, byteorder="big")
|
||||
else:
|
||||
secret = secrets.token_bytes(32)
|
||||
|
||||
host = new_host(key_pair=create_new_key_pair(secret))
|
||||
|
||||
async with host.run(listen_addrs=[listen_addr]):
|
||||
print(f"Host ID: {host.get_id().to_string()}")
|
||||
print("=" * 60)
|
||||
|
||||
if not destination: # Server mode
|
||||
print("🖥️ ENHANCED ECHO SERVER MODE")
|
||||
print("=" * 60)
|
||||
|
||||
# type: ignore: Stream is type of NetStream
|
||||
host.set_stream_handler(PROTOCOL_ID, enhanced_echo_handler)
|
||||
|
||||
print(
|
||||
"Run client from another console:\n"
|
||||
f"python3 example_net_stream.py "
|
||||
f"-d {host.get_addrs()[0]}\n"
|
||||
)
|
||||
print("Waiting for connections...")
|
||||
print("Press Ctrl+C to stop server")
|
||||
await trio.sleep_forever()
|
||||
|
||||
else: # Client mode
|
||||
print("📱 ENHANCED ECHO CLIENT MODE")
|
||||
print("=" * 60)
|
||||
|
||||
# Connect to server
|
||||
maddr = multiaddr.Multiaddr(destination)
|
||||
info = info_from_p2p_addr(maddr)
|
||||
await host.connect(info)
|
||||
print(f"Connected to server: {info.peer_id.pretty()}")
|
||||
|
||||
# Create stream and run enhanced demo
|
||||
stream = await host.new_stream(info.peer_id, [PROTOCOL_ID])
|
||||
if isinstance(stream, NetStream):
|
||||
await enhanced_client_demo(stream)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("CLIENT DEMO COMPLETE")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
example_maddr = (
|
||||
"/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--destination",
|
||||
type=str,
|
||||
help=f"destination multiaddr string, e.g. {example_maddr}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--seed",
|
||||
type=int,
|
||||
help="seed for deterministic peer ID generation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--demo-states", action="store_true", help="run state transition demo only"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
trio.run(run_enhanced_demo, args.port, args.destination, args.seed)
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 Demo interrupted by user")
|
||||
except Exception as e:
|
||||
print(f"❌ Demo failed: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,3 +1,9 @@
|
||||
from enum import (
|
||||
Enum,
|
||||
)
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
IMuxedStream,
|
||||
INetStream,
|
||||
@ -19,18 +25,102 @@ from .exceptions import (
|
||||
)
|
||||
|
||||
|
||||
# TODO: Handle exceptions from `muxed_stream`
|
||||
# TODO: Add stream state
|
||||
# - Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501
|
||||
class StreamState(Enum):
|
||||
"""NetStream States"""
|
||||
|
||||
OPEN = "open"
|
||||
CLOSE_READ = "close_read"
|
||||
CLOSE_WRITE = "close_write"
|
||||
CLOSE_BOTH = "close_both"
|
||||
RESET = "reset"
|
||||
|
||||
|
||||
class NetStream(INetStream):
|
||||
"""
|
||||
Summary
|
||||
_______
|
||||
A Network stream implementation.
|
||||
|
||||
NetStream wraps a muxed stream and provides proper state tracking, resource cleanup,
|
||||
and event notification capabilities.
|
||||
|
||||
State Machine
|
||||
_____________
|
||||
|
||||
.. code:: markdown
|
||||
|
||||
[CREATED] → OPEN → CLOSE_READ → CLOSE_BOTH → [CLEANUP]
|
||||
↓ ↗ ↗
|
||||
CLOSE_WRITE → ← ↗
|
||||
↓ ↗
|
||||
RESET → → → → → → → →
|
||||
|
||||
State Transitions
|
||||
_________________
|
||||
- OPEN → CLOSE_READ: EOF encountered during read()
|
||||
- OPEN → CLOSE_WRITE: Explicit close() call
|
||||
- OPEN → RESET: reset() call or critical stream error
|
||||
- CLOSE_READ → CLOSE_BOTH: Explicit close() call
|
||||
- CLOSE_WRITE → CLOSE_BOTH: EOF encountered during read()
|
||||
- Any state → RESET: reset() call
|
||||
|
||||
Terminal States (trigger cleanup)
|
||||
_________________________________
|
||||
- CLOSE_BOTH: Stream fully closed, triggers resource cleanup
|
||||
- RESET: Stream reset/terminated, triggers resource cleanup
|
||||
|
||||
Operation Validity by State
|
||||
___________________________
|
||||
OPEN: read() ✓ write() ✓ close() ✓ reset() ✓
|
||||
CLOSE_READ: read() ✗ write() ✓ close() ✓ reset() ✓
|
||||
CLOSE_WRITE: read() ✓ write() ✗ close() ✓ reset() ✓
|
||||
CLOSE_BOTH: read() ✗ write() ✗ close() ✓ reset() ✓
|
||||
RESET: read() ✗ write() ✗ close() ✓ reset() ✓
|
||||
|
||||
Cleanup Process (triggered by CLOSE_BOTH or RESET)
|
||||
__________________________________________________
|
||||
1. Remove stream from SwarmConn
|
||||
2. Notify all listeners with ClosedStream event
|
||||
3. Decrement reference counter
|
||||
4. Background cleanup via nursery (if provided)
|
||||
|
||||
Thread Safety
|
||||
_____________
|
||||
All state operations are protected by trio.Lock() for safe concurrent access.
|
||||
State checks and modifications are atomic operations.
|
||||
|
||||
Example: See :file:`examples/doc-examples/example_net_stream.py`
|
||||
|
||||
:param muxed_stream (IMuxedStream): The underlying muxed stream
|
||||
:param nursery (Optional[trio.Nursery]): Nursery for background cleanup tasks
|
||||
:raises StreamClosed: When attempting invalid operations on closed streams
|
||||
:raises StreamEOF: When EOF is encountered during read operations
|
||||
:raises StreamReset: When the underlying stream has been reset
|
||||
"""
|
||||
|
||||
muxed_stream: IMuxedStream
|
||||
protocol_id: TProtocol | None
|
||||
__stream_state: StreamState
|
||||
|
||||
def __init__(
|
||||
self, muxed_stream: IMuxedStream, nursery: trio.Nursery | None = None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
def __init__(self, muxed_stream: IMuxedStream) -> None:
|
||||
self.muxed_stream = muxed_stream
|
||||
self.muxed_conn = muxed_stream.muxed_conn
|
||||
self.protocol_id = None
|
||||
|
||||
# For background tasks
|
||||
self._nursery = nursery
|
||||
|
||||
# State management
|
||||
self.__stream_state = StreamState.OPEN
|
||||
self._state_lock = trio.Lock()
|
||||
|
||||
# For notification handling
|
||||
self._notify_lock = trio.Lock()
|
||||
|
||||
def get_protocol(self) -> TProtocol | None:
|
||||
"""
|
||||
:return: protocol id that stream runs on
|
||||
@ -43,42 +133,176 @@ class NetStream(INetStream):
|
||||
"""
|
||||
self.protocol_id = protocol_id
|
||||
|
||||
@property
|
||||
async def state(self) -> StreamState:
|
||||
"""Get current stream state."""
|
||||
async with self._state_lock:
|
||||
return self.__stream_state
|
||||
|
||||
async def read(self, n: int | None = None) -> bytes:
|
||||
"""
|
||||
Read from stream.
|
||||
|
||||
:param n: number of bytes to read
|
||||
:return: bytes of input
|
||||
:raises StreamClosed: If `NetStream` is closed for reading
|
||||
:raises StreamReset: If `NetStream` is reset
|
||||
:raises StreamEOF: If trying to read after reaching end of file
|
||||
:return: Bytes read from the stream
|
||||
"""
|
||||
async with self._state_lock:
|
||||
if self.__stream_state in [
|
||||
StreamState.CLOSE_READ,
|
||||
StreamState.CLOSE_BOTH,
|
||||
]:
|
||||
raise StreamClosed("Stream is closed for reading")
|
||||
|
||||
if self.__stream_state == StreamState.RESET:
|
||||
raise StreamReset("Stream is reset, cannot be used to read")
|
||||
|
||||
try:
|
||||
return await self.muxed_stream.read(n)
|
||||
data = await self.muxed_stream.read(n)
|
||||
return data
|
||||
except MuxedStreamEOF as error:
|
||||
async with self._state_lock:
|
||||
if self.__stream_state == StreamState.CLOSE_WRITE:
|
||||
self.__stream_state = StreamState.CLOSE_BOTH
|
||||
await self._remove()
|
||||
elif self.__stream_state == StreamState.OPEN:
|
||||
self.__stream_state = StreamState.CLOSE_READ
|
||||
raise StreamEOF() from error
|
||||
except MuxedStreamReset as error:
|
||||
async with self._state_lock:
|
||||
if self.__stream_state in [
|
||||
StreamState.OPEN,
|
||||
StreamState.CLOSE_READ,
|
||||
StreamState.CLOSE_WRITE,
|
||||
]:
|
||||
self.__stream_state = StreamState.RESET
|
||||
await self._remove()
|
||||
raise StreamReset() from error
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
"""
|
||||
Write to stream.
|
||||
|
||||
:return: number of bytes written
|
||||
:param data: bytes to write
|
||||
:raises StreamClosed: If `NetStream` is closed for writing or reset
|
||||
:raises StreamClosed: If `StreamError` occurred while writing
|
||||
"""
|
||||
async with self._state_lock:
|
||||
if self.__stream_state in [
|
||||
StreamState.CLOSE_WRITE,
|
||||
StreamState.CLOSE_BOTH,
|
||||
StreamState.RESET,
|
||||
]:
|
||||
raise StreamClosed("Stream is closed for writing")
|
||||
|
||||
try:
|
||||
await self.muxed_stream.write(data)
|
||||
except (MuxedStreamClosed, MuxedStreamError) as error:
|
||||
async with self._state_lock:
|
||||
if self.__stream_state == StreamState.OPEN:
|
||||
self.__stream_state = StreamState.CLOSE_WRITE
|
||||
elif self.__stream_state == StreamState.CLOSE_READ:
|
||||
self.__stream_state = StreamState.CLOSE_BOTH
|
||||
await self._remove()
|
||||
raise StreamClosed() from error
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close stream."""
|
||||
"""Close stream for writing."""
|
||||
async with self._state_lock:
|
||||
if self.__stream_state in [
|
||||
StreamState.CLOSE_BOTH,
|
||||
StreamState.RESET,
|
||||
StreamState.CLOSE_WRITE,
|
||||
]:
|
||||
return
|
||||
|
||||
await self.muxed_stream.close()
|
||||
|
||||
async with self._state_lock:
|
||||
if self.__stream_state == StreamState.CLOSE_READ:
|
||||
self.__stream_state = StreamState.CLOSE_BOTH
|
||||
await self._remove()
|
||||
elif self.__stream_state == StreamState.OPEN:
|
||||
self.__stream_state = StreamState.CLOSE_WRITE
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Reset stream, closing both ends."""
|
||||
async with self._state_lock:
|
||||
if self.__stream_state == StreamState.RESET:
|
||||
return
|
||||
|
||||
await self.muxed_stream.reset()
|
||||
|
||||
async with self._state_lock:
|
||||
if self.__stream_state in [
|
||||
StreamState.OPEN,
|
||||
StreamState.CLOSE_READ,
|
||||
StreamState.CLOSE_WRITE,
|
||||
]:
|
||||
self.__stream_state = StreamState.RESET
|
||||
await self._remove()
|
||||
|
||||
async def _remove(self) -> None:
|
||||
"""
|
||||
Remove stream from connection and notify listeners.
|
||||
This is called when the stream is fully closed or reset.
|
||||
"""
|
||||
if hasattr(self.muxed_conn, "remove_stream"):
|
||||
remove_stream = getattr(self.muxed_conn, "remove_stream")
|
||||
await remove_stream(self)
|
||||
|
||||
# Notify in background using Trio nursery if available
|
||||
if self._nursery:
|
||||
self._nursery.start_soon(self._notify_closed)
|
||||
else:
|
||||
await self._notify_closed()
|
||||
|
||||
async def _notify_closed(self) -> None:
|
||||
"""
|
||||
Notify all listeners that the stream has been closed.
|
||||
This runs in a separate task to avoid blocking the main flow.
|
||||
"""
|
||||
async with self._notify_lock:
|
||||
if hasattr(self.muxed_conn, "swarm"):
|
||||
swarm = getattr(self.muxed_conn, "swarm")
|
||||
|
||||
if hasattr(swarm, "notify_all"):
|
||||
await swarm.notify_all(
|
||||
lambda notifiee: notifiee.closed_stream(swarm, self)
|
||||
)
|
||||
|
||||
if hasattr(swarm, "refs") and hasattr(swarm.refs, "done"):
|
||||
swarm.refs.done()
|
||||
|
||||
def get_remote_address(self) -> tuple[str, int] | None:
|
||||
"""Delegate to the underlying muxed stream."""
|
||||
return self.muxed_stream.get_remote_address()
|
||||
|
||||
# TODO: `remove`: Called by close and write when the stream is in specific states.
|
||||
# It notifies `ClosedStream` after `SwarmConn.remove_stream` is called.
|
||||
# Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501
|
||||
async def is_closed(self) -> bool:
|
||||
"""Check if stream is closed."""
|
||||
current_state = await self.state
|
||||
return current_state in [StreamState.CLOSE_BOTH, StreamState.RESET]
|
||||
|
||||
async def is_readable(self) -> bool:
|
||||
"""Check if stream is readable."""
|
||||
current_state = await self.state
|
||||
return current_state not in [
|
||||
StreamState.CLOSE_READ,
|
||||
StreamState.CLOSE_BOTH,
|
||||
StreamState.RESET,
|
||||
]
|
||||
|
||||
async def is_writable(self) -> bool:
|
||||
"""Check if stream is writable."""
|
||||
current_state = await self.state
|
||||
return current_state not in [
|
||||
StreamState.CLOSE_WRITE,
|
||||
StreamState.CLOSE_BOTH,
|
||||
StreamState.RESET,
|
||||
]
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the stream."""
|
||||
return f"<NetStream[{self.__stream_state.value}] protocol={self.protocol_id}>"
|
||||
|
||||
1
newsfragments/300.breaking.rst
Normal file
1
newsfragments/300.breaking.rst
Normal file
@ -0,0 +1 @@
|
||||
The `NetStream.state` property is now async and requires `await`. Update any direct state access to use `await stream.state`.
|
||||
1
newsfragments/300.bugfix.rst
Normal file
1
newsfragments/300.bugfix.rst
Normal file
@ -0,0 +1 @@
|
||||
Added proper state management and resource cleanup to `NetStream`, fixing memory leaks and improved error handling.
|
||||
6
tox.ini
6
tox.ini
@ -19,10 +19,10 @@ max_issue_threshold=1
|
||||
[testenv]
|
||||
usedevelop=True
|
||||
commands=
|
||||
core: pytest {posargs:tests/core}
|
||||
interop: pytest {posargs:tests/interop}
|
||||
core: pytest -n auto {posargs:tests/core}
|
||||
interop: pytest -n auto {posargs:tests/interop}
|
||||
docs: make check-docs-ci
|
||||
demos: pytest {posargs:tests/core/examples/test_examples.py}
|
||||
demos: pytest -n auto {posargs:tests/core/examples/test_examples.py}
|
||||
basepython=
|
||||
docs: python
|
||||
windows-wheel: python
|
||||
|
||||
Reference in New Issue
Block a user