mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-07 21:50:53 +00:00
Restructure mplex and mplex_stream
This commit is contained in:
@ -16,7 +16,6 @@ from libp2p.utils import (
|
||||
|
||||
from .constants import HeaderTags
|
||||
from .datastructures import StreamID
|
||||
from .exceptions import StreamNotFound
|
||||
from .mplex_stream import MplexStream
|
||||
|
||||
MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0")
|
||||
@ -32,10 +31,9 @@ class Mplex(IMuxedConn):
|
||||
# TODO: `dataIn` in go implementation. Should be size of 8.
|
||||
# TODO: Also, `dataIn` is closed indicating EOF in Go. We don't have similar strategies
|
||||
# to let the `MplexStream`s know that EOF arrived (#235).
|
||||
buffers: Dict[StreamID, "asyncio.Queue[bytes]"]
|
||||
stream_queue: "asyncio.Queue[StreamID]"
|
||||
next_channel_id: int
|
||||
buffers_lock: asyncio.Lock
|
||||
streams: Dict[StreamID, MplexStream]
|
||||
streams_lock: asyncio.Lock
|
||||
shutdown: asyncio.Event
|
||||
|
||||
_tasks: List["asyncio.Future[Any]"]
|
||||
@ -65,12 +63,10 @@ class Mplex(IMuxedConn):
|
||||
self.peer_id = peer_id
|
||||
|
||||
# Mapping from stream ID -> buffer of messages for that stream
|
||||
self.buffers = {}
|
||||
self.buffers_lock = asyncio.Lock()
|
||||
self.streams = {}
|
||||
self.streams_lock = asyncio.Lock()
|
||||
self.shutdown = asyncio.Event()
|
||||
|
||||
self.stream_queue = asyncio.Queue()
|
||||
|
||||
self._tasks = []
|
||||
|
||||
# Kick off reading
|
||||
@ -95,29 +91,6 @@ class Mplex(IMuxedConn):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def read_buffer(self, stream_id: StreamID) -> bytes:
|
||||
"""
|
||||
Read a message from buffer of the stream specified by `stream_id`,
|
||||
check secured connection for new messages.
|
||||
`StreamNotFound` is raised when stream `stream_id` is not found in `Mplex`.
|
||||
:param stream_id: stream id of stream to read from
|
||||
:return: message read
|
||||
"""
|
||||
if stream_id not in self.buffers:
|
||||
raise StreamNotFound(f"stream {stream_id} is not found")
|
||||
return await self.buffers[stream_id].get()
|
||||
|
||||
async def read_buffer_nonblocking(self, stream_id: StreamID) -> Optional[bytes]:
|
||||
"""
|
||||
Read a message from buffer of the stream specified by `stream_id`, non-blockingly.
|
||||
`StreamNotFound` is raised when stream `stream_id` is not found in `Mplex`.
|
||||
"""
|
||||
if stream_id not in self.buffers:
|
||||
raise StreamNotFound(f"stream {stream_id} is not found")
|
||||
if self.buffers[stream_id].empty():
|
||||
return None
|
||||
return await self.buffers[stream_id].get()
|
||||
|
||||
def _get_next_channel_id(self) -> int:
|
||||
"""
|
||||
Get next available stream id
|
||||
@ -127,6 +100,12 @@ class Mplex(IMuxedConn):
|
||||
self.next_channel_id += 1
|
||||
return next_id
|
||||
|
||||
async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream:
|
||||
async with self.streams_lock:
|
||||
stream = MplexStream(name, stream_id, self)
|
||||
self.streams[stream_id] = stream
|
||||
return stream
|
||||
|
||||
async def open_stream(self) -> IMuxedStream:
|
||||
"""
|
||||
creates a new muxed_stream
|
||||
@ -134,19 +113,18 @@ class Mplex(IMuxedConn):
|
||||
"""
|
||||
channel_id = self._get_next_channel_id()
|
||||
stream_id = StreamID(channel_id=channel_id, is_initiator=True)
|
||||
name = str(channel_id)
|
||||
stream = MplexStream(name, stream_id, self)
|
||||
self.buffers[stream_id] = asyncio.Queue()
|
||||
# Default stream name is the `channel_id`
|
||||
name = str(channel_id)
|
||||
stream = await self._initialize_stream(stream_id, name)
|
||||
await self.send_message(HeaderTags.NewStream, name.encode(), stream_id)
|
||||
return stream
|
||||
|
||||
async def accept_stream(self, name: str) -> None:
|
||||
async def accept_stream(self, stream_id: StreamID, name: str) -> None:
|
||||
"""
|
||||
accepts a muxed stream opened by the other end
|
||||
"""
|
||||
stream_id = await self.stream_queue.get()
|
||||
stream = MplexStream(name, stream_id, self)
|
||||
stream = await self._initialize_stream(stream_id, name)
|
||||
# Perform protocol negotiation for the stream.
|
||||
self._tasks.append(asyncio.ensure_future(self.generic_protocol_handler(stream)))
|
||||
|
||||
async def send_message(
|
||||
@ -185,22 +163,30 @@ class Mplex(IMuxedConn):
|
||||
|
||||
while True:
|
||||
channel_id, flag, message = await self.read_message()
|
||||
|
||||
if channel_id is not None and flag is not None and message is not None:
|
||||
stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1))
|
||||
if stream_id not in self.buffers:
|
||||
self.buffers[stream_id] = asyncio.Queue()
|
||||
await self.stream_queue.put(stream_id)
|
||||
|
||||
is_stream_id_seen: bool
|
||||
async with self.streams_lock:
|
||||
is_stream_id_seen = stream_id in self.streams
|
||||
# Other consequent stream message should wait until the stream get accepted
|
||||
# TODO: Handle more tags, and refactor `HeaderTags`
|
||||
if flag == HeaderTags.NewStream.value:
|
||||
# new stream detected on connection
|
||||
await self.accept_stream(message.decode())
|
||||
if is_stream_id_seen:
|
||||
# `NewStream` for the same id is received twice...
|
||||
pass
|
||||
await self.accept_stream(stream_id, message.decode())
|
||||
elif flag in (
|
||||
HeaderTags.MessageInitiator.value,
|
||||
HeaderTags.MessageReceiver.value,
|
||||
):
|
||||
await self.buffers[stream_id].put(message)
|
||||
if not is_stream_id_seen:
|
||||
# We receive a message of the stream `stream_id` which is not accepted
|
||||
# before. It is abnormal. Possibly disconnect?
|
||||
# TODO: Warn and emit logs about this.
|
||||
continue
|
||||
async with self.streams_lock:
|
||||
stream = self.streams[stream_id]
|
||||
await stream.incoming_data.put(message)
|
||||
# elif flag in (
|
||||
# HeaderTags.CloseInitiator.value,
|
||||
# HeaderTags.CloseReceiver.value
|
||||
|
||||
Reference in New Issue
Block a user