18 Commits

Author SHA1 Message Date
80b58a2ae0 Merge branch 'main' into noise-arch-change 2025-09-05 02:55:55 +05:30
9370101a84 Merge pull request #843 from unniznd/fix_pubsub_msg_id_type_inconsistency
fix: message id type inconsistency in handle ihave and message id parsing improvement in handle iwant
2025-09-04 23:39:14 +05:30
56732a1506 Merge branch 'main' into fix_pubsub_msg_id_type_inconsistency 2025-09-04 16:26:01 +05:30
b8217bb8a8 Merge branch 'main' into fix_pubsub_msg_id_type_inconsistency 2025-09-02 10:16:17 +05:30
333d56dc00 Merge branch 'main' into noise-arch-change 2025-09-02 03:40:54 +05:30
20edc3830a Merge branch 'main' into fix_pubsub_msg_id_type_inconsistency 2025-09-02 01:07:16 +05:30
69680e9c1f Added negative testcases 2025-09-01 10:30:25 +05:30
40dad64949 Merge branch 'main' into fix_pubsub_msg_id_type_inconsistency 2025-08-29 03:24:53 +05:30
999315a74a Merge branch 'main' into noise-arch-change 2025-08-29 03:23:05 +05:30
8100a5cd20 removed redudant check in seen seqnos and peers and added test cases of handle iwant and handle ihave 2025-08-26 21:49:12 +05:30
cacb3c8aca feat: add webtransport certhashes field to NoiseExtensions and implement serialization test
Signed-off-by: varun-r-mallya <varunrmallya@gmail.com>
2025-08-26 12:49:21 +05:30
fb544d6db2 fixed the merge conflict gossipsub module. 2025-08-25 21:12:45 +05:30
b40d84fc26 Merge remote-tracking branch 'origin/main' into fix_pubsub_msg_id_type_inconsistency 2025-08-25 21:11:55 +05:30
05fde3ad40 Merge branch 'main' into noise-arch-change 2025-08-25 16:21:43 +05:30
e4ab3cb2c5 Add early data support to Noise protocol
Signed-off-by: varun-r-mallya <varunrmallya@gmail.com>
2025-08-19 04:41:14 +05:30
a9a6ed6767 Merge branch 'main' into fix_pubsub_msg_id_type_inconsistency 2025-08-18 22:02:20 +05:30
388302baa7 Added newsfragment 2025-08-15 13:57:21 +05:30
dc04270c19 fix: message id type inonsistency in handle ihave and message id parsing improvement in handle iwant 2025-08-15 13:53:24 +05:30
17 changed files with 475 additions and 67 deletions

View File

@ -24,13 +24,8 @@ async def main():
noise_transport = NoiseTransport(
# local_key_pair: The key pair used for libp2p identity and authentication
libp2p_keypair=key_pair,
# noise_privkey: The private key used for Noise protocol encryption
noise_privkey=key_pair.private_key,
# early_data: Optional data to send during the handshake
# (None means no early data)
early_data=None,
# with_noise_pipes: Whether to use Noise pipes for additional security features
with_noise_pipes=False,
# TODO: add early data
)
# Create a security options dictionary mapping protocol ID to transport

View File

@ -28,9 +28,7 @@ async def main():
noise_privkey=key_pair.private_key,
# early_data: Optional data to send during the handshake
# (None means no early data)
early_data=None,
# with_noise_pipes: Whether to use Noise pipes for additional security features
with_noise_pipes=False,
# TODO: add early data
)
# Create a security options dictionary mapping protocol ID to transport

View File

@ -31,9 +31,7 @@ async def main():
noise_privkey=key_pair.private_key,
# early_data: Optional data to send during the handshake
# (None means no early data)
early_data=None,
# with_noise_pipes: Whether to use Noise pipes for additional security features
with_noise_pipes=False,
# TODO: add early data
)
# Create a security options dictionary mapping protocol ID to transport

View File

@ -28,9 +28,7 @@ async def main():
noise_privkey=key_pair.private_key,
# early_data: Optional data to send during the handshake
# (None means no early data)
early_data=None,
# with_noise_pipes: Whether to use Noise pipes for additional security features
with_noise_pipes=False,
# TODO: add early data
)
# Create a security options dictionary mapping protocol ID to transport

View File

@ -37,3 +37,4 @@ SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool]
AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]]
ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn]
UnsubscribeFn = Callable[[], Awaitable[None]]
MessageID = NewType("MessageID", str)

View File

@ -1,6 +1,3 @@
from ast import (
literal_eval,
)
from collections import (
defaultdict,
)
@ -22,6 +19,7 @@ from libp2p.abc import (
IPubsubRouter,
)
from libp2p.custom_types import (
MessageID,
TProtocol,
)
from libp2p.peer.id import (
@ -56,6 +54,10 @@ from .pb import (
from .pubsub import (
Pubsub,
)
from .utils import (
parse_message_id_safe,
safe_parse_message_id,
)
PROTOCOL_ID = TProtocol("/meshsub/1.0.0")
PROTOCOL_ID_V11 = TProtocol("/meshsub/1.1.0")
@ -794,8 +796,8 @@ class GossipSub(IPubsubRouter, Service):
# Add all unknown message ids (ids that appear in ihave_msg but not in
# seen_seqnos) to list of messages we want to request
msg_ids_wanted: list[str] = [
msg_id
msg_ids_wanted: list[MessageID] = [
parse_message_id_safe(msg_id)
for msg_id in ihave_msg.messageIDs
if msg_id not in seen_seqnos_and_peers
]
@ -811,9 +813,9 @@ class GossipSub(IPubsubRouter, Service):
Forwards all request messages that are present in mcache to the
requesting peer.
"""
# FIXME: Update type of message ID
# FIXME: Find a better way to parse the msg ids
msg_ids: list[Any] = [literal_eval(msg) for msg in iwant_msg.messageIDs]
msg_ids: list[tuple[bytes, bytes]] = [
safe_parse_message_id(msg) for msg in iwant_msg.messageIDs
]
msgs_to_forward: list[rpc_pb2.Message] = []
for msg_id_iwant in msg_ids:
# Check if the wanted message ID is present in mcache

View File

@ -1,6 +1,10 @@
import ast
import logging
from libp2p.abc import IHost
from libp2p.custom_types import (
MessageID,
)
from libp2p.peer.envelope import consume_envelope
from libp2p.peer.id import ID
from libp2p.pubsub.pb.rpc_pb2 import RPC
@ -48,3 +52,29 @@ def maybe_consume_signed_record(msg: RPC, host: IHost, peer_id: ID) -> bool:
logger.error("Failed to update the Certified-Addr-Book: %s", e)
return False
return True
def parse_message_id_safe(msg_id_str: str) -> MessageID:
"""Safely handle message ID as string."""
return MessageID(msg_id_str)
def safe_parse_message_id(msg_id_str: str) -> tuple[bytes, bytes]:
"""
Safely parse message ID using ast.literal_eval with validation.
:param msg_id_str: String representation of message ID
:return: Tuple of (seqno, from_id) as bytes
:raises ValueError: If parsing fails
"""
try:
parsed = ast.literal_eval(msg_id_str)
if not isinstance(parsed, tuple) or len(parsed) != 2:
raise ValueError("Invalid message ID format")
seqno, from_id = parsed
if not isinstance(seqno, bytes) or not isinstance(from_id, bytes):
raise ValueError("Message ID components must be bytes")
return (seqno, from_id)
except (ValueError, SyntaxError) as e:
raise ValueError(f"Invalid message ID format: {e}")

View File

@ -0,0 +1,68 @@
from abc import ABC, abstractmethod
from libp2p.abc import IRawConnection
from libp2p.custom_types import TProtocol
from libp2p.peer.id import ID
from .pb import noise_pb2 as noise_pb
class EarlyDataHandler(ABC):
"""Interface for handling early data during Noise handshake"""
@abstractmethod
async def send(
self, conn: IRawConnection, peer_id: ID
) -> noise_pb.NoiseExtensions | None:
"""Called to generate early data to send during handshake"""
pass
@abstractmethod
async def received(
self, conn: IRawConnection, extensions: noise_pb.NoiseExtensions | None
) -> None:
"""Called when early data is received during handshake"""
pass
class TransportEarlyDataHandler(EarlyDataHandler):
"""Default early data handler for muxer negotiation"""
def __init__(self, supported_muxers: list[TProtocol]):
self.supported_muxers = supported_muxers
self.received_muxers: list[TProtocol] = []
async def send(
self, conn: IRawConnection, peer_id: ID
) -> noise_pb.NoiseExtensions | None:
"""Send our supported muxers list"""
if not self.supported_muxers:
return None
extensions = noise_pb.NoiseExtensions()
# Convert TProtocol to string for serialization
extensions.stream_muxers[:] = [str(muxer) for muxer in self.supported_muxers]
return extensions
async def received(
self, conn: IRawConnection, extensions: noise_pb.NoiseExtensions | None
) -> None:
"""Store received muxers list"""
if extensions and extensions.stream_muxers:
self.received_muxers = [
TProtocol(muxer) for muxer in extensions.stream_muxers
]
def match_muxers(self, is_initiator: bool) -> TProtocol | None:
"""Find first common muxer between local and remote"""
if is_initiator:
# Initiator: find first local muxer that remote supports
for local_muxer in self.supported_muxers:
if local_muxer in self.received_muxers:
return local_muxer
else:
# Responder: find first remote muxer that we support
for remote_muxer in self.received_muxers:
if remote_muxer in self.supported_muxers:
return remote_muxer
return None

View File

@ -30,6 +30,9 @@ from libp2p.security.secure_session import (
SecureSession,
)
from .early_data import (
EarlyDataHandler,
)
from .exceptions import (
HandshakeHasNotFinished,
InvalidSignature,
@ -45,6 +48,7 @@ from .messages import (
make_handshake_payload_sig,
verify_handshake_payload_sig,
)
from .pb import noise_pb2 as noise_pb
class IPattern(ABC):
@ -62,7 +66,8 @@ class BasePattern(IPattern):
noise_static_key: PrivateKey
local_peer: ID
libp2p_privkey: PrivateKey
early_data: bytes | None
initiator_early_data_handler: EarlyDataHandler | None
responder_early_data_handler: EarlyDataHandler | None
def create_noise_state(self) -> NoiseState:
noise_state = NoiseState.from_name(self.protocol_name)
@ -73,11 +78,50 @@ class BasePattern(IPattern):
raise NoiseStateError("noise_protocol is not initialized")
return noise_state
def make_handshake_payload(self) -> NoiseHandshakePayload:
async def make_handshake_payload(
self, conn: IRawConnection, peer_id: ID, is_initiator: bool
) -> NoiseHandshakePayload:
signature = make_handshake_payload_sig(
self.libp2p_privkey, self.noise_static_key.get_public_key()
)
return NoiseHandshakePayload(self.libp2p_privkey.get_public_key(), signature)
# NEW: Get early data from appropriate handler
extensions = None
if is_initiator and self.initiator_early_data_handler:
extensions = await self.initiator_early_data_handler.send(conn, peer_id)
elif not is_initiator and self.responder_early_data_handler:
extensions = await self.responder_early_data_handler.send(conn, peer_id)
# NEW: Serialize extensions into early_data field
early_data = None
if extensions:
early_data = extensions.SerializeToString()
return NoiseHandshakePayload(
self.libp2p_privkey.get_public_key(),
signature,
early_data, # ← This is the key addition
)
async def handle_received_payload(
self, conn: IRawConnection, payload: NoiseHandshakePayload, is_initiator: bool
) -> None:
"""Process early data from received payload"""
if not payload.early_data:
return
# Deserialize the NoiseExtensions from early_data field
try:
extensions = noise_pb.NoiseExtensions.FromString(payload.early_data)
except Exception:
# Invalid extensions, ignore silently
return
# Pass to appropriate handler
if is_initiator and self.initiator_early_data_handler:
await self.initiator_early_data_handler.received(conn, extensions)
elif not is_initiator and self.responder_early_data_handler:
await self.responder_early_data_handler.received(conn, extensions)
class PatternXX(BasePattern):
@ -86,13 +130,15 @@ class PatternXX(BasePattern):
local_peer: ID,
libp2p_privkey: PrivateKey,
noise_static_key: PrivateKey,
early_data: bytes | None = None,
initiator_early_data_handler: EarlyDataHandler | None,
responder_early_data_handler: EarlyDataHandler | None,
) -> None:
self.protocol_name = b"Noise_XX_25519_ChaChaPoly_SHA256"
self.local_peer = local_peer
self.libp2p_privkey = libp2p_privkey
self.noise_static_key = noise_static_key
self.early_data = early_data
self.initiator_early_data_handler = initiator_early_data_handler
self.responder_early_data_handler = responder_early_data_handler
async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn:
noise_state = self.create_noise_state()
@ -106,18 +152,23 @@ class PatternXX(BasePattern):
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
# Consume msg#1.
# 1. Consume msg#1 (just empty bytes)
await read_writer.read_msg()
# Send msg#2, which should include our handshake payload.
our_payload = self.make_handshake_payload()
# 2. Send msg#2 with our payload INCLUDING EARLY DATA
our_payload = await self.make_handshake_payload(
conn,
self.local_peer, # We send our own peer ID in responder role
is_initiator=False,
)
msg_2 = our_payload.serialize()
await read_writer.write_msg(msg_2)
# Receive and consume msg#3.
# 3. Receive msg#3
msg_3 = await read_writer.read_msg()
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_3)
# Extract remote pubkey from noise handshake state
if handshake_state.rs is None:
raise NoiseStateError(
"something is wrong in the underlying noise `handshake_state`: "
@ -126,14 +177,31 @@ class PatternXX(BasePattern):
)
remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs)
# 4. Verify signature (unchanged)
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
raise InvalidSignature
# NEW: Process early data from msg#3 AFTER signature verification
await self.handle_received_payload(
conn, peer_handshake_payload, is_initiator=False
)
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
if not noise_state.handshake_finished:
raise HandshakeHasNotFinished(
"handshake is done but it is not marked as finished in `noise_state`"
)
# NEW: Get negotiated muxer for connection state
# negotiated_muxer = None
if self.responder_early_data_handler and hasattr(
self.responder_early_data_handler, "match_muxers"
):
# negotiated_muxer =
# self.responder_early_data_handler.match_muxers(is_initiator=False)
pass
transport_read_writer = NoiseTransportReadWriter(conn, noise_state)
return SecureSession(
local_peer=self.local_peer,
@ -142,6 +210,8 @@ class PatternXX(BasePattern):
remote_permanent_pubkey=remote_pubkey,
is_initiator=False,
conn=transport_read_writer,
# NOTE: negotiated_muxer would need to be added to SecureSession constructor
# For now, store it in connection metadata or similar
)
async def handshake_outbound(
@ -158,24 +228,27 @@ class PatternXX(BasePattern):
if handshake_state is None:
raise NoiseStateError("Handshake state is not initialized")
# Send msg#1, which is *not* encrypted.
# 1. Send msg#1 (empty) - no early data possible in XX pattern
msg_1 = b""
await read_writer.write_msg(msg_1)
# Read msg#2 from the remote, which contains the public key of the peer.
# 2. Read msg#2 from responder
msg_2 = await read_writer.read_msg()
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_2)
# Extract remote pubkey from noise handshake state
if handshake_state.rs is None:
raise NoiseStateError(
"something is wrong in the underlying noise `handshake_state`: "
"we received and consumed msg#3, which should have included the "
"we received and consumed msg#2, which should have included the "
"remote static public key, but it is not present in the handshake_state"
)
remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs)
# Verify signature BEFORE processing early data (security)
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
raise InvalidSignature
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
if remote_peer_id_from_pubkey != remote_peer:
raise PeerIDMismatchesPubkey(
@ -184,8 +257,15 @@ class PatternXX(BasePattern):
f"remote_peer_id_from_pubkey={remote_peer_id_from_pubkey}"
)
# Send msg#3, which includes our encrypted payload and our noise static key.
our_payload = self.make_handshake_payload()
# NEW: Process early data from msg#2 AFTER verification
await self.handle_received_payload(
conn, peer_handshake_payload, is_initiator=True
)
# 3. Send msg#3 with our payload INCLUDING EARLY DATA
our_payload = await self.make_handshake_payload(
conn, remote_peer, is_initiator=True
)
msg_3 = our_payload.serialize()
await read_writer.write_msg(msg_3)
@ -193,6 +273,16 @@ class PatternXX(BasePattern):
raise HandshakeHasNotFinished(
"handshake is done but it is not marked as finished in `noise_state`"
)
# NEW: Get negotiated muxer
# negotiated_muxer = None
if self.initiator_early_data_handler and hasattr(
self.initiator_early_data_handler, "match_muxers"
):
pass
# negotiated_muxer =
# self.initiator_early_data_handler.match_muxers(is_initiator=True)
transport_read_writer = NoiseTransportReadWriter(conn, noise_state)
return SecureSession(
local_peer=self.local_peer,
@ -201,6 +291,8 @@ class PatternXX(BasePattern):
remote_permanent_pubkey=remote_pubkey,
is_initiator=True,
conn=transport_read_writer,
# NOTE: negotiated_muxer would need to be added to SecureSession constructor
# For now, store it in connection metadata or similar
)
@staticmethod

View File

@ -1,8 +1,13 @@
syntax = "proto3";
syntax = "proto2";
package pb;
message NoiseHandshakePayload {
bytes identity_key = 1;
bytes identity_sig = 2;
bytes data = 3;
message NoiseExtensions {
repeated bytes webtransport_certhashes = 1;
repeated string stream_muxers = 2;
}
message NoiseHandshakePayload {
optional bytes identity_key = 1;
optional bytes identity_sig = 2;
optional bytes data = 3;
}

View File

@ -13,13 +13,15 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$libp2p/security/noise/pb/noise.proto\x12\x02pb\"Q\n\x15NoiseHandshakePayload\x12\x14\n\x0cidentity_key\x18\x01 \x01(\x0c\x12\x14\n\x0cidentity_sig\x18\x02 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\x62\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$libp2p/security/noise/pb/noise.proto\x12\x02pb\"I\n\x0fNoiseExtensions\x12\x1f\n\x17webtransport_certhashes\x18\x01 \x03(\x0c\x12\x15\n\rstream_muxers\x18\x02 \x03(\t\"Q\n\x15NoiseHandshakePayload\x12\x14\n\x0cidentity_key\x18\x01 \x01(\x0c\x12\x14\n\x0cidentity_sig\x18\x02 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.security.noise.pb.noise_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_NOISEHANDSHAKEPAYLOAD._serialized_start=44
_NOISEHANDSHAKEPAYLOAD._serialized_end=125
_NOISEEXTENSIONS._serialized_start=44
_NOISEEXTENSIONS._serialized_end=117
_NOISEHANDSHAKEPAYLOAD._serialized_start=119
_NOISEHANDSHAKEPAYLOAD._serialized_end=200
# @@protoc_insertion_point(module_scope)

View File

@ -4,12 +4,34 @@ isort:skip_file
"""
import builtins
import collections.abc
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.message
import typing
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@typing.final
class NoiseExtensions(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
WEBTRANSPORT_CERTHASHES_FIELD_NUMBER: builtins.int
STREAM_MUXERS_FIELD_NUMBER: builtins.int
@property
def webtransport_certhashes(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ...
@property
def stream_muxers(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ...
def __init__(
self,
*,
webtransport_certhashes: collections.abc.Iterable[builtins.bytes] | None = ...,
stream_muxers: collections.abc.Iterable[builtins.str] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["stream_muxers", b"stream_muxers", "webtransport_certhashes", b"webtransport_certhashes"]) -> None: ...
global___NoiseExtensions = NoiseExtensions
@typing.final
class NoiseHandshakePayload(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@ -23,10 +45,11 @@ class NoiseHandshakePayload(google.protobuf.message.Message):
def __init__(
self,
*,
identity_key: builtins.bytes = ...,
identity_sig: builtins.bytes = ...,
data: builtins.bytes = ...,
identity_key: builtins.bytes | None = ...,
identity_sig: builtins.bytes | None = ...,
data: builtins.bytes | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["data", b"data", "identity_key", b"identity_key", "identity_sig", b"identity_sig"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["data", b"data", "identity_key", b"identity_key", "identity_sig", b"identity_sig"]) -> None: ...
global___NoiseHandshakePayload = NoiseHandshakePayload

View File

@ -14,6 +14,7 @@ from libp2p.peer.id import (
ID,
)
from .early_data import EarlyDataHandler, TransportEarlyDataHandler
from .patterns import (
IPattern,
PatternXX,
@ -26,35 +27,40 @@ class Transport(ISecureTransport):
libp2p_privkey: PrivateKey
noise_privkey: PrivateKey
local_peer: ID
early_data: bytes | None
with_noise_pipes: bool
supported_muxers: list[TProtocol]
initiator_early_data_handler: EarlyDataHandler | None
responder_early_data_handler: EarlyDataHandler | None
def __init__(
self,
libp2p_keypair: KeyPair,
noise_privkey: PrivateKey,
early_data: bytes | None = None,
with_noise_pipes: bool = False,
supported_muxers: list[TProtocol] | None = None,
initiator_handler: EarlyDataHandler | None = None,
responder_handler: EarlyDataHandler | None = None,
) -> None:
self.libp2p_privkey = libp2p_keypair.private_key
self.noise_privkey = noise_privkey
self.local_peer = ID.from_pubkey(libp2p_keypair.public_key)
self.early_data = early_data
self.with_noise_pipes = with_noise_pipes
self.supported_muxers = supported_muxers or []
if self.with_noise_pipes:
raise NotImplementedError
# Create default handlers for muxer negotiation if none provided
if initiator_handler is None and self.supported_muxers:
initiator_handler = TransportEarlyDataHandler(self.supported_muxers)
if responder_handler is None and self.supported_muxers:
responder_handler = TransportEarlyDataHandler(self.supported_muxers)
self.initiator_early_data_handler = initiator_handler
self.responder_early_data_handler = responder_handler
def get_pattern(self) -> IPattern:
if self.with_noise_pipes:
raise NotImplementedError
else:
return PatternXX(
self.local_peer,
self.libp2p_privkey,
self.noise_privkey,
self.early_data,
)
return PatternXX(
self.local_peer,
self.libp2p_privkey,
self.noise_privkey,
self.initiator_early_data_handler,
self.responder_early_data_handler,
)
async def secure_inbound(self, conn: IRawConnection) -> ISecureConn:
pattern = self.get_pattern()

View File

@ -0,0 +1 @@
Fixed message id type inconsistency in handle ihave and message id parsing improvement in handle iwant in pubsub module.

View File

@ -1,4 +1,8 @@
import random
from unittest.mock import (
AsyncMock,
MagicMock,
)
import pytest
import trio
@ -7,6 +11,9 @@ from libp2p.pubsub.gossipsub import (
PROTOCOL_ID,
GossipSub,
)
from libp2p.pubsub.pb import (
rpc_pb2,
)
from libp2p.tools.utils import (
connect,
)
@ -754,3 +761,173 @@ async def test_single_host():
assert connected_peers == 0, (
f"Single host has {connected_peers} connections, expected 0"
)
@pytest.mark.trio
async def test_handle_ihave(monkeypatch):
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
gossipsub_routers = []
for pubsub in pubsubs_gsub:
if isinstance(pubsub.router, GossipSub):
gossipsub_routers.append(pubsub.router)
gossipsubs = tuple(gossipsub_routers)
index_alice = 0
index_bob = 1
id_bob = pubsubs_gsub[index_bob].my_id
# Connect Alice and Bob
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
await trio.sleep(0.1) # Allow connections to establish
# Mock emit_iwant to capture calls
mock_emit_iwant = AsyncMock()
monkeypatch.setattr(gossipsubs[index_alice], "emit_iwant", mock_emit_iwant)
# Create a test message ID as a string representation of a (seqno, from) tuple
test_seqno = b"1234"
test_from = id_bob.to_bytes()
test_msg_id = f"(b'{test_seqno.hex()}', b'{test_from.hex()}')"
ihave_msg = rpc_pb2.ControlIHave(messageIDs=[test_msg_id])
# Mock seen_messages.cache to avoid false positives
monkeypatch.setattr(pubsubs_gsub[index_alice].seen_messages, "cache", {})
# Simulate Bob sending IHAVE to Alice
await gossipsubs[index_alice].handle_ihave(ihave_msg, id_bob)
# Check if emit_iwant was called with the correct message ID
mock_emit_iwant.assert_called_once()
called_args = mock_emit_iwant.call_args[0]
assert called_args[0] == [test_msg_id] # Expected message IDs
assert called_args[1] == id_bob # Sender peer ID
@pytest.mark.trio
async def test_handle_iwant(monkeypatch):
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
gossipsub_routers = []
for pubsub in pubsubs_gsub:
if isinstance(pubsub.router, GossipSub):
gossipsub_routers.append(pubsub.router)
gossipsubs = tuple(gossipsub_routers)
index_alice = 0
index_bob = 1
id_alice = pubsubs_gsub[index_alice].my_id
# Connect Alice and Bob
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
await trio.sleep(0.1) # Allow connections to establish
# Mock mcache.get to return a message
test_message = rpc_pb2.Message(data=b"test_data")
test_seqno = b"1234"
test_from = id_alice.to_bytes()
# ✅ Correct: use raw tuple and str() to serialize, no hex()
test_msg_id = str((test_seqno, test_from))
mock_mcache_get = MagicMock(return_value=test_message)
monkeypatch.setattr(gossipsubs[index_bob].mcache, "get", mock_mcache_get)
# Mock write_msg to capture the sent packet
mock_write_msg = AsyncMock()
monkeypatch.setattr(gossipsubs[index_bob].pubsub, "write_msg", mock_write_msg)
# Simulate Alice sending IWANT to Bob
iwant_msg = rpc_pb2.ControlIWant(messageIDs=[test_msg_id])
await gossipsubs[index_bob].handle_iwant(iwant_msg, id_alice)
# Check if write_msg was called with the correct packet
mock_write_msg.assert_called_once()
packet = mock_write_msg.call_args[0][1]
assert isinstance(packet, rpc_pb2.RPC)
assert len(packet.publish) == 1
assert packet.publish[0] == test_message
# Verify that mcache.get was called with the correct parsed message ID
mock_mcache_get.assert_called_once()
called_msg_id = mock_mcache_get.call_args[0][0]
assert isinstance(called_msg_id, tuple)
assert called_msg_id == (test_seqno, test_from)
@pytest.mark.trio
async def test_handle_iwant_invalid_msg_id(monkeypatch):
"""
Test that handle_iwant raises ValueError for malformed message IDs.
"""
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
gossipsub_routers = []
for pubsub in pubsubs_gsub:
if isinstance(pubsub.router, GossipSub):
gossipsub_routers.append(pubsub.router)
gossipsubs = tuple(gossipsub_routers)
index_alice = 0
index_bob = 1
id_alice = pubsubs_gsub[index_alice].my_id
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
await trio.sleep(0.1)
# Malformed message ID (not a tuple string)
malformed_msg_id = "not_a_valid_msg_id"
iwant_msg = rpc_pb2.ControlIWant(messageIDs=[malformed_msg_id])
# Mock mcache.get and write_msg to ensure they are not called
mock_mcache_get = MagicMock()
monkeypatch.setattr(gossipsubs[index_bob].mcache, "get", mock_mcache_get)
mock_write_msg = AsyncMock()
monkeypatch.setattr(gossipsubs[index_bob].pubsub, "write_msg", mock_write_msg)
with pytest.raises(ValueError):
await gossipsubs[index_bob].handle_iwant(iwant_msg, id_alice)
mock_mcache_get.assert_not_called()
mock_write_msg.assert_not_called()
# Message ID that's a tuple string but not (bytes, bytes)
invalid_tuple_msg_id = "('abc', 123)"
iwant_msg = rpc_pb2.ControlIWant(messageIDs=[invalid_tuple_msg_id])
with pytest.raises(ValueError):
await gossipsubs[index_bob].handle_iwant(iwant_msg, id_alice)
mock_mcache_get.assert_not_called()
mock_write_msg.assert_not_called()
@pytest.mark.trio
async def test_handle_ihave_empty_message_ids(monkeypatch):
"""
Test that handle_ihave with an empty messageIDs list does not call emit_iwant.
"""
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
gossipsub_routers = []
for pubsub in pubsubs_gsub:
if isinstance(pubsub.router, GossipSub):
gossipsub_routers.append(pubsub.router)
gossipsubs = tuple(gossipsub_routers)
index_alice = 0
index_bob = 1
id_bob = pubsubs_gsub[index_bob].my_id
# Connect Alice and Bob
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
await trio.sleep(0.1) # Allow connections to establish
# Mock emit_iwant to capture calls
mock_emit_iwant = AsyncMock()
monkeypatch.setattr(gossipsubs[index_alice], "emit_iwant", mock_emit_iwant)
# Empty messageIDs list
ihave_msg = rpc_pb2.ControlIHave(messageIDs=[])
# Mock seen_messages.cache to avoid false positives
monkeypatch.setattr(pubsubs_gsub[index_alice].seen_messages, "cache", {})
# Simulate Bob sending IHAVE to Alice
await gossipsubs[index_alice].handle_ihave(ihave_msg, id_bob)
# emit_iwant should not be called since there are no message IDs
mock_emit_iwant.assert_not_called()

View File

@ -0,0 +1,13 @@
from libp2p.security.noise.pb import noise_pb2 as noise_pb
def test_noise_extensions_serialization():
# Test NoiseExtensions
ext = noise_pb.NoiseExtensions()
ext.stream_muxers.append("/mplex/6.7.0")
ext.stream_muxers.append("/yamux/1.0.0")
# Serialize and deserialize
data = ext.SerializeToString()
ext2 = noise_pb.NoiseExtensions.FromString(data)
assert list(ext2.stream_muxers) == ["/mplex/6.7.0", "/yamux/1.0.0"]

View File

@ -173,8 +173,7 @@ def noise_transport_factory(key_pair: KeyPair) -> ISecureTransport:
return NoiseTransport(
libp2p_keypair=key_pair,
noise_privkey=noise_static_key_factory(),
early_data=None,
with_noise_pipes=False,
# TODO: add early data
)