BaseMsgReadWriter

- Change `BaseMsgReadWriter` to encode/decode messages with abstract
method, which can be implemented by the subclasses. This allows us to
create subclasses `FixedSizeLenMsgReadWriter` and
`VarIntLenMsgReadWriter`.
This commit is contained in:
mhchia
2020-02-20 21:48:03 +08:00
parent 88f660a9c5
commit 6016ea731b
5 changed files with 59 additions and 30 deletions

View File

@ -5,11 +5,13 @@ from that repo: "a simple package to r/w length-delimited slices."
NOTE: currently missing the capability to indicate lengths by "varint" method.
"""
from typing import Optional
from abc import abstractmethod
from libp2p.io.abc import MsgReadWriteCloser, Reader, ReadWriteCloser
from libp2p.io.utils import read_exactly
from libp2p.utils import decode_uvarint_from_stream, encode_varint_prefixed
from .exceptions import MessageTooLarge
BYTE_ORDER = "big"
@ -31,34 +33,57 @@ def encode_msg_with_length(msg_bytes: bytes, size_len_bytes: int) -> bytes:
class BaseMsgReadWriter(MsgReadWriteCloser):
next_length: Optional[int]
read_write_closer: ReadWriteCloser
size_len_bytes: int
def __init__(self, read_write_closer: ReadWriteCloser) -> None:
self.read_write_closer = read_write_closer
self.next_length = None
async def read_msg(self) -> bytes:
length = await self.next_msg_len()
return await read_exactly(self.read_write_closer, length)
data = await read_exactly(self.read_write_closer, length)
if len(data) < length:
self.next_length = length - len(data)
else:
self.next_length = None
return data
@abstractmethod
async def next_msg_len(self) -> int:
if self.next_length is None:
self.next_length = await read_length(
self.read_write_closer, self.size_len_bytes
)
return self.next_length
...
@abstractmethod
def encode_msg(self, msg: bytes) -> bytes:
...
async def close(self) -> None:
await self.read_write_closer.close()
async def write_msg(self, msg: bytes) -> None:
data = encode_msg_with_length(msg, self.size_len_bytes)
await self.read_write_closer.write(data)
encoded_msg = self.encode_msg(msg)
await self.read_write_closer.write(encoded_msg)
class FixedSizeLenMsgReadWriter(BaseMsgReadWriter):
size_len_bytes: int
async def next_msg_len(self) -> int:
return await read_length(self.read_write_closer, self.size_len_bytes)
def encode_msg(self, msg: bytes) -> bytes:
return encode_msg_with_length(msg, self.size_len_bytes)
class VarIntLengthMsgReadWriter(BaseMsgReadWriter):
max_msg_size: int
async def next_msg_len(self) -> int:
msg_len = await decode_uvarint_from_stream(self.read_write_closer)
if msg_len > self.max_msg_size:
raise MessageTooLarge(
f"msg_len={msg_len} > max_msg_size={self.max_msg_size}"
)
return msg_len
def encode_msg(self, msg: bytes) -> bytes:
msg_len = len(msg)
if msg_len > self.max_msg_size:
raise MessageTooLarge(
f"msg_len={msg_len} > max_msg_size={self.max_msg_size}"
)
return encode_varint_prefixed(msg)