BaseMsgReadWriter

- Change `BaseMsgReadWriter` to encode/decode messages with abstract
method, which can be implemented by the subclasses. This allows us to
create subclasses `FixedSizeLenMsgReadWriter` and
`VarIntLenMsgReadWriter`.
This commit is contained in:
mhchia
2020-02-20 21:48:03 +08:00
parent 88f660a9c5
commit 6016ea731b
5 changed files with 59 additions and 30 deletions

View File

@ -23,3 +23,7 @@ class MissingMessageException(MsgioException):
class DecryptionFailedException(MsgioException): class DecryptionFailedException(MsgioException):
pass pass
class MessageTooLarge(MsgioException):
pass

View File

@ -5,11 +5,13 @@ from that repo: "a simple package to r/w length-delimited slices."
NOTE: currently missing the capability to indicate lengths by "varint" method. NOTE: currently missing the capability to indicate lengths by "varint" method.
""" """
from abc import abstractmethod
from typing import Optional
from libp2p.io.abc import MsgReadWriteCloser, Reader, ReadWriteCloser from libp2p.io.abc import MsgReadWriteCloser, Reader, ReadWriteCloser
from libp2p.io.utils import read_exactly from libp2p.io.utils import read_exactly
from libp2p.utils import decode_uvarint_from_stream, encode_varint_prefixed
from .exceptions import MessageTooLarge
BYTE_ORDER = "big" BYTE_ORDER = "big"
@ -31,34 +33,57 @@ def encode_msg_with_length(msg_bytes: bytes, size_len_bytes: int) -> bytes:
class BaseMsgReadWriter(MsgReadWriteCloser): class BaseMsgReadWriter(MsgReadWriteCloser):
next_length: Optional[int]
read_write_closer: ReadWriteCloser read_write_closer: ReadWriteCloser
size_len_bytes: int size_len_bytes: int
def __init__(self, read_write_closer: ReadWriteCloser) -> None: def __init__(self, read_write_closer: ReadWriteCloser) -> None:
self.read_write_closer = read_write_closer self.read_write_closer = read_write_closer
self.next_length = None
async def read_msg(self) -> bytes: async def read_msg(self) -> bytes:
length = await self.next_msg_len() length = await self.next_msg_len()
return await read_exactly(self.read_write_closer, length)
data = await read_exactly(self.read_write_closer, length) @abstractmethod
if len(data) < length:
self.next_length = length - len(data)
else:
self.next_length = None
return data
async def next_msg_len(self) -> int: async def next_msg_len(self) -> int:
if self.next_length is None: ...
self.next_length = await read_length(
self.read_write_closer, self.size_len_bytes @abstractmethod
) def encode_msg(self, msg: bytes) -> bytes:
return self.next_length ...
async def close(self) -> None: async def close(self) -> None:
await self.read_write_closer.close() await self.read_write_closer.close()
async def write_msg(self, msg: bytes) -> None: async def write_msg(self, msg: bytes) -> None:
data = encode_msg_with_length(msg, self.size_len_bytes) encoded_msg = self.encode_msg(msg)
await self.read_write_closer.write(data) await self.read_write_closer.write(encoded_msg)
class FixedSizeLenMsgReadWriter(BaseMsgReadWriter):
size_len_bytes: int
async def next_msg_len(self) -> int:
return await read_length(self.read_write_closer, self.size_len_bytes)
def encode_msg(self, msg: bytes) -> bytes:
return encode_msg_with_length(msg, self.size_len_bytes)
class VarIntLengthMsgReadWriter(BaseMsgReadWriter):
max_msg_size: int
async def next_msg_len(self) -> int:
msg_len = await decode_uvarint_from_stream(self.read_write_closer)
if msg_len > self.max_msg_size:
raise MessageTooLarge(
f"msg_len={msg_len} > max_msg_size={self.max_msg_size}"
)
return msg_len
def encode_msg(self, msg: bytes) -> bytes:
msg_len = len(msg)
if msg_len > self.max_msg_size:
raise MessageTooLarge(
f"msg_len={msg_len} > max_msg_size={self.max_msg_size}"
)
return encode_varint_prefixed(msg)

View File

@ -3,7 +3,7 @@ from libp2p.crypto.keys import PrivateKey, PublicKey
from libp2p.crypto.pb import crypto_pb2 from libp2p.crypto.pb import crypto_pb2
from libp2p.crypto.serialization import deserialize_public_key from libp2p.crypto.serialization import deserialize_public_key
from libp2p.io.abc import ReadWriteCloser from libp2p.io.abc import ReadWriteCloser
from libp2p.io.msgio import BaseMsgReadWriter from libp2p.io.msgio import FixedSizeLenMsgReadWriter
from libp2p.network.connection.exceptions import RawConnError from libp2p.network.connection.exceptions import RawConnError
from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.peer.id import ID from libp2p.peer.id import ID
@ -23,7 +23,7 @@ PLAINTEXT_PROTOCOL_ID = TProtocol("/plaintext/2.0.0")
SIZE_PLAINTEXT_LEN_BYTES = 4 SIZE_PLAINTEXT_LEN_BYTES = 4
class PlaintextHandshakeReadWriter(BaseMsgReadWriter): class PlaintextHandshakeReadWriter(FixedSizeLenMsgReadWriter):
size_len_bytes = SIZE_PLAINTEXT_LEN_BYTES size_len_bytes = SIZE_PLAINTEXT_LEN_BYTES

View File

@ -3,7 +3,7 @@ from typing import cast
from noise.connection import NoiseConnection as NoiseState from noise.connection import NoiseConnection as NoiseState
from libp2p.io.abc import EncryptedMsgReadWriter, MsgReadWriteCloser, ReadWriteCloser from libp2p.io.abc import EncryptedMsgReadWriter, MsgReadWriteCloser, ReadWriteCloser
from libp2p.io.msgio import BaseMsgReadWriter, encode_msg_with_length from libp2p.io.msgio import FixedSizeLenMsgReadWriter, encode_msg_with_length
from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection
SIZE_NOISE_MESSAGE_LEN = 2 SIZE_NOISE_MESSAGE_LEN = 2
@ -19,7 +19,7 @@ BYTE_ORDER = "big"
# <-2 bytes-><- max=65533 bytes -> # <-2 bytes-><- max=65533 bytes ->
class NoisePacketReadWriter(BaseMsgReadWriter): class NoisePacketReadWriter(FixedSizeLenMsgReadWriter):
size_len_bytes = SIZE_NOISE_MESSAGE_LEN size_len_bytes = SIZE_NOISE_MESSAGE_LEN

View File

@ -19,7 +19,7 @@ from libp2p.crypto.keys import PrivateKey, PublicKey
from libp2p.crypto.serialization import deserialize_public_key from libp2p.crypto.serialization import deserialize_public_key
from libp2p.io.abc import EncryptedMsgReadWriter from libp2p.io.abc import EncryptedMsgReadWriter
from libp2p.io.exceptions import DecryptionFailedException, IOException from libp2p.io.exceptions import DecryptionFailedException, IOException
from libp2p.io.msgio import BaseMsgReadWriter from libp2p.io.msgio import FixedSizeLenMsgReadWriter
from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.peer.id import ID as PeerID from libp2p.peer.id import ID as PeerID
from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.base_transport import BaseSecureTransport
@ -50,18 +50,18 @@ DEFAULT_SUPPORTED_CIPHERS = "AES-128"
DEFAULT_SUPPORTED_HASHES = "SHA256" DEFAULT_SUPPORTED_HASHES = "SHA256"
class MsgIOReadWriter(BaseMsgReadWriter): class SecioPacketReadWriter(FixedSizeLenMsgReadWriter):
size_len_bytes = SIZE_SECIO_LEN_BYTES size_len_bytes = SIZE_SECIO_LEN_BYTES
class SecioMsgReadWriter(EncryptedMsgReadWriter): class SecioMsgReadWriter(EncryptedMsgReadWriter):
read_writer: MsgIOReadWriter read_writer: SecioPacketReadWriter
def __init__( def __init__(
self, self,
local_encryption_parameters: AuthenticatedEncryptionParameters, local_encryption_parameters: AuthenticatedEncryptionParameters,
remote_encryption_parameters: AuthenticatedEncryptionParameters, remote_encryption_parameters: AuthenticatedEncryptionParameters,
read_writer: MsgIOReadWriter, read_writer: SecioPacketReadWriter,
) -> None: ) -> None:
self.local_encryption_parameters = local_encryption_parameters self.local_encryption_parameters = local_encryption_parameters
self.remote_encryption_parameters = remote_encryption_parameters self.remote_encryption_parameters = remote_encryption_parameters
@ -170,7 +170,7 @@ class SessionParameters:
pass pass
async def _response_to_msg(read_writer: MsgIOReadWriter, msg: bytes) -> bytes: async def _response_to_msg(read_writer: SecioPacketReadWriter, msg: bytes) -> bytes:
await read_writer.write_msg(msg) await read_writer.write_msg(msg)
return await read_writer.read_msg() return await read_writer.read_msg()
@ -234,7 +234,7 @@ async def _establish_session_parameters(
local_peer: PeerID, local_peer: PeerID,
local_private_key: PrivateKey, local_private_key: PrivateKey,
remote_peer: Optional[PeerID], remote_peer: Optional[PeerID],
conn: MsgIOReadWriter, conn: SecioPacketReadWriter,
nonce: bytes, nonce: bytes,
) -> Tuple[SessionParameters, bytes]: ) -> Tuple[SessionParameters, bytes]:
# establish shared encryption parameters # establish shared encryption parameters
@ -326,7 +326,7 @@ async def _establish_session_parameters(
def _mk_session_from( def _mk_session_from(
local_private_key: PrivateKey, local_private_key: PrivateKey,
session_parameters: SessionParameters, session_parameters: SessionParameters,
conn: MsgIOReadWriter, conn: SecioPacketReadWriter,
is_initiator: bool, is_initiator: bool,
) -> SecureSession: ) -> SecureSession:
key_set1, key_set2 = initialize_pair_for_encryption( key_set1, key_set2 = initialize_pair_for_encryption(
@ -371,7 +371,7 @@ async def create_secure_session(
to the ``remote_peer``. Raise `SecioException` when `conn` closed. to the ``remote_peer``. Raise `SecioException` when `conn` closed.
Raise `InconsistentNonce` when handshake failed Raise `InconsistentNonce` when handshake failed
""" """
msg_io = MsgIOReadWriter(conn) msg_io = SecioPacketReadWriter(conn)
try: try:
session_parameters, remote_nonce = await _establish_session_parameters( session_parameters, remote_nonce = await _establish_session_parameters(
local_peer, local_private_key, remote_peer, msg_io, local_nonce local_peer, local_private_key, remote_peer, msg_io, local_nonce