mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-12 16:10:57 +00:00
Compare commits
4 Commits
80b58a2ae0
...
chore-02
| Author | SHA1 | Date | |
|---|---|---|---|
| 17ed7f82fb | |||
| 7bddb08808 | |||
| dc205bff83 | |||
| 1037fbb0aa |
@ -24,8 +24,13 @@ async def main():
|
|||||||
noise_transport = NoiseTransport(
|
noise_transport = NoiseTransport(
|
||||||
# local_key_pair: The key pair used for libp2p identity and authentication
|
# local_key_pair: The key pair used for libp2p identity and authentication
|
||||||
libp2p_keypair=key_pair,
|
libp2p_keypair=key_pair,
|
||||||
|
# noise_privkey: The private key used for Noise protocol encryption
|
||||||
noise_privkey=key_pair.private_key,
|
noise_privkey=key_pair.private_key,
|
||||||
# TODO: add early data
|
# early_data: Optional data to send during the handshake
|
||||||
|
# (None means no early data)
|
||||||
|
early_data=None,
|
||||||
|
# with_noise_pipes: Whether to use Noise pipes for additional security features
|
||||||
|
with_noise_pipes=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a security options dictionary mapping protocol ID to transport
|
# Create a security options dictionary mapping protocol ID to transport
|
||||||
|
|||||||
@ -28,7 +28,9 @@ async def main():
|
|||||||
noise_privkey=key_pair.private_key,
|
noise_privkey=key_pair.private_key,
|
||||||
# early_data: Optional data to send during the handshake
|
# early_data: Optional data to send during the handshake
|
||||||
# (None means no early data)
|
# (None means no early data)
|
||||||
# TODO: add early data
|
early_data=None,
|
||||||
|
# with_noise_pipes: Whether to use Noise pipes for additional security features
|
||||||
|
with_noise_pipes=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a security options dictionary mapping protocol ID to transport
|
# Create a security options dictionary mapping protocol ID to transport
|
||||||
|
|||||||
@ -31,7 +31,9 @@ async def main():
|
|||||||
noise_privkey=key_pair.private_key,
|
noise_privkey=key_pair.private_key,
|
||||||
# early_data: Optional data to send during the handshake
|
# early_data: Optional data to send during the handshake
|
||||||
# (None means no early data)
|
# (None means no early data)
|
||||||
# TODO: add early data
|
early_data=None,
|
||||||
|
# with_noise_pipes: Whether to use Noise pipes for additional security features
|
||||||
|
with_noise_pipes=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a security options dictionary mapping protocol ID to transport
|
# Create a security options dictionary mapping protocol ID to transport
|
||||||
|
|||||||
@ -28,7 +28,9 @@ async def main():
|
|||||||
noise_privkey=key_pair.private_key,
|
noise_privkey=key_pair.private_key,
|
||||||
# early_data: Optional data to send during the handshake
|
# early_data: Optional data to send during the handshake
|
||||||
# (None means no early data)
|
# (None means no early data)
|
||||||
# TODO: add early data
|
early_data=None,
|
||||||
|
# with_noise_pipes: Whether to use Noise pipes for additional security features
|
||||||
|
with_noise_pipes=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a security options dictionary mapping protocol ID to transport
|
# Create a security options dictionary mapping protocol ID to transport
|
||||||
|
|||||||
@ -37,4 +37,3 @@ SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool]
|
|||||||
AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]]
|
AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]]
|
||||||
ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn]
|
ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn]
|
||||||
UnsubscribeFn = Callable[[], Awaitable[None]]
|
UnsubscribeFn = Callable[[], Awaitable[None]]
|
||||||
MessageID = NewType("MessageID", str)
|
|
||||||
|
|||||||
@ -1,3 +1,6 @@
|
|||||||
|
from ast import (
|
||||||
|
literal_eval,
|
||||||
|
)
|
||||||
from collections import (
|
from collections import (
|
||||||
defaultdict,
|
defaultdict,
|
||||||
)
|
)
|
||||||
@ -19,7 +22,6 @@ from libp2p.abc import (
|
|||||||
IPubsubRouter,
|
IPubsubRouter,
|
||||||
)
|
)
|
||||||
from libp2p.custom_types import (
|
from libp2p.custom_types import (
|
||||||
MessageID,
|
|
||||||
TProtocol,
|
TProtocol,
|
||||||
)
|
)
|
||||||
from libp2p.peer.id import (
|
from libp2p.peer.id import (
|
||||||
@ -54,10 +56,6 @@ from .pb import (
|
|||||||
from .pubsub import (
|
from .pubsub import (
|
||||||
Pubsub,
|
Pubsub,
|
||||||
)
|
)
|
||||||
from .utils import (
|
|
||||||
parse_message_id_safe,
|
|
||||||
safe_parse_message_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
PROTOCOL_ID = TProtocol("/meshsub/1.0.0")
|
PROTOCOL_ID = TProtocol("/meshsub/1.0.0")
|
||||||
PROTOCOL_ID_V11 = TProtocol("/meshsub/1.1.0")
|
PROTOCOL_ID_V11 = TProtocol("/meshsub/1.1.0")
|
||||||
@ -796,8 +794,8 @@ class GossipSub(IPubsubRouter, Service):
|
|||||||
|
|
||||||
# Add all unknown message ids (ids that appear in ihave_msg but not in
|
# Add all unknown message ids (ids that appear in ihave_msg but not in
|
||||||
# seen_seqnos) to list of messages we want to request
|
# seen_seqnos) to list of messages we want to request
|
||||||
msg_ids_wanted: list[MessageID] = [
|
msg_ids_wanted: list[str] = [
|
||||||
parse_message_id_safe(msg_id)
|
msg_id
|
||||||
for msg_id in ihave_msg.messageIDs
|
for msg_id in ihave_msg.messageIDs
|
||||||
if msg_id not in seen_seqnos_and_peers
|
if msg_id not in seen_seqnos_and_peers
|
||||||
]
|
]
|
||||||
@ -813,9 +811,9 @@ class GossipSub(IPubsubRouter, Service):
|
|||||||
Forwards all request messages that are present in mcache to the
|
Forwards all request messages that are present in mcache to the
|
||||||
requesting peer.
|
requesting peer.
|
||||||
"""
|
"""
|
||||||
msg_ids: list[tuple[bytes, bytes]] = [
|
# FIXME: Update type of message ID
|
||||||
safe_parse_message_id(msg) for msg in iwant_msg.messageIDs
|
# FIXME: Find a better way to parse the msg ids
|
||||||
]
|
msg_ids: list[Any] = [literal_eval(msg) for msg in iwant_msg.messageIDs]
|
||||||
msgs_to_forward: list[rpc_pb2.Message] = []
|
msgs_to_forward: list[rpc_pb2.Message] = []
|
||||||
for msg_id_iwant in msg_ids:
|
for msg_id_iwant in msg_ids:
|
||||||
# Check if the wanted message ID is present in mcache
|
# Check if the wanted message ID is present in mcache
|
||||||
|
|||||||
@ -1,10 +1,6 @@
|
|||||||
import ast
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from libp2p.abc import IHost
|
from libp2p.abc import IHost
|
||||||
from libp2p.custom_types import (
|
|
||||||
MessageID,
|
|
||||||
)
|
|
||||||
from libp2p.peer.envelope import consume_envelope
|
from libp2p.peer.envelope import consume_envelope
|
||||||
from libp2p.peer.id import ID
|
from libp2p.peer.id import ID
|
||||||
from libp2p.pubsub.pb.rpc_pb2 import RPC
|
from libp2p.pubsub.pb.rpc_pb2 import RPC
|
||||||
@ -52,29 +48,3 @@ def maybe_consume_signed_record(msg: RPC, host: IHost, peer_id: ID) -> bool:
|
|||||||
logger.error("Failed to update the Certified-Addr-Book: %s", e)
|
logger.error("Failed to update the Certified-Addr-Book: %s", e)
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def parse_message_id_safe(msg_id_str: str) -> MessageID:
|
|
||||||
"""Safely handle message ID as string."""
|
|
||||||
return MessageID(msg_id_str)
|
|
||||||
|
|
||||||
|
|
||||||
def safe_parse_message_id(msg_id_str: str) -> tuple[bytes, bytes]:
|
|
||||||
"""
|
|
||||||
Safely parse message ID using ast.literal_eval with validation.
|
|
||||||
:param msg_id_str: String representation of message ID
|
|
||||||
:return: Tuple of (seqno, from_id) as bytes
|
|
||||||
:raises ValueError: If parsing fails
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
parsed = ast.literal_eval(msg_id_str)
|
|
||||||
if not isinstance(parsed, tuple) or len(parsed) != 2:
|
|
||||||
raise ValueError("Invalid message ID format")
|
|
||||||
|
|
||||||
seqno, from_id = parsed
|
|
||||||
if not isinstance(seqno, bytes) or not isinstance(from_id, bytes):
|
|
||||||
raise ValueError("Message ID components must be bytes")
|
|
||||||
|
|
||||||
return (seqno, from_id)
|
|
||||||
except (ValueError, SyntaxError) as e:
|
|
||||||
raise ValueError(f"Invalid message ID format: {e}")
|
|
||||||
|
|||||||
@ -1,68 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
from libp2p.abc import IRawConnection
|
|
||||||
from libp2p.custom_types import TProtocol
|
|
||||||
from libp2p.peer.id import ID
|
|
||||||
|
|
||||||
from .pb import noise_pb2 as noise_pb
|
|
||||||
|
|
||||||
|
|
||||||
class EarlyDataHandler(ABC):
|
|
||||||
"""Interface for handling early data during Noise handshake"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def send(
|
|
||||||
self, conn: IRawConnection, peer_id: ID
|
|
||||||
) -> noise_pb.NoiseExtensions | None:
|
|
||||||
"""Called to generate early data to send during handshake"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def received(
|
|
||||||
self, conn: IRawConnection, extensions: noise_pb.NoiseExtensions | None
|
|
||||||
) -> None:
|
|
||||||
"""Called when early data is received during handshake"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class TransportEarlyDataHandler(EarlyDataHandler):
|
|
||||||
"""Default early data handler for muxer negotiation"""
|
|
||||||
|
|
||||||
def __init__(self, supported_muxers: list[TProtocol]):
|
|
||||||
self.supported_muxers = supported_muxers
|
|
||||||
self.received_muxers: list[TProtocol] = []
|
|
||||||
|
|
||||||
async def send(
|
|
||||||
self, conn: IRawConnection, peer_id: ID
|
|
||||||
) -> noise_pb.NoiseExtensions | None:
|
|
||||||
"""Send our supported muxers list"""
|
|
||||||
if not self.supported_muxers:
|
|
||||||
return None
|
|
||||||
|
|
||||||
extensions = noise_pb.NoiseExtensions()
|
|
||||||
# Convert TProtocol to string for serialization
|
|
||||||
extensions.stream_muxers[:] = [str(muxer) for muxer in self.supported_muxers]
|
|
||||||
return extensions
|
|
||||||
|
|
||||||
async def received(
|
|
||||||
self, conn: IRawConnection, extensions: noise_pb.NoiseExtensions | None
|
|
||||||
) -> None:
|
|
||||||
"""Store received muxers list"""
|
|
||||||
if extensions and extensions.stream_muxers:
|
|
||||||
self.received_muxers = [
|
|
||||||
TProtocol(muxer) for muxer in extensions.stream_muxers
|
|
||||||
]
|
|
||||||
|
|
||||||
def match_muxers(self, is_initiator: bool) -> TProtocol | None:
|
|
||||||
"""Find first common muxer between local and remote"""
|
|
||||||
if is_initiator:
|
|
||||||
# Initiator: find first local muxer that remote supports
|
|
||||||
for local_muxer in self.supported_muxers:
|
|
||||||
if local_muxer in self.received_muxers:
|
|
||||||
return local_muxer
|
|
||||||
else:
|
|
||||||
# Responder: find first remote muxer that we support
|
|
||||||
for remote_muxer in self.received_muxers:
|
|
||||||
if remote_muxer in self.supported_muxers:
|
|
||||||
return remote_muxer
|
|
||||||
return None
|
|
||||||
@ -30,9 +30,6 @@ from libp2p.security.secure_session import (
|
|||||||
SecureSession,
|
SecureSession,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .early_data import (
|
|
||||||
EarlyDataHandler,
|
|
||||||
)
|
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
HandshakeHasNotFinished,
|
HandshakeHasNotFinished,
|
||||||
InvalidSignature,
|
InvalidSignature,
|
||||||
@ -48,7 +45,6 @@ from .messages import (
|
|||||||
make_handshake_payload_sig,
|
make_handshake_payload_sig,
|
||||||
verify_handshake_payload_sig,
|
verify_handshake_payload_sig,
|
||||||
)
|
)
|
||||||
from .pb import noise_pb2 as noise_pb
|
|
||||||
|
|
||||||
|
|
||||||
class IPattern(ABC):
|
class IPattern(ABC):
|
||||||
@ -66,8 +62,7 @@ class BasePattern(IPattern):
|
|||||||
noise_static_key: PrivateKey
|
noise_static_key: PrivateKey
|
||||||
local_peer: ID
|
local_peer: ID
|
||||||
libp2p_privkey: PrivateKey
|
libp2p_privkey: PrivateKey
|
||||||
initiator_early_data_handler: EarlyDataHandler | None
|
early_data: bytes | None
|
||||||
responder_early_data_handler: EarlyDataHandler | None
|
|
||||||
|
|
||||||
def create_noise_state(self) -> NoiseState:
|
def create_noise_state(self) -> NoiseState:
|
||||||
noise_state = NoiseState.from_name(self.protocol_name)
|
noise_state = NoiseState.from_name(self.protocol_name)
|
||||||
@ -78,50 +73,11 @@ class BasePattern(IPattern):
|
|||||||
raise NoiseStateError("noise_protocol is not initialized")
|
raise NoiseStateError("noise_protocol is not initialized")
|
||||||
return noise_state
|
return noise_state
|
||||||
|
|
||||||
async def make_handshake_payload(
|
def make_handshake_payload(self) -> NoiseHandshakePayload:
|
||||||
self, conn: IRawConnection, peer_id: ID, is_initiator: bool
|
|
||||||
) -> NoiseHandshakePayload:
|
|
||||||
signature = make_handshake_payload_sig(
|
signature = make_handshake_payload_sig(
|
||||||
self.libp2p_privkey, self.noise_static_key.get_public_key()
|
self.libp2p_privkey, self.noise_static_key.get_public_key()
|
||||||
)
|
)
|
||||||
|
return NoiseHandshakePayload(self.libp2p_privkey.get_public_key(), signature)
|
||||||
# NEW: Get early data from appropriate handler
|
|
||||||
extensions = None
|
|
||||||
if is_initiator and self.initiator_early_data_handler:
|
|
||||||
extensions = await self.initiator_early_data_handler.send(conn, peer_id)
|
|
||||||
elif not is_initiator and self.responder_early_data_handler:
|
|
||||||
extensions = await self.responder_early_data_handler.send(conn, peer_id)
|
|
||||||
|
|
||||||
# NEW: Serialize extensions into early_data field
|
|
||||||
early_data = None
|
|
||||||
if extensions:
|
|
||||||
early_data = extensions.SerializeToString()
|
|
||||||
|
|
||||||
return NoiseHandshakePayload(
|
|
||||||
self.libp2p_privkey.get_public_key(),
|
|
||||||
signature,
|
|
||||||
early_data, # ← This is the key addition
|
|
||||||
)
|
|
||||||
|
|
||||||
async def handle_received_payload(
|
|
||||||
self, conn: IRawConnection, payload: NoiseHandshakePayload, is_initiator: bool
|
|
||||||
) -> None:
|
|
||||||
"""Process early data from received payload"""
|
|
||||||
if not payload.early_data:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Deserialize the NoiseExtensions from early_data field
|
|
||||||
try:
|
|
||||||
extensions = noise_pb.NoiseExtensions.FromString(payload.early_data)
|
|
||||||
except Exception:
|
|
||||||
# Invalid extensions, ignore silently
|
|
||||||
return
|
|
||||||
|
|
||||||
# Pass to appropriate handler
|
|
||||||
if is_initiator and self.initiator_early_data_handler:
|
|
||||||
await self.initiator_early_data_handler.received(conn, extensions)
|
|
||||||
elif not is_initiator and self.responder_early_data_handler:
|
|
||||||
await self.responder_early_data_handler.received(conn, extensions)
|
|
||||||
|
|
||||||
|
|
||||||
class PatternXX(BasePattern):
|
class PatternXX(BasePattern):
|
||||||
@ -130,15 +86,13 @@ class PatternXX(BasePattern):
|
|||||||
local_peer: ID,
|
local_peer: ID,
|
||||||
libp2p_privkey: PrivateKey,
|
libp2p_privkey: PrivateKey,
|
||||||
noise_static_key: PrivateKey,
|
noise_static_key: PrivateKey,
|
||||||
initiator_early_data_handler: EarlyDataHandler | None,
|
early_data: bytes | None = None,
|
||||||
responder_early_data_handler: EarlyDataHandler | None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.protocol_name = b"Noise_XX_25519_ChaChaPoly_SHA256"
|
self.protocol_name = b"Noise_XX_25519_ChaChaPoly_SHA256"
|
||||||
self.local_peer = local_peer
|
self.local_peer = local_peer
|
||||||
self.libp2p_privkey = libp2p_privkey
|
self.libp2p_privkey = libp2p_privkey
|
||||||
self.noise_static_key = noise_static_key
|
self.noise_static_key = noise_static_key
|
||||||
self.initiator_early_data_handler = initiator_early_data_handler
|
self.early_data = early_data
|
||||||
self.responder_early_data_handler = responder_early_data_handler
|
|
||||||
|
|
||||||
async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn:
|
async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn:
|
||||||
noise_state = self.create_noise_state()
|
noise_state = self.create_noise_state()
|
||||||
@ -152,23 +106,18 @@ class PatternXX(BasePattern):
|
|||||||
|
|
||||||
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
|
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
|
||||||
|
|
||||||
# 1. Consume msg#1 (just empty bytes)
|
# Consume msg#1.
|
||||||
await read_writer.read_msg()
|
await read_writer.read_msg()
|
||||||
|
|
||||||
# 2. Send msg#2 with our payload INCLUDING EARLY DATA
|
# Send msg#2, which should include our handshake payload.
|
||||||
our_payload = await self.make_handshake_payload(
|
our_payload = self.make_handshake_payload()
|
||||||
conn,
|
|
||||||
self.local_peer, # We send our own peer ID in responder role
|
|
||||||
is_initiator=False,
|
|
||||||
)
|
|
||||||
msg_2 = our_payload.serialize()
|
msg_2 = our_payload.serialize()
|
||||||
await read_writer.write_msg(msg_2)
|
await read_writer.write_msg(msg_2)
|
||||||
|
|
||||||
# 3. Receive msg#3
|
# Receive and consume msg#3.
|
||||||
msg_3 = await read_writer.read_msg()
|
msg_3 = await read_writer.read_msg()
|
||||||
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_3)
|
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_3)
|
||||||
|
|
||||||
# Extract remote pubkey from noise handshake state
|
|
||||||
if handshake_state.rs is None:
|
if handshake_state.rs is None:
|
||||||
raise NoiseStateError(
|
raise NoiseStateError(
|
||||||
"something is wrong in the underlying noise `handshake_state`: "
|
"something is wrong in the underlying noise `handshake_state`: "
|
||||||
@ -177,31 +126,14 @@ class PatternXX(BasePattern):
|
|||||||
)
|
)
|
||||||
remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs)
|
remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs)
|
||||||
|
|
||||||
# 4. Verify signature (unchanged)
|
|
||||||
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
|
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
|
||||||
raise InvalidSignature
|
raise InvalidSignature
|
||||||
|
|
||||||
# NEW: Process early data from msg#3 AFTER signature verification
|
|
||||||
await self.handle_received_payload(
|
|
||||||
conn, peer_handshake_payload, is_initiator=False
|
|
||||||
)
|
|
||||||
|
|
||||||
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
|
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
|
||||||
|
|
||||||
if not noise_state.handshake_finished:
|
if not noise_state.handshake_finished:
|
||||||
raise HandshakeHasNotFinished(
|
raise HandshakeHasNotFinished(
|
||||||
"handshake is done but it is not marked as finished in `noise_state`"
|
"handshake is done but it is not marked as finished in `noise_state`"
|
||||||
)
|
)
|
||||||
|
|
||||||
# NEW: Get negotiated muxer for connection state
|
|
||||||
# negotiated_muxer = None
|
|
||||||
if self.responder_early_data_handler and hasattr(
|
|
||||||
self.responder_early_data_handler, "match_muxers"
|
|
||||||
):
|
|
||||||
# negotiated_muxer =
|
|
||||||
# self.responder_early_data_handler.match_muxers(is_initiator=False)
|
|
||||||
pass
|
|
||||||
|
|
||||||
transport_read_writer = NoiseTransportReadWriter(conn, noise_state)
|
transport_read_writer = NoiseTransportReadWriter(conn, noise_state)
|
||||||
return SecureSession(
|
return SecureSession(
|
||||||
local_peer=self.local_peer,
|
local_peer=self.local_peer,
|
||||||
@ -210,8 +142,6 @@ class PatternXX(BasePattern):
|
|||||||
remote_permanent_pubkey=remote_pubkey,
|
remote_permanent_pubkey=remote_pubkey,
|
||||||
is_initiator=False,
|
is_initiator=False,
|
||||||
conn=transport_read_writer,
|
conn=transport_read_writer,
|
||||||
# NOTE: negotiated_muxer would need to be added to SecureSession constructor
|
|
||||||
# For now, store it in connection metadata or similar
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def handshake_outbound(
|
async def handshake_outbound(
|
||||||
@ -228,27 +158,24 @@ class PatternXX(BasePattern):
|
|||||||
if handshake_state is None:
|
if handshake_state is None:
|
||||||
raise NoiseStateError("Handshake state is not initialized")
|
raise NoiseStateError("Handshake state is not initialized")
|
||||||
|
|
||||||
# 1. Send msg#1 (empty) - no early data possible in XX pattern
|
# Send msg#1, which is *not* encrypted.
|
||||||
msg_1 = b""
|
msg_1 = b""
|
||||||
await read_writer.write_msg(msg_1)
|
await read_writer.write_msg(msg_1)
|
||||||
|
|
||||||
# 2. Read msg#2 from responder
|
# Read msg#2 from the remote, which contains the public key of the peer.
|
||||||
msg_2 = await read_writer.read_msg()
|
msg_2 = await read_writer.read_msg()
|
||||||
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_2)
|
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_2)
|
||||||
|
|
||||||
# Extract remote pubkey from noise handshake state
|
|
||||||
if handshake_state.rs is None:
|
if handshake_state.rs is None:
|
||||||
raise NoiseStateError(
|
raise NoiseStateError(
|
||||||
"something is wrong in the underlying noise `handshake_state`: "
|
"something is wrong in the underlying noise `handshake_state`: "
|
||||||
"we received and consumed msg#2, which should have included the "
|
"we received and consumed msg#3, which should have included the "
|
||||||
"remote static public key, but it is not present in the handshake_state"
|
"remote static public key, but it is not present in the handshake_state"
|
||||||
)
|
)
|
||||||
remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs)
|
remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs)
|
||||||
|
|
||||||
# Verify signature BEFORE processing early data (security)
|
|
||||||
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
|
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
|
||||||
raise InvalidSignature
|
raise InvalidSignature
|
||||||
|
|
||||||
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
|
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
|
||||||
if remote_peer_id_from_pubkey != remote_peer:
|
if remote_peer_id_from_pubkey != remote_peer:
|
||||||
raise PeerIDMismatchesPubkey(
|
raise PeerIDMismatchesPubkey(
|
||||||
@ -257,15 +184,8 @@ class PatternXX(BasePattern):
|
|||||||
f"remote_peer_id_from_pubkey={remote_peer_id_from_pubkey}"
|
f"remote_peer_id_from_pubkey={remote_peer_id_from_pubkey}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# NEW: Process early data from msg#2 AFTER verification
|
# Send msg#3, which includes our encrypted payload and our noise static key.
|
||||||
await self.handle_received_payload(
|
our_payload = self.make_handshake_payload()
|
||||||
conn, peer_handshake_payload, is_initiator=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Send msg#3 with our payload INCLUDING EARLY DATA
|
|
||||||
our_payload = await self.make_handshake_payload(
|
|
||||||
conn, remote_peer, is_initiator=True
|
|
||||||
)
|
|
||||||
msg_3 = our_payload.serialize()
|
msg_3 = our_payload.serialize()
|
||||||
await read_writer.write_msg(msg_3)
|
await read_writer.write_msg(msg_3)
|
||||||
|
|
||||||
@ -273,16 +193,6 @@ class PatternXX(BasePattern):
|
|||||||
raise HandshakeHasNotFinished(
|
raise HandshakeHasNotFinished(
|
||||||
"handshake is done but it is not marked as finished in `noise_state`"
|
"handshake is done but it is not marked as finished in `noise_state`"
|
||||||
)
|
)
|
||||||
|
|
||||||
# NEW: Get negotiated muxer
|
|
||||||
# negotiated_muxer = None
|
|
||||||
if self.initiator_early_data_handler and hasattr(
|
|
||||||
self.initiator_early_data_handler, "match_muxers"
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
# negotiated_muxer =
|
|
||||||
# self.initiator_early_data_handler.match_muxers(is_initiator=True)
|
|
||||||
|
|
||||||
transport_read_writer = NoiseTransportReadWriter(conn, noise_state)
|
transport_read_writer = NoiseTransportReadWriter(conn, noise_state)
|
||||||
return SecureSession(
|
return SecureSession(
|
||||||
local_peer=self.local_peer,
|
local_peer=self.local_peer,
|
||||||
@ -291,8 +201,6 @@ class PatternXX(BasePattern):
|
|||||||
remote_permanent_pubkey=remote_pubkey,
|
remote_permanent_pubkey=remote_pubkey,
|
||||||
is_initiator=True,
|
is_initiator=True,
|
||||||
conn=transport_read_writer,
|
conn=transport_read_writer,
|
||||||
# NOTE: negotiated_muxer would need to be added to SecureSession constructor
|
|
||||||
# For now, store it in connection metadata or similar
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -1,13 +1,8 @@
|
|||||||
syntax = "proto2";
|
syntax = "proto3";
|
||||||
package pb;
|
package pb;
|
||||||
|
|
||||||
message NoiseExtensions {
|
|
||||||
repeated bytes webtransport_certhashes = 1;
|
|
||||||
repeated string stream_muxers = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
message NoiseHandshakePayload {
|
message NoiseHandshakePayload {
|
||||||
optional bytes identity_key = 1;
|
bytes identity_key = 1;
|
||||||
optional bytes identity_sig = 2;
|
bytes identity_sig = 2;
|
||||||
optional bytes data = 3;
|
bytes data = 3;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -13,15 +13,13 @@ _sym_db = _symbol_database.Default()
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$libp2p/security/noise/pb/noise.proto\x12\x02pb\"I\n\x0fNoiseExtensions\x12\x1f\n\x17webtransport_certhashes\x18\x01 \x03(\x0c\x12\x15\n\rstream_muxers\x18\x02 \x03(\t\"Q\n\x15NoiseHandshakePayload\x12\x14\n\x0cidentity_key\x18\x01 \x01(\x0c\x12\x14\n\x0cidentity_sig\x18\x02 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c')
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$libp2p/security/noise/pb/noise.proto\x12\x02pb\"Q\n\x15NoiseHandshakePayload\x12\x14\n\x0cidentity_key\x18\x01 \x01(\x0c\x12\x14\n\x0cidentity_sig\x18\x02 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\x62\x06proto3')
|
||||||
|
|
||||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.security.noise.pb.noise_pb2', globals())
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.security.noise.pb.noise_pb2', globals())
|
||||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||||
|
|
||||||
DESCRIPTOR._options = None
|
DESCRIPTOR._options = None
|
||||||
_NOISEEXTENSIONS._serialized_start=44
|
_NOISEHANDSHAKEPAYLOAD._serialized_start=44
|
||||||
_NOISEEXTENSIONS._serialized_end=117
|
_NOISEHANDSHAKEPAYLOAD._serialized_end=125
|
||||||
_NOISEHANDSHAKEPAYLOAD._serialized_start=119
|
|
||||||
_NOISEHANDSHAKEPAYLOAD._serialized_end=200
|
|
||||||
# @@protoc_insertion_point(module_scope)
|
# @@protoc_insertion_point(module_scope)
|
||||||
|
|||||||
@ -4,34 +4,12 @@ isort:skip_file
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import builtins
|
import builtins
|
||||||
import collections.abc
|
|
||||||
import google.protobuf.descriptor
|
import google.protobuf.descriptor
|
||||||
import google.protobuf.internal.containers
|
|
||||||
import google.protobuf.message
|
import google.protobuf.message
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||||
|
|
||||||
@typing.final
|
|
||||||
class NoiseExtensions(google.protobuf.message.Message):
|
|
||||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
|
||||||
|
|
||||||
WEBTRANSPORT_CERTHASHES_FIELD_NUMBER: builtins.int
|
|
||||||
STREAM_MUXERS_FIELD_NUMBER: builtins.int
|
|
||||||
@property
|
|
||||||
def webtransport_certhashes(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ...
|
|
||||||
@property
|
|
||||||
def stream_muxers(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ...
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
webtransport_certhashes: collections.abc.Iterable[builtins.bytes] | None = ...,
|
|
||||||
stream_muxers: collections.abc.Iterable[builtins.str] | None = ...,
|
|
||||||
) -> None: ...
|
|
||||||
def ClearField(self, field_name: typing.Literal["stream_muxers", b"stream_muxers", "webtransport_certhashes", b"webtransport_certhashes"]) -> None: ...
|
|
||||||
|
|
||||||
global___NoiseExtensions = NoiseExtensions
|
|
||||||
|
|
||||||
@typing.final
|
@typing.final
|
||||||
class NoiseHandshakePayload(google.protobuf.message.Message):
|
class NoiseHandshakePayload(google.protobuf.message.Message):
|
||||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||||
@ -45,11 +23,10 @@ class NoiseHandshakePayload(google.protobuf.message.Message):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
identity_key: builtins.bytes | None = ...,
|
identity_key: builtins.bytes = ...,
|
||||||
identity_sig: builtins.bytes | None = ...,
|
identity_sig: builtins.bytes = ...,
|
||||||
data: builtins.bytes | None = ...,
|
data: builtins.bytes = ...,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
def HasField(self, field_name: typing.Literal["data", b"data", "identity_key", b"identity_key", "identity_sig", b"identity_sig"]) -> builtins.bool: ...
|
|
||||||
def ClearField(self, field_name: typing.Literal["data", b"data", "identity_key", b"identity_key", "identity_sig", b"identity_sig"]) -> None: ...
|
def ClearField(self, field_name: typing.Literal["data", b"data", "identity_key", b"identity_key", "identity_sig", b"identity_sig"]) -> None: ...
|
||||||
|
|
||||||
global___NoiseHandshakePayload = NoiseHandshakePayload
|
global___NoiseHandshakePayload = NoiseHandshakePayload
|
||||||
|
|||||||
@ -14,7 +14,6 @@ from libp2p.peer.id import (
|
|||||||
ID,
|
ID,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .early_data import EarlyDataHandler, TransportEarlyDataHandler
|
|
||||||
from .patterns import (
|
from .patterns import (
|
||||||
IPattern,
|
IPattern,
|
||||||
PatternXX,
|
PatternXX,
|
||||||
@ -27,40 +26,35 @@ class Transport(ISecureTransport):
|
|||||||
libp2p_privkey: PrivateKey
|
libp2p_privkey: PrivateKey
|
||||||
noise_privkey: PrivateKey
|
noise_privkey: PrivateKey
|
||||||
local_peer: ID
|
local_peer: ID
|
||||||
supported_muxers: list[TProtocol]
|
early_data: bytes | None
|
||||||
initiator_early_data_handler: EarlyDataHandler | None
|
with_noise_pipes: bool
|
||||||
responder_early_data_handler: EarlyDataHandler | None
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
libp2p_keypair: KeyPair,
|
libp2p_keypair: KeyPair,
|
||||||
noise_privkey: PrivateKey,
|
noise_privkey: PrivateKey,
|
||||||
supported_muxers: list[TProtocol] | None = None,
|
early_data: bytes | None = None,
|
||||||
initiator_handler: EarlyDataHandler | None = None,
|
with_noise_pipes: bool = False,
|
||||||
responder_handler: EarlyDataHandler | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.libp2p_privkey = libp2p_keypair.private_key
|
self.libp2p_privkey = libp2p_keypair.private_key
|
||||||
self.noise_privkey = noise_privkey
|
self.noise_privkey = noise_privkey
|
||||||
self.local_peer = ID.from_pubkey(libp2p_keypair.public_key)
|
self.local_peer = ID.from_pubkey(libp2p_keypair.public_key)
|
||||||
self.supported_muxers = supported_muxers or []
|
self.early_data = early_data
|
||||||
|
self.with_noise_pipes = with_noise_pipes
|
||||||
|
|
||||||
# Create default handlers for muxer negotiation if none provided
|
if self.with_noise_pipes:
|
||||||
if initiator_handler is None and self.supported_muxers:
|
raise NotImplementedError
|
||||||
initiator_handler = TransportEarlyDataHandler(self.supported_muxers)
|
|
||||||
if responder_handler is None and self.supported_muxers:
|
|
||||||
responder_handler = TransportEarlyDataHandler(self.supported_muxers)
|
|
||||||
|
|
||||||
self.initiator_early_data_handler = initiator_handler
|
|
||||||
self.responder_early_data_handler = responder_handler
|
|
||||||
|
|
||||||
def get_pattern(self) -> IPattern:
|
def get_pattern(self) -> IPattern:
|
||||||
return PatternXX(
|
if self.with_noise_pipes:
|
||||||
self.local_peer,
|
raise NotImplementedError
|
||||||
self.libp2p_privkey,
|
else:
|
||||||
self.noise_privkey,
|
return PatternXX(
|
||||||
self.initiator_early_data_handler,
|
self.local_peer,
|
||||||
self.responder_early_data_handler,
|
self.libp2p_privkey,
|
||||||
)
|
self.noise_privkey,
|
||||||
|
self.early_data,
|
||||||
|
)
|
||||||
|
|
||||||
async def secure_inbound(self, conn: IRawConnection) -> ISecureConn:
|
async def secure_inbound(self, conn: IRawConnection) -> ISecureConn:
|
||||||
pattern = self.get_pattern()
|
pattern = self.get_pattern()
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
import time
|
||||||
from types import (
|
from types import (
|
||||||
TracebackType,
|
TracebackType,
|
||||||
)
|
)
|
||||||
@ -100,6 +101,12 @@ class ReadWriteLock:
|
|||||||
self.release_write()
|
self.release_write()
|
||||||
|
|
||||||
|
|
||||||
|
class MplexStreamTimeout(Exception):
|
||||||
|
"""Raised when a stream operation exceeds its deadline."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MplexStream(IMuxedStream):
|
class MplexStream(IMuxedStream):
|
||||||
"""
|
"""
|
||||||
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go
|
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go
|
||||||
@ -111,8 +118,8 @@ class MplexStream(IMuxedStream):
|
|||||||
# class of IMuxedConn. Ignoring this type assignment should not pose
|
# class of IMuxedConn. Ignoring this type assignment should not pose
|
||||||
# any risk.
|
# any risk.
|
||||||
muxed_conn: "Mplex" # type: ignore[assignment]
|
muxed_conn: "Mplex" # type: ignore[assignment]
|
||||||
read_deadline: int | None
|
read_deadline: float | None
|
||||||
write_deadline: int | None
|
write_deadline: float | None
|
||||||
|
|
||||||
rw_lock: ReadWriteLock
|
rw_lock: ReadWriteLock
|
||||||
close_lock: trio.Lock
|
close_lock: trio.Lock
|
||||||
@ -156,6 +163,30 @@ class MplexStream(IMuxedStream):
|
|||||||
def is_initiator(self) -> bool:
|
def is_initiator(self) -> bool:
|
||||||
return self.stream_id.is_initiator
|
return self.stream_id.is_initiator
|
||||||
|
|
||||||
|
def _check_read_deadline(self) -> None:
|
||||||
|
"""Check if read deadline has expired and raise timeout if needed."""
|
||||||
|
if self.read_deadline is not None and time.time() > self.read_deadline:
|
||||||
|
raise MplexStreamTimeout("Read operation exceeded deadline")
|
||||||
|
|
||||||
|
def _check_write_deadline(self) -> None:
|
||||||
|
"""Check if write deadline has expired and raise timeout if needed."""
|
||||||
|
if self.write_deadline is not None and time.time() > self.write_deadline:
|
||||||
|
raise MplexStreamTimeout("Write operation exceeded deadline")
|
||||||
|
|
||||||
|
def _get_read_timeout(self) -> float | None:
|
||||||
|
"""Calculate remaining time until read deadline."""
|
||||||
|
if self.read_deadline is None:
|
||||||
|
return None
|
||||||
|
remaining = self.read_deadline - time.time()
|
||||||
|
return max(0.0, remaining) if remaining > 0 else 0
|
||||||
|
|
||||||
|
def _get_write_timeout(self) -> float | None:
|
||||||
|
"""Calculate remaining time until write deadline."""
|
||||||
|
if self.write_deadline is None:
|
||||||
|
return None
|
||||||
|
remaining = self.write_deadline - time.time()
|
||||||
|
return max(0.0, remaining) if remaining > 0 else 0
|
||||||
|
|
||||||
async def _read_until_eof(self) -> bytes:
|
async def _read_until_eof(self) -> bytes:
|
||||||
async for data in self.incoming_data_channel:
|
async for data in self.incoming_data_channel:
|
||||||
self._buf.extend(data)
|
self._buf.extend(data)
|
||||||
@ -182,6 +213,9 @@ class MplexStream(IMuxedStream):
|
|||||||
:param n: number of bytes to read
|
:param n: number of bytes to read
|
||||||
:return: bytes actually read
|
:return: bytes actually read
|
||||||
"""
|
"""
|
||||||
|
# check deadline before starting
|
||||||
|
self._check_read_deadline()
|
||||||
|
|
||||||
async with self.rw_lock.read_lock():
|
async with self.rw_lock.read_lock():
|
||||||
if n is not None and n < 0:
|
if n is not None and n < 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -192,8 +226,13 @@ class MplexStream(IMuxedStream):
|
|||||||
raise MplexStreamReset
|
raise MplexStreamReset
|
||||||
if n is None:
|
if n is None:
|
||||||
return await self._read_until_eof()
|
return await self._read_until_eof()
|
||||||
|
|
||||||
|
# check deadline again before potentially blocking operation
|
||||||
|
self._check_read_deadline()
|
||||||
|
|
||||||
if len(self._buf) == 0:
|
if len(self._buf) == 0:
|
||||||
data: bytes
|
data: bytes
|
||||||
|
timeout = self._get_read_timeout()
|
||||||
# Peek whether there is data available. If yes, we just read until
|
# Peek whether there is data available. If yes, we just read until
|
||||||
# there is no data, then return.
|
# there is no data, then return.
|
||||||
try:
|
try:
|
||||||
@ -207,6 +246,20 @@ class MplexStream(IMuxedStream):
|
|||||||
try:
|
try:
|
||||||
data = await self.incoming_data_channel.receive()
|
data = await self.incoming_data_channel.receive()
|
||||||
self._buf.extend(data)
|
self._buf.extend(data)
|
||||||
|
if timeout is not None and timeout <= 0:
|
||||||
|
raise MplexStreamTimeout(
|
||||||
|
"Read deadline exceeded while waiting for data"
|
||||||
|
)
|
||||||
|
|
||||||
|
if timeout is not None:
|
||||||
|
with trio.fail_after(timeout):
|
||||||
|
data = await self.incoming_data_channel.receive()
|
||||||
|
else:
|
||||||
|
data = await self.incoming_data_channel.receive()
|
||||||
|
|
||||||
|
self._buf.extend(data)
|
||||||
|
except trio.TooSlowError:
|
||||||
|
raise MplexStreamTimeout("Read operation timed out")
|
||||||
except trio.EndOfChannel:
|
except trio.EndOfChannel:
|
||||||
if self.event_reset.is_set():
|
if self.event_reset.is_set():
|
||||||
raise MplexStreamReset
|
raise MplexStreamReset
|
||||||
@ -226,15 +279,43 @@ class MplexStream(IMuxedStream):
|
|||||||
self._buf = self._buf[len(payload) :]
|
self._buf = self._buf[len(payload) :]
|
||||||
return bytes(payload)
|
return bytes(payload)
|
||||||
|
|
||||||
|
async def _read_until_eof_with_timeout(self) -> bytes:
|
||||||
|
"""Read until EOF with timeout support."""
|
||||||
|
timeout = self._get_read_timeout()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if timeout is not None:
|
||||||
|
with trio.fail_after(timeout):
|
||||||
|
async for data in self.incoming_data_channel:
|
||||||
|
self._buf.extend(data)
|
||||||
|
else:
|
||||||
|
async for data in self.incoming_data_channel:
|
||||||
|
self._buf.extend(data)
|
||||||
|
except trio.TooSlowError:
|
||||||
|
raise MplexStreamTimeout("Read until EOF operation timed out")
|
||||||
|
|
||||||
|
payload = self._buf
|
||||||
|
self._buf = self._buf[len(payload) :]
|
||||||
|
return bytes(payload)
|
||||||
|
|
||||||
async def write(self, data: bytes) -> None:
|
async def write(self, data: bytes) -> None:
|
||||||
"""
|
"""
|
||||||
Write to stream.
|
Write to stream.
|
||||||
|
|
||||||
:return: number of bytes written
|
:return: number of bytes written
|
||||||
"""
|
"""
|
||||||
|
# Check deadline before starting
|
||||||
|
self._check_write_deadline()
|
||||||
|
|
||||||
async with self.rw_lock.write_lock():
|
async with self.rw_lock.write_lock():
|
||||||
if self.event_local_closed.is_set():
|
if self.event_local_closed.is_set():
|
||||||
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")
|
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")
|
||||||
|
|
||||||
|
# Check deadline again after acquiring lock
|
||||||
|
timeout = self._get_write_timeout()
|
||||||
|
if timeout is not None and timeout <= 0:
|
||||||
|
raise MplexStreamTimeout("Write deadline exceeded")
|
||||||
|
|
||||||
flag = (
|
flag = (
|
||||||
HeaderTags.MessageInitiator
|
HeaderTags.MessageInitiator
|
||||||
if self.is_initiator
|
if self.is_initiator
|
||||||
@ -315,8 +396,9 @@ class MplexStream(IMuxedStream):
|
|||||||
|
|
||||||
:return: True if successful
|
:return: True if successful
|
||||||
"""
|
"""
|
||||||
self.read_deadline = ttl
|
deadline = time.time() + ttl
|
||||||
self.write_deadline = ttl
|
self.read_deadline = deadline
|
||||||
|
self.write_deadline = deadline
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def set_read_deadline(self, ttl: int) -> bool:
|
def set_read_deadline(self, ttl: int) -> bool:
|
||||||
@ -325,7 +407,7 @@ class MplexStream(IMuxedStream):
|
|||||||
|
|
||||||
:return: True if successful
|
:return: True if successful
|
||||||
"""
|
"""
|
||||||
self.read_deadline = ttl
|
self.read_deadline = time.time() + ttl
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def set_write_deadline(self, ttl: int) -> bool:
|
def set_write_deadline(self, ttl: int) -> bool:
|
||||||
@ -334,7 +416,7 @@ class MplexStream(IMuxedStream):
|
|||||||
|
|
||||||
:return: True if successful
|
:return: True if successful
|
||||||
"""
|
"""
|
||||||
self.write_deadline = ttl
|
self.write_deadline = ttl + time.time()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def get_remote_address(self) -> tuple[str, int] | None:
|
def get_remote_address(self) -> tuple[str, int] | None:
|
||||||
|
|||||||
@ -1 +0,0 @@
|
|||||||
Fixed message id type inconsistency in handle ihave and message id parsing improvement in handle iwant in pubsub module.
|
|
||||||
@ -1,8 +1,4 @@
|
|||||||
import random
|
import random
|
||||||
from unittest.mock import (
|
|
||||||
AsyncMock,
|
|
||||||
MagicMock,
|
|
||||||
)
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import trio
|
import trio
|
||||||
@ -11,9 +7,6 @@ from libp2p.pubsub.gossipsub import (
|
|||||||
PROTOCOL_ID,
|
PROTOCOL_ID,
|
||||||
GossipSub,
|
GossipSub,
|
||||||
)
|
)
|
||||||
from libp2p.pubsub.pb import (
|
|
||||||
rpc_pb2,
|
|
||||||
)
|
|
||||||
from libp2p.tools.utils import (
|
from libp2p.tools.utils import (
|
||||||
connect,
|
connect,
|
||||||
)
|
)
|
||||||
@ -761,173 +754,3 @@ async def test_single_host():
|
|||||||
assert connected_peers == 0, (
|
assert connected_peers == 0, (
|
||||||
f"Single host has {connected_peers} connections, expected 0"
|
f"Single host has {connected_peers} connections, expected 0"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
|
||||||
async def test_handle_ihave(monkeypatch):
|
|
||||||
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
|
|
||||||
gossipsub_routers = []
|
|
||||||
for pubsub in pubsubs_gsub:
|
|
||||||
if isinstance(pubsub.router, GossipSub):
|
|
||||||
gossipsub_routers.append(pubsub.router)
|
|
||||||
gossipsubs = tuple(gossipsub_routers)
|
|
||||||
|
|
||||||
index_alice = 0
|
|
||||||
index_bob = 1
|
|
||||||
id_bob = pubsubs_gsub[index_bob].my_id
|
|
||||||
|
|
||||||
# Connect Alice and Bob
|
|
||||||
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
|
|
||||||
await trio.sleep(0.1) # Allow connections to establish
|
|
||||||
|
|
||||||
# Mock emit_iwant to capture calls
|
|
||||||
mock_emit_iwant = AsyncMock()
|
|
||||||
monkeypatch.setattr(gossipsubs[index_alice], "emit_iwant", mock_emit_iwant)
|
|
||||||
|
|
||||||
# Create a test message ID as a string representation of a (seqno, from) tuple
|
|
||||||
test_seqno = b"1234"
|
|
||||||
test_from = id_bob.to_bytes()
|
|
||||||
test_msg_id = f"(b'{test_seqno.hex()}', b'{test_from.hex()}')"
|
|
||||||
ihave_msg = rpc_pb2.ControlIHave(messageIDs=[test_msg_id])
|
|
||||||
|
|
||||||
# Mock seen_messages.cache to avoid false positives
|
|
||||||
monkeypatch.setattr(pubsubs_gsub[index_alice].seen_messages, "cache", {})
|
|
||||||
|
|
||||||
# Simulate Bob sending IHAVE to Alice
|
|
||||||
await gossipsubs[index_alice].handle_ihave(ihave_msg, id_bob)
|
|
||||||
|
|
||||||
# Check if emit_iwant was called with the correct message ID
|
|
||||||
mock_emit_iwant.assert_called_once()
|
|
||||||
called_args = mock_emit_iwant.call_args[0]
|
|
||||||
assert called_args[0] == [test_msg_id] # Expected message IDs
|
|
||||||
assert called_args[1] == id_bob # Sender peer ID
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
|
||||||
async def test_handle_iwant(monkeypatch):
|
|
||||||
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
|
|
||||||
gossipsub_routers = []
|
|
||||||
for pubsub in pubsubs_gsub:
|
|
||||||
if isinstance(pubsub.router, GossipSub):
|
|
||||||
gossipsub_routers.append(pubsub.router)
|
|
||||||
gossipsubs = tuple(gossipsub_routers)
|
|
||||||
|
|
||||||
index_alice = 0
|
|
||||||
index_bob = 1
|
|
||||||
id_alice = pubsubs_gsub[index_alice].my_id
|
|
||||||
|
|
||||||
# Connect Alice and Bob
|
|
||||||
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
|
|
||||||
await trio.sleep(0.1) # Allow connections to establish
|
|
||||||
|
|
||||||
# Mock mcache.get to return a message
|
|
||||||
test_message = rpc_pb2.Message(data=b"test_data")
|
|
||||||
test_seqno = b"1234"
|
|
||||||
test_from = id_alice.to_bytes()
|
|
||||||
|
|
||||||
# ✅ Correct: use raw tuple and str() to serialize, no hex()
|
|
||||||
test_msg_id = str((test_seqno, test_from))
|
|
||||||
|
|
||||||
mock_mcache_get = MagicMock(return_value=test_message)
|
|
||||||
monkeypatch.setattr(gossipsubs[index_bob].mcache, "get", mock_mcache_get)
|
|
||||||
|
|
||||||
# Mock write_msg to capture the sent packet
|
|
||||||
mock_write_msg = AsyncMock()
|
|
||||||
monkeypatch.setattr(gossipsubs[index_bob].pubsub, "write_msg", mock_write_msg)
|
|
||||||
|
|
||||||
# Simulate Alice sending IWANT to Bob
|
|
||||||
iwant_msg = rpc_pb2.ControlIWant(messageIDs=[test_msg_id])
|
|
||||||
await gossipsubs[index_bob].handle_iwant(iwant_msg, id_alice)
|
|
||||||
|
|
||||||
# Check if write_msg was called with the correct packet
|
|
||||||
mock_write_msg.assert_called_once()
|
|
||||||
packet = mock_write_msg.call_args[0][1]
|
|
||||||
assert isinstance(packet, rpc_pb2.RPC)
|
|
||||||
assert len(packet.publish) == 1
|
|
||||||
assert packet.publish[0] == test_message
|
|
||||||
|
|
||||||
# Verify that mcache.get was called with the correct parsed message ID
|
|
||||||
mock_mcache_get.assert_called_once()
|
|
||||||
called_msg_id = mock_mcache_get.call_args[0][0]
|
|
||||||
assert isinstance(called_msg_id, tuple)
|
|
||||||
assert called_msg_id == (test_seqno, test_from)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
|
||||||
async def test_handle_iwant_invalid_msg_id(monkeypatch):
|
|
||||||
"""
|
|
||||||
Test that handle_iwant raises ValueError for malformed message IDs.
|
|
||||||
"""
|
|
||||||
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
|
|
||||||
gossipsub_routers = []
|
|
||||||
for pubsub in pubsubs_gsub:
|
|
||||||
if isinstance(pubsub.router, GossipSub):
|
|
||||||
gossipsub_routers.append(pubsub.router)
|
|
||||||
gossipsubs = tuple(gossipsub_routers)
|
|
||||||
|
|
||||||
index_alice = 0
|
|
||||||
index_bob = 1
|
|
||||||
id_alice = pubsubs_gsub[index_alice].my_id
|
|
||||||
|
|
||||||
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
|
|
||||||
await trio.sleep(0.1)
|
|
||||||
|
|
||||||
# Malformed message ID (not a tuple string)
|
|
||||||
malformed_msg_id = "not_a_valid_msg_id"
|
|
||||||
iwant_msg = rpc_pb2.ControlIWant(messageIDs=[malformed_msg_id])
|
|
||||||
|
|
||||||
# Mock mcache.get and write_msg to ensure they are not called
|
|
||||||
mock_mcache_get = MagicMock()
|
|
||||||
monkeypatch.setattr(gossipsubs[index_bob].mcache, "get", mock_mcache_get)
|
|
||||||
mock_write_msg = AsyncMock()
|
|
||||||
monkeypatch.setattr(gossipsubs[index_bob].pubsub, "write_msg", mock_write_msg)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
await gossipsubs[index_bob].handle_iwant(iwant_msg, id_alice)
|
|
||||||
mock_mcache_get.assert_not_called()
|
|
||||||
mock_write_msg.assert_not_called()
|
|
||||||
|
|
||||||
# Message ID that's a tuple string but not (bytes, bytes)
|
|
||||||
invalid_tuple_msg_id = "('abc', 123)"
|
|
||||||
iwant_msg = rpc_pb2.ControlIWant(messageIDs=[invalid_tuple_msg_id])
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
await gossipsubs[index_bob].handle_iwant(iwant_msg, id_alice)
|
|
||||||
mock_mcache_get.assert_not_called()
|
|
||||||
mock_write_msg.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
|
||||||
async def test_handle_ihave_empty_message_ids(monkeypatch):
|
|
||||||
"""
|
|
||||||
Test that handle_ihave with an empty messageIDs list does not call emit_iwant.
|
|
||||||
"""
|
|
||||||
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
|
|
||||||
gossipsub_routers = []
|
|
||||||
for pubsub in pubsubs_gsub:
|
|
||||||
if isinstance(pubsub.router, GossipSub):
|
|
||||||
gossipsub_routers.append(pubsub.router)
|
|
||||||
gossipsubs = tuple(gossipsub_routers)
|
|
||||||
|
|
||||||
index_alice = 0
|
|
||||||
index_bob = 1
|
|
||||||
id_bob = pubsubs_gsub[index_bob].my_id
|
|
||||||
|
|
||||||
# Connect Alice and Bob
|
|
||||||
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
|
|
||||||
await trio.sleep(0.1) # Allow connections to establish
|
|
||||||
|
|
||||||
# Mock emit_iwant to capture calls
|
|
||||||
mock_emit_iwant = AsyncMock()
|
|
||||||
monkeypatch.setattr(gossipsubs[index_alice], "emit_iwant", mock_emit_iwant)
|
|
||||||
|
|
||||||
# Empty messageIDs list
|
|
||||||
ihave_msg = rpc_pb2.ControlIHave(messageIDs=[])
|
|
||||||
|
|
||||||
# Mock seen_messages.cache to avoid false positives
|
|
||||||
monkeypatch.setattr(pubsubs_gsub[index_alice].seen_messages, "cache", {})
|
|
||||||
|
|
||||||
# Simulate Bob sending IHAVE to Alice
|
|
||||||
await gossipsubs[index_alice].handle_ihave(ihave_msg, id_bob)
|
|
||||||
|
|
||||||
# emit_iwant should not be called since there are no message IDs
|
|
||||||
mock_emit_iwant.assert_not_called()
|
|
||||||
|
|||||||
@ -1,13 +0,0 @@
|
|||||||
from libp2p.security.noise.pb import noise_pb2 as noise_pb
|
|
||||||
|
|
||||||
|
|
||||||
def test_noise_extensions_serialization():
|
|
||||||
# Test NoiseExtensions
|
|
||||||
ext = noise_pb.NoiseExtensions()
|
|
||||||
ext.stream_muxers.append("/mplex/6.7.0")
|
|
||||||
ext.stream_muxers.append("/yamux/1.0.0")
|
|
||||||
|
|
||||||
# Serialize and deserialize
|
|
||||||
data = ext.SerializeToString()
|
|
||||||
ext2 = noise_pb.NoiseExtensions.FromString(data)
|
|
||||||
assert list(ext2.stream_muxers) == ["/mplex/6.7.0", "/yamux/1.0.0"]
|
|
||||||
@ -173,7 +173,8 @@ def noise_transport_factory(key_pair: KeyPair) -> ISecureTransport:
|
|||||||
return NoiseTransport(
|
return NoiseTransport(
|
||||||
libp2p_keypair=key_pair,
|
libp2p_keypair=key_pair,
|
||||||
noise_privkey=noise_static_key_factory(),
|
noise_privkey=noise_static_key_factory(),
|
||||||
# TODO: add early data
|
early_data=None,
|
||||||
|
with_noise_pipes=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user