Merge pull request #303 from mhchia/fix/refactor-mplex-swarm-host

Refactor: `Mplex`, `Swarm`, and `BasicHost`
This commit is contained in:
Kevin Mai-Husan Chia
2019-09-21 17:27:09 +08:00
committed by GitHub
14 changed files with 221 additions and 202 deletions

View File

@ -1,12 +1,19 @@
from typing import Any, List, Sequence import asyncio
import logging
from typing import List, Sequence
import multiaddr import multiaddr
from libp2p.host.exceptions import StreamFailure
from libp2p.network.network_interface import INetwork from libp2p.network.network_interface import INetwork
from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.peer.peerinfo import PeerInfo from libp2p.peer.peerinfo import PeerInfo
from libp2p.peer.peerstore_interface import IPeerStore from libp2p.peer.peerstore_interface import IPeerStore
from libp2p.protocol_muxer.exceptions import MultiselectClientError, MultiselectError
from libp2p.protocol_muxer.multiselect import Multiselect
from libp2p.protocol_muxer.multiselect_client import MultiselectClient
from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator
from libp2p.routing.kademlia.kademlia_peer_router import KadmeliaPeerRouter from libp2p.routing.kademlia.kademlia_peer_router import KadmeliaPeerRouter
from libp2p.typing import StreamHandlerFn, TProtocol from libp2p.typing import StreamHandlerFn, TProtocol
@ -18,17 +25,28 @@ from .host_interface import IHost
# telling it to listen on the given listen addresses. # telling it to listen on the given listen addresses.
logger = logging.getLogger("libp2p.network.basic_host")
logger.setLevel(logging.DEBUG)
class BasicHost(IHost): class BasicHost(IHost):
_network: INetwork _network: INetwork
_router: KadmeliaPeerRouter _router: KadmeliaPeerRouter
peerstore: IPeerStore peerstore: IPeerStore
multiselect: Multiselect
multiselect_client: MultiselectClient
# default options constructor # default options constructor
def __init__(self, network: INetwork, router: KadmeliaPeerRouter = None) -> None: def __init__(self, network: INetwork, router: KadmeliaPeerRouter = None) -> None:
self._network = network self._network = network
self._network.set_stream_handler(self._swarm_stream_handler)
self._router = router self._router = router
self.peerstore = self._network.peerstore self.peerstore = self._network.peerstore
# Protocol muxing
self.multiselect = Multiselect()
self.multiselect_client = MultiselectClient()
def get_id(self) -> ID: def get_id(self) -> ID:
""" """
@ -48,11 +66,11 @@ class BasicHost(IHost):
""" """
return self.peerstore return self.peerstore
# FIXME: Replace with correct return type def get_mux(self) -> Multiselect:
def get_mux(self) -> Any:
""" """
:return: mux instance of host :return: mux instance of host
""" """
return self.multiselect
def get_addrs(self) -> List[multiaddr.Multiaddr]: def get_addrs(self) -> List[multiaddr.Multiaddr]:
""" """
@ -74,7 +92,7 @@ class BasicHost(IHost):
:param protocol_id: protocol id used on stream :param protocol_id: protocol id used on stream
:param stream_handler: a stream handler function :param stream_handler: a stream handler function
""" """
self._network.set_stream_handler(protocol_id, stream_handler) self.multiselect.add_handler(protocol_id, stream_handler)
# `protocol_ids` can be a list of `protocol_id` # `protocol_ids` can be a list of `protocol_id`
# stream will decide which `protocol_id` to run on # stream will decide which `protocol_id` to run on
@ -86,7 +104,21 @@ class BasicHost(IHost):
:param protocol_ids: available protocol ids to use for stream :param protocol_ids: available protocol ids to use for stream
:return: stream: new stream created :return: stream: new stream created
""" """
return await self._network.new_stream(peer_id, protocol_ids)
net_stream = await self._network.new_stream(peer_id, protocol_ids)
# Perform protocol muxing to determine protocol to use
try:
selected_protocol = await self.multiselect_client.select_one_of(
list(protocol_ids), MultiselectCommunicator(net_stream)
)
except MultiselectClientError as error:
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)
await net_stream.reset()
raise StreamFailure("failt to open a stream to peer %s", peer_id) from error
net_stream.set_protocol(selected_protocol)
return net_stream
async def connect(self, peer_info: PeerInfo) -> None: async def connect(self, peer_info: PeerInfo) -> None:
""" """
@ -111,3 +143,16 @@ class BasicHost(IHost):
async def close(self) -> None: async def close(self) -> None:
await self._network.close() await self._network.close()
# Reference: `BasicHost.newStreamHandler` in Go.
async def _swarm_stream_handler(self, net_stream: INetStream) -> None:
# Perform protocol muxing to determine protocol to use
try:
protocol, handler = await self.multiselect.negotiate(
MultiselectCommunicator(net_stream)
)
except MultiselectError:
await net_stream.reset()
return
net_stream.set_protocol(protocol)
asyncio.ensure_future(handler(net_stream))

15
libp2p/host/exceptions.py Normal file
View File

@ -0,0 +1,15 @@
from libp2p.exceptions import BaseLibp2pError
class HostException(BaseLibp2pError):
"""
A generic exception in `IHost`.
"""
class ConnectionFailure(HostException):
pass
class StreamFailure(HostException):
pass

View File

@ -0,0 +1,18 @@
from abc import abstractmethod
from typing import Tuple
from libp2p.io.abc import Closer
from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.stream_muxer.abc import IMuxedConn
class INetConn(Closer):
conn: IMuxedConn
@abstractmethod
async def new_stream(self) -> INetStream:
...
@abstractmethod
async def get_streams(self) -> Tuple[INetStream, ...]:
...

View File

@ -0,0 +1,76 @@
import asyncio
from typing import TYPE_CHECKING, Any, Awaitable, List, Set, Tuple
from libp2p.network.connection.net_connection_interface import INetConn
from libp2p.network.stream.net_stream import NetStream
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
if TYPE_CHECKING:
from libp2p.network.swarm import Swarm # noqa: F401
"""
Reference: https://github.com/libp2p/go-libp2p-swarm/blob/04c86bbdafd390651cb2ee14e334f7caeedad722/swarm_conn.go # noqa: E501
"""
class SwarmConn(INetConn):
conn: IMuxedConn
swarm: "Swarm"
streams: Set[NetStream]
event_closed: asyncio.Event
_tasks: List["asyncio.Future[Any]"]
def __init__(self, conn: IMuxedConn, swarm: "Swarm") -> None:
self.conn = conn
self.swarm = swarm
self.streams = set()
self.event_closed = asyncio.Event()
self._tasks = []
async def close(self) -> None:
if self.event_closed.is_set():
return
self.event_closed.set()
await self.conn.close()
for task in self._tasks:
task.cancel()
# TODO: Reset streams for local.
# TODO: Notify closed.
async def _handle_new_streams(self) -> None:
# TODO: Break the loop when anything wrong in the connection.
while True:
stream = await self.conn.accept_stream()
# Asynchronously handle the accepted stream, to avoid blocking the next stream.
await self.run_task(self._handle_muxed_stream(stream))
await self.close()
async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None:
net_stream = await self._add_stream(muxed_stream)
if self.swarm.common_stream_handler is not None:
await self.run_task(self.swarm.common_stream_handler(net_stream))
async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream:
net_stream = NetStream(muxed_stream)
# Call notifiers since event occurred
for notifee in self.swarm.notifees:
await notifee.opened_stream(self.swarm, net_stream)
return net_stream
async def start(self) -> None:
await self.run_task(self._handle_new_streams())
async def run_task(self, coro: Awaitable[Any]) -> None:
self._tasks.append(asyncio.ensure_future(coro))
async def new_stream(self) -> NetStream:
muxed_stream = await self.conn.open_stream()
return await self._add_stream(muxed_stream)
async def get_streams(self) -> Tuple[NetStream, ...]:
return tuple(self.streams)

View File

@ -3,9 +3,9 @@ from typing import TYPE_CHECKING, Dict, Sequence
from multiaddr import Multiaddr from multiaddr import Multiaddr
from libp2p.network.connection.net_connection_interface import INetConn
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.peer.peerstore_interface import IPeerStore from libp2p.peer.peerstore_interface import IPeerStore
from libp2p.stream_muxer.abc import IMuxedConn
from libp2p.transport.listener_interface import IListener from libp2p.transport.listener_interface import IListener
from libp2p.typing import StreamHandlerFn, TProtocol from libp2p.typing import StreamHandlerFn, TProtocol
@ -18,7 +18,7 @@ if TYPE_CHECKING:
class INetwork(ABC): class INetwork(ABC):
peerstore: IPeerStore peerstore: IPeerStore
connections: Dict[ID, IMuxedConn] connections: Dict[ID, INetConn]
listeners: Dict[str, IListener] listeners: Dict[str, IListener]
@abstractmethod @abstractmethod
@ -28,7 +28,7 @@ class INetwork(ABC):
""" """
@abstractmethod @abstractmethod
async def dial_peer(self, peer_id: ID) -> IMuxedConn: async def dial_peer(self, peer_id: ID) -> INetConn:
""" """
dial_peer try to create a connection to peer_id dial_peer try to create a connection to peer_id
@ -37,15 +37,6 @@ class INetwork(ABC):
:return: muxed connection :return: muxed connection
""" """
@abstractmethod
def set_stream_handler(
self, protocol_id: TProtocol, stream_handler: StreamHandlerFn
) -> None:
"""
:param protocol_id: protocol id used on stream
:param stream_handler: a stream handler instance
"""
@abstractmethod @abstractmethod
async def new_stream( async def new_stream(
self, peer_id: ID, protocol_ids: Sequence[TProtocol] self, peer_id: ID, protocol_ids: Sequence[TProtocol]
@ -56,6 +47,12 @@ class INetwork(ABC):
:return: net stream instance :return: net stream instance
""" """
@abstractmethod
def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None:
"""
Set the stream handler for all incoming streams.
"""
@abstractmethod @abstractmethod
async def listen(self, *multiaddrs: Sequence[Multiaddr]) -> bool: async def listen(self, *multiaddrs: Sequence[Multiaddr]) -> bool:
""" """

View File

@ -1,18 +1,15 @@
import asyncio import asyncio
import logging import logging
from typing import Callable, Dict, List, Sequence from typing import Dict, List, Optional, Sequence
from multiaddr import Multiaddr from multiaddr import Multiaddr
from libp2p.network.connection.net_connection_interface import INetConn
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.peer.peerstore import PeerStoreError from libp2p.peer.peerstore import PeerStoreError
from libp2p.peer.peerstore_interface import IPeerStore from libp2p.peer.peerstore_interface import IPeerStore
from libp2p.protocol_muxer.exceptions import MultiselectClientError
from libp2p.protocol_muxer.multiselect import Multiselect
from libp2p.protocol_muxer.multiselect_client import MultiselectClient
from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator
from libp2p.routing.interfaces import IPeerRouting from libp2p.routing.interfaces import IPeerRouting
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream from libp2p.stream_muxer.abc import IMuxedConn
from libp2p.transport.exceptions import ( from libp2p.transport.exceptions import (
MuxerUpgradeFailure, MuxerUpgradeFailure,
OpenConnectionError, OpenConnectionError,
@ -24,12 +21,11 @@ from libp2p.transport.upgrader import TransportUpgrader
from libp2p.typing import StreamHandlerFn, TProtocol from libp2p.typing import StreamHandlerFn, TProtocol
from .connection.raw_connection import RawConnection from .connection.raw_connection import RawConnection
from .connection.swarm_connection import SwarmConn
from .exceptions import SwarmException from .exceptions import SwarmException
from .network_interface import INetwork from .network_interface import INetwork
from .notifee_interface import INotifee from .notifee_interface import INotifee
from .stream.net_stream import NetStream
from .stream.net_stream_interface import INetStream from .stream.net_stream_interface import INetStream
from .typing import GenericProtocolHandlerFn
logger = logging.getLogger("libp2p.network.swarm") logger = logging.getLogger("libp2p.network.swarm")
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
@ -44,12 +40,9 @@ class Swarm(INetwork):
router: IPeerRouting router: IPeerRouting
# TODO: Connection and `peer_id` are 1-1 mapping in our implementation, # TODO: Connection and `peer_id` are 1-1 mapping in our implementation,
# whereas in Go one `peer_id` may point to multiple connections. # whereas in Go one `peer_id` may point to multiple connections.
connections: Dict[ID, IMuxedConn] connections: Dict[ID, INetConn]
listeners: Dict[str, IListener] listeners: Dict[str, IListener]
stream_handlers: Dict[INetStream, Callable[[INetStream], None]] common_stream_handler: Optional[StreamHandlerFn]
multiselect: Multiselect
multiselect_client: MultiselectClient
notifees: List[INotifee] notifees: List[INotifee]
@ -68,31 +61,19 @@ class Swarm(INetwork):
self.router = router self.router = router
self.connections = dict() self.connections = dict()
self.listeners = dict() self.listeners = dict()
self.stream_handlers = dict()
# Protocol muxing
self.multiselect = Multiselect()
self.multiselect_client = MultiselectClient()
# Create Notifee array # Create Notifee array
self.notifees = [] self.notifees = []
# Create generic protocol handler self.common_stream_handler = None
self.generic_protocol_handler = create_generic_protocol_handler(self)
def get_peer_id(self) -> ID: def get_peer_id(self) -> ID:
return self.self_id return self.self_id
def set_stream_handler( def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None:
self, protocol_id: TProtocol, stream_handler: StreamHandlerFn self.common_stream_handler = stream_handler
) -> None:
"""
:param protocol_id: protocol id used on stream
:param stream_handler: a stream handler instance
"""
self.multiselect.add_handler(protocol_id, stream_handler)
async def dial_peer(self, peer_id: ID) -> IMuxedConn: async def dial_peer(self, peer_id: ID) -> INetConn:
""" """
dial_peer try to create a connection to peer_id dial_peer try to create a connection to peer_id
:param peer_id: peer if we want to dial :param peer_id: peer if we want to dial
@ -145,9 +126,7 @@ class Swarm(INetwork):
logger.debug("upgraded security for peer %s", peer_id) logger.debug("upgraded security for peer %s", peer_id)
try: try:
muxed_conn = await self.upgrader.upgrade_connection( muxed_conn = await self.upgrader.upgrade_connection(secured_conn, peer_id)
secured_conn, self.generic_protocol_handler, peer_id
)
except MuxerUpgradeFailure as error: except MuxerUpgradeFailure as error:
error_msg = "fail to upgrade mux for peer %s" error_msg = "fail to upgrade mux for peer %s"
logger.debug(error_msg, peer_id) logger.debug(error_msg, peer_id)
@ -156,20 +135,15 @@ class Swarm(INetwork):
logger.debug("upgraded mux for peer %s", peer_id) logger.debug("upgraded mux for peer %s", peer_id)
# Store muxed connection in connections swarm_conn = await self.add_conn(muxed_conn)
self.connections[peer_id] = muxed_conn
# Call notifiers since event occurred
for notifee in self.notifees:
await notifee.connected(self, muxed_conn)
logger.debug("successfully dialed peer %s", peer_id) logger.debug("successfully dialed peer %s", peer_id)
return muxed_conn return swarm_conn
async def new_stream( async def new_stream(
self, peer_id: ID, protocol_ids: Sequence[TProtocol] self, peer_id: ID, protocol_ids: Sequence[TProtocol]
) -> NetStream: ) -> INetStream:
""" """
:param peer_id: peer_id of destination :param peer_id: peer_id of destination
:param protocol_id: protocol id :param protocol_id: protocol id
@ -182,37 +156,10 @@ class Swarm(INetwork):
protocol_ids, protocol_ids,
) )
muxed_conn = await self.dial_peer(peer_id) swarm_conn = await self.dial_peer(peer_id)
# Use muxed conn to open stream, which returns a muxed stream
muxed_stream = await muxed_conn.open_stream()
# Perform protocol muxing to determine protocol to use
try:
selected_protocol = await self.multiselect_client.select_one_of(
list(protocol_ids), MultiselectCommunicator(muxed_stream)
)
except MultiselectClientError as error:
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)
await muxed_stream.reset()
raise SwarmException(
"failt to open a stream to peer %s", peer_id
) from error
# Create a net stream with the selected protocol
net_stream = NetStream(muxed_stream)
net_stream.set_protocol(selected_protocol)
logger.debug(
"successfully opened a stream to peer %s, over protocol %s",
peer_id,
selected_protocol,
)
# Call notifiers since event occurred
for notifee in self.notifees:
await notifee.opened_stream(self, net_stream)
net_stream = await swarm_conn.new_stream()
logger.debug("successfully opened a stream to peer %s", peer_id)
return net_stream return net_stream
async def listen(self, *multiaddrs: Multiaddr) -> bool: async def listen(self, *multiaddrs: Multiaddr) -> bool:
@ -262,7 +209,7 @@ class Swarm(INetwork):
try: try:
muxed_conn = await self.upgrader.upgrade_connection( muxed_conn = await self.upgrader.upgrade_connection(
secured_conn, self.generic_protocol_handler, peer_id secured_conn, peer_id
) )
except MuxerUpgradeFailure as error: except MuxerUpgradeFailure as error:
error_msg = "fail to upgrade mux for peer %s" error_msg = "fail to upgrade mux for peer %s"
@ -270,11 +217,8 @@ class Swarm(INetwork):
await secured_conn.close() await secured_conn.close()
raise SwarmException(error_msg % peer_id) from error raise SwarmException(error_msg % peer_id) from error
logger.debug("upgraded mux for peer %s", peer_id) logger.debug("upgraded mux for peer %s", peer_id)
# Store muxed_conn with peer id
self.connections[peer_id] = muxed_conn await self.add_conn(muxed_conn)
# Call notifiers since event occurred
for notifee in self.notifees:
await notifee.connected(self, muxed_conn)
logger.debug("successfully opened connection to peer %s", peer_id) logger.debug("successfully opened connection to peer %s", peer_id)
@ -334,30 +278,13 @@ class Swarm(INetwork):
logger.debug("successfully close the connection to peer %s", peer_id) logger.debug("successfully close the connection to peer %s", peer_id)
async def add_conn(self, muxed_conn: IMuxedConn) -> SwarmConn:
def create_generic_protocol_handler(swarm: Swarm) -> GenericProtocolHandlerFn: swarm_conn = SwarmConn(muxed_conn, self)
""" # Store muxed_conn with peer id
Create a generic protocol handler from the given swarm. We use swarm self.connections[muxed_conn.peer_id] = swarm_conn
to extract the multiselect module so that generic_protocol_handler
can use multiselect when generic_protocol_handler is called
from a different class
"""
multiselect = swarm.multiselect
async def generic_protocol_handler(muxed_stream: IMuxedStream) -> None:
# Perform protocol muxing to determine protocol to use
protocol, handler = await multiselect.negotiate(
MultiselectCommunicator(muxed_stream)
)
net_stream = NetStream(muxed_stream)
net_stream.set_protocol(protocol)
# Call notifiers since event occurred # Call notifiers since event occurred
for notifee in swarm.notifees: for notifee in self.notifees:
await notifee.opened_stream(swarm, net_stream) # TODO: Call with other type of conn?
await notifee.connected(self, muxed_conn)
# Give to stream handler await swarm_conn.start()
asyncio.ensure_future(handler(net_stream)) return swarm_conn
return generic_protocol_handler

View File

@ -1,5 +0,0 @@
from typing import Awaitable, Callable
from libp2p.stream_muxer.abc import IMuxedStream
GenericProtocolHandlerFn = Callable[[IMuxedStream], Awaitable[None]]

View File

@ -1,15 +1,8 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from libp2p.io.abc import ReadWriteCloser from libp2p.io.abc import ReadWriteCloser
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.stream_muxer.mplex.constants import HeaderTags
from libp2p.stream_muxer.mplex.datastructures import StreamID
if TYPE_CHECKING:
# Prevent GenericProtocolHandlerFn introducing circular dependencies
from libp2p.network.typing import GenericProtocolHandlerFn # noqa: F401
class IMuxedConn(ABC): class IMuxedConn(ABC):
@ -20,16 +13,10 @@ class IMuxedConn(ABC):
peer_id: ID peer_id: ID
@abstractmethod @abstractmethod
def __init__( def __init__(self, conn: ISecureConn, peer_id: ID) -> None:
self,
conn: ISecureConn,
generic_protocol_handler: "GenericProtocolHandlerFn",
peer_id: ID,
) -> None:
""" """
create a new muxed connection create a new muxed connection
:param conn: an instance of secured connection :param conn: an instance of secured connection
:param generic_protocol_handler: generic protocol handler
for new muxed streams for new muxed streams
:param peer_id: peer_id of peer the connection is to :param peer_id: peer_id of peer the connection is to
""" """
@ -60,22 +47,11 @@ class IMuxedConn(ABC):
""" """
@abstractmethod @abstractmethod
async def accept_stream(self, stream_id: StreamID, name: str) -> None: async def accept_stream(self) -> "IMuxedStream":
""" """
accepts a muxed stream opened by the other end accepts a muxed stream opened by the other end
""" """
@abstractmethod
async def send_message(
self, flag: HeaderTags, data: bytes, stream_id: StreamID
) -> int:
"""
sends a message over the connection
:param header: header to use
:param data: data to send in the message
:param stream_id: stream the message is in
"""
class IMuxedStream(ReadWriteCloser): class IMuxedStream(ReadWriteCloser):

View File

@ -4,9 +4,7 @@ from typing import Dict, List, Optional, Tuple
from libp2p.exceptions import ParseError from libp2p.exceptions import ParseError
from libp2p.io.exceptions import IncompleteReadError from libp2p.io.exceptions import IncompleteReadError
from libp2p.network.typing import GenericProtocolHandlerFn
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.protocol_muxer.exceptions import MultiselectError
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
@ -37,17 +35,13 @@ class Mplex(IMuxedConn):
next_channel_id: int next_channel_id: int
streams: Dict[StreamID, MplexStream] streams: Dict[StreamID, MplexStream]
streams_lock: asyncio.Lock streams_lock: asyncio.Lock
new_stream_queue: "asyncio.Queue[IMuxedStream]"
shutdown: asyncio.Event shutdown: asyncio.Event
_tasks: List["asyncio.Future[Any]"] _tasks: List["asyncio.Future[Any]"]
# TODO: `generic_protocol_handler` should be refactored out of mplex conn. # TODO: `generic_protocol_handler` should be refactored out of mplex conn.
def __init__( def __init__(self, secured_conn: ISecureConn, peer_id: ID) -> None:
self,
secured_conn: ISecureConn,
generic_protocol_handler: GenericProtocolHandlerFn,
peer_id: ID,
) -> None:
""" """
create a new muxed connection create a new muxed connection
:param secured_conn: an instance of ``ISecureConn`` :param secured_conn: an instance of ``ISecureConn``
@ -59,15 +53,13 @@ class Mplex(IMuxedConn):
self.next_channel_id = 0 self.next_channel_id = 0
# Store generic protocol handler
self.generic_protocol_handler = generic_protocol_handler
# Set peer_id # Set peer_id
self.peer_id = peer_id self.peer_id = peer_id
# Mapping from stream ID -> buffer of messages for that stream # Mapping from stream ID -> buffer of messages for that stream
self.streams = {} self.streams = {}
self.streams_lock = asyncio.Lock() self.streams_lock = asyncio.Lock()
self.new_stream_queue = asyncio.Queue()
self.shutdown = asyncio.Event() self.shutdown = asyncio.Event()
self._tasks = [] self._tasks = []
@ -104,9 +96,9 @@ class Mplex(IMuxedConn):
return next_id return next_id
async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream: async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream:
stream = MplexStream(name, stream_id, self)
async with self.streams_lock: async with self.streams_lock:
stream = MplexStream(name, stream_id, self) self.streams[stream_id] = stream
self.streams[stream_id] = stream
return stream return stream
async def open_stream(self) -> IMuxedStream: async def open_stream(self) -> IMuxedStream:
@ -122,19 +114,11 @@ class Mplex(IMuxedConn):
await self.send_message(HeaderTags.NewStream, name.encode(), stream_id) await self.send_message(HeaderTags.NewStream, name.encode(), stream_id)
return stream return stream
async def accept_stream(self, stream_id: StreamID, name: str) -> None: async def accept_stream(self) -> IMuxedStream:
""" """
accepts a muxed stream opened by the other end accepts a muxed stream opened by the other end
""" """
stream = await self._initialize_stream(stream_id, name) return await self.new_stream_queue.get()
# Perform protocol negotiation for the stream.
try:
await self.generic_protocol_handler(stream)
except MultiselectError:
# Un-register and reset the stream
del self.streams[stream_id]
await stream.reset()
return
async def send_message( async def send_message(
self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID
@ -187,11 +171,11 @@ class Mplex(IMuxedConn):
# `NewStream` for the same id is received twice... # `NewStream` for the same id is received twice...
# TODO: Shutdown # TODO: Shutdown
pass pass
self._tasks.append( mplex_stream = await self._initialize_stream(
asyncio.ensure_future( stream_id, message.decode()
self.accept_stream(stream_id, message.decode())
)
) )
# TODO: Check if `self` is shutdown.
await self.new_stream_queue.put(mplex_stream)
elif flag in ( elif flag in (
HeaderTags.MessageInitiator.value, HeaderTags.MessageInitiator.value,
HeaderTags.MessageReceiver.value, HeaderTags.MessageReceiver.value,

View File

@ -2,7 +2,6 @@ from collections import OrderedDict
from typing import Mapping, Type from typing import Mapping, Type
from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.network.typing import GenericProtocolHandlerFn
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.protocol_muxer.multiselect import Multiselect from libp2p.protocol_muxer.multiselect import Multiselect
from libp2p.protocol_muxer.multiselect_client import MultiselectClient from libp2p.protocol_muxer.multiselect_client import MultiselectClient
@ -69,11 +68,6 @@ class MuxerMultistream:
protocol, _ = await self.multiselect.negotiate(communicator) protocol, _ = await self.multiselect.negotiate(communicator)
return self.transports[protocol] return self.transports[protocol]
async def new_conn( async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn:
self,
conn: ISecureConn,
generic_protocol_handler: GenericProtocolHandlerFn,
peer_id: ID,
) -> IMuxedConn:
transport_class = await self.select_transport(conn) transport_class = await self.select_transport(conn)
return transport_class(conn, generic_protocol_handler, peer_id) return transport_class(conn, peer_id)

View File

@ -1,7 +1,6 @@
from typing import Mapping from typing import Mapping
from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.network.typing import GenericProtocolHandlerFn
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.protocol_muxer.exceptions import MultiselectClientError, MultiselectError from libp2p.protocol_muxer.exceptions import MultiselectClientError, MultiselectError
from libp2p.security.exceptions import HandshakeFailure from libp2p.security.exceptions import HandshakeFailure
@ -57,19 +56,12 @@ class TransportUpgrader:
"handshake failed when upgrading to secure connection" "handshake failed when upgrading to secure connection"
) from error ) from error
async def upgrade_connection( async def upgrade_connection(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn:
self,
conn: ISecureConn,
generic_protocol_handler: GenericProtocolHandlerFn,
peer_id: ID,
) -> IMuxedConn:
""" """
Upgrade secured connection to a muxed connection Upgrade secured connection to a muxed connection
""" """
try: try:
return await self.muxer_multistream.new_conn( return await self.muxer_multistream.new_conn(conn, peer_id)
conn, generic_protocol_handler, peer_id
)
except (MultiselectError, MultiselectClientError) as error: except (MultiselectError, MultiselectClientError) as error:
raise MuxerUpgradeFailure( raise MuxerUpgradeFailure(
"failed to negotiate the multiplexer protocol" "failed to negotiate the multiplexer protocol"

View File

@ -2,7 +2,7 @@ import asyncio
import pytest import pytest
from libp2p.network.exceptions import SwarmException from libp2p.host.exceptions import StreamFailure
from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.peer.peerinfo import info_from_p2p_addr
from tests.utils import set_up_nodes_by_transport_opt from tests.utils import set_up_nodes_by_transport_opt
@ -84,7 +84,7 @@ async def no_common_protocol(host_a, host_b):
host_a.set_stream_handler(PROTOCOL_ID, stream_handler) host_a.set_stream_handler(PROTOCOL_ID, stream_handler)
# try to creates a new new with a procotol not known by the other host # try to creates a new new with a procotol not known by the other host
with pytest.raises(SwarmException): with pytest.raises(StreamFailure):
await host_b.new_stream(host_a.get_id(), ["/fakeproto/0.0.1"]) await host_b.new_stream(host_a.get_id(), ["/fakeproto/0.0.1"])

View File

@ -1,6 +1,6 @@
import pytest import pytest
from libp2p.network.exceptions import SwarmException from libp2p.host.exceptions import StreamFailure
from tests.utils import echo_stream_handler, set_up_nodes_by_transport_opt from tests.utils import echo_stream_handler, set_up_nodes_by_transport_opt
# TODO: Add tests for multiple streams being opened on different # TODO: Add tests for multiple streams being opened on different
@ -47,7 +47,7 @@ async def test_single_protocol_succeeds():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_single_protocol_fails(): async def test_single_protocol_fails():
with pytest.raises(SwarmException): with pytest.raises(StreamFailure):
await perform_simple_test("", ["/echo/1.0.0"], ["/potato/1.0.0"]) await perform_simple_test("", ["/echo/1.0.0"], ["/potato/1.0.0"])
# Cleanup not reached on error # Cleanup not reached on error
@ -77,7 +77,7 @@ async def test_multiple_protocol_second_is_valid_succeeds():
async def test_multiple_protocol_fails(): async def test_multiple_protocol_fails():
protocols_for_client = ["/rock/1.0.0", "/foo/1.0.0", "/bar/1.0.0"] protocols_for_client = ["/rock/1.0.0", "/foo/1.0.0", "/bar/1.0.0"]
protocols_for_listener = ["/aspyn/1.0.0", "/rob/1.0.0", "/zx/1.0.0", "/alex/1.0.0"] protocols_for_listener = ["/aspyn/1.0.0", "/rob/1.0.0", "/zx/1.0.0", "/alex/1.0.0"]
with pytest.raises(SwarmException): with pytest.raises(StreamFailure):
await perform_simple_test("", protocols_for_client, protocols_for_listener) await perform_simple_test("", protocols_for_client, protocols_for_listener)
# Cleanup not reached on error # Cleanup not reached on error

View File

@ -53,8 +53,8 @@ async def perform_simple_test(
node2_conn = node2.get_network().connections[peer_id_for_node(node1)] node2_conn = node2.get_network().connections[peer_id_for_node(node1)]
# Perform assertion # Perform assertion
assertion_func(node1_conn.secured_conn) assertion_func(node1_conn.conn.secured_conn)
assertion_func(node2_conn.secured_conn) assertion_func(node2_conn.conn.secured_conn)
# Success, terminate pending tasks. # Success, terminate pending tasks.