This commit is contained in:
mhchia
2020-02-17 19:02:18 +08:00
parent 874c6bbca4
commit 2df47a943c
5 changed files with 63 additions and 75 deletions

View File

@ -39,9 +39,9 @@ class MsgReader(ABC):
async def read_msg(self) -> bytes: async def read_msg(self) -> bytes:
... ...
@abstractmethod # @abstractmethod
async def next_msg_len(self) -> int: # async def next_msg_len(self) -> int:
... # ...
class MsgWriter(ABC): class MsgWriter(ABC):
@ -52,3 +52,17 @@ class MsgWriter(ABC):
class MsgReadWriter(MsgReader, MsgWriter): class MsgReadWriter(MsgReader, MsgWriter):
pass pass
class Encrypter(ABC):
@abstractmethod
def encrypt(self, data: bytes) -> bytes:
...
@abstractmethod
def decrypt(self, data: bytes) -> bytes:
...
class EncryptedMsgReadWriter(MsgReadWriter, Encrypter):
pass

View File

@ -11,8 +11,7 @@ from typing import Optional
from libp2p.io.abc import MsgReadWriter, Reader, ReadWriteCloser from libp2p.io.abc import MsgReadWriter, Reader, ReadWriteCloser
from libp2p.io.utils import read_exactly from libp2p.io.utils import read_exactly
SIZE_NOISE_LEN_BYTES = 2
SIZE_SECIO_LEN_BYTES = 4
BYTE_ORDER = "big" BYTE_ORDER = "big"
@ -22,7 +21,13 @@ async def read_length(reader: Reader, size_len_bytes: int) -> int:
def encode_msg_with_length(msg_bytes: bytes, size_len_bytes: int) -> bytes: def encode_msg_with_length(msg_bytes: bytes, size_len_bytes: int) -> bytes:
len_prefix = len(msg_bytes).to_bytes(size_len_bytes, byteorder=BYTE_ORDER) try:
len_prefix = len(msg_bytes).to_bytes(size_len_bytes, byteorder=BYTE_ORDER)
except OverflowError:
raise ValueError(
"msg_bytes is too large for `size_len_bytes` bytes length: "
f"msg_bytes={msg_bytes}, size_len_bytes={size_len_bytes}"
)
return len_prefix + msg_bytes return len_prefix + msg_bytes
@ -58,7 +63,3 @@ class BaseMsgReadWriter(MsgReadWriter):
async def write_msg(self, msg: bytes) -> None: async def write_msg(self, msg: bytes) -> None:
data = encode_msg_with_length(msg, self.size_len_bytes) data = encode_msg_with_length(msg, self.size_len_bytes)
await self.read_write_closer.write(data) await self.read_write_closer.write(data)
class MsgIOReadWriter(BaseMsgReadWriter):
size_len_bytes = SIZE_SECIO_LEN_BYTES

View File

@ -3,10 +3,12 @@ from typing import cast
from noise.connection import NoiseConnection as NoiseState from noise.connection import NoiseConnection as NoiseState
from libp2p.io.abc import ReadWriter from libp2p.io.abc import ReadWriteCloser, MsgReadWriter, EncryptedMsgReadWriter
from libp2p.io.msgio import BaseMsgReadWriter, encode_msg_with_length
from libp2p.io.utils import read_exactly from libp2p.io.utils import read_exactly
from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection
SIZE_NOISE_MESSAGE_LEN = 2 SIZE_NOISE_MESSAGE_LEN = 2
MAX_NOISE_MESSAGE_LEN = 2 ** (8 * SIZE_NOISE_MESSAGE_LEN) - 1 MAX_NOISE_MESSAGE_LEN = 2 ** (8 * SIZE_NOISE_MESSAGE_LEN) - 1
SIZE_NOISE_MESSAGE_BODY_LEN = 2 SIZE_NOISE_MESSAGE_BODY_LEN = 2
@ -20,53 +22,12 @@ BYTE_ORDER = "big"
# <-2 bytes-><- max=65533 bytes -> # <-2 bytes-><- max=65533 bytes ->
def encode_data(data: bytes, size_len: int) -> bytes: class NoisePacketReadWriter(BaseMsgReadWriter):
len_data = len(data) size_len_bytes = SIZE_NOISE_MESSAGE_LEN
try:
len_bytes = len_data.to_bytes(size_len, BYTE_ORDER)
except OverflowError as e:
raise ValueError from e
return len_bytes + data
class MsgReader(ABC):
@abstractmethod
async def read_msg(self) -> bytes:
...
class MsgWriter(ABC):
@abstractmethod
async def write_msg(self, msg: bytes) -> None:
...
class MsgReadWriter(MsgReader, MsgWriter):
pass
# TODO: Add comments
class NoisePacketReadWriter(MsgReadWriter):
"""Encode and decode the low level noise messages."""
read_writer: ReadWriter
def __init__(self, read_writer: ReadWriter) -> None:
self.read_writer = read_writer
async def read_msg(self) -> bytes:
len_bytes = await read_exactly(self.read_writer, SIZE_NOISE_MESSAGE_LEN)
len_int = int.from_bytes(len_bytes, BYTE_ORDER)
return await read_exactly(self.read_writer, len_int)
async def write_msg(self, msg: bytes) -> None:
encoded_data = encode_data(msg, SIZE_NOISE_MESSAGE_LEN)
await self.read_writer.write(encoded_data)
# TODO: Add comments
def encode_msg_body(msg_body: bytes) -> bytes: def encode_msg_body(msg_body: bytes) -> bytes:
encoded_msg_body = encode_data(msg_body, SIZE_NOISE_MESSAGE_BODY_LEN) encoded_msg_body = encode_msg_with_length(msg_body, SIZE_NOISE_MESSAGE_BODY_LEN)
if len(encoded_msg_body) > MAX_NOISE_MESSAGE_BODY_LEN: if len(encoded_msg_body) > MAX_NOISE_MESSAGE_BODY_LEN:
raise ValueError( raise ValueError(
f"msg_body is too long: {len(msg_body)}, " f"msg_body is too long: {len(msg_body)}, "
@ -88,39 +49,36 @@ def decode_msg_body(noise_msg: bytes) -> bytes:
] ]
class NoiseHandshakeReadWriter(MsgReadWriter): class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter):
read_writer: MsgReadWriter read_writer: MsgReadWriter
noise_state: NoiseState noise_state: NoiseState
def __init__(self, conn: IRawConnection, noise_state: NoiseState) -> None: def __init__(self, conn: IRawConnection, noise_state: NoiseState) -> None:
self.read_writer = NoisePacketReadWriter(cast(ReadWriter, conn)) self.read_writer = NoisePacketReadWriter(cast(ReadWriteCloser, conn))
self.noise_state = noise_state self.noise_state = noise_state
async def write_msg(self, data: bytes) -> None: async def write_msg(self, data: bytes) -> None:
noise_msg = encode_msg_body(data) noise_msg = encode_msg_body(data)
data_encrypted = self.noise_state.write_message(noise_msg) data_encrypted = self.encrypt(noise_msg)
await self.read_writer.write_msg(data_encrypted) await self.read_writer.write_msg(data_encrypted)
async def read_msg(self) -> bytes: async def read_msg(self) -> bytes:
noise_msg_encrypted = await self.read_writer.read_msg() noise_msg_encrypted = await self.read_writer.read_msg()
noise_msg = self.noise_state.read_message(noise_msg_encrypted) noise_msg = self.decrypt(noise_msg_encrypted)
return decode_msg_body(noise_msg) return decode_msg_body(noise_msg)
class NoiseTransportReadWriter(MsgReadWriter): class NoiseHandshakeReadWriter(BaseNoiseMsgReadWriter):
read_writer: MsgReadWriter def encrypt(self, data: bytes) -> bytes:
noise_state: NoiseState return self.noise_state.write_message(data)
def __init__(self, conn: IRawConnection, noise_state: NoiseState) -> None: def decrypt(self, data: bytes) -> bytes:
self.read_writer = NoisePacketReadWriter(cast(ReadWriter, conn)) return self.noise_state.read_message(data)
self.noise_state = noise_state
async def write_msg(self, data: bytes) -> None:
noise_msg = encode_msg_body(data)
data_encrypted = self.noise_state.encrypt(noise_msg)
await self.read_writer.write_msg(data_encrypted)
async def read_msg(self) -> bytes: class NoiseTransportReadWriter(BaseNoiseMsgReadWriter):
noise_msg_encrypted = await self.read_writer.read_msg() def encrypt(self, data: bytes) -> bytes:
noise_msg = self.noise_state.decrypt(noise_msg_encrypted) return self.noise_state.encrypt(data)
return decode_msg_body(noise_msg)
def decrypt(self, data: bytes) -> bytes:
return self.noise_state.decrypt(data)

View File

@ -16,7 +16,7 @@ from .exceptions import (
NoiseStateError, NoiseStateError,
PeerIDMismatchesPubkey, PeerIDMismatchesPubkey,
) )
from .io import NoiseHandshakeReadWriter from .io import encode_msg_body, decode_msg_body, NoiseHandshakeReadWriter
from .messages import ( from .messages import (
NoiseHandshakePayload, NoiseHandshakePayload,
make_handshake_payload_sig, make_handshake_payload_sig,
@ -56,6 +56,16 @@ class BasePattern(IPattern):
) )
return NoiseHandshakePayload(self.libp2p_privkey.get_public_key(), signature) return NoiseHandshakePayload(self.libp2p_privkey.get_public_key(), signature)
async def write_msg(self, conn: IRawConnection, data: bytes) -> None:
noise_msg = encode_msg_body(data)
data_encrypted = self.noise_state.write_message(noise_msg)
await self.read_writer.write_msg(data_encrypted)
async def read_msg(self) -> bytes:
noise_msg_encrypted = await self.read_writer.read_msg()
noise_msg = self.noise_state.read_message(noise_msg_encrypted)
return decode_msg_body(noise_msg)
class PatternXX(BasePattern): class PatternXX(BasePattern):
def __init__( def __init__(

View File

@ -19,7 +19,7 @@ from libp2p.crypto.key_exchange import create_ephemeral_key_pair
from libp2p.crypto.keys import PrivateKey, PublicKey from libp2p.crypto.keys import PrivateKey, PublicKey
from libp2p.crypto.serialization import deserialize_public_key from libp2p.crypto.serialization import deserialize_public_key
from libp2p.io.exceptions import DecryptionFailedException, IOException from libp2p.io.exceptions import DecryptionFailedException, IOException
from libp2p.io.msgio import MsgIOReadWriter from libp2p.io.msgio import BaseMsgReadWriter
from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.peer.id import ID as PeerID from libp2p.peer.id import ID as PeerID
from libp2p.security.base_session import BaseSession from libp2p.security.base_session import BaseSession
@ -41,6 +41,7 @@ from .pb.spipe_pb2 import Exchange, Propose
ID = TProtocol("/secio/1.0.0") ID = TProtocol("/secio/1.0.0")
NONCE_SIZE = 16 # bytes NONCE_SIZE = 16 # bytes
SIZE_SECIO_LEN_BYTES = 4
# NOTE: the following is only a subset of allowable parameters according to the # NOTE: the following is only a subset of allowable parameters according to the
# `secio` specification. # `secio` specification.
@ -49,6 +50,10 @@ DEFAULT_SUPPORTED_CIPHERS = "AES-128"
DEFAULT_SUPPORTED_HASHES = "SHA256" DEFAULT_SUPPORTED_HASHES = "SHA256"
class MsgIOReadWriter(BaseMsgReadWriter):
size_len_bytes = SIZE_SECIO_LEN_BYTES
class SecureSession(BaseSession): class SecureSession(BaseSession):
buf: io.BytesIO buf: io.BytesIO
low_watermark: int low_watermark: int