mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-12 16:10:57 +00:00
Merge branch 'main' into feature/mDNS
This commit is contained in:
@ -134,7 +134,7 @@ async def test_handle_graft(monkeypatch):
|
||||
# check if it is called in `handle_graft`
|
||||
event_emit_prune = trio.Event()
|
||||
|
||||
async def emit_prune(topic, sender_peer_id):
|
||||
async def emit_prune(topic, sender_peer_id, do_px, is_unsubscribe):
|
||||
event_emit_prune.set()
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
@ -193,7 +193,7 @@ async def test_handle_prune():
|
||||
|
||||
# alice emit prune message to bob, alice should be removed
|
||||
# from bob's mesh peer
|
||||
await gossipsubs[index_alice].emit_prune(topic, id_bob)
|
||||
await gossipsubs[index_alice].emit_prune(topic, id_bob, False, False)
|
||||
# `emit_prune` does not remove bob from alice's mesh peers
|
||||
assert id_bob in gossipsubs[index_alice].mesh[topic]
|
||||
|
||||
@ -292,7 +292,9 @@ async def test_fanout():
|
||||
@pytest.mark.trio
|
||||
@pytest.mark.slow
|
||||
async def test_fanout_maintenance():
|
||||
async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub:
|
||||
async with PubsubFactory.create_batch_with_gossipsub(
|
||||
10, unsubscribe_back_off=1
|
||||
) as pubsubs_gsub:
|
||||
hosts = [pubsub.host for pubsub in pubsubs_gsub]
|
||||
num_msgs = 5
|
||||
|
||||
|
||||
274
tests/core/pubsub/test_gossipsub_px_and_backoff.py
Normal file
274
tests/core/pubsub/test_gossipsub_px_and_backoff.py
Normal file
@ -0,0 +1,274 @@
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.pubsub.gossipsub import (
|
||||
GossipSub,
|
||||
)
|
||||
from libp2p.tools.utils import (
|
||||
connect,
|
||||
)
|
||||
from tests.utils.factories import (
|
||||
PubsubFactory,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_prune_backoff():
|
||||
async with PubsubFactory.create_batch_with_gossipsub(
|
||||
2, heartbeat_interval=0.5, prune_back_off=2
|
||||
) as pubsubs:
|
||||
gsub0 = pubsubs[0].router
|
||||
gsub1 = pubsubs[1].router
|
||||
assert isinstance(gsub0, GossipSub)
|
||||
assert isinstance(gsub1, GossipSub)
|
||||
host_0 = pubsubs[0].host
|
||||
host_1 = pubsubs[1].host
|
||||
|
||||
topic = "test_prune_backoff"
|
||||
|
||||
# connect hosts
|
||||
await connect(host_0, host_1)
|
||||
await trio.sleep(0.5)
|
||||
|
||||
# both join the topic
|
||||
await gsub0.join(topic)
|
||||
await gsub1.join(topic)
|
||||
await gsub0.emit_graft(topic, host_1.get_id())
|
||||
await trio.sleep(0.5)
|
||||
|
||||
# ensure peer is registered in mesh
|
||||
assert host_0.get_id() in gsub1.mesh[topic]
|
||||
|
||||
# prune host_1 from gsub0's mesh
|
||||
await gsub0.emit_prune(topic, host_1.get_id(), False, False)
|
||||
await trio.sleep(0.5)
|
||||
|
||||
# host_0 should not be in gsub1's mesh
|
||||
assert host_0.get_id() not in gsub1.mesh[topic]
|
||||
|
||||
# try to graft again immediately (should be rejected due to backoff)
|
||||
await gsub0.emit_graft(topic, host_1.get_id())
|
||||
await trio.sleep(0.5)
|
||||
assert host_0.get_id() not in gsub1.mesh[topic], (
|
||||
"peer should be backoffed and not re-added"
|
||||
)
|
||||
|
||||
# try to graft again (should succeed after backoff)
|
||||
await trio.sleep(2)
|
||||
await gsub0.emit_graft(topic, host_1.get_id())
|
||||
await trio.sleep(1)
|
||||
assert host_0.get_id() in gsub1.mesh[topic], (
|
||||
"peer should be able to rejoin after backoff"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_unsubscribe_backoff():
|
||||
async with PubsubFactory.create_batch_with_gossipsub(
|
||||
2, heartbeat_interval=1, prune_back_off=1, unsubscribe_back_off=2
|
||||
) as pubsubs:
|
||||
gsub0 = pubsubs[0].router
|
||||
gsub1 = pubsubs[1].router
|
||||
assert isinstance(gsub0, GossipSub)
|
||||
assert isinstance(gsub1, GossipSub)
|
||||
host_0 = pubsubs[0].host
|
||||
host_1 = pubsubs[1].host
|
||||
|
||||
topic = "test_unsubscribe_backoff"
|
||||
|
||||
# connect hosts
|
||||
await connect(host_0, host_1)
|
||||
await trio.sleep(0.5)
|
||||
|
||||
# both join the topic
|
||||
await gsub0.join(topic)
|
||||
await gsub1.join(topic)
|
||||
await gsub0.emit_graft(topic, host_1.get_id())
|
||||
await trio.sleep(0.5)
|
||||
|
||||
# ensure peer is registered in mesh
|
||||
assert host_0.get_id() in gsub1.mesh[topic]
|
||||
|
||||
# host_1 unsubscribes from the topic
|
||||
await gsub1.leave(topic)
|
||||
await trio.sleep(0.5)
|
||||
assert topic not in gsub1.mesh
|
||||
|
||||
# host_1 resubscribes to the topic
|
||||
await gsub1.join(topic)
|
||||
await trio.sleep(0.5)
|
||||
assert topic in gsub1.mesh
|
||||
|
||||
# try to graft again immediately (should be rejected due to backoff)
|
||||
await gsub0.emit_graft(topic, host_1.get_id())
|
||||
await trio.sleep(0.5)
|
||||
assert host_0.get_id() not in gsub1.mesh[topic], (
|
||||
"peer should be backoffed and not re-added"
|
||||
)
|
||||
|
||||
# try to graft again (should succeed after backoff)
|
||||
await trio.sleep(1)
|
||||
await gsub0.emit_graft(topic, host_1.get_id())
|
||||
await trio.sleep(1)
|
||||
assert host_0.get_id() in gsub1.mesh[topic], (
|
||||
"peer should be able to rejoin after backoff"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_peer_exchange():
|
||||
async with PubsubFactory.create_batch_with_gossipsub(
|
||||
3,
|
||||
heartbeat_interval=0.5,
|
||||
do_px=True,
|
||||
px_peers_count=1,
|
||||
) as pubsubs:
|
||||
gsub0 = pubsubs[0].router
|
||||
gsub1 = pubsubs[1].router
|
||||
gsub2 = pubsubs[2].router
|
||||
assert isinstance(gsub0, GossipSub)
|
||||
assert isinstance(gsub1, GossipSub)
|
||||
assert isinstance(gsub2, GossipSub)
|
||||
host_0 = pubsubs[0].host
|
||||
host_1 = pubsubs[1].host
|
||||
host_2 = pubsubs[2].host
|
||||
|
||||
topic = "test_peer_exchange"
|
||||
|
||||
# connect hosts
|
||||
await connect(host_1, host_0)
|
||||
await connect(host_1, host_2)
|
||||
await trio.sleep(0.5)
|
||||
|
||||
# all join the topic and 0 <-> 1 and 1 <-> 2 graft
|
||||
await pubsubs[1].subscribe(topic)
|
||||
await pubsubs[0].subscribe(topic)
|
||||
await pubsubs[2].subscribe(topic)
|
||||
await gsub1.emit_graft(topic, host_0.get_id())
|
||||
await gsub1.emit_graft(topic, host_2.get_id())
|
||||
await gsub0.emit_graft(topic, host_1.get_id())
|
||||
await gsub2.emit_graft(topic, host_1.get_id())
|
||||
await trio.sleep(1)
|
||||
|
||||
# ensure peer is registered in mesh
|
||||
assert host_0.get_id() in gsub1.mesh[topic]
|
||||
assert host_2.get_id() in gsub1.mesh[topic]
|
||||
assert host_2.get_id() not in gsub0.mesh[topic]
|
||||
|
||||
# host_1 unsubscribes from the topic
|
||||
await gsub1.leave(topic)
|
||||
await trio.sleep(1) # Wait for heartbeat to update mesh
|
||||
assert topic not in gsub1.mesh
|
||||
|
||||
# Wait for gsub0 to graft host_2 into its mesh via PX
|
||||
await trio.sleep(1)
|
||||
assert host_2.get_id() in gsub0.mesh[topic]
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_topics_are_isolated():
|
||||
async with PubsubFactory.create_batch_with_gossipsub(
|
||||
2, heartbeat_interval=0.5, prune_back_off=2
|
||||
) as pubsubs:
|
||||
gsub0 = pubsubs[0].router
|
||||
gsub1 = pubsubs[1].router
|
||||
assert isinstance(gsub0, GossipSub)
|
||||
assert isinstance(gsub1, GossipSub)
|
||||
host_0 = pubsubs[0].host
|
||||
host_1 = pubsubs[1].host
|
||||
|
||||
topic1 = "test_prune_backoff"
|
||||
topic2 = "test_prune_backoff2"
|
||||
|
||||
# connect hosts
|
||||
await connect(host_0, host_1)
|
||||
await trio.sleep(0.5)
|
||||
|
||||
# both peers join both the topics
|
||||
await gsub0.join(topic1)
|
||||
await gsub1.join(topic1)
|
||||
await gsub0.join(topic2)
|
||||
await gsub1.join(topic2)
|
||||
await gsub0.emit_graft(topic1, host_1.get_id())
|
||||
await trio.sleep(0.5)
|
||||
|
||||
# ensure topic1 for peer is registered in mesh
|
||||
assert host_0.get_id() in gsub1.mesh[topic1]
|
||||
|
||||
# prune topic1 for host_1 from gsub0's mesh
|
||||
await gsub0.emit_prune(topic1, host_1.get_id(), False, False)
|
||||
await trio.sleep(0.5)
|
||||
|
||||
# topic1 for host_0 should not be in gsub1's mesh
|
||||
assert host_0.get_id() not in gsub1.mesh[topic1]
|
||||
|
||||
# try to regraft topic1 and graft new topic2
|
||||
await gsub0.emit_graft(topic1, host_1.get_id())
|
||||
await gsub0.emit_graft(topic2, host_1.get_id())
|
||||
await trio.sleep(0.5)
|
||||
assert host_0.get_id() not in gsub1.mesh[topic1], (
|
||||
"peer should be backoffed and not re-added"
|
||||
)
|
||||
assert host_0.get_id() in gsub1.mesh[topic2], (
|
||||
"peer should be able to join a different topic"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_stress_churn():
|
||||
NUM_PEERS = 5
|
||||
CHURN_CYCLES = 30
|
||||
TOPIC = "stress_churn_topic"
|
||||
PRUNE_BACKOFF = 1
|
||||
HEARTBEAT_INTERVAL = 0.2
|
||||
|
||||
async with PubsubFactory.create_batch_with_gossipsub(
|
||||
NUM_PEERS,
|
||||
heartbeat_interval=HEARTBEAT_INTERVAL,
|
||||
prune_back_off=PRUNE_BACKOFF,
|
||||
) as pubsubs:
|
||||
routers: list[GossipSub] = []
|
||||
for ps in pubsubs:
|
||||
assert isinstance(ps.router, GossipSub)
|
||||
routers.append(ps.router)
|
||||
hosts = [ps.host for ps in pubsubs]
|
||||
|
||||
# fully connect all peers
|
||||
for i in range(NUM_PEERS):
|
||||
for j in range(i + 1, NUM_PEERS):
|
||||
await connect(hosts[i], hosts[j])
|
||||
await trio.sleep(1)
|
||||
|
||||
# all peers join the topic
|
||||
for router in routers:
|
||||
await router.join(TOPIC)
|
||||
await trio.sleep(1)
|
||||
|
||||
# rapid join/prune cycles
|
||||
for cycle in range(CHURN_CYCLES):
|
||||
for i, router in enumerate(routers):
|
||||
# prune all other peers from this router's mesh
|
||||
for j, peer_host in enumerate(hosts):
|
||||
if i != j:
|
||||
await router.emit_prune(TOPIC, peer_host.get_id(), False, False)
|
||||
await trio.sleep(0.1)
|
||||
for i, router in enumerate(routers):
|
||||
# graft all other peers back
|
||||
for j, peer_host in enumerate(hosts):
|
||||
if i != j:
|
||||
await router.emit_graft(TOPIC, peer_host.get_id())
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# wait for backoff entries to expire and cleanup
|
||||
await trio.sleep(PRUNE_BACKOFF * 2)
|
||||
|
||||
# check that the backoff table is not unbounded
|
||||
for router in routers:
|
||||
# backoff is a dict: topic -> peer -> expiry
|
||||
backoff = getattr(router, "back_off", None)
|
||||
assert backoff is not None, "router missing backoff table"
|
||||
# only a small number of entries should remain (ideally 0)
|
||||
total_entries = sum(len(peers) for peers in backoff.values())
|
||||
assert total_entries < NUM_PEERS * 2, (
|
||||
f"backoff table grew too large: {total_entries} entries"
|
||||
)
|
||||
314
tests/core/security/test_insecure_peerstore_integration.py
Normal file
314
tests/core/security/test_insecure_peerstore_integration.py
Normal file
@ -0,0 +1,314 @@
|
||||
import pytest
|
||||
import trio
|
||||
from trio.testing import memory_stream_pair
|
||||
|
||||
from libp2p.abc import IRawConnection
|
||||
from libp2p.crypto.rsa import create_new_key_pair
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerdata import PeerData
|
||||
from libp2p.peer.peerstore import PeerStore
|
||||
from libp2p.security.exceptions import HandshakeFailure
|
||||
from libp2p.security.insecure.transport import InsecureTransport
|
||||
|
||||
|
||||
# Adapter class to bridge between trio streams and libp2p raw connections
|
||||
class TrioStreamAdapter(IRawConnection):
|
||||
def __init__(self, send_stream, receive_stream, is_initiator: bool = False):
|
||||
self.send_stream = send_stream
|
||||
self.receive_stream = receive_stream
|
||||
self.is_initiator = is_initiator
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
await self.send_stream.send_all(data)
|
||||
|
||||
async def read(self, n: int | None = None) -> bytes:
|
||||
if n is None or n == -1:
|
||||
raise ValueError("Reading unbounded not supported")
|
||||
return await self.receive_stream.receive_some(n)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.send_stream.aclose()
|
||||
await self.receive_stream.aclose()
|
||||
|
||||
def get_remote_address(self) -> tuple[str, int] | None:
|
||||
# Return None since this is a test adapter without real network info
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_insecure_transport_stores_pubkey_in_peerstore():
|
||||
"""
|
||||
Test that InsecureTransport stores the pubkey and peerid in
|
||||
peerstore during handshake.
|
||||
"""
|
||||
# Create key pairs for both sides
|
||||
local_key_pair = create_new_key_pair()
|
||||
remote_key_pair = create_new_key_pair()
|
||||
|
||||
# Create peer IDs
|
||||
remote_peer_id = ID.from_pubkey(remote_key_pair.public_key)
|
||||
|
||||
# Create peerstore
|
||||
peerstore = PeerStore()
|
||||
|
||||
# Create memory streams for communication
|
||||
local_send, remote_receive = memory_stream_pair()
|
||||
remote_send, local_receive = memory_stream_pair()
|
||||
|
||||
# Create adapters
|
||||
local_stream = TrioStreamAdapter(local_send, local_receive, is_initiator=True)
|
||||
remote_stream = TrioStreamAdapter(remote_send, remote_receive, is_initiator=False)
|
||||
|
||||
# Create transports
|
||||
local_transport = InsecureTransport(local_key_pair, peerstore=peerstore)
|
||||
remote_transport = InsecureTransport(remote_key_pair, peerstore=None)
|
||||
|
||||
# Run handshake
|
||||
async def run_local_handshake(nursery_results):
|
||||
with trio.move_on_after(5):
|
||||
local_conn = await local_transport.secure_outbound(
|
||||
local_stream, remote_peer_id
|
||||
)
|
||||
nursery_results["local"] = local_conn
|
||||
|
||||
async def run_remote_handshake(nursery_results):
|
||||
with trio.move_on_after(5):
|
||||
remote_conn = await remote_transport.secure_inbound(remote_stream)
|
||||
nursery_results["remote"] = remote_conn
|
||||
|
||||
nursery_results = {}
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(run_local_handshake, nursery_results)
|
||||
nursery.start_soon(run_remote_handshake, nursery_results)
|
||||
await trio.sleep(0.1) # Give tasks a chance to finish
|
||||
|
||||
local_conn = nursery_results.get("local")
|
||||
remote_conn = nursery_results.get("remote")
|
||||
|
||||
assert local_conn is not None, "Local handshake failed"
|
||||
assert remote_conn is not None, "Remote handshake failed"
|
||||
|
||||
# Verify that the remote peer ID is in the peerstore
|
||||
assert remote_peer_id in peerstore.peer_ids()
|
||||
|
||||
# Verify that the public key was stored and matches
|
||||
stored_pubkey = peerstore.pubkey(remote_peer_id)
|
||||
assert stored_pubkey is not None
|
||||
assert stored_pubkey.serialize() == remote_key_pair.public_key.serialize()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_insecure_transport_without_peerstore():
|
||||
"""
|
||||
Test that InsecureTransport works correctly
|
||||
without a peerstore.
|
||||
"""
|
||||
# Create key pairs for both sides
|
||||
local_key_pair = create_new_key_pair()
|
||||
remote_key_pair = create_new_key_pair()
|
||||
|
||||
# Create peer IDs
|
||||
remote_peer_id = ID.from_pubkey(remote_key_pair.public_key)
|
||||
|
||||
# Create memory streams for communication
|
||||
local_send, remote_receive = memory_stream_pair()
|
||||
remote_send, local_receive = memory_stream_pair()
|
||||
|
||||
# Create adapters
|
||||
local_stream = TrioStreamAdapter(local_send, local_receive, is_initiator=True)
|
||||
remote_stream = TrioStreamAdapter(remote_send, remote_receive, is_initiator=False)
|
||||
|
||||
# Create transports without peerstore
|
||||
local_transport = InsecureTransport(local_key_pair, peerstore=None)
|
||||
remote_transport = InsecureTransport(remote_key_pair, peerstore=None)
|
||||
|
||||
# Run handshake
|
||||
async def run_local_handshake(nursery_results):
|
||||
with trio.move_on_after(5):
|
||||
local_conn = await local_transport.secure_outbound(
|
||||
local_stream, remote_peer_id
|
||||
)
|
||||
nursery_results["local"] = local_conn
|
||||
|
||||
async def run_remote_handshake(nursery_results):
|
||||
with trio.move_on_after(5):
|
||||
remote_conn = await remote_transport.secure_inbound(remote_stream)
|
||||
nursery_results["remote"] = remote_conn
|
||||
|
||||
nursery_results = {}
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(run_local_handshake, nursery_results)
|
||||
nursery.start_soon(run_remote_handshake, nursery_results)
|
||||
await trio.sleep(0.1) # Give tasks a chance to finish
|
||||
|
||||
local_conn = nursery_results.get("local")
|
||||
remote_conn = nursery_results.get("remote")
|
||||
|
||||
# Verify that handshake still works without a peerstore
|
||||
assert local_conn is not None, "Local handshake failed"
|
||||
assert remote_conn is not None, "Remote handshake failed"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_peerstore_unchanged_when_handshake_fails():
|
||||
"""
|
||||
Test that the peerstore remains unchanged if the handshake fails
|
||||
due to a peer ID mismatch.
|
||||
"""
|
||||
# Create key pairs for both sides
|
||||
local_key_pair = create_new_key_pair()
|
||||
remote_key_pair = create_new_key_pair()
|
||||
|
||||
# Create a third key pair to cause a mismatch
|
||||
mismatch_key_pair = create_new_key_pair()
|
||||
|
||||
# Create peer IDs
|
||||
remote_peer_id = ID.from_pubkey(remote_key_pair.public_key)
|
||||
mismatch_peer_id = ID.from_pubkey(mismatch_key_pair.public_key)
|
||||
|
||||
# Create peerstore and add some initial data to verify it stays unchanged
|
||||
peerstore = PeerStore()
|
||||
|
||||
# Store some initial data in peerstore to verify it remains unchanged
|
||||
initial_key_pair = create_new_key_pair()
|
||||
initial_peer_id = ID.from_pubkey(initial_key_pair.public_key)
|
||||
peerstore.add_pubkey(initial_peer_id, initial_key_pair.public_key)
|
||||
|
||||
# Remember the initial state of the peerstore
|
||||
initial_peer_ids = set(peerstore.peer_ids())
|
||||
|
||||
# Create memory streams for communication
|
||||
local_send, remote_receive = memory_stream_pair()
|
||||
remote_send, local_receive = memory_stream_pair()
|
||||
|
||||
# Create adapters
|
||||
local_stream = TrioStreamAdapter(local_send, local_receive, is_initiator=True)
|
||||
remote_stream = TrioStreamAdapter(remote_send, remote_receive, is_initiator=False)
|
||||
|
||||
# Create transports
|
||||
local_transport = InsecureTransport(local_key_pair, peerstore=peerstore)
|
||||
remote_transport = InsecureTransport(remote_key_pair, peerstore=None)
|
||||
|
||||
# Run handshake with mismatched peer_id
|
||||
# (expecting remote_peer_id but sending mismatch_peer_id to cause a failure)
|
||||
async def run_local_handshake(nursery_results):
|
||||
with trio.move_on_after(5):
|
||||
try:
|
||||
# Pass mismatch_peer_id instead of remote_peer_id
|
||||
# to cause a handshake failure
|
||||
local_conn = await local_transport.secure_outbound(
|
||||
local_stream, mismatch_peer_id
|
||||
)
|
||||
nursery_results["local"] = local_conn
|
||||
except HandshakeFailure:
|
||||
nursery_results["local_error"] = True
|
||||
|
||||
async def run_remote_handshake(nursery_results):
|
||||
with trio.move_on_after(5):
|
||||
try:
|
||||
remote_conn = await remote_transport.secure_inbound(remote_stream)
|
||||
nursery_results["remote"] = remote_conn
|
||||
except HandshakeFailure:
|
||||
nursery_results["remote_error"] = True
|
||||
|
||||
nursery_results = {}
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(run_local_handshake, nursery_results)
|
||||
nursery.start_soon(run_remote_handshake, nursery_results)
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Verify that at least one side encountered an error
|
||||
assert "local_error" in nursery_results or "remote_error" in nursery_results, (
|
||||
"Expected handshake to fail due to peer ID mismatch"
|
||||
)
|
||||
|
||||
# Verify that the peerstore remains unchanged
|
||||
current_peer_ids = set(peerstore.peer_ids())
|
||||
assert current_peer_ids == initial_peer_ids, (
|
||||
"Peerstore should remain unchanged when handshake fails"
|
||||
)
|
||||
|
||||
# Verify that neither the remote_peer_id nor mismatch_peer_id was added
|
||||
assert remote_peer_id not in peerstore.peer_ids(), (
|
||||
"Remote peer ID should not be added on handshake failure"
|
||||
)
|
||||
assert mismatch_peer_id not in peerstore.peer_ids(), (
|
||||
"Mismatch peer ID should not be added on handshake failure"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_handshake_adds_pubkey_to_existing_peer():
|
||||
"""
|
||||
Test that when a peer ID already exists in the peerstore but without
|
||||
a public key, the handshake correctly adds the public key.
|
||||
|
||||
This tests the case where we might have a peer ID from another source
|
||||
(like a routing table) but don't yet have its public key.
|
||||
"""
|
||||
# Create key pairs for both sides
|
||||
local_key_pair = create_new_key_pair()
|
||||
remote_key_pair = create_new_key_pair()
|
||||
|
||||
# Create peer IDs
|
||||
remote_peer_id = ID.from_pubkey(remote_key_pair.public_key)
|
||||
|
||||
# Create peerstore and add the peer ID without a public key
|
||||
peerstore = PeerStore()
|
||||
|
||||
# Add the peer ID to the peerstore without its public key
|
||||
# (adding an address for the peer, which creates the peer entry)
|
||||
# This simulates having discovered a peer through DHT or other means
|
||||
# without having its public key yet
|
||||
peerstore.peer_data_map[remote_peer_id] = PeerData()
|
||||
|
||||
# Verify initial state - the peer ID should exist but without a public key
|
||||
assert remote_peer_id in peerstore.peer_ids()
|
||||
with pytest.raises(Exception):
|
||||
peerstore.pubkey(remote_peer_id)
|
||||
|
||||
# Create memory streams for communication
|
||||
local_send, remote_receive = memory_stream_pair()
|
||||
remote_send, local_receive = memory_stream_pair()
|
||||
|
||||
# Create adapters
|
||||
local_stream = TrioStreamAdapter(local_send, local_receive, is_initiator=True)
|
||||
remote_stream = TrioStreamAdapter(remote_send, remote_receive, is_initiator=False)
|
||||
|
||||
# Create transports
|
||||
local_transport = InsecureTransport(local_key_pair, peerstore=peerstore)
|
||||
remote_transport = InsecureTransport(remote_key_pair, peerstore=None)
|
||||
|
||||
# Run handshake
|
||||
async def run_local_handshake(nursery_results):
|
||||
with trio.move_on_after(5):
|
||||
local_conn = await local_transport.secure_outbound(
|
||||
local_stream, remote_peer_id
|
||||
)
|
||||
nursery_results["local"] = local_conn
|
||||
|
||||
async def run_remote_handshake(nursery_results):
|
||||
with trio.move_on_after(5):
|
||||
remote_conn = await remote_transport.secure_inbound(remote_stream)
|
||||
nursery_results["remote"] = remote_conn
|
||||
|
||||
nursery_results = {}
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(run_local_handshake, nursery_results)
|
||||
nursery.start_soon(run_remote_handshake, nursery_results)
|
||||
await trio.sleep(0.1) # Give tasks a chance to finish
|
||||
|
||||
local_conn = nursery_results.get("local")
|
||||
remote_conn = nursery_results.get("remote")
|
||||
|
||||
# Verify that the handshake succeeded
|
||||
assert local_conn is not None, "Local handshake failed"
|
||||
assert remote_conn is not None, "Remote handshake failed"
|
||||
|
||||
# Verify that the peer ID is still in the peerstore
|
||||
assert remote_peer_id in peerstore.peer_ids()
|
||||
|
||||
# Verify that the public key was added
|
||||
stored_pubkey = peerstore.pubkey(remote_peer_id)
|
||||
assert stored_pubkey is not None
|
||||
assert stored_pubkey.serialize() == remote_key_pair.public_key.serialize()
|
||||
@ -79,7 +79,7 @@ async def secure_conn_pair(key_pair, peer_id):
|
||||
client_rw = TrioStreamAdapter(client_send, client_receive, is_initiator=True)
|
||||
server_rw = TrioStreamAdapter(server_send, server_receive, is_initiator=False)
|
||||
|
||||
insecure_transport = InsecureTransport(key_pair)
|
||||
insecure_transport = InsecureTransport(key_pair, peerstore=None)
|
||||
|
||||
async def run_outbound(nursery_results):
|
||||
with trio.move_on_after(5):
|
||||
|
||||
@ -161,8 +161,8 @@ def noise_handshake_payload_factory() -> NoiseHandshakePayload:
|
||||
)
|
||||
|
||||
|
||||
def plaintext_transport_factory(key_pair: KeyPair) -> ISecureTransport:
|
||||
return InsecureTransport(key_pair)
|
||||
def plaintext_transport_factory(key_pair: KeyPair, peerstore=None) -> ISecureTransport:
|
||||
return InsecureTransport(key_pair, peerstore=peerstore)
|
||||
|
||||
|
||||
def secio_transport_factory(key_pair: KeyPair) -> ISecureTransport:
|
||||
@ -443,6 +443,10 @@ class GossipsubFactory(factory.Factory):
|
||||
heartbeat_interval = GOSSIPSUB_PARAMS.heartbeat_interval
|
||||
direct_connect_initial_delay = GOSSIPSUB_PARAMS.direct_connect_initial_delay
|
||||
direct_connect_interval = GOSSIPSUB_PARAMS.direct_connect_interval
|
||||
do_px = GOSSIPSUB_PARAMS.do_px
|
||||
px_peers_count = GOSSIPSUB_PARAMS.px_peers_count
|
||||
prune_back_off = GOSSIPSUB_PARAMS.prune_back_off
|
||||
unsubscribe_back_off = GOSSIPSUB_PARAMS.unsubscribe_back_off
|
||||
|
||||
|
||||
class PubsubFactory(factory.Factory):
|
||||
@ -568,6 +572,10 @@ class PubsubFactory(factory.Factory):
|
||||
heartbeat_initial_delay: float = GOSSIPSUB_PARAMS.heartbeat_initial_delay,
|
||||
direct_connect_initial_delay: float = GOSSIPSUB_PARAMS.direct_connect_initial_delay, # noqa: E501
|
||||
direct_connect_interval: int = GOSSIPSUB_PARAMS.direct_connect_interval,
|
||||
do_px: bool = GOSSIPSUB_PARAMS.do_px,
|
||||
px_peers_count: int = GOSSIPSUB_PARAMS.px_peers_count,
|
||||
prune_back_off: int = GOSSIPSUB_PARAMS.prune_back_off,
|
||||
unsubscribe_back_off: int = GOSSIPSUB_PARAMS.unsubscribe_back_off,
|
||||
security_protocol: TProtocol | None = None,
|
||||
muxer_opt: TMuxerOptions | None = None,
|
||||
msg_id_constructor: None
|
||||
@ -588,6 +596,10 @@ class PubsubFactory(factory.Factory):
|
||||
heartbeat_interval=heartbeat_interval,
|
||||
direct_connect_initial_delay=direct_connect_initial_delay,
|
||||
direct_connect_interval=direct_connect_interval,
|
||||
do_px=do_px,
|
||||
px_peers_count=px_peers_count,
|
||||
prune_back_off=prune_back_off,
|
||||
unsubscribe_back_off=unsubscribe_back_off,
|
||||
)
|
||||
else:
|
||||
gossipsubs = GossipsubFactory.create_batch(
|
||||
@ -602,6 +614,10 @@ class PubsubFactory(factory.Factory):
|
||||
heartbeat_initial_delay=heartbeat_initial_delay,
|
||||
direct_connect_initial_delay=direct_connect_initial_delay,
|
||||
direct_connect_interval=direct_connect_interval,
|
||||
do_px=do_px,
|
||||
px_peers_count=px_peers_count,
|
||||
prune_back_off=prune_back_off,
|
||||
unsubscribe_back_off=unsubscribe_back_off,
|
||||
)
|
||||
|
||||
async with cls._create_batch_with_router(
|
||||
|
||||
Reference in New Issue
Block a user