diff --git a/libp2p/protocol_muxer/multiselect.py b/libp2p/protocol_muxer/multiselect.py index 99ad37f5..05b112dc 100644 --- a/libp2p/protocol_muxer/multiselect.py +++ b/libp2p/protocol_muxer/multiselect.py @@ -59,41 +59,44 @@ class Multiselect(IMultiselectMuxer): :return: selected protocol name, handler function :raise MultiselectError: raised when negotiation failed """ - await self.handshake(communicator) + try: + with trio.fail_after(DEFAULT_NEGOTIATE_TIMEOUT): + await self.handshake(communicator) - while True: - try: - with trio.fail_after(DEFAULT_NEGOTIATE_TIMEOUT): - command = await communicator.read() - except trio.TooSlowError: - raise MultiselectError("handshake read timeout") - except MultiselectCommunicatorError as error: - raise MultiselectError() from error - - if command == "ls": - supported_protocols = [p for p in self.handlers.keys() if p is not None] - response = "\n".join(supported_protocols) + "\n" - - try: - await communicator.write(response) - except MultiselectCommunicatorError as error: - raise MultiselectError() from error - - else: - protocol = TProtocol(command) - if protocol in self.handlers: + while True: try: - await communicator.write(protocol) + command = await communicator.read() except MultiselectCommunicatorError as error: raise MultiselectError() from error - return protocol, self.handlers[protocol] - try: - await communicator.write(PROTOCOL_NOT_FOUND_MSG) - except MultiselectCommunicatorError as error: - raise MultiselectError() from error + if command == "ls": + supported_protocols = [ + p for p in self.handlers.keys() if p is not None + ] + response = "\n".join(supported_protocols) + "\n" - raise MultiselectError("Negotiation failed: no matching protocol") + try: + await communicator.write(response) + except MultiselectCommunicatorError as error: + raise MultiselectError() from error + + else: + protocol = TProtocol(command) + if protocol in self.handlers: + try: + await communicator.write(protocol) + except MultiselectCommunicatorError as error: + raise MultiselectError() from error + + return protocol, self.handlers[protocol] + try: + await communicator.write(PROTOCOL_NOT_FOUND_MSG) + except MultiselectCommunicatorError as error: + raise MultiselectError() from error + + raise MultiselectError("Negotiation failed: no matching protocol") + except trio.TooSlowError: + raise MultiselectError("handshake read timeout") async def handshake(self, communicator: IMultiselectCommunicator) -> None: """ @@ -108,10 +111,7 @@ class Multiselect(IMultiselectMuxer): raise MultiselectError() from error try: - with trio.fail_after(DEFAULT_NEGOTIATE_TIMEOUT): # Timeout after 5 seconds - handshake_contents = await communicator.read() - except trio.TooSlowError: - raise MultiselectError("protocol selection response timed out") + handshake_contents = await communicator.read() except MultiselectCommunicatorError as error: raise MultiselectError() from error diff --git a/libp2p/protocol_muxer/multiselect_client.py b/libp2p/protocol_muxer/multiselect_client.py index 74335241..75d5ca05 100644 --- a/libp2p/protocol_muxer/multiselect_client.py +++ b/libp2p/protocol_muxer/multiselect_client.py @@ -42,10 +42,8 @@ class MultiselectClient(IMultiselectClient): raise MultiselectClientError() from error try: - with trio.fail_after(DEFAULT_NEGOTIATE_TIMEOUT): - handshake_contents = await communicator.read() - except trio.TooSlowError: - raise MultiselectClientError("handshake read timed out") + handshake_contents = await communicator.read() + except MultiselectCommunicatorError as error: raise MultiselectClientError() from error @@ -65,16 +63,22 @@ class MultiselectClient(IMultiselectClient): :return: selected protocol :raise MultiselectClientError: raised when protocol negotiation failed """ - await self.handshake(communicator) + try: + with trio.fail_after(DEFAULT_NEGOTIATE_TIMEOUT): + await self.handshake(communicator) - for protocol in protocols: - try: - selected_protocol = await self.try_select(communicator, protocol) - return selected_protocol - except MultiselectClientError: - pass + for protocol in protocols: + try: + selected_protocol = await self.try_select( + communicator, protocol + ) + return selected_protocol + except MultiselectClientError: + pass - raise MultiselectClientError("protocols not supported") + raise MultiselectClientError("protocols not supported") + except trio.TooSlowError: + raise MultiselectClientError("response timed out") async def query_multistream_command( self, communicator: IMultiselectCommunicator, command: str @@ -88,26 +92,28 @@ class MultiselectClient(IMultiselectClient): :raise MultiselectClientError: If the communicator fails to process data. :return: list of strings representing the response from peer. """ - await self.handshake(communicator) - - if command == "ls": - try: - await communicator.write("ls") - except MultiselectCommunicatorError as error: - raise MultiselectClientError() from error - else: - raise ValueError("Command not supported") - try: - with trio.fail_after(DEFAULT_NEGOTIATE_TIMEOUT): # Timeout after 5 seconds - response = await communicator.read() - response_list = response.strip().splitlines() + with trio.fail_after(DEFAULT_NEGOTIATE_TIMEOUT): + await self.handshake(communicator) + + if command == "ls": + try: + await communicator.write("ls") + except MultiselectCommunicatorError as error: + raise MultiselectClientError() from error + else: + raise ValueError("Command not supported") + + try: + response = await communicator.read() + response_list = response.strip().splitlines() + + except MultiselectCommunicatorError as error: + raise MultiselectClientError() from error + + return response_list except trio.TooSlowError: raise MultiselectClientError("command response timed out") - except MultiselectCommunicatorError as error: - raise MultiselectClientError() from error - - return response_list async def try_select( self, communicator: IMultiselectCommunicator, protocol: TProtocol @@ -126,10 +132,8 @@ class MultiselectClient(IMultiselectClient): raise MultiselectClientError() from error try: - with trio.fail_after(DEFAULT_NEGOTIATE_TIMEOUT): # Timeout after 5 seconds - response = await communicator.read() - except trio.TooSlowError: - raise MultiselectClientError("protocol selection response timed out") + response = await communicator.read() + except MultiselectCommunicatorError as error: raise MultiselectClientError() from error