Refactor HeaderTags

This commit is contained in:
Chih Cheng Liang
2019-08-02 17:14:43 +08:00
parent 29fbb9e40a
commit 36b7e8ded9
4 changed files with 31 additions and 27 deletions

View File

@ -1,8 +1,9 @@
import asyncio
from ..muxed_connection_interface import IMuxedConn
from .constants import HeaderTags
from .utils import encode_uvarint, decode_uvarint_from_stream
from .mplex_stream import MplexStream
from .utils import decode_uvarint_from_stream, encode_uvarint, get_flag
from ..muxed_connection_interface import IMuxedConn
class Mplex(IMuxedConn):
@ -78,7 +79,7 @@ class Mplex(IMuxedConn):
stream_id = self.raw_conn.next_stream_id()
stream = MplexStream(stream_id, multi_addr, self)
self.buffers[stream_id] = asyncio.Queue()
await self.send_message(get_flag(self.initiator, "NEW_STREAM"), None, stream_id)
await self.send_message(HeaderTags.NewStream, None, stream_id)
return stream
async def accept_stream(self):
@ -90,7 +91,7 @@ class Mplex(IMuxedConn):
stream = MplexStream(stream_id, False, self)
asyncio.ensure_future(self.generic_protocol_handler(stream))
async def send_message(self, flag, data, stream_id):
async def send_message(self, flag: HeaderTags, data, stream_id):
"""
sends a message over the connection
:param header: header to use
@ -99,7 +100,7 @@ class Mplex(IMuxedConn):
:return: True if success
"""
# << by 3, then or with flag
header = (stream_id << 3) | flag
header = (stream_id << 3) | flag.value
header = encode_uvarint(header)
if data is None:
@ -135,7 +136,7 @@ class Mplex(IMuxedConn):
self.buffers[stream_id] = asyncio.Queue()
await self.stream_queue.put(stream_id)
if flag is get_flag(True, "NEW_STREAM"):
if flag == HeaderTags.NewStream.value:
# new stream detected on connection
await self.accept_stream()