mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
stream_muxer(yamux): add ReadWriteLock to YamuxStream to prevent concurrent read/write corruption
Introduce a read/write lock abstraction and integrate it into `YamuxStream` so that simultaneous reads and writes do not interleave, eliminating potential data corruption and race conditions. Major changes: - Abstract `ReadWriteLock` into its own util module - Integrate locking into YamuxStream for `write` operations - Ensure tests pass for lock correctness - Fix lint & type issues discovered during review Closes #793
This commit is contained in:
@ -1,5 +1,3 @@
|
|||||||
from collections.abc import AsyncGenerator
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from types import (
|
from types import (
|
||||||
TracebackType,
|
TracebackType,
|
||||||
)
|
)
|
||||||
@ -15,6 +13,7 @@ from libp2p.abc import (
|
|||||||
from libp2p.stream_muxer.exceptions import (
|
from libp2p.stream_muxer.exceptions import (
|
||||||
MuxedConnUnavailable,
|
MuxedConnUnavailable,
|
||||||
)
|
)
|
||||||
|
from libp2p.stream_muxer.rw_lock import ReadWriteLock
|
||||||
|
|
||||||
from .constants import (
|
from .constants import (
|
||||||
HeaderTags,
|
HeaderTags,
|
||||||
@ -34,72 +33,6 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ReadWriteLock:
|
|
||||||
"""
|
|
||||||
A read-write lock that allows multiple concurrent readers
|
|
||||||
or one exclusive writer, implemented using Trio primitives.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._readers = 0
|
|
||||||
self._readers_lock = trio.Lock() # Protects access to _readers count
|
|
||||||
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time
|
|
||||||
|
|
||||||
async def acquire_read(self) -> None:
|
|
||||||
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
|
|
||||||
try:
|
|
||||||
async with self._readers_lock:
|
|
||||||
if self._readers == 0:
|
|
||||||
await self._writer_lock.acquire()
|
|
||||||
self._readers += 1
|
|
||||||
except trio.Cancelled:
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def release_read(self) -> None:
|
|
||||||
"""Release a read lock."""
|
|
||||||
async with self._readers_lock:
|
|
||||||
if self._readers == 1:
|
|
||||||
self._writer_lock.release()
|
|
||||||
self._readers -= 1
|
|
||||||
|
|
||||||
async def acquire_write(self) -> None:
|
|
||||||
"""Acquire an exclusive write lock."""
|
|
||||||
try:
|
|
||||||
await self._writer_lock.acquire()
|
|
||||||
except trio.Cancelled:
|
|
||||||
raise
|
|
||||||
|
|
||||||
def release_write(self) -> None:
|
|
||||||
"""Release the exclusive write lock."""
|
|
||||||
self._writer_lock.release()
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def read_lock(self) -> AsyncGenerator[None, None]:
|
|
||||||
"""Context manager for acquiring and releasing a read lock safely."""
|
|
||||||
acquire = False
|
|
||||||
try:
|
|
||||||
await self.acquire_read()
|
|
||||||
acquire = True
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
if acquire:
|
|
||||||
with trio.CancelScope() as scope:
|
|
||||||
scope.shield = True
|
|
||||||
await self.release_read()
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def write_lock(self) -> AsyncGenerator[None, None]:
|
|
||||||
"""Context manager for acquiring and releasing a write lock safely."""
|
|
||||||
acquire = False
|
|
||||||
try:
|
|
||||||
await self.acquire_write()
|
|
||||||
acquire = True
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
if acquire:
|
|
||||||
self.release_write()
|
|
||||||
|
|
||||||
|
|
||||||
class MplexStream(IMuxedStream):
|
class MplexStream(IMuxedStream):
|
||||||
"""
|
"""
|
||||||
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go
|
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go
|
||||||
|
|||||||
70
libp2p/stream_muxer/rw_lock.py
Normal file
70
libp2p/stream_muxer/rw_lock.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
import trio
|
||||||
|
|
||||||
|
|
||||||
|
class ReadWriteLock:
|
||||||
|
"""
|
||||||
|
A read-write lock that allows multiple concurrent readers
|
||||||
|
or one exclusive writer, implemented using Trio primitives.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._readers = 0
|
||||||
|
self._readers_lock = trio.Lock() # Protects access to _readers count
|
||||||
|
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time
|
||||||
|
|
||||||
|
async def acquire_read(self) -> None:
|
||||||
|
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
|
||||||
|
try:
|
||||||
|
async with self._readers_lock:
|
||||||
|
if self._readers == 0:
|
||||||
|
await self._writer_lock.acquire()
|
||||||
|
self._readers += 1
|
||||||
|
except trio.Cancelled:
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def release_read(self) -> None:
|
||||||
|
"""Release a read lock."""
|
||||||
|
async with self._readers_lock:
|
||||||
|
if self._readers == 1:
|
||||||
|
self._writer_lock.release()
|
||||||
|
self._readers -= 1
|
||||||
|
|
||||||
|
async def acquire_write(self) -> None:
|
||||||
|
"""Acquire an exclusive write lock."""
|
||||||
|
try:
|
||||||
|
await self._writer_lock.acquire()
|
||||||
|
except trio.Cancelled:
|
||||||
|
raise
|
||||||
|
|
||||||
|
def release_write(self) -> None:
|
||||||
|
"""Release the exclusive write lock."""
|
||||||
|
self._writer_lock.release()
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def read_lock(self) -> AsyncGenerator[None, None]:
|
||||||
|
"""Context manager for acquiring and releasing a read lock safely."""
|
||||||
|
acquire = False
|
||||||
|
try:
|
||||||
|
await self.acquire_read()
|
||||||
|
acquire = True
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
if acquire:
|
||||||
|
with trio.CancelScope() as scope:
|
||||||
|
scope.shield = True
|
||||||
|
await self.release_read()
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def write_lock(self) -> AsyncGenerator[None, None]:
|
||||||
|
"""Context manager for acquiring and releasing a write lock safely."""
|
||||||
|
acquire = False
|
||||||
|
try:
|
||||||
|
await self.acquire_write()
|
||||||
|
acquire = True
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
if acquire:
|
||||||
|
self.release_write()
|
||||||
@ -44,6 +44,7 @@ from libp2p.stream_muxer.exceptions import (
|
|||||||
MuxedStreamError,
|
MuxedStreamError,
|
||||||
MuxedStreamReset,
|
MuxedStreamReset,
|
||||||
)
|
)
|
||||||
|
from libp2p.stream_muxer.rw_lock import ReadWriteLock
|
||||||
|
|
||||||
# Configure logger for this module
|
# Configure logger for this module
|
||||||
logger = logging.getLogger("libp2p.stream_muxer.yamux")
|
logger = logging.getLogger("libp2p.stream_muxer.yamux")
|
||||||
@ -80,6 +81,8 @@ class YamuxStream(IMuxedStream):
|
|||||||
self.send_window = DEFAULT_WINDOW_SIZE
|
self.send_window = DEFAULT_WINDOW_SIZE
|
||||||
self.recv_window = DEFAULT_WINDOW_SIZE
|
self.recv_window = DEFAULT_WINDOW_SIZE
|
||||||
self.window_lock = trio.Lock()
|
self.window_lock = trio.Lock()
|
||||||
|
self.rw_lock = ReadWriteLock()
|
||||||
|
self.close_lock = trio.Lock()
|
||||||
|
|
||||||
async def __aenter__(self) -> "YamuxStream":
|
async def __aenter__(self) -> "YamuxStream":
|
||||||
"""Enter the async context manager."""
|
"""Enter the async context manager."""
|
||||||
@ -95,52 +98,54 @@ class YamuxStream(IMuxedStream):
|
|||||||
await self.close()
|
await self.close()
|
||||||
|
|
||||||
async def write(self, data: bytes) -> None:
|
async def write(self, data: bytes) -> None:
|
||||||
if self.send_closed:
|
async with self.rw_lock.write_lock():
|
||||||
raise MuxedStreamError("Stream is closed for sending")
|
if self.send_closed:
|
||||||
|
raise MuxedStreamError("Stream is closed for sending")
|
||||||
|
|
||||||
# Flow control: Check if we have enough send window
|
# Flow control: Check if we have enough send window
|
||||||
total_len = len(data)
|
total_len = len(data)
|
||||||
sent = 0
|
sent = 0
|
||||||
logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
|
logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
|
||||||
while sent < total_len:
|
while sent < total_len:
|
||||||
# Wait for available window with timeout
|
# Wait for available window with timeout
|
||||||
timeout = False
|
timeout = False
|
||||||
async with self.window_lock:
|
async with self.window_lock:
|
||||||
if self.send_window == 0:
|
if self.send_window == 0:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Stream {self.stream_id}: Window is zero, waiting for update"
|
f"Stream {self.stream_id}: "
|
||||||
|
"Window is zero, waiting for update"
|
||||||
|
)
|
||||||
|
# Release lock and wait with timeout
|
||||||
|
self.window_lock.release()
|
||||||
|
# To avoid re-acquiring the lock immediately,
|
||||||
|
with trio.move_on_after(5.0) as cancel_scope:
|
||||||
|
while self.send_window == 0 and not self.closed:
|
||||||
|
await trio.sleep(0.01)
|
||||||
|
# If we timed out, cancel the scope
|
||||||
|
timeout = cancel_scope.cancelled_caught
|
||||||
|
# Re-acquire lock
|
||||||
|
await self.window_lock.acquire()
|
||||||
|
|
||||||
|
# If we timed out waiting for window update, raise an error
|
||||||
|
if timeout:
|
||||||
|
raise MuxedStreamError(
|
||||||
|
"Timed out waiting for window update after 5 seconds."
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.closed:
|
||||||
|
raise MuxedStreamError("Stream is closed")
|
||||||
|
|
||||||
|
# Calculate how much we can send now
|
||||||
|
to_send = min(self.send_window, total_len - sent)
|
||||||
|
chunk = data[sent : sent + to_send]
|
||||||
|
self.send_window -= to_send
|
||||||
|
|
||||||
|
# Send the data
|
||||||
|
header = struct.pack(
|
||||||
|
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, 0, self.stream_id, len(chunk)
|
||||||
)
|
)
|
||||||
# Release lock and wait with timeout
|
await self.conn.secured_conn.write(header + chunk)
|
||||||
self.window_lock.release()
|
sent += to_send
|
||||||
# To avoid re-acquiring the lock immediately,
|
|
||||||
with trio.move_on_after(5.0) as cancel_scope:
|
|
||||||
while self.send_window == 0 and not self.closed:
|
|
||||||
await trio.sleep(0.01)
|
|
||||||
# If we timed out, cancel the scope
|
|
||||||
timeout = cancel_scope.cancelled_caught
|
|
||||||
# Re-acquire lock
|
|
||||||
await self.window_lock.acquire()
|
|
||||||
|
|
||||||
# If we timed out waiting for window update, raise an error
|
|
||||||
if timeout:
|
|
||||||
raise MuxedStreamError(
|
|
||||||
"Timed out waiting for window update after 5 seconds."
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.closed:
|
|
||||||
raise MuxedStreamError("Stream is closed")
|
|
||||||
|
|
||||||
# Calculate how much we can send now
|
|
||||||
to_send = min(self.send_window, total_len - sent)
|
|
||||||
chunk = data[sent : sent + to_send]
|
|
||||||
self.send_window -= to_send
|
|
||||||
|
|
||||||
# Send the data
|
|
||||||
header = struct.pack(
|
|
||||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, 0, self.stream_id, len(chunk)
|
|
||||||
)
|
|
||||||
await self.conn.secured_conn.write(header + chunk)
|
|
||||||
sent += to_send
|
|
||||||
|
|
||||||
async def send_window_update(self, increment: int, skip_lock: bool = False) -> None:
|
async def send_window_update(self, increment: int, skip_lock: bool = False) -> None:
|
||||||
"""
|
"""
|
||||||
@ -257,30 +262,32 @@ class YamuxStream(IMuxedStream):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
if not self.send_closed:
|
async with self.close_lock:
|
||||||
logger.debug(f"Half-closing stream {self.stream_id} (local end)")
|
if not self.send_closed:
|
||||||
header = struct.pack(
|
logger.debug(f"Half-closing stream {self.stream_id} (local end)")
|
||||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0
|
header = struct.pack(
|
||||||
)
|
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0
|
||||||
await self.conn.secured_conn.write(header)
|
)
|
||||||
self.send_closed = True
|
await self.conn.secured_conn.write(header)
|
||||||
|
self.send_closed = True
|
||||||
|
|
||||||
# Only set fully closed if both directions are closed
|
# Only set fully closed if both directions are closed
|
||||||
if self.send_closed and self.recv_closed:
|
if self.send_closed and self.recv_closed:
|
||||||
self.closed = True
|
self.closed = True
|
||||||
else:
|
else:
|
||||||
# Stream is half-closed but not fully closed
|
# Stream is half-closed but not fully closed
|
||||||
self.closed = False
|
self.closed = False
|
||||||
|
|
||||||
async def reset(self) -> None:
|
async def reset(self) -> None:
|
||||||
if not self.closed:
|
if not self.closed:
|
||||||
logger.debug(f"Resetting stream {self.stream_id}")
|
async with self.close_lock:
|
||||||
header = struct.pack(
|
logger.debug(f"Resetting stream {self.stream_id}")
|
||||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0
|
header = struct.pack(
|
||||||
)
|
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0
|
||||||
await self.conn.secured_conn.write(header)
|
)
|
||||||
self.closed = True
|
await self.conn.secured_conn.write(header)
|
||||||
self.reset_received = True # Mark as reset
|
self.closed = True
|
||||||
|
self.reset_received = True # Mark as reset
|
||||||
|
|
||||||
def set_deadline(self, ttl: int) -> bool:
|
def set_deadline(self, ttl: int) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|||||||
6
newsfragments/897.bugfix.rst
Normal file
6
newsfragments/897.bugfix.rst
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
enhancement: Add write lock to `YamuxStream` to prevent concurrent write race conditions
|
||||||
|
|
||||||
|
- Implements ReadWriteLock for `YamuxStream` write operations
|
||||||
|
- Prevents data corruption from concurrent write operations
|
||||||
|
- Read operations remain lock-free due to existing `Yamux` architecture
|
||||||
|
- Resolves race conditions identified in Issue #793
|
||||||
@ -1,6 +1,5 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def security_protocol():
|
def security_protocol():
|
||||||
return None
|
return None
|
||||||
Reference in New Issue
Block a user