diff --git a/Makefile b/Makefile index e5b6509c..45412701 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ FILES_TO_LINT = libp2p tests examples setup.py -PB = libp2p/crypto/pb/crypto.proto libp2p/pubsub/pb/rpc.proto libp2p/security/insecure/pb/plaintext.proto +PB = libp2p/crypto/pb/crypto.proto libp2p/pubsub/pb/rpc.proto libp2p/security/insecure/pb/plaintext.proto libp2p/security/secio/pb/spipe.proto PY = $(PB:.proto=_pb2.py) PYI = $(PB:.proto=_pb2.pyi) diff --git a/libp2p/__init__.py b/libp2p/__init__.py index 10e71dba..dce33b9c 100644 --- a/libp2p/__init__.py +++ b/libp2p/__init__.py @@ -14,6 +14,7 @@ from libp2p.peer.peerstore_interface import IPeerStore from libp2p.routing.interfaces import IPeerRouting from libp2p.routing.kademlia.kademlia_peer_router import KadmeliaPeerRouter from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport +import libp2p.security.secio.transport as secio from libp2p.security.secure_transport_interface import ISecureTransport from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex from libp2p.stream_muxer.muxer_multistream import MuxerClassType @@ -98,7 +99,8 @@ def initialize_default_swarm( muxer_transports_by_protocol = muxer_opt or {MPLEX_PROTOCOL_ID: Mplex} security_transports_by_protocol = sec_opt or { - TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair) + TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair), + TProtocol(secio.ID): secio.Transport(key_pair), } upgrader = TransportUpgrader( security_transports_by_protocol, muxer_transports_by_protocol diff --git a/libp2p/crypto/authenticated_encryption.py b/libp2p/crypto/authenticated_encryption.py new file mode 100644 index 00000000..6900f1f4 --- /dev/null +++ b/libp2p/crypto/authenticated_encryption.py @@ -0,0 +1,128 @@ +from dataclasses import dataclass +import hmac +from typing import Tuple + +from Crypto.Cipher import AES +import Crypto.Util.Counter as Counter + + +class InvalidMACException(Exception): + pass + + +@dataclass(frozen=True) +class EncryptionParameters: + cipher_type: str + hash_type: str + iv: bytes + mac_key: bytes + cipher_key: bytes + + +class MacAndCipher: + def __init__(self, parameters: EncryptionParameters) -> None: + self.authenticator = hmac.new( + parameters.mac_key, digestmod=parameters.hash_type + ) + iv_bit_size = 8 * len(parameters.iv) + cipher = AES.new( + parameters.cipher_key, + AES.MODE_CTR, + counter=Counter.new( + iv_bit_size, + initial_value=int.from_bytes(parameters.iv, byteorder="big"), + ), + ) + self.cipher = cipher + + def encrypt(self, data: bytes) -> bytes: + return self.cipher.encrypt(data) + + def authenticate(self, data: bytes) -> bytes: + authenticator = self.authenticator.copy() + authenticator.update(data) + return authenticator.digest() + + def decrypt_if_valid(self, data_with_tag: bytes) -> bytes: + tag_position = len(data_with_tag) - self.authenticator.digest_size + data = data_with_tag[:tag_position] + tag = data_with_tag[tag_position:] + + authenticator = self.authenticator.copy() + authenticator.update(data) + expected_tag = authenticator.digest() + + if not hmac.compare_digest(tag, expected_tag): + raise InvalidMACException(expected_tag, tag) + + return self.cipher.decrypt(data) + + +def initialize_pair( + cipher_type: str, hash_type: str, secret: bytes +) -> Tuple[EncryptionParameters, EncryptionParameters]: + """ + Return a pair of ``Keys`` for use in securing a + communications channel with authenticated encryption + derived from the ``secret`` and using the + requested ``cipher_type`` and ``hash_type``. + """ + if cipher_type != "AES-128": + raise NotImplementedError() + if hash_type != "SHA256": + raise NotImplementedError() + + iv_size = 16 + cipher_key_size = 16 + hmac_key_size = 20 + seed = "key expansion".encode() + + params_size = iv_size + cipher_key_size + hmac_key_size + result = bytearray(2 * params_size) + + authenticator = hmac.new(secret, digestmod=hash_type) + authenticator.update(seed) + tag = authenticator.digest() + + i = 0 + len_result = 2 * params_size + while i < len_result: + authenticator = hmac.new(secret, digestmod=hash_type) + + authenticator.update(tag) + authenticator.update(seed) + + another_tag = authenticator.digest() + + remaining_bytes = len(another_tag) + + if i + remaining_bytes > len_result: + remaining_bytes = len_result - i + + result[i : i + remaining_bytes] = another_tag[0:remaining_bytes] + + i += remaining_bytes + + authenticator = hmac.new(secret, digestmod=hash_type) + authenticator.update(tag) + tag = authenticator.digest() + + first_half = result[:params_size] + second_half = result[params_size:] + + return ( + EncryptionParameters( + cipher_type, + hash_type, + first_half[0:iv_size], + first_half[iv_size + cipher_key_size :], + first_half[iv_size : iv_size + cipher_key_size], + ), + EncryptionParameters( + cipher_type, + hash_type, + second_half[0:iv_size], + second_half[iv_size + cipher_key_size :], + second_half[iv_size : iv_size + cipher_key_size], + ), + ) diff --git a/libp2p/crypto/ecc.py b/libp2p/crypto/ecc.py new file mode 100644 index 00000000..8ede8f8c --- /dev/null +++ b/libp2p/crypto/ecc.py @@ -0,0 +1,56 @@ +from typing import cast + +from Crypto.PublicKey import ECC +from Crypto.PublicKey.ECC import EccKey + +from libp2p.crypto.keys import KeyPair, KeyType, PrivateKey, PublicKey + + +class ECCPublicKey(PublicKey): + def __init__(self, impl: EccKey) -> None: + self.impl = impl + + def to_bytes(self) -> bytes: + return cast(bytes, self.impl.export_key(format="DER")) + + @classmethod + def from_bytes(cls, data: bytes) -> "ECCPublicKey": + public_key_impl = ECC.import_key(data) + return cls(public_key_impl) + + def get_type(self) -> KeyType: + return KeyType.ECC_P256 + + def verify(self, data: bytes, signature: bytes) -> bool: + raise NotImplementedError + + +class ECCPrivateKey(PrivateKey): + def __init__(self, impl: EccKey) -> None: + self.impl = impl + + @classmethod + def new(cls, curve: str) -> "ECCPrivateKey": + private_key_impl = ECC.generate(curve=curve) + return cls(private_key_impl) + + def to_bytes(self) -> bytes: + return cast(bytes, self.impl.export_key(format="DER")) + + def get_type(self) -> KeyType: + return KeyType.ECC_P256 + + def sign(self, data: bytes) -> bytes: + raise NotImplementedError + + def get_public_key(self) -> PublicKey: + return ECCPublicKey(self.impl.public_key()) + + +def create_new_key_pair(curve: str) -> KeyPair: + """ + Return a new ECC keypair with the requested ``curve`` type, e.g. "P-256". + """ + private_key = ECCPrivateKey.new(curve) + public_key = private_key.get_public_key() + return KeyPair(private_key, public_key) diff --git a/libp2p/crypto/key_exchange.py b/libp2p/crypto/key_exchange.py new file mode 100644 index 00000000..4e895c95 --- /dev/null +++ b/libp2p/crypto/key_exchange.py @@ -0,0 +1,29 @@ +from typing import Callable, Tuple, cast + +from Crypto.Math.Numbers import Integer +import Crypto.PublicKey.ECC as ECC + +from libp2p.crypto.ecc import ECCPrivateKey, create_new_key_pair +from libp2p.crypto.keys import PublicKey + +SharedKeyGenerator = Callable[[bytes], bytes] + + +def create_ephemeral_key_pair(curve_type: str) -> Tuple[PublicKey, SharedKeyGenerator]: + """ + Facilitates ECDH key exchange. + """ + if curve_type != "P-256": + raise NotImplementedError() + + key_pair = create_new_key_pair(curve_type) + + def _key_exchange(serialized_remote_public_key: bytes) -> bytes: + remote_public_key = ECC.import_key(serialized_remote_public_key) + curve_point = remote_public_key.pointQ + private_key = cast(ECCPrivateKey, key_pair.private_key) + secret_point = curve_point * private_key.impl.d + byte_size = secret_point.size_in_bytes() + return cast(Integer, secret_point.x).to_bytes(byte_size) + + return key_pair.public_key, _key_exchange diff --git a/libp2p/crypto/keys.py b/libp2p/crypto/keys.py index 31caca00..5bcc5a37 100644 --- a/libp2p/crypto/keys.py +++ b/libp2p/crypto/keys.py @@ -11,6 +11,7 @@ class KeyType(Enum): Ed25519 = 1 Secp256k1 = 2 ECDSA = 3 + ECC_P256 = 4 class Key(ABC): @@ -32,6 +33,11 @@ class Key(ABC): """ ... + def __eq__(self, other: object) -> bool: + if not isinstance(other, Key): + return NotImplemented + return self.to_bytes() == other.to_bytes() + class PublicKey(Key): """ @@ -60,14 +66,16 @@ class PublicKey(Key): """ return self._serialize_to_protobuf().SerializeToString() + @classmethod + def deserialize_from_protobuf(cls, protobuf_data: bytes) -> protobuf.PublicKey: + return protobuf.PublicKey.FromString(protobuf_data) + class PrivateKey(Key): """ A ``PrivateKey`` represents a cryptographic private key. """ - protobuf_constructor = protobuf.PrivateKey - @abstractmethod def sign(self, data: bytes) -> bytes: ... @@ -91,6 +99,10 @@ class PrivateKey(Key): """ return self._serialize_to_protobuf().SerializeToString() + @classmethod + def deserialize_from_protobuf(cls, protobuf_data: bytes) -> protobuf.PrivateKey: + return protobuf.PrivateKey.FromString(protobuf_data) + @dataclass(frozen=True) class KeyPair: diff --git a/libp2p/crypto/secp256k1.py b/libp2p/crypto/secp256k1.py index e2d5fb26..475c1673 100644 --- a/libp2p/crypto/secp256k1.py +++ b/libp2p/crypto/secp256k1.py @@ -11,15 +11,20 @@ class Secp256k1PublicKey(PublicKey): return self.impl.format() @classmethod - def from_bytes(cls, key_bytes: bytes) -> "Secp256k1PublicKey": - secp256k1_pubkey = coincurve.PublicKey(key_bytes) - return cls(secp256k1_pubkey) + def from_bytes(cls, data: bytes) -> "Secp256k1PublicKey": + impl = coincurve.PublicKey(data) + return cls(impl) + + @classmethod + def deserialize(cls, data: bytes) -> "Secp256k1PublicKey": + protobuf_key = cls.deserialize_from_protobuf(data) + return cls.from_bytes(protobuf_key.data) def get_type(self) -> KeyType: return KeyType.Secp256k1 def verify(self, data: bytes, signature: bytes) -> bool: - raise NotImplementedError + return self.impl.verify(signature, data) class Secp256k1PrivateKey(PrivateKey): @@ -34,11 +39,21 @@ class Secp256k1PrivateKey(PrivateKey): def to_bytes(self) -> bytes: return self.impl.secret + @classmethod + def from_bytes(cls, data: bytes) -> "Secp256k1PrivateKey": + impl = coincurve.PrivateKey(data) + return cls(impl) + + @classmethod + def deserialize(cls, data: bytes) -> "Secp256k1PrivateKey": + protobuf_key = cls.deserialize_from_protobuf(data) + return cls.from_bytes(protobuf_key.data) + def get_type(self) -> KeyType: return KeyType.Secp256k1 def sign(self, data: bytes) -> bytes: - raise NotImplementedError + return self.impl.sign(data) def get_public_key(self) -> PublicKey: public_key_impl = coincurve.PublicKey.from_secret(self.impl.secret) diff --git a/libp2p/crypto/serialization.py b/libp2p/crypto/serialization.py new file mode 100644 index 00000000..5b6b2764 --- /dev/null +++ b/libp2p/crypto/serialization.py @@ -0,0 +1,22 @@ +from libp2p.crypto.keys import KeyType, PrivateKey, PublicKey +from libp2p.crypto.secp256k1 import Secp256k1PrivateKey, Secp256k1PublicKey + +key_type_to_public_key_deserializer = { + KeyType.Secp256k1.value: Secp256k1PublicKey.from_bytes +} + +key_type_to_private_key_deserializer = { + KeyType.Secp256k1.value: Secp256k1PrivateKey.from_bytes +} + + +def deserialize_public_key(data: bytes) -> PublicKey: + f = PublicKey.deserialize_from_protobuf(data) + deserializer = key_type_to_public_key_deserializer[f.key_type] + return deserializer(f.data) + + +def deserialize_private_key(data: bytes) -> PrivateKey: + f = PrivateKey.deserialize_from_protobuf(data) + deserializer = key_type_to_private_key_deserializer[f.key_type] + return deserializer(f.data) diff --git a/libp2p/io/__init__.py b/libp2p/io/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libp2p/io/exceptions.py b/libp2p/io/exceptions.py new file mode 100644 index 00000000..6e1376fa --- /dev/null +++ b/libp2p/io/exceptions.py @@ -0,0 +1,13 @@ +from libp2p.exceptions import BaseLibp2pError + + +class MsgioException(BaseLibp2pError): + pass + + +class MissingLengthException(MsgioException): + pass + + +class MissingMessageException(MsgioException): + pass diff --git a/libp2p/io/msgio.py b/libp2p/io/msgio.py new file mode 100644 index 00000000..65fde685 --- /dev/null +++ b/libp2p/io/msgio.py @@ -0,0 +1,24 @@ +from libp2p.network.connection.raw_connection_interface import IRawConnection + +from .exceptions import MissingLengthException, MissingMessageException + +SIZE_LEN_BYTES = 4 + +# TODO unify w/ https://github.com/libp2p/py-libp2p/blob/1aed52856f56a4b791696bbcbac31b5f9c2e88c9/libp2p/utils.py#L85-L99 # noqa: E501 + + +def encode(msg_bytes: bytes) -> bytes: + len_prefix = len(msg_bytes).to_bytes(SIZE_LEN_BYTES, "big") + return len_prefix + msg_bytes + + +async def read_next_message(reader: IRawConnection) -> bytes: + len_bytes = await reader.read(SIZE_LEN_BYTES) + if len(len_bytes) != SIZE_LEN_BYTES: + raise MissingLengthException() + len_int = int.from_bytes(len_bytes, "big") + next_msg = await reader.read(len_int) + if len(next_msg) != len_int: + # TODO makes sense to keep reading until this condition is true? + raise MissingMessageException() + return next_msg diff --git a/libp2p/peer/id.py b/libp2p/peer/id.py index 303f519f..c1b52f02 100644 --- a/libp2p/peer/id.py +++ b/libp2p/peer/id.py @@ -6,6 +6,11 @@ import multihash from libp2p.crypto.keys import PublicKey +# NOTE: ``FRIENDLY_IDS`` renders a ``str`` representation of ``ID`` as a +# short string of a prefix of the base58 representation. This feature is primarily +# intended for debugging, logging, etc. +FRIENDLY_IDS = True + class ID: _bytes: bytes @@ -32,7 +37,13 @@ class ID: def __repr__(self) -> str: return "" - __str__ = pretty = to_string = to_base58 + pretty = to_string = to_base58 + + def __str__(self) -> str: + if FRIENDLY_IDS: + return self.to_string()[2:8] + else: + return self.to_string() def __eq__(self, other: object) -> bool: if isinstance(other, str): diff --git a/libp2p/security/base_session.py b/libp2p/security/base_session.py index 54562b18..2d76198a 100644 --- a/libp2p/security/base_session.py +++ b/libp2p/security/base_session.py @@ -3,7 +3,6 @@ from typing import Optional from libp2p.crypto.keys import PrivateKey, PublicKey from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.peer.id import ID -from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.secure_conn_interface import ISecureConn @@ -20,14 +19,18 @@ class BaseSession(ISecureConn): remote_permanent_pubkey: PublicKey def __init__( - self, transport: BaseSecureTransport, conn: IRawConnection, peer_id: ID + self, + local_peer: ID, + local_private_key: PrivateKey, + conn: IRawConnection, + peer_id: Optional[ID] = None, ) -> None: - self.local_peer = transport.local_peer - self.local_private_key = transport.local_private_key - self.conn = conn + self.local_peer = local_peer + self.local_private_key = local_private_key self.remote_peer_id = peer_id self.remote_permanent_pubkey = None + self.conn = conn self.initiator = self.conn.initiator async def write(self, data: bytes) -> None: diff --git a/libp2p/security/base_transport.py b/libp2p/security/base_transport.py index 0f096bfc..10d7b663 100644 --- a/libp2p/security/base_transport.py +++ b/libp2p/security/base_transport.py @@ -1,14 +1,30 @@ +import secrets +from typing import Callable + from libp2p.crypto.keys import KeyPair from libp2p.peer.id import ID from libp2p.security.secure_transport_interface import ISecureTransport +def default_secure_bytes_provider(n: int) -> bytes: + return secrets.token_bytes(n) + + class BaseSecureTransport(ISecureTransport): """ ``BaseSecureTransport`` is not fully instantiated from its abstract classes as it is only meant to be used in clases that derive from it. + + Clients can provide a strategy to get cryptographically secure bytes of a given length. + A default implementation is provided using the ``secrets`` module from the + standard library. """ - def __init__(self, local_key_pair: KeyPair) -> None: + def __init__( + self, + local_key_pair: KeyPair, + secure_bytes_provider: Callable[[int], bytes] = default_secure_bytes_provider, + ) -> None: self.local_private_key = local_key_pair.private_key self.local_peer = ID.from_pubkey(local_key_pair.public_key) + self.secure_bytes_provider = secure_bytes_provider diff --git a/libp2p/security/insecure/transport.py b/libp2p/security/insecure/transport.py index 2aad45c0..8ad2e614 100644 --- a/libp2p/security/insecure/transport.py +++ b/libp2p/security/insecure/transport.py @@ -76,7 +76,7 @@ class InsecureTransport(BaseSecureTransport): for an inbound connection (i.e. we are not the initiator) :return: secure connection object (that implements secure_conn_interface) """ - session = InsecureSession(self, conn, ID(b"")) + session = InsecureSession(self.local_peer, self.local_private_key, conn) await session.run_handshake() return session @@ -86,7 +86,9 @@ class InsecureTransport(BaseSecureTransport): for an inbound connection (i.e. we are the initiator) :return: secure connection object (that implements secure_conn_interface) """ - session = InsecureSession(self, conn, peer_id) + session = InsecureSession( + self.local_peer, self.local_private_key, conn, peer_id + ) await session.run_handshake() return session diff --git a/libp2p/security/secio/exceptions.py b/libp2p/security/secio/exceptions.py new file mode 100644 index 00000000..f9ea8cf5 --- /dev/null +++ b/libp2p/security/secio/exceptions.py @@ -0,0 +1,27 @@ +class SecioException(Exception): + pass + + +class SelfEncryption(SecioException): + """ + Raised to indicate that a host is attempting to encrypt communications + with itself. + """ + + pass + + +class PeerMismatchException(SecioException): + pass + + +class InvalidSignatureOnExchange(SecioException): + pass + + +class HandshakeFailed(SecioException): + pass + + +class IncompatibleChoices(SecioException): + pass diff --git a/libp2p/security/secio/pb/spipe.proto b/libp2p/security/secio/pb/spipe.proto new file mode 100644 index 00000000..942a9a5f --- /dev/null +++ b/libp2p/security/secio/pb/spipe.proto @@ -0,0 +1,16 @@ +syntax = "proto2"; + +package spipe.pb; + +message Propose { + optional bytes rand = 1; + optional bytes public_key = 2; + optional string exchanges = 3; + optional string ciphers = 4; + optional string hashes = 5; +} + +message Exchange { + optional bytes ephemeral_public_key = 1; + optional bytes signature = 2; +} \ No newline at end of file diff --git a/libp2p/security/secio/pb/spipe_pb2.py b/libp2p/security/secio/pb/spipe_pb2.py new file mode 100644 index 00000000..87a2f37d --- /dev/null +++ b/libp2p/security/secio/pb/spipe_pb2.py @@ -0,0 +1,144 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: spipe.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='spipe.proto', + package='spipe.pb', + syntax='proto2', + serialized_options=None, + serialized_pb=_b('\n\x0bspipe.proto\x12\x08spipe.pb\"_\n\x07Propose\x12\x0c\n\x04rand\x18\x01 \x01(\x0c\x12\x12\n\npublic_key\x18\x02 \x01(\x0c\x12\x11\n\texchanges\x18\x03 \x01(\t\x12\x0f\n\x07\x63iphers\x18\x04 \x01(\t\x12\x0e\n\x06hashes\x18\x05 \x01(\t\";\n\x08\x45xchange\x12\x1c\n\x14\x65phemeral_public_key\x18\x01 \x01(\x0c\x12\x11\n\tsignature\x18\x02 \x01(\x0c') +) + + + + +_PROPOSE = _descriptor.Descriptor( + name='Propose', + full_name='spipe.pb.Propose', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='rand', full_name='spipe.pb.Propose.rand', index=0, + number=1, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=_b(""), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='public_key', full_name='spipe.pb.Propose.public_key', index=1, + number=2, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=_b(""), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='exchanges', full_name='spipe.pb.Propose.exchanges', index=2, + number=3, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='ciphers', full_name='spipe.pb.Propose.ciphers', index=3, + number=4, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='hashes', full_name='spipe.pb.Propose.hashes', index=4, + number=5, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=25, + serialized_end=120, +) + + +_EXCHANGE = _descriptor.Descriptor( + name='Exchange', + full_name='spipe.pb.Exchange', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='ephemeral_public_key', full_name='spipe.pb.Exchange.ephemeral_public_key', index=0, + number=1, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=_b(""), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='signature', full_name='spipe.pb.Exchange.signature', index=1, + number=2, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=_b(""), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=122, + serialized_end=181, +) + +DESCRIPTOR.message_types_by_name['Propose'] = _PROPOSE +DESCRIPTOR.message_types_by_name['Exchange'] = _EXCHANGE +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +Propose = _reflection.GeneratedProtocolMessageType('Propose', (_message.Message,), dict( + DESCRIPTOR = _PROPOSE, + __module__ = 'spipe_pb2' + # @@protoc_insertion_point(class_scope:spipe.pb.Propose) + )) +_sym_db.RegisterMessage(Propose) + +Exchange = _reflection.GeneratedProtocolMessageType('Exchange', (_message.Message,), dict( + DESCRIPTOR = _EXCHANGE, + __module__ = 'spipe_pb2' + # @@protoc_insertion_point(class_scope:spipe.pb.Exchange) + )) +_sym_db.RegisterMessage(Exchange) + + +# @@protoc_insertion_point(module_scope) diff --git a/libp2p/security/secio/pb/spipe_pb2.pyi b/libp2p/security/secio/pb/spipe_pb2.pyi new file mode 100644 index 00000000..2025ff13 --- /dev/null +++ b/libp2p/security/secio/pb/spipe_pb2.pyi @@ -0,0 +1,67 @@ +# @generated by generate_proto_mypy_stubs.py. Do not edit! +import sys +from google.protobuf.descriptor import ( + Descriptor as google___protobuf___descriptor___Descriptor, +) + +from google.protobuf.message import ( + Message as google___protobuf___message___Message, +) + +from typing import ( + Optional as typing___Optional, + Text as typing___Text, +) + +from typing_extensions import ( + Literal as typing_extensions___Literal, +) + + +class Propose(google___protobuf___message___Message): + DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... + rand = ... # type: bytes + public_key = ... # type: bytes + exchanges = ... # type: typing___Text + ciphers = ... # type: typing___Text + hashes = ... # type: typing___Text + + def __init__(self, + *, + rand : typing___Optional[bytes] = None, + public_key : typing___Optional[bytes] = None, + exchanges : typing___Optional[typing___Text] = None, + ciphers : typing___Optional[typing___Text] = None, + hashes : typing___Optional[typing___Text] = None, + ) -> None: ... + @classmethod + def FromString(cls, s: bytes) -> Propose: ... + def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... + def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... + if sys.version_info >= (3,): + def HasField(self, field_name: typing_extensions___Literal[u"ciphers",u"exchanges",u"hashes",u"public_key",u"rand"]) -> bool: ... + def ClearField(self, field_name: typing_extensions___Literal[u"ciphers",u"exchanges",u"hashes",u"public_key",u"rand"]) -> None: ... + else: + def HasField(self, field_name: typing_extensions___Literal[u"ciphers",b"ciphers",u"exchanges",b"exchanges",u"hashes",b"hashes",u"public_key",b"public_key",u"rand",b"rand"]) -> bool: ... + def ClearField(self, field_name: typing_extensions___Literal[u"ciphers",b"ciphers",u"exchanges",b"exchanges",u"hashes",b"hashes",u"public_key",b"public_key",u"rand",b"rand"]) -> None: ... + +class Exchange(google___protobuf___message___Message): + DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... + ephemeral_public_key = ... # type: bytes + signature = ... # type: bytes + + def __init__(self, + *, + ephemeral_public_key : typing___Optional[bytes] = None, + signature : typing___Optional[bytes] = None, + ) -> None: ... + @classmethod + def FromString(cls, s: bytes) -> Exchange: ... + def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... + def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... + if sys.version_info >= (3,): + def HasField(self, field_name: typing_extensions___Literal[u"ephemeral_public_key",u"signature"]) -> bool: ... + def ClearField(self, field_name: typing_extensions___Literal[u"ephemeral_public_key",u"signature"]) -> None: ... + else: + def HasField(self, field_name: typing_extensions___Literal[u"ephemeral_public_key",b"ephemeral_public_key",u"signature",b"signature"]) -> bool: ... + def ClearField(self, field_name: typing_extensions___Literal[u"ephemeral_public_key",b"ephemeral_public_key",u"signature",b"signature"]) -> None: ... diff --git a/libp2p/security/secio/transport.py b/libp2p/security/secio/transport.py new file mode 100644 index 00000000..4c3dbc08 --- /dev/null +++ b/libp2p/security/secio/transport.py @@ -0,0 +1,403 @@ +from dataclasses import dataclass +from typing import Optional, Tuple + +import multihash + +from libp2p.crypto.authenticated_encryption import ( + EncryptionParameters as AuthenticatedEncryptionParameters, +) +from libp2p.crypto.authenticated_encryption import ( + initialize_pair as initialize_pair_for_encryption, +) +from libp2p.crypto.authenticated_encryption import MacAndCipher as Encrypter +from libp2p.crypto.ecc import ECCPublicKey +from libp2p.crypto.key_exchange import create_ephemeral_key_pair +from libp2p.crypto.keys import PrivateKey, PublicKey +from libp2p.crypto.serialization import deserialize_public_key +from libp2p.io.msgio import encode as encode_message +from libp2p.io.msgio import read_next_message +from libp2p.network.connection.raw_connection_interface import IRawConnection +from libp2p.peer.id import ID as PeerID +from libp2p.security.base_session import BaseSession +from libp2p.security.base_transport import BaseSecureTransport +from libp2p.security.secure_conn_interface import ISecureConn + +from .exceptions import ( + HandshakeFailed, + IncompatibleChoices, + InvalidSignatureOnExchange, + PeerMismatchException, + SecioException, + SelfEncryption, +) +from .pb.spipe_pb2 import Exchange, Propose + +ID = "/secio/1.0.0" + +NONCE_SIZE = 16 # bytes + +# NOTE: the following is only a subset of allowable parameters according to the +# `secio` specification. +DEFAULT_SUPPORTED_EXCHANGES = "P-256" +DEFAULT_SUPPORTED_CIPHERS = "AES-128" +DEFAULT_SUPPORTED_HASHES = "SHA256" + + +class SecureSession(BaseSession): + def __init__( + self, + local_peer: PeerID, + local_private_key: PrivateKey, + local_encryption_parameters: AuthenticatedEncryptionParameters, + remote_peer: PeerID, + remote_encryption_parameters: AuthenticatedEncryptionParameters, + conn: IRawConnection, + ) -> None: + super().__init__(local_peer, local_private_key, conn, remote_peer) + + self.local_encryption_parameters = local_encryption_parameters + self.remote_encryption_parameters = remote_encryption_parameters + self._initialize_authenticated_encryption_for_local_peer() + self._initialize_authenticated_encryption_for_remote_peer() + + def _initialize_authenticated_encryption_for_local_peer(self) -> None: + self.local_encrypter = Encrypter(self.local_encryption_parameters) + + def _initialize_authenticated_encryption_for_remote_peer(self) -> None: + self.remote_encrypter = Encrypter(self.remote_encryption_parameters) + + async def read(self, n: int = -1) -> bytes: + return await self._read_msg() + + async def _read_msg(self) -> bytes: + # TODO do we need to serialize reads? + msg = await read_next_message(self.conn) + return self.remote_encrypter.decrypt_if_valid(msg) + + async def write(self, data: bytes) -> None: + await self._write_msg(data) + + async def _write_msg(self, data: bytes) -> None: + # TODO do we need to serialize writes? + encrypted_data = self.local_encrypter.encrypt(data) + tag = self.local_encrypter.authenticate(encrypted_data) + msg = encode_message(encrypted_data + tag) + await self.conn.write(msg) + + +@dataclass(frozen=True) +class Proposal: + """ + A ``Proposal`` represents the set of session parameters one peer in a pair of + peers attempting to negotiate a `secio` channel prefers. + """ + + nonce: bytes + public_key: PublicKey + exchanges: str = DEFAULT_SUPPORTED_EXCHANGES # comma separated list + ciphers: str = DEFAULT_SUPPORTED_CIPHERS # comma separated list + hashes: str = DEFAULT_SUPPORTED_HASHES # comma separated list + + def serialize(self) -> bytes: + protobuf = Propose( + rand=self.nonce, + public_key=self.public_key.serialize(), + exchanges=self.exchanges, + ciphers=self.ciphers, + hashes=self.hashes, + ) + return protobuf.SerializeToString() + + @classmethod + def deserialize(cls, protobuf_bytes: bytes) -> "Proposal": + protobuf = Propose.FromString(protobuf_bytes) + + nonce = protobuf.rand + public_key_protobuf_bytes = protobuf.public_key + public_key = deserialize_public_key(public_key_protobuf_bytes) + exchanges = protobuf.exchanges + ciphers = protobuf.ciphers + hashes = protobuf.hashes + + return cls(nonce, public_key, exchanges, ciphers, hashes) + + def calculate_peer_id(self) -> PeerID: + return PeerID.from_pubkey(self.public_key) + + +@dataclass +class EncryptionParameters: + permanent_public_key: PublicKey + + curve_type: str + cipher_type: str + hash_type: str + + ephemeral_public_key: PublicKey + + def __init__(self) -> None: + pass + + +@dataclass +class SessionParameters: + local_peer: PeerID + local_encryption_parameters: EncryptionParameters + + remote_peer: PeerID + remote_encryption_parameters: EncryptionParameters + + # order is a comparator used to break the symmetry b/t each pair of peers + order: int + shared_key: bytes + + def __init__(self) -> None: + pass + + +async def _response_to_msg(conn: IRawConnection, msg: bytes) -> bytes: + await conn.write(encode_message(msg)) + return await read_next_message(conn) + + +def _mk_multihash_sha256(data: bytes) -> bytes: + return multihash.digest(data, "sha2-256") + + +def _mk_score(public_key: PublicKey, nonce: bytes) -> bytes: + return _mk_multihash_sha256(public_key.serialize() + nonce) + + +def _select_parameter_from_order( + order: int, supported_parameters: str, available_parameters: str +) -> str: + if order < 0: + first_choices = available_parameters.split(",") + second_choices = supported_parameters.split(",") + elif order > 0: + first_choices = supported_parameters.split(",") + second_choices = available_parameters.split(",") + else: + return supported_parameters.split(",")[0] + + for first, second in zip(first_choices, second_choices): + if first == second: + return first + raise IncompatibleChoices() + + +def _select_encryption_parameters( + local_proposal: Proposal, remote_proposal: Proposal +) -> Tuple[str, str, str, int]: + first_score = _mk_score(remote_proposal.public_key, local_proposal.nonce) + second_score = _mk_score(local_proposal.public_key, remote_proposal.nonce) + + order = 0 + if first_score < second_score: + order = -1 + elif second_score < first_score: + order = 1 + + if order == 0: + raise SelfEncryption() + + return ( + _select_parameter_from_order( + order, DEFAULT_SUPPORTED_EXCHANGES, remote_proposal.exchanges + ), + _select_parameter_from_order( + order, DEFAULT_SUPPORTED_CIPHERS, remote_proposal.ciphers + ), + _select_parameter_from_order( + order, DEFAULT_SUPPORTED_HASHES, remote_proposal.hashes + ), + order, + ) + + +async def _establish_session_parameters( + local_peer: PeerID, + local_private_key: PrivateKey, + remote_peer: Optional[PeerID], + conn: IRawConnection, + nonce: bytes, +) -> Tuple[SessionParameters, bytes]: + # establish shared encryption parameters + session_parameters = SessionParameters() + session_parameters.local_peer = local_peer + + local_encryption_parameters = EncryptionParameters() + session_parameters.local_encryption_parameters = local_encryption_parameters + + local_public_key = local_private_key.get_public_key() + local_encryption_parameters.permanent_public_key = local_public_key + + local_proposal = Proposal(nonce, local_public_key) + serialized_local_proposal = local_proposal.serialize() + serialized_remote_proposal = await _response_to_msg(conn, serialized_local_proposal) + + remote_encryption_parameters = EncryptionParameters() + session_parameters.remote_encryption_parameters = remote_encryption_parameters + remote_proposal = Proposal.deserialize(serialized_remote_proposal) + remote_encryption_parameters.permanent_public_key = remote_proposal.public_key + + remote_peer_from_proposal = remote_proposal.calculate_peer_id() + if not remote_peer: + remote_peer = remote_peer_from_proposal + elif remote_peer != remote_peer_from_proposal: + raise PeerMismatchException() + session_parameters.remote_peer = remote_peer + + curve_param, cipher_param, hash_param, order = _select_encryption_parameters( + local_proposal, remote_proposal + ) + local_encryption_parameters.curve_type = curve_param + local_encryption_parameters.cipher_type = cipher_param + local_encryption_parameters.hash_type = hash_param + remote_encryption_parameters.curve_type = curve_param + remote_encryption_parameters.cipher_type = cipher_param + remote_encryption_parameters.hash_type = hash_param + session_parameters.order = order + + # exchange ephemeral pub keys + local_ephemeral_public_key, shared_key_generator = create_ephemeral_key_pair( + curve_param + ) + local_encryption_parameters.ephemeral_public_key = local_ephemeral_public_key + local_selection = ( + serialized_local_proposal + + serialized_remote_proposal + + local_ephemeral_public_key.to_bytes() + ) + exchange_signature = local_private_key.sign(local_selection) + local_exchange = Exchange( + ephemeral_public_key=local_ephemeral_public_key.to_bytes(), + signature=exchange_signature, + ) + + serialized_local_exchange = local_exchange.SerializeToString() + serialized_remote_exchange = await _response_to_msg(conn, serialized_local_exchange) + + remote_exchange = Exchange() + remote_exchange.ParseFromString(serialized_remote_exchange) + + remote_ephemeral_public_key_bytes = remote_exchange.ephemeral_public_key + remote_ephemeral_public_key = ECCPublicKey.from_bytes( + remote_ephemeral_public_key_bytes + ) + remote_encryption_parameters.ephemeral_public_key = remote_ephemeral_public_key + remote_selection = ( + serialized_remote_proposal + + serialized_local_proposal + + remote_ephemeral_public_key_bytes + ) + valid_signature = remote_encryption_parameters.permanent_public_key.verify( + remote_selection, remote_exchange.signature + ) + if not valid_signature: + raise InvalidSignatureOnExchange() + + shared_key = shared_key_generator(remote_ephemeral_public_key_bytes) + session_parameters.shared_key = shared_key + + return session_parameters, remote_proposal.nonce + + +def _mk_session_from( + local_private_key: PrivateKey, + session_parameters: SessionParameters, + conn: IRawConnection, +) -> SecureSession: + key_set1, key_set2 = initialize_pair_for_encryption( + session_parameters.local_encryption_parameters.cipher_type, + session_parameters.local_encryption_parameters.hash_type, + session_parameters.shared_key, + ) + + if session_parameters.order < 0: + key_set1, key_set2 = key_set2, key_set1 + + session = SecureSession( + session_parameters.local_peer, + local_private_key, + key_set1, + session_parameters.remote_peer, + key_set2, + conn, + ) + return session + + +async def _finish_handshake(session: ISecureConn, remote_nonce: bytes) -> bytes: + await session.write(remote_nonce) + return await session.read() + + +async def create_secure_session( + local_nonce: bytes, + local_peer: PeerID, + local_private_key: PrivateKey, + conn: IRawConnection, + remote_peer: PeerID = None, +) -> ISecureConn: + """ + Attempt the initial `secio` handshake with the remote peer. + If successful, return an object that provides secure communication to the + ``remote_peer``. + """ + try: + session_parameters, remote_nonce = await _establish_session_parameters( + local_peer, local_private_key, remote_peer, conn, local_nonce + ) + except SecioException as e: + await conn.close() + raise e + + session = _mk_session_from(local_private_key, session_parameters, conn) + + received_nonce = await _finish_handshake(session, remote_nonce) + if received_nonce != local_nonce: + await conn.close() + raise HandshakeFailed() + + return session + + +class Transport(BaseSecureTransport): + """ + ``Transport`` provides a security upgrader for a ``IRawConnection``, + following the `secio` protocol defined in the libp2p specs. + """ + + def get_nonce(self) -> bytes: + return self.secure_bytes_provider(NONCE_SIZE) + + async def secure_inbound(self, conn: IRawConnection) -> ISecureConn: + """ + Secure the connection, either locally or by communicating with opposing node via conn, + for an inbound connection (i.e. we are not the initiator) + :return: secure connection object (that implements secure_conn_interface) + """ + local_nonce = self.get_nonce() + local_peer = self.local_peer + local_private_key = self.local_private_key + + return await create_secure_session( + local_nonce, local_peer, local_private_key, conn + ) + + async def secure_outbound( + self, conn: IRawConnection, peer_id: PeerID + ) -> ISecureConn: + """ + Secure the connection, either locally or by communicating with opposing node via conn, + for an inbound connection (i.e. we are the initiator) + :return: secure connection object (that implements secure_conn_interface) + """ + local_nonce = self.get_nonce() + local_peer = self.local_peer + local_private_key = self.local_private_key + + return await create_secure_session( + local_nonce, local_peer, local_private_key, conn, peer_id + ) diff --git a/libp2p/security/simple/transport.py b/libp2p/security/simple/transport.py new file mode 100644 index 00000000..28187d1d --- /dev/null +++ b/libp2p/security/simple/transport.py @@ -0,0 +1,79 @@ +import asyncio + +from libp2p.crypto.keys import KeyPair +from libp2p.network.connection.raw_connection_interface import IRawConnection +from libp2p.peer.id import ID +from libp2p.security.base_transport import BaseSecureTransport +from libp2p.security.insecure.transport import InsecureSession +from libp2p.security.secure_conn_interface import ISecureConn +from libp2p.transport.exceptions import SecurityUpgradeFailure +from libp2p.utils import encode_fixedint_prefixed, read_fixedint_prefixed + + +class SimpleSecurityTransport(BaseSecureTransport): + key_phrase: str + + def __init__(self, local_key_pair: KeyPair, key_phrase: str) -> None: + super().__init__(local_key_pair) + self.key_phrase = key_phrase + + async def secure_inbound(self, conn: IRawConnection) -> ISecureConn: + """ + Secure the connection, either locally or by communicating with opposing node via conn, + for an inbound connection (i.e. we are not the initiator) + :return: secure connection object (that implements secure_conn_interface) + """ + await conn.write(encode_fixedint_prefixed(self.key_phrase.encode())) + incoming = (await read_fixedint_prefixed(conn)).decode() + + if incoming != self.key_phrase: + raise SecurityUpgradeFailure( + "Key phrase differed between nodes. Expected " + self.key_phrase + ) + + session = InsecureSession( + self.local_peer, self.local_private_key, conn, ID(b"") + ) + # NOTE: Here we calls `run_handshake` for both sides to exchange their public keys and + # peer ids, otherwise tests fail. However, it seems pretty weird that + # `SimpleSecurityTransport` sends peer id through `Insecure`. + await session.run_handshake() + # NOTE: this is abusing the abstraction we have here + # but this code may be deprecated soon and this exists + # mainly to satisfy a test that will go along w/ it + # FIXME: Enable type check back when we can deprecate the simple transport. + session.key_phrase = self.key_phrase # type: ignore + return session + + async def secure_outbound(self, conn: IRawConnection, peer_id: ID) -> ISecureConn: + """ + Secure the connection, either locally or by communicating with opposing node via conn, + for an inbound connection (i.e. we are the initiator) + :return: secure connection object (that implements secure_conn_interface) + """ + await conn.write(encode_fixedint_prefixed(self.key_phrase.encode())) + incoming = (await read_fixedint_prefixed(conn)).decode() + + # Force context switch, as this security transport is built for testing locally + # in a single event loop + await asyncio.sleep(0) + + if incoming != self.key_phrase: + raise SecurityUpgradeFailure( + "Key phrase differed between nodes. Expected " + self.key_phrase + ) + + session = InsecureSession( + self.local_peer, self.local_private_key, conn, peer_id + ) + + # NOTE: Here we calls `run_handshake` for both sides to exchange their public keys and + # peer ids, otherwise tests fail. However, it seems pretty weird that + # `SimpleSecurityTransport` sends peer id through `Insecure`. + await session.run_handshake() + # NOTE: this is abusing the abstraction we have here + # but this code may be deprecated soon and this exists + # mainly to satisfy a test that will go along w/ it + # FIXME: Enable type check back when we can deprecate the simple transport. + session.key_phrase = self.key_phrase # type: ignore + return session diff --git a/tests/crypto/test_secp256k1.py b/tests/crypto/test_secp256k1.py new file mode 100644 index 00000000..81e9eb23 --- /dev/null +++ b/tests/crypto/test_secp256k1.py @@ -0,0 +1,22 @@ +from libp2p.crypto.secp256k1 import create_new_key_pair +from libp2p.crypto.serialization import deserialize_private_key, deserialize_public_key + + +def test_public_key_serialize_deserialize_round_trip(): + key_pair = create_new_key_pair() + public_key = key_pair.public_key + + public_key_bytes = public_key.serialize() + another_public_key = deserialize_public_key(public_key_bytes) + + assert public_key == another_public_key + + +def test_private_key_serialize_deserialize_round_trip(): + key_pair = create_new_key_pair() + private_key = key_pair.private_key + + private_key_bytes = private_key.serialize() + another_private_key = deserialize_private_key(private_key_bytes) + + assert private_key == another_private_key diff --git a/tests/peer/test_peerid.py b/tests/peer/test_peerid.py index ea244ce8..e808a3b6 100644 --- a/tests/peer/test_peerid.py +++ b/tests/peer/test_peerid.py @@ -4,10 +4,14 @@ import base58 import multihash from libp2p.crypto.rsa import create_new_key_pair +import libp2p.peer.id as PeerID from libp2p.peer.id import ID ALPHABETS = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz" +# ensure we are not in "debug" mode for the following tests +PeerID.FRIENDLY_IDS = False + def test_eq_impl_for_bytes(): random_id_string = "" diff --git a/tests/security/test_secio.py b/tests/security/test_secio.py new file mode 100644 index 00000000..673cbc58 --- /dev/null +++ b/tests/security/test_secio.py @@ -0,0 +1,101 @@ +import asyncio + +import pytest + +from libp2p.crypto.secp256k1 import create_new_key_pair +from libp2p.network.connection.raw_connection_interface import IRawConnection +from libp2p.peer.id import ID +from libp2p.security.secio.transport import NONCE_SIZE, create_secure_session + + +class InMemoryConnection(IRawConnection): + def __init__(self, peer, initiator=False): + self.peer = peer + self.recv_queue = asyncio.Queue() + self.send_queue = asyncio.Queue() + self.initiator = initiator + + self.current_msg = None + self.current_position = 0 + + self.closed = False + + async def write(self, data: bytes) -> None: + if self.closed: + raise Exception("InMemoryConnection is closed for writing") + + await self.send_queue.put(data) + + async def read(self, n: int = -1) -> bytes: + """ + NOTE: have to buffer the current message and juggle packets + off the recv queue to satisfy the semantics of this function. + """ + if self.closed: + raise Exception("InMemoryConnection is closed for reading") + + if not self.current_msg: + self.current_msg = await self.recv_queue.get() + self.current_position = 0 + + if n < 0: + msg = self.current_msg + self.current_msg = None + return msg + + next_msg = self.current_msg[self.current_position : self.current_position + n] + self.current_position += n + if self.current_position == len(self.current_msg): + self.current_msg = None + return next_msg + + async def close(self) -> None: + self.closed = True + + +async def create_pipe(local_conn, remote_conn): + try: + while True: + next_msg = await local_conn.send_queue.get() + await remote_conn.recv_queue.put(next_msg) + except asyncio.CancelledError: + return + + +@pytest.mark.asyncio +async def test_create_secure_session(): + local_nonce = b"\x01" * NONCE_SIZE + local_key_pair = create_new_key_pair(b"a") + local_peer = ID.from_pubkey(local_key_pair.public_key) + + remote_nonce = b"\x02" * NONCE_SIZE + remote_key_pair = create_new_key_pair(b"b") + remote_peer = ID.from_pubkey(remote_key_pair.public_key) + + local_conn = InMemoryConnection(local_peer, initiator=True) + remote_conn = InMemoryConnection(remote_peer) + + local_pipe_task = asyncio.create_task(create_pipe(local_conn, remote_conn)) + remote_pipe_task = asyncio.create_task(create_pipe(remote_conn, local_conn)) + + local_session_builder = create_secure_session( + local_nonce, local_peer, local_key_pair.private_key, local_conn, remote_peer + ) + remote_session_builder = create_secure_session( + remote_nonce, remote_peer, remote_key_pair.private_key, remote_conn + ) + local_secure_conn, remote_secure_conn = await asyncio.gather( + local_session_builder, remote_session_builder + ) + + msg = b"abc" + await local_secure_conn.write(msg) + received_msg = await remote_secure_conn.read() + assert received_msg == msg + + await asyncio.gather(local_secure_conn.close(), remote_secure_conn.close()) + + local_pipe_task.cancel() + remote_pipe_task.cancel() + await local_pipe_task + await remote_pipe_task