mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-10 07:00:54 +00:00
updated reading to read until you see a message for your stream (#100)
* updated reading to read until you see a message for your stream * added timeout to decode uvarint * resolved comments * shortened long line
This commit is contained in:
@ -29,7 +29,7 @@ class Mplex(IMuxedConn):
|
||||
# The initiator of the raw connection need not read upon construction time.
|
||||
# It should read when the user decides that it wants to read from the constructed stream.
|
||||
if not self.initiator:
|
||||
asyncio.ensure_future(self.handle_incoming())
|
||||
asyncio.ensure_future(self.handle_incoming(None))
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
@ -52,7 +52,7 @@ class Mplex(IMuxedConn):
|
||||
# Empty buffer or nonexistent stream
|
||||
# TODO: propagate up timeout exception and catch
|
||||
if stream_id not in self.buffers or self.buffers[stream_id].empty():
|
||||
await self.handle_incoming()
|
||||
await self.handle_incoming(stream_id)
|
||||
if stream_id in self.buffers:
|
||||
return await self._read_buffer_exists(stream_id)
|
||||
|
||||
@ -126,20 +126,26 @@ class Mplex(IMuxedConn):
|
||||
await self.raw_conn.writer.drain()
|
||||
return len(_bytes)
|
||||
|
||||
async def handle_incoming(self):
|
||||
async def handle_incoming(self, my_stream_id):
|
||||
"""
|
||||
Read a message off of the raw connection and add it to the corresponding message buffer
|
||||
"""
|
||||
# TODO Deal with other types of messages using flag (currently _)
|
||||
# TODO call read_message in loop to handle case message for other stream was in conn
|
||||
|
||||
stream_id, _, message = await self.read_message()
|
||||
continue_reading = True
|
||||
i = 0
|
||||
while continue_reading:
|
||||
i += 1
|
||||
stream_id, _, message = await self.read_message()
|
||||
continue_reading = (stream_id is not None and
|
||||
stream_id != my_stream_id and
|
||||
my_stream_id is not None)
|
||||
|
||||
if stream_id not in self.buffers:
|
||||
self.buffers[stream_id] = asyncio.Queue()
|
||||
await self.stream_queue.put(stream_id)
|
||||
if stream_id not in self.buffers:
|
||||
self.buffers[stream_id] = asyncio.Queue()
|
||||
await self.stream_queue.put(stream_id)
|
||||
|
||||
await self.buffers[stream_id].put(message)
|
||||
await self.buffers[stream_id].put(message)
|
||||
|
||||
async def read_chunk(self):
|
||||
"""
|
||||
@ -158,12 +164,15 @@ class Mplex(IMuxedConn):
|
||||
Read a single message off of the raw connection
|
||||
:return: stream_id, flag, message contents
|
||||
"""
|
||||
|
||||
# Timeout is set to a relatively small value to alleviate wait time to exit
|
||||
# loop in handle_incoming
|
||||
timeout = .1
|
||||
try:
|
||||
header = await decode_uvarint_from_stream(self.raw_conn.reader)
|
||||
length = await decode_uvarint_from_stream(self.raw_conn.reader)
|
||||
message = await asyncio.wait_for(self.raw_conn.reader.read(length), timeout=5)
|
||||
header = await decode_uvarint_from_stream(self.raw_conn.reader, timeout)
|
||||
length = await decode_uvarint_from_stream(self.raw_conn.reader, timeout)
|
||||
message = await asyncio.wait_for(self.raw_conn.reader.read(length), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
print("message malformed")
|
||||
return None, None, None
|
||||
|
||||
flag = header & 0x07
|
||||
|
||||
@ -29,12 +29,11 @@ def decode_uvarint(buff, index):
|
||||
|
||||
return result, index + 1
|
||||
|
||||
|
||||
async def decode_uvarint_from_stream(reader):
|
||||
async def decode_uvarint_from_stream(reader, timeout):
|
||||
shift = 0
|
||||
result = 0
|
||||
while True:
|
||||
byte = await asyncio.wait_for(reader.read(1), timeout=5)
|
||||
byte = await asyncio.wait_for(reader.read(1), timeout=timeout)
|
||||
i = struct.unpack('>H', b'\x00' + byte)[0]
|
||||
result |= (i & 0x7f) << shift
|
||||
shift += 7
|
||||
|
||||
Reference in New Issue
Block a user