diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index e4d71a6d..19d1b766 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -10,6 +10,7 @@ from libp2p.protocol_muxer.multiselect_client import MultiselectClient from libp2p.protocol_muxer.multiselect_communicator import StreamCommunicator from libp2p.routing.interfaces import IPeerRouting from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream +from libp2p.transport.exceptions import UpgradeFailure from libp2p.transport.listener_interface import IListener from libp2p.transport.transport_interface import ITransport from libp2p.transport.upgrader import TransportUpgrader @@ -197,13 +198,18 @@ class Swarm(INetwork): # Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure # the conn and then mux the conn # FIXME: This dummy `ID(b"")` for the remote peer is useless. - secured_conn = await self.upgrader.upgrade_security( - raw_conn, ID(b""), False - ) - peer_id = secured_conn.get_remote_peer() - muxed_conn = await self.upgrader.upgrade_connection( - secured_conn, self.generic_protocol_handler, peer_id - ) + try: + secured_conn = await self.upgrader.upgrade_security( + raw_conn, ID(b""), False + ) + peer_id = secured_conn.get_remote_peer() + muxed_conn = await self.upgrader.upgrade_connection( + secured_conn, self.generic_protocol_handler, peer_id + ) + except UpgradeFailure: + # TODO: Add logging to indicate the failure + raw_conn.close() + return # Store muxed_conn with peer id self.connections[peer_id] = muxed_conn diff --git a/libp2p/protocol_muxer/multiselect.py b/libp2p/protocol_muxer/multiselect.py index 8ce66d93..9fd3de84 100644 --- a/libp2p/protocol_muxer/multiselect.py +++ b/libp2p/protocol_muxer/multiselect.py @@ -80,7 +80,8 @@ class Multiselect(IMultiselectMuxer): # Confirm that the protocols are the same if not validate_handshake(handshake_contents): raise MultiselectError( - f"multiselect protocol ID mismatch: handshake_contents={handshake_contents}" + "multiselect protocol ID mismatch: " + f"received handshake_contents={handshake_contents}" ) # Handshake succeeded if this point is reached diff --git a/libp2p/security/insecure/exceptions.py b/libp2p/security/insecure/exceptions.py deleted file mode 100644 index d2570e7d..00000000 --- a/libp2p/security/insecure/exceptions.py +++ /dev/null @@ -1,2 +0,0 @@ -class UpgradeFailure(Exception): - pass diff --git a/libp2p/security/insecure/transport.py b/libp2p/security/insecure/transport.py index 64f08f04..337d20f0 100644 --- a/libp2p/security/insecure/transport.py +++ b/libp2p/security/insecure/transport.py @@ -4,10 +4,10 @@ from libp2p.peer.id import ID from libp2p.security.base_session import BaseSession from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.secure_conn_interface import ISecureConn +from libp2p.transport.exceptions import SecurityUpgradeFailure from libp2p.typing import TProtocol from libp2p.utils import encode_fixedint_prefixed, read_fixedint_prefixed -from .exceptions import UpgradeFailure from .pb import plaintext_pb2 # Reference: https://github.com/libp2p/go-libp2p-core/blob/master/sec/insecure/insecure.go @@ -17,7 +17,6 @@ PLAINTEXT_PROTOCOL_ID = TProtocol("/plaintext/2.0.0") class InsecureSession(BaseSession): - # FIXME: Update the read/write to `BaseSession` async def run_handshake(self): msg = make_exchange_message(self.local_private_key.get_public_key()) msg_bytes = msg.SerializeToString() @@ -61,7 +60,7 @@ class InsecureTransport(BaseSecureTransport): # TODO: Check if `remote_public_key is not None`. If so, check if `session.remote_peer` received_peer_id = session.get_remote_peer() if received_peer_id != peer_id: - raise UpgradeFailure( + raise SecurityUpgradeFailure( "remote peer sent unexpected peer ID. " f"expected={peer_id} received={received_peer_id}" ) diff --git a/libp2p/security/simple/transport.py b/libp2p/security/simple/transport.py index 8eed2a63..e63e651e 100644 --- a/libp2p/security/simple/transport.py +++ b/libp2p/security/simple/transport.py @@ -6,6 +6,7 @@ from libp2p.peer.id import ID from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.insecure.transport import InsecureSession from libp2p.security.secure_conn_interface import ISecureConn +from libp2p.transport.exceptions import SecurityUpgradeFailure class SimpleSecurityTransport(BaseSecureTransport): @@ -25,14 +26,14 @@ class SimpleSecurityTransport(BaseSecureTransport): incoming = (await conn.read()).decode() if incoming != self.key_phrase: - raise Exception( + raise SecurityUpgradeFailure( "Key phrase differed between nodes. Expected " + self.key_phrase ) session = InsecureSession(self, conn, ID(b"")) - # TODO: Calls handshake to make them know the peer id each other, otherwise tests fail. - # However, it seems pretty weird that `SimpleSecurityTransport` sends peer id through - # `Insecure`. + # NOTE: Here we calls `run_handshake` for both sides to exchange their public keys and + # peer ids, otherwise tests fail. However, it seems pretty weird that + # `SimpleSecurityTransport` sends peer id through `Insecure`. await session.run_handshake() # NOTE: this is abusing the abstraction we have here # but this code may be deprecated soon and this exists @@ -55,14 +56,14 @@ class SimpleSecurityTransport(BaseSecureTransport): await asyncio.sleep(0) if incoming != self.key_phrase: - raise Exception( + raise SecurityUpgradeFailure( "Key phrase differed between nodes. Expected " + self.key_phrase ) session = InsecureSession(self, conn, peer_id) - # TODO: Calls handshake to make them know the peer id each other, otherwise tests fail. - # However, it seems pretty weird that `SimpleSecurityTransport` sends peer id through - # `Insecure`. + # NOTE: Here we calls `run_handshake` for both sides to exchange their public keys and + # peer ids, otherwise tests fail. However, it seems pretty weird that + # `SimpleSecurityTransport` sends peer id through `Insecure`. await session.run_handshake() # NOTE: this is abusing the abstraction we have here # but this code may be deprecated soon and this exists diff --git a/libp2p/transport/exceptions.py b/libp2p/transport/exceptions.py new file mode 100644 index 00000000..5826f83c --- /dev/null +++ b/libp2p/transport/exceptions.py @@ -0,0 +1,7 @@ +# TODO: Add `BaseLibp2pError` and `UpgradeFailure` can inherit from it? +class UpgradeFailure(Exception): + pass + + +class SecurityUpgradeFailure(UpgradeFailure): + pass