diff --git a/libp2p/__init__.py b/libp2p/__init__.py index fa7ebefd..542a71c1 100644 --- a/libp2p/__init__.py +++ b/libp2p/__init__.py @@ -84,6 +84,8 @@ DEFAULT_MUXER = "YAMUX" # Multiplexer options MUXER_YAMUX = "YAMUX" MUXER_MPLEX = "MPLEX" +DEFAULT_NEGOTIATE_TIMEOUT = 5 + def set_default_muxer(muxer_name: Literal["YAMUX", "MPLEX"]) -> None: @@ -249,6 +251,7 @@ def new_host( muxer_preference: Literal["YAMUX", "MPLEX"] | None = None, listen_addrs: Sequence[multiaddr.Multiaddr] | None = None, enable_mDNS: bool = False, + negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT, ) -> IHost: """ Create a new libp2p host based on the given parameters. @@ -274,6 +277,6 @@ def new_host( if disc_opt is not None: return RoutedHost(swarm, disc_opt, enable_mDNS) - return BasicHost(swarm, enable_mDNS) + return BasicHost(network=swarm,enable_mDNS=enable_mDNS , negotitate_timeout=negotiate_timeout) __version__ = __version("libp2p") diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index 798186cf..cc93be08 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -71,6 +71,7 @@ if TYPE_CHECKING: logger = logging.getLogger("libp2p.network.basic_host") +DEFAULT_NEGOTIATE_TIMEOUT = 5 class BasicHost(IHost): @@ -92,10 +93,12 @@ class BasicHost(IHost): network: INetworkService, enable_mDNS: bool = False, default_protocols: Optional["OrderedDict[TProtocol, StreamHandlerFn]"] = None, + negotitate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT, ) -> None: self._network = network self._network.set_stream_handler(self._swarm_stream_handler) self.peerstore = self._network.peerstore + self.negotiate_timeout = negotitate_timeout # Protocol muxing default_protocols = default_protocols or get_default_protocols(self) self.multiselect = Multiselect(dict(default_protocols.items())) @@ -189,7 +192,10 @@ class BasicHost(IHost): self.multiselect.add_handler(protocol_id, stream_handler) async def new_stream( - self, peer_id: ID, protocol_ids: Sequence[TProtocol] + self, + peer_id: ID, + protocol_ids: Sequence[TProtocol], + negotitate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT, ) -> INetStream: """ :param peer_id: peer_id that host is connecting @@ -201,7 +207,9 @@ class BasicHost(IHost): # Perform protocol muxing to determine protocol to use try: selected_protocol = await self.multiselect_client.select_one_of( - list(protocol_ids), MultiselectCommunicator(net_stream) + list(protocol_ids), + MultiselectCommunicator(net_stream), + negotitate_timeout, ) except MultiselectClientError as error: logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error) @@ -211,7 +219,12 @@ class BasicHost(IHost): net_stream.set_protocol(selected_protocol) return net_stream - async def send_command(self, peer_id: ID, command: str) -> list[str]: + async def send_command( + self, + peer_id: ID, + command: str, + response_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT, + ) -> list[str]: """ Send a multistream-select command to the specified peer and return the response. @@ -225,7 +238,7 @@ class BasicHost(IHost): try: response = await self.multiselect_client.query_multistream_command( - MultiselectCommunicator(new_stream), command + MultiselectCommunicator(new_stream), command, response_timeout ) except MultiselectClientError as error: logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error) @@ -264,7 +277,7 @@ class BasicHost(IHost): # Perform protocol muxing to determine protocol to use try: protocol, handler = await self.multiselect.negotiate( - MultiselectCommunicator(net_stream) + MultiselectCommunicator(net_stream), self.negotiate_timeout ) except MultiselectError as error: peer_id = net_stream.muxed_conn.peer_id diff --git a/libp2p/protocol_muxer/multiselect.py b/libp2p/protocol_muxer/multiselect.py index 8f6e0e74..3f6ef02f 100644 --- a/libp2p/protocol_muxer/multiselect.py +++ b/libp2p/protocol_muxer/multiselect.py @@ -1,3 +1,5 @@ +import trio + from libp2p.abc import ( IMultiselectCommunicator, IMultiselectMuxer, @@ -14,6 +16,7 @@ from .exceptions import ( MULTISELECT_PROTOCOL_ID = "/multistream/1.0.0" PROTOCOL_NOT_FOUND_MSG = "na" +DEFAULT_NEGOTIATE_TIMEOUT = 5 class Multiselect(IMultiselectMuxer): @@ -47,47 +50,56 @@ class Multiselect(IMultiselectMuxer): # FIXME: Make TProtocol Optional[TProtocol] to keep types consistent async def negotiate( - self, communicator: IMultiselectCommunicator + self, + communicator: IMultiselectCommunicator, + negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT, ) -> tuple[TProtocol, StreamHandlerFn | None]: """ Negotiate performs protocol selection. :param stream: stream to negotiate on + :param negotiate_timeout: timeout for negotiation :return: selected protocol name, handler function :raise MultiselectError: raised when negotiation failed """ - await self.handshake(communicator) + try: + with trio.fail_after(negotiate_timeout): + await self.handshake(communicator) - while True: - try: - command = await communicator.read() - except MultiselectCommunicatorError as error: - raise MultiselectError() from error - - if command == "ls": - supported_protocols = [p for p in self.handlers.keys() if p is not None] - response = "\n".join(supported_protocols) + "\n" - - try: - await communicator.write(response) - except MultiselectCommunicatorError as error: - raise MultiselectError() from error - - else: - protocol = TProtocol(command) - if protocol in self.handlers: + while True: try: - await communicator.write(protocol) + command = await communicator.read() except MultiselectCommunicatorError as error: raise MultiselectError() from error - return protocol, self.handlers[protocol] - try: - await communicator.write(PROTOCOL_NOT_FOUND_MSG) - except MultiselectCommunicatorError as error: - raise MultiselectError() from error + if command == "ls": + supported_protocols = [ + p for p in self.handlers.keys() if p is not None + ] + response = "\n".join(supported_protocols) + "\n" - raise MultiselectError("Negotiation failed: no matching protocol") + try: + await communicator.write(response) + except MultiselectCommunicatorError as error: + raise MultiselectError() from error + + else: + protocol = TProtocol(command) + if protocol in self.handlers: + try: + await communicator.write(protocol) + except MultiselectCommunicatorError as error: + raise MultiselectError() from error + + return protocol, self.handlers[protocol] + try: + await communicator.write(PROTOCOL_NOT_FOUND_MSG) + except MultiselectCommunicatorError as error: + raise MultiselectError() from error + + raise MultiselectError("Negotiation failed: no matching protocol") + except trio.TooSlowError: + raise MultiselectError("handshake read timeout") async def handshake(self, communicator: IMultiselectCommunicator) -> None: """ diff --git a/libp2p/protocol_muxer/multiselect_client.py b/libp2p/protocol_muxer/multiselect_client.py index 8d8c02a1..a5b35006 100644 --- a/libp2p/protocol_muxer/multiselect_client.py +++ b/libp2p/protocol_muxer/multiselect_client.py @@ -2,6 +2,8 @@ from collections.abc import ( Sequence, ) +import trio + from libp2p.abc import ( IMultiselectClient, IMultiselectCommunicator, @@ -17,6 +19,7 @@ from .exceptions import ( MULTISELECT_PROTOCOL_ID = "/multistream/1.0.0" PROTOCOL_NOT_FOUND_MSG = "na" +DEFAULT_NEGOTIATE_TIMEOUT = 5 class MultiselectClient(IMultiselectClient): @@ -40,6 +43,7 @@ class MultiselectClient(IMultiselectClient): try: handshake_contents = await communicator.read() + except MultiselectCommunicatorError as error: raise MultiselectClientError() from error @@ -47,7 +51,10 @@ class MultiselectClient(IMultiselectClient): raise MultiselectClientError("multiselect protocol ID mismatch") async def select_one_of( - self, protocols: Sequence[TProtocol], communicator: IMultiselectCommunicator + self, + protocols: Sequence[TProtocol], + communicator: IMultiselectCommunicator, + negotitate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT, ) -> TProtocol: """ For each protocol, send message to multiselect selecting protocol and @@ -56,22 +63,32 @@ class MultiselectClient(IMultiselectClient): :param protocol: protocol to select :param communicator: communicator to use to communicate with counterparty + :param negotiate_timeout: timeout for negotiation :return: selected protocol :raise MultiselectClientError: raised when protocol negotiation failed """ - await self.handshake(communicator) + try: + with trio.fail_after(negotitate_timeout): + await self.handshake(communicator) - for protocol in protocols: - try: - selected_protocol = await self.try_select(communicator, protocol) - return selected_protocol - except MultiselectClientError: - pass + for protocol in protocols: + try: + selected_protocol = await self.try_select( + communicator, protocol + ) + return selected_protocol + except MultiselectClientError: + pass - raise MultiselectClientError("protocols not supported") + raise MultiselectClientError("protocols not supported") + except trio.TooSlowError: + raise MultiselectClientError("response timed out") async def query_multistream_command( - self, communicator: IMultiselectCommunicator, command: str + self, + communicator: IMultiselectCommunicator, + command: str, + response_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT, ) -> list[str]: """ Send a multistream-select command over the given communicator and return @@ -79,26 +96,32 @@ class MultiselectClient(IMultiselectClient): :param communicator: communicator to use to communicate with counterparty :param command: supported multistream-select command(e.g., ls) + :param negotiate_timeout: timeout for negotiation :raise MultiselectClientError: If the communicator fails to process data. :return: list of strings representing the response from peer. """ - await self.handshake(communicator) - - if command == "ls": - try: - await communicator.write("ls") - except MultiselectCommunicatorError as error: - raise MultiselectClientError() from error - else: - raise ValueError("Command not supported") - try: - response = await communicator.read() - response_list = response.strip().splitlines() - except MultiselectCommunicatorError as error: - raise MultiselectClientError() from error + with trio.fail_after(response_timeout): + await self.handshake(communicator) - return response_list + if command == "ls": + try: + await communicator.write("ls") + except MultiselectCommunicatorError as error: + raise MultiselectClientError() from error + else: + raise ValueError("Command not supported") + + try: + response = await communicator.read() + response_list = response.strip().splitlines() + + except MultiselectCommunicatorError as error: + raise MultiselectClientError() from error + + return response_list + except trio.TooSlowError: + raise MultiselectClientError("command response timed out") async def try_select( self, communicator: IMultiselectCommunicator, protocol: TProtocol @@ -118,6 +141,7 @@ class MultiselectClient(IMultiselectClient): try: response = await communicator.read() + except MultiselectCommunicatorError as error: raise MultiselectClientError() from error diff --git a/libp2p/stream_muxer/muxer_multistream.py b/libp2p/stream_muxer/muxer_multistream.py index b4aa5d57..76699c67 100644 --- a/libp2p/stream_muxer/muxer_multistream.py +++ b/libp2p/stream_muxer/muxer_multistream.py @@ -31,9 +31,6 @@ from libp2p.stream_muxer.yamux.yamux import ( Yamux, ) -# FIXME: add negotiate timeout to `MuxerMultistream` -DEFAULT_NEGOTIATE_TIMEOUT = 60 - class MuxerMultistream: """ diff --git a/newsfragments/696.bugfix.rst b/newsfragments/696.bugfix.rst new file mode 100644 index 00000000..d5686418 --- /dev/null +++ b/newsfragments/696.bugfix.rst @@ -0,0 +1,4 @@ +Add timeout wrappers in: +1. multiselect.py: `negotiate` function +2. multiselect_client.py: `select_one_of` , `query_multistream_command` functions +to prevent indefinite hangs when a remote peer does not respond. diff --git a/tests/core/protocol_muxer/test_negotiate_timeout.py b/tests/core/protocol_muxer/test_negotiate_timeout.py new file mode 100644 index 00000000..a50d65f6 --- /dev/null +++ b/tests/core/protocol_muxer/test_negotiate_timeout.py @@ -0,0 +1,59 @@ +import pytest +import trio + +from libp2p.abc import ( + IMultiselectCommunicator, +) +from libp2p.custom_types import TProtocol +from libp2p.protocol_muxer.exceptions import ( + MultiselectClientError, + MultiselectError, +) +from libp2p.protocol_muxer.multiselect import Multiselect +from libp2p.protocol_muxer.multiselect_client import MultiselectClient + + +class DummyMultiselectCommunicator(IMultiselectCommunicator): + """ + Dummy MultiSelectCommunicator to test out negotiate timmeout. + """ + + def __init__(self) -> None: + return + + async def write(self, msg_str: str) -> None: + """Goes into infinite loop when .write is called""" + await trio.sleep_forever() + + async def read(self) -> str: + """Returns a dummy read""" + return "dummy_read" + + +@pytest.mark.trio +async def test_select_one_of_timeout(): + ECHO = TProtocol("/echo/1.0.0") + communicator = DummyMultiselectCommunicator() + + client = MultiselectClient() + + with pytest.raises(MultiselectClientError, match="response timed out"): + await client.select_one_of([ECHO], communicator, 2) + + +@pytest.mark.trio +async def test_query_multistream_command_timeout(): + communicator = DummyMultiselectCommunicator() + client = MultiselectClient() + + with pytest.raises(MultiselectClientError, match="response timed out"): + await client.query_multistream_command(communicator, "ls", 2) + + +@pytest.mark.trio +async def test_negotiate_timeout(): + communicator = DummyMultiselectCommunicator() + server = Multiselect() + + with pytest.raises(MultiselectError, match="handshake read timeout"): + await server.negotiate(communicator, 2)