diff --git a/libp2p/io/exceptions.py b/libp2p/io/exceptions.py new file mode 100644 index 00000000..4d06ece4 --- /dev/null +++ b/libp2p/io/exceptions.py @@ -0,0 +1,10 @@ +class MsgioException(Exception): + pass + + +class MissingLengthException(MsgioException): + pass + + +class MissingMessageException(MsgioException): + pass diff --git a/libp2p/io/msgio.py b/libp2p/io/msgio.py index f745c180..65fde685 100644 --- a/libp2p/io/msgio.py +++ b/libp2p/io/msgio.py @@ -1,4 +1,6 @@ -import asyncio +from libp2p.network.connection.raw_connection_interface import IRawConnection + +from .exceptions import MissingLengthException, MissingMessageException SIZE_LEN_BYTES = 4 @@ -10,7 +12,13 @@ def encode(msg_bytes: bytes) -> bytes: return len_prefix + msg_bytes -async def read_next_message(reader: asyncio.StreamReader) -> bytes: - len_bytes = await reader.readexactly(SIZE_LEN_BYTES) +async def read_next_message(reader: IRawConnection) -> bytes: + len_bytes = await reader.read(SIZE_LEN_BYTES) + if len(len_bytes) != SIZE_LEN_BYTES: + raise MissingLengthException() len_int = int.from_bytes(len_bytes, "big") - return await reader.readexactly(len_int) + next_msg = await reader.read(len_int) + if len(next_msg) != len_int: + # TODO makes sense to keep reading until this condition is true? + raise MissingMessageException() + return next_msg