Fix MplexStream error

When receiving a `NewStream`, the message of that packet is the
name of the stream, which should be handled, rather than letting it go
into the message queue.
This commit is contained in:
mhchia
2019-08-26 20:26:22 +08:00
parent 5b122d04b2
commit d59870ebbf
4 changed files with 25 additions and 34 deletions

View File

@ -156,15 +156,10 @@ class Swarm(INetwork):
if not addrs: if not addrs:
raise SwarmException("No known addresses to peer") raise SwarmException("No known addresses to peer")
multiaddr = addrs[0]
muxed_conn = await self.dial_peer(peer_id) muxed_conn = await self.dial_peer(peer_id)
# Use muxed conn to open stream, which returns # Use muxed conn to open stream, which returns a muxed stream
# a muxed stream muxed_stream = await muxed_conn.open_stream()
# TODO: Remove protocol id from being passed into muxed_conn
# FIXME: Remove multiaddr from being passed into muxed_conn
muxed_stream = await muxed_conn.open_stream(protocol_ids[0], multiaddr)
# Perform protocol muxing to determine protocol to use # Perform protocol muxing to determine protocol to use
selected_protocol = await self.multiselect_client.select_one_of( selected_protocol = await self.multiselect_client.select_one_of(

View File

@ -1,8 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from multiaddr import Multiaddr
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.stream_muxer.mplex.constants import HeaderTags from libp2p.stream_muxer.mplex.constants import HeaderTags
@ -66,20 +64,15 @@ class IMuxedConn(ABC):
Read a message from `stream_id`'s buffer, non-blockingly. Read a message from `stream_id`'s buffer, non-blockingly.
""" """
# FIXME: Remove multiaddr from being passed into muxed_conn
@abstractmethod @abstractmethod
async def open_stream( async def open_stream(self) -> "IMuxedStream":
self, protocol_id: str, multi_addr: Multiaddr
) -> "IMuxedStream":
""" """
creates a new muxed_stream creates a new muxed_stream
:param protocol_id: protocol_id of stream :return: a new ``IMuxedStream`` stream
:param multi_addr: multi_addr that stream connects to
:return: a new stream
""" """
@abstractmethod @abstractmethod
async def accept_stream(self) -> None: async def accept_stream(self, name: str) -> None:
""" """
accepts a muxed stream opened by the other end accepts a muxed stream opened by the other end
""" """

View File

@ -1,8 +1,6 @@
import asyncio import asyncio
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
from multiaddr import Multiaddr
from libp2p.network.typing import GenericProtocolHandlerFn from libp2p.network.typing import GenericProtocolHandlerFn
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
@ -31,6 +29,7 @@ class Mplex(IMuxedConn):
stream_queue: "asyncio.Queue[int]" stream_queue: "asyncio.Queue[int]"
next_stream_id: int next_stream_id: int
# TODO: `generic_protocol_handler` should be refactored out of mplex conn.
def __init__( def __init__(
self, self,
secured_conn: ISecureConn, secured_conn: ISecureConn,
@ -114,28 +113,25 @@ class Mplex(IMuxedConn):
self.next_stream_id += 2 self.next_stream_id += 2
return next_id return next_id
# FIXME: Remove multiaddr from being passed into muxed_conn async def open_stream(self) -> IMuxedStream:
async def open_stream(
self, protocol_id: str, multi_addr: Multiaddr
) -> IMuxedStream:
""" """
creates a new muxed_stream creates a new muxed_stream
:param protocol_id: protocol_id of stream :return: a new ``MplexStream``
:param multi_addr: multi_addr that stream connects to
:return: a new muxed stream
""" """
stream_id = self._get_next_stream_id() stream_id = self._get_next_stream_id()
stream = MplexStream(stream_id, True, self) name = str(stream_id).encode()
stream = MplexStream(name, stream_id, True, self)
self.buffers[stream_id] = asyncio.Queue() self.buffers[stream_id] = asyncio.Queue()
await self.send_message(HeaderTags.NewStream, None, stream_id) # Default stream name is the `stream_id`
await self.send_message(HeaderTags.NewStream, name, stream_id)
return stream return stream
async def accept_stream(self) -> None: async def accept_stream(self, name: str) -> None:
""" """
accepts a muxed stream opened by the other end accepts a muxed stream opened by the other end
""" """
stream_id = await self.stream_queue.get() stream_id = await self.stream_queue.get()
stream = MplexStream(stream_id, False, self) stream = MplexStream(name, stream_id, False, self)
asyncio.ensure_future(self.generic_protocol_handler(stream)) asyncio.ensure_future(self.generic_protocol_handler(stream))
async def send_message(self, flag: HeaderTags, data: bytes, stream_id: int) -> int: async def send_message(self, flag: HeaderTags, data: bytes, stream_id: int) -> int:
@ -181,11 +177,14 @@ class Mplex(IMuxedConn):
self.buffers[stream_id] = asyncio.Queue() self.buffers[stream_id] = asyncio.Queue()
await self.stream_queue.put(stream_id) await self.stream_queue.put(stream_id)
# TODO: Handle more tags, and refactor `HeaderTags`
if flag == HeaderTags.NewStream.value: if flag == HeaderTags.NewStream.value:
# new stream detected on connection # new stream detected on connection
await self.accept_stream() await self.accept_stream(message)
elif flag in (
if message: HeaderTags.MessageInitiator.value,
HeaderTags.MessageReceiver.value,
):
await self.buffers[stream_id].put(message) await self.buffers[stream_id].put(message)
# Force context switch # Force context switch

View File

@ -10,6 +10,7 @@ class MplexStream(IMuxedStream):
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go reference: https://github.com/libp2p/go-mplex/blob/master/stream.go
""" """
name: str
stream_id: int stream_id: int
initiator: bool initiator: bool
mplex_conn: IMuxedConn mplex_conn: IMuxedConn
@ -21,13 +22,16 @@ class MplexStream(IMuxedStream):
_buf: bytearray _buf: bytearray
def __init__(self, stream_id: int, initiator: bool, mplex_conn: IMuxedConn) -> None: def __init__(
self, name: str, stream_id: int, initiator: bool, mplex_conn: IMuxedConn
) -> None:
""" """
create new MuxedStream in muxer create new MuxedStream in muxer
:param stream_id: stream stream id :param stream_id: stream stream id
:param initiator: boolean if this is an initiator :param initiator: boolean if this is an initiator
:param mplex_conn: muxed connection of this muxed_stream :param mplex_conn: muxed connection of this muxed_stream
""" """
self.name = name
self.stream_id = stream_id self.stream_id = stream_id
self.initiator = initiator self.initiator = initiator
self.mplex_conn = mplex_conn self.mplex_conn = mplex_conn