diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index e0026de8..3df0ac29 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -86,6 +86,9 @@ class Pubsub(IPubsub, Service): strict_signing: bool sign_key: PrivateKey + event_handle_peer_queue_started: trio.Event + event_handle_dead_peer_queue_started: trio.Event + def __init__( self, host: IHost, @@ -159,6 +162,9 @@ class Pubsub(IPubsub, Service): self.counter = int(time.time()) + self.event_handle_peer_queue_started = trio.Event() + self.event_handle_dead_peer_queue_started = trio.Event() + async def run(self) -> None: self.manager.run_daemon_task(self.handle_peer_queue) self.manager.run_daemon_task(self.handle_dead_peer_queue) @@ -331,12 +337,14 @@ class Pubsub(IPubsub, Service): """Continuously read from peer queue and each time a new peer is found, open a stream to the peer using a supported pubsub protocol pubsub protocols we support.""" + self.event_handle_peer_queue_started.set() async with self.peer_receive_channel: async for peer_id in self.peer_receive_channel: # Add Peer self.manager.run_task(self._handle_new_peer, peer_id) async def handle_dead_peer_queue(self) -> None: + self.event_handle_dead_peer_queue_started.set() """Continuously read from dead peer channel and close the stream between that peer and remove peer info from pubsub and pubsub router.""" diff --git a/libp2p/tools/factories.py b/libp2p/tools/factories.py index ae99f345..a20b59fa 100644 --- a/libp2p/tools/factories.py +++ b/libp2p/tools/factories.py @@ -245,6 +245,8 @@ class PubsubFactory(factory.Factory): strict_signing=strict_signing, ) async with background_trio_service(pubsub): + await pubsub.event_handle_peer_queue_started.wait() + await pubsub.event_handle_dead_peer_queue_started.wait() yield pubsub @classmethod diff --git a/tests/pubsub/test_gossipsub.py b/tests/pubsub/test_gossipsub.py index 4630c85f..a423fbd6 100644 --- a/tests/pubsub/test_gossipsub.py +++ b/tests/pubsub/test_gossipsub.py @@ -106,6 +106,7 @@ async def test_handle_graft(monkeypatch): async def emit_prune(topic, sender_peer_id): event_emit_prune.set() + await trio.hazmat.checkpoint() monkeypatch.setattr(gossipsubs[index_bob], "emit_prune", emit_prune) diff --git a/tests/pubsub/test_pubsub.py b/tests/pubsub/test_pubsub.py index c05aecbb..4bea6dd2 100644 --- a/tests/pubsub/test_pubsub.py +++ b/tests/pubsub/test_pubsub.py @@ -103,6 +103,7 @@ async def test_set_and_remove_topic_validator(): async def async_validator(peer_id, msg): nonlocal is_async_validator_called is_async_validator_called = True + await trio.hazmat.checkpoint() topic = "TEST_VALIDATOR" @@ -237,6 +238,7 @@ async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed): @pytest.mark.trio async def test_continuously_read_stream(monkeypatch, nursery, is_host_secure): async def wait_for_event_occurring(event): + await trio.hazmat.checkpoint() with trio.fail_after(0.1): await event.wait() @@ -418,6 +420,7 @@ async def test_publish(monkeypatch): async def push_msg(msg_forwarder, msg): msg_forwarders.append(msg_forwarder) msgs.append(msg) + await trio.hazmat.checkpoint() async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: with monkeypatch.context() as m: @@ -454,7 +457,7 @@ async def test_push_msg(monkeypatch): async def router_publish(*args, **kwargs): event.set() - await trio.sleep(0) + await trio.hazmat.checkpoint() with monkeypatch.context() as m: m.setattr(pubsubs_fsub[0].router, "publish", router_publish) @@ -555,6 +558,7 @@ async def test_strict_signing_failed_validation(monkeypatch): # Use router publish to check if `push_msg` succeed. async def router_publish(*args, **kwargs): + await trio.hazmat.checkpoint() # The event will only be set if `push_msg` succeed. event.set()