mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-12 16:10:57 +00:00
Merge pull request #254 from ralexstokes/implement-secio
Implement `secio`
This commit is contained in:
2
Makefile
2
Makefile
@ -1,5 +1,5 @@
|
|||||||
FILES_TO_LINT = libp2p tests examples setup.py
|
FILES_TO_LINT = libp2p tests examples setup.py
|
||||||
PB = libp2p/crypto/pb/crypto.proto libp2p/pubsub/pb/rpc.proto libp2p/security/insecure/pb/plaintext.proto
|
PB = libp2p/crypto/pb/crypto.proto libp2p/pubsub/pb/rpc.proto libp2p/security/insecure/pb/plaintext.proto libp2p/security/secio/pb/spipe.proto
|
||||||
PY = $(PB:.proto=_pb2.py)
|
PY = $(PB:.proto=_pb2.py)
|
||||||
PYI = $(PB:.proto=_pb2.pyi)
|
PYI = $(PB:.proto=_pb2.pyi)
|
||||||
|
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from libp2p.peer.peerstore_interface import IPeerStore
|
|||||||
from libp2p.routing.interfaces import IPeerRouting
|
from libp2p.routing.interfaces import IPeerRouting
|
||||||
from libp2p.routing.kademlia.kademlia_peer_router import KadmeliaPeerRouter
|
from libp2p.routing.kademlia.kademlia_peer_router import KadmeliaPeerRouter
|
||||||
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
||||||
|
import libp2p.security.secio.transport as secio
|
||||||
from libp2p.security.secure_transport_interface import ISecureTransport
|
from libp2p.security.secure_transport_interface import ISecureTransport
|
||||||
from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex
|
from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex
|
||||||
from libp2p.stream_muxer.muxer_multistream import MuxerClassType
|
from libp2p.stream_muxer.muxer_multistream import MuxerClassType
|
||||||
@ -98,7 +99,8 @@ def initialize_default_swarm(
|
|||||||
|
|
||||||
muxer_transports_by_protocol = muxer_opt or {MPLEX_PROTOCOL_ID: Mplex}
|
muxer_transports_by_protocol = muxer_opt or {MPLEX_PROTOCOL_ID: Mplex}
|
||||||
security_transports_by_protocol = sec_opt or {
|
security_transports_by_protocol = sec_opt or {
|
||||||
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
|
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair),
|
||||||
|
TProtocol(secio.ID): secio.Transport(key_pair),
|
||||||
}
|
}
|
||||||
upgrader = TransportUpgrader(
|
upgrader = TransportUpgrader(
|
||||||
security_transports_by_protocol, muxer_transports_by_protocol
|
security_transports_by_protocol, muxer_transports_by_protocol
|
||||||
|
|||||||
128
libp2p/crypto/authenticated_encryption.py
Normal file
128
libp2p/crypto/authenticated_encryption.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
import hmac
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from Crypto.Cipher import AES
|
||||||
|
import Crypto.Util.Counter as Counter
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidMACException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class EncryptionParameters:
|
||||||
|
cipher_type: str
|
||||||
|
hash_type: str
|
||||||
|
iv: bytes
|
||||||
|
mac_key: bytes
|
||||||
|
cipher_key: bytes
|
||||||
|
|
||||||
|
|
||||||
|
class MacAndCipher:
|
||||||
|
def __init__(self, parameters: EncryptionParameters) -> None:
|
||||||
|
self.authenticator = hmac.new(
|
||||||
|
parameters.mac_key, digestmod=parameters.hash_type
|
||||||
|
)
|
||||||
|
iv_bit_size = 8 * len(parameters.iv)
|
||||||
|
cipher = AES.new(
|
||||||
|
parameters.cipher_key,
|
||||||
|
AES.MODE_CTR,
|
||||||
|
counter=Counter.new(
|
||||||
|
iv_bit_size,
|
||||||
|
initial_value=int.from_bytes(parameters.iv, byteorder="big"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.cipher = cipher
|
||||||
|
|
||||||
|
def encrypt(self, data: bytes) -> bytes:
|
||||||
|
return self.cipher.encrypt(data)
|
||||||
|
|
||||||
|
def authenticate(self, data: bytes) -> bytes:
|
||||||
|
authenticator = self.authenticator.copy()
|
||||||
|
authenticator.update(data)
|
||||||
|
return authenticator.digest()
|
||||||
|
|
||||||
|
def decrypt_if_valid(self, data_with_tag: bytes) -> bytes:
|
||||||
|
tag_position = len(data_with_tag) - self.authenticator.digest_size
|
||||||
|
data = data_with_tag[:tag_position]
|
||||||
|
tag = data_with_tag[tag_position:]
|
||||||
|
|
||||||
|
authenticator = self.authenticator.copy()
|
||||||
|
authenticator.update(data)
|
||||||
|
expected_tag = authenticator.digest()
|
||||||
|
|
||||||
|
if not hmac.compare_digest(tag, expected_tag):
|
||||||
|
raise InvalidMACException(expected_tag, tag)
|
||||||
|
|
||||||
|
return self.cipher.decrypt(data)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_pair(
|
||||||
|
cipher_type: str, hash_type: str, secret: bytes
|
||||||
|
) -> Tuple[EncryptionParameters, EncryptionParameters]:
|
||||||
|
"""
|
||||||
|
Return a pair of ``Keys`` for use in securing a
|
||||||
|
communications channel with authenticated encryption
|
||||||
|
derived from the ``secret`` and using the
|
||||||
|
requested ``cipher_type`` and ``hash_type``.
|
||||||
|
"""
|
||||||
|
if cipher_type != "AES-128":
|
||||||
|
raise NotImplementedError()
|
||||||
|
if hash_type != "SHA256":
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
iv_size = 16
|
||||||
|
cipher_key_size = 16
|
||||||
|
hmac_key_size = 20
|
||||||
|
seed = "key expansion".encode()
|
||||||
|
|
||||||
|
params_size = iv_size + cipher_key_size + hmac_key_size
|
||||||
|
result = bytearray(2 * params_size)
|
||||||
|
|
||||||
|
authenticator = hmac.new(secret, digestmod=hash_type)
|
||||||
|
authenticator.update(seed)
|
||||||
|
tag = authenticator.digest()
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
len_result = 2 * params_size
|
||||||
|
while i < len_result:
|
||||||
|
authenticator = hmac.new(secret, digestmod=hash_type)
|
||||||
|
|
||||||
|
authenticator.update(tag)
|
||||||
|
authenticator.update(seed)
|
||||||
|
|
||||||
|
another_tag = authenticator.digest()
|
||||||
|
|
||||||
|
remaining_bytes = len(another_tag)
|
||||||
|
|
||||||
|
if i + remaining_bytes > len_result:
|
||||||
|
remaining_bytes = len_result - i
|
||||||
|
|
||||||
|
result[i : i + remaining_bytes] = another_tag[0:remaining_bytes]
|
||||||
|
|
||||||
|
i += remaining_bytes
|
||||||
|
|
||||||
|
authenticator = hmac.new(secret, digestmod=hash_type)
|
||||||
|
authenticator.update(tag)
|
||||||
|
tag = authenticator.digest()
|
||||||
|
|
||||||
|
first_half = result[:params_size]
|
||||||
|
second_half = result[params_size:]
|
||||||
|
|
||||||
|
return (
|
||||||
|
EncryptionParameters(
|
||||||
|
cipher_type,
|
||||||
|
hash_type,
|
||||||
|
first_half[0:iv_size],
|
||||||
|
first_half[iv_size + cipher_key_size :],
|
||||||
|
first_half[iv_size : iv_size + cipher_key_size],
|
||||||
|
),
|
||||||
|
EncryptionParameters(
|
||||||
|
cipher_type,
|
||||||
|
hash_type,
|
||||||
|
second_half[0:iv_size],
|
||||||
|
second_half[iv_size + cipher_key_size :],
|
||||||
|
second_half[iv_size : iv_size + cipher_key_size],
|
||||||
|
),
|
||||||
|
)
|
||||||
56
libp2p/crypto/ecc.py
Normal file
56
libp2p/crypto/ecc.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from Crypto.PublicKey import ECC
|
||||||
|
from Crypto.PublicKey.ECC import EccKey
|
||||||
|
|
||||||
|
from libp2p.crypto.keys import KeyPair, KeyType, PrivateKey, PublicKey
|
||||||
|
|
||||||
|
|
||||||
|
class ECCPublicKey(PublicKey):
|
||||||
|
def __init__(self, impl: EccKey) -> None:
|
||||||
|
self.impl = impl
|
||||||
|
|
||||||
|
def to_bytes(self) -> bytes:
|
||||||
|
return cast(bytes, self.impl.export_key(format="DER"))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_bytes(cls, data: bytes) -> "ECCPublicKey":
|
||||||
|
public_key_impl = ECC.import_key(data)
|
||||||
|
return cls(public_key_impl)
|
||||||
|
|
||||||
|
def get_type(self) -> KeyType:
|
||||||
|
return KeyType.ECC_P256
|
||||||
|
|
||||||
|
def verify(self, data: bytes, signature: bytes) -> bool:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class ECCPrivateKey(PrivateKey):
|
||||||
|
def __init__(self, impl: EccKey) -> None:
|
||||||
|
self.impl = impl
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def new(cls, curve: str) -> "ECCPrivateKey":
|
||||||
|
private_key_impl = ECC.generate(curve=curve)
|
||||||
|
return cls(private_key_impl)
|
||||||
|
|
||||||
|
def to_bytes(self) -> bytes:
|
||||||
|
return cast(bytes, self.impl.export_key(format="DER"))
|
||||||
|
|
||||||
|
def get_type(self) -> KeyType:
|
||||||
|
return KeyType.ECC_P256
|
||||||
|
|
||||||
|
def sign(self, data: bytes) -> bytes:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_public_key(self) -> PublicKey:
|
||||||
|
return ECCPublicKey(self.impl.public_key())
|
||||||
|
|
||||||
|
|
||||||
|
def create_new_key_pair(curve: str) -> KeyPair:
|
||||||
|
"""
|
||||||
|
Return a new ECC keypair with the requested ``curve`` type, e.g. "P-256".
|
||||||
|
"""
|
||||||
|
private_key = ECCPrivateKey.new(curve)
|
||||||
|
public_key = private_key.get_public_key()
|
||||||
|
return KeyPair(private_key, public_key)
|
||||||
29
libp2p/crypto/key_exchange.py
Normal file
29
libp2p/crypto/key_exchange.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from typing import Callable, Tuple, cast
|
||||||
|
|
||||||
|
from Crypto.Math.Numbers import Integer
|
||||||
|
import Crypto.PublicKey.ECC as ECC
|
||||||
|
|
||||||
|
from libp2p.crypto.ecc import ECCPrivateKey, create_new_key_pair
|
||||||
|
from libp2p.crypto.keys import PublicKey
|
||||||
|
|
||||||
|
SharedKeyGenerator = Callable[[bytes], bytes]
|
||||||
|
|
||||||
|
|
||||||
|
def create_ephemeral_key_pair(curve_type: str) -> Tuple[PublicKey, SharedKeyGenerator]:
|
||||||
|
"""
|
||||||
|
Facilitates ECDH key exchange.
|
||||||
|
"""
|
||||||
|
if curve_type != "P-256":
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
key_pair = create_new_key_pair(curve_type)
|
||||||
|
|
||||||
|
def _key_exchange(serialized_remote_public_key: bytes) -> bytes:
|
||||||
|
remote_public_key = ECC.import_key(serialized_remote_public_key)
|
||||||
|
curve_point = remote_public_key.pointQ
|
||||||
|
private_key = cast(ECCPrivateKey, key_pair.private_key)
|
||||||
|
secret_point = curve_point * private_key.impl.d
|
||||||
|
byte_size = secret_point.size_in_bytes()
|
||||||
|
return cast(Integer, secret_point.x).to_bytes(byte_size)
|
||||||
|
|
||||||
|
return key_pair.public_key, _key_exchange
|
||||||
@ -11,6 +11,7 @@ class KeyType(Enum):
|
|||||||
Ed25519 = 1
|
Ed25519 = 1
|
||||||
Secp256k1 = 2
|
Secp256k1 = 2
|
||||||
ECDSA = 3
|
ECDSA = 3
|
||||||
|
ECC_P256 = 4
|
||||||
|
|
||||||
|
|
||||||
class Key(ABC):
|
class Key(ABC):
|
||||||
@ -32,6 +33,11 @@ class Key(ABC):
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
if not isinstance(other, Key):
|
||||||
|
return NotImplemented
|
||||||
|
return self.to_bytes() == other.to_bytes()
|
||||||
|
|
||||||
|
|
||||||
class PublicKey(Key):
|
class PublicKey(Key):
|
||||||
"""
|
"""
|
||||||
@ -60,14 +66,16 @@ class PublicKey(Key):
|
|||||||
"""
|
"""
|
||||||
return self._serialize_to_protobuf().SerializeToString()
|
return self._serialize_to_protobuf().SerializeToString()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def deserialize_from_protobuf(cls, protobuf_data: bytes) -> protobuf.PublicKey:
|
||||||
|
return protobuf.PublicKey.FromString(protobuf_data)
|
||||||
|
|
||||||
|
|
||||||
class PrivateKey(Key):
|
class PrivateKey(Key):
|
||||||
"""
|
"""
|
||||||
A ``PrivateKey`` represents a cryptographic private key.
|
A ``PrivateKey`` represents a cryptographic private key.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
protobuf_constructor = protobuf.PrivateKey
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def sign(self, data: bytes) -> bytes:
|
def sign(self, data: bytes) -> bytes:
|
||||||
...
|
...
|
||||||
@ -91,6 +99,10 @@ class PrivateKey(Key):
|
|||||||
"""
|
"""
|
||||||
return self._serialize_to_protobuf().SerializeToString()
|
return self._serialize_to_protobuf().SerializeToString()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def deserialize_from_protobuf(cls, protobuf_data: bytes) -> protobuf.PrivateKey:
|
||||||
|
return protobuf.PrivateKey.FromString(protobuf_data)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class KeyPair:
|
class KeyPair:
|
||||||
|
|||||||
@ -11,15 +11,20 @@ class Secp256k1PublicKey(PublicKey):
|
|||||||
return self.impl.format()
|
return self.impl.format()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_bytes(cls, key_bytes: bytes) -> "Secp256k1PublicKey":
|
def from_bytes(cls, data: bytes) -> "Secp256k1PublicKey":
|
||||||
secp256k1_pubkey = coincurve.PublicKey(key_bytes)
|
impl = coincurve.PublicKey(data)
|
||||||
return cls(secp256k1_pubkey)
|
return cls(impl)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def deserialize(cls, data: bytes) -> "Secp256k1PublicKey":
|
||||||
|
protobuf_key = cls.deserialize_from_protobuf(data)
|
||||||
|
return cls.from_bytes(protobuf_key.data)
|
||||||
|
|
||||||
def get_type(self) -> KeyType:
|
def get_type(self) -> KeyType:
|
||||||
return KeyType.Secp256k1
|
return KeyType.Secp256k1
|
||||||
|
|
||||||
def verify(self, data: bytes, signature: bytes) -> bool:
|
def verify(self, data: bytes, signature: bytes) -> bool:
|
||||||
raise NotImplementedError
|
return self.impl.verify(signature, data)
|
||||||
|
|
||||||
|
|
||||||
class Secp256k1PrivateKey(PrivateKey):
|
class Secp256k1PrivateKey(PrivateKey):
|
||||||
@ -34,11 +39,21 @@ class Secp256k1PrivateKey(PrivateKey):
|
|||||||
def to_bytes(self) -> bytes:
|
def to_bytes(self) -> bytes:
|
||||||
return self.impl.secret
|
return self.impl.secret
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_bytes(cls, data: bytes) -> "Secp256k1PrivateKey":
|
||||||
|
impl = coincurve.PrivateKey(data)
|
||||||
|
return cls(impl)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def deserialize(cls, data: bytes) -> "Secp256k1PrivateKey":
|
||||||
|
protobuf_key = cls.deserialize_from_protobuf(data)
|
||||||
|
return cls.from_bytes(protobuf_key.data)
|
||||||
|
|
||||||
def get_type(self) -> KeyType:
|
def get_type(self) -> KeyType:
|
||||||
return KeyType.Secp256k1
|
return KeyType.Secp256k1
|
||||||
|
|
||||||
def sign(self, data: bytes) -> bytes:
|
def sign(self, data: bytes) -> bytes:
|
||||||
raise NotImplementedError
|
return self.impl.sign(data)
|
||||||
|
|
||||||
def get_public_key(self) -> PublicKey:
|
def get_public_key(self) -> PublicKey:
|
||||||
public_key_impl = coincurve.PublicKey.from_secret(self.impl.secret)
|
public_key_impl = coincurve.PublicKey.from_secret(self.impl.secret)
|
||||||
|
|||||||
22
libp2p/crypto/serialization.py
Normal file
22
libp2p/crypto/serialization.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from libp2p.crypto.keys import KeyType, PrivateKey, PublicKey
|
||||||
|
from libp2p.crypto.secp256k1 import Secp256k1PrivateKey, Secp256k1PublicKey
|
||||||
|
|
||||||
|
key_type_to_public_key_deserializer = {
|
||||||
|
KeyType.Secp256k1.value: Secp256k1PublicKey.from_bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
key_type_to_private_key_deserializer = {
|
||||||
|
KeyType.Secp256k1.value: Secp256k1PrivateKey.from_bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_public_key(data: bytes) -> PublicKey:
|
||||||
|
f = PublicKey.deserialize_from_protobuf(data)
|
||||||
|
deserializer = key_type_to_public_key_deserializer[f.key_type]
|
||||||
|
return deserializer(f.data)
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_private_key(data: bytes) -> PrivateKey:
|
||||||
|
f = PrivateKey.deserialize_from_protobuf(data)
|
||||||
|
deserializer = key_type_to_private_key_deserializer[f.key_type]
|
||||||
|
return deserializer(f.data)
|
||||||
0
libp2p/io/__init__.py
Normal file
0
libp2p/io/__init__.py
Normal file
13
libp2p/io/exceptions.py
Normal file
13
libp2p/io/exceptions.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from libp2p.exceptions import BaseLibp2pError
|
||||||
|
|
||||||
|
|
||||||
|
class MsgioException(BaseLibp2pError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MissingLengthException(MsgioException):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MissingMessageException(MsgioException):
|
||||||
|
pass
|
||||||
24
libp2p/io/msgio.py
Normal file
24
libp2p/io/msgio.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from libp2p.network.connection.raw_connection_interface import IRawConnection
|
||||||
|
|
||||||
|
from .exceptions import MissingLengthException, MissingMessageException
|
||||||
|
|
||||||
|
SIZE_LEN_BYTES = 4
|
||||||
|
|
||||||
|
# TODO unify w/ https://github.com/libp2p/py-libp2p/blob/1aed52856f56a4b791696bbcbac31b5f9c2e88c9/libp2p/utils.py#L85-L99 # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
|
def encode(msg_bytes: bytes) -> bytes:
|
||||||
|
len_prefix = len(msg_bytes).to_bytes(SIZE_LEN_BYTES, "big")
|
||||||
|
return len_prefix + msg_bytes
|
||||||
|
|
||||||
|
|
||||||
|
async def read_next_message(reader: IRawConnection) -> bytes:
|
||||||
|
len_bytes = await reader.read(SIZE_LEN_BYTES)
|
||||||
|
if len(len_bytes) != SIZE_LEN_BYTES:
|
||||||
|
raise MissingLengthException()
|
||||||
|
len_int = int.from_bytes(len_bytes, "big")
|
||||||
|
next_msg = await reader.read(len_int)
|
||||||
|
if len(next_msg) != len_int:
|
||||||
|
# TODO makes sense to keep reading until this condition is true?
|
||||||
|
raise MissingMessageException()
|
||||||
|
return next_msg
|
||||||
@ -6,6 +6,11 @@ import multihash
|
|||||||
|
|
||||||
from libp2p.crypto.keys import PublicKey
|
from libp2p.crypto.keys import PublicKey
|
||||||
|
|
||||||
|
# NOTE: ``FRIENDLY_IDS`` renders a ``str`` representation of ``ID`` as a
|
||||||
|
# short string of a prefix of the base58 representation. This feature is primarily
|
||||||
|
# intended for debugging, logging, etc.
|
||||||
|
FRIENDLY_IDS = True
|
||||||
|
|
||||||
|
|
||||||
class ID:
|
class ID:
|
||||||
_bytes: bytes
|
_bytes: bytes
|
||||||
@ -32,7 +37,13 @@ class ID:
|
|||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return "<libp2p.peer.id.ID 0x" + self._bytes.hex() + ">"
|
return "<libp2p.peer.id.ID 0x" + self._bytes.hex() + ">"
|
||||||
|
|
||||||
__str__ = pretty = to_string = to_base58
|
pretty = to_string = to_base58
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
if FRIENDLY_IDS:
|
||||||
|
return self.to_string()[2:8]
|
||||||
|
else:
|
||||||
|
return self.to_string()
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
if isinstance(other, str):
|
if isinstance(other, str):
|
||||||
|
|||||||
@ -3,7 +3,6 @@ from typing import Optional
|
|||||||
from libp2p.crypto.keys import PrivateKey, PublicKey
|
from libp2p.crypto.keys import PrivateKey, PublicKey
|
||||||
from libp2p.network.connection.raw_connection_interface import IRawConnection
|
from libp2p.network.connection.raw_connection_interface import IRawConnection
|
||||||
from libp2p.peer.id import ID
|
from libp2p.peer.id import ID
|
||||||
from libp2p.security.base_transport import BaseSecureTransport
|
|
||||||
from libp2p.security.secure_conn_interface import ISecureConn
|
from libp2p.security.secure_conn_interface import ISecureConn
|
||||||
|
|
||||||
|
|
||||||
@ -20,14 +19,18 @@ class BaseSession(ISecureConn):
|
|||||||
remote_permanent_pubkey: PublicKey
|
remote_permanent_pubkey: PublicKey
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, transport: BaseSecureTransport, conn: IRawConnection, peer_id: ID
|
self,
|
||||||
|
local_peer: ID,
|
||||||
|
local_private_key: PrivateKey,
|
||||||
|
conn: IRawConnection,
|
||||||
|
peer_id: Optional[ID] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.local_peer = transport.local_peer
|
self.local_peer = local_peer
|
||||||
self.local_private_key = transport.local_private_key
|
self.local_private_key = local_private_key
|
||||||
self.conn = conn
|
|
||||||
self.remote_peer_id = peer_id
|
self.remote_peer_id = peer_id
|
||||||
self.remote_permanent_pubkey = None
|
self.remote_permanent_pubkey = None
|
||||||
|
|
||||||
|
self.conn = conn
|
||||||
self.initiator = self.conn.initiator
|
self.initiator = self.conn.initiator
|
||||||
|
|
||||||
async def write(self, data: bytes) -> None:
|
async def write(self, data: bytes) -> None:
|
||||||
|
|||||||
@ -1,14 +1,30 @@
|
|||||||
|
import secrets
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
from libp2p.crypto.keys import KeyPair
|
from libp2p.crypto.keys import KeyPair
|
||||||
from libp2p.peer.id import ID
|
from libp2p.peer.id import ID
|
||||||
from libp2p.security.secure_transport_interface import ISecureTransport
|
from libp2p.security.secure_transport_interface import ISecureTransport
|
||||||
|
|
||||||
|
|
||||||
|
def default_secure_bytes_provider(n: int) -> bytes:
|
||||||
|
return secrets.token_bytes(n)
|
||||||
|
|
||||||
|
|
||||||
class BaseSecureTransport(ISecureTransport):
|
class BaseSecureTransport(ISecureTransport):
|
||||||
"""
|
"""
|
||||||
``BaseSecureTransport`` is not fully instantiated from its abstract classes as it
|
``BaseSecureTransport`` is not fully instantiated from its abstract classes as it
|
||||||
is only meant to be used in clases that derive from it.
|
is only meant to be used in clases that derive from it.
|
||||||
|
|
||||||
|
Clients can provide a strategy to get cryptographically secure bytes of a given length.
|
||||||
|
A default implementation is provided using the ``secrets`` module from the
|
||||||
|
standard library.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, local_key_pair: KeyPair) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
local_key_pair: KeyPair,
|
||||||
|
secure_bytes_provider: Callable[[int], bytes] = default_secure_bytes_provider,
|
||||||
|
) -> None:
|
||||||
self.local_private_key = local_key_pair.private_key
|
self.local_private_key = local_key_pair.private_key
|
||||||
self.local_peer = ID.from_pubkey(local_key_pair.public_key)
|
self.local_peer = ID.from_pubkey(local_key_pair.public_key)
|
||||||
|
self.secure_bytes_provider = secure_bytes_provider
|
||||||
|
|||||||
@ -76,7 +76,7 @@ class InsecureTransport(BaseSecureTransport):
|
|||||||
for an inbound connection (i.e. we are not the initiator)
|
for an inbound connection (i.e. we are not the initiator)
|
||||||
:return: secure connection object (that implements secure_conn_interface)
|
:return: secure connection object (that implements secure_conn_interface)
|
||||||
"""
|
"""
|
||||||
session = InsecureSession(self, conn, ID(b""))
|
session = InsecureSession(self.local_peer, self.local_private_key, conn)
|
||||||
await session.run_handshake()
|
await session.run_handshake()
|
||||||
return session
|
return session
|
||||||
|
|
||||||
@ -86,7 +86,9 @@ class InsecureTransport(BaseSecureTransport):
|
|||||||
for an inbound connection (i.e. we are the initiator)
|
for an inbound connection (i.e. we are the initiator)
|
||||||
:return: secure connection object (that implements secure_conn_interface)
|
:return: secure connection object (that implements secure_conn_interface)
|
||||||
"""
|
"""
|
||||||
session = InsecureSession(self, conn, peer_id)
|
session = InsecureSession(
|
||||||
|
self.local_peer, self.local_private_key, conn, peer_id
|
||||||
|
)
|
||||||
await session.run_handshake()
|
await session.run_handshake()
|
||||||
return session
|
return session
|
||||||
|
|
||||||
|
|||||||
27
libp2p/security/secio/exceptions.py
Normal file
27
libp2p/security/secio/exceptions.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
class SecioException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SelfEncryption(SecioException):
|
||||||
|
"""
|
||||||
|
Raised to indicate that a host is attempting to encrypt communications
|
||||||
|
with itself.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PeerMismatchException(SecioException):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidSignatureOnExchange(SecioException):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class HandshakeFailed(SecioException):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class IncompatibleChoices(SecioException):
|
||||||
|
pass
|
||||||
16
libp2p/security/secio/pb/spipe.proto
Normal file
16
libp2p/security/secio/pb/spipe.proto
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
syntax = "proto2";
|
||||||
|
|
||||||
|
package spipe.pb;
|
||||||
|
|
||||||
|
message Propose {
|
||||||
|
optional bytes rand = 1;
|
||||||
|
optional bytes public_key = 2;
|
||||||
|
optional string exchanges = 3;
|
||||||
|
optional string ciphers = 4;
|
||||||
|
optional string hashes = 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
message Exchange {
|
||||||
|
optional bytes ephemeral_public_key = 1;
|
||||||
|
optional bytes signature = 2;
|
||||||
|
}
|
||||||
144
libp2p/security/secio/pb/spipe_pb2.py
Normal file
144
libp2p/security/secio/pb/spipe_pb2.py
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||||
|
# source: spipe.proto
|
||||||
|
|
||||||
|
import sys
|
||||||
|
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
|
||||||
|
from google.protobuf import descriptor as _descriptor
|
||||||
|
from google.protobuf import message as _message
|
||||||
|
from google.protobuf import reflection as _reflection
|
||||||
|
from google.protobuf import symbol_database as _symbol_database
|
||||||
|
# @@protoc_insertion_point(imports)
|
||||||
|
|
||||||
|
_sym_db = _symbol_database.Default()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
DESCRIPTOR = _descriptor.FileDescriptor(
|
||||||
|
name='spipe.proto',
|
||||||
|
package='spipe.pb',
|
||||||
|
syntax='proto2',
|
||||||
|
serialized_options=None,
|
||||||
|
serialized_pb=_b('\n\x0bspipe.proto\x12\x08spipe.pb\"_\n\x07Propose\x12\x0c\n\x04rand\x18\x01 \x01(\x0c\x12\x12\n\npublic_key\x18\x02 \x01(\x0c\x12\x11\n\texchanges\x18\x03 \x01(\t\x12\x0f\n\x07\x63iphers\x18\x04 \x01(\t\x12\x0e\n\x06hashes\x18\x05 \x01(\t\";\n\x08\x45xchange\x12\x1c\n\x14\x65phemeral_public_key\x18\x01 \x01(\x0c\x12\x11\n\tsignature\x18\x02 \x01(\x0c')
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
_PROPOSE = _descriptor.Descriptor(
|
||||||
|
name='Propose',
|
||||||
|
full_name='spipe.pb.Propose',
|
||||||
|
filename=None,
|
||||||
|
file=DESCRIPTOR,
|
||||||
|
containing_type=None,
|
||||||
|
fields=[
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='rand', full_name='spipe.pb.Propose.rand', index=0,
|
||||||
|
number=1, type=12, cpp_type=9, label=1,
|
||||||
|
has_default_value=False, default_value=_b(""),
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR),
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='public_key', full_name='spipe.pb.Propose.public_key', index=1,
|
||||||
|
number=2, type=12, cpp_type=9, label=1,
|
||||||
|
has_default_value=False, default_value=_b(""),
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR),
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='exchanges', full_name='spipe.pb.Propose.exchanges', index=2,
|
||||||
|
number=3, type=9, cpp_type=9, label=1,
|
||||||
|
has_default_value=False, default_value=_b("").decode('utf-8'),
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR),
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='ciphers', full_name='spipe.pb.Propose.ciphers', index=3,
|
||||||
|
number=4, type=9, cpp_type=9, label=1,
|
||||||
|
has_default_value=False, default_value=_b("").decode('utf-8'),
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR),
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='hashes', full_name='spipe.pb.Propose.hashes', index=4,
|
||||||
|
number=5, type=9, cpp_type=9, label=1,
|
||||||
|
has_default_value=False, default_value=_b("").decode('utf-8'),
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR),
|
||||||
|
],
|
||||||
|
extensions=[
|
||||||
|
],
|
||||||
|
nested_types=[],
|
||||||
|
enum_types=[
|
||||||
|
],
|
||||||
|
serialized_options=None,
|
||||||
|
is_extendable=False,
|
||||||
|
syntax='proto2',
|
||||||
|
extension_ranges=[],
|
||||||
|
oneofs=[
|
||||||
|
],
|
||||||
|
serialized_start=25,
|
||||||
|
serialized_end=120,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_EXCHANGE = _descriptor.Descriptor(
|
||||||
|
name='Exchange',
|
||||||
|
full_name='spipe.pb.Exchange',
|
||||||
|
filename=None,
|
||||||
|
file=DESCRIPTOR,
|
||||||
|
containing_type=None,
|
||||||
|
fields=[
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='ephemeral_public_key', full_name='spipe.pb.Exchange.ephemeral_public_key', index=0,
|
||||||
|
number=1, type=12, cpp_type=9, label=1,
|
||||||
|
has_default_value=False, default_value=_b(""),
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR),
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='signature', full_name='spipe.pb.Exchange.signature', index=1,
|
||||||
|
number=2, type=12, cpp_type=9, label=1,
|
||||||
|
has_default_value=False, default_value=_b(""),
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR),
|
||||||
|
],
|
||||||
|
extensions=[
|
||||||
|
],
|
||||||
|
nested_types=[],
|
||||||
|
enum_types=[
|
||||||
|
],
|
||||||
|
serialized_options=None,
|
||||||
|
is_extendable=False,
|
||||||
|
syntax='proto2',
|
||||||
|
extension_ranges=[],
|
||||||
|
oneofs=[
|
||||||
|
],
|
||||||
|
serialized_start=122,
|
||||||
|
serialized_end=181,
|
||||||
|
)
|
||||||
|
|
||||||
|
DESCRIPTOR.message_types_by_name['Propose'] = _PROPOSE
|
||||||
|
DESCRIPTOR.message_types_by_name['Exchange'] = _EXCHANGE
|
||||||
|
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
|
||||||
|
|
||||||
|
Propose = _reflection.GeneratedProtocolMessageType('Propose', (_message.Message,), dict(
|
||||||
|
DESCRIPTOR = _PROPOSE,
|
||||||
|
__module__ = 'spipe_pb2'
|
||||||
|
# @@protoc_insertion_point(class_scope:spipe.pb.Propose)
|
||||||
|
))
|
||||||
|
_sym_db.RegisterMessage(Propose)
|
||||||
|
|
||||||
|
Exchange = _reflection.GeneratedProtocolMessageType('Exchange', (_message.Message,), dict(
|
||||||
|
DESCRIPTOR = _EXCHANGE,
|
||||||
|
__module__ = 'spipe_pb2'
|
||||||
|
# @@protoc_insertion_point(class_scope:spipe.pb.Exchange)
|
||||||
|
))
|
||||||
|
_sym_db.RegisterMessage(Exchange)
|
||||||
|
|
||||||
|
|
||||||
|
# @@protoc_insertion_point(module_scope)
|
||||||
67
libp2p/security/secio/pb/spipe_pb2.pyi
Normal file
67
libp2p/security/secio/pb/spipe_pb2.pyi
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
# @generated by generate_proto_mypy_stubs.py. Do not edit!
|
||||||
|
import sys
|
||||||
|
from google.protobuf.descriptor import (
|
||||||
|
Descriptor as google___protobuf___descriptor___Descriptor,
|
||||||
|
)
|
||||||
|
|
||||||
|
from google.protobuf.message import (
|
||||||
|
Message as google___protobuf___message___Message,
|
||||||
|
)
|
||||||
|
|
||||||
|
from typing import (
|
||||||
|
Optional as typing___Optional,
|
||||||
|
Text as typing___Text,
|
||||||
|
)
|
||||||
|
|
||||||
|
from typing_extensions import (
|
||||||
|
Literal as typing_extensions___Literal,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Propose(google___protobuf___message___Message):
|
||||||
|
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
|
||||||
|
rand = ... # type: bytes
|
||||||
|
public_key = ... # type: bytes
|
||||||
|
exchanges = ... # type: typing___Text
|
||||||
|
ciphers = ... # type: typing___Text
|
||||||
|
hashes = ... # type: typing___Text
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
*,
|
||||||
|
rand : typing___Optional[bytes] = None,
|
||||||
|
public_key : typing___Optional[bytes] = None,
|
||||||
|
exchanges : typing___Optional[typing___Text] = None,
|
||||||
|
ciphers : typing___Optional[typing___Text] = None,
|
||||||
|
hashes : typing___Optional[typing___Text] = None,
|
||||||
|
) -> None: ...
|
||||||
|
@classmethod
|
||||||
|
def FromString(cls, s: bytes) -> Propose: ...
|
||||||
|
def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
|
||||||
|
def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
|
||||||
|
if sys.version_info >= (3,):
|
||||||
|
def HasField(self, field_name: typing_extensions___Literal[u"ciphers",u"exchanges",u"hashes",u"public_key",u"rand"]) -> bool: ...
|
||||||
|
def ClearField(self, field_name: typing_extensions___Literal[u"ciphers",u"exchanges",u"hashes",u"public_key",u"rand"]) -> None: ...
|
||||||
|
else:
|
||||||
|
def HasField(self, field_name: typing_extensions___Literal[u"ciphers",b"ciphers",u"exchanges",b"exchanges",u"hashes",b"hashes",u"public_key",b"public_key",u"rand",b"rand"]) -> bool: ...
|
||||||
|
def ClearField(self, field_name: typing_extensions___Literal[u"ciphers",b"ciphers",u"exchanges",b"exchanges",u"hashes",b"hashes",u"public_key",b"public_key",u"rand",b"rand"]) -> None: ...
|
||||||
|
|
||||||
|
class Exchange(google___protobuf___message___Message):
|
||||||
|
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
|
||||||
|
ephemeral_public_key = ... # type: bytes
|
||||||
|
signature = ... # type: bytes
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
*,
|
||||||
|
ephemeral_public_key : typing___Optional[bytes] = None,
|
||||||
|
signature : typing___Optional[bytes] = None,
|
||||||
|
) -> None: ...
|
||||||
|
@classmethod
|
||||||
|
def FromString(cls, s: bytes) -> Exchange: ...
|
||||||
|
def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
|
||||||
|
def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
|
||||||
|
if sys.version_info >= (3,):
|
||||||
|
def HasField(self, field_name: typing_extensions___Literal[u"ephemeral_public_key",u"signature"]) -> bool: ...
|
||||||
|
def ClearField(self, field_name: typing_extensions___Literal[u"ephemeral_public_key",u"signature"]) -> None: ...
|
||||||
|
else:
|
||||||
|
def HasField(self, field_name: typing_extensions___Literal[u"ephemeral_public_key",b"ephemeral_public_key",u"signature",b"signature"]) -> bool: ...
|
||||||
|
def ClearField(self, field_name: typing_extensions___Literal[u"ephemeral_public_key",b"ephemeral_public_key",u"signature",b"signature"]) -> None: ...
|
||||||
403
libp2p/security/secio/transport.py
Normal file
403
libp2p/security/secio/transport.py
Normal file
@ -0,0 +1,403 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import multihash
|
||||||
|
|
||||||
|
from libp2p.crypto.authenticated_encryption import (
|
||||||
|
EncryptionParameters as AuthenticatedEncryptionParameters,
|
||||||
|
)
|
||||||
|
from libp2p.crypto.authenticated_encryption import (
|
||||||
|
initialize_pair as initialize_pair_for_encryption,
|
||||||
|
)
|
||||||
|
from libp2p.crypto.authenticated_encryption import MacAndCipher as Encrypter
|
||||||
|
from libp2p.crypto.ecc import ECCPublicKey
|
||||||
|
from libp2p.crypto.key_exchange import create_ephemeral_key_pair
|
||||||
|
from libp2p.crypto.keys import PrivateKey, PublicKey
|
||||||
|
from libp2p.crypto.serialization import deserialize_public_key
|
||||||
|
from libp2p.io.msgio import encode as encode_message
|
||||||
|
from libp2p.io.msgio import read_next_message
|
||||||
|
from libp2p.network.connection.raw_connection_interface import IRawConnection
|
||||||
|
from libp2p.peer.id import ID as PeerID
|
||||||
|
from libp2p.security.base_session import BaseSession
|
||||||
|
from libp2p.security.base_transport import BaseSecureTransport
|
||||||
|
from libp2p.security.secure_conn_interface import ISecureConn
|
||||||
|
|
||||||
|
from .exceptions import (
|
||||||
|
HandshakeFailed,
|
||||||
|
IncompatibleChoices,
|
||||||
|
InvalidSignatureOnExchange,
|
||||||
|
PeerMismatchException,
|
||||||
|
SecioException,
|
||||||
|
SelfEncryption,
|
||||||
|
)
|
||||||
|
from .pb.spipe_pb2 import Exchange, Propose
|
||||||
|
|
||||||
|
ID = "/secio/1.0.0"
|
||||||
|
|
||||||
|
NONCE_SIZE = 16 # bytes
|
||||||
|
|
||||||
|
# NOTE: the following is only a subset of allowable parameters according to the
|
||||||
|
# `secio` specification.
|
||||||
|
DEFAULT_SUPPORTED_EXCHANGES = "P-256"
|
||||||
|
DEFAULT_SUPPORTED_CIPHERS = "AES-128"
|
||||||
|
DEFAULT_SUPPORTED_HASHES = "SHA256"
|
||||||
|
|
||||||
|
|
||||||
|
class SecureSession(BaseSession):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
local_peer: PeerID,
|
||||||
|
local_private_key: PrivateKey,
|
||||||
|
local_encryption_parameters: AuthenticatedEncryptionParameters,
|
||||||
|
remote_peer: PeerID,
|
||||||
|
remote_encryption_parameters: AuthenticatedEncryptionParameters,
|
||||||
|
conn: IRawConnection,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(local_peer, local_private_key, conn, remote_peer)
|
||||||
|
|
||||||
|
self.local_encryption_parameters = local_encryption_parameters
|
||||||
|
self.remote_encryption_parameters = remote_encryption_parameters
|
||||||
|
self._initialize_authenticated_encryption_for_local_peer()
|
||||||
|
self._initialize_authenticated_encryption_for_remote_peer()
|
||||||
|
|
||||||
|
def _initialize_authenticated_encryption_for_local_peer(self) -> None:
|
||||||
|
self.local_encrypter = Encrypter(self.local_encryption_parameters)
|
||||||
|
|
||||||
|
def _initialize_authenticated_encryption_for_remote_peer(self) -> None:
|
||||||
|
self.remote_encrypter = Encrypter(self.remote_encryption_parameters)
|
||||||
|
|
||||||
|
async def read(self, n: int = -1) -> bytes:
|
||||||
|
return await self._read_msg()
|
||||||
|
|
||||||
|
async def _read_msg(self) -> bytes:
|
||||||
|
# TODO do we need to serialize reads?
|
||||||
|
msg = await read_next_message(self.conn)
|
||||||
|
return self.remote_encrypter.decrypt_if_valid(msg)
|
||||||
|
|
||||||
|
async def write(self, data: bytes) -> None:
|
||||||
|
await self._write_msg(data)
|
||||||
|
|
||||||
|
async def _write_msg(self, data: bytes) -> None:
|
||||||
|
# TODO do we need to serialize writes?
|
||||||
|
encrypted_data = self.local_encrypter.encrypt(data)
|
||||||
|
tag = self.local_encrypter.authenticate(encrypted_data)
|
||||||
|
msg = encode_message(encrypted_data + tag)
|
||||||
|
await self.conn.write(msg)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Proposal:
|
||||||
|
"""
|
||||||
|
A ``Proposal`` represents the set of session parameters one peer in a pair of
|
||||||
|
peers attempting to negotiate a `secio` channel prefers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
nonce: bytes
|
||||||
|
public_key: PublicKey
|
||||||
|
exchanges: str = DEFAULT_SUPPORTED_EXCHANGES # comma separated list
|
||||||
|
ciphers: str = DEFAULT_SUPPORTED_CIPHERS # comma separated list
|
||||||
|
hashes: str = DEFAULT_SUPPORTED_HASHES # comma separated list
|
||||||
|
|
||||||
|
def serialize(self) -> bytes:
|
||||||
|
protobuf = Propose(
|
||||||
|
rand=self.nonce,
|
||||||
|
public_key=self.public_key.serialize(),
|
||||||
|
exchanges=self.exchanges,
|
||||||
|
ciphers=self.ciphers,
|
||||||
|
hashes=self.hashes,
|
||||||
|
)
|
||||||
|
return protobuf.SerializeToString()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def deserialize(cls, protobuf_bytes: bytes) -> "Proposal":
|
||||||
|
protobuf = Propose.FromString(protobuf_bytes)
|
||||||
|
|
||||||
|
nonce = protobuf.rand
|
||||||
|
public_key_protobuf_bytes = protobuf.public_key
|
||||||
|
public_key = deserialize_public_key(public_key_protobuf_bytes)
|
||||||
|
exchanges = protobuf.exchanges
|
||||||
|
ciphers = protobuf.ciphers
|
||||||
|
hashes = protobuf.hashes
|
||||||
|
|
||||||
|
return cls(nonce, public_key, exchanges, ciphers, hashes)
|
||||||
|
|
||||||
|
def calculate_peer_id(self) -> PeerID:
|
||||||
|
return PeerID.from_pubkey(self.public_key)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EncryptionParameters:
|
||||||
|
permanent_public_key: PublicKey
|
||||||
|
|
||||||
|
curve_type: str
|
||||||
|
cipher_type: str
|
||||||
|
hash_type: str
|
||||||
|
|
||||||
|
ephemeral_public_key: PublicKey
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SessionParameters:
|
||||||
|
local_peer: PeerID
|
||||||
|
local_encryption_parameters: EncryptionParameters
|
||||||
|
|
||||||
|
remote_peer: PeerID
|
||||||
|
remote_encryption_parameters: EncryptionParameters
|
||||||
|
|
||||||
|
# order is a comparator used to break the symmetry b/t each pair of peers
|
||||||
|
order: int
|
||||||
|
shared_key: bytes
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def _response_to_msg(conn: IRawConnection, msg: bytes) -> bytes:
|
||||||
|
await conn.write(encode_message(msg))
|
||||||
|
return await read_next_message(conn)
|
||||||
|
|
||||||
|
|
||||||
|
def _mk_multihash_sha256(data: bytes) -> bytes:
|
||||||
|
return multihash.digest(data, "sha2-256")
|
||||||
|
|
||||||
|
|
||||||
|
def _mk_score(public_key: PublicKey, nonce: bytes) -> bytes:
|
||||||
|
return _mk_multihash_sha256(public_key.serialize() + nonce)
|
||||||
|
|
||||||
|
|
||||||
|
def _select_parameter_from_order(
|
||||||
|
order: int, supported_parameters: str, available_parameters: str
|
||||||
|
) -> str:
|
||||||
|
if order < 0:
|
||||||
|
first_choices = available_parameters.split(",")
|
||||||
|
second_choices = supported_parameters.split(",")
|
||||||
|
elif order > 0:
|
||||||
|
first_choices = supported_parameters.split(",")
|
||||||
|
second_choices = available_parameters.split(",")
|
||||||
|
else:
|
||||||
|
return supported_parameters.split(",")[0]
|
||||||
|
|
||||||
|
for first, second in zip(first_choices, second_choices):
|
||||||
|
if first == second:
|
||||||
|
return first
|
||||||
|
raise IncompatibleChoices()
|
||||||
|
|
||||||
|
|
||||||
|
def _select_encryption_parameters(
|
||||||
|
local_proposal: Proposal, remote_proposal: Proposal
|
||||||
|
) -> Tuple[str, str, str, int]:
|
||||||
|
first_score = _mk_score(remote_proposal.public_key, local_proposal.nonce)
|
||||||
|
second_score = _mk_score(local_proposal.public_key, remote_proposal.nonce)
|
||||||
|
|
||||||
|
order = 0
|
||||||
|
if first_score < second_score:
|
||||||
|
order = -1
|
||||||
|
elif second_score < first_score:
|
||||||
|
order = 1
|
||||||
|
|
||||||
|
if order == 0:
|
||||||
|
raise SelfEncryption()
|
||||||
|
|
||||||
|
return (
|
||||||
|
_select_parameter_from_order(
|
||||||
|
order, DEFAULT_SUPPORTED_EXCHANGES, remote_proposal.exchanges
|
||||||
|
),
|
||||||
|
_select_parameter_from_order(
|
||||||
|
order, DEFAULT_SUPPORTED_CIPHERS, remote_proposal.ciphers
|
||||||
|
),
|
||||||
|
_select_parameter_from_order(
|
||||||
|
order, DEFAULT_SUPPORTED_HASHES, remote_proposal.hashes
|
||||||
|
),
|
||||||
|
order,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _establish_session_parameters(
|
||||||
|
local_peer: PeerID,
|
||||||
|
local_private_key: PrivateKey,
|
||||||
|
remote_peer: Optional[PeerID],
|
||||||
|
conn: IRawConnection,
|
||||||
|
nonce: bytes,
|
||||||
|
) -> Tuple[SessionParameters, bytes]:
|
||||||
|
# establish shared encryption parameters
|
||||||
|
session_parameters = SessionParameters()
|
||||||
|
session_parameters.local_peer = local_peer
|
||||||
|
|
||||||
|
local_encryption_parameters = EncryptionParameters()
|
||||||
|
session_parameters.local_encryption_parameters = local_encryption_parameters
|
||||||
|
|
||||||
|
local_public_key = local_private_key.get_public_key()
|
||||||
|
local_encryption_parameters.permanent_public_key = local_public_key
|
||||||
|
|
||||||
|
local_proposal = Proposal(nonce, local_public_key)
|
||||||
|
serialized_local_proposal = local_proposal.serialize()
|
||||||
|
serialized_remote_proposal = await _response_to_msg(conn, serialized_local_proposal)
|
||||||
|
|
||||||
|
remote_encryption_parameters = EncryptionParameters()
|
||||||
|
session_parameters.remote_encryption_parameters = remote_encryption_parameters
|
||||||
|
remote_proposal = Proposal.deserialize(serialized_remote_proposal)
|
||||||
|
remote_encryption_parameters.permanent_public_key = remote_proposal.public_key
|
||||||
|
|
||||||
|
remote_peer_from_proposal = remote_proposal.calculate_peer_id()
|
||||||
|
if not remote_peer:
|
||||||
|
remote_peer = remote_peer_from_proposal
|
||||||
|
elif remote_peer != remote_peer_from_proposal:
|
||||||
|
raise PeerMismatchException()
|
||||||
|
session_parameters.remote_peer = remote_peer
|
||||||
|
|
||||||
|
curve_param, cipher_param, hash_param, order = _select_encryption_parameters(
|
||||||
|
local_proposal, remote_proposal
|
||||||
|
)
|
||||||
|
local_encryption_parameters.curve_type = curve_param
|
||||||
|
local_encryption_parameters.cipher_type = cipher_param
|
||||||
|
local_encryption_parameters.hash_type = hash_param
|
||||||
|
remote_encryption_parameters.curve_type = curve_param
|
||||||
|
remote_encryption_parameters.cipher_type = cipher_param
|
||||||
|
remote_encryption_parameters.hash_type = hash_param
|
||||||
|
session_parameters.order = order
|
||||||
|
|
||||||
|
# exchange ephemeral pub keys
|
||||||
|
local_ephemeral_public_key, shared_key_generator = create_ephemeral_key_pair(
|
||||||
|
curve_param
|
||||||
|
)
|
||||||
|
local_encryption_parameters.ephemeral_public_key = local_ephemeral_public_key
|
||||||
|
local_selection = (
|
||||||
|
serialized_local_proposal
|
||||||
|
+ serialized_remote_proposal
|
||||||
|
+ local_ephemeral_public_key.to_bytes()
|
||||||
|
)
|
||||||
|
exchange_signature = local_private_key.sign(local_selection)
|
||||||
|
local_exchange = Exchange(
|
||||||
|
ephemeral_public_key=local_ephemeral_public_key.to_bytes(),
|
||||||
|
signature=exchange_signature,
|
||||||
|
)
|
||||||
|
|
||||||
|
serialized_local_exchange = local_exchange.SerializeToString()
|
||||||
|
serialized_remote_exchange = await _response_to_msg(conn, serialized_local_exchange)
|
||||||
|
|
||||||
|
remote_exchange = Exchange()
|
||||||
|
remote_exchange.ParseFromString(serialized_remote_exchange)
|
||||||
|
|
||||||
|
remote_ephemeral_public_key_bytes = remote_exchange.ephemeral_public_key
|
||||||
|
remote_ephemeral_public_key = ECCPublicKey.from_bytes(
|
||||||
|
remote_ephemeral_public_key_bytes
|
||||||
|
)
|
||||||
|
remote_encryption_parameters.ephemeral_public_key = remote_ephemeral_public_key
|
||||||
|
remote_selection = (
|
||||||
|
serialized_remote_proposal
|
||||||
|
+ serialized_local_proposal
|
||||||
|
+ remote_ephemeral_public_key_bytes
|
||||||
|
)
|
||||||
|
valid_signature = remote_encryption_parameters.permanent_public_key.verify(
|
||||||
|
remote_selection, remote_exchange.signature
|
||||||
|
)
|
||||||
|
if not valid_signature:
|
||||||
|
raise InvalidSignatureOnExchange()
|
||||||
|
|
||||||
|
shared_key = shared_key_generator(remote_ephemeral_public_key_bytes)
|
||||||
|
session_parameters.shared_key = shared_key
|
||||||
|
|
||||||
|
return session_parameters, remote_proposal.nonce
|
||||||
|
|
||||||
|
|
||||||
|
def _mk_session_from(
|
||||||
|
local_private_key: PrivateKey,
|
||||||
|
session_parameters: SessionParameters,
|
||||||
|
conn: IRawConnection,
|
||||||
|
) -> SecureSession:
|
||||||
|
key_set1, key_set2 = initialize_pair_for_encryption(
|
||||||
|
session_parameters.local_encryption_parameters.cipher_type,
|
||||||
|
session_parameters.local_encryption_parameters.hash_type,
|
||||||
|
session_parameters.shared_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
if session_parameters.order < 0:
|
||||||
|
key_set1, key_set2 = key_set2, key_set1
|
||||||
|
|
||||||
|
session = SecureSession(
|
||||||
|
session_parameters.local_peer,
|
||||||
|
local_private_key,
|
||||||
|
key_set1,
|
||||||
|
session_parameters.remote_peer,
|
||||||
|
key_set2,
|
||||||
|
conn,
|
||||||
|
)
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
async def _finish_handshake(session: ISecureConn, remote_nonce: bytes) -> bytes:
|
||||||
|
await session.write(remote_nonce)
|
||||||
|
return await session.read()
|
||||||
|
|
||||||
|
|
||||||
|
async def create_secure_session(
|
||||||
|
local_nonce: bytes,
|
||||||
|
local_peer: PeerID,
|
||||||
|
local_private_key: PrivateKey,
|
||||||
|
conn: IRawConnection,
|
||||||
|
remote_peer: PeerID = None,
|
||||||
|
) -> ISecureConn:
|
||||||
|
"""
|
||||||
|
Attempt the initial `secio` handshake with the remote peer.
|
||||||
|
If successful, return an object that provides secure communication to the
|
||||||
|
``remote_peer``.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
session_parameters, remote_nonce = await _establish_session_parameters(
|
||||||
|
local_peer, local_private_key, remote_peer, conn, local_nonce
|
||||||
|
)
|
||||||
|
except SecioException as e:
|
||||||
|
await conn.close()
|
||||||
|
raise e
|
||||||
|
|
||||||
|
session = _mk_session_from(local_private_key, session_parameters, conn)
|
||||||
|
|
||||||
|
received_nonce = await _finish_handshake(session, remote_nonce)
|
||||||
|
if received_nonce != local_nonce:
|
||||||
|
await conn.close()
|
||||||
|
raise HandshakeFailed()
|
||||||
|
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
class Transport(BaseSecureTransport):
|
||||||
|
"""
|
||||||
|
``Transport`` provides a security upgrader for a ``IRawConnection``,
|
||||||
|
following the `secio` protocol defined in the libp2p specs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_nonce(self) -> bytes:
|
||||||
|
return self.secure_bytes_provider(NONCE_SIZE)
|
||||||
|
|
||||||
|
async def secure_inbound(self, conn: IRawConnection) -> ISecureConn:
|
||||||
|
"""
|
||||||
|
Secure the connection, either locally or by communicating with opposing node via conn,
|
||||||
|
for an inbound connection (i.e. we are not the initiator)
|
||||||
|
:return: secure connection object (that implements secure_conn_interface)
|
||||||
|
"""
|
||||||
|
local_nonce = self.get_nonce()
|
||||||
|
local_peer = self.local_peer
|
||||||
|
local_private_key = self.local_private_key
|
||||||
|
|
||||||
|
return await create_secure_session(
|
||||||
|
local_nonce, local_peer, local_private_key, conn
|
||||||
|
)
|
||||||
|
|
||||||
|
async def secure_outbound(
|
||||||
|
self, conn: IRawConnection, peer_id: PeerID
|
||||||
|
) -> ISecureConn:
|
||||||
|
"""
|
||||||
|
Secure the connection, either locally or by communicating with opposing node via conn,
|
||||||
|
for an inbound connection (i.e. we are the initiator)
|
||||||
|
:return: secure connection object (that implements secure_conn_interface)
|
||||||
|
"""
|
||||||
|
local_nonce = self.get_nonce()
|
||||||
|
local_peer = self.local_peer
|
||||||
|
local_private_key = self.local_private_key
|
||||||
|
|
||||||
|
return await create_secure_session(
|
||||||
|
local_nonce, local_peer, local_private_key, conn, peer_id
|
||||||
|
)
|
||||||
79
libp2p/security/simple/transport.py
Normal file
79
libp2p/security/simple/transport.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
from libp2p.crypto.keys import KeyPair
|
||||||
|
from libp2p.network.connection.raw_connection_interface import IRawConnection
|
||||||
|
from libp2p.peer.id import ID
|
||||||
|
from libp2p.security.base_transport import BaseSecureTransport
|
||||||
|
from libp2p.security.insecure.transport import InsecureSession
|
||||||
|
from libp2p.security.secure_conn_interface import ISecureConn
|
||||||
|
from libp2p.transport.exceptions import SecurityUpgradeFailure
|
||||||
|
from libp2p.utils import encode_fixedint_prefixed, read_fixedint_prefixed
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleSecurityTransport(BaseSecureTransport):
|
||||||
|
key_phrase: str
|
||||||
|
|
||||||
|
def __init__(self, local_key_pair: KeyPair, key_phrase: str) -> None:
|
||||||
|
super().__init__(local_key_pair)
|
||||||
|
self.key_phrase = key_phrase
|
||||||
|
|
||||||
|
async def secure_inbound(self, conn: IRawConnection) -> ISecureConn:
|
||||||
|
"""
|
||||||
|
Secure the connection, either locally or by communicating with opposing node via conn,
|
||||||
|
for an inbound connection (i.e. we are not the initiator)
|
||||||
|
:return: secure connection object (that implements secure_conn_interface)
|
||||||
|
"""
|
||||||
|
await conn.write(encode_fixedint_prefixed(self.key_phrase.encode()))
|
||||||
|
incoming = (await read_fixedint_prefixed(conn)).decode()
|
||||||
|
|
||||||
|
if incoming != self.key_phrase:
|
||||||
|
raise SecurityUpgradeFailure(
|
||||||
|
"Key phrase differed between nodes. Expected " + self.key_phrase
|
||||||
|
)
|
||||||
|
|
||||||
|
session = InsecureSession(
|
||||||
|
self.local_peer, self.local_private_key, conn, ID(b"")
|
||||||
|
)
|
||||||
|
# NOTE: Here we calls `run_handshake` for both sides to exchange their public keys and
|
||||||
|
# peer ids, otherwise tests fail. However, it seems pretty weird that
|
||||||
|
# `SimpleSecurityTransport` sends peer id through `Insecure`.
|
||||||
|
await session.run_handshake()
|
||||||
|
# NOTE: this is abusing the abstraction we have here
|
||||||
|
# but this code may be deprecated soon and this exists
|
||||||
|
# mainly to satisfy a test that will go along w/ it
|
||||||
|
# FIXME: Enable type check back when we can deprecate the simple transport.
|
||||||
|
session.key_phrase = self.key_phrase # type: ignore
|
||||||
|
return session
|
||||||
|
|
||||||
|
async def secure_outbound(self, conn: IRawConnection, peer_id: ID) -> ISecureConn:
|
||||||
|
"""
|
||||||
|
Secure the connection, either locally or by communicating with opposing node via conn,
|
||||||
|
for an inbound connection (i.e. we are the initiator)
|
||||||
|
:return: secure connection object (that implements secure_conn_interface)
|
||||||
|
"""
|
||||||
|
await conn.write(encode_fixedint_prefixed(self.key_phrase.encode()))
|
||||||
|
incoming = (await read_fixedint_prefixed(conn)).decode()
|
||||||
|
|
||||||
|
# Force context switch, as this security transport is built for testing locally
|
||||||
|
# in a single event loop
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
if incoming != self.key_phrase:
|
||||||
|
raise SecurityUpgradeFailure(
|
||||||
|
"Key phrase differed between nodes. Expected " + self.key_phrase
|
||||||
|
)
|
||||||
|
|
||||||
|
session = InsecureSession(
|
||||||
|
self.local_peer, self.local_private_key, conn, peer_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE: Here we calls `run_handshake` for both sides to exchange their public keys and
|
||||||
|
# peer ids, otherwise tests fail. However, it seems pretty weird that
|
||||||
|
# `SimpleSecurityTransport` sends peer id through `Insecure`.
|
||||||
|
await session.run_handshake()
|
||||||
|
# NOTE: this is abusing the abstraction we have here
|
||||||
|
# but this code may be deprecated soon and this exists
|
||||||
|
# mainly to satisfy a test that will go along w/ it
|
||||||
|
# FIXME: Enable type check back when we can deprecate the simple transport.
|
||||||
|
session.key_phrase = self.key_phrase # type: ignore
|
||||||
|
return session
|
||||||
22
tests/crypto/test_secp256k1.py
Normal file
22
tests/crypto/test_secp256k1.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||||
|
from libp2p.crypto.serialization import deserialize_private_key, deserialize_public_key
|
||||||
|
|
||||||
|
|
||||||
|
def test_public_key_serialize_deserialize_round_trip():
|
||||||
|
key_pair = create_new_key_pair()
|
||||||
|
public_key = key_pair.public_key
|
||||||
|
|
||||||
|
public_key_bytes = public_key.serialize()
|
||||||
|
another_public_key = deserialize_public_key(public_key_bytes)
|
||||||
|
|
||||||
|
assert public_key == another_public_key
|
||||||
|
|
||||||
|
|
||||||
|
def test_private_key_serialize_deserialize_round_trip():
|
||||||
|
key_pair = create_new_key_pair()
|
||||||
|
private_key = key_pair.private_key
|
||||||
|
|
||||||
|
private_key_bytes = private_key.serialize()
|
||||||
|
another_private_key = deserialize_private_key(private_key_bytes)
|
||||||
|
|
||||||
|
assert private_key == another_private_key
|
||||||
@ -4,10 +4,14 @@ import base58
|
|||||||
import multihash
|
import multihash
|
||||||
|
|
||||||
from libp2p.crypto.rsa import create_new_key_pair
|
from libp2p.crypto.rsa import create_new_key_pair
|
||||||
|
import libp2p.peer.id as PeerID
|
||||||
from libp2p.peer.id import ID
|
from libp2p.peer.id import ID
|
||||||
|
|
||||||
ALPHABETS = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
|
ALPHABETS = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
|
||||||
|
|
||||||
|
# ensure we are not in "debug" mode for the following tests
|
||||||
|
PeerID.FRIENDLY_IDS = False
|
||||||
|
|
||||||
|
|
||||||
def test_eq_impl_for_bytes():
|
def test_eq_impl_for_bytes():
|
||||||
random_id_string = ""
|
random_id_string = ""
|
||||||
|
|||||||
101
tests/security/test_secio.py
Normal file
101
tests/security/test_secio.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||||
|
from libp2p.network.connection.raw_connection_interface import IRawConnection
|
||||||
|
from libp2p.peer.id import ID
|
||||||
|
from libp2p.security.secio.transport import NONCE_SIZE, create_secure_session
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryConnection(IRawConnection):
|
||||||
|
def __init__(self, peer, initiator=False):
|
||||||
|
self.peer = peer
|
||||||
|
self.recv_queue = asyncio.Queue()
|
||||||
|
self.send_queue = asyncio.Queue()
|
||||||
|
self.initiator = initiator
|
||||||
|
|
||||||
|
self.current_msg = None
|
||||||
|
self.current_position = 0
|
||||||
|
|
||||||
|
self.closed = False
|
||||||
|
|
||||||
|
async def write(self, data: bytes) -> None:
|
||||||
|
if self.closed:
|
||||||
|
raise Exception("InMemoryConnection is closed for writing")
|
||||||
|
|
||||||
|
await self.send_queue.put(data)
|
||||||
|
|
||||||
|
async def read(self, n: int = -1) -> bytes:
|
||||||
|
"""
|
||||||
|
NOTE: have to buffer the current message and juggle packets
|
||||||
|
off the recv queue to satisfy the semantics of this function.
|
||||||
|
"""
|
||||||
|
if self.closed:
|
||||||
|
raise Exception("InMemoryConnection is closed for reading")
|
||||||
|
|
||||||
|
if not self.current_msg:
|
||||||
|
self.current_msg = await self.recv_queue.get()
|
||||||
|
self.current_position = 0
|
||||||
|
|
||||||
|
if n < 0:
|
||||||
|
msg = self.current_msg
|
||||||
|
self.current_msg = None
|
||||||
|
return msg
|
||||||
|
|
||||||
|
next_msg = self.current_msg[self.current_position : self.current_position + n]
|
||||||
|
self.current_position += n
|
||||||
|
if self.current_position == len(self.current_msg):
|
||||||
|
self.current_msg = None
|
||||||
|
return next_msg
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
self.closed = True
|
||||||
|
|
||||||
|
|
||||||
|
async def create_pipe(local_conn, remote_conn):
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
next_msg = await local_conn.send_queue.get()
|
||||||
|
await remote_conn.recv_queue.put(next_msg)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_secure_session():
|
||||||
|
local_nonce = b"\x01" * NONCE_SIZE
|
||||||
|
local_key_pair = create_new_key_pair(b"a")
|
||||||
|
local_peer = ID.from_pubkey(local_key_pair.public_key)
|
||||||
|
|
||||||
|
remote_nonce = b"\x02" * NONCE_SIZE
|
||||||
|
remote_key_pair = create_new_key_pair(b"b")
|
||||||
|
remote_peer = ID.from_pubkey(remote_key_pair.public_key)
|
||||||
|
|
||||||
|
local_conn = InMemoryConnection(local_peer, initiator=True)
|
||||||
|
remote_conn = InMemoryConnection(remote_peer)
|
||||||
|
|
||||||
|
local_pipe_task = asyncio.create_task(create_pipe(local_conn, remote_conn))
|
||||||
|
remote_pipe_task = asyncio.create_task(create_pipe(remote_conn, local_conn))
|
||||||
|
|
||||||
|
local_session_builder = create_secure_session(
|
||||||
|
local_nonce, local_peer, local_key_pair.private_key, local_conn, remote_peer
|
||||||
|
)
|
||||||
|
remote_session_builder = create_secure_session(
|
||||||
|
remote_nonce, remote_peer, remote_key_pair.private_key, remote_conn
|
||||||
|
)
|
||||||
|
local_secure_conn, remote_secure_conn = await asyncio.gather(
|
||||||
|
local_session_builder, remote_session_builder
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = b"abc"
|
||||||
|
await local_secure_conn.write(msg)
|
||||||
|
received_msg = await remote_secure_conn.read()
|
||||||
|
assert received_msg == msg
|
||||||
|
|
||||||
|
await asyncio.gather(local_secure_conn.close(), remote_secure_conn.close())
|
||||||
|
|
||||||
|
local_pipe_task.cancel()
|
||||||
|
remote_pipe_task.cancel()
|
||||||
|
await local_pipe_task
|
||||||
|
await remote_pipe_task
|
||||||
Reference in New Issue
Block a user