diff --git a/stream_muxer/mplex/mplex.py b/stream_muxer/mplex/mplex.py index 87f72ad3..21ec0d6b 100644 --- a/stream_muxer/mplex/mplex.py +++ b/stream_muxer/mplex/mplex.py @@ -1,13 +1,15 @@ import asyncio -from .utils import encode_uvarint, decode_uvarint +from .utils import encode_uvarint, decode_uvarint_from_stream from .mplex_stream import MplexStream from ..muxed_connection_interface import IMuxedConn class Mplex(IMuxedConn): + # pylint: disable=too-many-instance-attributes """ reference: https://github.com/libp2p/go-mplex/blob/master/multiplex.go """ + def __init__(self, conn, initiator): """ create a new muxed connection @@ -21,8 +23,8 @@ class Mplex(IMuxedConn): self.buffers = {} self.stream_queue = asyncio.Queue() - self.conn_lock = asyncio.Lock() self._next_id = 0 + self.data_buffer = bytearray() # The initiator need not read upon construction time. # It should read when the user decides that it wants to read from the constructed stream. @@ -30,6 +32,10 @@ class Mplex(IMuxedConn): asyncio.ensure_future(self.handle_incoming()) def _next_stream_id(self): + """ + Get next available stream id + :return: next available stream id for the connection + """ next_id = self._next_id self._next_id += 1 return next_id @@ -47,14 +53,31 @@ class Mplex(IMuxedConn): """ async def read_buffer(self, stream_id): + """ + Read a message from stream_id's buffer, check raw connection for new messages + :param stream_id: stream id of stream to read from + :return: message read + """ # Empty buffer or nonexistent stream # TODO: propagate up timeout exception and catch - if stream_id not in self.buffers or not self.buffers[stream_id]: + if stream_id not in self.buffers or self.buffers[stream_id].empty(): await self.handle_incoming() + if stream_id in self.buffers: + return await self._read_buffer_exists(stream_id) - data = self.buffers[stream_id] - self.buffers[stream_id] = bytearray() - return data + return None + + async def _read_buffer_exists(self, stream_id): + """ + Reads from raw connection with the assumption that the message buffer for stream_id exsits + :param stream_id: stream id of stream to read from + :return: message read + """ + try: + data = await asyncio.wait_for(self.buffers[stream_id].get(), timeout=5) + return data + except asyncio.TimeoutError: + return None async def open_stream(self, protocol_id, peer_id, multi_addr): """ @@ -67,7 +90,7 @@ class Mplex(IMuxedConn): """ stream_id = self._next_stream_id() stream = MplexStream(stream_id, multi_addr, self) - self.buffers[stream_id] = bytearray() + self.buffers[stream_id] = asyncio.Queue() return stream async def accept_stream(self): @@ -92,6 +115,7 @@ class Mplex(IMuxedConn): # << by 3, then or with flag header = (stream_id << 3) | flag header = encode_uvarint(header) + if data is None: data_length = encode_uvarint(0) _bytes = header + data_length @@ -102,30 +126,56 @@ class Mplex(IMuxedConn): return await self.write_to_stream(_bytes) async def write_to_stream(self, _bytes): + """ + writes a byte array to a raw connection + :param _bytes: byte array to write + :return: length written + """ self.raw_conn.writer.write(_bytes) await self.raw_conn.writer.drain() return len(_bytes) async def handle_incoming(self): - data = bytearray() + """ + Read a message off of the raw connection and add it to the corresponding message buffer + """ + # TODO Deal with other types of messages using flag (currently _) + # TODO call read_message in loop to handle case message for other stream was in conn + + stream_id, _, message = await self.read_message() + + if stream_id not in self.buffers: + self.buffers[stream_id] = asyncio.Queue() + await self.stream_queue.put(stream_id) + + await self.buffers[stream_id].put(message) + + async def read_chunk(self): + """ + Read a chunk of bytes off of the raw connection into data_buffer + """ + # unused now but possibly useful in the future try: - chunk = await asyncio.wait_for(self.raw_conn.reader.read(1024), timeout=5) - data += chunk - - header, end_index = decode_uvarint(data, 0) - length, end_index = decode_uvarint(data, end_index) - - message = data[end_index:end_index + length + 1] - - # Deal with other types of messages - # TODO use flag - # flag = header & 0x07 - stream_id = header >> 3 - - if stream_id not in self.buffers: - self.buffers[stream_id] = message - await self.stream_queue.put(stream_id) - else: - self.buffers[stream_id] = self.buffers[stream_id] + message + chunk = await asyncio.wait_for(self.raw_conn.reader.read(-1), timeout=5) + self.data_buffer += chunk except asyncio.TimeoutError: print('timeout!') + return + + async def read_message(self): + """ + Read a single message off of the raw connection + :return: stream_id, flag, message contents + """ + try: + header = await decode_uvarint_from_stream(self.raw_conn.reader) + length = await decode_uvarint_from_stream(self.raw_conn.reader) + message = await asyncio.wait_for(self.raw_conn.reader.read(length), timeout=5) + except asyncio.TimeoutError: + print("message malformed") + return None, None, None + + flag = header & 0x07 + stream_id = header >> 3 + + return stream_id, flag, message diff --git a/stream_muxer/mplex/utils.py b/stream_muxer/mplex/utils.py index 824ad931..94ffe845 100644 --- a/stream_muxer/mplex/utils.py +++ b/stream_muxer/mplex/utils.py @@ -1,3 +1,6 @@ +import asyncio +import struct + def encode_uvarint(number): """Pack `number` into varint bytes""" buf = b'' @@ -23,3 +26,16 @@ def decode_uvarint(buff, index): index += 1 return result, index + 1 + +async def decode_uvarint_from_stream(reader): + shift = 0 + result = 0 + while True: + byte = await asyncio.wait_for(reader.read(1), timeout=5) + i = struct.unpack('>H', b'\x00' + byte)[0] + result |= (i & 0x7f) << shift + shift += 7 + if not i & 0x80: + break + + return result diff --git a/tests/libp2p/test_libp2p.py b/tests/libp2p/test_libp2p.py index 39f2ec74..da5b6ff3 100644 --- a/tests/libp2p/test_libp2p.py +++ b/tests/libp2p/test_libp2p.py @@ -20,7 +20,7 @@ async def test_simple_messages(): # Associate the peer with local ip address (see default parameters of Libp2p()) node_a.get_peerstore().add_addr("node_b", "/ip4/127.0.0.1/tcp/8000", 10) - + print("node_a about to open stream") stream = await node_a.new_stream("node_b", "/echo/1.0.0") messages = ["hello" + str(x) for x in range(10)] for message in messages: @@ -36,8 +36,8 @@ async def test_simple_messages(): @pytest.mark.asyncio async def test_double_response(): - hostA = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/8002/ipfs/hostA"]) - hostB = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/8003/ipfs/hostB"]) + node_a = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/8002/ipfs/node_a"]) + node_b = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/8003/ipfs/node_b"]) async def stream_handler(stream): while True: @@ -52,12 +52,12 @@ async def test_double_response(): print("sending response:" + response) await stream.write(response.encode()) - hostB.set_stream_handler("/echo/1.0.0", stream_handler) + node_b.set_stream_handler("/echo/1.0.0", stream_handler) # Associate the peer with local ip address (see default parameters of Libp2p()) - hostA.get_peerstore().add_addr("hostB", "/ip4/127.0.0.1/tcp/8003", 10) - print("hostA about to open stream") - stream = await hostA.new_stream("hostB", "/echo/1.0.0") + node_a.get_peerstore().add_addr("node_b", "/ip4/127.0.0.1/tcp/8003", 10) + print("node_a about to open stream") + stream = await node_a.new_stream("node_b", "/echo/1.0.0") messages = ["hello" + str(x) for x in range(10)] for message in messages: await stream.write(message.encode())