mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-11 15:40:54 +00:00
stream_muxer(yamux): add ReadWriteLock to YamuxStream to prevent concurrent read/write corruption
Introduce a read/write lock abstraction and integrate it into `YamuxStream` so that simultaneous reads and writes do not interleave, eliminating potential data corruption and race conditions. Major changes: - Abstract `ReadWriteLock` into its own util module - Integrate locking into YamuxStream for `write` operations - Ensure tests pass for lock correctness - Fix lint & type issues discovered during review Closes #793
This commit is contained in:
@ -1,5 +1,3 @@
|
|||||||
from collections.abc import AsyncGenerator
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from types import (
|
from types import (
|
||||||
TracebackType,
|
TracebackType,
|
||||||
)
|
)
|
||||||
@ -15,6 +13,7 @@ from libp2p.abc import (
|
|||||||
from libp2p.stream_muxer.exceptions import (
|
from libp2p.stream_muxer.exceptions import (
|
||||||
MuxedConnUnavailable,
|
MuxedConnUnavailable,
|
||||||
)
|
)
|
||||||
|
from libp2p.stream_muxer.rw_lock import ReadWriteLock
|
||||||
|
|
||||||
from .constants import (
|
from .constants import (
|
||||||
HeaderTags,
|
HeaderTags,
|
||||||
@ -34,72 +33,6 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ReadWriteLock:
|
|
||||||
"""
|
|
||||||
A read-write lock that allows multiple concurrent readers
|
|
||||||
or one exclusive writer, implemented using Trio primitives.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._readers = 0
|
|
||||||
self._readers_lock = trio.Lock() # Protects access to _readers count
|
|
||||||
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time
|
|
||||||
|
|
||||||
async def acquire_read(self) -> None:
|
|
||||||
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
|
|
||||||
try:
|
|
||||||
async with self._readers_lock:
|
|
||||||
if self._readers == 0:
|
|
||||||
await self._writer_lock.acquire()
|
|
||||||
self._readers += 1
|
|
||||||
except trio.Cancelled:
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def release_read(self) -> None:
|
|
||||||
"""Release a read lock."""
|
|
||||||
async with self._readers_lock:
|
|
||||||
if self._readers == 1:
|
|
||||||
self._writer_lock.release()
|
|
||||||
self._readers -= 1
|
|
||||||
|
|
||||||
async def acquire_write(self) -> None:
|
|
||||||
"""Acquire an exclusive write lock."""
|
|
||||||
try:
|
|
||||||
await self._writer_lock.acquire()
|
|
||||||
except trio.Cancelled:
|
|
||||||
raise
|
|
||||||
|
|
||||||
def release_write(self) -> None:
|
|
||||||
"""Release the exclusive write lock."""
|
|
||||||
self._writer_lock.release()
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def read_lock(self) -> AsyncGenerator[None, None]:
|
|
||||||
"""Context manager for acquiring and releasing a read lock safely."""
|
|
||||||
acquire = False
|
|
||||||
try:
|
|
||||||
await self.acquire_read()
|
|
||||||
acquire = True
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
if acquire:
|
|
||||||
with trio.CancelScope() as scope:
|
|
||||||
scope.shield = True
|
|
||||||
await self.release_read()
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def write_lock(self) -> AsyncGenerator[None, None]:
|
|
||||||
"""Context manager for acquiring and releasing a write lock safely."""
|
|
||||||
acquire = False
|
|
||||||
try:
|
|
||||||
await self.acquire_write()
|
|
||||||
acquire = True
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
if acquire:
|
|
||||||
self.release_write()
|
|
||||||
|
|
||||||
|
|
||||||
class MplexStream(IMuxedStream):
|
class MplexStream(IMuxedStream):
|
||||||
"""
|
"""
|
||||||
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go
|
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go
|
||||||
|
|||||||
70
libp2p/stream_muxer/rw_lock.py
Normal file
70
libp2p/stream_muxer/rw_lock.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
import trio
|
||||||
|
|
||||||
|
|
||||||
|
class ReadWriteLock:
|
||||||
|
"""
|
||||||
|
A read-write lock that allows multiple concurrent readers
|
||||||
|
or one exclusive writer, implemented using Trio primitives.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._readers = 0
|
||||||
|
self._readers_lock = trio.Lock() # Protects access to _readers count
|
||||||
|
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time
|
||||||
|
|
||||||
|
async def acquire_read(self) -> None:
|
||||||
|
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
|
||||||
|
try:
|
||||||
|
async with self._readers_lock:
|
||||||
|
if self._readers == 0:
|
||||||
|
await self._writer_lock.acquire()
|
||||||
|
self._readers += 1
|
||||||
|
except trio.Cancelled:
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def release_read(self) -> None:
|
||||||
|
"""Release a read lock."""
|
||||||
|
async with self._readers_lock:
|
||||||
|
if self._readers == 1:
|
||||||
|
self._writer_lock.release()
|
||||||
|
self._readers -= 1
|
||||||
|
|
||||||
|
async def acquire_write(self) -> None:
|
||||||
|
"""Acquire an exclusive write lock."""
|
||||||
|
try:
|
||||||
|
await self._writer_lock.acquire()
|
||||||
|
except trio.Cancelled:
|
||||||
|
raise
|
||||||
|
|
||||||
|
def release_write(self) -> None:
|
||||||
|
"""Release the exclusive write lock."""
|
||||||
|
self._writer_lock.release()
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def read_lock(self) -> AsyncGenerator[None, None]:
|
||||||
|
"""Context manager for acquiring and releasing a read lock safely."""
|
||||||
|
acquire = False
|
||||||
|
try:
|
||||||
|
await self.acquire_read()
|
||||||
|
acquire = True
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
if acquire:
|
||||||
|
with trio.CancelScope() as scope:
|
||||||
|
scope.shield = True
|
||||||
|
await self.release_read()
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def write_lock(self) -> AsyncGenerator[None, None]:
|
||||||
|
"""Context manager for acquiring and releasing a write lock safely."""
|
||||||
|
acquire = False
|
||||||
|
try:
|
||||||
|
await self.acquire_write()
|
||||||
|
acquire = True
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
if acquire:
|
||||||
|
self.release_write()
|
||||||
@ -44,6 +44,7 @@ from libp2p.stream_muxer.exceptions import (
|
|||||||
MuxedStreamError,
|
MuxedStreamError,
|
||||||
MuxedStreamReset,
|
MuxedStreamReset,
|
||||||
)
|
)
|
||||||
|
from libp2p.stream_muxer.rw_lock import ReadWriteLock
|
||||||
|
|
||||||
# Configure logger for this module
|
# Configure logger for this module
|
||||||
logger = logging.getLogger("libp2p.stream_muxer.yamux")
|
logger = logging.getLogger("libp2p.stream_muxer.yamux")
|
||||||
@ -80,6 +81,8 @@ class YamuxStream(IMuxedStream):
|
|||||||
self.send_window = DEFAULT_WINDOW_SIZE
|
self.send_window = DEFAULT_WINDOW_SIZE
|
||||||
self.recv_window = DEFAULT_WINDOW_SIZE
|
self.recv_window = DEFAULT_WINDOW_SIZE
|
||||||
self.window_lock = trio.Lock()
|
self.window_lock = trio.Lock()
|
||||||
|
self.rw_lock = ReadWriteLock()
|
||||||
|
self.close_lock = trio.Lock()
|
||||||
|
|
||||||
async def __aenter__(self) -> "YamuxStream":
|
async def __aenter__(self) -> "YamuxStream":
|
||||||
"""Enter the async context manager."""
|
"""Enter the async context manager."""
|
||||||
@ -95,6 +98,7 @@ class YamuxStream(IMuxedStream):
|
|||||||
await self.close()
|
await self.close()
|
||||||
|
|
||||||
async def write(self, data: bytes) -> None:
|
async def write(self, data: bytes) -> None:
|
||||||
|
async with self.rw_lock.write_lock():
|
||||||
if self.send_closed:
|
if self.send_closed:
|
||||||
raise MuxedStreamError("Stream is closed for sending")
|
raise MuxedStreamError("Stream is closed for sending")
|
||||||
|
|
||||||
@ -108,7 +112,8 @@ class YamuxStream(IMuxedStream):
|
|||||||
async with self.window_lock:
|
async with self.window_lock:
|
||||||
if self.send_window == 0:
|
if self.send_window == 0:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Stream {self.stream_id}: Window is zero, waiting for update"
|
f"Stream {self.stream_id}: "
|
||||||
|
"Window is zero, waiting for update"
|
||||||
)
|
)
|
||||||
# Release lock and wait with timeout
|
# Release lock and wait with timeout
|
||||||
self.window_lock.release()
|
self.window_lock.release()
|
||||||
@ -257,6 +262,7 @@ class YamuxStream(IMuxedStream):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
|
async with self.close_lock:
|
||||||
if not self.send_closed:
|
if not self.send_closed:
|
||||||
logger.debug(f"Half-closing stream {self.stream_id} (local end)")
|
logger.debug(f"Half-closing stream {self.stream_id} (local end)")
|
||||||
header = struct.pack(
|
header = struct.pack(
|
||||||
@ -274,6 +280,7 @@ class YamuxStream(IMuxedStream):
|
|||||||
|
|
||||||
async def reset(self) -> None:
|
async def reset(self) -> None:
|
||||||
if not self.closed:
|
if not self.closed:
|
||||||
|
async with self.close_lock:
|
||||||
logger.debug(f"Resetting stream {self.stream_id}")
|
logger.debug(f"Resetting stream {self.stream_id}")
|
||||||
header = struct.pack(
|
header = struct.pack(
|
||||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0
|
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0
|
||||||
|
|||||||
6
newsfragments/897.bugfix.rst
Normal file
6
newsfragments/897.bugfix.rst
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
enhancement: Add write lock to `YamuxStream` to prevent concurrent write race conditions
|
||||||
|
|
||||||
|
- Implements ReadWriteLock for `YamuxStream` write operations
|
||||||
|
- Prevents data corruption from concurrent write operations
|
||||||
|
- Read operations remain lock-free due to existing `Yamux` architecture
|
||||||
|
- Resolves race conditions identified in Issue #793
|
||||||
@ -1,6 +1,5 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def security_protocol():
|
def security_protocol():
|
||||||
return None
|
return None
|
||||||
Reference in New Issue
Block a user