diff --git a/libp2p/protocol_muxer/multiselect.py b/libp2p/protocol_muxer/multiselect.py index 0c3dc727..c76e07f8 100644 --- a/libp2p/protocol_muxer/multiselect.py +++ b/libp2p/protocol_muxer/multiselect.py @@ -2,7 +2,7 @@ from typing import Dict, Tuple from libp2p.typing import StreamHandlerFn, TProtocol -from .exceptions import MultiselectError +from .exceptions import MultiselectCommunicatorError, MultiselectError from .multiselect_communicator_interface import IMultiselectCommunicator from .multiselect_muxer_interface import IMultiselectMuxer @@ -46,7 +46,10 @@ class Multiselect(IMultiselectMuxer): # Read and respond to commands until a valid protocol ID is sent while True: # Read message - command = await communicator.read() + try: + command = await communicator.read() + except MultiselectCommunicatorError as error: + raise MultiselectError(str(error)) # Command is ls or a protocol if command == "ls": @@ -76,7 +79,10 @@ class Multiselect(IMultiselectMuxer): await communicator.write(MULTISELECT_PROTOCOL_ID) # Read in the protocol ID from other party - handshake_contents = await communicator.read() + try: + handshake_contents = await communicator.read() + except MultiselectCommunicatorError as error: + raise MultiselectError(str(error)) # Confirm that the protocols are the same if not validate_handshake(handshake_contents): diff --git a/libp2p/protocol_muxer/multiselect_client.py b/libp2p/protocol_muxer/multiselect_client.py index fcd55d08..51af025a 100644 --- a/libp2p/protocol_muxer/multiselect_client.py +++ b/libp2p/protocol_muxer/multiselect_client.py @@ -2,7 +2,7 @@ from typing import Sequence from libp2p.typing import TProtocol -from .exceptions import MultiselectClientError +from .exceptions import MultiselectClientError, MultiselectCommunicatorError from .multiselect_client_interface import IMultiselectClient from .multiselect_communicator_interface import IMultiselectCommunicator @@ -30,7 +30,10 @@ class MultiselectClient(IMultiselectClient): await communicator.write(MULTISELECT_PROTOCOL_ID) # Read in the protocol ID from other party - handshake_contents = await communicator.read() + try: + handshake_contents = await communicator.read() + except MultiselectCommunicatorError as error: + raise MultiselectClientError(str(error)) # Confirm that the protocols are the same if not validate_handshake(handshake_contents): @@ -79,7 +82,10 @@ class MultiselectClient(IMultiselectClient): await communicator.write(protocol) # Get what counterparty says in response - response = await communicator.read() + try: + response = await communicator.read() + except MultiselectCommunicatorError as error: + raise MultiselectClientError(str(error)) # Return protocol if response is equal to protocol or raise error if response == protocol: