From a9ad37bc6f8a715086feac6262b42dbf978e079e Mon Sep 17 00:00:00 2001 From: mhchia Date: Wed, 18 Sep 2019 15:44:45 +0800 Subject: [PATCH] Add mplex tests and fix error in `SwarmConn.close` --- libp2p/__init__.py | 8 ++-- libp2p/host/basic_host.py | 2 +- libp2p/network/connection/swarm_connection.py | 13 ++++++- libp2p/security/security_multistream.py | 11 +++--- libp2p/stream_muxer/mplex/mplex.py | 3 -- libp2p/stream_muxer/mplex/mplex_stream.py | 1 + libp2p/stream_muxer/muxer_multistream.py | 19 ++++----- libp2p/transport/typing.py | 9 ++++- libp2p/transport/upgrader.py | 11 ++---- tests/factories.py | 39 ++++++++++++++----- tests/network/test_swarm_conn.py | 2 + tests/stream_muxer/__init__.py | 0 tests/stream_muxer/conftest.py | 16 ++++++++ tests/stream_muxer/test_mplex_conn.py | 6 +++ 14 files changed, 96 insertions(+), 44 deletions(-) create mode 100644 tests/stream_muxer/__init__.py create mode 100644 tests/stream_muxer/conftest.py create mode 100644 tests/stream_muxer/test_mplex_conn.py diff --git a/libp2p/__init__.py b/libp2p/__init__.py index b4d2a9a2..cbff1e44 100644 --- a/libp2p/__init__.py +++ b/libp2p/__init__.py @@ -17,8 +17,8 @@ from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTr import libp2p.security.secio.transport as secio from libp2p.security.secure_transport_interface import ISecureTransport from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex -from libp2p.stream_muxer.muxer_multistream import MuxerClassType from libp2p.transport.tcp.tcp import TCP +from libp2p.transport.typing import TMuxerClass, TMuxerOptions, TSecurityOptions from libp2p.transport.upgrader import TransportUpgrader from libp2p.typing import TProtocol @@ -74,8 +74,8 @@ def initialize_default_swarm( key_pair: KeyPair, id_opt: ID = None, transport_opt: Sequence[str] = None, - muxer_opt: Mapping[TProtocol, MuxerClassType] = None, - sec_opt: Mapping[TProtocol, ISecureTransport] = None, + muxer_opt: TMuxerOptions = None, + sec_opt: TSecurityOptions = None, peerstore_opt: IPeerStore = None, disc_opt: IPeerRouting = None, ) -> Swarm: @@ -114,7 +114,7 @@ async def new_node( key_pair: KeyPair = None, swarm_opt: INetwork = None, transport_opt: Sequence[str] = None, - muxer_opt: Mapping[TProtocol, MuxerClassType] = None, + muxer_opt: Mapping[TProtocol, TMuxerClass] = None, sec_opt: Mapping[TProtocol, ISecureTransport] = None, peerstore_opt: IPeerStore = None, disc_opt: IPeerRouting = None, diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index 5f0ccfd1..912d3ef5 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -141,4 +141,4 @@ class BasicHost(IHost): MultiselectCommunicator(net_stream) ) net_stream.set_protocol(protocol) - asyncio.ensure_future(handler(net_stream)) + await handler(net_stream) diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index 78e6ead6..6714bb83 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -66,10 +66,19 @@ class SwarmConn(INetConn): await self.close() + async def _call_stream_handler(self, net_stream: NetStream) -> None: + try: + await self.swarm.common_stream_handler(net_stream) + # TODO: More exact exceptions + except Exception: + # TODO: Emit logs. + # TODO: Clean up and remove the stream from SwarmConn if there is anything wrong. + self.remove_stream(net_stream) + async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None: net_stream = await self._add_stream(muxed_stream) if self.swarm.common_stream_handler is not None: - await self.run_task(self.swarm.common_stream_handler(net_stream)) + await self.run_task(self._call_stream_handler(net_stream)) async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: net_stream = NetStream(muxed_stream) @@ -97,4 +106,6 @@ class SwarmConn(INetConn): # TODO: Called by `Stream` whenever it is time to remove the stream. def remove_stream(self, stream: NetStream) -> None: + if stream not in self.streams: + return self.streams.remove(stream) diff --git a/libp2p/security/security_multistream.py b/libp2p/security/security_multistream.py index 466d60a8..06f4b8a5 100644 --- a/libp2p/security/security_multistream.py +++ b/libp2p/security/security_multistream.py @@ -1,6 +1,5 @@ from abc import ABC from collections import OrderedDict -from typing import Mapping from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.peer.id import ID @@ -9,6 +8,7 @@ from libp2p.protocol_muxer.multiselect_client import MultiselectClient from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_transport_interface import ISecureTransport +from libp2p.transport.typing import TSecurityOptions from libp2p.typing import TProtocol @@ -31,15 +31,14 @@ class SecurityMultistream(ABC): multiselect: Multiselect multiselect_client: MultiselectClient - def __init__( - self, secure_transports_by_protocol: Mapping[TProtocol, ISecureTransport] - ) -> None: + def __init__(self, secure_transports_by_protocol: TSecurityOptions = None) -> None: self.transports = OrderedDict() self.multiselect = Multiselect() self.multiselect_client = MultiselectClient() - for protocol, transport in secure_transports_by_protocol.items(): - self.add_transport(protocol, transport) + if secure_transports_by_protocol is not None: + for protocol, transport in secure_transports_by_protocol.items(): + self.add_transport(protocol, transport) def add_transport(self, protocol: TProtocol, transport: ISecureTransport) -> None: """ diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 6781fed4..f6602264 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -29,9 +29,6 @@ class Mplex(IMuxedConn): secured_conn: ISecureConn peer_id: ID - # TODO: `dataIn` in go implementation. Should be size of 8. - # TODO: Also, `dataIn` is closed indicating EOF in Go. We don't have similar strategies - # to let the `MplexStream`s know that EOF arrived (#235). next_channel_id: int streams: Dict[StreamID, MplexStream] streams_lock: asyncio.Lock diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 06b90faa..221e238e 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -24,6 +24,7 @@ class MplexStream(IMuxedStream): close_lock: asyncio.Lock + # NOTE: `dataIn` is size of 8 in Go implementation. incoming_data: "asyncio.Queue[bytes]" event_local_closed: asyncio.Event diff --git a/libp2p/stream_muxer/muxer_multistream.py b/libp2p/stream_muxer/muxer_multistream.py index 806c90d6..d5067490 100644 --- a/libp2p/stream_muxer/muxer_multistream.py +++ b/libp2p/stream_muxer/muxer_multistream.py @@ -1,5 +1,4 @@ from collections import OrderedDict -from typing import Mapping, Type from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.peer.id import ID @@ -7,12 +6,11 @@ from libp2p.protocol_muxer.multiselect import Multiselect from libp2p.protocol_muxer.multiselect_client import MultiselectClient from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator from libp2p.security.secure_conn_interface import ISecureConn +from libp2p.transport.typing import TMuxerClass, TMuxerOptions from libp2p.typing import TProtocol from .abc import IMuxedConn -MuxerClassType = Type[IMuxedConn] - # FIXME: add negotiate timeout to `MuxerMultistream` DEFAULT_NEGOTIATE_TIMEOUT = 60 @@ -24,20 +22,19 @@ class MuxerMultistream: """ # NOTE: Can be changed to `typing.OrderedDict` since Python 3.7.2. - transports: "OrderedDict[TProtocol, MuxerClassType]" + transports: "OrderedDict[TProtocol, TMuxerClass]" multiselect: Multiselect multiselect_client: MultiselectClient - def __init__( - self, muxer_transports_by_protocol: Mapping[TProtocol, MuxerClassType] - ) -> None: + def __init__(self, muxer_transports_by_protocol: TMuxerOptions = None) -> None: self.transports = OrderedDict() self.multiselect = Multiselect() self.multiselect_client = MultiselectClient() - for protocol, transport in muxer_transports_by_protocol.items(): - self.add_transport(protocol, transport) + if muxer_transports_by_protocol is not None: + for protocol, transport in muxer_transports_by_protocol.items(): + self.add_transport(protocol, transport) - def add_transport(self, protocol: TProtocol, transport: MuxerClassType) -> None: + def add_transport(self, protocol: TProtocol, transport: TMuxerClass) -> None: """ Add a protocol and its corresponding transport to multistream-select(multiselect). The order that a protocol is added is exactly the precedence it is negotiated in @@ -51,7 +48,7 @@ class MuxerMultistream: self.transports[protocol] = transport self.multiselect.add_handler(protocol, None) - async def select_transport(self, conn: IRawConnection) -> MuxerClassType: + async def select_transport(self, conn: IRawConnection) -> TMuxerClass: """ Select a transport that both us and the node on the other end of conn support and agree on diff --git a/libp2p/transport/typing.py b/libp2p/transport/typing.py index 6d0047c5..f9b31dcb 100644 --- a/libp2p/transport/typing.py +++ b/libp2p/transport/typing.py @@ -1,4 +1,11 @@ from asyncio import StreamReader, StreamWriter -from typing import Awaitable, Callable +from typing import Awaitable, Callable, Mapping, Type + +from libp2p.security.secure_transport_interface import ISecureTransport +from libp2p.stream_muxer.abc import IMuxedConn +from libp2p.typing import TProtocol THandler = Callable[[StreamReader, StreamWriter], Awaitable[None]] +TSecurityOptions = Mapping[TProtocol, ISecureTransport] +TMuxerClass = Type[IMuxedConn] +TMuxerOptions = Mapping[TProtocol, TMuxerClass] diff --git a/libp2p/transport/upgrader.py b/libp2p/transport/upgrader.py index 233c4d5f..877fd239 100644 --- a/libp2p/transport/upgrader.py +++ b/libp2p/transport/upgrader.py @@ -1,19 +1,16 @@ -from typing import Mapping - from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.peer.id import ID from libp2p.protocol_muxer.exceptions import MultiselectClientError, MultiselectError from libp2p.security.secure_conn_interface import ISecureConn -from libp2p.security.secure_transport_interface import ISecureTransport from libp2p.security.security_multistream import SecurityMultistream from libp2p.stream_muxer.abc import IMuxedConn -from libp2p.stream_muxer.muxer_multistream import MuxerClassType, MuxerMultistream +from libp2p.stream_muxer.muxer_multistream import MuxerMultistream from libp2p.transport.exceptions import ( HandshakeFailure, MuxerUpgradeFailure, SecurityUpgradeFailure, ) -from libp2p.typing import TProtocol +from libp2p.transport.typing import TMuxerOptions, TSecurityOptions from .listener_interface import IListener from .transport_interface import ITransport @@ -25,8 +22,8 @@ class TransportUpgrader: def __init__( self, - secure_transports_by_protocol: Mapping[TProtocol, ISecureTransport], - muxer_transports_by_protocol: Mapping[TProtocol, MuxerClassType], + secure_transports_by_protocol: TSecurityOptions, + muxer_transports_by_protocol: TMuxerOptions, ): self.security_multistream = SecurityMultistream(secure_transports_by_protocol) self.muxer_multistream = MuxerMultistream(muxer_transports_by_protocol) diff --git a/tests/factories.py b/tests/factories.py index af4d529b..dcc9a85d 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -15,6 +15,8 @@ from libp2p.pubsub.pubsub import Pubsub from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport import libp2p.security.secio.transport as secio +from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex +from libp2p.transport.typing import TMuxerOptions from libp2p.typing import TProtocol from tests.configs import LISTEN_MADDR from tests.pubsub.configs import ( @@ -34,10 +36,10 @@ def security_transport_factory( return {secio.ID: secio.Transport(key_pair)} -def SwarmFactory(is_secure: bool) -> Swarm: +def SwarmFactory(is_secure: bool, muxer_opt: TMuxerOptions = None) -> Swarm: key_pair = generate_new_rsa_identity() - sec_opt = security_transport_factory(False, key_pair) - return initialize_default_swarm(key_pair, sec_opt=sec_opt) + sec_opt = security_transport_factory(is_secure, key_pair) + return initialize_default_swarm(key_pair, sec_opt=sec_opt, muxer_opt=muxer_opt) class ListeningSwarmFactory(factory.Factory): @@ -45,17 +47,22 @@ class ListeningSwarmFactory(factory.Factory): model = Swarm @classmethod - async def create_and_listen(cls, is_secure: bool) -> Swarm: - swarm = SwarmFactory(is_secure) + async def create_and_listen( + cls, is_secure: bool, muxer_opt: TMuxerOptions = None + ) -> Swarm: + swarm = SwarmFactory(is_secure, muxer_opt=muxer_opt) await swarm.listen(LISTEN_MADDR) return swarm @classmethod async def create_batch_and_listen( - cls, is_secure: bool, number: int + cls, is_secure: bool, number: int, muxer_opt: TMuxerOptions = None ) -> Tuple[Swarm, ...]: return await asyncio.gather( - *[cls.create_and_listen(is_secure) for _ in range(number)] + *[ + cls.create_and_listen(is_secure, muxer_opt=muxer_opt) + for _ in range(number) + ] ) @@ -112,8 +119,12 @@ class PubsubFactory(factory.Factory): cache_size = None -async def swarm_pair_factory(is_secure: bool) -> Tuple[Swarm, Swarm]: - swarms = await ListeningSwarmFactory.create_batch_and_listen(is_secure, 2) +async def swarm_pair_factory( + is_secure: bool, muxer_opt: TMuxerOptions = None +) -> Tuple[Swarm, Swarm]: + swarms = await ListeningSwarmFactory.create_batch_and_listen( + is_secure, 2, muxer_opt=muxer_opt + ) await connect_swarm(swarms[0], swarms[1]) return swarms[0], swarms[1] @@ -130,7 +141,7 @@ async def host_pair_factory(is_secure) -> Tuple[BasicHost, BasicHost]: async def swarm_conn_pair_factory( - is_secure + is_secure: bool, muxer_opt: TMuxerOptions = None ) -> Tuple[SwarmConn, Swarm, SwarmConn, Swarm]: swarms = await swarm_pair_factory(is_secure) conn_0 = swarms[0].connections[swarms[1].get_peer_id()] @@ -138,6 +149,14 @@ async def swarm_conn_pair_factory( return conn_0, swarms[0], conn_1, swarms[1] +async def mplex_conn_pair_factory(is_secure): + muxer_opt = {MPLEX_PROTOCOL_ID: Mplex} + conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory( + is_secure, muxer_opt=muxer_opt + ) + return conn_0.conn, swarm_0, conn_1.conn, swarm_1 + + async def net_stream_pair_factory( is_secure: bool ) -> Tuple[INetStream, BasicHost, INetStream, BasicHost]: diff --git a/tests/network/test_swarm_conn.py b/tests/network/test_swarm_conn.py index f9974e14..2abc7d0f 100644 --- a/tests/network/test_swarm_conn.py +++ b/tests/network/test_swarm_conn.py @@ -41,3 +41,5 @@ async def test_swarm_conn_streams(swarm_conn_pair): assert len(await conn_0.get_streams()) == 1 conn_0.remove_stream(stream_0_1) assert len(await conn_0.get_streams()) == 0 + # Nothing happen if `stream_0_1` is not present or already removed. + conn_0.remove_stream(stream_0_1) diff --git a/tests/stream_muxer/__init__.py b/tests/stream_muxer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/stream_muxer/conftest.py b/tests/stream_muxer/conftest.py new file mode 100644 index 00000000..3695ae47 --- /dev/null +++ b/tests/stream_muxer/conftest.py @@ -0,0 +1,16 @@ +import asyncio + +import pytest + +from tests.factories import mplex_conn_pair_factory + + +@pytest.fixture +async def mplex_conn_pair(is_host_secure): + mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory( + is_host_secure + ) + try: + yield mplex_conn_0, mplex_conn_1 + finally: + await asyncio.gather(*[swarm_0.close(), swarm_1.close()]) diff --git a/tests/stream_muxer/test_mplex_conn.py b/tests/stream_muxer/test_mplex_conn.py new file mode 100644 index 00000000..a85d9f4b --- /dev/null +++ b/tests/stream_muxer/test_mplex_conn.py @@ -0,0 +1,6 @@ +import pytest + + +@pytest.mark.asyncio +async def test_mplex_conn(mplex_conn_pair): + conn_0, conn_1 = mplex_conn_pair