mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
ft. modernise py-libp2p (#618)
* fix pyproject.toml , add ruff * rm lock * make progress * add poetry lock ignore * fix type issues * fix tcp type errors * fix text example - type error - wrong args * add setuptools to dev * test ci * fix docs build * fix type issues for new_swarm & new_host * fix types in gossipsub * fix type issues in noise * wip: factories * revert factories * fix more type issues * more type fixes * fix: add null checks for noise protocol initialization and key handling * corrected argument-errors in peerId and Multiaddr in peer tests * fix: Noice - remove redundant type casts in BaseNoiseMsgReadWriter * fix: update test_notify.py to use SwarmFactory.create_batch_and_listen, fix type hints, and comment out ClosedStream assertions * Fix type checks for pubsub module Signed-off-by: sukhman <sukhmansinghsaluja@gmail.com> * Fix type checks for pubsub module-tests Signed-off-by: sukhman <sukhmansinghsaluja@gmail.com> * noise: add checks for uninitialized protocol and key states in PatternXX Signed-off-by: varun-r-mallya <varunrmallya@gmail.com> * pubsub: add None checks for optional fields in FloodSub and Pubsub Signed-off-by: varun-r-mallya <varunrmallya@gmail.com> * Fix type hints and improve testing Signed-off-by: varun-r-mallya <varunrmallya@gmail.com> * remove redundant checks Signed-off-by: varun-r-mallya <varunrmallya@gmail.com> * fix build issues * add optional to trio service * fix types * fix type errors * Fix type errors Signed-off-by: varun-r-mallya <varunrmallya@gmail.com> * fixed more-type checks in crypto and peer_data files * wip: factories * replaced union with optional * fix: type-error in interp-utils and peerinfo * replace pyright with pyrefly * add pyrefly.toml * wip: fix multiselect issues * try typecheck * base check * mcache test fixes , typecheck ci update * fix ci * will this work * minor fix * use poetry * fix wokflow * use cache,fix err * fix pyrefly.toml * fix pyrefly.toml * fix cache in ci * deploy commit * add main baseline * update to v5 * improve typecheck ci (#14) * fix typo * remove holepunching code (#16) * fix gossipsub typeerrors (#17) * fix: ensure initiator user includes remote peer id in handshake (#15) * fix ci (#19) * typefix: custom_types | core/peerinfo/test_peer_info | io/abc | pubsub/floodsub | protocol_muxer/multiselect (#18) * fix: Typefixes in PeerInfo (#21) * fix minor type issue (#22) * fix type errors in pubsub (#24) * fix: Minor typefixes in tests (#23) * Fix failing tests for type-fixed test/pubsub (#8) * move pyrefly & ruff to pyproject.toml & rm .project-template (#28) * move the async_context file to tests/core * move crypto test to crypto folder * fix: some typefixes (#25) * fix type errors * fix type issues * fix: update gRPC API usage in autonat_pb2_grpc.py (#31) * md: typecheck ci * rm comments * clean up : from review suggestions * use | None over Optional as per new python standards * drop supporto for py3.9 * newsfragments --------- Signed-off-by: sukhman <sukhmansinghsaluja@gmail.com> Signed-off-by: varun-r-mallya <varunrmallya@gmail.com> Co-authored-by: acul71 <luca.pisani@birdo.net> Co-authored-by: kaneki003 <sakshamchauhan707@gmail.com> Co-authored-by: sukhman <sukhmansinghsaluja@gmail.com> Co-authored-by: varun-r-mallya <varunrmallya@gmail.com> Co-authored-by: varunrmallya <100590632+varun-r-mallya@users.noreply.github.com> Co-authored-by: lla-dane <abhinavagarwalla6@gmail.com> Co-authored-by: Collins <ArtemisfowlX@protonmail.com> Co-authored-by: Abhinav Agarwalla <120122716+lla-dane@users.noreply.github.com> Co-authored-by: guha-rahul <52607971+guha-rahul@users.noreply.github.com> Co-authored-by: Sukhman Singh <63765293+sukhman-sukh@users.noreply.github.com> Co-authored-by: acul71 <34693171+acul71@users.noreply.github.com> Co-authored-by: pacrob <5199899+pacrob@users.noreply.github.com>
This commit is contained in:
@ -20,7 +20,6 @@ async def perform_test(num_nodes, adjacency_map, action_func, assertion_func):
|
||||
such as send crypto and set crypto
|
||||
:param assertion_func: assertions for testing the results of the actions are correct
|
||||
"""
|
||||
|
||||
async with DummyAccountNode.create(num_nodes) as dummy_nodes:
|
||||
# Create connections between nodes according to `adjacency_map`
|
||||
async with trio.open_nursery() as nursery:
|
||||
|
||||
@ -46,7 +46,7 @@ async def test_simple_two_nodes():
|
||||
async def test_timed_cache_two_nodes():
|
||||
# Two nodes using LastSeenCache with a TTL of 120 seconds
|
||||
def get_msg_id(msg):
|
||||
return (msg.data, msg.from_id)
|
||||
return msg.data + msg.from_id
|
||||
|
||||
async with PubsubFactory.create_batch_with_floodsub(
|
||||
2, seen_ttl=120, msg_id_constructor=get_msg_id
|
||||
|
||||
@ -5,6 +5,7 @@ import trio
|
||||
|
||||
from libp2p.pubsub.gossipsub import (
|
||||
PROTOCOL_ID,
|
||||
GossipSub,
|
||||
)
|
||||
from libp2p.tools.utils import (
|
||||
connect,
|
||||
@ -24,7 +25,10 @@ async def test_join():
|
||||
async with PubsubFactory.create_batch_with_gossipsub(
|
||||
4, degree=4, degree_low=3, degree_high=5, heartbeat_interval=1, time_to_live=1
|
||||
) as pubsubs_gsub:
|
||||
gossipsubs = [pubsub.router for pubsub in pubsubs_gsub]
|
||||
gossipsubs = []
|
||||
for pubsub in pubsubs_gsub:
|
||||
if isinstance(pubsub.router, GossipSub):
|
||||
gossipsubs.append(pubsub.router)
|
||||
hosts = [pubsub.host for pubsub in pubsubs_gsub]
|
||||
hosts_indices = list(range(len(pubsubs_gsub)))
|
||||
|
||||
@ -86,7 +90,9 @@ async def test_join():
|
||||
@pytest.mark.trio
|
||||
async def test_leave():
|
||||
async with PubsubFactory.create_batch_with_gossipsub(1) as pubsubs_gsub:
|
||||
gossipsub = pubsubs_gsub[0].router
|
||||
router = pubsubs_gsub[0].router
|
||||
assert isinstance(router, GossipSub)
|
||||
gossipsub = router
|
||||
topic = "test_leave"
|
||||
|
||||
assert topic not in gossipsub.mesh
|
||||
@ -104,7 +110,11 @@ async def test_leave():
|
||||
@pytest.mark.trio
|
||||
async def test_handle_graft(monkeypatch):
|
||||
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
|
||||
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub)
|
||||
gossipsub_routers = []
|
||||
for pubsub in pubsubs_gsub:
|
||||
if isinstance(pubsub.router, GossipSub):
|
||||
gossipsub_routers.append(pubsub.router)
|
||||
gossipsubs = tuple(gossipsub_routers)
|
||||
|
||||
index_alice = 0
|
||||
id_alice = pubsubs_gsub[index_alice].my_id
|
||||
@ -156,7 +166,11 @@ async def test_handle_prune():
|
||||
async with PubsubFactory.create_batch_with_gossipsub(
|
||||
2, heartbeat_interval=3
|
||||
) as pubsubs_gsub:
|
||||
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub)
|
||||
gossipsub_routers = []
|
||||
for pubsub in pubsubs_gsub:
|
||||
if isinstance(pubsub.router, GossipSub):
|
||||
gossipsub_routers.append(pubsub.router)
|
||||
gossipsubs = tuple(gossipsub_routers)
|
||||
|
||||
index_alice = 0
|
||||
id_alice = pubsubs_gsub[index_alice].my_id
|
||||
@ -382,7 +396,9 @@ async def test_mesh_heartbeat(initial_mesh_peer_count, monkeypatch):
|
||||
|
||||
fake_peer_ids = [IDFactory() for _ in range(total_peer_count)]
|
||||
peer_protocol = {peer_id: PROTOCOL_ID for peer_id in fake_peer_ids}
|
||||
monkeypatch.setattr(pubsubs_gsub[0].router, "peer_protocol", peer_protocol)
|
||||
router = pubsubs_gsub[0].router
|
||||
assert isinstance(router, GossipSub)
|
||||
monkeypatch.setattr(router, "peer_protocol", peer_protocol)
|
||||
|
||||
peer_topics = {topic: set(fake_peer_ids)}
|
||||
# Monkeypatch the peer subscriptions
|
||||
@ -394,27 +410,21 @@ async def test_mesh_heartbeat(initial_mesh_peer_count, monkeypatch):
|
||||
mesh_peers = [fake_peer_ids[i] for i in mesh_peer_indices]
|
||||
router_mesh = {topic: set(mesh_peers)}
|
||||
# Monkeypatch our mesh peers
|
||||
monkeypatch.setattr(pubsubs_gsub[0].router, "mesh", router_mesh)
|
||||
monkeypatch.setattr(router, "mesh", router_mesh)
|
||||
|
||||
peers_to_graft, peers_to_prune = pubsubs_gsub[0].router.mesh_heartbeat()
|
||||
if initial_mesh_peer_count > pubsubs_gsub[0].router.degree:
|
||||
peers_to_graft, peers_to_prune = router.mesh_heartbeat()
|
||||
if initial_mesh_peer_count > router.degree:
|
||||
# If number of initial mesh peers is more than `GossipSubDegree`,
|
||||
# we should PRUNE mesh peers
|
||||
assert len(peers_to_graft) == 0
|
||||
assert (
|
||||
len(peers_to_prune)
|
||||
== initial_mesh_peer_count - pubsubs_gsub[0].router.degree
|
||||
)
|
||||
assert len(peers_to_prune) == initial_mesh_peer_count - router.degree
|
||||
for peer in peers_to_prune:
|
||||
assert peer in mesh_peers
|
||||
elif initial_mesh_peer_count < pubsubs_gsub[0].router.degree:
|
||||
elif initial_mesh_peer_count < router.degree:
|
||||
# If number of initial mesh peers is less than `GossipSubDegree`,
|
||||
# we should GRAFT more peers
|
||||
assert len(peers_to_prune) == 0
|
||||
assert (
|
||||
len(peers_to_graft)
|
||||
== pubsubs_gsub[0].router.degree - initial_mesh_peer_count
|
||||
)
|
||||
assert len(peers_to_graft) == router.degree - initial_mesh_peer_count
|
||||
for peer in peers_to_graft:
|
||||
assert peer not in mesh_peers
|
||||
else:
|
||||
@ -436,7 +446,10 @@ async def test_gossip_heartbeat(initial_peer_count, monkeypatch):
|
||||
|
||||
fake_peer_ids = [IDFactory() for _ in range(total_peer_count)]
|
||||
peer_protocol = {peer_id: PROTOCOL_ID for peer_id in fake_peer_ids}
|
||||
monkeypatch.setattr(pubsubs_gsub[0].router, "peer_protocol", peer_protocol)
|
||||
router_obj = pubsubs_gsub[0].router
|
||||
assert isinstance(router_obj, GossipSub)
|
||||
router = router_obj
|
||||
monkeypatch.setattr(router, "peer_protocol", peer_protocol)
|
||||
|
||||
topic_mesh_peer_count = 14
|
||||
# Split into mesh peers and fanout peers
|
||||
@ -453,14 +466,14 @@ async def test_gossip_heartbeat(initial_peer_count, monkeypatch):
|
||||
mesh_peers = [fake_peer_ids[i] for i in mesh_peer_indices]
|
||||
router_mesh = {topic_mesh: set(mesh_peers)}
|
||||
# Monkeypatch our mesh peers
|
||||
monkeypatch.setattr(pubsubs_gsub[0].router, "mesh", router_mesh)
|
||||
monkeypatch.setattr(router, "mesh", router_mesh)
|
||||
fanout_peer_indices = random.sample(
|
||||
range(topic_mesh_peer_count, total_peer_count), initial_peer_count
|
||||
)
|
||||
fanout_peers = [fake_peer_ids[i] for i in fanout_peer_indices]
|
||||
router_fanout = {topic_fanout: set(fanout_peers)}
|
||||
# Monkeypatch our fanout peers
|
||||
monkeypatch.setattr(pubsubs_gsub[0].router, "fanout", router_fanout)
|
||||
monkeypatch.setattr(router, "fanout", router_fanout)
|
||||
|
||||
def window(topic):
|
||||
if topic == topic_mesh:
|
||||
@ -471,20 +484,18 @@ async def test_gossip_heartbeat(initial_peer_count, monkeypatch):
|
||||
return []
|
||||
|
||||
# Monkeypatch the memory cache messages
|
||||
monkeypatch.setattr(pubsubs_gsub[0].router.mcache, "window", window)
|
||||
monkeypatch.setattr(router.mcache, "window", window)
|
||||
|
||||
peers_to_gossip = pubsubs_gsub[0].router.gossip_heartbeat()
|
||||
peers_to_gossip = router.gossip_heartbeat()
|
||||
# If our mesh peer count is less than `GossipSubDegree`, we should gossip to up
|
||||
# to `GossipSubDegree` peers (exclude mesh peers).
|
||||
if topic_mesh_peer_count - initial_peer_count < pubsubs_gsub[0].router.degree:
|
||||
if topic_mesh_peer_count - initial_peer_count < router.degree:
|
||||
# The same goes for fanout so it's two times the number of peers to gossip.
|
||||
assert len(peers_to_gossip) == 2 * (
|
||||
topic_mesh_peer_count - initial_peer_count
|
||||
)
|
||||
elif (
|
||||
topic_mesh_peer_count - initial_peer_count >= pubsubs_gsub[0].router.degree
|
||||
):
|
||||
assert len(peers_to_gossip) == 2 * (pubsubs_gsub[0].router.degree)
|
||||
elif topic_mesh_peer_count - initial_peer_count >= router.degree:
|
||||
assert len(peers_to_gossip) == 2 * (router.degree)
|
||||
|
||||
for peer in peers_to_gossip:
|
||||
if peer in peer_topics[topic_mesh]:
|
||||
|
||||
@ -4,6 +4,9 @@ import trio
|
||||
from libp2p.peer.peerinfo import (
|
||||
info_from_p2p_addr,
|
||||
)
|
||||
from libp2p.pubsub.gossipsub import (
|
||||
GossipSub,
|
||||
)
|
||||
from libp2p.tools.utils import (
|
||||
connect,
|
||||
)
|
||||
@ -82,31 +85,33 @@ async def test_reject_graft():
|
||||
await pubsubs_gsub_1[0].router.join(topic)
|
||||
|
||||
# Pre-Graft assertions
|
||||
assert (
|
||||
topic in pubsubs_gsub_0[0].router.mesh
|
||||
), "topic not in mesh for gossipsub 0"
|
||||
assert (
|
||||
topic in pubsubs_gsub_1[0].router.mesh
|
||||
), "topic not in mesh for gossipsub 1"
|
||||
assert (
|
||||
host_1.get_id() not in pubsubs_gsub_0[0].router.mesh[topic]
|
||||
), "gossipsub 1 in mesh topic for gossipsub 0"
|
||||
assert (
|
||||
host_0.get_id() not in pubsubs_gsub_1[0].router.mesh[topic]
|
||||
), "gossipsub 0 in mesh topic for gossipsub 1"
|
||||
assert topic in pubsubs_gsub_0[0].router.mesh, (
|
||||
"topic not in mesh for gossipsub 0"
|
||||
)
|
||||
assert topic in pubsubs_gsub_1[0].router.mesh, (
|
||||
"topic not in mesh for gossipsub 1"
|
||||
)
|
||||
assert host_1.get_id() not in pubsubs_gsub_0[0].router.mesh[topic], (
|
||||
"gossipsub 1 in mesh topic for gossipsub 0"
|
||||
)
|
||||
assert host_0.get_id() not in pubsubs_gsub_1[0].router.mesh[topic], (
|
||||
"gossipsub 0 in mesh topic for gossipsub 1"
|
||||
)
|
||||
|
||||
# Gossipsub 1 emits a graft request to Gossipsub 0
|
||||
await pubsubs_gsub_0[0].router.emit_graft(topic, host_1.get_id())
|
||||
router_obj = pubsubs_gsub_0[0].router
|
||||
assert isinstance(router_obj, GossipSub)
|
||||
await router_obj.emit_graft(topic, host_1.get_id())
|
||||
|
||||
await trio.sleep(1)
|
||||
|
||||
# Post-Graft assertions
|
||||
assert (
|
||||
host_1.get_id() not in pubsubs_gsub_0[0].router.mesh[topic]
|
||||
), "gossipsub 1 in mesh topic for gossipsub 0"
|
||||
assert (
|
||||
host_0.get_id() not in pubsubs_gsub_1[0].router.mesh[topic]
|
||||
), "gossipsub 0 in mesh topic for gossipsub 1"
|
||||
assert host_1.get_id() not in pubsubs_gsub_0[0].router.mesh[topic], (
|
||||
"gossipsub 1 in mesh topic for gossipsub 0"
|
||||
)
|
||||
assert host_0.get_id() not in pubsubs_gsub_1[0].router.mesh[topic], (
|
||||
"gossipsub 0 in mesh topic for gossipsub 1"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Test failed with error: {e}")
|
||||
@ -139,12 +144,12 @@ async def test_heartbeat_reconnect():
|
||||
await trio.sleep(1)
|
||||
|
||||
# Verify initial connection
|
||||
assert (
|
||||
host_1.get_id() in pubsubs_gsub_0[0].peers
|
||||
), "Initial connection not established for gossipsub 0"
|
||||
assert (
|
||||
host_0.get_id() in pubsubs_gsub_1[0].peers
|
||||
), "Initial connection not established for gossipsub 0"
|
||||
assert host_1.get_id() in pubsubs_gsub_0[0].peers, (
|
||||
"Initial connection not established for gossipsub 0"
|
||||
)
|
||||
assert host_0.get_id() in pubsubs_gsub_1[0].peers, (
|
||||
"Initial connection not established for gossipsub 0"
|
||||
)
|
||||
|
||||
# Simulate disconnection
|
||||
await host_0.disconnect(host_1.get_id())
|
||||
@ -153,17 +158,17 @@ async def test_heartbeat_reconnect():
|
||||
await trio.sleep(1)
|
||||
|
||||
# Verify that peers are removed after disconnection
|
||||
assert (
|
||||
host_0.get_id() not in pubsubs_gsub_1[0].peers
|
||||
), "Peer 0 still in gossipsub 1 after disconnection"
|
||||
assert host_0.get_id() not in pubsubs_gsub_1[0].peers, (
|
||||
"Peer 0 still in gossipsub 1 after disconnection"
|
||||
)
|
||||
|
||||
# Wait for heartbeat to reestablish connection
|
||||
await trio.sleep(2)
|
||||
|
||||
# Verify connection reestablishment
|
||||
assert (
|
||||
host_0.get_id() in pubsubs_gsub_1[0].peers
|
||||
), "Reconnection not established for gossipsub 0"
|
||||
assert host_0.get_id() in pubsubs_gsub_1[0].peers, (
|
||||
"Reconnection not established for gossipsub 0"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Test failed with error: {e}")
|
||||
|
||||
@ -1,15 +1,26 @@
|
||||
from collections.abc import (
|
||||
Sequence,
|
||||
)
|
||||
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.pubsub.mcache import (
|
||||
MessageCache,
|
||||
)
|
||||
from libp2p.pubsub.pb import (
|
||||
rpc_pb2,
|
||||
)
|
||||
|
||||
|
||||
class Msg:
|
||||
__slots__ = ["topicIDs", "seqno", "from_id"]
|
||||
|
||||
def __init__(self, topicIDs, seqno, from_id):
|
||||
self.topicIDs = topicIDs
|
||||
self.seqno = seqno
|
||||
self.from_id = from_id
|
||||
def make_msg(
|
||||
topic_ids: Sequence[str],
|
||||
seqno: bytes,
|
||||
from_id: ID,
|
||||
) -> rpc_pb2.Message:
|
||||
return rpc_pb2.Message(
|
||||
from_id=from_id.to_bytes(), seqno=seqno, topicIDs=list(topic_ids)
|
||||
)
|
||||
|
||||
|
||||
def test_mcache():
|
||||
@ -19,7 +30,7 @@ def test_mcache():
|
||||
msgs = []
|
||||
|
||||
for i in range(60):
|
||||
msgs.append(Msg(["test"], i, "test"))
|
||||
msgs.append(make_msg(["test"], i.to_bytes(1, "big"), ID(b"test")))
|
||||
|
||||
for i in range(10):
|
||||
mcache.put(msgs[i])
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from contextlib import (
|
||||
contextmanager,
|
||||
)
|
||||
import inspect
|
||||
from typing import (
|
||||
NamedTuple,
|
||||
)
|
||||
@ -14,6 +15,9 @@ from libp2p.exceptions import (
|
||||
from libp2p.network.stream.exceptions import (
|
||||
StreamEOF,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.pubsub.pb import (
|
||||
rpc_pb2,
|
||||
)
|
||||
@ -121,16 +125,18 @@ async def test_set_and_remove_topic_validator():
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
is_sync_validator_called = False
|
||||
|
||||
def sync_validator(peer_id, msg):
|
||||
def sync_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool:
|
||||
nonlocal is_sync_validator_called
|
||||
is_sync_validator_called = True
|
||||
return True
|
||||
|
||||
is_async_validator_called = False
|
||||
|
||||
async def async_validator(peer_id, msg):
|
||||
async def async_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool:
|
||||
nonlocal is_async_validator_called
|
||||
is_async_validator_called = True
|
||||
await trio.lowlevel.checkpoint()
|
||||
return True
|
||||
|
||||
topic = "TEST_VALIDATOR"
|
||||
|
||||
@ -144,7 +150,13 @@ async def test_set_and_remove_topic_validator():
|
||||
assert not topic_validator.is_async
|
||||
|
||||
# Validate with sync validator
|
||||
topic_validator.validator(peer_id=IDFactory(), msg="msg")
|
||||
test_msg = make_pubsub_msg(
|
||||
origin_id=IDFactory(),
|
||||
topic_ids=[topic],
|
||||
data=b"test",
|
||||
seqno=b"\x00" * 8,
|
||||
)
|
||||
topic_validator.validator(IDFactory(), test_msg)
|
||||
|
||||
assert is_sync_validator_called
|
||||
assert not is_async_validator_called
|
||||
@ -158,7 +170,20 @@ async def test_set_and_remove_topic_validator():
|
||||
assert topic_validator.is_async
|
||||
|
||||
# Validate with async validator
|
||||
await topic_validator.validator(peer_id=IDFactory(), msg="msg")
|
||||
test_msg = make_pubsub_msg(
|
||||
origin_id=IDFactory(),
|
||||
topic_ids=[topic],
|
||||
data=b"test",
|
||||
seqno=b"\x00" * 8,
|
||||
)
|
||||
validator = topic_validator.validator
|
||||
if topic_validator.is_async:
|
||||
import inspect
|
||||
|
||||
if inspect.iscoroutinefunction(validator):
|
||||
await validator(IDFactory(), test_msg)
|
||||
else:
|
||||
validator(IDFactory(), test_msg)
|
||||
|
||||
assert is_async_validator_called
|
||||
assert not is_sync_validator_called
|
||||
@ -170,20 +195,18 @@ async def test_set_and_remove_topic_validator():
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_get_msg_validators():
|
||||
calls = [0, 0] # [sync, async]
|
||||
|
||||
def sync_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool:
|
||||
calls[0] += 1
|
||||
return True
|
||||
|
||||
async def async_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool:
|
||||
calls[1] += 1
|
||||
await trio.lowlevel.checkpoint()
|
||||
return True
|
||||
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
times_sync_validator_called = 0
|
||||
|
||||
def sync_validator(peer_id, msg):
|
||||
nonlocal times_sync_validator_called
|
||||
times_sync_validator_called += 1
|
||||
|
||||
times_async_validator_called = 0
|
||||
|
||||
async def async_validator(peer_id, msg):
|
||||
nonlocal times_async_validator_called
|
||||
times_async_validator_called += 1
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
topic_1 = "TEST_VALIDATOR_1"
|
||||
topic_2 = "TEST_VALIDATOR_2"
|
||||
topic_3 = "TEST_VALIDATOR_3"
|
||||
@ -204,13 +227,15 @@ async def test_get_msg_validators():
|
||||
|
||||
topic_validators = pubsubs_fsub[0].get_msg_validators(msg)
|
||||
for topic_validator in topic_validators:
|
||||
validator = topic_validator.validator
|
||||
if topic_validator.is_async:
|
||||
await topic_validator.validator(peer_id=IDFactory(), msg="msg")
|
||||
if inspect.iscoroutinefunction(validator):
|
||||
await validator(IDFactory(), msg)
|
||||
else:
|
||||
topic_validator.validator(peer_id=IDFactory(), msg="msg")
|
||||
validator(IDFactory(), msg)
|
||||
|
||||
assert times_sync_validator_called == 2
|
||||
assert times_async_validator_called == 1
|
||||
assert calls[0] == 2
|
||||
assert calls[1] == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -221,17 +246,17 @@ async def test_get_msg_validators():
|
||||
async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed):
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
|
||||
def passed_sync_validator(peer_id, msg):
|
||||
def passed_sync_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool:
|
||||
return True
|
||||
|
||||
def failed_sync_validator(peer_id, msg):
|
||||
def failed_sync_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool:
|
||||
return False
|
||||
|
||||
async def passed_async_validator(peer_id, msg):
|
||||
async def passed_async_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool:
|
||||
await trio.lowlevel.checkpoint()
|
||||
return True
|
||||
|
||||
async def failed_async_validator(peer_id, msg):
|
||||
async def failed_async_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool:
|
||||
await trio.lowlevel.checkpoint()
|
||||
return False
|
||||
|
||||
@ -297,11 +322,12 @@ async def test_continuously_read_stream(monkeypatch, nursery, security_protocol)
|
||||
m.setattr(pubsubs_fsub[0].router, "handle_rpc", mock_handle_rpc)
|
||||
yield Events(event_push_msg, event_handle_subscription, event_handle_rpc)
|
||||
|
||||
async with PubsubFactory.create_batch_with_floodsub(
|
||||
1, security_protocol=security_protocol
|
||||
) as pubsubs_fsub, net_stream_pair_factory(
|
||||
security_protocol=security_protocol
|
||||
) as stream_pair:
|
||||
async with (
|
||||
PubsubFactory.create_batch_with_floodsub(
|
||||
1, security_protocol=security_protocol
|
||||
) as pubsubs_fsub,
|
||||
net_stream_pair_factory(security_protocol=security_protocol) as stream_pair,
|
||||
):
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
# Kick off the task `continuously_read_stream`
|
||||
nursery.start_soon(pubsubs_fsub[0].continuously_read_stream, stream_pair[0])
|
||||
@ -429,11 +455,12 @@ async def test_handle_talk():
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_message_all_peers(monkeypatch, security_protocol):
|
||||
async with PubsubFactory.create_batch_with_floodsub(
|
||||
1, security_protocol=security_protocol
|
||||
) as pubsubs_fsub, net_stream_pair_factory(
|
||||
security_protocol=security_protocol
|
||||
) as stream_pair:
|
||||
async with (
|
||||
PubsubFactory.create_batch_with_floodsub(
|
||||
1, security_protocol=security_protocol
|
||||
) as pubsubs_fsub,
|
||||
net_stream_pair_factory(security_protocol=security_protocol) as stream_pair,
|
||||
):
|
||||
peer_id = IDFactory()
|
||||
mock_peers = {peer_id: stream_pair[0]}
|
||||
with monkeypatch.context() as m:
|
||||
@ -530,15 +557,15 @@ async def test_publish_push_msg_is_called(monkeypatch):
|
||||
await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA)
|
||||
await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA)
|
||||
|
||||
assert (
|
||||
len(msgs) == 2
|
||||
), "`push_msg` should be called every time `publish` is called"
|
||||
assert len(msgs) == 2, (
|
||||
"`push_msg` should be called every time `publish` is called"
|
||||
)
|
||||
assert (msg_forwarders[0] == msg_forwarders[1]) and (
|
||||
msg_forwarders[1] == pubsubs_fsub[0].my_id
|
||||
)
|
||||
assert (
|
||||
msgs[0].seqno != msgs[1].seqno
|
||||
), "`seqno` should be different every time"
|
||||
assert msgs[0].seqno != msgs[1].seqno, (
|
||||
"`seqno` should be different every time"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
@ -611,7 +638,7 @@ async def test_push_msg(monkeypatch):
|
||||
# Test: add a topic validator and `push_msg` the message that
|
||||
# does not pass the validation.
|
||||
# `router_publish` is not called then.
|
||||
def failed_sync_validator(peer_id, msg):
|
||||
def failed_sync_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool:
|
||||
return False
|
||||
|
||||
pubsubs_fsub[0].set_topic_validator(
|
||||
@ -659,6 +686,9 @@ async def test_strict_signing_failed_validation(monkeypatch):
|
||||
seqno=b"\x00" * 8,
|
||||
)
|
||||
priv_key = pubsubs_fsub[0].sign_key
|
||||
assert priv_key is not None, (
|
||||
"Private key should not be None when strict_signing=True"
|
||||
)
|
||||
signature = priv_key.sign(
|
||||
PUBSUB_SIGNING_PREFIX.encode() + msg.SerializeToString()
|
||||
)
|
||||
@ -803,15 +833,15 @@ async def test_blacklist_blocks_new_peer_connections(monkeypatch):
|
||||
await pubsub._handle_new_peer(blacklisted_peer)
|
||||
|
||||
# Verify that both new_stream and router.add_peer was not called
|
||||
assert (
|
||||
not new_stream_called
|
||||
), "new_stream should be not be called to get hello packet"
|
||||
assert (
|
||||
not router_add_peer_called
|
||||
), "Router.add_peer should not be called for blacklisted peer"
|
||||
assert (
|
||||
blacklisted_peer not in pubsub.peers
|
||||
), "Blacklisted peer should not be in peers dict"
|
||||
assert not new_stream_called, (
|
||||
"new_stream should be not be called to get hello packet"
|
||||
)
|
||||
assert not router_add_peer_called, (
|
||||
"Router.add_peer should not be called for blacklisted peer"
|
||||
)
|
||||
assert blacklisted_peer not in pubsub.peers, (
|
||||
"Blacklisted peer should not be in peers dict"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
@ -838,7 +868,7 @@ async def test_blacklist_blocks_messages_from_blacklisted_originator():
|
||||
# Track if router.publish is called
|
||||
router_publish_called = False
|
||||
|
||||
async def mock_router_publish(*args, **kwargs):
|
||||
async def mock_router_publish(msg_forwarder: ID, pubsub_msg: rpc_pb2.Message):
|
||||
nonlocal router_publish_called
|
||||
router_publish_called = True
|
||||
await trio.lowlevel.checkpoint()
|
||||
@ -851,12 +881,12 @@ async def test_blacklist_blocks_messages_from_blacklisted_originator():
|
||||
await pubsub.push_msg(blacklisted_originator, msg)
|
||||
|
||||
# Verify message was rejected
|
||||
assert (
|
||||
not router_publish_called
|
||||
), "Router.publish should not be called for blacklisted originator"
|
||||
assert not pubsub._is_msg_seen(
|
||||
msg
|
||||
), "Message from blacklisted originator should not be marked as seen"
|
||||
assert not router_publish_called, (
|
||||
"Router.publish should not be called for blacklisted originator"
|
||||
)
|
||||
assert not pubsub._is_msg_seen(msg), (
|
||||
"Message from blacklisted originator should not be marked as seen"
|
||||
)
|
||||
|
||||
finally:
|
||||
pubsub.router.publish = original_router_publish
|
||||
@ -894,8 +924,8 @@ async def test_blacklist_allows_non_blacklisted_peers():
|
||||
# Track router.publish calls
|
||||
router_publish_calls = []
|
||||
|
||||
async def mock_router_publish(*args, **kwargs):
|
||||
router_publish_calls.append(args)
|
||||
async def mock_router_publish(msg_forwarder: ID, pubsub_msg: rpc_pb2.Message):
|
||||
router_publish_calls.append((msg_forwarder, pubsub_msg))
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
original_router_publish = pubsub.router.publish
|
||||
@ -909,15 +939,15 @@ async def test_blacklist_allows_non_blacklisted_peers():
|
||||
await pubsub.push_msg(allowed_peer, msg_from_blacklisted)
|
||||
|
||||
# Verify only allowed message was processed
|
||||
assert (
|
||||
len(router_publish_calls) == 1
|
||||
), "Only one message should be processed"
|
||||
assert pubsub._is_msg_seen(
|
||||
msg_from_allowed
|
||||
), "Allowed message should be marked as seen"
|
||||
assert not pubsub._is_msg_seen(
|
||||
msg_from_blacklisted
|
||||
), "Blacklisted message should not be marked as seen"
|
||||
assert len(router_publish_calls) == 1, (
|
||||
"Only one message should be processed"
|
||||
)
|
||||
assert pubsub._is_msg_seen(msg_from_allowed), (
|
||||
"Allowed message should be marked as seen"
|
||||
)
|
||||
assert not pubsub._is_msg_seen(msg_from_blacklisted), (
|
||||
"Blacklisted message should not be marked as seen"
|
||||
)
|
||||
|
||||
# Verify subscription received the allowed message
|
||||
received_msg = await sub.get()
|
||||
@ -960,7 +990,7 @@ async def test_blacklist_integration_with_existing_functionality():
|
||||
# due to seen cache (not blacklist)
|
||||
router_publish_called = False
|
||||
|
||||
async def mock_router_publish(*args, **kwargs):
|
||||
async def mock_router_publish(msg_forwarder: ID, pubsub_msg: rpc_pb2.Message):
|
||||
nonlocal router_publish_called
|
||||
router_publish_called = True
|
||||
await trio.lowlevel.checkpoint()
|
||||
@ -970,9 +1000,9 @@ async def test_blacklist_integration_with_existing_functionality():
|
||||
|
||||
try:
|
||||
await pubsub.push_msg(other_peer, msg)
|
||||
assert (
|
||||
not router_publish_called
|
||||
), "Duplicate message should be rejected by seen cache"
|
||||
assert not router_publish_called, (
|
||||
"Duplicate message should be rejected by seen cache"
|
||||
)
|
||||
finally:
|
||||
pubsub.router.publish = original_router_publish
|
||||
|
||||
@ -1001,7 +1031,7 @@ async def test_blacklist_blocks_messages_from_blacklisted_source():
|
||||
# Track if router.publish is called (it shouldn't be for blacklisted forwarder)
|
||||
router_publish_called = False
|
||||
|
||||
async def mock_router_publish(*args, **kwargs):
|
||||
async def mock_router_publish(msg_forwarder: ID, pubsub_msg: rpc_pb2.Message):
|
||||
nonlocal router_publish_called
|
||||
router_publish_called = True
|
||||
await trio.lowlevel.checkpoint()
|
||||
@ -1014,12 +1044,12 @@ async def test_blacklist_blocks_messages_from_blacklisted_source():
|
||||
await pubsub.push_msg(blacklisted_forwarder, msg)
|
||||
|
||||
# Verify message was rejected
|
||||
assert (
|
||||
not router_publish_called
|
||||
), "Router.publish should not be called for blacklisted forwarder"
|
||||
assert not pubsub._is_msg_seen(
|
||||
msg
|
||||
), "Message from blacklisted forwarder should not be marked as seen"
|
||||
assert not router_publish_called, (
|
||||
"Router.publish should not be called for blacklisted forwarder"
|
||||
)
|
||||
assert not pubsub._is_msg_seen(msg), (
|
||||
"Message from blacklisted forwarder should not be marked as seen"
|
||||
)
|
||||
|
||||
finally:
|
||||
pubsub.router.publish = original_router_publish
|
||||
|
||||
Reference in New Issue
Block a user