diff --git a/libp2p/custom_types.py b/libp2p/custom_types.py index d54f1257..d8e1a1d9 100644 --- a/libp2p/custom_types.py +++ b/libp2p/custom_types.py @@ -39,3 +39,4 @@ ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn] UnsubscribeFn = Callable[[], Awaitable[None]] TQUICStreamHandlerFn = Callable[[QUICStream], Awaitable[None]] TQUICConnHandlerFn = Callable[[QUICConnection], Awaitable[None]] +MessageID = NewType("MessageID", str) diff --git a/libp2p/pubsub/gossipsub.py b/libp2p/pubsub/gossipsub.py index a4c8c463..45c6cd81 100644 --- a/libp2p/pubsub/gossipsub.py +++ b/libp2p/pubsub/gossipsub.py @@ -1,6 +1,3 @@ -from ast import ( - literal_eval, -) from collections import ( defaultdict, ) @@ -22,6 +19,7 @@ from libp2p.abc import ( IPubsubRouter, ) from libp2p.custom_types import ( + MessageID, TProtocol, ) from libp2p.peer.id import ( @@ -56,6 +54,10 @@ from .pb import ( from .pubsub import ( Pubsub, ) +from .utils import ( + parse_message_id_safe, + safe_parse_message_id, +) PROTOCOL_ID = TProtocol("/meshsub/1.0.0") PROTOCOL_ID_V11 = TProtocol("/meshsub/1.1.0") @@ -794,8 +796,8 @@ class GossipSub(IPubsubRouter, Service): # Add all unknown message ids (ids that appear in ihave_msg but not in # seen_seqnos) to list of messages we want to request - msg_ids_wanted: list[str] = [ - msg_id + msg_ids_wanted: list[MessageID] = [ + parse_message_id_safe(msg_id) for msg_id in ihave_msg.messageIDs if msg_id not in seen_seqnos_and_peers ] @@ -811,9 +813,9 @@ class GossipSub(IPubsubRouter, Service): Forwards all request messages that are present in mcache to the requesting peer. """ - # FIXME: Update type of message ID - # FIXME: Find a better way to parse the msg ids - msg_ids: list[Any] = [literal_eval(msg) for msg in iwant_msg.messageIDs] + msg_ids: list[tuple[bytes, bytes]] = [ + safe_parse_message_id(msg) for msg in iwant_msg.messageIDs + ] msgs_to_forward: list[rpc_pb2.Message] = [] for msg_id_iwant in msg_ids: # Check if the wanted message ID is present in mcache diff --git a/libp2p/pubsub/utils.py b/libp2p/pubsub/utils.py index 6686ba69..6beaccc5 100644 --- a/libp2p/pubsub/utils.py +++ b/libp2p/pubsub/utils.py @@ -1,6 +1,10 @@ +import ast import logging from libp2p.abc import IHost +from libp2p.custom_types import ( + MessageID, +) from libp2p.peer.envelope import consume_envelope from libp2p.peer.id import ID from libp2p.pubsub.pb.rpc_pb2 import RPC @@ -48,3 +52,29 @@ def maybe_consume_signed_record(msg: RPC, host: IHost, peer_id: ID) -> bool: logger.error("Failed to update the Certified-Addr-Book: %s", e) return False return True + + +def parse_message_id_safe(msg_id_str: str) -> MessageID: + """Safely handle message ID as string.""" + return MessageID(msg_id_str) + + +def safe_parse_message_id(msg_id_str: str) -> tuple[bytes, bytes]: + """ + Safely parse message ID using ast.literal_eval with validation. + :param msg_id_str: String representation of message ID + :return: Tuple of (seqno, from_id) as bytes + :raises ValueError: If parsing fails + """ + try: + parsed = ast.literal_eval(msg_id_str) + if not isinstance(parsed, tuple) or len(parsed) != 2: + raise ValueError("Invalid message ID format") + + seqno, from_id = parsed + if not isinstance(seqno, bytes) or not isinstance(from_id, bytes): + raise ValueError("Message ID components must be bytes") + + return (seqno, from_id) + except (ValueError, SyntaxError) as e: + raise ValueError(f"Invalid message ID format: {e}") diff --git a/newsfragments/843.bugfix.rst b/newsfragments/843.bugfix.rst new file mode 100644 index 00000000..6160bbc7 --- /dev/null +++ b/newsfragments/843.bugfix.rst @@ -0,0 +1 @@ +Fixed message id type inconsistency in handle ihave and message id parsing improvement in handle iwant in pubsub module. diff --git a/tests/core/pubsub/test_gossipsub.py b/tests/core/pubsub/test_gossipsub.py index 91205b29..5c341d0b 100644 --- a/tests/core/pubsub/test_gossipsub.py +++ b/tests/core/pubsub/test_gossipsub.py @@ -1,4 +1,8 @@ import random +from unittest.mock import ( + AsyncMock, + MagicMock, +) import pytest import trio @@ -7,6 +11,9 @@ from libp2p.pubsub.gossipsub import ( PROTOCOL_ID, GossipSub, ) +from libp2p.pubsub.pb import ( + rpc_pb2, +) from libp2p.tools.utils import ( connect, ) @@ -754,3 +761,173 @@ async def test_single_host(): assert connected_peers == 0, ( f"Single host has {connected_peers} connections, expected 0" ) + + +@pytest.mark.trio +async def test_handle_ihave(monkeypatch): + async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub: + gossipsub_routers = [] + for pubsub in pubsubs_gsub: + if isinstance(pubsub.router, GossipSub): + gossipsub_routers.append(pubsub.router) + gossipsubs = tuple(gossipsub_routers) + + index_alice = 0 + index_bob = 1 + id_bob = pubsubs_gsub[index_bob].my_id + + # Connect Alice and Bob + await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host) + await trio.sleep(0.1) # Allow connections to establish + + # Mock emit_iwant to capture calls + mock_emit_iwant = AsyncMock() + monkeypatch.setattr(gossipsubs[index_alice], "emit_iwant", mock_emit_iwant) + + # Create a test message ID as a string representation of a (seqno, from) tuple + test_seqno = b"1234" + test_from = id_bob.to_bytes() + test_msg_id = f"(b'{test_seqno.hex()}', b'{test_from.hex()}')" + ihave_msg = rpc_pb2.ControlIHave(messageIDs=[test_msg_id]) + + # Mock seen_messages.cache to avoid false positives + monkeypatch.setattr(pubsubs_gsub[index_alice].seen_messages, "cache", {}) + + # Simulate Bob sending IHAVE to Alice + await gossipsubs[index_alice].handle_ihave(ihave_msg, id_bob) + + # Check if emit_iwant was called with the correct message ID + mock_emit_iwant.assert_called_once() + called_args = mock_emit_iwant.call_args[0] + assert called_args[0] == [test_msg_id] # Expected message IDs + assert called_args[1] == id_bob # Sender peer ID + + +@pytest.mark.trio +async def test_handle_iwant(monkeypatch): + async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub: + gossipsub_routers = [] + for pubsub in pubsubs_gsub: + if isinstance(pubsub.router, GossipSub): + gossipsub_routers.append(pubsub.router) + gossipsubs = tuple(gossipsub_routers) + + index_alice = 0 + index_bob = 1 + id_alice = pubsubs_gsub[index_alice].my_id + + # Connect Alice and Bob + await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host) + await trio.sleep(0.1) # Allow connections to establish + + # Mock mcache.get to return a message + test_message = rpc_pb2.Message(data=b"test_data") + test_seqno = b"1234" + test_from = id_alice.to_bytes() + + # ✅ Correct: use raw tuple and str() to serialize, no hex() + test_msg_id = str((test_seqno, test_from)) + + mock_mcache_get = MagicMock(return_value=test_message) + monkeypatch.setattr(gossipsubs[index_bob].mcache, "get", mock_mcache_get) + + # Mock write_msg to capture the sent packet + mock_write_msg = AsyncMock() + monkeypatch.setattr(gossipsubs[index_bob].pubsub, "write_msg", mock_write_msg) + + # Simulate Alice sending IWANT to Bob + iwant_msg = rpc_pb2.ControlIWant(messageIDs=[test_msg_id]) + await gossipsubs[index_bob].handle_iwant(iwant_msg, id_alice) + + # Check if write_msg was called with the correct packet + mock_write_msg.assert_called_once() + packet = mock_write_msg.call_args[0][1] + assert isinstance(packet, rpc_pb2.RPC) + assert len(packet.publish) == 1 + assert packet.publish[0] == test_message + + # Verify that mcache.get was called with the correct parsed message ID + mock_mcache_get.assert_called_once() + called_msg_id = mock_mcache_get.call_args[0][0] + assert isinstance(called_msg_id, tuple) + assert called_msg_id == (test_seqno, test_from) + + +@pytest.mark.trio +async def test_handle_iwant_invalid_msg_id(monkeypatch): + """ + Test that handle_iwant raises ValueError for malformed message IDs. + """ + async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub: + gossipsub_routers = [] + for pubsub in pubsubs_gsub: + if isinstance(pubsub.router, GossipSub): + gossipsub_routers.append(pubsub.router) + gossipsubs = tuple(gossipsub_routers) + + index_alice = 0 + index_bob = 1 + id_alice = pubsubs_gsub[index_alice].my_id + + await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host) + await trio.sleep(0.1) + + # Malformed message ID (not a tuple string) + malformed_msg_id = "not_a_valid_msg_id" + iwant_msg = rpc_pb2.ControlIWant(messageIDs=[malformed_msg_id]) + + # Mock mcache.get and write_msg to ensure they are not called + mock_mcache_get = MagicMock() + monkeypatch.setattr(gossipsubs[index_bob].mcache, "get", mock_mcache_get) + mock_write_msg = AsyncMock() + monkeypatch.setattr(gossipsubs[index_bob].pubsub, "write_msg", mock_write_msg) + + with pytest.raises(ValueError): + await gossipsubs[index_bob].handle_iwant(iwant_msg, id_alice) + mock_mcache_get.assert_not_called() + mock_write_msg.assert_not_called() + + # Message ID that's a tuple string but not (bytes, bytes) + invalid_tuple_msg_id = "('abc', 123)" + iwant_msg = rpc_pb2.ControlIWant(messageIDs=[invalid_tuple_msg_id]) + with pytest.raises(ValueError): + await gossipsubs[index_bob].handle_iwant(iwant_msg, id_alice) + mock_mcache_get.assert_not_called() + mock_write_msg.assert_not_called() + + +@pytest.mark.trio +async def test_handle_ihave_empty_message_ids(monkeypatch): + """ + Test that handle_ihave with an empty messageIDs list does not call emit_iwant. + """ + async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub: + gossipsub_routers = [] + for pubsub in pubsubs_gsub: + if isinstance(pubsub.router, GossipSub): + gossipsub_routers.append(pubsub.router) + gossipsubs = tuple(gossipsub_routers) + + index_alice = 0 + index_bob = 1 + id_bob = pubsubs_gsub[index_bob].my_id + + # Connect Alice and Bob + await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host) + await trio.sleep(0.1) # Allow connections to establish + + # Mock emit_iwant to capture calls + mock_emit_iwant = AsyncMock() + monkeypatch.setattr(gossipsubs[index_alice], "emit_iwant", mock_emit_iwant) + + # Empty messageIDs list + ihave_msg = rpc_pb2.ControlIHave(messageIDs=[]) + + # Mock seen_messages.cache to avoid false positives + monkeypatch.setattr(pubsubs_gsub[index_alice].seen_messages, "cache", {}) + + # Simulate Bob sending IHAVE to Alice + await gossipsubs[index_alice].handle_ihave(ihave_msg, id_bob) + + # emit_iwant should not be called since there are no message IDs + mock_emit_iwant.assert_not_called()