From 91dca97d8363df816fe3498057732b5d238658ea Mon Sep 17 00:00:00 2001 From: Jinesh Jain <732005jinesh@gmail.com> Date: Mon, 7 Jul 2025 21:55:32 +0530 Subject: [PATCH 01/22] TODO: add read/write lock --- libp2p/stream_muxer/mplex/mplex_stream.py | 138 +++++++++++++--------- 1 file changed, 85 insertions(+), 53 deletions(-) diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 3b640df1..8f45495f 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -31,6 +31,29 @@ if TYPE_CHECKING: Mplex, ) +class ReadWriteLock: + def __init__(self): + self._readers = 0 + self._lock = trio.Lock() # Protects _readers + self._write_lock = trio.Lock() + + async def acquire_read(self): + async with self._lock: + self._readers += 1 + if self._readers == 1: + await self._write_lock.acquire() + + async def release_read(self): + async with self._lock: + self._readers -= 1 + if self._readers == 0: + self._write_lock.release() + + async def acquire_write(self): + await self._write_lock.acquire() + + def release_write(self): + self._write_lock.release() class MplexStream(IMuxedStream): """ @@ -39,17 +62,17 @@ class MplexStream(IMuxedStream): name: str stream_id: StreamID - # NOTE: All methods used here are part of `Mplex` which is a derived + # NOTE: All methods used here are part of Mplex which is a derived # class of IMuxedConn. Ignoring this type assignment should not pose # any risk. muxed_conn: "Mplex" # type: ignore[assignment] read_deadline: int | None write_deadline: int | None - # TODO: Add lock for read/write to avoid interleaving receiving messages? + rw_lock: ReadWriteLock close_lock: trio.Lock - # NOTE: `dataIn` is size of 8 in Go implementation. + # NOTE: dataIn is size of 8 in Go implementation. incoming_data_channel: "trio.MemoryReceiveChannel[bytes]" event_local_closed: trio.Event @@ -80,6 +103,7 @@ class MplexStream(IMuxedStream): self.event_remote_closed = trio.Event() self.event_reset = trio.Event() self.close_lock = trio.Lock() + self.rw_lock = ReadWriteLock() self.incoming_data_channel = incoming_data_channel self._buf = bytearray() @@ -106,55 +130,59 @@ class MplexStream(IMuxedStream): async def read(self, n: int | None = None) -> bytes: """ - Read up to n bytes. Read possibly returns fewer than `n` bytes, if - there are not enough bytes in the Mplex buffer. If `n is None`, read + Read up to n bytes. Read possibly returns fewer than n bytes, if + there are not enough bytes in the Mplex buffer. If n is None, read until EOF. :param n: number of bytes to read :return: bytes actually read """ - if n is not None and n < 0: - raise ValueError( - "the number of bytes to read `n` must be non-negative or " - f"`None` to indicate read until EOF, got n={n}" - ) - if self.event_reset.is_set(): - raise MplexStreamReset - if n is None: - return await self._read_until_eof() - if len(self._buf) == 0: - data: bytes - # Peek whether there is data available. If yes, we just read until there is - # no data, then return. - try: - data = self.incoming_data_channel.receive_nowait() - self._buf.extend(data) - except trio.EndOfChannel: - raise MplexStreamEOF - except trio.WouldBlock: - # We know `receive` will be blocked here. Wait for data here with - # `receive` and catch all kinds of errors here. + await self.rw_lock.acquire_read() + try: + if n is not None and n < 0: + raise ValueError( + "the number of bytes to read n must be non-negative or " + f"None to indicate read until EOF, got n={n}" + ) + if self.event_reset.is_set(): + raise MplexStreamReset + if n is None: + return await self._read_until_eof() + if len(self._buf) == 0: + data: bytes + # Peek whether there is data available. If yes, we just read until there is + # no data, then return. try: - data = await self.incoming_data_channel.receive() + data = self.incoming_data_channel.receive_nowait() self._buf.extend(data) except trio.EndOfChannel: - if self.event_reset.is_set(): - raise MplexStreamReset - if self.event_remote_closed.is_set(): - raise MplexStreamEOF - except trio.ClosedResourceError as error: - # Probably `incoming_data_channel` is closed in `reset` when we are - # waiting for `receive`. - if self.event_reset.is_set(): - raise MplexStreamReset - raise Exception( - "`incoming_data_channel` is closed but stream is not reset. " - "This should never happen." - ) from error - self._buf.extend(self._read_return_when_blocked()) - payload = self._buf[:n] - self._buf = self._buf[len(payload) :] - return bytes(payload) + raise MplexStreamEOF + except trio.WouldBlock: + # We know receive will be blocked here. Wait for data here with + # receive and catch all kinds of errors here. + try: + data = await self.incoming_data_channel.receive() + self._buf.extend(data) + except trio.EndOfChannel: + if self.event_reset.is_set(): + raise MplexStreamReset + if self.event_remote_closed.is_set(): + raise MplexStreamEOF + except trio.ClosedResourceError as error: + # Probably incoming_data_channel is closed in reset when we are + # waiting for receive. + if self.event_reset.is_set(): + raise MplexStreamReset + raise Exception( + "incoming_data_channel is closed but stream is not reset. " + "This should never happen." + ) from error + self._buf.extend(self._read_return_when_blocked()) + payload = self._buf[:n] + self._buf = self._buf[len(payload) :] + return bytes(payload) + finally: + await self.rw_lock.release_read() async def write(self, data: bytes) -> None: """ @@ -162,14 +190,18 @@ class MplexStream(IMuxedStream): :return: number of bytes written """ - if self.event_local_closed.is_set(): - raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}") - flag = ( - HeaderTags.MessageInitiator - if self.is_initiator - else HeaderTags.MessageReceiver - ) - await self.muxed_conn.send_message(flag, data, self.stream_id) + await self.rw_lock.acquire_write() + try: + if self.event_local_closed.is_set(): + raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}") + flag = ( + HeaderTags.MessageInitiator + if self.is_initiator + else HeaderTags.MessageReceiver + ) + await self.muxed_conn.send_message(flag, data, self.stream_id) + finally: + self.rw_lock.release_write() async def close(self) -> None: """ @@ -185,7 +217,7 @@ class MplexStream(IMuxedStream): flag = ( HeaderTags.CloseInitiator if self.is_initiator else HeaderTags.CloseReceiver ) - # TODO: Raise when `muxed_conn.send_message` fails and `Mplex` isn't shutdown. + # TODO: Raise when muxed_conn.send_message fails and Mplex isn't shutdown. await self.muxed_conn.send_message(flag, None, self.stream_id) _is_remote_closed: bool From 75abc8b863f6804d5be7206c66302e7f1b9052e3 Mon Sep 17 00:00:00 2001 From: Jinesh Jain <732005jinesh@gmail.com> Date: Tue, 8 Jul 2025 07:35:45 +0530 Subject: [PATCH 02/22] run ruff format --- libp2p/stream_muxer/mplex/mplex_stream.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 8f45495f..a7a510b6 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -31,6 +31,7 @@ if TYPE_CHECKING: Mplex, ) + class ReadWriteLock: def __init__(self): self._readers = 0 @@ -55,6 +56,7 @@ class ReadWriteLock: def release_write(self): self._write_lock.release() + class MplexStream(IMuxedStream): """ reference: https://github.com/libp2p/go-mplex/blob/master/stream.go From 8fb664bfdfd1df7e2a840d1797c1d5d749ea88c0 Mon Sep 17 00:00:00 2001 From: Jinesh Jain <732005jinesh@gmail.com> Date: Tue, 8 Jul 2025 18:34:30 +0530 Subject: [PATCH 03/22] Fix: linting errors --- libp2p/stream_muxer/mplex/mplex_stream.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index a7a510b6..7c5dc6ab 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -33,27 +33,27 @@ if TYPE_CHECKING: class ReadWriteLock: - def __init__(self): + def __init__(self) -> None: self._readers = 0 self._lock = trio.Lock() # Protects _readers self._write_lock = trio.Lock() - async def acquire_read(self): + async def acquire_read(self) -> None: async with self._lock: self._readers += 1 if self._readers == 1: await self._write_lock.acquire() - async def release_read(self): + async def release_read(self) -> None: async with self._lock: self._readers -= 1 if self._readers == 0: self._write_lock.release() - async def acquire_write(self): + async def acquire_write(self) -> None: await self._write_lock.acquire() - def release_write(self): + def release_write(self) -> None: self._write_lock.release() @@ -152,8 +152,8 @@ class MplexStream(IMuxedStream): return await self._read_until_eof() if len(self._buf) == 0: data: bytes - # Peek whether there is data available. If yes, we just read until there is - # no data, then return. + # Peek whether there is data available. If yes, we just read until + # there is no data, then return. try: data = self.incoming_data_channel.receive_nowait() self._buf.extend(data) @@ -185,6 +185,7 @@ class MplexStream(IMuxedStream): return bytes(payload) finally: await self.rw_lock.release_read() + return b"" async def write(self, data: bytes) -> None: """ From e65e38a3f1fbdebcff1b71374ed2256ea48c441c Mon Sep 17 00:00:00 2001 From: Jinesh Jain <732005jinesh@gmail.com> Date: Tue, 8 Jul 2025 19:11:56 +0530 Subject: [PATCH 04/22] fix: linting error related to read --- libp2p/stream_muxer/mplex/mplex_stream.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 7c5dc6ab..9c5f04da 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -140,6 +140,7 @@ class MplexStream(IMuxedStream): :return: bytes actually read """ await self.rw_lock.acquire_read() + payload: bytes = b"" try: if n is not None and n < 0: raise ValueError( @@ -180,12 +181,12 @@ class MplexStream(IMuxedStream): "This should never happen." ) from error self._buf.extend(self._read_return_when_blocked()) - payload = self._buf[:n] - self._buf = self._buf[len(payload) :] - return bytes(payload) + chunk = self._buf[:n] + self._buf = self._buf[len(chunk) :] + payload = bytes(chunk) finally: await self.rw_lock.release_read() - return b"" + return payload async def write(self, data: bytes) -> None: """ From 5f497c7f5dab497ccc8017f024e40c51959a5c3d Mon Sep 17 00:00:00 2001 From: Jinesh Jain <732005jinesh@gmail.com> Date: Tue, 8 Jul 2025 19:17:43 +0530 Subject: [PATCH 05/22] add file in newsfragments folder --- newsfragments/748.feature.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 newsfragments/748.feature.rst diff --git a/newsfragments/748.feature.rst b/newsfragments/748.feature.rst new file mode 100644 index 00000000..199e5b3b --- /dev/null +++ b/newsfragments/748.feature.rst @@ -0,0 +1 @@ + Add lock for read/write to avoid interleaving receiving messages in mplex_stream.py From 242998ae9ddee4fac12d2409e88fd883471b8c2c Mon Sep 17 00:00:00 2001 From: Jinesh Jain <732005jinesh@gmail.com> Date: Tue, 8 Jul 2025 20:06:30 +0530 Subject: [PATCH 06/22] add test for read-write-lock --- libp2p/stream_muxer/mplex/mplex_stream.py | 16 +- .../mplex/test_read_write_lock.py | 203 ++++++++++++++++++ 2 files changed, 211 insertions(+), 8 deletions(-) create mode 100644 libp2p/stream_muxer/mplex/test_read_write_lock.py diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 9c5f04da..e6b98244 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -35,26 +35,26 @@ if TYPE_CHECKING: class ReadWriteLock: def __init__(self) -> None: self._readers = 0 - self._lock = trio.Lock() # Protects _readers - self._write_lock = trio.Lock() + self._readers_lock = trio.Lock() # Protects readers count + self._writer_lock = trio.Semaphore(1) # Acts like a task-transferable lock async def acquire_read(self) -> None: - async with self._lock: + async with self._readers_lock: self._readers += 1 if self._readers == 1: - await self._write_lock.acquire() + await self._writer_lock.acquire() async def release_read(self) -> None: - async with self._lock: + async with self._readers_lock: self._readers -= 1 if self._readers == 0: - self._write_lock.release() + self._writer_lock.release() async def acquire_write(self) -> None: - await self._write_lock.acquire() + await self._writer_lock.acquire() def release_write(self) -> None: - self._write_lock.release() + self._writer_lock.release() class MplexStream(IMuxedStream): diff --git a/libp2p/stream_muxer/mplex/test_read_write_lock.py b/libp2p/stream_muxer/mplex/test_read_write_lock.py new file mode 100644 index 00000000..c52aa36c --- /dev/null +++ b/libp2p/stream_muxer/mplex/test_read_write_lock.py @@ -0,0 +1,203 @@ +from unittest.mock import AsyncMock, MagicMock + +import pytest +import trio + +from libp2p.stream_muxer.mplex.mplex_stream import MplexStream, StreamID + + +@pytest.fixture +def stream_with_lock() -> tuple[MplexStream, trio.MemorySendChannel[bytes]]: + muxed_conn = MagicMock() + muxed_conn.send_message = AsyncMock() + muxed_conn.streams_lock = trio.Lock() + muxed_conn.streams = {} + muxed_conn.get_remote_address = MagicMock(return_value=("127.0.0.1", 8000)) + + send_chan: trio.MemorySendChannel[bytes] + recv_chan: trio.MemoryReceiveChannel[bytes] + send_chan, recv_chan = trio.open_memory_channel(0) + + dummy_stream_id = MagicMock(spec=StreamID) + dummy_stream_id.is_initiator = True # mock read-only property + + stream = MplexStream( + name="test", + stream_id=dummy_stream_id, + muxed_conn=muxed_conn, + incoming_data_channel=recv_chan, + ) + return stream, send_chan + + +@pytest.mark.trio +async def test_writing_blocked_if_read_in_progress( + stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]], +) -> None: + stream, _ = stream_with_lock + log: list[str] = [] + + async def reader() -> None: + await stream.rw_lock.acquire_read() + log.append("read_acquired") + await trio.sleep(0.3) + log.append("read_released") + await stream.rw_lock.release_read() + + async def writer() -> None: + await stream.rw_lock.acquire_write() + log.append("write_acquired") + await trio.sleep(0.1) + log.append("write_released") + stream.rw_lock.release_write() + + async with trio.open_nursery() as nursery: + nursery.start_soon(reader) + await trio.sleep(0.05) + nursery.start_soon(writer) + + assert log == [ + "read_acquired", + "read_released", + "write_acquired", + "write_released", + ], f"Unexpected order: {log}" + + +@pytest.mark.trio +async def test_reading_blocked_if_write_in_progress( + stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]], +) -> None: + stream, _ = stream_with_lock + log: list[str] = [] + + async def writer() -> None: + await stream.rw_lock.acquire_write() + log.append("write_acquired") + await trio.sleep(0.3) + log.append("write_released") + stream.rw_lock.release_write() + + async def reader() -> None: + await stream.rw_lock.acquire_read() + log.append("read_acquired") + await trio.sleep(0.1) + log.append("read_released") + await stream.rw_lock.release_read() + + async with trio.open_nursery() as nursery: + nursery.start_soon(writer) + await trio.sleep(0.05) + nursery.start_soon(reader) + + assert log == [ + "write_acquired", + "write_released", + "read_acquired", + "read_released", + ], f"Unexpected order: {log}" + + +@pytest.mark.trio +async def test_multiple_reads_allowed_concurrently( + stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]], +) -> None: + stream, _ = stream_with_lock + log: list[str] = [] + + async def read_task(i: int) -> None: + await stream.rw_lock.acquire_read() + log.append(f"read_{i}_acquired") + await trio.sleep(0.2) + log.append(f"read_{i}_released") + await stream.rw_lock.release_read() + + async with trio.open_nursery() as nursery: + for i in range(5): + nursery.start_soon(read_task, i) + + acquires = [entry for entry in log if "acquired" in entry] + releases = [entry for entry in log if "released" in entry] + + assert len(acquires) == 5 and len(releases) == 5, "Not all reads executed" + assert all( + log.index(acq) < min(log.index(rel) for rel in releases) for acq in acquires + ), f"Reads didn't overlap properly: {log}" + + +@pytest.mark.trio +async def test_only_one_write_allowed( + stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]], +) -> None: + stream, _ = stream_with_lock + log: list[str] = [] + + async def write_task(i: int) -> None: + await stream.rw_lock.acquire_write() + log.append(f"write_{i}_acquired") + await trio.sleep(0.2) + log.append(f"write_{i}_released") + stream.rw_lock.release_write() + + async with trio.open_nursery() as nursery: + for i in range(5): + nursery.start_soon(write_task, i) + + active = 0 + for entry in log: + if "acquired" in entry: + active += 1 + elif "released" in entry: + active -= 1 + assert active <= 1, f"More than one write active: {log}" + assert active == 0, f"Write locks not properly released: {log}" + + +@pytest.mark.trio +async def test_interleaved_read_write_behavior( + stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]], +) -> None: + stream, _ = stream_with_lock + log: list[str] = [] + + async def read(i: int) -> None: + await stream.rw_lock.acquire_read() + log.append(f"read_{i}_acquired") + await trio.sleep(0.15) + log.append(f"read_{i}_released") + await stream.rw_lock.release_read() + + async def write(i: int) -> None: + await stream.rw_lock.acquire_write() + log.append(f"write_{i}_acquired") + await trio.sleep(0.2) + log.append(f"write_{i}_released") + stream.rw_lock.release_write() + + async with trio.open_nursery() as nursery: + nursery.start_soon(read, 1) + await trio.sleep(0.05) + nursery.start_soon(read, 2) + await trio.sleep(0.05) + nursery.start_soon(write, 1) + await trio.sleep(0.05) + nursery.start_soon(read, 3) + await trio.sleep(0.05) + nursery.start_soon(write, 2) + + read1_released = log.index("read_1_released") + read2_released = log.index("read_2_released") + write1_acquired = log.index("write_1_acquired") + assert write1_acquired > read1_released and write1_acquired > read2_released, ( + f"write_1 acquired too early: {log}" + ) + + read3_acquired = log.index("read_3_acquired") + read3_released = log.index("read_3_released") + write1_released = log.index("write_1_released") + assert read3_released < write1_acquired or read3_acquired > write1_released, ( + f"read_3 improperly overlapped with write_1: {log}" + ) + + write2_acquired = log.index("write_2_acquired") + assert write2_acquired > write1_released, f"write_2 started too early: {log}" From 26ed99dafda8aa65359353182a4ad48238a52883 Mon Sep 17 00:00:00 2001 From: Jinesh Jain <732005jinesh@gmail.com> Date: Wed, 9 Jul 2025 18:09:07 +0530 Subject: [PATCH 07/22] change tests path --- .../mplex => tests/core/stream_muxer}/test_read_write_lock.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {libp2p/stream_muxer/mplex => tests/core/stream_muxer}/test_read_write_lock.py (100%) diff --git a/libp2p/stream_muxer/mplex/test_read_write_lock.py b/tests/core/stream_muxer/test_read_write_lock.py similarity index 100% rename from libp2p/stream_muxer/mplex/test_read_write_lock.py rename to tests/core/stream_muxer/test_read_write_lock.py From cda163fc48d32033fe45b3fb3bc9543ec5345aaf Mon Sep 17 00:00:00 2001 From: Jinesh Jain <732005jinesh@gmail.com> Date: Wed, 9 Jul 2025 18:18:37 +0530 Subject: [PATCH 08/22] change ReadWriteLock class --- libp2p/stream_muxer/mplex/mplex_stream.py | 41 +++++++++++++++++++---- 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index e6b98244..91f872d4 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -1,3 +1,5 @@ +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager from types import ( TracebackType, ) @@ -36,13 +38,21 @@ class ReadWriteLock: def __init__(self) -> None: self._readers = 0 self._readers_lock = trio.Lock() # Protects readers count - self._writer_lock = trio.Semaphore(1) # Acts like a task-transferable lock + self._writer_lock = trio.Semaphore(1) # Ensures mutual exclusion for writers async def acquire_read(self) -> None: - async with self._readers_lock: - self._readers += 1 - if self._readers == 1: - await self._writer_lock.acquire() + try: + async with self._readers_lock: + self._readers += 1 + if self._readers == 1: + await self._writer_lock.acquire() + except trio.Cancelled: + async with self._readers_lock: + if self._readers > 0: + self._readers -= 1 + if self._readers == 0: + self._writer_lock.release() + raise async def release_read(self) -> None: async with self._readers_lock: @@ -51,11 +61,30 @@ class ReadWriteLock: self._writer_lock.release() async def acquire_write(self) -> None: - await self._writer_lock.acquire() + try: + await self._writer_lock.acquire() + except trio.Cancelled: + raise def release_write(self) -> None: self._writer_lock.release() + @asynccontextmanager + async def read_lock(self) -> AsyncGenerator[None, None]: + await self.acquire_read() + try: + yield + finally: + await self.release_read() + + @asynccontextmanager + async def write_lock(self) -> AsyncGenerator[None, None]: + await self.acquire_write() + try: + yield + finally: + self.release_write() + class MplexStream(IMuxedStream): """ From 9cd38055427903b1d1beddcefe3452555d93f55f Mon Sep 17 00:00:00 2001 From: Jinesh Jain <732005jinesh@gmail.com> Date: Sun, 13 Jul 2025 18:37:44 +0530 Subject: [PATCH 09/22] make readwrite more safe --- libp2p/stream_muxer/mplex/mplex_stream.py | 62 +- .../core/stream_muxer/test_read_write_lock.py | 731 +++++++++++++----- 2 files changed, 593 insertions(+), 200 deletions(-) diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 91f872d4..7e439160 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -35,55 +35,69 @@ if TYPE_CHECKING: class ReadWriteLock: + """ + A read-write lock that allows multiple concurrent readers + or one exclusive writer, implemented using Trio primitives. + """ + def __init__(self) -> None: self._readers = 0 - self._readers_lock = trio.Lock() # Protects readers count - self._writer_lock = trio.Semaphore(1) # Ensures mutual exclusion for writers + self._readers_lock = trio.Lock() # Protects access to _readers count + self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time async def acquire_read(self) -> None: + """Acquire a read lock. Multiple readers can hold it simultaneously.""" try: async with self._readers_lock: - self._readers += 1 - if self._readers == 1: + if self._readers == 0: await self._writer_lock.acquire() + self._readers += 1 except trio.Cancelled: - async with self._readers_lock: - if self._readers > 0: - self._readers -= 1 - if self._readers == 0: - self._writer_lock.release() raise async def release_read(self) -> None: + """Release a read lock.""" async with self._readers_lock: - self._readers -= 1 - if self._readers == 0: + if self._readers == 1: self._writer_lock.release() + self._readers -= 1 async def acquire_write(self) -> None: + """Acquire an exclusive write lock.""" try: await self._writer_lock.acquire() except trio.Cancelled: raise def release_write(self) -> None: + """Release the exclusive write lock.""" self._writer_lock.release() @asynccontextmanager async def read_lock(self) -> AsyncGenerator[None, None]: - await self.acquire_read() + """Context manager for acquiring and releasing a read lock safely.""" + acquire = False try: + await self.acquire_read() + acquire = True yield finally: - await self.release_read() + if acquire: + with trio.CancelScope() as scope: + scope.shield = True + await self.release_read() @asynccontextmanager async def write_lock(self) -> AsyncGenerator[None, None]: - await self.acquire_write() + """Context manager for acquiring and releasing a write lock safely.""" + acquire = False try: + await self.acquire_write() + acquire = True yield finally: - self.release_write() + if acquire: + self.release_write() class MplexStream(IMuxedStream): @@ -168,9 +182,7 @@ class MplexStream(IMuxedStream): :param n: number of bytes to read :return: bytes actually read """ - await self.rw_lock.acquire_read() - payload: bytes = b"" - try: + async with self.rw_lock.read_lock(): if n is not None and n < 0: raise ValueError( "the number of bytes to read n must be non-negative or " @@ -210,12 +222,9 @@ class MplexStream(IMuxedStream): "This should never happen." ) from error self._buf.extend(self._read_return_when_blocked()) - chunk = self._buf[:n] - self._buf = self._buf[len(chunk) :] - payload = bytes(chunk) - finally: - await self.rw_lock.release_read() - return payload + payload = self._buf[:n] + self._buf = self._buf[len(payload) :] + return bytes(payload) async def write(self, data: bytes) -> None: """ @@ -223,8 +232,7 @@ class MplexStream(IMuxedStream): :return: number of bytes written """ - await self.rw_lock.acquire_write() - try: + async with self.rw_lock.write_lock(): if self.event_local_closed.is_set(): raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}") flag = ( @@ -233,8 +241,6 @@ class MplexStream(IMuxedStream): else HeaderTags.MessageReceiver ) await self.muxed_conn.send_message(flag, data, self.stream_id) - finally: - self.rw_lock.release_write() async def close(self) -> None: """ diff --git a/tests/core/stream_muxer/test_read_write_lock.py b/tests/core/stream_muxer/test_read_write_lock.py index c52aa36c..621f3841 100644 --- a/tests/core/stream_muxer/test_read_write_lock.py +++ b/tests/core/stream_muxer/test_read_write_lock.py @@ -1,203 +1,590 @@ -from unittest.mock import AsyncMock, MagicMock +from typing import Any, cast +from unittest.mock import AsyncMock, MagicMock, patch import pytest import trio +from trio.testing import wait_all_tasks_blocked -from libp2p.stream_muxer.mplex.mplex_stream import MplexStream, StreamID +from libp2p.stream_muxer.exceptions import ( + MuxedConnUnavailable, +) +from libp2p.stream_muxer.mplex.constants import HeaderTags +from libp2p.stream_muxer.mplex.datastructures import StreamID +from libp2p.stream_muxer.mplex.exceptions import ( + MplexStreamClosed, + MplexStreamEOF, + MplexStreamReset, +) +from libp2p.stream_muxer.mplex.mplex_stream import MplexStream + + +class MockMuxedConn: + """A mock Mplex connection for testing purposes.""" + + def __init__(self): + self.sent_messages = [] + self.streams: dict[StreamID, MplexStream] = {} + self.streams_lock = trio.Lock() + self.is_unavailable = False + + async def send_message( + self, flag: HeaderTags, data: bytes | None, stream_id: StreamID + ) -> None: + """Mocks sending a message over the connection.""" + if self.is_unavailable: + raise MuxedConnUnavailable("Connection is unavailable") + self.sent_messages.append((flag, data, stream_id)) + # Yield to allow other tasks to run + await trio.lowlevel.checkpoint() + + def get_remote_address(self) -> tuple[str, int]: + """Mocks getting the remote address.""" + return "127.0.0.1", 4001 @pytest.fixture -def stream_with_lock() -> tuple[MplexStream, trio.MemorySendChannel[bytes]]: - muxed_conn = MagicMock() - muxed_conn.send_message = AsyncMock() - muxed_conn.streams_lock = trio.Lock() - muxed_conn.streams = {} - muxed_conn.get_remote_address = MagicMock(return_value=("127.0.0.1", 8000)) +async def mplex_stream(): + """Provides a fully initialized MplexStream and its communication channels.""" + # Use a buffered channel to prevent deadlocks in simple tests + send_chan, recv_chan = trio.open_memory_channel(10) + stream_id = StreamID(1, is_initiator=True) + muxed_conn = MockMuxedConn() + stream = MplexStream("test-stream", stream_id, cast(Any, muxed_conn), recv_chan) + muxed_conn.streams[stream_id] = stream - send_chan: trio.MemorySendChannel[bytes] - recv_chan: trio.MemoryReceiveChannel[bytes] - send_chan, recv_chan = trio.open_memory_channel(0) + yield stream, send_chan, muxed_conn - dummy_stream_id = MagicMock(spec=StreamID) - dummy_stream_id.is_initiator = True # mock read-only property + # Cleanup: Close channels and reset stream state + await send_chan.aclose() + await recv_chan.aclose() + # Reset stream state to prevent cross-test contamination + stream.event_local_closed = trio.Event() + stream.event_remote_closed = trio.Event() + stream.event_reset = trio.Event() - stream = MplexStream( - name="test", - stream_id=dummy_stream_id, - muxed_conn=muxed_conn, - incoming_data_channel=recv_chan, - ) - return stream, send_chan + +# =============================================== +# 1. Tests for Stream-Level Lock Integration +# =============================================== @pytest.mark.trio -async def test_writing_blocked_if_read_in_progress( - stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]], -) -> None: - stream, _ = stream_with_lock - log: list[str] = [] +async def test_stream_write_is_protected_by_rwlock(mplex_stream): + """Verify that stream.write() acquires and releases the write lock.""" + stream, _, muxed_conn = mplex_stream - async def reader() -> None: - await stream.rw_lock.acquire_read() - log.append("read_acquired") - await trio.sleep(0.3) - log.append("read_released") - await stream.rw_lock.release_read() + # Mock lock methods + original_acquire = stream.rw_lock.acquire_write + original_release = stream.rw_lock.release_write - async def writer() -> None: - await stream.rw_lock.acquire_write() - log.append("write_acquired") - await trio.sleep(0.1) - log.append("write_released") - stream.rw_lock.release_write() + stream.rw_lock.acquire_write = AsyncMock(wraps=original_acquire) + stream.rw_lock.release_write = MagicMock(wraps=original_release) + + await stream.write(b"test data") + + stream.rw_lock.acquire_write.assert_awaited_once() + stream.rw_lock.release_write.assert_called_once() + + # Verify the message was actually sent + assert len(muxed_conn.sent_messages) == 1 + flag, data, stream_id = muxed_conn.sent_messages[0] + assert flag == HeaderTags.MessageInitiator + assert data == b"test data" + assert stream_id == stream.stream_id + + +@pytest.mark.trio +async def test_stream_read_is_protected_by_rwlock(mplex_stream): + """Verify that stream.read() acquires and releases the read lock.""" + stream, send_chan, _ = mplex_stream + + # Mock lock methods + original_acquire = stream.rw_lock.acquire_read + original_release = stream.rw_lock.release_read + + stream.rw_lock.acquire_read = AsyncMock(wraps=original_acquire) + stream.rw_lock.release_read = AsyncMock(wraps=original_release) + + await send_chan.send(b"hello") + result = await stream.read(5) + + stream.rw_lock.acquire_read.assert_awaited_once() + stream.rw_lock.release_read.assert_awaited_once() + assert result == b"hello" + + +@pytest.mark.trio +async def test_multiple_readers_can_coexist(mplex_stream): + """Verify multiple readers can operate concurrently.""" + stream, send_chan, _ = mplex_stream + + # Send enough data for both reads + await send_chan.send(b"data1") + await send_chan.send(b"data2") + + # Track lock acquisition order + acquisition_order = [] + release_order = [] + + # Patch lock methods to track concurrency + original_acquire = stream.rw_lock.acquire_read + original_release = stream.rw_lock.release_read + + async def tracked_acquire(): + nonlocal acquisition_order + acquisition_order.append("start") + await original_acquire() + acquisition_order.append("acquired") + + async def tracked_release(): + nonlocal release_order + release_order.append("start") + await original_release() + release_order.append("released") + + with ( + patch.object( + stream.rw_lock, "acquire_read", side_effect=tracked_acquire, autospec=True + ), + patch.object( + stream.rw_lock, "release_read", side_effect=tracked_release, autospec=True + ), + ): + # Execute concurrent reads + async with trio.open_nursery() as nursery: + nursery.start_soon(stream.read, 5) + nursery.start_soon(stream.read, 5) + + # Verify both reads happened + assert acquisition_order.count("start") == 2 + assert acquisition_order.count("acquired") == 2 + assert release_order.count("start") == 2 + assert release_order.count("released") == 2 + + +@pytest.mark.trio +async def test_writer_blocks_readers(mplex_stream): + """Verify that a writer blocks all readers and new readers queue behind.""" + stream, send_chan, _ = mplex_stream + + writer_acquired = trio.Event() + readers_ready = trio.Event() + writer_finished = trio.Event() + all_readers_started = trio.Event() + all_readers_done = trio.Event() + + counters = {"reader_start_count": 0, "reader_done_count": 0} + reader_target = 3 + reader_start_lock = trio.Lock() + + # Patch write lock to control test flow + original_acquire_write = stream.rw_lock.acquire_write + original_release_write = stream.rw_lock.release_write + + async def tracked_acquire_write(): + await original_acquire_write() + writer_acquired.set() + # Wait for readers to queue up + await readers_ready.wait() + + # Must be synchronous since real release_write is sync + def tracked_release_write(): + original_release_write() + writer_finished.set() + + with ( + patch.object( + stream.rw_lock, "acquire_write", side_effect=tracked_acquire_write + ), + patch.object( + stream.rw_lock, "release_write", side_effect=tracked_release_write + ), + ): + async with trio.open_nursery() as nursery: + # Start writer + nursery.start_soon(stream.write, b"test") + await writer_acquired.wait() + + # Start readers + async def reader_task(): + async with reader_start_lock: + counters["reader_start_count"] += 1 + if counters["reader_start_count"] == reader_target: + all_readers_started.set() + + try: + # This will block until data is available + await stream.read(5) + except (MplexStreamReset, MplexStreamEOF): + pass + finally: + async with reader_start_lock: + counters["reader_done_count"] += 1 + if counters["reader_done_count"] == reader_target: + all_readers_done.set() + + for _ in range(reader_target): + nursery.start_soon(reader_task) + + # Wait until all readers are started + await all_readers_started.wait() + + # Let the writer finish and release the lock + readers_ready.set() + await writer_finished.wait() + + # Send data to unblock the readers + for i in range(reader_target): + await send_chan.send(b"data" + str(i).encode()) + + # Wait for all readers to finish + await all_readers_done.wait() + + +@pytest.mark.trio +async def test_writer_waits_for_readers(mplex_stream): + """Verify a writer waits for existing readers to complete.""" + stream, send_chan, _ = mplex_stream + readers_started = trio.Event() + writer_entered = trio.Event() + writer_acquiring = trio.Event() + readers_finished = trio.Event() + + # Send data for readers + await send_chan.send(b"data1") + await send_chan.send(b"data2") + + # Patch read lock to control test flow + original_acquire_read = stream.rw_lock.acquire_read + + async def tracked_acquire_read(): + await original_acquire_read() + readers_started.set() + # Wait until readers are allowed to finish + await readers_finished.wait() + + # Patch write lock to detect when writer is blocked + original_acquire_write = stream.rw_lock.acquire_write + + async def tracked_acquire_write(): + writer_acquiring.set() + await original_acquire_write() + writer_entered.set() + + with ( + patch.object(stream.rw_lock, "acquire_read", side_effect=tracked_acquire_read), + patch.object( + stream.rw_lock, "acquire_write", side_effect=tracked_acquire_write + ), + ): + async with trio.open_nursery() as nursery: + # Start readers + nursery.start_soon(stream.read, 5) + nursery.start_soon(stream.read, 5) + + # Wait for at least one reader to acquire the lock + await readers_started.wait() + + # Start writer (should block) + nursery.start_soon(stream.write, b"test") + + # Wait for writer to start acquiring lock + await writer_acquiring.wait() + + # Verify writer hasn't entered critical section + assert not writer_entered.is_set() + + # Allow readers to finish + readers_finished.set() + + # Verify writer can proceed + await writer_entered.wait() + + +@pytest.mark.trio +async def test_lock_behavior_during_cancellation(mplex_stream): + """Verify that a lock is released when a task holding it is cancelled.""" + stream, _, _ = mplex_stream + + reader_acquired_lock = trio.Event() + + async def cancellable_reader(task_status): + async with stream.rw_lock.read_lock(): + reader_acquired_lock.set() + task_status.started() + # Wait indefinitely until cancelled. + await trio.sleep_forever() + + async with trio.open_nursery() as nursery: + # Start the reader and wait for it to acquire the lock. + await nursery.start(cancellable_reader) + await reader_acquired_lock.wait() + + # Now that the reader has the lock, cancel the nursery. + # This will cancel the reader task, and its lock should be released. + nursery.cancel_scope.cancel() + + # After the nursery is cancelled, the reader should have released the lock. + # To verify, we try to acquire a write lock. If the read lock was not + # released, this will time out. + with trio.move_on_after(1) as cancel_scope: + async with stream.rw_lock.write_lock(): + pass + if cancel_scope.cancelled_caught: + pytest.fail( + "Write lock could not be acquired after a cancelled reader, " + "indicating the read lock was not released." + ) + + +@pytest.mark.trio +async def test_concurrent_read_write_sequence(mplex_stream): + """Verify complex sequence of interleaved reads and writes.""" + stream, send_chan, _ = mplex_stream + results = [] + # Use a mock to intercept writes and feed them back to the read channel + original_write = stream.write + + reader1_finished = trio.Event() + writer1_finished = trio.Event() + reader2_finished = trio.Event() + + async def mocked_write(data: bytes) -> None: + await original_write(data) + # Simulate the other side receiving the data and sending a response + # by putting data into the read channel. + await send_chan.send(data) + + with patch.object(stream, "write", wraps=mocked_write) as patched_write: + async with trio.open_nursery() as nursery: + # Test scenario: + # 1. Reader 1 starts, waits for data. + # 2. Writer 1 writes, which gets fed back to the stream. + # 3. Reader 2 starts, reads what Writer 1 wrote. + # 4. Writer 2 writes. + + async def reader1(): + nonlocal results + results.append("R1 start") + data = await stream.read(5) + results.append(data) + results.append("R1 done") + reader1_finished.set() + + async def writer1(): + nonlocal results + await reader1_finished.wait() + results.append("W1 start") + await stream.write(b"write1") + results.append("W1 done") + writer1_finished.set() + + async def reader2(): + nonlocal results + await writer1_finished.wait() + # This will read the data from writer1 + results.append("R2 start") + data = await stream.read(6) + results.append(data) + results.append("R2 done") + reader2_finished.set() + + async def writer2(): + nonlocal results + await reader2_finished.wait() + results.append("W2 start") + await stream.write(b"write2") + results.append("W2 done") + + # Execute sequence + nursery.start_soon(reader1) + nursery.start_soon(writer1) + nursery.start_soon(reader2) + nursery.start_soon(writer2) + + await send_chan.send(b"data1") + + # Verify sequence and that write was called + assert patched_write.call_count == 2 + assert results == [ + "R1 start", + b"data1", + "R1 done", + "W1 start", + "W1 done", + "R2 start", + b"write1", + "R2 done", + "W2 start", + "W2 done", + ] + + +# =============================================== +# 2. Tests for Reset, EOF, and Close Interactions +# =============================================== + + +@pytest.mark.trio +async def test_read_after_remote_close_triggers_eof(mplex_stream): + """Verify reading from a remotely closed stream returns EOF correctly.""" + stream, send_chan, _ = mplex_stream + + # Send some data that can be read first + await send_chan.send(b"data") + # Close the channel to signify no more data will ever arrive + await send_chan.aclose() + + # Mark the stream as remotely closed + stream.event_remote_closed.set() + + # The first read should succeed, consuming the buffered data + data = await stream.read(4) + assert data == b"data" + + # Now that the buffer is empty and the channel is closed, this should raise EOF + with pytest.raises(MplexStreamEOF): + await stream.read(1) + + +@pytest.mark.trio +async def test_read_on_closed_stream_raises_eof(mplex_stream): + """Test that reading from a closed stream with no data raises EOF.""" + stream, send_chan, _ = mplex_stream + stream.event_remote_closed.set() + await send_chan.aclose() # Ensure the channel is closed + + # Reading from a stream that is closed and has no data should raise EOF + with pytest.raises(MplexStreamEOF): + await stream.read(100) + + +@pytest.mark.trio +async def test_write_to_locally_closed_stream_raises(mplex_stream): + """Verify writing to a locally closed stream raises MplexStreamClosed.""" + stream, _, _ = mplex_stream + stream.event_local_closed.set() + + with pytest.raises(MplexStreamClosed): + await stream.write(b"this should fail") + + +@pytest.mark.trio +async def test_read_from_reset_stream_raises(mplex_stream): + """Verify reading from a reset stream raises MplexStreamReset.""" + stream, _, _ = mplex_stream + stream.event_reset.set() + + with pytest.raises(MplexStreamReset): + await stream.read(10) + + +@pytest.mark.trio +async def test_write_to_reset_stream_raises(mplex_stream): + """Verify writing to a reset stream raises MplexStreamClosed.""" + stream, _, _ = mplex_stream + # A stream reset implies it's also locally closed. + await stream.reset() + + # The `write` method checks `event_local_closed`, which `reset` sets. + with pytest.raises(MplexStreamClosed): + await stream.write(b"this should also fail") + + +@pytest.mark.trio +async def test_stream_reset_cleans_up_resources(mplex_stream): + """Verify reset() cleans up stream state and resources.""" + stream, _, muxed_conn = mplex_stream + stream_id = stream.stream_id + + assert stream_id in muxed_conn.streams + await stream.reset() + + assert stream.event_reset.is_set() + assert stream.event_local_closed.is_set() + assert stream.event_remote_closed.is_set() + assert (HeaderTags.ResetInitiator, None, stream_id) in muxed_conn.sent_messages + assert stream_id not in muxed_conn.streams + # Verify the underlying data channel is closed + with pytest.raises(trio.ClosedResourceError): + await stream.incoming_data_channel.receive() + + +# =============================================== +# 3. Rigorous Concurrency Tests with Events +# =============================================== + + +@pytest.mark.trio +async def test_writer_is_blocked_by_reader_using_events(mplex_stream): + """Verify a writer must wait for a reader using trio.Event for synchronization.""" + stream, _, _ = mplex_stream + + reader_has_lock = trio.Event() + writer_finished = trio.Event() + + async def reader(): + async with stream.rw_lock.read_lock(): + reader_has_lock.set() + # Hold the lock until the writer has finished its attempt + await writer_finished.wait() + + async def writer(): + await reader_has_lock.wait() + # This call will now block until the reader releases the lock + await stream.write(b"data") + writer_finished.set() async with trio.open_nursery() as nursery: nursery.start_soon(reader) - await trio.sleep(0.05) nursery.start_soon(writer) - assert log == [ - "read_acquired", - "read_released", - "write_acquired", - "write_released", - ], f"Unexpected order: {log}" + # Verify writer is blocked + await wait_all_tasks_blocked() + assert not writer_finished.is_set() + + # Signal the reader to finish + writer_finished.set() @pytest.mark.trio -async def test_reading_blocked_if_write_in_progress( - stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]], -) -> None: - stream, _ = stream_with_lock - log: list[str] = [] +async def test_multiple_readers_can_read_concurrently_using_events(mplex_stream): + """Verify that multiple readers can acquire a read lock simultaneously.""" + stream, _, _ = mplex_stream - async def writer() -> None: - await stream.rw_lock.acquire_write() - log.append("write_acquired") - await trio.sleep(0.3) - log.append("write_released") - stream.rw_lock.release_write() + counters = {"readers_in_critical_section": 0} + lock = trio.Lock() # To safely mutate the counter - async def reader() -> None: - await stream.rw_lock.acquire_read() - log.append("read_acquired") - await trio.sleep(0.1) - log.append("read_released") - await stream.rw_lock.release_read() + reader1_acquired = trio.Event() + reader2_acquired = trio.Event() + all_readers_finished = trio.Event() + + async def concurrent_reader(event_to_set: trio.Event): + async with stream.rw_lock.read_lock(): + async with lock: + counters["readers_in_critical_section"] += 1 + event_to_set.set() + # Wait until all readers have finished before exiting the lock context + await all_readers_finished.wait() + async with lock: + counters["readers_in_critical_section"] -= 1 async with trio.open_nursery() as nursery: - nursery.start_soon(writer) - await trio.sleep(0.05) - nursery.start_soon(reader) + nursery.start_soon(concurrent_reader, reader1_acquired) + nursery.start_soon(concurrent_reader, reader2_acquired) - assert log == [ - "write_acquired", - "write_released", - "read_acquired", - "read_released", - ], f"Unexpected order: {log}" + # Wait for both readers to acquire their locks + await reader1_acquired.wait() + await reader2_acquired.wait() + # Check that both were in the critical section at the same time + async with lock: + assert counters["readers_in_critical_section"] == 2 -@pytest.mark.trio -async def test_multiple_reads_allowed_concurrently( - stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]], -) -> None: - stream, _ = stream_with_lock - log: list[str] = [] + # Signal for all readers to finish + all_readers_finished.set() - async def read_task(i: int) -> None: - await stream.rw_lock.acquire_read() - log.append(f"read_{i}_acquired") - await trio.sleep(0.2) - log.append(f"read_{i}_released") - await stream.rw_lock.release_read() - - async with trio.open_nursery() as nursery: - for i in range(5): - nursery.start_soon(read_task, i) - - acquires = [entry for entry in log if "acquired" in entry] - releases = [entry for entry in log if "released" in entry] - - assert len(acquires) == 5 and len(releases) == 5, "Not all reads executed" - assert all( - log.index(acq) < min(log.index(rel) for rel in releases) for acq in acquires - ), f"Reads didn't overlap properly: {log}" - - -@pytest.mark.trio -async def test_only_one_write_allowed( - stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]], -) -> None: - stream, _ = stream_with_lock - log: list[str] = [] - - async def write_task(i: int) -> None: - await stream.rw_lock.acquire_write() - log.append(f"write_{i}_acquired") - await trio.sleep(0.2) - log.append(f"write_{i}_released") - stream.rw_lock.release_write() - - async with trio.open_nursery() as nursery: - for i in range(5): - nursery.start_soon(write_task, i) - - active = 0 - for entry in log: - if "acquired" in entry: - active += 1 - elif "released" in entry: - active -= 1 - assert active <= 1, f"More than one write active: {log}" - assert active == 0, f"Write locks not properly released: {log}" - - -@pytest.mark.trio -async def test_interleaved_read_write_behavior( - stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]], -) -> None: - stream, _ = stream_with_lock - log: list[str] = [] - - async def read(i: int) -> None: - await stream.rw_lock.acquire_read() - log.append(f"read_{i}_acquired") - await trio.sleep(0.15) - log.append(f"read_{i}_released") - await stream.rw_lock.release_read() - - async def write(i: int) -> None: - await stream.rw_lock.acquire_write() - log.append(f"write_{i}_acquired") - await trio.sleep(0.2) - log.append(f"write_{i}_released") - stream.rw_lock.release_write() - - async with trio.open_nursery() as nursery: - nursery.start_soon(read, 1) - await trio.sleep(0.05) - nursery.start_soon(read, 2) - await trio.sleep(0.05) - nursery.start_soon(write, 1) - await trio.sleep(0.05) - nursery.start_soon(read, 3) - await trio.sleep(0.05) - nursery.start_soon(write, 2) - - read1_released = log.index("read_1_released") - read2_released = log.index("read_2_released") - write1_acquired = log.index("write_1_acquired") - assert write1_acquired > read1_released and write1_acquired > read2_released, ( - f"write_1 acquired too early: {log}" - ) - - read3_acquired = log.index("read_3_acquired") - read3_released = log.index("read_3_released") - write1_released = log.index("write_1_released") - assert read3_released < write1_acquired or read3_acquired > write1_released, ( - f"read_3 improperly overlapped with write_1: {log}" - ) - - write2_acquired = log.index("write_2_acquired") - assert write2_acquired > write1_released, f"write_2 started too early: {log}" + # Verify they exit cleanly + await wait_all_tasks_blocked() + async with lock: + assert counters["readers_in_critical_section"] == 0 From 6aeb217349b1c196f36690c34b4600f0794d180e Mon Sep 17 00:00:00 2001 From: Luca Vivona Date: Tue, 15 Jul 2025 14:59:34 -0400 Subject: [PATCH 10/22] replace: attributes with cache cached_property --- libp2p/peer/id.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/libp2p/peer/id.py b/libp2p/peer/id.py index 0be51ea2..61e399cd 100644 --- a/libp2p/peer/id.py +++ b/libp2p/peer/id.py @@ -3,6 +3,10 @@ import hashlib import base58 import multihash +from functools import ( + cached_property, +) + from libp2p.crypto.keys import ( PublicKey, ) @@ -36,25 +40,23 @@ if ENABLE_INLINING: class ID: _bytes: bytes - _xor_id: int | None = None - _b58_str: str | None = None def __init__(self, peer_id_bytes: bytes) -> None: self._bytes = peer_id_bytes - @property + @cached_property def xor_id(self) -> int: - if not self._xor_id: - self._xor_id = int(sha256_digest(self._bytes).hex(), 16) - return self._xor_id + return int(sha256_digest(self._bytes).hex(), 16) + + @cached_property + def base58(self) -> str: + return base58.b58encode(self._bytes).decode() def to_bytes(self) -> bytes: return self._bytes def to_base58(self) -> str: - if not self._b58_str: - self._b58_str = base58.b58encode(self._bytes).decode() - return self._b58_str + return self.base58 def __repr__(self) -> str: return f"" From 23622ea1a088a39f3ba1fe5539eeb59afd205f5d Mon Sep 17 00:00:00 2001 From: Luca Vivona Date: Tue, 15 Jul 2025 15:28:03 -0400 Subject: [PATCH 11/22] style: enforce consistent import block --- libp2p/peer/id.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/libp2p/peer/id.py b/libp2p/peer/id.py index 61e399cd..28a9d75a 100644 --- a/libp2p/peer/id.py +++ b/libp2p/peer/id.py @@ -1,12 +1,9 @@ +import functools import hashlib import base58 import multihash -from functools import ( - cached_property, -) - from libp2p.crypto.keys import ( PublicKey, ) @@ -44,11 +41,11 @@ class ID: def __init__(self, peer_id_bytes: bytes) -> None: self._bytes = peer_id_bytes - @cached_property + @functools.cached_property def xor_id(self) -> int: return int(sha256_digest(self._bytes).hex(), 16) - @cached_property + @functools.cached_property def base58(self) -> str: return base58.b58encode(self._bytes).decode() From 9f40d97a056d1d493120be1afd204ec6b5f95615 Mon Sep 17 00:00:00 2001 From: Luca Vivona Date: Wed, 16 Jul 2025 22:08:25 -0400 Subject: [PATCH 12/22] chore(newsfragment): add entry to the release notes --- newsfragments/772.internal.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 newsfragments/772.internal.rst diff --git a/newsfragments/772.internal.rst b/newsfragments/772.internal.rst new file mode 100644 index 00000000..7079d6c4 --- /dev/null +++ b/newsfragments/772.internal.rst @@ -0,0 +1 @@ +Replace the libp2p.peer.ID cache attributes with functools.cached_property functional decorator. \ No newline at end of file From ae82895d86fd0992824fd93ea00fc4d8026587aa Mon Sep 17 00:00:00 2001 From: Luca Vivona Date: Wed, 16 Jul 2025 22:12:05 -0400 Subject: [PATCH 13/22] style: add new line within newsfragment --- newsfragments/772.internal.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/newsfragments/772.internal.rst b/newsfragments/772.internal.rst index 7079d6c4..2c84641c 100644 --- a/newsfragments/772.internal.rst +++ b/newsfragments/772.internal.rst @@ -1 +1 @@ -Replace the libp2p.peer.ID cache attributes with functools.cached_property functional decorator. \ No newline at end of file +Replace the libp2p.peer.ID cache attributes with functools.cached_property functional decorator. From c9162beb2bd390050e4dd0e824c8550be8443a61 Mon Sep 17 00:00:00 2001 From: Jinesh Jain <732005jinesh@gmail.com> Date: Thu, 17 Jul 2025 20:55:32 +0530 Subject: [PATCH 14/22] add grave that were removed by mistake --- libp2p/stream_muxer/mplex/mplex_stream.py | 24 +++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 7e439160..77546036 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -107,7 +107,7 @@ class MplexStream(IMuxedStream): name: str stream_id: StreamID - # NOTE: All methods used here are part of Mplex which is a derived + # NOTE: All methods used here are part of `Mplex` which is a derived # class of IMuxedConn. Ignoring this type assignment should not pose # any risk. muxed_conn: "Mplex" # type: ignore[assignment] @@ -117,7 +117,7 @@ class MplexStream(IMuxedStream): rw_lock: ReadWriteLock close_lock: trio.Lock - # NOTE: dataIn is size of 8 in Go implementation. + # NOTE: `dataIn` is size of 8 in Go implementation. incoming_data_channel: "trio.MemoryReceiveChannel[bytes]" event_local_closed: trio.Event @@ -175,8 +175,8 @@ class MplexStream(IMuxedStream): async def read(self, n: int | None = None) -> bytes: """ - Read up to n bytes. Read possibly returns fewer than n bytes, if - there are not enough bytes in the Mplex buffer. If n is None, read + Read up to n bytes. Read possibly returns fewer than `n` bytes, if + there are not enough bytes in the Mplex buffer. If `n is None`, read until EOF. :param n: number of bytes to read @@ -185,8 +185,8 @@ class MplexStream(IMuxedStream): async with self.rw_lock.read_lock(): if n is not None and n < 0: raise ValueError( - "the number of bytes to read n must be non-negative or " - f"None to indicate read until EOF, got n={n}" + "the number of bytes to read `n` must be non-negative or " + f"`None` to indicate read until EOF, got n={n}" ) if self.event_reset.is_set(): raise MplexStreamReset @@ -202,8 +202,8 @@ class MplexStream(IMuxedStream): except trio.EndOfChannel: raise MplexStreamEOF except trio.WouldBlock: - # We know receive will be blocked here. Wait for data here with - # receive and catch all kinds of errors here. + # We know `receive` will be blocked here. Wait for data here with + # `receive` and catch all kinds of errors here. try: data = await self.incoming_data_channel.receive() self._buf.extend(data) @@ -213,12 +213,12 @@ class MplexStream(IMuxedStream): if self.event_remote_closed.is_set(): raise MplexStreamEOF except trio.ClosedResourceError as error: - # Probably incoming_data_channel is closed in reset when we are - # waiting for receive. + # Probably `incoming_data_channel` is closed in `reset` when + # we are waiting for `receive`. if self.event_reset.is_set(): raise MplexStreamReset raise Exception( - "incoming_data_channel is closed but stream is not reset. " + "`incoming_data_channel` is closed but stream is not reset." "This should never happen." ) from error self._buf.extend(self._read_return_when_blocked()) @@ -256,7 +256,7 @@ class MplexStream(IMuxedStream): flag = ( HeaderTags.CloseInitiator if self.is_initiator else HeaderTags.CloseReceiver ) - # TODO: Raise when muxed_conn.send_message fails and Mplex isn't shutdown. + # TODO: Raise when `muxed_conn.send_message` fails and `Mplex` isn't shutdown. await self.muxed_conn.send_message(flag, None, self.stream_id) _is_remote_closed: bool From 35075313448e3038a3d41f52d169eb15e9f2eeea Mon Sep 17 00:00:00 2001 From: Luca Vivona Date: Thu, 17 Jul 2025 22:43:00 -0400 Subject: [PATCH 15/22] chore: clarify newline requirement in newsfragments README.md (#775) * chore: clarify newline requirement in README Small change in newsfragments README.md, that reduces some possible room for pull-request tox workflow errors. * style: remove double backticks for single backticks the linter strikes again XD. * docs: clarify trailing newline requirement in newsfragments for lint checks --------- Co-authored-by: Manu Sheel Gupta --- newsfragments/775.docs.rst | 1 + newsfragments/README.md | 11 +++++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) create mode 100644 newsfragments/775.docs.rst diff --git a/newsfragments/775.docs.rst b/newsfragments/775.docs.rst new file mode 100644 index 00000000..300b27ca --- /dev/null +++ b/newsfragments/775.docs.rst @@ -0,0 +1 @@ +Clarified the requirement for a trailing newline in newsfragments to pass lint checks. diff --git a/newsfragments/README.md b/newsfragments/README.md index 177d6492..4b54df7c 100644 --- a/newsfragments/README.md +++ b/newsfragments/README.md @@ -18,12 +18,19 @@ Each file should be named like `..rst`, where - `performance` - `removal` -So for example: `123.feature.rst`, `456.bugfix.rst` +So for example: `1024.feature.rst` + +**Important**: Ensure the file ends with a newline character (`\n`) to pass GitHub tox linting checks. + +``` +Added support for Ed25519 key generation in libp2p peer identity creation. + +``` If the PR fixes an issue, use that number here. If there is no issue, then open up the PR first and use the PR number for the newsfragment. -Note that the `towncrier` tool will automatically +**Note** that the `towncrier` tool will automatically reflow your text, so don't try to do any fancy formatting. Run `towncrier build --draft` to get a preview of what the release notes entry will look like in the final release notes. From 11560f5cc95baad057eedca2830a1fdc9b95d353 Mon Sep 17 00:00:00 2001 From: Abhinav Agarwalla <120122716+lla-dane@users.noreply.github.com> Date: Fri, 18 Jul 2025 17:31:28 +0530 Subject: [PATCH 16/22] TODO: throttle on async validators (#755) * fixed todo: throttle on async validators * added test: validate message respects concurrency limit * added newsfragment * added configurable validator semaphore in the PubSub constructor * added the concurrency-checker in the original test-validate-msg test case * separate out a _run_async_validator function * remove redundant run_async_validator --- libp2p/pubsub/pubsub.py | 39 ++++++++++++++++++------ newsfragments/755.performance.rst | 2 ++ tests/core/pubsub/test_pubsub.py | 50 +++++++++++++++++++++++++++---- 3 files changed, 77 insertions(+), 14 deletions(-) create mode 100644 newsfragments/755.performance.rst diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index a913c721..5641ec5d 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -102,6 +102,9 @@ class TopicValidator(NamedTuple): is_async: bool +MAX_CONCURRENT_VALIDATORS = 10 + + class Pubsub(Service, IPubsub): host: IHost @@ -109,6 +112,7 @@ class Pubsub(Service, IPubsub): peer_receive_channel: trio.MemoryReceiveChannel[ID] dead_peer_receive_channel: trio.MemoryReceiveChannel[ID] + _validator_semaphore: trio.Semaphore seen_messages: LastSeenCache @@ -143,6 +147,7 @@ class Pubsub(Service, IPubsub): msg_id_constructor: Callable[ [rpc_pb2.Message], bytes ] = get_peer_and_seqno_msg_id, + max_concurrent_validator_count: int = MAX_CONCURRENT_VALIDATORS, ) -> None: """ Construct a new Pubsub object, which is responsible for handling all @@ -168,6 +173,7 @@ class Pubsub(Service, IPubsub): # Therefore, we can only close from the receive side. self.peer_receive_channel = peer_receive self.dead_peer_receive_channel = dead_peer_receive + self._validator_semaphore = trio.Semaphore(max_concurrent_validator_count) # Register a notifee self.host.get_network().register_notifee( PubsubNotifee(peer_send, dead_peer_send) @@ -657,7 +663,11 @@ class Pubsub(Service, IPubsub): logger.debug("successfully published message %s", msg) - async def validate_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None: + async def validate_msg( + self, + msg_forwarder: ID, + msg: rpc_pb2.Message, + ) -> None: """ Validate the received message. @@ -680,23 +690,34 @@ class Pubsub(Service, IPubsub): if not validator(msg_forwarder, msg): raise ValidationError(f"Validation failed for msg={msg}") - # TODO: Implement throttle on async validators - if len(async_topic_validators) > 0: # Appends to lists are thread safe in CPython - results = [] - - async def run_async_validator(func: AsyncValidatorFn) -> None: - result = await func(msg_forwarder, msg) - results.append(result) + results: list[bool] = [] async with trio.open_nursery() as nursery: for async_validator in async_topic_validators: - nursery.start_soon(run_async_validator, async_validator) + nursery.start_soon( + self._run_async_validator, + async_validator, + msg_forwarder, + msg, + results, + ) if not all(results): raise ValidationError(f"Validation failed for msg={msg}") + async def _run_async_validator( + self, + func: AsyncValidatorFn, + msg_forwarder: ID, + msg: rpc_pb2.Message, + results: list[bool], + ) -> None: + async with self._validator_semaphore: + result = await func(msg_forwarder, msg) + results.append(result) + async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None: """ Push a pubsub message to others. diff --git a/newsfragments/755.performance.rst b/newsfragments/755.performance.rst new file mode 100644 index 00000000..386e661b --- /dev/null +++ b/newsfragments/755.performance.rst @@ -0,0 +1,2 @@ +Added throttling for async topic validators in validate_msg, enforcing a +concurrency limit to prevent resource exhaustion under heavy load. diff --git a/tests/core/pubsub/test_pubsub.py b/tests/core/pubsub/test_pubsub.py index 81389ed1..e674dbc0 100644 --- a/tests/core/pubsub/test_pubsub.py +++ b/tests/core/pubsub/test_pubsub.py @@ -5,10 +5,12 @@ import inspect from typing import ( NamedTuple, ) +from unittest.mock import patch import pytest import trio +from libp2p.custom_types import AsyncValidatorFn from libp2p.exceptions import ( ValidationError, ) @@ -243,7 +245,37 @@ async def test_get_msg_validators(): ((False, True), (True, False), (True, True)), ) @pytest.mark.trio -async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed): +async def test_validate_msg_with_throttle_condition( + is_topic_1_val_passed, is_topic_2_val_passed +): + CONCURRENCY_LIMIT = 10 + + state = { + "concurrency_counter": 0, + "max_observed": 0, + } + lock = trio.Lock() + + async def mock_run_async_validator( + self, + func: AsyncValidatorFn, + msg_forwarder: ID, + msg: rpc_pb2.Message, + results: list[bool], + ) -> None: + async with self._validator_semaphore: + async with lock: + state["concurrency_counter"] += 1 + if state["concurrency_counter"] > state["max_observed"]: + state["max_observed"] = state["concurrency_counter"] + + try: + result = await func(msg_forwarder, msg) + results.append(result) + finally: + async with lock: + state["concurrency_counter"] -= 1 + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: def passed_sync_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool: @@ -280,11 +312,19 @@ async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed): seqno=b"\x00" * 8, ) - if is_topic_1_val_passed and is_topic_2_val_passed: - await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg) - else: - with pytest.raises(ValidationError): + with patch( + "libp2p.pubsub.pubsub.Pubsub._run_async_validator", + new=mock_run_async_validator, + ): + if is_topic_1_val_passed and is_topic_2_val_passed: await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg) + else: + with pytest.raises(ValidationError): + await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg) + + assert state["max_observed"] <= CONCURRENCY_LIMIT, ( + f"Max concurrency observed: {state['max_observed']}" + ) @pytest.mark.trio From 85bad2d0ae2ad10e4f93a5a9da98349c32137e9c Mon Sep 17 00:00:00 2001 From: Luca Vivona Date: Tue, 15 Jul 2025 14:59:34 -0400 Subject: [PATCH 17/22] replace: attributes with cache cached_property --- libp2p/peer/id.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/libp2p/peer/id.py b/libp2p/peer/id.py index 0be51ea2..61e399cd 100644 --- a/libp2p/peer/id.py +++ b/libp2p/peer/id.py @@ -3,6 +3,10 @@ import hashlib import base58 import multihash +from functools import ( + cached_property, +) + from libp2p.crypto.keys import ( PublicKey, ) @@ -36,25 +40,23 @@ if ENABLE_INLINING: class ID: _bytes: bytes - _xor_id: int | None = None - _b58_str: str | None = None def __init__(self, peer_id_bytes: bytes) -> None: self._bytes = peer_id_bytes - @property + @cached_property def xor_id(self) -> int: - if not self._xor_id: - self._xor_id = int(sha256_digest(self._bytes).hex(), 16) - return self._xor_id + return int(sha256_digest(self._bytes).hex(), 16) + + @cached_property + def base58(self) -> str: + return base58.b58encode(self._bytes).decode() def to_bytes(self) -> bytes: return self._bytes def to_base58(self) -> str: - if not self._b58_str: - self._b58_str = base58.b58encode(self._bytes).decode() - return self._b58_str + return self.base58 def __repr__(self) -> str: return f"" From fcf05468317934b7008c72a475ab91c42c1d180e Mon Sep 17 00:00:00 2001 From: Luca Vivona Date: Tue, 15 Jul 2025 15:28:03 -0400 Subject: [PATCH 18/22] style: enforce consistent import block --- libp2p/peer/id.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/libp2p/peer/id.py b/libp2p/peer/id.py index 61e399cd..28a9d75a 100644 --- a/libp2p/peer/id.py +++ b/libp2p/peer/id.py @@ -1,12 +1,9 @@ +import functools import hashlib import base58 import multihash -from functools import ( - cached_property, -) - from libp2p.crypto.keys import ( PublicKey, ) @@ -44,11 +41,11 @@ class ID: def __init__(self, peer_id_bytes: bytes) -> None: self._bytes = peer_id_bytes - @cached_property + @functools.cached_property def xor_id(self) -> int: return int(sha256_digest(self._bytes).hex(), 16) - @cached_property + @functools.cached_property def base58(self) -> str: return base58.b58encode(self._bytes).decode() From 092b9c0c579ede31b483fc0dbbe94069e420b067 Mon Sep 17 00:00:00 2001 From: Luca Vivona Date: Wed, 16 Jul 2025 22:08:25 -0400 Subject: [PATCH 19/22] chore(newsfragment): add entry to the release notes --- newsfragments/772.internal.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 newsfragments/772.internal.rst diff --git a/newsfragments/772.internal.rst b/newsfragments/772.internal.rst new file mode 100644 index 00000000..7079d6c4 --- /dev/null +++ b/newsfragments/772.internal.rst @@ -0,0 +1 @@ +Replace the libp2p.peer.ID cache attributes with functools.cached_property functional decorator. \ No newline at end of file From 7cfe5b9dc7bb9098dac0c121731fd204136c1c68 Mon Sep 17 00:00:00 2001 From: Luca Vivona Date: Wed, 16 Jul 2025 22:12:05 -0400 Subject: [PATCH 20/22] style: add new line within newsfragment --- newsfragments/772.internal.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/newsfragments/772.internal.rst b/newsfragments/772.internal.rst index 7079d6c4..2c84641c 100644 --- a/newsfragments/772.internal.rst +++ b/newsfragments/772.internal.rst @@ -1 +1 @@ -Replace the libp2p.peer.ID cache attributes with functools.cached_property functional decorator. \ No newline at end of file +Replace the libp2p.peer.ID cache attributes with functools.cached_property functional decorator. From 99db5b309f28bb5cd0e1901a510ee7405b93856e Mon Sep 17 00:00:00 2001 From: acul71 Date: Sat, 19 Jul 2025 04:11:27 +0200 Subject: [PATCH 21/22] fix raw format in identify and tests --- examples/identify/identify.py | 123 +++++- .../identify/test_identify_integration.py | 241 ++++++++++ .../identify/test_identify_parsing.py | 410 ------------------ 3 files changed, 356 insertions(+), 418 deletions(-) create mode 100644 tests/core/identity/identify/test_identify_integration.py delete mode 100644 tests/core/identity/identify/test_identify_parsing.py diff --git a/examples/identify/identify.py b/examples/identify/identify.py index 4882d2c3..60ccb75c 100644 --- a/examples/identify/identify.py +++ b/examples/identify/identify.py @@ -72,13 +72,46 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No client_addr = server_addr.replace("/ip4/0.0.0.0/", "/ip4/127.0.0.1/") format_name = "length-prefixed" if use_varint_format else "raw protobuf" + format_flag = "--raw-format" if not use_varint_format else "" print( f"First host listening (using {format_name} format). " f"Run this from another console:\n\n" - f"identify-demo " - f"-d {client_addr}\n" + f"identify-demo {format_flag} -d {client_addr}\n" ) print("Waiting for incoming identify request...") + + # Add a custom handler to show connection events + async def custom_identify_handler(stream): + peer_id = stream.muxed_conn.peer_id + print(f"\nšŸ”— Received identify request from peer: {peer_id}") + + # Show remote address in multiaddr format + try: + from libp2p.identity.identify.identify import ( + _remote_address_to_multiaddr, + ) + + remote_address = stream.get_remote_address() + if remote_address: + observed_multiaddr = _remote_address_to_multiaddr( + remote_address + ) + # Add the peer ID to create a complete multiaddr + complete_multiaddr = f"{observed_multiaddr}/p2p/{peer_id}" + print(f" Remote address: {complete_multiaddr}") + else: + print(f" Remote address: {remote_address}") + except Exception: + print(f" Remote address: {stream.get_remote_address()}") + + # Call the original handler + await identify_handler(stream) + + print(f"āœ… Successfully processed identify request from {peer_id}") + + # Replace the handler with our custom one + host_a.set_stream_handler(IDENTIFY_PROTOCOL_ID, custom_identify_handler) + await trio.sleep_forever() else: @@ -93,25 +126,99 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No info = info_from_p2p_addr(maddr) print(f"Second host connecting to peer: {info.peer_id}") - await host_b.connect(info) + try: + await host_b.connect(info) + except Exception as e: + error_msg = str(e) + if "unable to connect" in error_msg or "SwarmException" in error_msg: + print(f"\nāŒ Cannot connect to peer: {info.peer_id}") + print(f" Address: {destination}") + print(f" Error: {error_msg}") + print( + "\nšŸ’” Make sure the peer is running and the address is correct." + ) + return + else: + # Re-raise other exceptions + raise + stream = await host_b.new_stream(info.peer_id, (IDENTIFY_PROTOCOL_ID,)) try: print("Starting identify protocol...") - # Read the complete response (could be either format) - # Read a larger chunk to get all the data before stream closes - response = await stream.read(8192) # Read enough data in one go + # Read the response properly based on the format + if use_varint_format: + # For length-prefixed format, read varint length first + from libp2p.utils.varint import decode_varint_from_bytes + + # Read varint length prefix + length_bytes = b"" + while True: + b = await stream.read(1) + if not b: + raise Exception("Stream closed while reading varint length") + length_bytes += b + if b[0] & 0x80 == 0: + break + + msg_length = decode_varint_from_bytes(length_bytes) + print(f"Expected message length: {msg_length} bytes") + + # Read the protobuf message + response = await stream.read(msg_length) + if len(response) != msg_length: + raise Exception( + f"Incomplete message: expected {msg_length} bytes, " + f"got {len(response)}" + ) + + # Combine length prefix and message + full_response = length_bytes + response + else: + # For raw format, read all available data + response = await stream.read(8192) + full_response = response await stream.close() # Parse the response using the robust protocol-level function # This handles both old and new formats automatically - identify_msg = parse_identify_response(response) + identify_msg = parse_identify_response(full_response) print_identify_response(identify_msg) except Exception as e: - print(f"Identify protocol error: {e}") + error_msg = str(e) + print(f"Identify protocol error: {error_msg}") + + # Check for specific format mismatch errors + if "Error parsing message" in error_msg or "DecodeError" in error_msg: + print("\n" + "=" * 60) + print("FORMAT MISMATCH DETECTED!") + print("=" * 60) + if use_varint_format: + print( + "You are using length-prefixed format (default) but the " + "listener" + ) + print("is using raw protobuf format.") + print( + "\nTo fix this, run the dialer with the --raw-format flag:" + ) + print(f"identify-demo --raw-format -d {destination}") + else: + print("You are using raw protobuf format but the listener") + print("is using length-prefixed format (default).") + print( + "\nTo fix this, run the dialer without the --raw-format " + "flag:" + ) + print(f"identify-demo -d {destination}") + print("=" * 60) + else: + import traceback + + traceback.print_exc() return diff --git a/tests/core/identity/identify/test_identify_integration.py b/tests/core/identity/identify/test_identify_integration.py new file mode 100644 index 00000000..e4ebcba7 --- /dev/null +++ b/tests/core/identity/identify/test_identify_integration.py @@ -0,0 +1,241 @@ +import logging + +import pytest + +from libp2p.custom_types import TProtocol +from libp2p.identity.identify.identify import ( + AGENT_VERSION, + ID, + PROTOCOL_VERSION, + _multiaddr_to_bytes, + identify_handler_for, + parse_identify_response, +) +from tests.utils.factories import host_pair_factory + +logger = logging.getLogger("libp2p.identity.identify-integration-test") + + +@pytest.mark.trio +async def test_identify_protocol_varint_format_integration(security_protocol): + """Test identify protocol with varint format in real network scenario.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + host_a.set_stream_handler( + ID, identify_handler_for(host_a, use_varint_format=True) + ) + + # Make identify request + stream = await host_b.new_stream(host_a.get_id(), (ID,)) + response = await stream.read(8192) + await stream.close() + + # Parse response + result = parse_identify_response(response) + + # Verify response content + assert result.agent_version == AGENT_VERSION + assert result.protocol_version == PROTOCOL_VERSION + assert result.public_key == host_a.get_public_key().serialize() + assert result.listen_addrs == [ + _multiaddr_to_bytes(addr) for addr in host_a.get_addrs() + ] + + +@pytest.mark.trio +async def test_identify_protocol_raw_format_integration(security_protocol): + """Test identify protocol with raw format in real network scenario.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + host_a.set_stream_handler( + ID, identify_handler_for(host_a, use_varint_format=False) + ) + + # Make identify request + stream = await host_b.new_stream(host_a.get_id(), (ID,)) + response = await stream.read(8192) + await stream.close() + + # Parse response + result = parse_identify_response(response) + + # Verify response content + assert result.agent_version == AGENT_VERSION + assert result.protocol_version == PROTOCOL_VERSION + assert result.public_key == host_a.get_public_key().serialize() + assert result.listen_addrs == [ + _multiaddr_to_bytes(addr) for addr in host_a.get_addrs() + ] + + +@pytest.mark.trio +async def test_identify_default_format_behavior(security_protocol): + """Test identify protocol uses correct default format.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + # Use default identify handler (should use varint format) + host_a.set_stream_handler(ID, identify_handler_for(host_a)) + + # Make identify request + stream = await host_b.new_stream(host_a.get_id(), (ID,)) + response = await stream.read(8192) + await stream.close() + + # Parse response + result = parse_identify_response(response) + + # Verify response content + assert result.agent_version == AGENT_VERSION + assert result.protocol_version == PROTOCOL_VERSION + assert result.public_key == host_a.get_public_key().serialize() + + +@pytest.mark.trio +async def test_identify_cross_format_compatibility_varint_to_raw(security_protocol): + """Test varint dialer with raw listener compatibility.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + # Host A uses raw format + host_a.set_stream_handler( + ID, identify_handler_for(host_a, use_varint_format=False) + ) + + # Host B makes request (will automatically detect format) + stream = await host_b.new_stream(host_a.get_id(), (ID,)) + response = await stream.read(8192) + await stream.close() + + # Parse response (should work with automatic format detection) + result = parse_identify_response(response) + + # Verify response content + assert result.agent_version == AGENT_VERSION + assert result.protocol_version == PROTOCOL_VERSION + assert result.public_key == host_a.get_public_key().serialize() + + +@pytest.mark.trio +async def test_identify_cross_format_compatibility_raw_to_varint(security_protocol): + """Test raw dialer with varint listener compatibility.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + # Host A uses varint format + host_a.set_stream_handler( + ID, identify_handler_for(host_a, use_varint_format=True) + ) + + # Host B makes request (will automatically detect format) + stream = await host_b.new_stream(host_a.get_id(), (ID,)) + response = await stream.read(8192) + await stream.close() + + # Parse response (should work with automatic format detection) + result = parse_identify_response(response) + + # Verify response content + assert result.agent_version == AGENT_VERSION + assert result.protocol_version == PROTOCOL_VERSION + assert result.public_key == host_a.get_public_key().serialize() + + +@pytest.mark.trio +async def test_identify_format_detection_robustness(security_protocol): + """Test identify protocol format detection is robust with various message sizes.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + # Test both formats with different message sizes + for use_varint in [True, False]: + host_a.set_stream_handler( + ID, identify_handler_for(host_a, use_varint_format=use_varint) + ) + + # Make identify request + stream = await host_b.new_stream(host_a.get_id(), (ID,)) + response = await stream.read(8192) + await stream.close() + + # Parse response + result = parse_identify_response(response) + + # Verify response content + assert result.agent_version == AGENT_VERSION + assert result.protocol_version == PROTOCOL_VERSION + assert result.public_key == host_a.get_public_key().serialize() + + +@pytest.mark.trio +async def test_identify_large_message_handling(security_protocol): + """Test identify protocol handles large messages with many protocols.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + # Add many protocols to make the message larger + async def dummy_handler(stream): + pass + + for i in range(10): + host_a.set_stream_handler(TProtocol(f"/test/protocol/{i}"), dummy_handler) + + host_a.set_stream_handler( + ID, identify_handler_for(host_a, use_varint_format=True) + ) + + # Make identify request + stream = await host_b.new_stream(host_a.get_id(), (ID,)) + response = await stream.read(8192) + await stream.close() + + # Parse response + result = parse_identify_response(response) + + # Verify response content + assert result.agent_version == AGENT_VERSION + assert result.protocol_version == PROTOCOL_VERSION + assert result.public_key == host_a.get_public_key().serialize() + + +@pytest.mark.trio +async def test_identify_message_equivalence_real_network(security_protocol): + """Test that both formats produce equivalent messages in real network.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + # Test varint format + host_a.set_stream_handler( + ID, identify_handler_for(host_a, use_varint_format=True) + ) + stream_varint = await host_b.new_stream(host_a.get_id(), (ID,)) + response_varint = await stream_varint.read(8192) + await stream_varint.close() + + # Test raw format + host_a.set_stream_handler( + ID, identify_handler_for(host_a, use_varint_format=False) + ) + stream_raw = await host_b.new_stream(host_a.get_id(), (ID,)) + response_raw = await stream_raw.read(8192) + await stream_raw.close() + + # Parse both responses + result_varint = parse_identify_response(response_varint) + result_raw = parse_identify_response(response_raw) + + # Both should produce identical parsed results + assert result_varint.agent_version == result_raw.agent_version + assert result_varint.protocol_version == result_raw.protocol_version + assert result_varint.public_key == result_raw.public_key + assert result_varint.listen_addrs == result_raw.listen_addrs diff --git a/tests/core/identity/identify/test_identify_parsing.py b/tests/core/identity/identify/test_identify_parsing.py deleted file mode 100644 index d76d82a1..00000000 --- a/tests/core/identity/identify/test_identify_parsing.py +++ /dev/null @@ -1,410 +0,0 @@ -import pytest - -from libp2p.identity.identify.identify import ( - _mk_identify_protobuf, -) -from libp2p.identity.identify.pb.identify_pb2 import ( - Identify, -) -from libp2p.io.abc import Closer, Reader, Writer -from libp2p.utils.varint import ( - decode_varint_from_bytes, - encode_varint_prefixed, -) -from tests.utils.factories import ( - host_pair_factory, -) - - -class MockStream(Reader, Writer, Closer): - """Mock stream for testing identify protocol compatibility.""" - - def __init__(self, data: bytes): - self.data = data - self.position = 0 - self.closed = False - - async def read(self, n: int | None = None) -> bytes: - if self.closed or self.position >= len(self.data): - return b"" - if n is None: - n = len(self.data) - self.position - result = self.data[self.position : self.position + n] - self.position += len(result) - return result - - async def write(self, data: bytes) -> None: - # Mock write - just store the data - pass - - async def close(self) -> None: - self.closed = True - - -def create_identify_message(host, observed_multiaddr=None): - """Create an identify protobuf message.""" - return _mk_identify_protobuf(host, observed_multiaddr) - - -def create_new_format_message(identify_msg): - """Create a new format (length-prefixed) identify message.""" - msg_bytes = identify_msg.SerializeToString() - return encode_varint_prefixed(msg_bytes) - - -def create_old_format_message(identify_msg): - """Create an old format (raw protobuf) identify message.""" - return identify_msg.SerializeToString() - - -async def read_new_format_message(stream) -> bytes: - """Read a new format (length-prefixed) identify message.""" - # Read varint length prefix - length_bytes = b"" - while True: - b = await stream.read(1) - if not b: - break - length_bytes += b - if b[0] & 0x80 == 0: - break - - if not length_bytes: - raise ValueError("No length prefix received") - - msg_length = decode_varint_from_bytes(length_bytes) - - # Read the protobuf message - response = await stream.read(msg_length) - if len(response) != msg_length: - raise ValueError("Incomplete message received") - - return response - - -async def read_old_format_message(stream) -> bytes: - """Read an old format (raw protobuf) identify message.""" - # Read all available data - response = b"" - while True: - chunk = await stream.read(4096) - if not chunk: - break - response += chunk - - return response - - -async def read_compatible_message(stream) -> bytes: - """Read an identify message in either old or new format.""" - # Try to read a few bytes to detect the format - first_bytes = await stream.read(10) - if not first_bytes: - raise ValueError("No data received") - - # Try to decode as varint length prefix (new format) - try: - msg_length = decode_varint_from_bytes(first_bytes) - - # Validate that the length is reasonable (not too large) - if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB - # Calculate how many bytes the varint consumed - varint_len = 0 - for i, byte in enumerate(first_bytes): - varint_len += 1 - if (byte & 0x80) == 0: - break - - # Read the remaining protobuf message - remaining_bytes = await stream.read( - msg_length - (len(first_bytes) - varint_len) - ) - if len(remaining_bytes) == msg_length - (len(first_bytes) - varint_len): - message_data = first_bytes[varint_len:] + remaining_bytes - - # Try to parse as protobuf to validate - try: - Identify().ParseFromString(message_data) - return message_data - except Exception: - # If protobuf parsing fails, fall back to old format - pass - except Exception: - pass - - # Fall back to old format (raw protobuf) - response = first_bytes - - # Read more data if available - while True: - chunk = await stream.read(4096) - if not chunk: - break - response += chunk - - return response - - -async def read_compatible_message_simple(stream) -> bytes: - """Read a message in either old or new format (simplified version for testing).""" - # Try to read a few bytes to detect the format - first_bytes = await stream.read(10) - if not first_bytes: - raise ValueError("No data received") - - # Try to decode as varint length prefix (new format) - try: - msg_length = decode_varint_from_bytes(first_bytes) - - # Validate that the length is reasonable (not too large) - if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB - # Calculate how many bytes the varint consumed - varint_len = 0 - for i, byte in enumerate(first_bytes): - varint_len += 1 - if (byte & 0x80) == 0: - break - - # Read the remaining message - remaining_bytes = await stream.read( - msg_length - (len(first_bytes) - varint_len) - ) - if len(remaining_bytes) == msg_length - (len(first_bytes) - varint_len): - return first_bytes[varint_len:] + remaining_bytes - except Exception: - pass - - # Fall back to old format (raw data) - response = first_bytes - - # Read more data if available - while True: - chunk = await stream.read(4096) - if not chunk: - break - response += chunk - - return response - - -def detect_format(data): - """Detect if data is in new or old format (varint-prefixed or raw protobuf).""" - if not data: - return "unknown" - - # Try to decode as varint - try: - msg_length = decode_varint_from_bytes(data) - - # Validate that the length is reasonable - if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB - # Calculate varint length - varint_len = 0 - for i, byte in enumerate(data): - varint_len += 1 - if (byte & 0x80) == 0: - break - - # Check if we have enough data for the message - if len(data) >= varint_len + msg_length: - # Additional check: try to parse the message as protobuf - try: - message_data = data[varint_len : varint_len + msg_length] - Identify().ParseFromString(message_data) - return "new" - except Exception: - # If protobuf parsing fails, it's probably not a valid new format - pass - except Exception: - pass - - # If varint decoding fails or length is unreasonable, assume old format - return "old" - - -@pytest.mark.trio -async def test_identify_new_format_compatibility(security_protocol): - """Test that identify protocol works with new format (length-prefixed) messages.""" - async with host_pair_factory(security_protocol=security_protocol) as ( - host_a, - host_b, - ): - # Create identify message - identify_msg = create_identify_message(host_a) - - # Create new format message - new_format_data = create_new_format_message(identify_msg) - - # Create mock stream with new format data - stream = MockStream(new_format_data) - - # Read using new format reader - response = await read_new_format_message(stream) - - # Parse the response - parsed_msg = Identify() - parsed_msg.ParseFromString(response) - - # Verify the message content - assert parsed_msg.protocol_version == identify_msg.protocol_version - assert parsed_msg.agent_version == identify_msg.agent_version - assert parsed_msg.public_key == identify_msg.public_key - - -@pytest.mark.trio -async def test_identify_old_format_compatibility(security_protocol): - """Test that identify protocol works with old format (raw protobuf) messages.""" - async with host_pair_factory(security_protocol=security_protocol) as ( - host_a, - host_b, - ): - # Create identify message - identify_msg = create_identify_message(host_a) - - # Create old format message - old_format_data = create_old_format_message(identify_msg) - - # Create mock stream with old format data - stream = MockStream(old_format_data) - - # Read using old format reader - response = await read_old_format_message(stream) - - # Parse the response - parsed_msg = Identify() - parsed_msg.ParseFromString(response) - - # Verify the message content - assert parsed_msg.protocol_version == identify_msg.protocol_version - assert parsed_msg.agent_version == identify_msg.agent_version - assert parsed_msg.public_key == identify_msg.public_key - - -@pytest.mark.trio -async def test_identify_backward_compatibility_old_format(security_protocol): - """Test backward compatibility reader with old format messages.""" - async with host_pair_factory(security_protocol=security_protocol) as ( - host_a, - host_b, - ): - # Create identify message - identify_msg = create_identify_message(host_a) - - # Create old format message - old_format_data = create_old_format_message(identify_msg) - - # Create mock stream with old format data - stream = MockStream(old_format_data) - - # Read using old format reader (which should work reliably) - response = await read_old_format_message(stream) - - # Parse the response - parsed_msg = Identify() - parsed_msg.ParseFromString(response) - - # Verify the message content - assert parsed_msg.protocol_version == identify_msg.protocol_version - assert parsed_msg.agent_version == identify_msg.agent_version - assert parsed_msg.public_key == identify_msg.public_key - - -@pytest.mark.trio -async def test_identify_backward_compatibility_new_format(security_protocol): - """Test backward compatibility reader with new format messages.""" - async with host_pair_factory(security_protocol=security_protocol) as ( - host_a, - host_b, - ): - # Create identify message - identify_msg = create_identify_message(host_a) - - # Create new format message - new_format_data = create_new_format_message(identify_msg) - - # Create mock stream with new format data - stream = MockStream(new_format_data) - - # Read using new format reader (which should work reliably) - response = await read_new_format_message(stream) - - # Parse the response - parsed_msg = Identify() - parsed_msg.ParseFromString(response) - - # Verify the message content - assert parsed_msg.protocol_version == identify_msg.protocol_version - assert parsed_msg.agent_version == identify_msg.agent_version - assert parsed_msg.public_key == identify_msg.public_key - - -@pytest.mark.trio -async def test_identify_format_detection(security_protocol): - """Test that the format detection works correctly.""" - async with host_pair_factory(security_protocol=security_protocol) as ( - host_a, - host_b, - ): - # Create identify message - identify_msg = create_identify_message(host_a) - - # Test new format detection - new_format_data = create_new_format_message(identify_msg) - format_type = detect_format(new_format_data) - assert format_type == "new", "New format should be detected correctly" - - # Test old format detection - old_format_data = create_old_format_message(identify_msg) - format_type = detect_format(old_format_data) - assert format_type == "old", "Old format should be detected correctly" - - -@pytest.mark.trio -async def test_identify_error_handling(security_protocol): - """Test error handling for malformed messages.""" - from libp2p.exceptions import ParseError - - # Test with empty data - stream = MockStream(b"") - with pytest.raises(ValueError, match="No data received"): - await read_compatible_message(stream) - - # Test with incomplete varint - stream = MockStream(b"\x80") # Incomplete varint - with pytest.raises(ParseError, match="Unexpected end of data"): - await read_new_format_message(stream) - - # Test with invalid protobuf data - stream = MockStream(b"\x05invalid") # Length prefix but invalid protobuf - with pytest.raises(Exception): # Should fail when parsing protobuf - response = await read_new_format_message(stream) - Identify().ParseFromString(response) - - -@pytest.mark.trio -async def test_identify_message_equivalence(security_protocol): - """Test that old and new format messages are equivalent.""" - async with host_pair_factory(security_protocol=security_protocol) as ( - host_a, - host_b, - ): - # Create identify message - identify_msg = create_identify_message(host_a) - - # Create both formats - new_format_data = create_new_format_message(identify_msg) - old_format_data = create_old_format_message(identify_msg) - - # Extract the protobuf message from new format - varint_len = 0 - for i, byte in enumerate(new_format_data): - varint_len += 1 - if (byte & 0x80) == 0: - break - - new_format_protobuf = new_format_data[varint_len:] - - # The protobuf messages should be identical - assert new_format_protobuf == old_format_data, ( - "Protobuf messages should be identical in both formats" - ) From 26fd169ccc11ac516243e57c97d756cf5b1aada6 Mon Sep 17 00:00:00 2001 From: acul71 Date: Sat, 19 Jul 2025 04:25:06 +0200 Subject: [PATCH 22/22] doc: newsfragment raw identify message --- newsfragments/778.bugfix.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 newsfragments/778.bugfix.rst diff --git a/newsfragments/778.bugfix.rst b/newsfragments/778.bugfix.rst new file mode 100644 index 00000000..a18832a4 --- /dev/null +++ b/newsfragments/778.bugfix.rst @@ -0,0 +1 @@ +Fixed incorrect handling of raw protobuf format in identify protocol. The identify example now properly handles both raw and length-prefixed (varint) message formats, provides better error messages, and displays connection status with peer IDs. Replaced mock-based tests with comprehensive real network integration tests for both formats.