mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
stream_muxer(yamux): add ReadWriteLock to YamuxStream to prevent concurrent read/write corruption
Introduce a read/write lock abstraction and integrate it into `YamuxStream` so that simultaneous reads and writes do not interleave, eliminating potential data corruption and race conditions. Major changes: - Abstract `ReadWriteLock` into its own util module - Integrate locking into YamuxStream for `write` operations - Ensure tests pass for lock correctness - Fix lint & type issues discovered during review Closes #793
This commit is contained in:
@ -1,5 +1,3 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from types import (
|
||||
TracebackType,
|
||||
)
|
||||
@ -15,6 +13,7 @@ from libp2p.abc import (
|
||||
from libp2p.stream_muxer.exceptions import (
|
||||
MuxedConnUnavailable,
|
||||
)
|
||||
from libp2p.stream_muxer.rw_lock import ReadWriteLock
|
||||
|
||||
from .constants import (
|
||||
HeaderTags,
|
||||
@ -34,72 +33,6 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
|
||||
class ReadWriteLock:
|
||||
"""
|
||||
A read-write lock that allows multiple concurrent readers
|
||||
or one exclusive writer, implemented using Trio primitives.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._readers = 0
|
||||
self._readers_lock = trio.Lock() # Protects access to _readers count
|
||||
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time
|
||||
|
||||
async def acquire_read(self) -> None:
|
||||
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
|
||||
try:
|
||||
async with self._readers_lock:
|
||||
if self._readers == 0:
|
||||
await self._writer_lock.acquire()
|
||||
self._readers += 1
|
||||
except trio.Cancelled:
|
||||
raise
|
||||
|
||||
async def release_read(self) -> None:
|
||||
"""Release a read lock."""
|
||||
async with self._readers_lock:
|
||||
if self._readers == 1:
|
||||
self._writer_lock.release()
|
||||
self._readers -= 1
|
||||
|
||||
async def acquire_write(self) -> None:
|
||||
"""Acquire an exclusive write lock."""
|
||||
try:
|
||||
await self._writer_lock.acquire()
|
||||
except trio.Cancelled:
|
||||
raise
|
||||
|
||||
def release_write(self) -> None:
|
||||
"""Release the exclusive write lock."""
|
||||
self._writer_lock.release()
|
||||
|
||||
@asynccontextmanager
|
||||
async def read_lock(self) -> AsyncGenerator[None, None]:
|
||||
"""Context manager for acquiring and releasing a read lock safely."""
|
||||
acquire = False
|
||||
try:
|
||||
await self.acquire_read()
|
||||
acquire = True
|
||||
yield
|
||||
finally:
|
||||
if acquire:
|
||||
with trio.CancelScope() as scope:
|
||||
scope.shield = True
|
||||
await self.release_read()
|
||||
|
||||
@asynccontextmanager
|
||||
async def write_lock(self) -> AsyncGenerator[None, None]:
|
||||
"""Context manager for acquiring and releasing a write lock safely."""
|
||||
acquire = False
|
||||
try:
|
||||
await self.acquire_write()
|
||||
acquire = True
|
||||
yield
|
||||
finally:
|
||||
if acquire:
|
||||
self.release_write()
|
||||
|
||||
|
||||
class MplexStream(IMuxedStream):
|
||||
"""
|
||||
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go
|
||||
|
||||
70
libp2p/stream_muxer/rw_lock.py
Normal file
70
libp2p/stream_muxer/rw_lock.py
Normal file
@ -0,0 +1,70 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import trio
|
||||
|
||||
|
||||
class ReadWriteLock:
|
||||
"""
|
||||
A read-write lock that allows multiple concurrent readers
|
||||
or one exclusive writer, implemented using Trio primitives.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._readers = 0
|
||||
self._readers_lock = trio.Lock() # Protects access to _readers count
|
||||
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time
|
||||
|
||||
async def acquire_read(self) -> None:
|
||||
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
|
||||
try:
|
||||
async with self._readers_lock:
|
||||
if self._readers == 0:
|
||||
await self._writer_lock.acquire()
|
||||
self._readers += 1
|
||||
except trio.Cancelled:
|
||||
raise
|
||||
|
||||
async def release_read(self) -> None:
|
||||
"""Release a read lock."""
|
||||
async with self._readers_lock:
|
||||
if self._readers == 1:
|
||||
self._writer_lock.release()
|
||||
self._readers -= 1
|
||||
|
||||
async def acquire_write(self) -> None:
|
||||
"""Acquire an exclusive write lock."""
|
||||
try:
|
||||
await self._writer_lock.acquire()
|
||||
except trio.Cancelled:
|
||||
raise
|
||||
|
||||
def release_write(self) -> None:
|
||||
"""Release the exclusive write lock."""
|
||||
self._writer_lock.release()
|
||||
|
||||
@asynccontextmanager
|
||||
async def read_lock(self) -> AsyncGenerator[None, None]:
|
||||
"""Context manager for acquiring and releasing a read lock safely."""
|
||||
acquire = False
|
||||
try:
|
||||
await self.acquire_read()
|
||||
acquire = True
|
||||
yield
|
||||
finally:
|
||||
if acquire:
|
||||
with trio.CancelScope() as scope:
|
||||
scope.shield = True
|
||||
await self.release_read()
|
||||
|
||||
@asynccontextmanager
|
||||
async def write_lock(self) -> AsyncGenerator[None, None]:
|
||||
"""Context manager for acquiring and releasing a write lock safely."""
|
||||
acquire = False
|
||||
try:
|
||||
await self.acquire_write()
|
||||
acquire = True
|
||||
yield
|
||||
finally:
|
||||
if acquire:
|
||||
self.release_write()
|
||||
@ -44,6 +44,7 @@ from libp2p.stream_muxer.exceptions import (
|
||||
MuxedStreamError,
|
||||
MuxedStreamReset,
|
||||
)
|
||||
from libp2p.stream_muxer.rw_lock import ReadWriteLock
|
||||
|
||||
# Configure logger for this module
|
||||
logger = logging.getLogger("libp2p.stream_muxer.yamux")
|
||||
@ -80,6 +81,8 @@ class YamuxStream(IMuxedStream):
|
||||
self.send_window = DEFAULT_WINDOW_SIZE
|
||||
self.recv_window = DEFAULT_WINDOW_SIZE
|
||||
self.window_lock = trio.Lock()
|
||||
self.rw_lock = ReadWriteLock()
|
||||
self.close_lock = trio.Lock()
|
||||
|
||||
async def __aenter__(self) -> "YamuxStream":
|
||||
"""Enter the async context manager."""
|
||||
@ -95,52 +98,54 @@ class YamuxStream(IMuxedStream):
|
||||
await self.close()
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
if self.send_closed:
|
||||
raise MuxedStreamError("Stream is closed for sending")
|
||||
async with self.rw_lock.write_lock():
|
||||
if self.send_closed:
|
||||
raise MuxedStreamError("Stream is closed for sending")
|
||||
|
||||
# Flow control: Check if we have enough send window
|
||||
total_len = len(data)
|
||||
sent = 0
|
||||
logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
|
||||
while sent < total_len:
|
||||
# Wait for available window with timeout
|
||||
timeout = False
|
||||
async with self.window_lock:
|
||||
if self.send_window == 0:
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id}: Window is zero, waiting for update"
|
||||
# Flow control: Check if we have enough send window
|
||||
total_len = len(data)
|
||||
sent = 0
|
||||
logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
|
||||
while sent < total_len:
|
||||
# Wait for available window with timeout
|
||||
timeout = False
|
||||
async with self.window_lock:
|
||||
if self.send_window == 0:
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id}: "
|
||||
"Window is zero, waiting for update"
|
||||
)
|
||||
# Release lock and wait with timeout
|
||||
self.window_lock.release()
|
||||
# To avoid re-acquiring the lock immediately,
|
||||
with trio.move_on_after(5.0) as cancel_scope:
|
||||
while self.send_window == 0 and not self.closed:
|
||||
await trio.sleep(0.01)
|
||||
# If we timed out, cancel the scope
|
||||
timeout = cancel_scope.cancelled_caught
|
||||
# Re-acquire lock
|
||||
await self.window_lock.acquire()
|
||||
|
||||
# If we timed out waiting for window update, raise an error
|
||||
if timeout:
|
||||
raise MuxedStreamError(
|
||||
"Timed out waiting for window update after 5 seconds."
|
||||
)
|
||||
|
||||
if self.closed:
|
||||
raise MuxedStreamError("Stream is closed")
|
||||
|
||||
# Calculate how much we can send now
|
||||
to_send = min(self.send_window, total_len - sent)
|
||||
chunk = data[sent : sent + to_send]
|
||||
self.send_window -= to_send
|
||||
|
||||
# Send the data
|
||||
header = struct.pack(
|
||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, 0, self.stream_id, len(chunk)
|
||||
)
|
||||
# Release lock and wait with timeout
|
||||
self.window_lock.release()
|
||||
# To avoid re-acquiring the lock immediately,
|
||||
with trio.move_on_after(5.0) as cancel_scope:
|
||||
while self.send_window == 0 and not self.closed:
|
||||
await trio.sleep(0.01)
|
||||
# If we timed out, cancel the scope
|
||||
timeout = cancel_scope.cancelled_caught
|
||||
# Re-acquire lock
|
||||
await self.window_lock.acquire()
|
||||
|
||||
# If we timed out waiting for window update, raise an error
|
||||
if timeout:
|
||||
raise MuxedStreamError(
|
||||
"Timed out waiting for window update after 5 seconds."
|
||||
)
|
||||
|
||||
if self.closed:
|
||||
raise MuxedStreamError("Stream is closed")
|
||||
|
||||
# Calculate how much we can send now
|
||||
to_send = min(self.send_window, total_len - sent)
|
||||
chunk = data[sent : sent + to_send]
|
||||
self.send_window -= to_send
|
||||
|
||||
# Send the data
|
||||
header = struct.pack(
|
||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, 0, self.stream_id, len(chunk)
|
||||
)
|
||||
await self.conn.secured_conn.write(header + chunk)
|
||||
sent += to_send
|
||||
await self.conn.secured_conn.write(header + chunk)
|
||||
sent += to_send
|
||||
|
||||
async def send_window_update(self, increment: int, skip_lock: bool = False) -> None:
|
||||
"""
|
||||
@ -257,30 +262,32 @@ class YamuxStream(IMuxedStream):
|
||||
return data
|
||||
|
||||
async def close(self) -> None:
|
||||
if not self.send_closed:
|
||||
logger.debug(f"Half-closing stream {self.stream_id} (local end)")
|
||||
header = struct.pack(
|
||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0
|
||||
)
|
||||
await self.conn.secured_conn.write(header)
|
||||
self.send_closed = True
|
||||
async with self.close_lock:
|
||||
if not self.send_closed:
|
||||
logger.debug(f"Half-closing stream {self.stream_id} (local end)")
|
||||
header = struct.pack(
|
||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0
|
||||
)
|
||||
await self.conn.secured_conn.write(header)
|
||||
self.send_closed = True
|
||||
|
||||
# Only set fully closed if both directions are closed
|
||||
if self.send_closed and self.recv_closed:
|
||||
self.closed = True
|
||||
else:
|
||||
# Stream is half-closed but not fully closed
|
||||
self.closed = False
|
||||
# Only set fully closed if both directions are closed
|
||||
if self.send_closed and self.recv_closed:
|
||||
self.closed = True
|
||||
else:
|
||||
# Stream is half-closed but not fully closed
|
||||
self.closed = False
|
||||
|
||||
async def reset(self) -> None:
|
||||
if not self.closed:
|
||||
logger.debug(f"Resetting stream {self.stream_id}")
|
||||
header = struct.pack(
|
||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0
|
||||
)
|
||||
await self.conn.secured_conn.write(header)
|
||||
self.closed = True
|
||||
self.reset_received = True # Mark as reset
|
||||
async with self.close_lock:
|
||||
logger.debug(f"Resetting stream {self.stream_id}")
|
||||
header = struct.pack(
|
||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0
|
||||
)
|
||||
await self.conn.secured_conn.write(header)
|
||||
self.closed = True
|
||||
self.reset_received = True # Mark as reset
|
||||
|
||||
def set_deadline(self, ttl: int) -> bool:
|
||||
"""
|
||||
|
||||
6
newsfragments/897.bugfix.rst
Normal file
6
newsfragments/897.bugfix.rst
Normal file
@ -0,0 +1,6 @@
|
||||
enhancement: Add write lock to `YamuxStream` to prevent concurrent write race conditions
|
||||
|
||||
- Implements ReadWriteLock for `YamuxStream` write operations
|
||||
- Prevents data corruption from concurrent write operations
|
||||
- Read operations remain lock-free due to existing `Yamux` architecture
|
||||
- Resolves race conditions identified in Issue #793
|
||||
@ -1,6 +1,5 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def security_protocol():
|
||||
return None
|
||||
return None
|
||||
Reference in New Issue
Block a user