mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
1244 lines
45 KiB
Python
1244 lines
45 KiB
Python
from contextlib import (
|
||
contextmanager,
|
||
)
|
||
import inspect
|
||
from typing import (
|
||
NamedTuple,
|
||
)
|
||
from unittest.mock import patch
|
||
|
||
import pytest
|
||
import multiaddr
|
||
import trio
|
||
|
||
from libp2p.crypto.rsa import create_new_key_pair
|
||
from libp2p.custom_types import AsyncValidatorFn
|
||
from libp2p.exceptions import (
|
||
ValidationError,
|
||
)
|
||
from libp2p.network.stream.exceptions import (
|
||
StreamEOF,
|
||
)
|
||
from libp2p.peer.envelope import Envelope, seal_record
|
||
from libp2p.peer.id import (
|
||
ID,
|
||
)
|
||
from libp2p.peer.peer_record import PeerRecord
|
||
from libp2p.pubsub.pb import (
|
||
rpc_pb2,
|
||
)
|
||
from libp2p.pubsub.pubsub import (
|
||
PUBSUB_SIGNING_PREFIX,
|
||
SUBSCRIPTION_CHANNEL_SIZE,
|
||
)
|
||
from libp2p.tools.constants import (
|
||
MAX_READ_LEN,
|
||
)
|
||
from libp2p.tools.utils import (
|
||
connect,
|
||
)
|
||
from libp2p.utils import (
|
||
encode_varint_prefixed,
|
||
)
|
||
from tests.utils.factories import (
|
||
IDFactory,
|
||
PubsubFactory,
|
||
net_stream_pair_factory,
|
||
)
|
||
from tests.utils.pubsub.utils import (
|
||
make_pubsub_msg,
|
||
)
|
||
|
||
TESTING_TOPIC = "TEST_SUBSCRIBE"
|
||
TESTING_DATA = b"data"
|
||
|
||
|
||
@pytest.mark.trio
|
||
async def test_subscribe_and_unsubscribe():
|
||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||
assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids
|
||
|
||
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
||
assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids
|
||
|
||
|
||
@pytest.mark.trio
|
||
async def test_re_subscribe():
|
||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||
assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids
|
||
|
||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||
assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids
|
||
|
||
|
||
@pytest.mark.trio
|
||
async def test_re_unsubscribe():
|
||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||
# Unsubscribe from topic we didn't even subscribe to
|
||
assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].topic_ids
|
||
await pubsubs_fsub[0].unsubscribe("NOT_MY_TOPIC")
|
||
assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].topic_ids
|
||
|
||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||
assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids
|
||
|
||
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
||
assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids
|
||
|
||
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
||
assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids
|
||
|
||
|
||
@pytest.mark.trio
|
||
async def test_reissue_when_listen_addrs_change():
|
||
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
|
||
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
|
||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||
# Yield to let 0 notify 1
|
||
await trio.sleep(1)
|
||
assert pubsubs_fsub[0].my_id in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
|
||
|
||
# Check whether signed-records were transfered properly in the subscribe call
|
||
envelope_b_sub = (
|
||
pubsubs_fsub[1]
|
||
.host.get_peerstore()
|
||
.get_peer_record(pubsubs_fsub[0].host.get_id())
|
||
)
|
||
assert isinstance(envelope_b_sub, Envelope)
|
||
|
||
# Simulate pubsubs_fsub[1].host listen addrs changing (different port)
|
||
new_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/123")
|
||
|
||
# Patch just for the duration we force A to unsubscribe
|
||
with patch.object(pubsubs_fsub[0].host, "get_addrs", return_value=[new_addr]):
|
||
# Unsubscribe from A's side so that a new_record is issued
|
||
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
||
await trio.sleep(1)
|
||
|
||
# B should be holding A's new record with bumped seq
|
||
envelope_b_unsub = (
|
||
pubsubs_fsub[1]
|
||
.host.get_peerstore()
|
||
.get_peer_record(pubsubs_fsub[0].host.get_id())
|
||
)
|
||
assert isinstance(envelope_b_unsub, Envelope)
|
||
|
||
# This proves that a freshly signed record was issued rather than
|
||
# the latest-cached-one creating one.
|
||
assert envelope_b_sub.record().seq < envelope_b_unsub.record().seq
|
||
|
||
|
||
@pytest.mark.trio
|
||
async def test_peers_subscribe():
|
||
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
|
||
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
|
||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||
# Yield to let 0 notify 1
|
||
await trio.sleep(1)
|
||
assert pubsubs_fsub[0].my_id in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
|
||
|
||
# Check whether signed-records were transfered properly in the subscribe call
|
||
envelope_b_sub = (
|
||
pubsubs_fsub[1]
|
||
.host.get_peerstore()
|
||
.get_peer_record(pubsubs_fsub[0].host.get_id())
|
||
)
|
||
assert isinstance(envelope_b_sub, Envelope)
|
||
|
||
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
||
# Yield to let 0 notify 1
|
||
await trio.sleep(1)
|
||
assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
|
||
|
||
envelope_b_unsub = (
|
||
pubsubs_fsub[1]
|
||
.host.get_peerstore()
|
||
.get_peer_record(pubsubs_fsub[0].host.get_id())
|
||
)
|
||
assert isinstance(envelope_b_unsub, Envelope)
|
||
|
||
# This proves that the latest-cached-record was re-issued rather than
|
||
# freshly creating one.
|
||
assert envelope_b_sub.record().seq == envelope_b_unsub.record().seq
|
||
|
||
|
||
@pytest.mark.trio
|
||
async def test_peer_subscribe_fail_upon_invald_record_transfer():
|
||
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
|
||
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
|
||
|
||
# Corrupt host_a's local peer record
|
||
envelope = pubsubs_fsub[0].host.get_peerstore().get_local_record()
|
||
if envelope is not None:
|
||
true_record = envelope.record()
|
||
key_pair = create_new_key_pair()
|
||
|
||
if envelope is not None:
|
||
envelope.public_key = key_pair.public_key
|
||
pubsubs_fsub[0].host.get_peerstore().set_local_record(envelope)
|
||
|
||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||
# Yeild to let 0 notify 1
|
||
await trio.sleep(1)
|
||
assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics.get(
|
||
TESTING_TOPIC, set()
|
||
)
|
||
|
||
# Create a corrupt envelope with correct signature but false peer-id
|
||
false_record = PeerRecord(
|
||
ID.from_pubkey(key_pair.public_key), true_record.addrs
|
||
)
|
||
false_envelope = seal_record(
|
||
false_record, pubsubs_fsub[0].host.get_private_key()
|
||
)
|
||
|
||
pubsubs_fsub[0].host.get_peerstore().set_local_record(false_envelope)
|
||
|
||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||
# Yeild to let 0 notify 1
|
||
await trio.sleep(1)
|
||
assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics.get(
|
||
TESTING_TOPIC, set()
|
||
)
|
||
|
||
|
||
@pytest.mark.trio
|
||
async def test_get_hello_packet():
|
||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||
|
||
def _get_hello_packet_topic_ids():
|
||
packet = pubsubs_fsub[0].get_hello_packet()
|
||
return tuple(sub.topicid for sub in packet.subscriptions)
|
||
|
||
# Test: No subscription, so there should not be any topic ids in the
|
||
# hello packet.
|
||
assert len(_get_hello_packet_topic_ids()) == 0
|
||
|
||
# Test: After subscriptions, topic ids should be in the hello packet.
|
||
topic_ids = ["t", "o", "p", "i", "c"]
|
||
for topic in topic_ids:
|
||
await pubsubs_fsub[0].subscribe(topic)
|
||
topic_ids_in_hello = _get_hello_packet_topic_ids()
|
||
for topic in topic_ids:
|
||
assert topic in topic_ids_in_hello
|
||
|
||
|
||
@pytest.mark.trio
|
||
async def test_set_and_remove_topic_validator():
|
||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||
is_sync_validator_called = False
|
||
|
||
def sync_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool:
|
||
nonlocal is_sync_validator_called
|
||
is_sync_validator_called = True
|
||
return True
|
||
|
||
is_async_validator_called = False
|
||
|
||
async def async_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool:
|
||
nonlocal is_async_validator_called
|
||
is_async_validator_called = True
|
||
await trio.lowlevel.checkpoint()
|
||
return True
|
||
|
||
topic = "TEST_VALIDATOR"
|
||
|
||
assert topic not in pubsubs_fsub[0].topic_validators
|
||
|
||
# Register sync validator
|
||
pubsubs_fsub[0].set_topic_validator(topic, sync_validator, False)
|
||
|
||
assert topic in pubsubs_fsub[0].topic_validators
|
||
topic_validator = pubsubs_fsub[0].topic_validators[topic]
|
||
assert not topic_validator.is_async
|
||
|
||
# Validate with sync validator
|
||
test_msg = make_pubsub_msg(
|
||
origin_id=IDFactory(),
|
||
topic_ids=[topic],
|
||
data=b"test",
|
||
seqno=b"\x00" * 8,
|
||
)
|
||
topic_validator.validator(IDFactory(), test_msg)
|
||
|
||
assert is_sync_validator_called
|
||
assert not is_async_validator_called
|
||
|
||
# Register with async validator
|
||
pubsubs_fsub[0].set_topic_validator(topic, async_validator, True)
|
||
|
||
is_sync_validator_called = False
|
||
assert topic in pubsubs_fsub[0].topic_validators
|
||
topic_validator = pubsubs_fsub[0].topic_validators[topic]
|
||
assert topic_validator.is_async
|
||
|
||
# Validate with async validator
|
||
test_msg = make_pubsub_msg(
|
||
origin_id=IDFactory(),
|
||
topic_ids=[topic],
|
||
data=b"test",
|
||
seqno=b"\x00" * 8,
|
||
)
|
||
validator = topic_validator.validator
|
||
if topic_validator.is_async:
|
||
import inspect
|
||
|
||
if inspect.iscoroutinefunction(validator):
|
||
await validator(IDFactory(), test_msg)
|
||
else:
|
||
validator(IDFactory(), test_msg)
|
||
|
||
assert is_async_validator_called
|
||
assert not is_sync_validator_called
|
||
|
||
# Remove validator
|
||
pubsubs_fsub[0].remove_topic_validator(topic)
|
||
assert topic not in pubsubs_fsub[0].topic_validators
|
||
|
||
|
||
@pytest.mark.trio
|
||
async def test_get_msg_validators():
|
||
calls = [0, 0] # [sync, async]
|
||
|
||
def sync_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool:
|
||
calls[0] += 1
|
||
return True
|
||
|
||
async def async_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool:
|
||
calls[1] += 1
|
||
await trio.lowlevel.checkpoint()
|
||
return True
|
||
|
||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||
topic_1 = "TEST_VALIDATOR_1"
|
||
topic_2 = "TEST_VALIDATOR_2"
|
||
topic_3 = "TEST_VALIDATOR_3"
|
||
|
||
# Register sync validator for topic 1 and 2
|
||
pubsubs_fsub[0].set_topic_validator(topic_1, sync_validator, False)
|
||
pubsubs_fsub[0].set_topic_validator(topic_2, sync_validator, False)
|
||
|
||
# Register async validator for topic 3
|
||
pubsubs_fsub[0].set_topic_validator(topic_3, async_validator, True)
|
||
|
||
msg = make_pubsub_msg(
|
||
origin_id=pubsubs_fsub[0].my_id,
|
||
topic_ids=[topic_1, topic_2, topic_3],
|
||
data=b"1234",
|
||
seqno=b"\x00" * 8,
|
||
)
|
||
|
||
topic_validators = pubsubs_fsub[0].get_msg_validators(msg)
|
||
for topic_validator in topic_validators:
|
||
validator = topic_validator.validator
|
||
if topic_validator.is_async:
|
||
if inspect.iscoroutinefunction(validator):
|
||
await validator(IDFactory(), msg)
|
||
else:
|
||
validator(IDFactory(), msg)
|
||
|
||
assert calls[0] == 2
|
||
assert calls[1] == 1
|
||
|
||
|
||
@pytest.mark.parametrize(
|
||
"is_topic_1_val_passed, is_topic_2_val_passed",
|
||
((False, True), (True, False), (True, True)),
|
||
)
|
||
@pytest.mark.trio
|
||
async def test_validate_msg_with_throttle_condition(
|
||
is_topic_1_val_passed, is_topic_2_val_passed
|
||
):
|
||
CONCURRENCY_LIMIT = 10
|
||
|
||
state = {
|
||
"concurrency_counter": 0,
|
||
"max_observed": 0,
|
||
}
|
||
lock = trio.Lock()
|
||
|
||
async def mock_run_async_validator(
|
||
self,
|
||
func: AsyncValidatorFn,
|
||
msg_forwarder: ID,
|
||
msg: rpc_pb2.Message,
|
||
results: list[bool],
|
||
) -> None:
|
||
async with self._validator_semaphore:
|
||
async with lock:
|
||
state["concurrency_counter"] += 1
|
||
if state["concurrency_counter"] > state["max_observed"]:
|
||
state["max_observed"] = state["concurrency_counter"]
|
||
|
||
try:
|
||
result = await func(msg_forwarder, msg)
|
||
results.append(result)
|
||
finally:
|
||
async with lock:
|
||
state["concurrency_counter"] -= 1
|
||
|
||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||
|
||
def passed_sync_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool:
|
||
return True
|
||
|
||
def failed_sync_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool:
|
||
return False
|
||
|
||
async def passed_async_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool:
|
||
await trio.lowlevel.checkpoint()
|
||
return True
|
||
|
||
async def failed_async_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool:
|
||
await trio.lowlevel.checkpoint()
|
||
return False
|
||
|
||
topic_1 = "TEST_SYNC_VALIDATOR"
|
||
topic_2 = "TEST_ASYNC_VALIDATOR"
|
||
|
||
if is_topic_1_val_passed:
|
||
pubsubs_fsub[0].set_topic_validator(topic_1, passed_sync_validator, False)
|
||
else:
|
||
pubsubs_fsub[0].set_topic_validator(topic_1, failed_sync_validator, False)
|
||
|
||
if is_topic_2_val_passed:
|
||
pubsubs_fsub[0].set_topic_validator(topic_2, passed_async_validator, True)
|
||
else:
|
||
pubsubs_fsub[0].set_topic_validator(topic_2, failed_async_validator, True)
|
||
|
||
msg = make_pubsub_msg(
|
||
origin_id=pubsubs_fsub[0].my_id,
|
||
topic_ids=[topic_1, topic_2],
|
||
data=b"1234",
|
||
seqno=b"\x00" * 8,
|
||
)
|
||
|
||
with patch(
|
||
"libp2p.pubsub.pubsub.Pubsub._run_async_validator",
|
||
new=mock_run_async_validator,
|
||
):
|
||
if is_topic_1_val_passed and is_topic_2_val_passed:
|
||
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
|
||
else:
|
||
with pytest.raises(ValidationError):
|
||
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
|
||
|
||
assert state["max_observed"] <= CONCURRENCY_LIMIT, (
|
||
f"Max concurrency observed: {state['max_observed']}"
|
||
)
|
||
|
||
|
||
@pytest.mark.trio
|
||
async def test_continuously_read_stream(monkeypatch, nursery, security_protocol):
|
||
async def wait_for_event_occurring(event):
|
||
await trio.lowlevel.checkpoint()
|
||
with trio.fail_after(0.1):
|
||
await event.wait()
|
||
|
||
class Events(NamedTuple):
|
||
push_msg: trio.Event
|
||
handle_subscription: trio.Event
|
||
handle_rpc: trio.Event
|
||
|
||
@contextmanager
|
||
def mock_methods():
|
||
event_push_msg = trio.Event()
|
||
event_handle_subscription = trio.Event()
|
||
event_handle_rpc = trio.Event()
|
||
|
||
async def mock_push_msg(msg_forwarder, msg):
|
||
event_push_msg.set()
|
||
await trio.lowlevel.checkpoint()
|
||
|
||
def mock_handle_subscription(origin_id, sub_message):
|
||
event_handle_subscription.set()
|
||
|
||
async def mock_handle_rpc(rpc, sender_peer_id):
|
||
event_handle_rpc.set()
|
||
await trio.lowlevel.checkpoint()
|
||
|
||
with monkeypatch.context() as m:
|
||
m.setattr(pubsubs_fsub[0], "push_msg", mock_push_msg)
|
||
m.setattr(pubsubs_fsub[0], "handle_subscription", mock_handle_subscription)
|
||
m.setattr(pubsubs_fsub[0].router, "handle_rpc", mock_handle_rpc)
|
||
yield Events(event_push_msg, event_handle_subscription, event_handle_rpc)
|
||
|
||
async with (
|
||
PubsubFactory.create_batch_with_floodsub(
|
||
1, security_protocol=security_protocol
|
||
) as pubsubs_fsub,
|
||
net_stream_pair_factory(security_protocol=security_protocol) as stream_pair,
|
||
):
|
||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||
# Kick off the task `continuously_read_stream`
|
||
nursery.start_soon(pubsubs_fsub[0].continuously_read_stream, stream_pair[0])
|
||
|
||
# Test: `push_msg` is called when publishing to a subscribed topic.
|
||
publish_subscribed_topic = rpc_pb2.RPC(
|
||
publish=[rpc_pb2.Message(topicIDs=[TESTING_TOPIC])]
|
||
)
|
||
with mock_methods() as events:
|
||
await stream_pair[1].write(
|
||
encode_varint_prefixed(publish_subscribed_topic.SerializeToString())
|
||
)
|
||
await wait_for_event_occurring(events.push_msg)
|
||
# Make sure the other events are not emitted.
|
||
with pytest.raises(trio.TooSlowError):
|
||
await wait_for_event_occurring(events.handle_subscription)
|
||
with pytest.raises(trio.TooSlowError):
|
||
await wait_for_event_occurring(events.handle_rpc)
|
||
|
||
# Test: `push_msg` is not called when publishing to a topic-not-subscribed.
|
||
publish_not_subscribed_topic = rpc_pb2.RPC(
|
||
publish=[rpc_pb2.Message(topicIDs=["NOT_SUBSCRIBED"])]
|
||
)
|
||
with mock_methods() as events:
|
||
await stream_pair[1].write(
|
||
encode_varint_prefixed(publish_not_subscribed_topic.SerializeToString())
|
||
)
|
||
with pytest.raises(trio.TooSlowError):
|
||
await wait_for_event_occurring(events.push_msg)
|
||
|
||
# Test: `handle_subscription` is called when a subscription message is received.
|
||
subscription_msg = rpc_pb2.RPC(subscriptions=[rpc_pb2.RPC.SubOpts()])
|
||
with mock_methods() as events:
|
||
await stream_pair[1].write(
|
||
encode_varint_prefixed(subscription_msg.SerializeToString())
|
||
)
|
||
await wait_for_event_occurring(events.handle_subscription)
|
||
# Make sure the other events are not emitted.
|
||
with pytest.raises(trio.TooSlowError):
|
||
await wait_for_event_occurring(events.push_msg)
|
||
with pytest.raises(trio.TooSlowError):
|
||
await wait_for_event_occurring(events.handle_rpc)
|
||
|
||
# Test: `handle_rpc` is called when a control message is received.
|
||
control_msg = rpc_pb2.RPC(control=rpc_pb2.ControlMessage())
|
||
with mock_methods() as events:
|
||
await stream_pair[1].write(
|
||
encode_varint_prefixed(control_msg.SerializeToString())
|
||
)
|
||
await wait_for_event_occurring(events.handle_rpc)
|
||
# Make sure the other events are not emitted.
|
||
with pytest.raises(trio.TooSlowError):
|
||
await wait_for_event_occurring(events.push_msg)
|
||
with pytest.raises(trio.TooSlowError):
|
||
await wait_for_event_occurring(events.handle_subscription)
|
||
# After all messages, close the write end to signal EOF
|
||
await stream_pair[1].close()
|
||
# Now reading should raise StreamEOF
|
||
with pytest.raises(StreamEOF):
|
||
await stream_pair[0].read(1)
|
||
|
||
|
||
# TODO: Add the following tests after they are aligned with Go.
|
||
# (Issue #191: https://github.com/libp2p/py-libp2p/issues/191)
|
||
# - `test_stream_handler`
|
||
# - `test_handle_peer_queue`
|
||
|
||
|
||
@pytest.mark.trio
|
||
async def test_handle_subscription():
|
||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||
assert len(pubsubs_fsub[0].peer_topics) == 0
|
||
sub_msg_0 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=TESTING_TOPIC)
|
||
peer_ids = [IDFactory() for _ in range(2)]
|
||
# Test: One peer is subscribed
|
||
pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_0)
|
||
assert (
|
||
len(pubsubs_fsub[0].peer_topics) == 1
|
||
and TESTING_TOPIC in pubsubs_fsub[0].peer_topics
|
||
)
|
||
assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 1
|
||
assert peer_ids[0] in pubsubs_fsub[0].peer_topics[TESTING_TOPIC]
|
||
# Test: Another peer is subscribed
|
||
pubsubs_fsub[0].handle_subscription(peer_ids[1], sub_msg_0)
|
||
assert len(pubsubs_fsub[0].peer_topics) == 1
|
||
assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 2
|
||
assert peer_ids[1] in pubsubs_fsub[0].peer_topics[TESTING_TOPIC]
|
||
# Test: Subscribe to another topic
|
||
another_topic = "ANOTHER_TOPIC"
|
||
sub_msg_1 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=another_topic)
|
||
pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_1)
|
||
assert len(pubsubs_fsub[0].peer_topics) == 2
|
||
assert another_topic in pubsubs_fsub[0].peer_topics
|
||
assert peer_ids[0] in pubsubs_fsub[0].peer_topics[another_topic]
|
||
# Test: unsubscribe
|
||
unsub_msg = rpc_pb2.RPC.SubOpts(subscribe=False, topicid=TESTING_TOPIC)
|
||
pubsubs_fsub[0].handle_subscription(peer_ids[0], unsub_msg)
|
||
assert peer_ids[0] not in pubsubs_fsub[0].peer_topics[TESTING_TOPIC]
|
||
|
||
|
||
@pytest.mark.trio
|
||
async def test_handle_talk():
|
||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||
sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||
msg_0 = make_pubsub_msg(
|
||
origin_id=pubsubs_fsub[0].my_id,
|
||
topic_ids=[TESTING_TOPIC],
|
||
data=b"1234",
|
||
seqno=b"\x00" * 8,
|
||
)
|
||
pubsubs_fsub[0].notify_subscriptions(msg_0)
|
||
msg_1 = make_pubsub_msg(
|
||
origin_id=pubsubs_fsub[0].my_id,
|
||
topic_ids=["NOT_SUBSCRIBED"],
|
||
data=b"1234",
|
||
seqno=b"\x11" * 8,
|
||
)
|
||
pubsubs_fsub[0].notify_subscriptions(msg_1)
|
||
assert (
|
||
len(pubsubs_fsub[0].topic_ids) == 1
|
||
and sub == pubsubs_fsub[0].subscribed_topics_receive[TESTING_TOPIC]
|
||
)
|
||
assert (await sub.get()) == msg_0
|
||
|
||
|
||
@pytest.mark.trio
|
||
async def test_message_all_peers(monkeypatch, security_protocol):
|
||
async with (
|
||
PubsubFactory.create_batch_with_floodsub(
|
||
1, security_protocol=security_protocol
|
||
) as pubsubs_fsub,
|
||
net_stream_pair_factory(security_protocol=security_protocol) as stream_pair,
|
||
):
|
||
peer_id = IDFactory()
|
||
mock_peers = {peer_id: stream_pair[0]}
|
||
with monkeypatch.context() as m:
|
||
m.setattr(pubsubs_fsub[0], "peers", mock_peers)
|
||
|
||
empty_rpc = rpc_pb2.RPC()
|
||
empty_rpc_bytes = empty_rpc.SerializeToString()
|
||
empty_rpc_bytes_len_prefixed = encode_varint_prefixed(empty_rpc_bytes)
|
||
await pubsubs_fsub[0].message_all_peers(empty_rpc_bytes)
|
||
assert (
|
||
await stream_pair[1].read(MAX_READ_LEN)
|
||
) == empty_rpc_bytes_len_prefixed
|
||
|
||
|
||
@pytest.mark.trio
|
||
async def test_subscribe_and_publish():
|
||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||
pubsub = pubsubs_fsub[0]
|
||
|
||
list_data = [b"d0", b"d1"]
|
||
event_receive_data_started = trio.Event()
|
||
|
||
async def publish_data(topic):
|
||
await event_receive_data_started.wait()
|
||
for data in list_data:
|
||
await pubsub.publish(topic, data)
|
||
|
||
async def receive_data(topic):
|
||
i = 0
|
||
event_receive_data_started.set()
|
||
assert topic not in pubsub.topic_ids
|
||
subscription = await pubsub.subscribe(topic)
|
||
async with subscription:
|
||
assert topic in pubsub.topic_ids
|
||
async for msg in subscription:
|
||
assert msg.data == list_data[i]
|
||
i += 1
|
||
if i == len(list_data):
|
||
break
|
||
assert topic not in pubsub.topic_ids
|
||
|
||
async with trio.open_nursery() as nursery:
|
||
nursery.start_soon(receive_data, TESTING_TOPIC)
|
||
nursery.start_soon(publish_data, TESTING_TOPIC)
|
||
|
||
|
||
@pytest.mark.trio
|
||
async def test_subscribe_and_publish_full_channel():
|
||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||
pubsub = pubsubs_fsub[0]
|
||
|
||
extra_data_0 = b"extra_data_0"
|
||
extra_data_1 = b"extra_data_1"
|
||
|
||
# Test: Subscription channel is of size `SUBSCRIPTION_CHANNEL_SIZE`.
|
||
# When the channel is full, new received messages are dropped.
|
||
# After the channel has empty slot, the channel can receive new messages.
|
||
|
||
# Assume `SUBSCRIPTION_CHANNEL_SIZE` is smaller than `2**(4*8)`.
|
||
list_data = [i.to_bytes(4, "big") for i in range(SUBSCRIPTION_CHANNEL_SIZE)]
|
||
# Expect `extra_data_0` is dropped and `extra_data_1` is appended.
|
||
expected_list_data = list_data + [extra_data_1]
|
||
|
||
subscription = await pubsub.subscribe(TESTING_TOPIC)
|
||
for data in list_data:
|
||
await pubsub.publish(TESTING_TOPIC, data)
|
||
|
||
# Publish `extra_data_0` which should be dropped since the channel is
|
||
# already full.
|
||
await pubsub.publish(TESTING_TOPIC, extra_data_0)
|
||
# Consume a message and there is an empty slot in the channel.
|
||
assert (await subscription.get()).data == expected_list_data.pop(0)
|
||
# Publish `extra_data_1` which should be appended to the channel.
|
||
await pubsub.publish(TESTING_TOPIC, extra_data_1)
|
||
|
||
for expected_data in expected_list_data:
|
||
assert (await subscription.get()).data == expected_data
|
||
|
||
|
||
@pytest.mark.trio
|
||
async def test_publish_push_msg_is_called(monkeypatch):
|
||
msg_forwarders = []
|
||
msgs = []
|
||
|
||
async def push_msg(msg_forwarder, msg):
|
||
msg_forwarders.append(msg_forwarder)
|
||
msgs.append(msg)
|
||
await trio.lowlevel.checkpoint()
|
||
|
||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||
with monkeypatch.context() as m:
|
||
m.setattr(pubsubs_fsub[0], "push_msg", push_msg)
|
||
|
||
await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA)
|
||
await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA)
|
||
|
||
assert len(msgs) == 2, (
|
||
"`push_msg` should be called every time `publish` is called"
|
||
)
|
||
assert (msg_forwarders[0] == msg_forwarders[1]) and (
|
||
msg_forwarders[1] == pubsubs_fsub[0].my_id
|
||
)
|
||
assert msgs[0].seqno != msgs[1].seqno, (
|
||
"`seqno` should be different every time"
|
||
)
|
||
|
||
|
||
@pytest.mark.trio
|
||
async def test_push_msg(monkeypatch):
|
||
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
|
||
msg_0 = make_pubsub_msg(
|
||
origin_id=pubsubs_fsub[0].my_id,
|
||
topic_ids=[TESTING_TOPIC],
|
||
data=TESTING_DATA,
|
||
seqno=b"\x00" * 8,
|
||
)
|
||
|
||
@contextmanager
|
||
def mock_router_publish():
|
||
event = trio.Event()
|
||
|
||
async def router_publish(*args, **kwargs):
|
||
event.set()
|
||
await trio.lowlevel.checkpoint()
|
||
|
||
with monkeypatch.context() as m:
|
||
m.setattr(pubsubs_fsub[0].router, "publish", router_publish)
|
||
yield event
|
||
|
||
with mock_router_publish() as event:
|
||
# Test: `msg` is not seen before `push_msg`, and is seen after `push_msg`.
|
||
assert not pubsubs_fsub[0]._is_msg_seen(msg_0)
|
||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0)
|
||
assert pubsubs_fsub[0]._is_msg_seen(msg_0)
|
||
# Test: Ensure `router.publish` is called in `push_msg`
|
||
with trio.fail_after(0.1):
|
||
await event.wait()
|
||
|
||
with mock_router_publish() as event:
|
||
# Test: `push_msg` the message again and it will be reject.
|
||
# `router_publish` is not called then.
|
||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0)
|
||
await trio.sleep(0.01)
|
||
assert not event.is_set()
|
||
|
||
# Test: `push_msg` a new msg but forwarder as not self, it will be reject.
|
||
# `router_publish` is not called then.
|
||
msg_0A = make_pubsub_msg(
|
||
origin_id=pubsubs_fsub[0].my_id,
|
||
topic_ids=[TESTING_TOPIC],
|
||
data=TESTING_DATA,
|
||
seqno=b"\x33" * 8,
|
||
)
|
||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[1].my_id, msg_0A)
|
||
await trio.sleep(0.01)
|
||
assert not event.is_set()
|
||
|
||
sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||
# Test: `push_msg` succeeds with another unseen msg.
|
||
msg_1 = make_pubsub_msg(
|
||
origin_id=pubsubs_fsub[0].my_id,
|
||
topic_ids=[TESTING_TOPIC],
|
||
data=TESTING_DATA,
|
||
seqno=b"\x11" * 8,
|
||
)
|
||
assert not pubsubs_fsub[0]._is_msg_seen(msg_1)
|
||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_1)
|
||
assert pubsubs_fsub[0]._is_msg_seen(msg_1)
|
||
with trio.fail_after(0.1):
|
||
await event.wait()
|
||
# Test: Subscribers are notified when `push_msg` new messages.
|
||
assert (await sub.get()) == msg_1
|
||
|
||
with mock_router_publish() as event:
|
||
# Test: add a topic validator and `push_msg` the message that
|
||
# does not pass the validation.
|
||
# `router_publish` is not called then.
|
||
def failed_sync_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool:
|
||
return False
|
||
|
||
pubsubs_fsub[0].set_topic_validator(
|
||
TESTING_TOPIC, failed_sync_validator, False
|
||
)
|
||
|
||
msg_2 = make_pubsub_msg(
|
||
origin_id=pubsubs_fsub[0].my_id,
|
||
topic_ids=[TESTING_TOPIC],
|
||
data=TESTING_DATA,
|
||
seqno=b"\x22" * 8,
|
||
)
|
||
|
||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_2)
|
||
await trio.sleep(0.01)
|
||
assert not event.is_set()
|
||
|
||
|
||
@pytest.mark.trio
|
||
async def test_strict_signing():
|
||
async with PubsubFactory.create_batch_with_floodsub(
|
||
2, strict_signing=True
|
||
) as pubsubs_fsub:
|
||
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
|
||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||
await pubsubs_fsub[1].subscribe(TESTING_TOPIC)
|
||
await trio.sleep(1)
|
||
|
||
await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA)
|
||
await trio.sleep(1)
|
||
|
||
assert pubsubs_fsub[0].seen_messages.length() == 1
|
||
assert pubsubs_fsub[1].seen_messages.length() == 1
|
||
|
||
|
||
@pytest.mark.trio
|
||
async def test_strict_signing_failed_validation(monkeypatch):
|
||
async with PubsubFactory.create_batch_with_floodsub(
|
||
2, strict_signing=True
|
||
) as pubsubs_fsub:
|
||
msg = make_pubsub_msg(
|
||
origin_id=pubsubs_fsub[0].my_id,
|
||
topic_ids=[TESTING_TOPIC],
|
||
data=TESTING_DATA,
|
||
seqno=b"\x00" * 8,
|
||
)
|
||
priv_key = pubsubs_fsub[0].sign_key
|
||
assert priv_key is not None, (
|
||
"Private key should not be None when strict_signing=True"
|
||
)
|
||
signature = priv_key.sign(
|
||
PUBSUB_SIGNING_PREFIX.encode() + msg.SerializeToString()
|
||
)
|
||
|
||
event = trio.Event()
|
||
|
||
def _is_msg_seen(msg):
|
||
return False
|
||
|
||
# Use router publish to check if `push_msg` succeed.
|
||
async def router_publish(*args, **kwargs):
|
||
await trio.lowlevel.checkpoint()
|
||
# The event will only be set if `push_msg` succeed.
|
||
event.set()
|
||
|
||
monkeypatch.setattr(pubsubs_fsub[0], "_is_msg_seen", _is_msg_seen)
|
||
monkeypatch.setattr(pubsubs_fsub[0].router, "publish", router_publish)
|
||
|
||
# Test: no signature attached in `msg`
|
||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg)
|
||
await trio.sleep(0.01)
|
||
assert not event.is_set()
|
||
|
||
# Test: `msg.key` does not match `msg.from_id`
|
||
msg.key = pubsubs_fsub[1].host.get_public_key().serialize()
|
||
msg.signature = signature
|
||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg)
|
||
await trio.sleep(0.01)
|
||
assert not event.is_set()
|
||
|
||
# Test: invalid signature
|
||
msg.key = pubsubs_fsub[0].host.get_public_key().serialize()
|
||
msg.signature = b"\x12" * 100
|
||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg)
|
||
await trio.sleep(0.01)
|
||
assert not event.is_set()
|
||
|
||
# Finally, assert the signature indeed will pass validation
|
||
msg.key = pubsubs_fsub[0].host.get_public_key().serialize()
|
||
msg.signature = signature
|
||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg)
|
||
await trio.sleep(0.01)
|
||
assert event.is_set()
|
||
|
||
|
||
@pytest.mark.trio
|
||
async def test_blacklist_basic_operations():
|
||
"""Test basic blacklist operations: add, remove, check, clear."""
|
||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||
pubsub = pubsubs_fsub[0]
|
||
|
||
# Create test peer IDs
|
||
peer1 = IDFactory()
|
||
peer2 = IDFactory()
|
||
peer3 = IDFactory()
|
||
|
||
# Initially no peers should be blacklisted
|
||
assert len(pubsub.get_blacklisted_peers()) == 0
|
||
assert not pubsub.is_peer_blacklisted(peer1)
|
||
assert not pubsub.is_peer_blacklisted(peer2)
|
||
assert not pubsub.is_peer_blacklisted(peer3)
|
||
|
||
# Add peers to blacklist
|
||
pubsub.add_to_blacklist(peer1)
|
||
pubsub.add_to_blacklist(peer2)
|
||
|
||
# Check blacklist state
|
||
assert len(pubsub.get_blacklisted_peers()) == 2
|
||
assert pubsub.is_peer_blacklisted(peer1)
|
||
assert pubsub.is_peer_blacklisted(peer2)
|
||
assert not pubsub.is_peer_blacklisted(peer3)
|
||
|
||
# Remove one peer from blacklist
|
||
pubsub.remove_from_blacklist(peer1)
|
||
|
||
# Check state after removal
|
||
assert len(pubsub.get_blacklisted_peers()) == 1
|
||
assert not pubsub.is_peer_blacklisted(peer1)
|
||
assert pubsub.is_peer_blacklisted(peer2)
|
||
assert not pubsub.is_peer_blacklisted(peer3)
|
||
|
||
# Add peer3 and then clear all
|
||
pubsub.add_to_blacklist(peer3)
|
||
assert len(pubsub.get_blacklisted_peers()) == 2
|
||
|
||
pubsub.clear_blacklist()
|
||
assert len(pubsub.get_blacklisted_peers()) == 0
|
||
assert not pubsub.is_peer_blacklisted(peer1)
|
||
assert not pubsub.is_peer_blacklisted(peer2)
|
||
assert not pubsub.is_peer_blacklisted(peer3)
|
||
|
||
# Test duplicate additions (should not increase size)
|
||
pubsub.add_to_blacklist(peer1)
|
||
pubsub.add_to_blacklist(peer1)
|
||
assert len(pubsub.get_blacklisted_peers()) == 1
|
||
|
||
# Test removing non-blacklisted peer (should not cause errors)
|
||
pubsub.remove_from_blacklist(peer2)
|
||
assert len(pubsub.get_blacklisted_peers()) == 1
|
||
|
||
|
||
@pytest.mark.trio
|
||
async def test_blacklist_blocks_new_peer_connections(monkeypatch):
|
||
"""Test that blacklisted peers are rejected when trying to connect."""
|
||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||
pubsub = pubsubs_fsub[0]
|
||
|
||
# Create a blacklisted peer ID
|
||
blacklisted_peer = IDFactory()
|
||
|
||
# Add peer to blacklist
|
||
pubsub.add_to_blacklist(blacklisted_peer)
|
||
|
||
new_stream_called = False
|
||
|
||
async def mock_new_stream(*args, **kwargs):
|
||
nonlocal new_stream_called
|
||
new_stream_called = True
|
||
# Create a mock stream
|
||
from unittest.mock import (
|
||
AsyncMock,
|
||
Mock,
|
||
)
|
||
|
||
mock_stream = Mock()
|
||
mock_stream.write = AsyncMock()
|
||
mock_stream.reset = AsyncMock()
|
||
mock_stream.get_protocol = Mock(return_value="test_protocol")
|
||
return mock_stream
|
||
|
||
router_add_peer_called = False
|
||
|
||
def mock_add_peer(*args, **kwargs):
|
||
nonlocal router_add_peer_called
|
||
router_add_peer_called = True
|
||
|
||
with monkeypatch.context() as m:
|
||
m.setattr(pubsub.host, "new_stream", mock_new_stream)
|
||
m.setattr(pubsub.router, "add_peer", mock_add_peer)
|
||
|
||
# Attempt to handle the blacklisted peer
|
||
await pubsub._handle_new_peer(blacklisted_peer)
|
||
|
||
# Verify that both new_stream and router.add_peer was not called
|
||
assert not new_stream_called, (
|
||
"new_stream should be not be called to get hello packet"
|
||
)
|
||
assert not router_add_peer_called, (
|
||
"Router.add_peer should not be called for blacklisted peer"
|
||
)
|
||
assert blacklisted_peer not in pubsub.peers, (
|
||
"Blacklisted peer should not be in peers dict"
|
||
)
|
||
|
||
|
||
@pytest.mark.trio
|
||
async def test_blacklist_blocks_messages_from_blacklisted_originator():
|
||
"""Test that messages from blacklisted originator (from field) are rejected."""
|
||
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
|
||
pubsub = pubsubs_fsub[0]
|
||
blacklisted_originator = pubsubs_fsub[1].my_id # Use existing peer ID
|
||
|
||
# Add the originator to blacklist
|
||
pubsub.add_to_blacklist(blacklisted_originator)
|
||
|
||
# Create a message with blacklisted originator
|
||
msg = make_pubsub_msg(
|
||
origin_id=blacklisted_originator,
|
||
topic_ids=[TESTING_TOPIC],
|
||
data=TESTING_DATA,
|
||
seqno=b"\x00" * 8,
|
||
)
|
||
|
||
# Subscribe to the topic
|
||
await pubsub.subscribe(TESTING_TOPIC)
|
||
|
||
# Track if router.publish is called
|
||
router_publish_called = False
|
||
|
||
async def mock_router_publish(msg_forwarder: ID, pubsub_msg: rpc_pb2.Message):
|
||
nonlocal router_publish_called
|
||
router_publish_called = True
|
||
await trio.lowlevel.checkpoint()
|
||
|
||
original_router_publish = pubsub.router.publish
|
||
pubsub.router.publish = mock_router_publish
|
||
|
||
try:
|
||
# Attempt to push message from blacklisted originator
|
||
await pubsub.push_msg(blacklisted_originator, msg)
|
||
|
||
# Verify message was rejected
|
||
assert not router_publish_called, (
|
||
"Router.publish should not be called for blacklisted originator"
|
||
)
|
||
assert not pubsub._is_msg_seen(msg), (
|
||
"Message from blacklisted originator should not be marked as seen"
|
||
)
|
||
|
||
finally:
|
||
pubsub.router.publish = original_router_publish
|
||
|
||
|
||
@pytest.mark.trio
|
||
async def test_blacklist_allows_non_blacklisted_peers():
|
||
"""Test that non-blacklisted peers can send messages normally."""
|
||
async with PubsubFactory.create_batch_with_floodsub(3) as pubsubs_fsub:
|
||
pubsub = pubsubs_fsub[0]
|
||
allowed_peer = pubsubs_fsub[1].my_id
|
||
blacklisted_peer = pubsubs_fsub[2].my_id
|
||
|
||
# Blacklist one peer but not the other
|
||
pubsub.add_to_blacklist(blacklisted_peer)
|
||
|
||
# Create messages from both peers
|
||
msg_from_allowed = make_pubsub_msg(
|
||
origin_id=allowed_peer,
|
||
topic_ids=[TESTING_TOPIC],
|
||
data=b"allowed_data",
|
||
seqno=b"\x00" * 8,
|
||
)
|
||
|
||
msg_from_blacklisted = make_pubsub_msg(
|
||
origin_id=blacklisted_peer,
|
||
topic_ids=[TESTING_TOPIC],
|
||
data=b"blacklisted_data",
|
||
seqno=b"\x11" * 8,
|
||
)
|
||
|
||
# Subscribe to the topic
|
||
sub = await pubsub.subscribe(TESTING_TOPIC)
|
||
|
||
# Track router.publish calls
|
||
router_publish_calls = []
|
||
|
||
async def mock_router_publish(msg_forwarder: ID, pubsub_msg: rpc_pb2.Message):
|
||
router_publish_calls.append((msg_forwarder, pubsub_msg))
|
||
await trio.lowlevel.checkpoint()
|
||
|
||
original_router_publish = pubsub.router.publish
|
||
pubsub.router.publish = mock_router_publish
|
||
|
||
try:
|
||
# Send message from allowed peer (should succeed)
|
||
await pubsub.push_msg(allowed_peer, msg_from_allowed)
|
||
|
||
# Send message from blacklisted peer (should be rejected)
|
||
await pubsub.push_msg(allowed_peer, msg_from_blacklisted)
|
||
|
||
# Verify only allowed message was processed
|
||
assert len(router_publish_calls) == 1, (
|
||
"Only one message should be processed"
|
||
)
|
||
assert pubsub._is_msg_seen(msg_from_allowed), (
|
||
"Allowed message should be marked as seen"
|
||
)
|
||
assert not pubsub._is_msg_seen(msg_from_blacklisted), (
|
||
"Blacklisted message should not be marked as seen"
|
||
)
|
||
|
||
# Verify subscription received the allowed message
|
||
received_msg = await sub.get()
|
||
assert received_msg.data == b"allowed_data"
|
||
|
||
finally:
|
||
pubsub.router.publish = original_router_publish
|
||
|
||
|
||
@pytest.mark.trio
|
||
async def test_blacklist_integration_with_existing_functionality():
|
||
"""Test that blacklisting works correctly with existing pubsub functionality."""
|
||
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
|
||
pubsub = pubsubs_fsub[0]
|
||
other_peer = pubsubs_fsub[1].my_id
|
||
|
||
# Test that seen messages cache still works with blacklisting
|
||
pubsub.add_to_blacklist(other_peer)
|
||
|
||
msg = make_pubsub_msg(
|
||
origin_id=other_peer,
|
||
topic_ids=[TESTING_TOPIC],
|
||
data=TESTING_DATA,
|
||
seqno=b"\x00" * 8,
|
||
)
|
||
|
||
# First attempt - should be rejected due to blacklist
|
||
await pubsub.push_msg(other_peer, msg)
|
||
assert not pubsub._is_msg_seen(msg)
|
||
|
||
# Remove from blacklist
|
||
pubsub.remove_from_blacklist(other_peer)
|
||
|
||
# Now the message should be processed
|
||
await pubsub.subscribe(TESTING_TOPIC)
|
||
await pubsub.push_msg(other_peer, msg)
|
||
assert pubsub._is_msg_seen(msg)
|
||
|
||
# If we try to send the same message again, it should be rejected
|
||
# due to seen cache (not blacklist)
|
||
router_publish_called = False
|
||
|
||
async def mock_router_publish(msg_forwarder: ID, pubsub_msg: rpc_pb2.Message):
|
||
nonlocal router_publish_called
|
||
router_publish_called = True
|
||
await trio.lowlevel.checkpoint()
|
||
|
||
original_router_publish = pubsub.router.publish
|
||
pubsub.router.publish = mock_router_publish
|
||
|
||
try:
|
||
await pubsub.push_msg(other_peer, msg)
|
||
assert not router_publish_called, (
|
||
"Duplicate message should be rejected by seen cache"
|
||
)
|
||
finally:
|
||
pubsub.router.publish = original_router_publish
|
||
|
||
|
||
@pytest.mark.trio
|
||
async def test_blacklist_blocks_messages_from_blacklisted_source():
|
||
"""Test that messages from blacklisted source (forwarder) are rejected."""
|
||
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
|
||
pubsub = pubsubs_fsub[0]
|
||
blacklisted_forwarder = pubsubs_fsub[1].my_id
|
||
|
||
# Add the forwarder to blacklist
|
||
pubsub.add_to_blacklist(blacklisted_forwarder)
|
||
|
||
# Create a message
|
||
msg = make_pubsub_msg(
|
||
origin_id=pubsubs_fsub[1].my_id,
|
||
topic_ids=[TESTING_TOPIC],
|
||
data=TESTING_DATA,
|
||
seqno=b"\x00" * 8,
|
||
)
|
||
|
||
# Subscribe to the topic so we can check if message is processed
|
||
await pubsub.subscribe(TESTING_TOPIC)
|
||
|
||
# Track if router.publish is called (it shouldn't be for blacklisted forwarder)
|
||
router_publish_called = False
|
||
|
||
async def mock_router_publish(msg_forwarder: ID, pubsub_msg: rpc_pb2.Message):
|
||
nonlocal router_publish_called
|
||
router_publish_called = True
|
||
await trio.lowlevel.checkpoint()
|
||
|
||
original_router_publish = pubsub.router.publish
|
||
pubsub.router.publish = mock_router_publish
|
||
|
||
try:
|
||
# Attempt to push message from blacklisted forwarder
|
||
await pubsub.push_msg(blacklisted_forwarder, msg)
|
||
|
||
# Verify message was rejected
|
||
assert not router_publish_called, (
|
||
"Router.publish should not be called for blacklisted forwarder"
|
||
)
|
||
assert not pubsub._is_msg_seen(msg), (
|
||
"Message from blacklisted forwarder should not be marked as seen"
|
||
)
|
||
|
||
finally:
|
||
pubsub.router.publish = original_router_publish
|
||
|
||
|
||
@pytest.mark.trio
|
||
async def test_blacklist_tears_down_existing_connection():
|
||
"""
|
||
Verify that if a peer is already in pubsub.peers and pubsub.peer_topics,
|
||
calling add_to_blacklist(peer_id) immediately resets its stream and
|
||
removes it from both places.
|
||
"""
|
||
# Create two pubsub instances (floodsub), so they can connect to each other
|
||
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
|
||
pubsub0, pubsub1 = pubsubs_fsub
|
||
|
||
# 1) Connect peer1 to peer0
|
||
await connect(pubsub0.host, pubsub1.host)
|
||
# Give handle_peer_queue some time to run
|
||
await trio.sleep(0.1)
|
||
|
||
# After connect, pubsub0.peers should contain pubsub1.my_id
|
||
assert pubsub1.my_id in pubsub0.peers
|
||
|
||
# 2) Manually record a subscription from peer1 under TESTING_TOPIC,
|
||
# so that peer1 shows up in pubsub0.peer_topics[TESTING_TOPIC].
|
||
sub_msg = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=TESTING_TOPIC)
|
||
pubsub0.handle_subscription(pubsub1.my_id, sub_msg)
|
||
|
||
assert TESTING_TOPIC in pubsub0.peer_topics
|
||
assert pubsub1.my_id in pubsub0.peer_topics[TESTING_TOPIC]
|
||
|
||
# 3) Now blacklist peer1
|
||
pubsub0.add_to_blacklist(pubsub1.my_id)
|
||
|
||
# Allow the asynchronous teardown task (_teardown_if_connected) to run
|
||
await trio.sleep(0.1)
|
||
|
||
# 4a) pubsub0.peers should no longer contain peer1
|
||
assert pubsub1.my_id not in pubsub0.peers
|
||
|
||
# 4b) pubsub0.peer_topics[TESTING_TOPIC] should no longer contain peer1
|
||
# (or TESTING_TOPIC may have been removed entirely if no other peers remain)
|
||
if TESTING_TOPIC in pubsub0.peer_topics:
|
||
assert pubsub1.my_id not in pubsub0.peer_topics[TESTING_TOPIC]
|
||
else:
|
||
# It’s also fine if the entire topic entry was pruned
|
||
assert TESTING_TOPIC not in pubsub0.peer_topics
|