mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
reorg test structure to match tox and CI jobs, drop bumpversion for bump-my-version and move config to pyproject.toml, fix docs building
This commit is contained in:
197
tests/core/pubsub/test_dummyaccount_demo.py
Normal file
197
tests/core/pubsub/test_dummyaccount_demo.py
Normal file
@ -0,0 +1,197 @@
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.tools.pubsub.dummy_account_node import (
|
||||
DummyAccountNode,
|
||||
)
|
||||
from libp2p.tools.utils import (
|
||||
connect,
|
||||
)
|
||||
|
||||
|
||||
async def perform_test(num_nodes, adjacency_map, action_func, assertion_func):
|
||||
"""
|
||||
Helper function to allow for easy construction of custom tests for dummy
|
||||
account nodes in various network topologies.
|
||||
|
||||
:param num_nodes: number of nodes in the test
|
||||
:param adjacency_map: adjacency map defining each node and its list of neighbors
|
||||
:param action_func: function to execute that includes actions by the nodes,
|
||||
such as send crypto and set crypto
|
||||
:param assertion_func: assertions for testing the results of the actions are correct
|
||||
"""
|
||||
|
||||
async with DummyAccountNode.create(num_nodes) as dummy_nodes:
|
||||
# Create connections between nodes according to `adjacency_map`
|
||||
async with trio.open_nursery() as nursery:
|
||||
for source_num in adjacency_map:
|
||||
target_nums = adjacency_map[source_num]
|
||||
for target_num in target_nums:
|
||||
nursery.start_soon(
|
||||
connect,
|
||||
dummy_nodes[source_num].host,
|
||||
dummy_nodes[target_num].host,
|
||||
)
|
||||
|
||||
# Allow time for network creation to take place
|
||||
await trio.sleep(0.25)
|
||||
|
||||
# Perform action function
|
||||
await action_func(dummy_nodes)
|
||||
|
||||
# Allow time for action function to be performed (i.e. messages to propogate)
|
||||
await trio.sleep(1)
|
||||
|
||||
# Perform assertion function
|
||||
for dummy_node in dummy_nodes:
|
||||
assertion_func(dummy_node)
|
||||
|
||||
# Success, terminate pending tasks.
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_simple_two_nodes():
|
||||
num_nodes = 2
|
||||
adj_map = {0: [1]}
|
||||
|
||||
async def action_func(dummy_nodes):
|
||||
await dummy_nodes[0].publish_set_crypto("aspyn", 10)
|
||||
|
||||
def assertion_func(dummy_node):
|
||||
assert dummy_node.get_balance("aspyn") == 10
|
||||
|
||||
await perform_test(num_nodes, adj_map, action_func, assertion_func)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_simple_three_nodes_line_topography():
|
||||
num_nodes = 3
|
||||
adj_map = {0: [1], 1: [2]}
|
||||
|
||||
async def action_func(dummy_nodes):
|
||||
await dummy_nodes[0].publish_set_crypto("aspyn", 10)
|
||||
|
||||
def assertion_func(dummy_node):
|
||||
assert dummy_node.get_balance("aspyn") == 10
|
||||
|
||||
await perform_test(num_nodes, adj_map, action_func, assertion_func)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_simple_three_nodes_triangle_topography():
|
||||
num_nodes = 3
|
||||
adj_map = {0: [1, 2], 1: [2]}
|
||||
|
||||
async def action_func(dummy_nodes):
|
||||
await dummy_nodes[0].publish_set_crypto("aspyn", 20)
|
||||
|
||||
def assertion_func(dummy_node):
|
||||
assert dummy_node.get_balance("aspyn") == 20
|
||||
|
||||
await perform_test(num_nodes, adj_map, action_func, assertion_func)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_simple_seven_nodes_tree_topography():
|
||||
num_nodes = 7
|
||||
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
|
||||
|
||||
async def action_func(dummy_nodes):
|
||||
await dummy_nodes[0].publish_set_crypto("aspyn", 20)
|
||||
|
||||
def assertion_func(dummy_node):
|
||||
assert dummy_node.get_balance("aspyn") == 20
|
||||
|
||||
await perform_test(num_nodes, adj_map, action_func, assertion_func)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_set_then_send_from_root_seven_nodes_tree_topography():
|
||||
num_nodes = 7
|
||||
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
|
||||
|
||||
async def action_func(dummy_nodes):
|
||||
await dummy_nodes[0].publish_set_crypto("aspyn", 20)
|
||||
await trio.sleep(0.25)
|
||||
await dummy_nodes[0].publish_send_crypto("aspyn", "alex", 5)
|
||||
|
||||
def assertion_func(dummy_node):
|
||||
assert dummy_node.get_balance("aspyn") == 15
|
||||
assert dummy_node.get_balance("alex") == 5
|
||||
|
||||
await perform_test(num_nodes, adj_map, action_func, assertion_func)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_set_then_send_from_different_leafs_seven_nodes_tree_topography():
|
||||
num_nodes = 7
|
||||
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
|
||||
|
||||
async def action_func(dummy_nodes):
|
||||
await dummy_nodes[6].publish_set_crypto("aspyn", 20)
|
||||
await trio.sleep(0.25)
|
||||
await dummy_nodes[4].publish_send_crypto("aspyn", "alex", 5)
|
||||
|
||||
def assertion_func(dummy_node):
|
||||
assert dummy_node.get_balance("aspyn") == 15
|
||||
assert dummy_node.get_balance("alex") == 5
|
||||
|
||||
await perform_test(num_nodes, adj_map, action_func, assertion_func)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_simple_five_nodes_ring_topography():
|
||||
num_nodes = 5
|
||||
adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]}
|
||||
|
||||
async def action_func(dummy_nodes):
|
||||
await dummy_nodes[0].publish_set_crypto("aspyn", 20)
|
||||
|
||||
def assertion_func(dummy_node):
|
||||
assert dummy_node.get_balance("aspyn") == 20
|
||||
|
||||
await perform_test(num_nodes, adj_map, action_func, assertion_func)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_set_then_send_from_diff_nodes_five_nodes_ring_topography():
|
||||
num_nodes = 5
|
||||
adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]}
|
||||
|
||||
async def action_func(dummy_nodes):
|
||||
await dummy_nodes[0].publish_set_crypto("alex", 20)
|
||||
await trio.sleep(0.25)
|
||||
await dummy_nodes[3].publish_send_crypto("alex", "rob", 12)
|
||||
|
||||
def assertion_func(dummy_node):
|
||||
assert dummy_node.get_balance("alex") == 8
|
||||
assert dummy_node.get_balance("rob") == 12
|
||||
|
||||
await perform_test(num_nodes, adj_map, action_func, assertion_func)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
@pytest.mark.slow
|
||||
async def test_set_then_send_from_five_diff_nodes_five_nodes_ring_topography():
|
||||
num_nodes = 5
|
||||
adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]}
|
||||
|
||||
async def action_func(dummy_nodes):
|
||||
await dummy_nodes[0].publish_set_crypto("alex", 20)
|
||||
await trio.sleep(1)
|
||||
await dummy_nodes[1].publish_send_crypto("alex", "rob", 3)
|
||||
await trio.sleep(1)
|
||||
await dummy_nodes[2].publish_send_crypto("rob", "aspyn", 2)
|
||||
await trio.sleep(1)
|
||||
await dummy_nodes[3].publish_send_crypto("aspyn", "zx", 1)
|
||||
await trio.sleep(1)
|
||||
await dummy_nodes[4].publish_send_crypto("zx", "raul", 1)
|
||||
|
||||
def assertion_func(dummy_node):
|
||||
assert dummy_node.get_balance("alex") == 17
|
||||
assert dummy_node.get_balance("rob") == 1
|
||||
assert dummy_node.get_balance("aspyn") == 1
|
||||
assert dummy_node.get_balance("zx") == 0
|
||||
assert dummy_node.get_balance("raul") == 1
|
||||
|
||||
await perform_test(num_nodes, adj_map, action_func, assertion_func)
|
||||
95
tests/core/pubsub/test_floodsub.py
Normal file
95
tests/core/pubsub/test_floodsub.py
Normal file
@ -0,0 +1,95 @@
|
||||
import functools
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.tools.factories import (
|
||||
PubsubFactory,
|
||||
)
|
||||
from libp2p.tools.pubsub.floodsub_integration_test_settings import (
|
||||
floodsub_protocol_pytest_params,
|
||||
perform_test_from_obj,
|
||||
)
|
||||
from libp2p.tools.utils import (
|
||||
connect,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_simple_two_nodes():
|
||||
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
|
||||
topic = "my_topic"
|
||||
data = b"some data"
|
||||
|
||||
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
|
||||
await trio.sleep(0.25)
|
||||
|
||||
sub_b = await pubsubs_fsub[1].subscribe(topic)
|
||||
# Sleep to let a know of b's subscription
|
||||
await trio.sleep(0.25)
|
||||
|
||||
await pubsubs_fsub[0].publish(topic, data)
|
||||
|
||||
res_b = await sub_b.get()
|
||||
|
||||
# Check that the msg received by node_b is the same
|
||||
# as the message sent by node_a
|
||||
assert ID(res_b.from_id) == pubsubs_fsub[0].host.get_id()
|
||||
assert res_b.data == data
|
||||
assert res_b.topicIDs == [topic]
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_lru_cache_two_nodes():
|
||||
# two nodes with cache_size of 4
|
||||
|
||||
# Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`.
|
||||
def get_msg_id(msg):
|
||||
# Originally it is `(msg.seqno, msg.from_id)`
|
||||
return (msg.data, msg.from_id)
|
||||
|
||||
async with PubsubFactory.create_batch_with_floodsub(
|
||||
2, cache_size=4, msg_id_constructor=get_msg_id
|
||||
) as pubsubs_fsub:
|
||||
# `node_a` send the following messages to node_b
|
||||
message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1]
|
||||
# `node_b` should only receive the following
|
||||
expected_received_indices = [1, 2, 3, 4, 5, 1]
|
||||
|
||||
topic = "my_topic"
|
||||
|
||||
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
|
||||
await trio.sleep(0.25)
|
||||
|
||||
sub_b = await pubsubs_fsub[1].subscribe(topic)
|
||||
await trio.sleep(0.25)
|
||||
|
||||
def _make_testing_data(i: int) -> bytes:
|
||||
num_int_bytes = 4
|
||||
if i >= 2 ** (num_int_bytes * 8):
|
||||
raise ValueError("integer is too large to be serialized")
|
||||
return b"data" + i.to_bytes(num_int_bytes, "big")
|
||||
|
||||
for index in message_indices:
|
||||
await pubsubs_fsub[0].publish(topic, _make_testing_data(index))
|
||||
await trio.sleep(0.25)
|
||||
|
||||
for index in expected_received_indices:
|
||||
res_b = await sub_b.get()
|
||||
assert res_b.data == _make_testing_data(index)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params)
|
||||
@pytest.mark.trio
|
||||
@pytest.mark.slow
|
||||
async def test_gossipsub_run_with_floodsub_tests(test_case_obj, security_protocol):
|
||||
await perform_test_from_obj(
|
||||
test_case_obj,
|
||||
functools.partial(
|
||||
PubsubFactory.create_batch_with_floodsub,
|
||||
security_protocol=security_protocol,
|
||||
),
|
||||
)
|
||||
488
tests/core/pubsub/test_gossipsub.py
Normal file
488
tests/core/pubsub/test_gossipsub.py
Normal file
@ -0,0 +1,488 @@
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.pubsub.gossipsub import (
|
||||
PROTOCOL_ID,
|
||||
)
|
||||
from libp2p.tools.factories import (
|
||||
IDFactory,
|
||||
PubsubFactory,
|
||||
)
|
||||
from libp2p.tools.pubsub.utils import (
|
||||
dense_connect,
|
||||
one_to_all_connect,
|
||||
)
|
||||
from libp2p.tools.utils import (
|
||||
connect,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_join():
|
||||
async with PubsubFactory.create_batch_with_gossipsub(
|
||||
4, degree=4, degree_low=3, degree_high=5
|
||||
) as pubsubs_gsub:
|
||||
gossipsubs = [pubsub.router for pubsub in pubsubs_gsub]
|
||||
hosts = [pubsub.host for pubsub in pubsubs_gsub]
|
||||
hosts_indices = list(range(len(pubsubs_gsub)))
|
||||
|
||||
topic = "test_join"
|
||||
central_node_index = 0
|
||||
# Remove index of central host from the indices
|
||||
hosts_indices.remove(central_node_index)
|
||||
num_subscribed_peer = 2
|
||||
subscribed_peer_indices = random.sample(hosts_indices, num_subscribed_peer)
|
||||
|
||||
# All pubsub except the one of central node subscribe to topic
|
||||
for i in subscribed_peer_indices:
|
||||
await pubsubs_gsub[i].subscribe(topic)
|
||||
|
||||
# Connect central host to all other hosts
|
||||
await one_to_all_connect(hosts, central_node_index)
|
||||
|
||||
# Wait 2 seconds for heartbeat to allow mesh to connect
|
||||
await trio.sleep(2)
|
||||
|
||||
# Central node publish to the topic so that this topic
|
||||
# is added to central node's fanout
|
||||
# publish from the randomly chosen host
|
||||
await pubsubs_gsub[central_node_index].publish(topic, b"data")
|
||||
|
||||
# Check that the gossipsub of central node has fanout for the topic
|
||||
assert topic in gossipsubs[central_node_index].fanout
|
||||
# Check that the gossipsub of central node does not have a mesh for the topic
|
||||
assert topic not in gossipsubs[central_node_index].mesh
|
||||
|
||||
# Central node subscribes the topic
|
||||
await pubsubs_gsub[central_node_index].subscribe(topic)
|
||||
|
||||
await trio.sleep(2)
|
||||
|
||||
# Check that the gossipsub of central node no longer has fanout for the topic
|
||||
assert topic not in gossipsubs[central_node_index].fanout
|
||||
|
||||
for i in hosts_indices:
|
||||
if i in subscribed_peer_indices:
|
||||
assert hosts[i].get_id() in gossipsubs[central_node_index].mesh[topic]
|
||||
assert hosts[central_node_index].get_id() in gossipsubs[i].mesh[topic]
|
||||
else:
|
||||
assert (
|
||||
hosts[i].get_id() not in gossipsubs[central_node_index].mesh[topic]
|
||||
)
|
||||
assert topic not in gossipsubs[i].mesh
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_leave():
|
||||
async with PubsubFactory.create_batch_with_gossipsub(1) as pubsubs_gsub:
|
||||
gossipsub = pubsubs_gsub[0].router
|
||||
topic = "test_leave"
|
||||
|
||||
assert topic not in gossipsub.mesh
|
||||
|
||||
await gossipsub.join(topic)
|
||||
assert topic in gossipsub.mesh
|
||||
|
||||
await gossipsub.leave(topic)
|
||||
assert topic not in gossipsub.mesh
|
||||
|
||||
# Test re-leave
|
||||
await gossipsub.leave(topic)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_handle_graft(monkeypatch):
|
||||
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
|
||||
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub)
|
||||
|
||||
index_alice = 0
|
||||
id_alice = pubsubs_gsub[index_alice].my_id
|
||||
index_bob = 1
|
||||
id_bob = pubsubs_gsub[index_bob].my_id
|
||||
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
|
||||
|
||||
# Wait 2 seconds for heartbeat to allow mesh to connect
|
||||
await trio.sleep(2)
|
||||
|
||||
topic = "test_handle_graft"
|
||||
# Only lice subscribe to the topic
|
||||
await gossipsubs[index_alice].join(topic)
|
||||
|
||||
# Monkey patch bob's `emit_prune` function so we can
|
||||
# check if it is called in `handle_graft`
|
||||
event_emit_prune = trio.Event()
|
||||
|
||||
async def emit_prune(topic, sender_peer_id):
|
||||
event_emit_prune.set()
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
monkeypatch.setattr(gossipsubs[index_bob], "emit_prune", emit_prune)
|
||||
|
||||
# Check that alice is bob's peer but not his mesh peer
|
||||
assert gossipsubs[index_bob].peer_protocol[id_alice] == PROTOCOL_ID
|
||||
assert topic not in gossipsubs[index_bob].mesh
|
||||
|
||||
await gossipsubs[index_alice].emit_graft(topic, id_bob)
|
||||
|
||||
# Check that `emit_prune` is called
|
||||
await event_emit_prune.wait()
|
||||
|
||||
# Check that bob is alice's peer but not her mesh peer
|
||||
assert topic in gossipsubs[index_alice].mesh
|
||||
assert id_bob not in gossipsubs[index_alice].mesh[topic]
|
||||
assert gossipsubs[index_alice].peer_protocol[id_bob] == PROTOCOL_ID
|
||||
|
||||
await gossipsubs[index_bob].emit_graft(topic, id_alice)
|
||||
|
||||
await trio.sleep(1)
|
||||
|
||||
# Check that bob is now alice's mesh peer
|
||||
assert id_bob in gossipsubs[index_alice].mesh[topic]
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_handle_prune():
|
||||
async with PubsubFactory.create_batch_with_gossipsub(
|
||||
2, heartbeat_interval=3
|
||||
) as pubsubs_gsub:
|
||||
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub)
|
||||
|
||||
index_alice = 0
|
||||
id_alice = pubsubs_gsub[index_alice].my_id
|
||||
index_bob = 1
|
||||
id_bob = pubsubs_gsub[index_bob].my_id
|
||||
|
||||
topic = "test_handle_prune"
|
||||
for pubsub in pubsubs_gsub:
|
||||
await pubsub.subscribe(topic)
|
||||
|
||||
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
|
||||
|
||||
# Wait for heartbeat to allow mesh to connect
|
||||
await trio.sleep(1)
|
||||
|
||||
# Check that they are each other's mesh peer
|
||||
assert id_alice in gossipsubs[index_bob].mesh[topic]
|
||||
assert id_bob in gossipsubs[index_alice].mesh[topic]
|
||||
|
||||
# alice emit prune message to bob, alice should be removed
|
||||
# from bob's mesh peer
|
||||
await gossipsubs[index_alice].emit_prune(topic, id_bob)
|
||||
# `emit_prune` does not remove bob from alice's mesh peers
|
||||
assert id_bob in gossipsubs[index_alice].mesh[topic]
|
||||
|
||||
# NOTE: We increase `heartbeat_interval` to 3 seconds so that bob will not
|
||||
# add alice back to his mesh after heartbeat.
|
||||
# Wait for bob to `handle_prune`
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Check that alice is no longer bob's mesh peer
|
||||
assert id_alice not in gossipsubs[index_bob].mesh[topic]
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_dense():
|
||||
async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub:
|
||||
hosts = [pubsub.host for pubsub in pubsubs_gsub]
|
||||
num_msgs = 5
|
||||
|
||||
# All pubsub subscribe to foobar
|
||||
queues = [await pubsub.subscribe("foobar") for pubsub in pubsubs_gsub]
|
||||
|
||||
# Densely connect libp2p hosts in a random way
|
||||
await dense_connect(hosts)
|
||||
|
||||
# Wait 2 seconds for heartbeat to allow mesh to connect
|
||||
await trio.sleep(2)
|
||||
|
||||
for i in range(num_msgs):
|
||||
msg_content = b"foo " + i.to_bytes(1, "big")
|
||||
|
||||
# randomly pick a message origin
|
||||
origin_idx = random.randint(0, len(hosts) - 1)
|
||||
|
||||
# publish from the randomly chosen host
|
||||
await pubsubs_gsub[origin_idx].publish("foobar", msg_content)
|
||||
|
||||
await trio.sleep(0.5)
|
||||
# Assert that all blocking queues receive the message
|
||||
for queue in queues:
|
||||
msg = await queue.get()
|
||||
assert msg.data == msg_content
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_fanout():
|
||||
async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub:
|
||||
hosts = [pubsub.host for pubsub in pubsubs_gsub]
|
||||
num_msgs = 5
|
||||
|
||||
# All pubsub subscribe to foobar except for `pubsubs_gsub[0]`
|
||||
subs = [await pubsub.subscribe("foobar") for pubsub in pubsubs_gsub[1:]]
|
||||
|
||||
# Sparsely connect libp2p hosts in random way
|
||||
await dense_connect(hosts)
|
||||
|
||||
# Wait 2 seconds for heartbeat to allow mesh to connect
|
||||
await trio.sleep(2)
|
||||
|
||||
topic = "foobar"
|
||||
# Send messages with origin not subscribed
|
||||
for i in range(num_msgs):
|
||||
msg_content = b"foo " + i.to_bytes(1, "big")
|
||||
|
||||
# Pick the message origin to the node that is not subscribed to 'foobar'
|
||||
origin_idx = 0
|
||||
|
||||
# publish from the randomly chosen host
|
||||
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
|
||||
|
||||
await trio.sleep(0.5)
|
||||
# Assert that all blocking queues receive the message
|
||||
for sub in subs:
|
||||
msg = await sub.get()
|
||||
assert msg.data == msg_content
|
||||
|
||||
# Subscribe message origin
|
||||
subs.insert(0, await pubsubs_gsub[0].subscribe(topic))
|
||||
|
||||
# Send messages again
|
||||
for i in range(num_msgs):
|
||||
msg_content = b"bar " + i.to_bytes(1, "big")
|
||||
|
||||
# Pick the message origin to the node that is not subscribed to 'foobar'
|
||||
origin_idx = 0
|
||||
|
||||
# publish from the randomly chosen host
|
||||
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
|
||||
|
||||
await trio.sleep(0.5)
|
||||
# Assert that all blocking queues receive the message
|
||||
for sub in subs:
|
||||
msg = await sub.get()
|
||||
assert msg.data == msg_content
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
@pytest.mark.slow
|
||||
async def test_fanout_maintenance():
|
||||
async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub:
|
||||
hosts = [pubsub.host for pubsub in pubsubs_gsub]
|
||||
num_msgs = 5
|
||||
|
||||
# All pubsub subscribe to foobar
|
||||
queues = []
|
||||
topic = "foobar"
|
||||
for i in range(1, len(pubsubs_gsub)):
|
||||
q = await pubsubs_gsub[i].subscribe(topic)
|
||||
|
||||
# Add each blocking queue to an array of blocking queues
|
||||
queues.append(q)
|
||||
|
||||
# Sparsely connect libp2p hosts in random way
|
||||
await dense_connect(hosts)
|
||||
|
||||
# Wait 2 seconds for heartbeat to allow mesh to connect
|
||||
await trio.sleep(2)
|
||||
|
||||
# Send messages with origin not subscribed
|
||||
for i in range(num_msgs):
|
||||
msg_content = b"foo " + i.to_bytes(1, "big")
|
||||
|
||||
# Pick the message origin to the node that is not subscribed to 'foobar'
|
||||
origin_idx = 0
|
||||
|
||||
# publish from the randomly chosen host
|
||||
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
|
||||
|
||||
await trio.sleep(0.5)
|
||||
# Assert that all blocking queues receive the message
|
||||
for queue in queues:
|
||||
msg = await queue.get()
|
||||
assert msg.data == msg_content
|
||||
|
||||
for sub in pubsubs_gsub:
|
||||
await sub.unsubscribe(topic)
|
||||
|
||||
queues = []
|
||||
|
||||
await trio.sleep(2)
|
||||
|
||||
# Resub and repeat
|
||||
for i in range(1, len(pubsubs_gsub)):
|
||||
q = await pubsubs_gsub[i].subscribe(topic)
|
||||
|
||||
# Add each blocking queue to an array of blocking queues
|
||||
queues.append(q)
|
||||
|
||||
await trio.sleep(2)
|
||||
|
||||
# Check messages can still be sent
|
||||
for i in range(num_msgs):
|
||||
msg_content = b"bar " + i.to_bytes(1, "big")
|
||||
|
||||
# Pick the message origin to the node that is not subscribed to 'foobar'
|
||||
origin_idx = 0
|
||||
|
||||
# publish from the randomly chosen host
|
||||
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
|
||||
|
||||
await trio.sleep(0.5)
|
||||
# Assert that all blocking queues receive the message
|
||||
for queue in queues:
|
||||
msg = await queue.get()
|
||||
assert msg.data == msg_content
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_gossip_propagation():
|
||||
async with PubsubFactory.create_batch_with_gossipsub(
|
||||
2, degree=1, degree_low=0, degree_high=2, gossip_window=50, gossip_history=100
|
||||
) as pubsubs_gsub:
|
||||
topic = "foo"
|
||||
queue_0 = await pubsubs_gsub[0].subscribe(topic)
|
||||
|
||||
# node 0 publish to topic
|
||||
msg_content = b"foo_msg"
|
||||
|
||||
# publish from the randomly chosen host
|
||||
await pubsubs_gsub[0].publish(topic, msg_content)
|
||||
|
||||
await trio.sleep(0.5)
|
||||
# Assert that the blocking queues receive the message
|
||||
msg = await queue_0.get()
|
||||
assert msg.data == msg_content
|
||||
|
||||
|
||||
@pytest.mark.parametrize("initial_mesh_peer_count", (7, 10, 13))
|
||||
@pytest.mark.trio
|
||||
async def test_mesh_heartbeat(initial_mesh_peer_count, monkeypatch):
|
||||
async with PubsubFactory.create_batch_with_gossipsub(
|
||||
1, heartbeat_initial_delay=100
|
||||
) as pubsubs_gsub:
|
||||
# It's difficult to set up the initial peer subscription condition.
|
||||
# Ideally I would like to have initial mesh peer count that's below
|
||||
# ``GossipSubDegree`` so I can test if `mesh_heartbeat` return correct peers to
|
||||
# GRAFT. The problem is that I can not set it up so that we have peers subscribe
|
||||
# to the topic but not being part of our mesh peers (as these peers are the
|
||||
# peers to GRAFT). So I monkeypatch the peer subscriptions and our mesh peers.
|
||||
total_peer_count = 14
|
||||
topic = "TEST_MESH_HEARTBEAT"
|
||||
|
||||
fake_peer_ids = [IDFactory() for _ in range(total_peer_count)]
|
||||
peer_protocol = {peer_id: PROTOCOL_ID for peer_id in fake_peer_ids}
|
||||
monkeypatch.setattr(pubsubs_gsub[0].router, "peer_protocol", peer_protocol)
|
||||
|
||||
peer_topics = {topic: set(fake_peer_ids)}
|
||||
# Monkeypatch the peer subscriptions
|
||||
monkeypatch.setattr(pubsubs_gsub[0], "peer_topics", peer_topics)
|
||||
|
||||
mesh_peer_indices = random.sample(
|
||||
range(total_peer_count), initial_mesh_peer_count
|
||||
)
|
||||
mesh_peers = [fake_peer_ids[i] for i in mesh_peer_indices]
|
||||
router_mesh = {topic: set(mesh_peers)}
|
||||
# Monkeypatch our mesh peers
|
||||
monkeypatch.setattr(pubsubs_gsub[0].router, "mesh", router_mesh)
|
||||
|
||||
peers_to_graft, peers_to_prune = pubsubs_gsub[0].router.mesh_heartbeat()
|
||||
if initial_mesh_peer_count > pubsubs_gsub[0].router.degree:
|
||||
# If number of initial mesh peers is more than `GossipSubDegree`,
|
||||
# we should PRUNE mesh peers
|
||||
assert len(peers_to_graft) == 0
|
||||
assert (
|
||||
len(peers_to_prune)
|
||||
== initial_mesh_peer_count - pubsubs_gsub[0].router.degree
|
||||
)
|
||||
for peer in peers_to_prune:
|
||||
assert peer in mesh_peers
|
||||
elif initial_mesh_peer_count < pubsubs_gsub[0].router.degree:
|
||||
# If number of initial mesh peers is less than `GossipSubDegree`,
|
||||
# we should GRAFT more peers
|
||||
assert len(peers_to_prune) == 0
|
||||
assert (
|
||||
len(peers_to_graft)
|
||||
== pubsubs_gsub[0].router.degree - initial_mesh_peer_count
|
||||
)
|
||||
for peer in peers_to_graft:
|
||||
assert peer not in mesh_peers
|
||||
else:
|
||||
assert len(peers_to_prune) == 0 and len(peers_to_graft) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("initial_peer_count", (1, 4, 7))
|
||||
@pytest.mark.trio
|
||||
async def test_gossip_heartbeat(initial_peer_count, monkeypatch):
|
||||
async with PubsubFactory.create_batch_with_gossipsub(
|
||||
1, heartbeat_initial_delay=100
|
||||
) as pubsubs_gsub:
|
||||
# The problem is that I can not set it up so that we have peers subscribe to the
|
||||
# topic but not being part of our mesh peers (as these peers are the peers to
|
||||
# GRAFT). So I monkeypatch the peer subscriptions and our mesh peers.
|
||||
total_peer_count = 28
|
||||
topic_mesh = "TEST_GOSSIP_HEARTBEAT_1"
|
||||
topic_fanout = "TEST_GOSSIP_HEARTBEAT_2"
|
||||
|
||||
fake_peer_ids = [IDFactory() for _ in range(total_peer_count)]
|
||||
peer_protocol = {peer_id: PROTOCOL_ID for peer_id in fake_peer_ids}
|
||||
monkeypatch.setattr(pubsubs_gsub[0].router, "peer_protocol", peer_protocol)
|
||||
|
||||
topic_mesh_peer_count = 14
|
||||
# Split into mesh peers and fanout peers
|
||||
peer_topics = {
|
||||
topic_mesh: set(fake_peer_ids[:topic_mesh_peer_count]),
|
||||
topic_fanout: set(fake_peer_ids[topic_mesh_peer_count:]),
|
||||
}
|
||||
# Monkeypatch the peer subscriptions
|
||||
monkeypatch.setattr(pubsubs_gsub[0], "peer_topics", peer_topics)
|
||||
|
||||
mesh_peer_indices = random.sample(
|
||||
range(topic_mesh_peer_count), initial_peer_count
|
||||
)
|
||||
mesh_peers = [fake_peer_ids[i] for i in mesh_peer_indices]
|
||||
router_mesh = {topic_mesh: set(mesh_peers)}
|
||||
# Monkeypatch our mesh peers
|
||||
monkeypatch.setattr(pubsubs_gsub[0].router, "mesh", router_mesh)
|
||||
fanout_peer_indices = random.sample(
|
||||
range(topic_mesh_peer_count, total_peer_count), initial_peer_count
|
||||
)
|
||||
fanout_peers = [fake_peer_ids[i] for i in fanout_peer_indices]
|
||||
router_fanout = {topic_fanout: set(fanout_peers)}
|
||||
# Monkeypatch our fanout peers
|
||||
monkeypatch.setattr(pubsubs_gsub[0].router, "fanout", router_fanout)
|
||||
|
||||
def window(topic):
|
||||
if topic == topic_mesh:
|
||||
return [topic_mesh]
|
||||
elif topic == topic_fanout:
|
||||
return [topic_fanout]
|
||||
else:
|
||||
return []
|
||||
|
||||
# Monkeypatch the memory cache messages
|
||||
monkeypatch.setattr(pubsubs_gsub[0].router.mcache, "window", window)
|
||||
|
||||
peers_to_gossip = pubsubs_gsub[0].router.gossip_heartbeat()
|
||||
# If our mesh peer count is less than `GossipSubDegree`, we should gossip to up
|
||||
# to `GossipSubDegree` peers (exclude mesh peers).
|
||||
if topic_mesh_peer_count - initial_peer_count < pubsubs_gsub[0].router.degree:
|
||||
# The same goes for fanout so it's two times the number of peers to gossip.
|
||||
assert len(peers_to_gossip) == 2 * (
|
||||
topic_mesh_peer_count - initial_peer_count
|
||||
)
|
||||
elif (
|
||||
topic_mesh_peer_count - initial_peer_count >= pubsubs_gsub[0].router.degree
|
||||
):
|
||||
assert len(peers_to_gossip) == 2 * (pubsubs_gsub[0].router.degree)
|
||||
|
||||
for peer in peers_to_gossip:
|
||||
if peer in peer_topics[topic_mesh]:
|
||||
# Check that the peer to gossip to is not in our mesh peers
|
||||
assert peer not in mesh_peers
|
||||
assert topic_mesh in peers_to_gossip[peer]
|
||||
elif peer in peer_topics[topic_fanout]:
|
||||
# Check that the peer to gossip to is not in our fanout peers
|
||||
assert peer not in fanout_peers
|
||||
assert topic_fanout in peers_to_gossip[peer]
|
||||
31
tests/core/pubsub/test_gossipsub_backward_compatibility.py
Normal file
31
tests/core/pubsub/test_gossipsub_backward_compatibility.py
Normal file
@ -0,0 +1,31 @@
|
||||
import functools
|
||||
|
||||
import pytest
|
||||
|
||||
from libp2p.tools.constants import (
|
||||
FLOODSUB_PROTOCOL_ID,
|
||||
)
|
||||
from libp2p.tools.factories import (
|
||||
PubsubFactory,
|
||||
)
|
||||
from libp2p.tools.pubsub.floodsub_integration_test_settings import (
|
||||
floodsub_protocol_pytest_params,
|
||||
perform_test_from_obj,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params)
|
||||
@pytest.mark.trio
|
||||
@pytest.mark.slow
|
||||
async def test_gossipsub_run_with_floodsub_tests(test_case_obj):
|
||||
await perform_test_from_obj(
|
||||
test_case_obj,
|
||||
functools.partial(
|
||||
PubsubFactory.create_batch_with_gossipsub,
|
||||
protocols=[FLOODSUB_PROTOCOL_ID],
|
||||
degree=3,
|
||||
degree_low=2,
|
||||
degree_high=4,
|
||||
time_to_live=30,
|
||||
),
|
||||
)
|
||||
130
tests/core/pubsub/test_mcache.py
Normal file
130
tests/core/pubsub/test_mcache.py
Normal file
@ -0,0 +1,130 @@
|
||||
from libp2p.pubsub.mcache import (
|
||||
MessageCache,
|
||||
)
|
||||
|
||||
|
||||
class Msg:
|
||||
__slots__ = ["topicIDs", "seqno", "from_id"]
|
||||
|
||||
def __init__(self, topicIDs, seqno, from_id):
|
||||
self.topicIDs = topicIDs
|
||||
self.seqno = seqno
|
||||
self.from_id = from_id
|
||||
|
||||
|
||||
def test_mcache():
|
||||
# Ported from:
|
||||
# https://github.com/libp2p/go-libp2p-pubsub/blob/51b7501433411b5096cac2b4994a36a68515fc03/mcache_test.go
|
||||
mcache = MessageCache(3, 5)
|
||||
msgs = []
|
||||
|
||||
for i in range(60):
|
||||
msgs.append(Msg(["test"], i, "test"))
|
||||
|
||||
for i in range(10):
|
||||
mcache.put(msgs[i])
|
||||
|
||||
for i in range(10):
|
||||
msg = msgs[i]
|
||||
mid = (msg.seqno, msg.from_id)
|
||||
get_msg = mcache.get(mid)
|
||||
|
||||
# successful read
|
||||
assert get_msg == msg
|
||||
|
||||
gids = mcache.window("test")
|
||||
|
||||
assert len(gids) == 10
|
||||
|
||||
for i in range(10):
|
||||
msg = msgs[i]
|
||||
mid = (msg.seqno, msg.from_id)
|
||||
|
||||
assert mid == gids[i]
|
||||
|
||||
mcache.shift()
|
||||
|
||||
for i in range(10, 20):
|
||||
mcache.put(msgs[i])
|
||||
|
||||
for i in range(20):
|
||||
msg = msgs[i]
|
||||
mid = (msg.seqno, msg.from_id)
|
||||
get_msg = mcache.get(mid)
|
||||
|
||||
assert get_msg == msg
|
||||
|
||||
gids = mcache.window("test")
|
||||
|
||||
assert len(gids) == 20
|
||||
|
||||
for i in range(10):
|
||||
msg = msgs[i]
|
||||
mid = (msg.seqno, msg.from_id)
|
||||
|
||||
assert mid == gids[10 + i]
|
||||
|
||||
for i in range(10, 20):
|
||||
msg = msgs[i]
|
||||
mid = (msg.seqno, msg.from_id)
|
||||
|
||||
assert mid == gids[i - 10]
|
||||
|
||||
mcache.shift()
|
||||
|
||||
for i in range(20, 30):
|
||||
mcache.put(msgs[i])
|
||||
|
||||
mcache.shift()
|
||||
|
||||
for i in range(30, 40):
|
||||
mcache.put(msgs[i])
|
||||
|
||||
mcache.shift()
|
||||
|
||||
for i in range(40, 50):
|
||||
mcache.put(msgs[i])
|
||||
|
||||
mcache.shift()
|
||||
|
||||
for i in range(50, 60):
|
||||
mcache.put(msgs[i])
|
||||
|
||||
assert len(mcache.msgs) == 50
|
||||
|
||||
for i in range(10):
|
||||
msg = msgs[i]
|
||||
mid = (msg.seqno, msg.from_id)
|
||||
get_msg = mcache.get(mid)
|
||||
|
||||
# Should be evicted from cache
|
||||
assert not get_msg
|
||||
|
||||
for i in range(10, 60):
|
||||
msg = msgs[i]
|
||||
mid = (msg.seqno, msg.from_id)
|
||||
get_msg = mcache.get(mid)
|
||||
|
||||
assert get_msg == msg
|
||||
|
||||
gids = mcache.window("test")
|
||||
|
||||
assert len(gids) == 30
|
||||
|
||||
for i in range(10):
|
||||
msg = msgs[50 + i]
|
||||
mid = (msg.seqno, msg.from_id)
|
||||
|
||||
assert mid == gids[i]
|
||||
|
||||
for i in range(10, 20):
|
||||
msg = msgs[30 + i]
|
||||
mid = (msg.seqno, msg.from_id)
|
||||
|
||||
assert mid == gids[i]
|
||||
|
||||
for i in range(20, 30):
|
||||
msg = msgs[10 + i]
|
||||
mid = (msg.seqno, msg.from_id)
|
||||
|
||||
assert mid == gids[i]
|
||||
684
tests/core/pubsub/test_pubsub.py
Normal file
684
tests/core/pubsub/test_pubsub.py
Normal file
@ -0,0 +1,684 @@
|
||||
from contextlib import (
|
||||
contextmanager,
|
||||
)
|
||||
from typing import (
|
||||
NamedTuple,
|
||||
)
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.exceptions import (
|
||||
ValidationError,
|
||||
)
|
||||
from libp2p.pubsub.pb import (
|
||||
rpc_pb2,
|
||||
)
|
||||
from libp2p.pubsub.pubsub import (
|
||||
PUBSUB_SIGNING_PREFIX,
|
||||
SUBSCRIPTION_CHANNEL_SIZE,
|
||||
)
|
||||
from libp2p.tools.constants import (
|
||||
MAX_READ_LEN,
|
||||
)
|
||||
from libp2p.tools.factories import (
|
||||
IDFactory,
|
||||
PubsubFactory,
|
||||
net_stream_pair_factory,
|
||||
)
|
||||
from libp2p.tools.pubsub.utils import (
|
||||
make_pubsub_msg,
|
||||
)
|
||||
from libp2p.tools.utils import (
|
||||
connect,
|
||||
)
|
||||
from libp2p.utils import (
|
||||
encode_varint_prefixed,
|
||||
)
|
||||
|
||||
TESTING_TOPIC = "TEST_SUBSCRIBE"
|
||||
TESTING_DATA = b"data"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_subscribe_and_unsubscribe():
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids
|
||||
|
||||
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_re_subscribe():
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids
|
||||
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_re_unsubscribe():
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
# Unsubscribe from topic we didn't even subscribe to
|
||||
assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].topic_ids
|
||||
await pubsubs_fsub[0].unsubscribe("NOT_MY_TOPIC")
|
||||
assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].topic_ids
|
||||
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids
|
||||
|
||||
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids
|
||||
|
||||
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_peers_subscribe():
|
||||
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
|
||||
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
# Yield to let 0 notify 1
|
||||
await trio.sleep(1)
|
||||
assert pubsubs_fsub[0].my_id in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
|
||||
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
||||
# Yield to let 0 notify 1
|
||||
await trio.sleep(1)
|
||||
assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_get_hello_packet():
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
|
||||
def _get_hello_packet_topic_ids():
|
||||
packet = pubsubs_fsub[0].get_hello_packet()
|
||||
return tuple(sub.topicid for sub in packet.subscriptions)
|
||||
|
||||
# Test: No subscription, so there should not be any topic ids in the
|
||||
# hello packet.
|
||||
assert len(_get_hello_packet_topic_ids()) == 0
|
||||
|
||||
# Test: After subscriptions, topic ids should be in the hello packet.
|
||||
topic_ids = ["t", "o", "p", "i", "c"]
|
||||
for topic in topic_ids:
|
||||
await pubsubs_fsub[0].subscribe(topic)
|
||||
topic_ids_in_hello = _get_hello_packet_topic_ids()
|
||||
for topic in topic_ids:
|
||||
assert topic in topic_ids_in_hello
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_set_and_remove_topic_validator():
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
is_sync_validator_called = False
|
||||
|
||||
def sync_validator(peer_id, msg):
|
||||
nonlocal is_sync_validator_called
|
||||
is_sync_validator_called = True
|
||||
|
||||
is_async_validator_called = False
|
||||
|
||||
async def async_validator(peer_id, msg):
|
||||
nonlocal is_async_validator_called
|
||||
is_async_validator_called = True
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
topic = "TEST_VALIDATOR"
|
||||
|
||||
assert topic not in pubsubs_fsub[0].topic_validators
|
||||
|
||||
# Register sync validator
|
||||
pubsubs_fsub[0].set_topic_validator(topic, sync_validator, False)
|
||||
|
||||
assert topic in pubsubs_fsub[0].topic_validators
|
||||
topic_validator = pubsubs_fsub[0].topic_validators[topic]
|
||||
assert not topic_validator.is_async
|
||||
|
||||
# Validate with sync validator
|
||||
topic_validator.validator(peer_id=IDFactory(), msg="msg")
|
||||
|
||||
assert is_sync_validator_called
|
||||
assert not is_async_validator_called
|
||||
|
||||
# Register with async validator
|
||||
pubsubs_fsub[0].set_topic_validator(topic, async_validator, True)
|
||||
|
||||
is_sync_validator_called = False
|
||||
assert topic in pubsubs_fsub[0].topic_validators
|
||||
topic_validator = pubsubs_fsub[0].topic_validators[topic]
|
||||
assert topic_validator.is_async
|
||||
|
||||
# Validate with async validator
|
||||
await topic_validator.validator(peer_id=IDFactory(), msg="msg")
|
||||
|
||||
assert is_async_validator_called
|
||||
assert not is_sync_validator_called
|
||||
|
||||
# Remove validator
|
||||
pubsubs_fsub[0].remove_topic_validator(topic)
|
||||
assert topic not in pubsubs_fsub[0].topic_validators
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_get_msg_validators():
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
times_sync_validator_called = 0
|
||||
|
||||
def sync_validator(peer_id, msg):
|
||||
nonlocal times_sync_validator_called
|
||||
times_sync_validator_called += 1
|
||||
|
||||
times_async_validator_called = 0
|
||||
|
||||
async def async_validator(peer_id, msg):
|
||||
nonlocal times_async_validator_called
|
||||
times_async_validator_called += 1
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
topic_1 = "TEST_VALIDATOR_1"
|
||||
topic_2 = "TEST_VALIDATOR_2"
|
||||
topic_3 = "TEST_VALIDATOR_3"
|
||||
|
||||
# Register sync validator for topic 1 and 2
|
||||
pubsubs_fsub[0].set_topic_validator(topic_1, sync_validator, False)
|
||||
pubsubs_fsub[0].set_topic_validator(topic_2, sync_validator, False)
|
||||
|
||||
# Register async validator for topic 3
|
||||
pubsubs_fsub[0].set_topic_validator(topic_3, async_validator, True)
|
||||
|
||||
msg = make_pubsub_msg(
|
||||
origin_id=pubsubs_fsub[0].my_id,
|
||||
topic_ids=[topic_1, topic_2, topic_3],
|
||||
data=b"1234",
|
||||
seqno=b"\x00" * 8,
|
||||
)
|
||||
|
||||
topic_validators = pubsubs_fsub[0].get_msg_validators(msg)
|
||||
for topic_validator in topic_validators:
|
||||
if topic_validator.is_async:
|
||||
await topic_validator.validator(peer_id=IDFactory(), msg="msg")
|
||||
else:
|
||||
topic_validator.validator(peer_id=IDFactory(), msg="msg")
|
||||
|
||||
assert times_sync_validator_called == 2
|
||||
assert times_async_validator_called == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"is_topic_1_val_passed, is_topic_2_val_passed",
|
||||
((False, True), (True, False), (True, True)),
|
||||
)
|
||||
@pytest.mark.trio
|
||||
async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed):
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
|
||||
def passed_sync_validator(peer_id, msg):
|
||||
return True
|
||||
|
||||
def failed_sync_validator(peer_id, msg):
|
||||
return False
|
||||
|
||||
async def passed_async_validator(peer_id, msg):
|
||||
await trio.lowlevel.checkpoint()
|
||||
return True
|
||||
|
||||
async def failed_async_validator(peer_id, msg):
|
||||
await trio.lowlevel.checkpoint()
|
||||
return False
|
||||
|
||||
topic_1 = "TEST_SYNC_VALIDATOR"
|
||||
topic_2 = "TEST_ASYNC_VALIDATOR"
|
||||
|
||||
if is_topic_1_val_passed:
|
||||
pubsubs_fsub[0].set_topic_validator(topic_1, passed_sync_validator, False)
|
||||
else:
|
||||
pubsubs_fsub[0].set_topic_validator(topic_1, failed_sync_validator, False)
|
||||
|
||||
if is_topic_2_val_passed:
|
||||
pubsubs_fsub[0].set_topic_validator(topic_2, passed_async_validator, True)
|
||||
else:
|
||||
pubsubs_fsub[0].set_topic_validator(topic_2, failed_async_validator, True)
|
||||
|
||||
msg = make_pubsub_msg(
|
||||
origin_id=pubsubs_fsub[0].my_id,
|
||||
topic_ids=[topic_1, topic_2],
|
||||
data=b"1234",
|
||||
seqno=b"\x00" * 8,
|
||||
)
|
||||
|
||||
if is_topic_1_val_passed and is_topic_2_val_passed:
|
||||
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
|
||||
else:
|
||||
with pytest.raises(ValidationError):
|
||||
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_continuously_read_stream(monkeypatch, nursery, security_protocol):
|
||||
async def wait_for_event_occurring(event):
|
||||
await trio.lowlevel.checkpoint()
|
||||
with trio.fail_after(0.1):
|
||||
await event.wait()
|
||||
|
||||
class Events(NamedTuple):
|
||||
push_msg: trio.Event
|
||||
handle_subscription: trio.Event
|
||||
handle_rpc: trio.Event
|
||||
|
||||
@contextmanager
|
||||
def mock_methods():
|
||||
event_push_msg = trio.Event()
|
||||
event_handle_subscription = trio.Event()
|
||||
event_handle_rpc = trio.Event()
|
||||
|
||||
async def mock_push_msg(msg_forwarder, msg):
|
||||
event_push_msg.set()
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
def mock_handle_subscription(origin_id, sub_message):
|
||||
event_handle_subscription.set()
|
||||
|
||||
async def mock_handle_rpc(rpc, sender_peer_id):
|
||||
event_handle_rpc.set()
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setattr(pubsubs_fsub[0], "push_msg", mock_push_msg)
|
||||
m.setattr(pubsubs_fsub[0], "handle_subscription", mock_handle_subscription)
|
||||
m.setattr(pubsubs_fsub[0].router, "handle_rpc", mock_handle_rpc)
|
||||
yield Events(event_push_msg, event_handle_subscription, event_handle_rpc)
|
||||
|
||||
async with PubsubFactory.create_batch_with_floodsub(
|
||||
1, security_protocol=security_protocol
|
||||
) as pubsubs_fsub, net_stream_pair_factory(
|
||||
security_protocol=security_protocol
|
||||
) as stream_pair:
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
# Kick off the task `continuously_read_stream`
|
||||
nursery.start_soon(pubsubs_fsub[0].continuously_read_stream, stream_pair[0])
|
||||
|
||||
# Test: `push_msg` is called when publishing to a subscribed topic.
|
||||
publish_subscribed_topic = rpc_pb2.RPC(
|
||||
publish=[rpc_pb2.Message(topicIDs=[TESTING_TOPIC])]
|
||||
)
|
||||
with mock_methods() as events:
|
||||
await stream_pair[1].write(
|
||||
encode_varint_prefixed(publish_subscribed_topic.SerializeToString())
|
||||
)
|
||||
await wait_for_event_occurring(events.push_msg)
|
||||
# Make sure the other events are not emitted.
|
||||
with pytest.raises(trio.TooSlowError):
|
||||
await wait_for_event_occurring(events.handle_subscription)
|
||||
with pytest.raises(trio.TooSlowError):
|
||||
await wait_for_event_occurring(events.handle_rpc)
|
||||
|
||||
# Test: `push_msg` is not called when publishing to a topic-not-subscribed.
|
||||
publish_not_subscribed_topic = rpc_pb2.RPC(
|
||||
publish=[rpc_pb2.Message(topicIDs=["NOT_SUBSCRIBED"])]
|
||||
)
|
||||
with mock_methods() as events:
|
||||
await stream_pair[1].write(
|
||||
encode_varint_prefixed(publish_not_subscribed_topic.SerializeToString())
|
||||
)
|
||||
with pytest.raises(trio.TooSlowError):
|
||||
await wait_for_event_occurring(events.push_msg)
|
||||
|
||||
# Test: `handle_subscription` is called when a subscription message is received.
|
||||
subscription_msg = rpc_pb2.RPC(subscriptions=[rpc_pb2.RPC.SubOpts()])
|
||||
with mock_methods() as events:
|
||||
await stream_pair[1].write(
|
||||
encode_varint_prefixed(subscription_msg.SerializeToString())
|
||||
)
|
||||
await wait_for_event_occurring(events.handle_subscription)
|
||||
# Make sure the other events are not emitted.
|
||||
with pytest.raises(trio.TooSlowError):
|
||||
await wait_for_event_occurring(events.push_msg)
|
||||
with pytest.raises(trio.TooSlowError):
|
||||
await wait_for_event_occurring(events.handle_rpc)
|
||||
|
||||
# Test: `handle_rpc` is called when a control message is received.
|
||||
control_msg = rpc_pb2.RPC(control=rpc_pb2.ControlMessage())
|
||||
with mock_methods() as events:
|
||||
await stream_pair[1].write(
|
||||
encode_varint_prefixed(control_msg.SerializeToString())
|
||||
)
|
||||
await wait_for_event_occurring(events.handle_rpc)
|
||||
# Make sure the other events are not emitted.
|
||||
with pytest.raises(trio.TooSlowError):
|
||||
await wait_for_event_occurring(events.push_msg)
|
||||
with pytest.raises(trio.TooSlowError):
|
||||
await wait_for_event_occurring(events.handle_subscription)
|
||||
|
||||
|
||||
# TODO: Add the following tests after they are aligned with Go.
|
||||
# (Issue #191: https://github.com/libp2p/py-libp2p/issues/191)
|
||||
# - `test_stream_handler`
|
||||
# - `test_handle_peer_queue`
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_handle_subscription():
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
assert len(pubsubs_fsub[0].peer_topics) == 0
|
||||
sub_msg_0 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=TESTING_TOPIC)
|
||||
peer_ids = [IDFactory() for _ in range(2)]
|
||||
# Test: One peer is subscribed
|
||||
pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_0)
|
||||
assert (
|
||||
len(pubsubs_fsub[0].peer_topics) == 1
|
||||
and TESTING_TOPIC in pubsubs_fsub[0].peer_topics
|
||||
)
|
||||
assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 1
|
||||
assert peer_ids[0] in pubsubs_fsub[0].peer_topics[TESTING_TOPIC]
|
||||
# Test: Another peer is subscribed
|
||||
pubsubs_fsub[0].handle_subscription(peer_ids[1], sub_msg_0)
|
||||
assert len(pubsubs_fsub[0].peer_topics) == 1
|
||||
assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 2
|
||||
assert peer_ids[1] in pubsubs_fsub[0].peer_topics[TESTING_TOPIC]
|
||||
# Test: Subscribe to another topic
|
||||
another_topic = "ANOTHER_TOPIC"
|
||||
sub_msg_1 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=another_topic)
|
||||
pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_1)
|
||||
assert len(pubsubs_fsub[0].peer_topics) == 2
|
||||
assert another_topic in pubsubs_fsub[0].peer_topics
|
||||
assert peer_ids[0] in pubsubs_fsub[0].peer_topics[another_topic]
|
||||
# Test: unsubscribe
|
||||
unsub_msg = rpc_pb2.RPC.SubOpts(subscribe=False, topicid=TESTING_TOPIC)
|
||||
pubsubs_fsub[0].handle_subscription(peer_ids[0], unsub_msg)
|
||||
assert peer_ids[0] not in pubsubs_fsub[0].peer_topics[TESTING_TOPIC]
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_handle_talk():
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
msg_0 = make_pubsub_msg(
|
||||
origin_id=pubsubs_fsub[0].my_id,
|
||||
topic_ids=[TESTING_TOPIC],
|
||||
data=b"1234",
|
||||
seqno=b"\x00" * 8,
|
||||
)
|
||||
pubsubs_fsub[0].notify_subscriptions(msg_0)
|
||||
msg_1 = make_pubsub_msg(
|
||||
origin_id=pubsubs_fsub[0].my_id,
|
||||
topic_ids=["NOT_SUBSCRIBED"],
|
||||
data=b"1234",
|
||||
seqno=b"\x11" * 8,
|
||||
)
|
||||
pubsubs_fsub[0].notify_subscriptions(msg_1)
|
||||
assert (
|
||||
len(pubsubs_fsub[0].topic_ids) == 1
|
||||
and sub == pubsubs_fsub[0].subscribed_topics_receive[TESTING_TOPIC]
|
||||
)
|
||||
assert (await sub.get()) == msg_0
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_message_all_peers(monkeypatch, security_protocol):
|
||||
async with PubsubFactory.create_batch_with_floodsub(
|
||||
1, security_protocol=security_protocol
|
||||
) as pubsubs_fsub, net_stream_pair_factory(
|
||||
security_protocol=security_protocol
|
||||
) as stream_pair:
|
||||
peer_id = IDFactory()
|
||||
mock_peers = {peer_id: stream_pair[0]}
|
||||
with monkeypatch.context() as m:
|
||||
m.setattr(pubsubs_fsub[0], "peers", mock_peers)
|
||||
|
||||
empty_rpc = rpc_pb2.RPC()
|
||||
empty_rpc_bytes = empty_rpc.SerializeToString()
|
||||
empty_rpc_bytes_len_prefixed = encode_varint_prefixed(empty_rpc_bytes)
|
||||
await pubsubs_fsub[0].message_all_peers(empty_rpc_bytes)
|
||||
assert (
|
||||
await stream_pair[1].read(MAX_READ_LEN)
|
||||
) == empty_rpc_bytes_len_prefixed
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_subscribe_and_publish():
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
pubsub = pubsubs_fsub[0]
|
||||
|
||||
list_data = [b"d0", b"d1"]
|
||||
event_receive_data_started = trio.Event()
|
||||
|
||||
async def publish_data(topic):
|
||||
await event_receive_data_started.wait()
|
||||
for data in list_data:
|
||||
await pubsub.publish(topic, data)
|
||||
|
||||
async def receive_data(topic):
|
||||
i = 0
|
||||
event_receive_data_started.set()
|
||||
assert topic not in pubsub.topic_ids
|
||||
subscription = await pubsub.subscribe(topic)
|
||||
async with subscription:
|
||||
assert topic in pubsub.topic_ids
|
||||
async for msg in subscription:
|
||||
assert msg.data == list_data[i]
|
||||
i += 1
|
||||
if i == len(list_data):
|
||||
break
|
||||
assert topic not in pubsub.topic_ids
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(receive_data, TESTING_TOPIC)
|
||||
nursery.start_soon(publish_data, TESTING_TOPIC)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_subscribe_and_publish_full_channel():
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
pubsub = pubsubs_fsub[0]
|
||||
|
||||
extra_data_0 = b"extra_data_0"
|
||||
extra_data_1 = b"extra_data_1"
|
||||
|
||||
# Test: Subscription channel is of size `SUBSCRIPTION_CHANNEL_SIZE`.
|
||||
# When the channel is full, new received messages are dropped.
|
||||
# After the channel has empty slot, the channel can receive new messages.
|
||||
|
||||
# Assume `SUBSCRIPTION_CHANNEL_SIZE` is smaller than `2**(4*8)`.
|
||||
list_data = [i.to_bytes(4, "big") for i in range(SUBSCRIPTION_CHANNEL_SIZE)]
|
||||
# Expect `extra_data_0` is dropped and `extra_data_1` is appended.
|
||||
expected_list_data = list_data + [extra_data_1]
|
||||
|
||||
subscription = await pubsub.subscribe(TESTING_TOPIC)
|
||||
for data in list_data:
|
||||
await pubsub.publish(TESTING_TOPIC, data)
|
||||
|
||||
# Publish `extra_data_0` which should be dropped since the channel is
|
||||
# already full.
|
||||
await pubsub.publish(TESTING_TOPIC, extra_data_0)
|
||||
# Consume a message and there is an empty slot in the channel.
|
||||
assert (await subscription.get()).data == expected_list_data.pop(0)
|
||||
# Publish `extra_data_1` which should be appended to the channel.
|
||||
await pubsub.publish(TESTING_TOPIC, extra_data_1)
|
||||
|
||||
for expected_data in expected_list_data:
|
||||
assert (await subscription.get()).data == expected_data
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_publish_push_msg_is_called(monkeypatch):
|
||||
msg_forwarders = []
|
||||
msgs = []
|
||||
|
||||
async def push_msg(msg_forwarder, msg):
|
||||
msg_forwarders.append(msg_forwarder)
|
||||
msgs.append(msg)
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
with monkeypatch.context() as m:
|
||||
m.setattr(pubsubs_fsub[0], "push_msg", push_msg)
|
||||
|
||||
await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA)
|
||||
await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA)
|
||||
|
||||
assert (
|
||||
len(msgs) == 2
|
||||
), "`push_msg` should be called every time `publish` is called"
|
||||
assert (msg_forwarders[0] == msg_forwarders[1]) and (
|
||||
msg_forwarders[1] == pubsubs_fsub[0].my_id
|
||||
)
|
||||
assert (
|
||||
msgs[0].seqno != msgs[1].seqno
|
||||
), "`seqno` should be different every time"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_push_msg(monkeypatch):
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
msg_0 = make_pubsub_msg(
|
||||
origin_id=pubsubs_fsub[0].my_id,
|
||||
topic_ids=[TESTING_TOPIC],
|
||||
data=TESTING_DATA,
|
||||
seqno=b"\x00" * 8,
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def mock_router_publish():
|
||||
event = trio.Event()
|
||||
|
||||
async def router_publish(*args, **kwargs):
|
||||
event.set()
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setattr(pubsubs_fsub[0].router, "publish", router_publish)
|
||||
yield event
|
||||
|
||||
with mock_router_publish() as event:
|
||||
# Test: `msg` is not seen before `push_msg`, and is seen after `push_msg`.
|
||||
assert not pubsubs_fsub[0]._is_msg_seen(msg_0)
|
||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0)
|
||||
assert pubsubs_fsub[0]._is_msg_seen(msg_0)
|
||||
# Test: Ensure `router.publish` is called in `push_msg`
|
||||
with trio.fail_after(0.1):
|
||||
await event.wait()
|
||||
|
||||
with mock_router_publish() as event:
|
||||
# Test: `push_msg` the message again and it will be reject.
|
||||
# `router_publish` is not called then.
|
||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0)
|
||||
await trio.sleep(0.01)
|
||||
assert not event.is_set()
|
||||
|
||||
sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
# Test: `push_msg` succeeds with another unseen msg.
|
||||
msg_1 = make_pubsub_msg(
|
||||
origin_id=pubsubs_fsub[0].my_id,
|
||||
topic_ids=[TESTING_TOPIC],
|
||||
data=TESTING_DATA,
|
||||
seqno=b"\x11" * 8,
|
||||
)
|
||||
assert not pubsubs_fsub[0]._is_msg_seen(msg_1)
|
||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_1)
|
||||
assert pubsubs_fsub[0]._is_msg_seen(msg_1)
|
||||
with trio.fail_after(0.1):
|
||||
await event.wait()
|
||||
# Test: Subscribers are notified when `push_msg` new messages.
|
||||
assert (await sub.get()) == msg_1
|
||||
|
||||
with mock_router_publish() as event:
|
||||
# Test: add a topic validator and `push_msg` the message that
|
||||
# does not pass the validation.
|
||||
# `router_publish` is not called then.
|
||||
def failed_sync_validator(peer_id, msg):
|
||||
return False
|
||||
|
||||
pubsubs_fsub[0].set_topic_validator(
|
||||
TESTING_TOPIC, failed_sync_validator, False
|
||||
)
|
||||
|
||||
msg_2 = make_pubsub_msg(
|
||||
origin_id=pubsubs_fsub[0].my_id,
|
||||
topic_ids=[TESTING_TOPIC],
|
||||
data=TESTING_DATA,
|
||||
seqno=b"\x22" * 8,
|
||||
)
|
||||
|
||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_2)
|
||||
await trio.sleep(0.01)
|
||||
assert not event.is_set()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_strict_signing():
|
||||
async with PubsubFactory.create_batch_with_floodsub(
|
||||
2, strict_signing=True
|
||||
) as pubsubs_fsub:
|
||||
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
await pubsubs_fsub[1].subscribe(TESTING_TOPIC)
|
||||
await trio.sleep(1)
|
||||
|
||||
await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA)
|
||||
await trio.sleep(1)
|
||||
|
||||
assert len(pubsubs_fsub[0].seen_messages) == 1
|
||||
assert len(pubsubs_fsub[1].seen_messages) == 1
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_strict_signing_failed_validation(monkeypatch):
|
||||
async with PubsubFactory.create_batch_with_floodsub(
|
||||
2, strict_signing=True
|
||||
) as pubsubs_fsub:
|
||||
msg = make_pubsub_msg(
|
||||
origin_id=pubsubs_fsub[0].my_id,
|
||||
topic_ids=[TESTING_TOPIC],
|
||||
data=TESTING_DATA,
|
||||
seqno=b"\x00" * 8,
|
||||
)
|
||||
priv_key = pubsubs_fsub[0].sign_key
|
||||
signature = priv_key.sign(
|
||||
PUBSUB_SIGNING_PREFIX.encode() + msg.SerializeToString()
|
||||
)
|
||||
|
||||
event = trio.Event()
|
||||
|
||||
def _is_msg_seen(msg):
|
||||
return False
|
||||
|
||||
# Use router publish to check if `push_msg` succeed.
|
||||
async def router_publish(*args, **kwargs):
|
||||
await trio.lowlevel.checkpoint()
|
||||
# The event will only be set if `push_msg` succeed.
|
||||
event.set()
|
||||
|
||||
monkeypatch.setattr(pubsubs_fsub[0], "_is_msg_seen", _is_msg_seen)
|
||||
monkeypatch.setattr(pubsubs_fsub[0].router, "publish", router_publish)
|
||||
|
||||
# Test: no signature attached in `msg`
|
||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg)
|
||||
await trio.sleep(0.01)
|
||||
assert not event.is_set()
|
||||
|
||||
# Test: `msg.key` does not match `msg.from_id`
|
||||
msg.key = pubsubs_fsub[1].host.get_public_key().serialize()
|
||||
msg.signature = signature
|
||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg)
|
||||
await trio.sleep(0.01)
|
||||
assert not event.is_set()
|
||||
|
||||
# Test: invalid signature
|
||||
msg.key = pubsubs_fsub[0].host.get_public_key().serialize()
|
||||
msg.signature = b"\x12" * 100
|
||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg)
|
||||
await trio.sleep(0.01)
|
||||
assert not event.is_set()
|
||||
|
||||
# Finally, assert the signature indeed will pass validation
|
||||
msg.key = pubsubs_fsub[0].host.get_public_key().serialize()
|
||||
msg.signature = signature
|
||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg)
|
||||
await trio.sleep(0.01)
|
||||
assert event.is_set()
|
||||
88
tests/core/pubsub/test_subscription.py
Normal file
88
tests/core/pubsub/test_subscription.py
Normal file
@ -0,0 +1,88 @@
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.pubsub.pb import (
|
||||
rpc_pb2,
|
||||
)
|
||||
from libp2p.pubsub.subscription import (
|
||||
TrioSubscriptionAPI,
|
||||
)
|
||||
|
||||
GET_TIMEOUT = 0.001
|
||||
|
||||
|
||||
def make_trio_subscription():
|
||||
send_channel, receive_channel = trio.open_memory_channel(math.inf)
|
||||
|
||||
async def unsubscribe_fn():
|
||||
await send_channel.aclose()
|
||||
|
||||
return (
|
||||
send_channel,
|
||||
TrioSubscriptionAPI(receive_channel, unsubscribe_fn=unsubscribe_fn),
|
||||
)
|
||||
|
||||
|
||||
def make_pubsub_msg():
|
||||
return rpc_pb2.Message()
|
||||
|
||||
|
||||
async def send_something(send_channel):
|
||||
msg = make_pubsub_msg()
|
||||
await send_channel.send(msg)
|
||||
return msg
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_trio_subscription_get():
|
||||
send_channel, sub = make_trio_subscription()
|
||||
data_0 = await send_something(send_channel)
|
||||
data_1 = await send_something(send_channel)
|
||||
assert data_0 == await sub.get()
|
||||
assert data_1 == await sub.get()
|
||||
# No more message
|
||||
with pytest.raises(trio.TooSlowError):
|
||||
with trio.fail_after(GET_TIMEOUT):
|
||||
await sub.get()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_trio_subscription_iter():
|
||||
send_channel, sub = make_trio_subscription()
|
||||
received_data = []
|
||||
|
||||
async def iter_subscriptions(subscription):
|
||||
async for data in sub:
|
||||
received_data.append(data)
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(iter_subscriptions, sub)
|
||||
await send_something(send_channel)
|
||||
await send_something(send_channel)
|
||||
await send_channel.aclose()
|
||||
|
||||
assert len(received_data) == 2
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_trio_subscription_unsubscribe():
|
||||
send_channel, sub = make_trio_subscription()
|
||||
await sub.unsubscribe()
|
||||
# Test: If the subscription is unsubscribed, `send_channel` should be closed.
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await send_something(send_channel)
|
||||
# Test: No side effect when cancelled twice.
|
||||
await sub.unsubscribe()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_trio_subscription_async_context_manager():
|
||||
send_channel, sub = make_trio_subscription()
|
||||
async with sub:
|
||||
# Test: `sub` is not cancelled yet, so `send_something` works fine.
|
||||
await send_something(send_channel)
|
||||
# Test: `sub` is cancelled, `send_something` fails
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await send_something(send_channel)
|
||||
Reference in New Issue
Block a user