mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-09 22:50:54 +00:00
Merge branch 'main' into fix/885-Update-default-Bind-address
This commit is contained in:
@ -37,3 +37,4 @@ SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool]
|
|||||||
AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]]
|
AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]]
|
||||||
ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn]
|
ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn]
|
||||||
UnsubscribeFn = Callable[[], Awaitable[None]]
|
UnsubscribeFn = Callable[[], Awaitable[None]]
|
||||||
|
MessageID = NewType("MessageID", str)
|
||||||
|
|||||||
@ -1,6 +1,3 @@
|
|||||||
from ast import (
|
|
||||||
literal_eval,
|
|
||||||
)
|
|
||||||
from collections import (
|
from collections import (
|
||||||
defaultdict,
|
defaultdict,
|
||||||
)
|
)
|
||||||
@ -22,6 +19,7 @@ from libp2p.abc import (
|
|||||||
IPubsubRouter,
|
IPubsubRouter,
|
||||||
)
|
)
|
||||||
from libp2p.custom_types import (
|
from libp2p.custom_types import (
|
||||||
|
MessageID,
|
||||||
TProtocol,
|
TProtocol,
|
||||||
)
|
)
|
||||||
from libp2p.peer.id import (
|
from libp2p.peer.id import (
|
||||||
@ -56,6 +54,10 @@ from .pb import (
|
|||||||
from .pubsub import (
|
from .pubsub import (
|
||||||
Pubsub,
|
Pubsub,
|
||||||
)
|
)
|
||||||
|
from .utils import (
|
||||||
|
parse_message_id_safe,
|
||||||
|
safe_parse_message_id,
|
||||||
|
)
|
||||||
|
|
||||||
PROTOCOL_ID = TProtocol("/meshsub/1.0.0")
|
PROTOCOL_ID = TProtocol("/meshsub/1.0.0")
|
||||||
PROTOCOL_ID_V11 = TProtocol("/meshsub/1.1.0")
|
PROTOCOL_ID_V11 = TProtocol("/meshsub/1.1.0")
|
||||||
@ -794,8 +796,8 @@ class GossipSub(IPubsubRouter, Service):
|
|||||||
|
|
||||||
# Add all unknown message ids (ids that appear in ihave_msg but not in
|
# Add all unknown message ids (ids that appear in ihave_msg but not in
|
||||||
# seen_seqnos) to list of messages we want to request
|
# seen_seqnos) to list of messages we want to request
|
||||||
msg_ids_wanted: list[str] = [
|
msg_ids_wanted: list[MessageID] = [
|
||||||
msg_id
|
parse_message_id_safe(msg_id)
|
||||||
for msg_id in ihave_msg.messageIDs
|
for msg_id in ihave_msg.messageIDs
|
||||||
if msg_id not in seen_seqnos_and_peers
|
if msg_id not in seen_seqnos_and_peers
|
||||||
]
|
]
|
||||||
@ -811,9 +813,9 @@ class GossipSub(IPubsubRouter, Service):
|
|||||||
Forwards all request messages that are present in mcache to the
|
Forwards all request messages that are present in mcache to the
|
||||||
requesting peer.
|
requesting peer.
|
||||||
"""
|
"""
|
||||||
# FIXME: Update type of message ID
|
msg_ids: list[tuple[bytes, bytes]] = [
|
||||||
# FIXME: Find a better way to parse the msg ids
|
safe_parse_message_id(msg) for msg in iwant_msg.messageIDs
|
||||||
msg_ids: list[Any] = [literal_eval(msg) for msg in iwant_msg.messageIDs]
|
]
|
||||||
msgs_to_forward: list[rpc_pb2.Message] = []
|
msgs_to_forward: list[rpc_pb2.Message] = []
|
||||||
for msg_id_iwant in msg_ids:
|
for msg_id_iwant in msg_ids:
|
||||||
# Check if the wanted message ID is present in mcache
|
# Check if the wanted message ID is present in mcache
|
||||||
|
|||||||
@ -1,6 +1,10 @@
|
|||||||
|
import ast
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from libp2p.abc import IHost
|
from libp2p.abc import IHost
|
||||||
|
from libp2p.custom_types import (
|
||||||
|
MessageID,
|
||||||
|
)
|
||||||
from libp2p.peer.envelope import consume_envelope
|
from libp2p.peer.envelope import consume_envelope
|
||||||
from libp2p.peer.id import ID
|
from libp2p.peer.id import ID
|
||||||
from libp2p.pubsub.pb.rpc_pb2 import RPC
|
from libp2p.pubsub.pb.rpc_pb2 import RPC
|
||||||
@ -48,3 +52,29 @@ def maybe_consume_signed_record(msg: RPC, host: IHost, peer_id: ID) -> bool:
|
|||||||
logger.error("Failed to update the Certified-Addr-Book: %s", e)
|
logger.error("Failed to update the Certified-Addr-Book: %s", e)
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def parse_message_id_safe(msg_id_str: str) -> MessageID:
|
||||||
|
"""Safely handle message ID as string."""
|
||||||
|
return MessageID(msg_id_str)
|
||||||
|
|
||||||
|
|
||||||
|
def safe_parse_message_id(msg_id_str: str) -> tuple[bytes, bytes]:
|
||||||
|
"""
|
||||||
|
Safely parse message ID using ast.literal_eval with validation.
|
||||||
|
:param msg_id_str: String representation of message ID
|
||||||
|
:return: Tuple of (seqno, from_id) as bytes
|
||||||
|
:raises ValueError: If parsing fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
parsed = ast.literal_eval(msg_id_str)
|
||||||
|
if not isinstance(parsed, tuple) or len(parsed) != 2:
|
||||||
|
raise ValueError("Invalid message ID format")
|
||||||
|
|
||||||
|
seqno, from_id = parsed
|
||||||
|
if not isinstance(seqno, bytes) or not isinstance(from_id, bytes):
|
||||||
|
raise ValueError("Message ID components must be bytes")
|
||||||
|
|
||||||
|
return (seqno, from_id)
|
||||||
|
except (ValueError, SyntaxError) as e:
|
||||||
|
raise ValueError(f"Invalid message ID format: {e}")
|
||||||
|
|||||||
1
newsfragments/843.bugfix.rst
Normal file
1
newsfragments/843.bugfix.rst
Normal file
@ -0,0 +1 @@
|
|||||||
|
Fixed message id type inconsistency in handle ihave and message id parsing improvement in handle iwant in pubsub module.
|
||||||
@ -1,4 +1,8 @@
|
|||||||
import random
|
import random
|
||||||
|
from unittest.mock import (
|
||||||
|
AsyncMock,
|
||||||
|
MagicMock,
|
||||||
|
)
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import trio
|
import trio
|
||||||
@ -7,6 +11,9 @@ from libp2p.pubsub.gossipsub import (
|
|||||||
PROTOCOL_ID,
|
PROTOCOL_ID,
|
||||||
GossipSub,
|
GossipSub,
|
||||||
)
|
)
|
||||||
|
from libp2p.pubsub.pb import (
|
||||||
|
rpc_pb2,
|
||||||
|
)
|
||||||
from libp2p.tools.utils import (
|
from libp2p.tools.utils import (
|
||||||
connect,
|
connect,
|
||||||
)
|
)
|
||||||
@ -754,3 +761,173 @@ async def test_single_host():
|
|||||||
assert connected_peers == 0, (
|
assert connected_peers == 0, (
|
||||||
f"Single host has {connected_peers} connections, expected 0"
|
f"Single host has {connected_peers} connections, expected 0"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_handle_ihave(monkeypatch):
|
||||||
|
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
|
||||||
|
gossipsub_routers = []
|
||||||
|
for pubsub in pubsubs_gsub:
|
||||||
|
if isinstance(pubsub.router, GossipSub):
|
||||||
|
gossipsub_routers.append(pubsub.router)
|
||||||
|
gossipsubs = tuple(gossipsub_routers)
|
||||||
|
|
||||||
|
index_alice = 0
|
||||||
|
index_bob = 1
|
||||||
|
id_bob = pubsubs_gsub[index_bob].my_id
|
||||||
|
|
||||||
|
# Connect Alice and Bob
|
||||||
|
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
|
||||||
|
await trio.sleep(0.1) # Allow connections to establish
|
||||||
|
|
||||||
|
# Mock emit_iwant to capture calls
|
||||||
|
mock_emit_iwant = AsyncMock()
|
||||||
|
monkeypatch.setattr(gossipsubs[index_alice], "emit_iwant", mock_emit_iwant)
|
||||||
|
|
||||||
|
# Create a test message ID as a string representation of a (seqno, from) tuple
|
||||||
|
test_seqno = b"1234"
|
||||||
|
test_from = id_bob.to_bytes()
|
||||||
|
test_msg_id = f"(b'{test_seqno.hex()}', b'{test_from.hex()}')"
|
||||||
|
ihave_msg = rpc_pb2.ControlIHave(messageIDs=[test_msg_id])
|
||||||
|
|
||||||
|
# Mock seen_messages.cache to avoid false positives
|
||||||
|
monkeypatch.setattr(pubsubs_gsub[index_alice].seen_messages, "cache", {})
|
||||||
|
|
||||||
|
# Simulate Bob sending IHAVE to Alice
|
||||||
|
await gossipsubs[index_alice].handle_ihave(ihave_msg, id_bob)
|
||||||
|
|
||||||
|
# Check if emit_iwant was called with the correct message ID
|
||||||
|
mock_emit_iwant.assert_called_once()
|
||||||
|
called_args = mock_emit_iwant.call_args[0]
|
||||||
|
assert called_args[0] == [test_msg_id] # Expected message IDs
|
||||||
|
assert called_args[1] == id_bob # Sender peer ID
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_handle_iwant(monkeypatch):
|
||||||
|
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
|
||||||
|
gossipsub_routers = []
|
||||||
|
for pubsub in pubsubs_gsub:
|
||||||
|
if isinstance(pubsub.router, GossipSub):
|
||||||
|
gossipsub_routers.append(pubsub.router)
|
||||||
|
gossipsubs = tuple(gossipsub_routers)
|
||||||
|
|
||||||
|
index_alice = 0
|
||||||
|
index_bob = 1
|
||||||
|
id_alice = pubsubs_gsub[index_alice].my_id
|
||||||
|
|
||||||
|
# Connect Alice and Bob
|
||||||
|
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
|
||||||
|
await trio.sleep(0.1) # Allow connections to establish
|
||||||
|
|
||||||
|
# Mock mcache.get to return a message
|
||||||
|
test_message = rpc_pb2.Message(data=b"test_data")
|
||||||
|
test_seqno = b"1234"
|
||||||
|
test_from = id_alice.to_bytes()
|
||||||
|
|
||||||
|
# ✅ Correct: use raw tuple and str() to serialize, no hex()
|
||||||
|
test_msg_id = str((test_seqno, test_from))
|
||||||
|
|
||||||
|
mock_mcache_get = MagicMock(return_value=test_message)
|
||||||
|
monkeypatch.setattr(gossipsubs[index_bob].mcache, "get", mock_mcache_get)
|
||||||
|
|
||||||
|
# Mock write_msg to capture the sent packet
|
||||||
|
mock_write_msg = AsyncMock()
|
||||||
|
monkeypatch.setattr(gossipsubs[index_bob].pubsub, "write_msg", mock_write_msg)
|
||||||
|
|
||||||
|
# Simulate Alice sending IWANT to Bob
|
||||||
|
iwant_msg = rpc_pb2.ControlIWant(messageIDs=[test_msg_id])
|
||||||
|
await gossipsubs[index_bob].handle_iwant(iwant_msg, id_alice)
|
||||||
|
|
||||||
|
# Check if write_msg was called with the correct packet
|
||||||
|
mock_write_msg.assert_called_once()
|
||||||
|
packet = mock_write_msg.call_args[0][1]
|
||||||
|
assert isinstance(packet, rpc_pb2.RPC)
|
||||||
|
assert len(packet.publish) == 1
|
||||||
|
assert packet.publish[0] == test_message
|
||||||
|
|
||||||
|
# Verify that mcache.get was called with the correct parsed message ID
|
||||||
|
mock_mcache_get.assert_called_once()
|
||||||
|
called_msg_id = mock_mcache_get.call_args[0][0]
|
||||||
|
assert isinstance(called_msg_id, tuple)
|
||||||
|
assert called_msg_id == (test_seqno, test_from)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_handle_iwant_invalid_msg_id(monkeypatch):
|
||||||
|
"""
|
||||||
|
Test that handle_iwant raises ValueError for malformed message IDs.
|
||||||
|
"""
|
||||||
|
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
|
||||||
|
gossipsub_routers = []
|
||||||
|
for pubsub in pubsubs_gsub:
|
||||||
|
if isinstance(pubsub.router, GossipSub):
|
||||||
|
gossipsub_routers.append(pubsub.router)
|
||||||
|
gossipsubs = tuple(gossipsub_routers)
|
||||||
|
|
||||||
|
index_alice = 0
|
||||||
|
index_bob = 1
|
||||||
|
id_alice = pubsubs_gsub[index_alice].my_id
|
||||||
|
|
||||||
|
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
|
||||||
|
await trio.sleep(0.1)
|
||||||
|
|
||||||
|
# Malformed message ID (not a tuple string)
|
||||||
|
malformed_msg_id = "not_a_valid_msg_id"
|
||||||
|
iwant_msg = rpc_pb2.ControlIWant(messageIDs=[malformed_msg_id])
|
||||||
|
|
||||||
|
# Mock mcache.get and write_msg to ensure they are not called
|
||||||
|
mock_mcache_get = MagicMock()
|
||||||
|
monkeypatch.setattr(gossipsubs[index_bob].mcache, "get", mock_mcache_get)
|
||||||
|
mock_write_msg = AsyncMock()
|
||||||
|
monkeypatch.setattr(gossipsubs[index_bob].pubsub, "write_msg", mock_write_msg)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await gossipsubs[index_bob].handle_iwant(iwant_msg, id_alice)
|
||||||
|
mock_mcache_get.assert_not_called()
|
||||||
|
mock_write_msg.assert_not_called()
|
||||||
|
|
||||||
|
# Message ID that's a tuple string but not (bytes, bytes)
|
||||||
|
invalid_tuple_msg_id = "('abc', 123)"
|
||||||
|
iwant_msg = rpc_pb2.ControlIWant(messageIDs=[invalid_tuple_msg_id])
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await gossipsubs[index_bob].handle_iwant(iwant_msg, id_alice)
|
||||||
|
mock_mcache_get.assert_not_called()
|
||||||
|
mock_write_msg.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_handle_ihave_empty_message_ids(monkeypatch):
|
||||||
|
"""
|
||||||
|
Test that handle_ihave with an empty messageIDs list does not call emit_iwant.
|
||||||
|
"""
|
||||||
|
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
|
||||||
|
gossipsub_routers = []
|
||||||
|
for pubsub in pubsubs_gsub:
|
||||||
|
if isinstance(pubsub.router, GossipSub):
|
||||||
|
gossipsub_routers.append(pubsub.router)
|
||||||
|
gossipsubs = tuple(gossipsub_routers)
|
||||||
|
|
||||||
|
index_alice = 0
|
||||||
|
index_bob = 1
|
||||||
|
id_bob = pubsubs_gsub[index_bob].my_id
|
||||||
|
|
||||||
|
# Connect Alice and Bob
|
||||||
|
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
|
||||||
|
await trio.sleep(0.1) # Allow connections to establish
|
||||||
|
|
||||||
|
# Mock emit_iwant to capture calls
|
||||||
|
mock_emit_iwant = AsyncMock()
|
||||||
|
monkeypatch.setattr(gossipsubs[index_alice], "emit_iwant", mock_emit_iwant)
|
||||||
|
|
||||||
|
# Empty messageIDs list
|
||||||
|
ihave_msg = rpc_pb2.ControlIHave(messageIDs=[])
|
||||||
|
|
||||||
|
# Mock seen_messages.cache to avoid false positives
|
||||||
|
monkeypatch.setattr(pubsubs_gsub[index_alice].seen_messages, "cache", {})
|
||||||
|
|
||||||
|
# Simulate Bob sending IHAVE to Alice
|
||||||
|
await gossipsubs[index_alice].handle_ihave(ihave_msg, id_bob)
|
||||||
|
|
||||||
|
# emit_iwant should not be called since there are no message IDs
|
||||||
|
mock_emit_iwant.assert_not_called()
|
||||||
|
|||||||
Reference in New Issue
Block a user