diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index e25d75f0..29d544eb 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -43,6 +43,9 @@ class SwarmConn(INetConn): # We *could* optimize this but it really isn't worth it. for stream in self.streams: await stream.reset() + # Force context switch for stream handlers to process the stream reset event we just emit + # before we cancel the stream handler tasks. + await asyncio.sleep(0.1) for task in self._tasks: task.cancel() diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index a697d7da..7bb40cee 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -248,8 +248,8 @@ class Swarm(INetwork): # TODO: Should be changed to close multisple connections, # if we have several connections per peer in the future. connection = self.connections[peer_id] - # NOTE: `connection.close` performs `del self.connections[peer_id]` for us, - # so we don't need to remove the entry here. + # NOTE: `connection.close` will perform `del self.connections[peer_id]` + # and `notify_disconnected` for us. await connection.close() logger.debug("successfully close the connection to peer %s", peer_id) diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 7be2d288..d5e02677 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -32,6 +32,7 @@ from .validators import signature_validator if TYPE_CHECKING: from .pubsub_router_interface import IPubsubRouter # noqa: F401 + from typing import Any # noqa: F401 logger = logging.getLogger("libp2p.pubsub") @@ -60,6 +61,7 @@ class Pubsub: router: "IPubsubRouter" peer_queue: "asyncio.Queue[ID]" + dead_peer_queue: "asyncio.Queue[ID]" protocols: List[TProtocol] @@ -78,6 +80,8 @@ class Pubsub: # TODO: Be sure it is increased atomically everytime. counter: int # uint64 + _tasks: List["asyncio.Future[Any]"] + def __init__( self, host: IHost, router: "IPubsubRouter", my_id: ID, cache_size: int = None ) -> None: @@ -100,7 +104,10 @@ class Pubsub: # Register a notifee self.peer_queue = asyncio.Queue() - self.host.get_network().register_notifee(PubsubNotifee(self.peer_queue)) + self.dead_peer_queue = asyncio.Queue() + self.host.get_network().register_notifee( + PubsubNotifee(self.peer_queue, self.dead_peer_queue) + ) # Register stream handlers for each pubsub router protocol to handle # the pubsub streams opened on those protocols @@ -135,8 +142,10 @@ class Pubsub: self.counter = time.time_ns() + self._tasks = [] # Call handle peer to keep waiting for updates to peer queue - asyncio.ensure_future(self.handle_peer_queue()) + self._tasks.append(asyncio.ensure_future(self.handle_peer_queue())) + self._tasks.append(asyncio.ensure_future(self.handle_dead_peer_queue())) def get_hello_packet(self) -> rpc_pb2.RPC: """Generate subscription message with all topics we are subscribed to @@ -158,13 +167,7 @@ class Pubsub: peer_id = stream.muxed_conn.peer_id while True: - try: - incoming: bytes = await read_varint_prefixed_bytes(stream) - except (ParseError, IncompleteReadError) as error: - logger.debug( - "read corrupted data from peer %s, error=%s", peer_id, error - ) - continue + incoming: bytes = await read_varint_prefixed_bytes(stream) rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC() rpc_incoming.ParseFromString(incoming) if rpc_incoming.publish: @@ -175,7 +178,11 @@ class Pubsub: logger.debug( "received `publish` message %s from peer %s", msg, peer_id ) - asyncio.ensure_future(self.push_msg(msg_forwarder=peer_id, msg=msg)) + self._tasks.append( + asyncio.ensure_future( + self.push_msg(msg_forwarder=peer_id, msg=msg) + ) + ) if rpc_incoming.subscriptions: # deal with RPC.subscriptions @@ -247,13 +254,19 @@ class Pubsub: :param stream: newly created stream """ + peer_id = stream.muxed_conn.peer_id + try: await self.continuously_read_stream(stream) - except (StreamEOF, StreamReset) as error: - logger.debug("fail to read from stream, error=%s", error) + except (StreamEOF, StreamReset, ParseError, IncompleteReadError) as error: + logger.debug( + "fail to read from peer %s, error=%s," + "closing the stream and remove the peer from record", + peer_id, + error, + ) await stream.reset() - # TODO: what to do when the stream is terminated? - # disconnect the peer? + self._handle_dead_peer(peer_id) async def _handle_new_peer(self, peer_id: ID) -> None: try: @@ -277,6 +290,19 @@ class Pubsub: logger.debug("added new peer %s", peer_id) + def _handle_dead_peer(self, peer_id: ID) -> None: + if peer_id not in self.peers: + return + del self.peers[peer_id] + + for topic in self.peer_topics: + if peer_id in self.peer_topics[topic]: + self.peer_topics[topic].remove(peer_id) + + self.router.remove_peer(peer_id) + + logger.debug("removed dead peer %s", peer_id) + async def handle_peer_queue(self) -> None: """ Continuously read from peer queue and each time a new peer is found, @@ -285,14 +311,17 @@ class Pubsub: pubsub protocols we support """ while True: - peer_id: ID = await self.peer_queue.get() - # Add Peer + self._tasks.append(asyncio.ensure_future(self._handle_new_peer(peer_id))) - asyncio.ensure_future(self._handle_new_peer(peer_id)) - # Force context switch - await asyncio.sleep(0) + async def handle_dead_peer_queue(self) -> None: + """Continuously read from dead peer queue and close the stream between + that peer and remove peer info from pubsub and pubsub router.""" + while True: + peer_id: ID = await self.dead_peer_queue.get() + # Remove Peer + self._handle_dead_peer(peer_id) def handle_subscription( self, origin_id: ID, sub_message: rpc_pb2.RPC.SubOpts @@ -514,3 +543,11 @@ class Pubsub: if not self.my_topics: return False return any(topic in self.my_topics for topic in msg.topicIDs) + + async def close(self) -> None: + for task in self._tasks: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass diff --git a/libp2p/pubsub/pubsub_notifee.py b/libp2p/pubsub/pubsub_notifee.py index 19be6123..6afa9ad2 100644 --- a/libp2p/pubsub/pubsub_notifee.py +++ b/libp2p/pubsub/pubsub_notifee.py @@ -15,13 +15,21 @@ if TYPE_CHECKING: class PubsubNotifee(INotifee): initiator_peers_queue: "asyncio.Queue[ID]" + dead_peers_queue: "asyncio.Queue[ID]" - def __init__(self, initiator_peers_queue: "asyncio.Queue[ID]") -> None: + def __init__( + self, + initiator_peers_queue: "asyncio.Queue[ID]", + dead_peers_queue: "asyncio.Queue[ID]", + ) -> None: """ :param initiator_peers_queue: queue to add new peers to so that pubsub can process new peers after we connect to them + :param dead_peers_queue: queue to add dead peers to so that pubsub + can process dead peers after we disconnect from each other """ self.initiator_peers_queue = initiator_peers_queue + self.dead_peers_queue = dead_peers_queue async def opened_stream(self, network: INetwork, stream: INetStream) -> None: pass @@ -41,7 +49,14 @@ class PubsubNotifee(INotifee): await self.initiator_peers_queue.put(conn.muxed_conn.peer_id) async def disconnected(self, network: INetwork, conn: INetConn) -> None: - pass + """ + Add peer_id to dead_peers_queue, so that pubsub and its router can + remove this peer_id and close the stream inbetween. + + :param network: network the connection was opened on + :param conn: connection that was opened + """ + await self.dead_peers_queue.put(conn.muxed_conn.peer_id) async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None: pass