Merge branch 'main' into feature/mDNS

This commit is contained in:
Manu Sheel Gupta
2025-06-29 09:43:03 -07:00
committed by GitHub
20 changed files with 924 additions and 64 deletions

View File

@ -134,7 +134,7 @@ async def test_handle_graft(monkeypatch):
# check if it is called in `handle_graft`
event_emit_prune = trio.Event()
async def emit_prune(topic, sender_peer_id):
async def emit_prune(topic, sender_peer_id, do_px, is_unsubscribe):
event_emit_prune.set()
await trio.lowlevel.checkpoint()
@ -193,7 +193,7 @@ async def test_handle_prune():
# alice emit prune message to bob, alice should be removed
# from bob's mesh peer
await gossipsubs[index_alice].emit_prune(topic, id_bob)
await gossipsubs[index_alice].emit_prune(topic, id_bob, False, False)
# `emit_prune` does not remove bob from alice's mesh peers
assert id_bob in gossipsubs[index_alice].mesh[topic]
@ -292,7 +292,9 @@ async def test_fanout():
@pytest.mark.trio
@pytest.mark.slow
async def test_fanout_maintenance():
async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub:
async with PubsubFactory.create_batch_with_gossipsub(
10, unsubscribe_back_off=1
) as pubsubs_gsub:
hosts = [pubsub.host for pubsub in pubsubs_gsub]
num_msgs = 5

View File

@ -0,0 +1,274 @@
import pytest
import trio
from libp2p.pubsub.gossipsub import (
GossipSub,
)
from libp2p.tools.utils import (
connect,
)
from tests.utils.factories import (
PubsubFactory,
)
@pytest.mark.trio
async def test_prune_backoff():
async with PubsubFactory.create_batch_with_gossipsub(
2, heartbeat_interval=0.5, prune_back_off=2
) as pubsubs:
gsub0 = pubsubs[0].router
gsub1 = pubsubs[1].router
assert isinstance(gsub0, GossipSub)
assert isinstance(gsub1, GossipSub)
host_0 = pubsubs[0].host
host_1 = pubsubs[1].host
topic = "test_prune_backoff"
# connect hosts
await connect(host_0, host_1)
await trio.sleep(0.5)
# both join the topic
await gsub0.join(topic)
await gsub1.join(topic)
await gsub0.emit_graft(topic, host_1.get_id())
await trio.sleep(0.5)
# ensure peer is registered in mesh
assert host_0.get_id() in gsub1.mesh[topic]
# prune host_1 from gsub0's mesh
await gsub0.emit_prune(topic, host_1.get_id(), False, False)
await trio.sleep(0.5)
# host_0 should not be in gsub1's mesh
assert host_0.get_id() not in gsub1.mesh[topic]
# try to graft again immediately (should be rejected due to backoff)
await gsub0.emit_graft(topic, host_1.get_id())
await trio.sleep(0.5)
assert host_0.get_id() not in gsub1.mesh[topic], (
"peer should be backoffed and not re-added"
)
# try to graft again (should succeed after backoff)
await trio.sleep(2)
await gsub0.emit_graft(topic, host_1.get_id())
await trio.sleep(1)
assert host_0.get_id() in gsub1.mesh[topic], (
"peer should be able to rejoin after backoff"
)
@pytest.mark.trio
async def test_unsubscribe_backoff():
async with PubsubFactory.create_batch_with_gossipsub(
2, heartbeat_interval=1, prune_back_off=1, unsubscribe_back_off=2
) as pubsubs:
gsub0 = pubsubs[0].router
gsub1 = pubsubs[1].router
assert isinstance(gsub0, GossipSub)
assert isinstance(gsub1, GossipSub)
host_0 = pubsubs[0].host
host_1 = pubsubs[1].host
topic = "test_unsubscribe_backoff"
# connect hosts
await connect(host_0, host_1)
await trio.sleep(0.5)
# both join the topic
await gsub0.join(topic)
await gsub1.join(topic)
await gsub0.emit_graft(topic, host_1.get_id())
await trio.sleep(0.5)
# ensure peer is registered in mesh
assert host_0.get_id() in gsub1.mesh[topic]
# host_1 unsubscribes from the topic
await gsub1.leave(topic)
await trio.sleep(0.5)
assert topic not in gsub1.mesh
# host_1 resubscribes to the topic
await gsub1.join(topic)
await trio.sleep(0.5)
assert topic in gsub1.mesh
# try to graft again immediately (should be rejected due to backoff)
await gsub0.emit_graft(topic, host_1.get_id())
await trio.sleep(0.5)
assert host_0.get_id() not in gsub1.mesh[topic], (
"peer should be backoffed and not re-added"
)
# try to graft again (should succeed after backoff)
await trio.sleep(1)
await gsub0.emit_graft(topic, host_1.get_id())
await trio.sleep(1)
assert host_0.get_id() in gsub1.mesh[topic], (
"peer should be able to rejoin after backoff"
)
@pytest.mark.trio
async def test_peer_exchange():
async with PubsubFactory.create_batch_with_gossipsub(
3,
heartbeat_interval=0.5,
do_px=True,
px_peers_count=1,
) as pubsubs:
gsub0 = pubsubs[0].router
gsub1 = pubsubs[1].router
gsub2 = pubsubs[2].router
assert isinstance(gsub0, GossipSub)
assert isinstance(gsub1, GossipSub)
assert isinstance(gsub2, GossipSub)
host_0 = pubsubs[0].host
host_1 = pubsubs[1].host
host_2 = pubsubs[2].host
topic = "test_peer_exchange"
# connect hosts
await connect(host_1, host_0)
await connect(host_1, host_2)
await trio.sleep(0.5)
# all join the topic and 0 <-> 1 and 1 <-> 2 graft
await pubsubs[1].subscribe(topic)
await pubsubs[0].subscribe(topic)
await pubsubs[2].subscribe(topic)
await gsub1.emit_graft(topic, host_0.get_id())
await gsub1.emit_graft(topic, host_2.get_id())
await gsub0.emit_graft(topic, host_1.get_id())
await gsub2.emit_graft(topic, host_1.get_id())
await trio.sleep(1)
# ensure peer is registered in mesh
assert host_0.get_id() in gsub1.mesh[topic]
assert host_2.get_id() in gsub1.mesh[topic]
assert host_2.get_id() not in gsub0.mesh[topic]
# host_1 unsubscribes from the topic
await gsub1.leave(topic)
await trio.sleep(1) # Wait for heartbeat to update mesh
assert topic not in gsub1.mesh
# Wait for gsub0 to graft host_2 into its mesh via PX
await trio.sleep(1)
assert host_2.get_id() in gsub0.mesh[topic]
@pytest.mark.trio
async def test_topics_are_isolated():
async with PubsubFactory.create_batch_with_gossipsub(
2, heartbeat_interval=0.5, prune_back_off=2
) as pubsubs:
gsub0 = pubsubs[0].router
gsub1 = pubsubs[1].router
assert isinstance(gsub0, GossipSub)
assert isinstance(gsub1, GossipSub)
host_0 = pubsubs[0].host
host_1 = pubsubs[1].host
topic1 = "test_prune_backoff"
topic2 = "test_prune_backoff2"
# connect hosts
await connect(host_0, host_1)
await trio.sleep(0.5)
# both peers join both the topics
await gsub0.join(topic1)
await gsub1.join(topic1)
await gsub0.join(topic2)
await gsub1.join(topic2)
await gsub0.emit_graft(topic1, host_1.get_id())
await trio.sleep(0.5)
# ensure topic1 for peer is registered in mesh
assert host_0.get_id() in gsub1.mesh[topic1]
# prune topic1 for host_1 from gsub0's mesh
await gsub0.emit_prune(topic1, host_1.get_id(), False, False)
await trio.sleep(0.5)
# topic1 for host_0 should not be in gsub1's mesh
assert host_0.get_id() not in gsub1.mesh[topic1]
# try to regraft topic1 and graft new topic2
await gsub0.emit_graft(topic1, host_1.get_id())
await gsub0.emit_graft(topic2, host_1.get_id())
await trio.sleep(0.5)
assert host_0.get_id() not in gsub1.mesh[topic1], (
"peer should be backoffed and not re-added"
)
assert host_0.get_id() in gsub1.mesh[topic2], (
"peer should be able to join a different topic"
)
@pytest.mark.trio
async def test_stress_churn():
NUM_PEERS = 5
CHURN_CYCLES = 30
TOPIC = "stress_churn_topic"
PRUNE_BACKOFF = 1
HEARTBEAT_INTERVAL = 0.2
async with PubsubFactory.create_batch_with_gossipsub(
NUM_PEERS,
heartbeat_interval=HEARTBEAT_INTERVAL,
prune_back_off=PRUNE_BACKOFF,
) as pubsubs:
routers: list[GossipSub] = []
for ps in pubsubs:
assert isinstance(ps.router, GossipSub)
routers.append(ps.router)
hosts = [ps.host for ps in pubsubs]
# fully connect all peers
for i in range(NUM_PEERS):
for j in range(i + 1, NUM_PEERS):
await connect(hosts[i], hosts[j])
await trio.sleep(1)
# all peers join the topic
for router in routers:
await router.join(TOPIC)
await trio.sleep(1)
# rapid join/prune cycles
for cycle in range(CHURN_CYCLES):
for i, router in enumerate(routers):
# prune all other peers from this router's mesh
for j, peer_host in enumerate(hosts):
if i != j:
await router.emit_prune(TOPIC, peer_host.get_id(), False, False)
await trio.sleep(0.1)
for i, router in enumerate(routers):
# graft all other peers back
for j, peer_host in enumerate(hosts):
if i != j:
await router.emit_graft(TOPIC, peer_host.get_id())
await trio.sleep(0.1)
# wait for backoff entries to expire and cleanup
await trio.sleep(PRUNE_BACKOFF * 2)
# check that the backoff table is not unbounded
for router in routers:
# backoff is a dict: topic -> peer -> expiry
backoff = getattr(router, "back_off", None)
assert backoff is not None, "router missing backoff table"
# only a small number of entries should remain (ideally 0)
total_entries = sum(len(peers) for peers in backoff.values())
assert total_entries < NUM_PEERS * 2, (
f"backoff table grew too large: {total_entries} entries"
)

View File

@ -0,0 +1,314 @@
import pytest
import trio
from trio.testing import memory_stream_pair
from libp2p.abc import IRawConnection
from libp2p.crypto.rsa import create_new_key_pair
from libp2p.peer.id import ID
from libp2p.peer.peerdata import PeerData
from libp2p.peer.peerstore import PeerStore
from libp2p.security.exceptions import HandshakeFailure
from libp2p.security.insecure.transport import InsecureTransport
# Adapter class to bridge between trio streams and libp2p raw connections
class TrioStreamAdapter(IRawConnection):
def __init__(self, send_stream, receive_stream, is_initiator: bool = False):
self.send_stream = send_stream
self.receive_stream = receive_stream
self.is_initiator = is_initiator
async def write(self, data: bytes) -> None:
await self.send_stream.send_all(data)
async def read(self, n: int | None = None) -> bytes:
if n is None or n == -1:
raise ValueError("Reading unbounded not supported")
return await self.receive_stream.receive_some(n)
async def close(self) -> None:
await self.send_stream.aclose()
await self.receive_stream.aclose()
def get_remote_address(self) -> tuple[str, int] | None:
# Return None since this is a test adapter without real network info
return None
@pytest.mark.trio
async def test_insecure_transport_stores_pubkey_in_peerstore():
"""
Test that InsecureTransport stores the pubkey and peerid in
peerstore during handshake.
"""
# Create key pairs for both sides
local_key_pair = create_new_key_pair()
remote_key_pair = create_new_key_pair()
# Create peer IDs
remote_peer_id = ID.from_pubkey(remote_key_pair.public_key)
# Create peerstore
peerstore = PeerStore()
# Create memory streams for communication
local_send, remote_receive = memory_stream_pair()
remote_send, local_receive = memory_stream_pair()
# Create adapters
local_stream = TrioStreamAdapter(local_send, local_receive, is_initiator=True)
remote_stream = TrioStreamAdapter(remote_send, remote_receive, is_initiator=False)
# Create transports
local_transport = InsecureTransport(local_key_pair, peerstore=peerstore)
remote_transport = InsecureTransport(remote_key_pair, peerstore=None)
# Run handshake
async def run_local_handshake(nursery_results):
with trio.move_on_after(5):
local_conn = await local_transport.secure_outbound(
local_stream, remote_peer_id
)
nursery_results["local"] = local_conn
async def run_remote_handshake(nursery_results):
with trio.move_on_after(5):
remote_conn = await remote_transport.secure_inbound(remote_stream)
nursery_results["remote"] = remote_conn
nursery_results = {}
async with trio.open_nursery() as nursery:
nursery.start_soon(run_local_handshake, nursery_results)
nursery.start_soon(run_remote_handshake, nursery_results)
await trio.sleep(0.1) # Give tasks a chance to finish
local_conn = nursery_results.get("local")
remote_conn = nursery_results.get("remote")
assert local_conn is not None, "Local handshake failed"
assert remote_conn is not None, "Remote handshake failed"
# Verify that the remote peer ID is in the peerstore
assert remote_peer_id in peerstore.peer_ids()
# Verify that the public key was stored and matches
stored_pubkey = peerstore.pubkey(remote_peer_id)
assert stored_pubkey is not None
assert stored_pubkey.serialize() == remote_key_pair.public_key.serialize()
@pytest.mark.trio
async def test_insecure_transport_without_peerstore():
"""
Test that InsecureTransport works correctly
without a peerstore.
"""
# Create key pairs for both sides
local_key_pair = create_new_key_pair()
remote_key_pair = create_new_key_pair()
# Create peer IDs
remote_peer_id = ID.from_pubkey(remote_key_pair.public_key)
# Create memory streams for communication
local_send, remote_receive = memory_stream_pair()
remote_send, local_receive = memory_stream_pair()
# Create adapters
local_stream = TrioStreamAdapter(local_send, local_receive, is_initiator=True)
remote_stream = TrioStreamAdapter(remote_send, remote_receive, is_initiator=False)
# Create transports without peerstore
local_transport = InsecureTransport(local_key_pair, peerstore=None)
remote_transport = InsecureTransport(remote_key_pair, peerstore=None)
# Run handshake
async def run_local_handshake(nursery_results):
with trio.move_on_after(5):
local_conn = await local_transport.secure_outbound(
local_stream, remote_peer_id
)
nursery_results["local"] = local_conn
async def run_remote_handshake(nursery_results):
with trio.move_on_after(5):
remote_conn = await remote_transport.secure_inbound(remote_stream)
nursery_results["remote"] = remote_conn
nursery_results = {}
async with trio.open_nursery() as nursery:
nursery.start_soon(run_local_handshake, nursery_results)
nursery.start_soon(run_remote_handshake, nursery_results)
await trio.sleep(0.1) # Give tasks a chance to finish
local_conn = nursery_results.get("local")
remote_conn = nursery_results.get("remote")
# Verify that handshake still works without a peerstore
assert local_conn is not None, "Local handshake failed"
assert remote_conn is not None, "Remote handshake failed"
@pytest.mark.trio
async def test_peerstore_unchanged_when_handshake_fails():
"""
Test that the peerstore remains unchanged if the handshake fails
due to a peer ID mismatch.
"""
# Create key pairs for both sides
local_key_pair = create_new_key_pair()
remote_key_pair = create_new_key_pair()
# Create a third key pair to cause a mismatch
mismatch_key_pair = create_new_key_pair()
# Create peer IDs
remote_peer_id = ID.from_pubkey(remote_key_pair.public_key)
mismatch_peer_id = ID.from_pubkey(mismatch_key_pair.public_key)
# Create peerstore and add some initial data to verify it stays unchanged
peerstore = PeerStore()
# Store some initial data in peerstore to verify it remains unchanged
initial_key_pair = create_new_key_pair()
initial_peer_id = ID.from_pubkey(initial_key_pair.public_key)
peerstore.add_pubkey(initial_peer_id, initial_key_pair.public_key)
# Remember the initial state of the peerstore
initial_peer_ids = set(peerstore.peer_ids())
# Create memory streams for communication
local_send, remote_receive = memory_stream_pair()
remote_send, local_receive = memory_stream_pair()
# Create adapters
local_stream = TrioStreamAdapter(local_send, local_receive, is_initiator=True)
remote_stream = TrioStreamAdapter(remote_send, remote_receive, is_initiator=False)
# Create transports
local_transport = InsecureTransport(local_key_pair, peerstore=peerstore)
remote_transport = InsecureTransport(remote_key_pair, peerstore=None)
# Run handshake with mismatched peer_id
# (expecting remote_peer_id but sending mismatch_peer_id to cause a failure)
async def run_local_handshake(nursery_results):
with trio.move_on_after(5):
try:
# Pass mismatch_peer_id instead of remote_peer_id
# to cause a handshake failure
local_conn = await local_transport.secure_outbound(
local_stream, mismatch_peer_id
)
nursery_results["local"] = local_conn
except HandshakeFailure:
nursery_results["local_error"] = True
async def run_remote_handshake(nursery_results):
with trio.move_on_after(5):
try:
remote_conn = await remote_transport.secure_inbound(remote_stream)
nursery_results["remote"] = remote_conn
except HandshakeFailure:
nursery_results["remote_error"] = True
nursery_results = {}
async with trio.open_nursery() as nursery:
nursery.start_soon(run_local_handshake, nursery_results)
nursery.start_soon(run_remote_handshake, nursery_results)
await trio.sleep(0.1)
# Verify that at least one side encountered an error
assert "local_error" in nursery_results or "remote_error" in nursery_results, (
"Expected handshake to fail due to peer ID mismatch"
)
# Verify that the peerstore remains unchanged
current_peer_ids = set(peerstore.peer_ids())
assert current_peer_ids == initial_peer_ids, (
"Peerstore should remain unchanged when handshake fails"
)
# Verify that neither the remote_peer_id nor mismatch_peer_id was added
assert remote_peer_id not in peerstore.peer_ids(), (
"Remote peer ID should not be added on handshake failure"
)
assert mismatch_peer_id not in peerstore.peer_ids(), (
"Mismatch peer ID should not be added on handshake failure"
)
@pytest.mark.trio
async def test_handshake_adds_pubkey_to_existing_peer():
"""
Test that when a peer ID already exists in the peerstore but without
a public key, the handshake correctly adds the public key.
This tests the case where we might have a peer ID from another source
(like a routing table) but don't yet have its public key.
"""
# Create key pairs for both sides
local_key_pair = create_new_key_pair()
remote_key_pair = create_new_key_pair()
# Create peer IDs
remote_peer_id = ID.from_pubkey(remote_key_pair.public_key)
# Create peerstore and add the peer ID without a public key
peerstore = PeerStore()
# Add the peer ID to the peerstore without its public key
# (adding an address for the peer, which creates the peer entry)
# This simulates having discovered a peer through DHT or other means
# without having its public key yet
peerstore.peer_data_map[remote_peer_id] = PeerData()
# Verify initial state - the peer ID should exist but without a public key
assert remote_peer_id in peerstore.peer_ids()
with pytest.raises(Exception):
peerstore.pubkey(remote_peer_id)
# Create memory streams for communication
local_send, remote_receive = memory_stream_pair()
remote_send, local_receive = memory_stream_pair()
# Create adapters
local_stream = TrioStreamAdapter(local_send, local_receive, is_initiator=True)
remote_stream = TrioStreamAdapter(remote_send, remote_receive, is_initiator=False)
# Create transports
local_transport = InsecureTransport(local_key_pair, peerstore=peerstore)
remote_transport = InsecureTransport(remote_key_pair, peerstore=None)
# Run handshake
async def run_local_handshake(nursery_results):
with trio.move_on_after(5):
local_conn = await local_transport.secure_outbound(
local_stream, remote_peer_id
)
nursery_results["local"] = local_conn
async def run_remote_handshake(nursery_results):
with trio.move_on_after(5):
remote_conn = await remote_transport.secure_inbound(remote_stream)
nursery_results["remote"] = remote_conn
nursery_results = {}
async with trio.open_nursery() as nursery:
nursery.start_soon(run_local_handshake, nursery_results)
nursery.start_soon(run_remote_handshake, nursery_results)
await trio.sleep(0.1) # Give tasks a chance to finish
local_conn = nursery_results.get("local")
remote_conn = nursery_results.get("remote")
# Verify that the handshake succeeded
assert local_conn is not None, "Local handshake failed"
assert remote_conn is not None, "Remote handshake failed"
# Verify that the peer ID is still in the peerstore
assert remote_peer_id in peerstore.peer_ids()
# Verify that the public key was added
stored_pubkey = peerstore.pubkey(remote_peer_id)
assert stored_pubkey is not None
assert stored_pubkey.serialize() == remote_key_pair.public_key.serialize()

View File

@ -79,7 +79,7 @@ async def secure_conn_pair(key_pair, peer_id):
client_rw = TrioStreamAdapter(client_send, client_receive, is_initiator=True)
server_rw = TrioStreamAdapter(server_send, server_receive, is_initiator=False)
insecure_transport = InsecureTransport(key_pair)
insecure_transport = InsecureTransport(key_pair, peerstore=None)
async def run_outbound(nursery_results):
with trio.move_on_after(5):

View File

@ -161,8 +161,8 @@ def noise_handshake_payload_factory() -> NoiseHandshakePayload:
)
def plaintext_transport_factory(key_pair: KeyPair) -> ISecureTransport:
return InsecureTransport(key_pair)
def plaintext_transport_factory(key_pair: KeyPair, peerstore=None) -> ISecureTransport:
return InsecureTransport(key_pair, peerstore=peerstore)
def secio_transport_factory(key_pair: KeyPair) -> ISecureTransport:
@ -443,6 +443,10 @@ class GossipsubFactory(factory.Factory):
heartbeat_interval = GOSSIPSUB_PARAMS.heartbeat_interval
direct_connect_initial_delay = GOSSIPSUB_PARAMS.direct_connect_initial_delay
direct_connect_interval = GOSSIPSUB_PARAMS.direct_connect_interval
do_px = GOSSIPSUB_PARAMS.do_px
px_peers_count = GOSSIPSUB_PARAMS.px_peers_count
prune_back_off = GOSSIPSUB_PARAMS.prune_back_off
unsubscribe_back_off = GOSSIPSUB_PARAMS.unsubscribe_back_off
class PubsubFactory(factory.Factory):
@ -568,6 +572,10 @@ class PubsubFactory(factory.Factory):
heartbeat_initial_delay: float = GOSSIPSUB_PARAMS.heartbeat_initial_delay,
direct_connect_initial_delay: float = GOSSIPSUB_PARAMS.direct_connect_initial_delay, # noqa: E501
direct_connect_interval: int = GOSSIPSUB_PARAMS.direct_connect_interval,
do_px: bool = GOSSIPSUB_PARAMS.do_px,
px_peers_count: int = GOSSIPSUB_PARAMS.px_peers_count,
prune_back_off: int = GOSSIPSUB_PARAMS.prune_back_off,
unsubscribe_back_off: int = GOSSIPSUB_PARAMS.unsubscribe_back_off,
security_protocol: TProtocol | None = None,
muxer_opt: TMuxerOptions | None = None,
msg_id_constructor: None
@ -588,6 +596,10 @@ class PubsubFactory(factory.Factory):
heartbeat_interval=heartbeat_interval,
direct_connect_initial_delay=direct_connect_initial_delay,
direct_connect_interval=direct_connect_interval,
do_px=do_px,
px_peers_count=px_peers_count,
prune_back_off=prune_back_off,
unsubscribe_back_off=unsubscribe_back_off,
)
else:
gossipsubs = GossipsubFactory.create_batch(
@ -602,6 +614,10 @@ class PubsubFactory(factory.Factory):
heartbeat_initial_delay=heartbeat_initial_delay,
direct_connect_initial_delay=direct_connect_initial_delay,
direct_connect_interval=direct_connect_interval,
do_px=do_px,
px_peers_count=px_peers_count,
prune_back_off=prune_back_off,
unsubscribe_back_off=unsubscribe_back_off,
)
async with cls._create_batch_with_router(