diff --git a/libp2p/__init__.py b/libp2p/__init__.py index b4d2a9a2..08caf256 100644 --- a/libp2p/__init__.py +++ b/libp2p/__init__.py @@ -1,5 +1,5 @@ import asyncio -from typing import Mapping, Sequence +from typing import Sequence from libp2p.crypto.keys import KeyPair from libp2p.crypto.rsa import create_new_key_pair @@ -15,10 +15,9 @@ from libp2p.routing.interfaces import IPeerRouting from libp2p.routing.kademlia.kademlia_peer_router import KadmeliaPeerRouter from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport import libp2p.security.secio.transport as secio -from libp2p.security.secure_transport_interface import ISecureTransport from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex -from libp2p.stream_muxer.muxer_multistream import MuxerClassType from libp2p.transport.tcp.tcp import TCP +from libp2p.transport.typing import TMuxerOptions, TSecurityOptions from libp2p.transport.upgrader import TransportUpgrader from libp2p.typing import TProtocol @@ -74,8 +73,8 @@ def initialize_default_swarm( key_pair: KeyPair, id_opt: ID = None, transport_opt: Sequence[str] = None, - muxer_opt: Mapping[TProtocol, MuxerClassType] = None, - sec_opt: Mapping[TProtocol, ISecureTransport] = None, + muxer_opt: TMuxerOptions = None, + sec_opt: TSecurityOptions = None, peerstore_opt: IPeerStore = None, disc_opt: IPeerRouting = None, ) -> Swarm: @@ -114,8 +113,8 @@ async def new_node( key_pair: KeyPair = None, swarm_opt: INetwork = None, transport_opt: Sequence[str] = None, - muxer_opt: Mapping[TProtocol, MuxerClassType] = None, - sec_opt: Mapping[TProtocol, ISecureTransport] = None, + muxer_opt: TMuxerOptions = None, + sec_opt: TSecurityOptions = None, peerstore_opt: IPeerStore = None, disc_opt: IPeerRouting = None, ) -> BasicHost: diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index af4225f4..7c783d93 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -1,4 +1,3 @@ -import asyncio import logging from typing import List, Sequence @@ -107,7 +106,7 @@ class BasicHost(IHost): :return: stream: new stream created """ - net_stream = await self._network.new_stream(peer_id, protocol_ids) + net_stream = await self._network.new_stream(peer_id) # Perform protocol muxing to determine protocol to use try: @@ -157,4 +156,4 @@ class BasicHost(IHost): await net_stream.reset() return net_stream.set_protocol(protocol) - asyncio.ensure_future(handler(net_stream)) + await handler(net_stream) diff --git a/libp2p/network/connection/net_connection_interface.py b/libp2p/network/connection/net_connection_interface.py index c2c62855..e308ad65 100644 --- a/libp2p/network/connection/net_connection_interface.py +++ b/libp2p/network/connection/net_connection_interface.py @@ -7,7 +7,7 @@ from libp2p.stream_muxer.abc import IMuxedConn class INetConn(Closer): - conn: IMuxedConn + muxed_conn: IMuxedConn @abstractmethod async def new_stream(self) -> INetStream: diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index 15816fcb..e25d75f0 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -16,15 +16,15 @@ Reference: https://github.com/libp2p/go-libp2p-swarm/blob/04c86bbdafd390651cb2ee class SwarmConn(INetConn): - conn: IMuxedConn + muxed_conn: IMuxedConn swarm: "Swarm" streams: Set[NetStream] event_closed: asyncio.Event _tasks: List["asyncio.Future[Any]"] - def __init__(self, conn: IMuxedConn, swarm: "Swarm") -> None: - self.conn = conn + def __init__(self, muxed_conn: IMuxedConn, swarm: "Swarm") -> None: + self.muxed_conn = muxed_conn self.swarm = swarm self.streams = set() self.event_closed = asyncio.Event() @@ -37,22 +37,26 @@ class SwarmConn(INetConn): self.event_closed.set() self.swarm.remove_conn(self) - await self.conn.close() + await self.muxed_conn.close() # This is just for cleaning up state. The connection has already been closed. # We *could* optimize this but it really isn't worth it. for stream in self.streams: await stream.reset() - # Schedule `self._notify_disconnected` to make it execute after `close` is finished. - asyncio.ensure_future(self._notify_disconnected()) for task in self._tasks: task.cancel() + try: + await task + except asyncio.CancelledError: + pass + # Schedule `self._notify_disconnected` to make it execute after `close` is finished. + self._notify_disconnected() async def _handle_new_streams(self) -> None: while True: try: - stream = await self.conn.accept_stream() + stream = await self.muxed_conn.accept_stream() except MuxedConnUnavailable: # If there is anything wrong in the MuxedConn, # we should break the loop and close the connection. @@ -62,22 +66,28 @@ class SwarmConn(INetConn): await self.close() - async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None: - net_stream = await self._add_stream(muxed_stream) - if self.swarm.common_stream_handler is not None: - await self.run_task(self.swarm.common_stream_handler(net_stream)) + async def _call_stream_handler(self, net_stream: NetStream) -> None: + try: + await self.swarm.common_stream_handler(net_stream) + # TODO: More exact exceptions + except Exception: + # TODO: Emit logs. + # TODO: Clean up and remove the stream from SwarmConn if there is anything wrong. + self.remove_stream(net_stream) - async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: + async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None: + net_stream = self._add_stream(muxed_stream) + if self.swarm.common_stream_handler is not None: + await self.run_task(self._call_stream_handler(net_stream)) + + def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: net_stream = NetStream(muxed_stream) self.streams.add(net_stream) - # Call notifiers since event occurred - for notifee in self.swarm.notifees: - await notifee.opened_stream(self.swarm, net_stream) + self.swarm.notify_opened_stream(net_stream) return net_stream - async def _notify_disconnected(self) -> None: - for notifee in self.swarm.notifees: - await notifee.disconnected(self.swarm, self.conn) + def _notify_disconnected(self) -> None: + self.swarm.notify_disconnected(self) async def start(self) -> None: await self.run_task(self._handle_new_streams()) @@ -86,8 +96,13 @@ class SwarmConn(INetConn): self._tasks.append(asyncio.ensure_future(coro)) async def new_stream(self) -> NetStream: - muxed_stream = await self.conn.open_stream() - return await self._add_stream(muxed_stream) + muxed_stream = await self.muxed_conn.open_stream() + return self._add_stream(muxed_stream) async def get_streams(self) -> Tuple[NetStream, ...]: return tuple(self.streams) + + def remove_stream(self, stream: NetStream) -> None: + if stream not in self.streams: + return + self.streams.remove(stream) diff --git a/libp2p/network/network_interface.py b/libp2p/network/network_interface.py index 470da1a9..94ddba2c 100644 --- a/libp2p/network/network_interface.py +++ b/libp2p/network/network_interface.py @@ -7,7 +7,7 @@ from libp2p.network.connection.net_connection_interface import INetConn from libp2p.peer.id import ID from libp2p.peer.peerstore_interface import IPeerStore from libp2p.transport.listener_interface import IListener -from libp2p.typing import StreamHandlerFn, TProtocol +from libp2p.typing import StreamHandlerFn from .stream.net_stream_interface import INetStream @@ -38,9 +38,7 @@ class INetwork(ABC): """ @abstractmethod - async def new_stream( - self, peer_id: ID, protocol_ids: Sequence[TProtocol] - ) -> INetStream: + async def new_stream(self, peer_id: ID) -> INetStream: """ :param peer_id: peer_id of destination :param protocol_ids: available protocol ids to use for stream @@ -61,7 +59,7 @@ class INetwork(ABC): """ @abstractmethod - def notify(self, notifee: "INotifee") -> bool: + def register_notifee(self, notifee: "INotifee") -> None: """ :param notifee: object implementing Notifee interface :return: true if notifee registered successfully, false otherwise diff --git a/libp2p/network/notifee_interface.py b/libp2p/network/notifee_interface.py index ef996bfb..c31f4732 100644 --- a/libp2p/network/notifee_interface.py +++ b/libp2p/network/notifee_interface.py @@ -3,8 +3,8 @@ from typing import TYPE_CHECKING from multiaddr import Multiaddr +from libp2p.network.connection.net_connection_interface import INetConn from libp2p.network.stream.net_stream_interface import INetStream -from libp2p.stream_muxer.abc import IMuxedConn if TYPE_CHECKING: from .network_interface import INetwork # noqa: F401 @@ -26,14 +26,14 @@ class INotifee(ABC): """ @abstractmethod - async def connected(self, network: "INetwork", conn: IMuxedConn) -> None: + async def connected(self, network: "INetwork", conn: INetConn) -> None: """ :param network: network the connection was opened on :param conn: connection that was opened """ @abstractmethod - async def disconnected(self, network: "INetwork", conn: IMuxedConn) -> None: + async def disconnected(self, network: "INetwork", conn: INetConn) -> None: """ :param network: network the connection was closed on :param conn: connection that was closed diff --git a/libp2p/network/stream/net_stream.py b/libp2p/network/stream/net_stream.py index d500c088..0142721c 100644 --- a/libp2p/network/stream/net_stream.py +++ b/libp2p/network/stream/net_stream.py @@ -1,4 +1,6 @@ -from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream +from typing import Optional + +from libp2p.stream_muxer.abc import IMuxedStream from libp2p.stream_muxer.exceptions import ( MuxedStreamClosed, MuxedStreamEOF, @@ -16,13 +18,11 @@ from .net_stream_interface import INetStream class NetStream(INetStream): muxed_stream: IMuxedStream - # TODO: Why we expose `mplex_conn` here? - mplex_conn: IMuxedConn - protocol_id: TProtocol + protocol_id: Optional[TProtocol] def __init__(self, muxed_stream: IMuxedStream) -> None: self.muxed_stream = muxed_stream - self.mplex_conn = muxed_stream.mplex_conn + self.muxed_conn = muxed_stream.muxed_conn self.protocol_id = None def get_protocol(self) -> TProtocol: @@ -68,3 +68,7 @@ class NetStream(INetStream): async def reset(self) -> None: await self.muxed_stream.reset() + + # TODO: `remove`: Called by close and write when the stream is in specific states. + # It notifies `ClosedStream` after `SwarmConn.remove_stream` is called. + # Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501 diff --git a/libp2p/network/stream/net_stream_interface.py b/libp2p/network/stream/net_stream_interface.py index d0547890..41bf4239 100644 --- a/libp2p/network/stream/net_stream_interface.py +++ b/libp2p/network/stream/net_stream_interface.py @@ -7,7 +7,7 @@ from libp2p.typing import TProtocol class INetStream(ReadWriteCloser): - mplex_conn: IMuxedConn + muxed_conn: IMuxedConn @abstractmethod def get_protocol(self) -> TProtocol: diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 32997203..9d507fb6 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -1,6 +1,6 @@ import asyncio import logging -from typing import Dict, List, Optional, Sequence +from typing import Dict, List, Optional from multiaddr import Multiaddr @@ -18,7 +18,7 @@ from libp2p.transport.exceptions import ( from libp2p.transport.listener_interface import IListener from libp2p.transport.transport_interface import ITransport from libp2p.transport.upgrader import TransportUpgrader -from libp2p.typing import StreamHandlerFn, TProtocol +from libp2p.typing import StreamHandlerFn from .connection.raw_connection import RawConnection from .connection.swarm_connection import SwarmConn @@ -141,20 +141,14 @@ class Swarm(INetwork): return swarm_conn - async def new_stream( - self, peer_id: ID, protocol_ids: Sequence[TProtocol] - ) -> INetStream: + async def new_stream(self, peer_id: ID) -> INetStream: """ :param peer_id: peer_id of destination :param protocol_id: protocol id :raises SwarmException: raised when an error occurs :return: net stream instance """ - logger.debug( - "attempting to open a stream to peer %s, over one of the protocols %s", - peer_id, - protocol_ids, - ) + logger.debug("attempting to open a stream to peer %s", peer_id) swarm_conn = await self.dial_peer(peer_id) @@ -229,8 +223,7 @@ class Swarm(INetwork): await listener.listen(maddr) # Call notifiers since event occurred - for notifee in self.notifees: - await notifee.listen(self, maddr) + self.notify_listen(maddr) return True except IOError: @@ -240,16 +233,6 @@ class Swarm(INetwork): # No maddr succeeded return False - def notify(self, notifee: INotifee) -> bool: - """ - :param notifee: object implementing Notifee interface - :return: true if notifee registered successfully, false otherwise - """ - if isinstance(notifee, INotifee): - self.notifees.append(notifee) - return True - return False - def add_router(self, router: IPeerRouting) -> None: self.router = router @@ -288,9 +271,7 @@ class Swarm(INetwork): # Store muxed_conn with peer id self.connections[muxed_conn.peer_id] = swarm_conn # Call notifiers since event occurred - for notifee in self.notifees: - # TODO: Call with other type of conn? - await notifee.connected(self, muxed_conn) + self.notify_connected(swarm_conn) await swarm_conn.start() return swarm_conn @@ -298,9 +279,38 @@ class Swarm(INetwork): """ Simply remove the connection from Swarm's records, without closing the connection. """ - peer_id = swarm_conn.conn.peer_id + peer_id = swarm_conn.muxed_conn.peer_id if peer_id not in self.connections: return # TODO: Should be changed to remove the exact connection, # if we have several connections per peer in the future. del self.connections[peer_id] + + # Notifee + + # TODO: Remeber the spawn notifying tasks and clean them up when closing. + + def register_notifee(self, notifee: INotifee) -> None: + """ + :param notifee: object implementing Notifee interface + :return: true if notifee registered successfully, false otherwise + """ + self.notifees.append(notifee) + + def notify_opened_stream(self, stream: INetStream) -> None: + asyncio.gather( + *[notifee.opened_stream(self, stream) for notifee in self.notifees] + ) + + # TODO: `notify_closed_stream` + + def notify_connected(self, conn: INetConn) -> None: + asyncio.gather(*[notifee.connected(self, conn) for notifee in self.notifees]) + + def notify_disconnected(self, conn: INetConn) -> None: + asyncio.gather(*[notifee.disconnected(self, conn) for notifee in self.notifees]) + + def notify_listen(self, multiaddr: Multiaddr) -> None: + asyncio.gather(*[notifee.listen(self, multiaddr) for notifee in self.notifees]) + + # TODO: `notify_listen_close` diff --git a/libp2p/protocol_muxer/multiselect_communicator.py b/libp2p/protocol_muxer/multiselect_communicator.py index a66a5642..6f7b715d 100644 --- a/libp2p/protocol_muxer/multiselect_communicator.py +++ b/libp2p/protocol_muxer/multiselect_communicator.py @@ -20,10 +20,10 @@ class MultiselectCommunicator(IMultiselectCommunicator): msg_bytes = encode_delim(msg_str.encode()) try: await self.read_writer.write(msg_bytes) - except IOException: + except IOException as error: raise MultiselectCommunicatorError( "fail to write to multiselect communicator" - ) + ) from error async def read(self) -> str: """ @@ -32,8 +32,8 @@ class MultiselectCommunicator(IMultiselectCommunicator): try: data = await read_delim(self.read_writer) # `IOException` includes `IncompleteReadError` and `StreamError` - except (ParseError, IOException, ValueError): + except (ParseError, IOException) as error: raise MultiselectCommunicatorError( "fail to read from multiselect communicator" - ) + ) from error return data.decode() diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index ffc210be..cceba08a 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -98,7 +98,7 @@ class Pubsub: # Register a notifee self.peer_queue = asyncio.Queue() - self.host.get_network().notify(PubsubNotifee(self.peer_queue)) + self.host.get_network().register_notifee(PubsubNotifee(self.peer_queue)) # Register stream handlers for each pubsub router protocol to handle # the pubsub streams opened on those protocols @@ -154,7 +154,7 @@ class Pubsub: messages from other nodes :param stream: stream to continously read from """ - peer_id = stream.mplex_conn.peer_id + peer_id = stream.muxed_conn.peer_id while True: try: diff --git a/libp2p/pubsub/pubsub_notifee.py b/libp2p/pubsub/pubsub_notifee.py index 6ecab1ab..85c0bd8d 100644 --- a/libp2p/pubsub/pubsub_notifee.py +++ b/libp2p/pubsub/pubsub_notifee.py @@ -2,10 +2,10 @@ from typing import TYPE_CHECKING from multiaddr import Multiaddr +from libp2p.network.connection.net_connection_interface import INetConn from libp2p.network.network_interface import INetwork from libp2p.network.notifee_interface import INotifee from libp2p.network.stream.net_stream_interface import INetStream -from libp2p.stream_muxer.abc import IMuxedConn if TYPE_CHECKING: import asyncio # noqa: F401 @@ -29,16 +29,16 @@ class PubsubNotifee(INotifee): async def closed_stream(self, network: INetwork, stream: INetStream) -> None: pass - async def connected(self, network: INetwork, conn: IMuxedConn) -> None: + async def connected(self, network: INetwork, conn: INetConn) -> None: """ Add peer_id to initiator_peers_queue, so that this peer_id can be used to create a stream and we only want to have one pubsub stream with each peer. :param network: network the connection was opened on :param conn: connection that was opened """ - await self.initiator_peers_queue.put(conn.peer_id) + await self.initiator_peers_queue.put(conn.muxed_conn.peer_id) - async def disconnected(self, network: INetwork, conn: IMuxedConn) -> None: + async def disconnected(self, network: INetwork, conn: INetConn) -> None: pass async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None: diff --git a/libp2p/security/security_multistream.py b/libp2p/security/security_multistream.py index 466d60a8..cff55af3 100644 --- a/libp2p/security/security_multistream.py +++ b/libp2p/security/security_multistream.py @@ -1,6 +1,5 @@ from abc import ABC from collections import OrderedDict -from typing import Mapping from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.peer.id import ID @@ -9,6 +8,7 @@ from libp2p.protocol_muxer.multiselect_client import MultiselectClient from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_transport_interface import ISecureTransport +from libp2p.transport.typing import TSecurityOptions from libp2p.typing import TProtocol @@ -31,9 +31,7 @@ class SecurityMultistream(ABC): multiselect: Multiselect multiselect_client: MultiselectClient - def __init__( - self, secure_transports_by_protocol: Mapping[TProtocol, ISecureTransport] - ) -> None: + def __init__(self, secure_transports_by_protocol: TSecurityOptions) -> None: self.transports = OrderedDict() self.multiselect = Multiselect() self.multiselect_client = MultiselectClient() diff --git a/libp2p/stream_muxer/abc.py b/libp2p/stream_muxer/abc.py index 78438f22..4af110b6 100644 --- a/libp2p/stream_muxer/abc.py +++ b/libp2p/stream_muxer/abc.py @@ -55,7 +55,7 @@ class IMuxedConn(ABC): class IMuxedStream(ReadWriteCloser): - mplex_conn: IMuxedConn + muxed_conn: IMuxedConn @abstractmethod async def reset(self) -> None: diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index ea7e47ac..768e66c1 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -31,9 +31,6 @@ class Mplex(IMuxedConn): secured_conn: ISecureConn peer_id: ID - # TODO: `dataIn` in go implementation. Should be size of 8. - # TODO: Also, `dataIn` is closed indicating EOF in Go. We don't have similar strategies - # to let the `MplexStream`s know that EOF arrived (#235). next_channel_id: int streams: Dict[StreamID, MplexStream] streams_lock: asyncio.Lock @@ -43,7 +40,6 @@ class Mplex(IMuxedConn): _tasks: List["asyncio.Future[Any]"] - # TODO: `generic_protocol_handler` should be refactored out of mplex conn. def __init__(self, secured_conn: ISecureConn, peer_id: ID) -> None: """ create a new muxed connection diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 8cabccc4..221e238e 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -18,12 +18,13 @@ class MplexStream(IMuxedStream): name: str stream_id: StreamID - mplex_conn: "Mplex" + muxed_conn: "Mplex" read_deadline: int write_deadline: int close_lock: asyncio.Lock + # NOTE: `dataIn` is size of 8 in Go implementation. incoming_data: "asyncio.Queue[bytes]" event_local_closed: asyncio.Event @@ -32,15 +33,15 @@ class MplexStream(IMuxedStream): _buf: bytearray - def __init__(self, name: str, stream_id: StreamID, mplex_conn: "Mplex") -> None: + def __init__(self, name: str, stream_id: StreamID, muxed_conn: "Mplex") -> None: """ create new MuxedStream in muxer :param stream_id: stream id of this stream - :param mplex_conn: muxed connection of this muxed_stream + :param muxed_conn: muxed connection of this muxed_stream """ self.name = name self.stream_id = stream_id - self.mplex_conn = mplex_conn + self.muxed_conn = muxed_conn self.read_deadline = None self.write_deadline = None self.event_local_closed = asyncio.Event() @@ -147,7 +148,7 @@ class MplexStream(IMuxedStream): if self.is_initiator else HeaderTags.MessageReceiver ) - return await self.mplex_conn.send_message(flag, data, self.stream_id) + return await self.muxed_conn.send_message(flag, data, self.stream_id) async def close(self) -> None: """ @@ -163,8 +164,8 @@ class MplexStream(IMuxedStream): flag = ( HeaderTags.CloseInitiator if self.is_initiator else HeaderTags.CloseReceiver ) - # TODO: Raise when `mplex_conn.send_message` fails and `Mplex` isn't shutdown. - await self.mplex_conn.send_message(flag, None, self.stream_id) + # TODO: Raise when `muxed_conn.send_message` fails and `Mplex` isn't shutdown. + await self.muxed_conn.send_message(flag, None, self.stream_id) _is_remote_closed: bool async with self.close_lock: @@ -173,8 +174,8 @@ class MplexStream(IMuxedStream): if _is_remote_closed: # Both sides are closed, we can safely remove the buffer from the dict. - async with self.mplex_conn.streams_lock: - del self.mplex_conn.streams[self.stream_id] + async with self.muxed_conn.streams_lock: + del self.muxed_conn.streams[self.stream_id] async def reset(self) -> None: """ @@ -196,19 +197,19 @@ class MplexStream(IMuxedStream): else HeaderTags.ResetReceiver ) asyncio.ensure_future( - self.mplex_conn.send_message(flag, None, self.stream_id) + self.muxed_conn.send_message(flag, None, self.stream_id) ) await asyncio.sleep(0) self.event_local_closed.set() self.event_remote_closed.set() - async with self.mplex_conn.streams_lock: + async with self.muxed_conn.streams_lock: if ( - self.mplex_conn.streams is not None - and self.stream_id in self.mplex_conn.streams + self.muxed_conn.streams is not None + and self.stream_id in self.muxed_conn.streams ): - del self.mplex_conn.streams[self.stream_id] + del self.muxed_conn.streams[self.stream_id] # TODO deadline not in use def set_deadline(self, ttl: int) -> bool: diff --git a/libp2p/stream_muxer/muxer_multistream.py b/libp2p/stream_muxer/muxer_multistream.py index 806c90d6..7f6ee077 100644 --- a/libp2p/stream_muxer/muxer_multistream.py +++ b/libp2p/stream_muxer/muxer_multistream.py @@ -1,5 +1,4 @@ from collections import OrderedDict -from typing import Mapping, Type from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.peer.id import ID @@ -7,12 +6,11 @@ from libp2p.protocol_muxer.multiselect import Multiselect from libp2p.protocol_muxer.multiselect_client import MultiselectClient from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator from libp2p.security.secure_conn_interface import ISecureConn +from libp2p.transport.typing import TMuxerClass, TMuxerOptions from libp2p.typing import TProtocol from .abc import IMuxedConn -MuxerClassType = Type[IMuxedConn] - # FIXME: add negotiate timeout to `MuxerMultistream` DEFAULT_NEGOTIATE_TIMEOUT = 60 @@ -24,20 +22,18 @@ class MuxerMultistream: """ # NOTE: Can be changed to `typing.OrderedDict` since Python 3.7.2. - transports: "OrderedDict[TProtocol, MuxerClassType]" + transports: "OrderedDict[TProtocol, TMuxerClass]" multiselect: Multiselect multiselect_client: MultiselectClient - def __init__( - self, muxer_transports_by_protocol: Mapping[TProtocol, MuxerClassType] - ) -> None: + def __init__(self, muxer_transports_by_protocol: TMuxerOptions) -> None: self.transports = OrderedDict() self.multiselect = Multiselect() self.multiselect_client = MultiselectClient() for protocol, transport in muxer_transports_by_protocol.items(): self.add_transport(protocol, transport) - def add_transport(self, protocol: TProtocol, transport: MuxerClassType) -> None: + def add_transport(self, protocol: TProtocol, transport: TMuxerClass) -> None: """ Add a protocol and its corresponding transport to multistream-select(multiselect). The order that a protocol is added is exactly the precedence it is negotiated in @@ -51,7 +47,7 @@ class MuxerMultistream: self.transports[protocol] = transport self.multiselect.add_handler(protocol, None) - async def select_transport(self, conn: IRawConnection) -> MuxerClassType: + async def select_transport(self, conn: IRawConnection) -> TMuxerClass: """ Select a transport that both us and the node on the other end of conn support and agree on diff --git a/libp2p/transport/typing.py b/libp2p/transport/typing.py index 6d0047c5..f9b31dcb 100644 --- a/libp2p/transport/typing.py +++ b/libp2p/transport/typing.py @@ -1,4 +1,11 @@ from asyncio import StreamReader, StreamWriter -from typing import Awaitable, Callable +from typing import Awaitable, Callable, Mapping, Type + +from libp2p.security.secure_transport_interface import ISecureTransport +from libp2p.stream_muxer.abc import IMuxedConn +from libp2p.typing import TProtocol THandler = Callable[[StreamReader, StreamWriter], Awaitable[None]] +TSecurityOptions = Mapping[TProtocol, ISecureTransport] +TMuxerClass = Type[IMuxedConn] +TMuxerOptions = Mapping[TProtocol, TMuxerClass] diff --git a/libp2p/transport/upgrader.py b/libp2p/transport/upgrader.py index 9f8be990..8dda95e7 100644 --- a/libp2p/transport/upgrader.py +++ b/libp2p/transport/upgrader.py @@ -1,16 +1,13 @@ -from typing import Mapping - from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.peer.id import ID from libp2p.protocol_muxer.exceptions import MultiselectClientError, MultiselectError from libp2p.security.exceptions import HandshakeFailure from libp2p.security.secure_conn_interface import ISecureConn -from libp2p.security.secure_transport_interface import ISecureTransport from libp2p.security.security_multistream import SecurityMultistream from libp2p.stream_muxer.abc import IMuxedConn -from libp2p.stream_muxer.muxer_multistream import MuxerClassType, MuxerMultistream +from libp2p.stream_muxer.muxer_multistream import MuxerMultistream from libp2p.transport.exceptions import MuxerUpgradeFailure, SecurityUpgradeFailure -from libp2p.typing import TProtocol +from libp2p.transport.typing import TMuxerOptions, TSecurityOptions from .listener_interface import IListener from .transport_interface import ITransport @@ -22,8 +19,8 @@ class TransportUpgrader: def __init__( self, - secure_transports_by_protocol: Mapping[TProtocol, ISecureTransport], - muxer_transports_by_protocol: Mapping[TProtocol, MuxerClassType], + secure_transports_by_protocol: TSecurityOptions, + muxer_transports_by_protocol: TMuxerOptions, ): self.security_multistream = SecurityMultistream(secure_transports_by_protocol) self.muxer_multistream = MuxerMultistream(muxer_transports_by_protocol) diff --git a/libp2p/utils.py b/libp2p/utils.py index 0e15b567..39c79e58 100644 --- a/libp2p/utils.py +++ b/libp2p/utils.py @@ -73,9 +73,12 @@ def encode_delim(msg: bytes) -> bytes: async def read_delim(reader: Reader) -> bytes: msg_bytes = await read_varint_prefixed_bytes(reader) - # TODO: Investigate if it is possible to have empty `msg_bytes` - if len(msg_bytes) != 0 and msg_bytes[-1:] != b"\n": - raise ValueError(f'msg_bytes is not delimited by b"\\n": msg_bytes={msg_bytes}') + if len(msg_bytes) == 0: + raise ParseError(f"`len(msg_bytes)` should not be 0") + if msg_bytes[-1:] != b"\n": + raise ParseError( + f'`msg_bytes` is not delimited by b"\\n": `msg_bytes`={msg_bytes}' + ) return msg_bytes[:-1] diff --git a/tests/factories.py b/tests/factories.py index e39b12d3..b4e8be23 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -6,6 +6,7 @@ import factory from libp2p import generate_new_rsa_identity, initialize_default_swarm from libp2p.crypto.keys import KeyPair from libp2p.host.basic_host import BasicHost +from libp2p.network.connection.swarm_connection import SwarmConn from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.swarm import Swarm from libp2p.pubsub.floodsub import FloodSub @@ -14,6 +15,9 @@ from libp2p.pubsub.pubsub import Pubsub from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport import libp2p.security.secio.transport as secio +from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex +from libp2p.stream_muxer.mplex.mplex_stream import MplexStream +from libp2p.transport.typing import TMuxerOptions from libp2p.typing import TProtocol from tests.configs import LISTEN_MADDR from tests.pubsub.configs import ( @@ -33,10 +37,10 @@ def security_transport_factory( return {secio.ID: secio.Transport(key_pair)} -def SwarmFactory(is_secure: bool) -> Swarm: +def SwarmFactory(is_secure: bool, muxer_opt: TMuxerOptions = None) -> Swarm: key_pair = generate_new_rsa_identity() - sec_opt = security_transport_factory(False, key_pair) - return initialize_default_swarm(key_pair, sec_opt=sec_opt) + sec_opt = security_transport_factory(is_secure, key_pair) + return initialize_default_swarm(key_pair, sec_opt=sec_opt, muxer_opt=muxer_opt) class ListeningSwarmFactory(factory.Factory): @@ -44,17 +48,22 @@ class ListeningSwarmFactory(factory.Factory): model = Swarm @classmethod - async def create_and_listen(cls, is_secure: bool) -> Swarm: - swarm = SwarmFactory(is_secure) + async def create_and_listen( + cls, is_secure: bool, muxer_opt: TMuxerOptions = None + ) -> Swarm: + swarm = SwarmFactory(is_secure, muxer_opt=muxer_opt) await swarm.listen(LISTEN_MADDR) return swarm @classmethod async def create_batch_and_listen( - cls, is_secure: bool, number: int + cls, is_secure: bool, number: int, muxer_opt: TMuxerOptions = None ) -> Tuple[Swarm, ...]: return await asyncio.gather( - *[cls.create_and_listen(is_secure) for _ in range(number)] + *[ + cls.create_and_listen(is_secure, muxer_opt=muxer_opt) + for _ in range(number) + ] ) @@ -111,8 +120,12 @@ class PubsubFactory(factory.Factory): cache_size = None -async def swarm_pair_factory(is_secure: bool) -> Tuple[Swarm, Swarm]: - swarms = await ListeningSwarmFactory.create_batch_and_listen(is_secure, 2) +async def swarm_pair_factory( + is_secure: bool, muxer_opt: TMuxerOptions = None +) -> Tuple[Swarm, Swarm]: + swarms = await ListeningSwarmFactory.create_batch_and_listen( + is_secure, 2, muxer_opt=muxer_opt + ) await connect_swarm(swarms[0], swarms[1]) return swarms[0], swarms[1] @@ -128,11 +141,37 @@ async def host_pair_factory(is_secure) -> Tuple[BasicHost, BasicHost]: return hosts[0], hosts[1] -# async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]: -# host_0, host_1 = await host_pair_factory() -# mplex_conn_0 = host_0.get_network().connections[host_1.get_id()] -# mplex_conn_1 = host_1.get_network().connections[host_0.get_id()] -# return mplex_conn_0, host_0, mplex_conn_1, host_1 +async def swarm_conn_pair_factory( + is_secure: bool, muxer_opt: TMuxerOptions = None +) -> Tuple[SwarmConn, Swarm, SwarmConn, Swarm]: + swarms = await swarm_pair_factory(is_secure) + conn_0 = swarms[0].connections[swarms[1].get_peer_id()] + conn_1 = swarms[1].connections[swarms[0].get_peer_id()] + return conn_0, swarms[0], conn_1, swarms[1] + + +async def mplex_conn_pair_factory(is_secure: bool) -> Tuple[Mplex, Swarm, Mplex, Swarm]: + muxer_opt = {MPLEX_PROTOCOL_ID: Mplex} + conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory( + is_secure, muxer_opt=muxer_opt + ) + return conn_0.muxed_conn, swarm_0, conn_1.muxed_conn, swarm_1 + + +async def mplex_stream_pair_factory( + is_secure: bool +) -> Tuple[MplexStream, Swarm, MplexStream, Swarm]: + mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory( + is_secure + ) + stream_0 = await mplex_conn_0.open_stream() + await asyncio.sleep(0.01) + stream_1: MplexStream + async with mplex_conn_1.streams_lock: + if len(mplex_conn_1.streams) != 1: + raise Exception("Mplex should not have any stream upon connection") + stream_1 = tuple(mplex_conn_1.streams.values())[0] + return stream_0, swarm_0, stream_1, swarm_1 async def net_stream_pair_factory( diff --git a/tests/libp2p/test_notify.py b/tests/libp2p/test_notify.py deleted file mode 100644 index b9a8707e..00000000 --- a/tests/libp2p/test_notify.py +++ /dev/null @@ -1,351 +0,0 @@ -""" -Test Notify and Notifee by ensuring that the proper events get -called, and that the stream passed into opened_stream is correct - -Note: Listen event does not get hit because MyNotifee is passed -into network after network has already started listening - -TODO: Add tests for closed_stream disconnected, listen_close when those -features are implemented in swarm -""" - -import multiaddr -import pytest - -from libp2p import initialize_default_swarm, new_node -from libp2p.crypto.rsa import create_new_key_pair -from libp2p.host.basic_host import BasicHost -from libp2p.network.notifee_interface import INotifee -from tests.constants import MAX_READ_LEN -from tests.utils import perform_two_host_set_up - -ACK = "ack:" - - -class MyNotifee(INotifee): - def __init__(self, events, val_to_append_to_event): - self.events = events - self.val_to_append_to_event = val_to_append_to_event - - async def opened_stream(self, network, stream): - self.events.append(["opened_stream" + self.val_to_append_to_event, stream]) - - async def closed_stream(self, network, stream): - pass - - async def connected(self, network, conn): - self.events.append(["connected" + self.val_to_append_to_event, conn]) - - async def disconnected(self, network, conn): - pass - - async def listen(self, network, _multiaddr): - self.events.append(["listened" + self.val_to_append_to_event, _multiaddr]) - - async def listen_close(self, network, _multiaddr): - pass - - -class InvalidNotifee: - def __init__(self): - pass - - async def opened_stream(self): - assert False - - async def closed_stream(self): - assert False - - async def connected(self): - assert False - - async def disconnected(self): - assert False - - async def listen(self): - assert False - - -@pytest.mark.asyncio -async def test_one_notifier(): - node_a, node_b = await perform_two_host_set_up() - - # Add notifee for node_a - events = [] - assert node_a.get_network().notify(MyNotifee(events, "0")) - - stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) - - # Ensure the connected and opened_stream events were hit in MyNotifee obj - # and that stream passed into opened_stream matches the stream created on - # node_a - assert events == [["connected0", stream.mplex_conn], ["opened_stream0", stream]] - - messages = ["hello", "hello"] - for message in messages: - expected_resp = ACK + message - await stream.write(message.encode()) - - response = (await stream.read(len(expected_resp))).decode() - - assert response == expected_resp - - # Success, terminate pending tasks. - - -@pytest.mark.asyncio -async def test_one_notifier_on_two_nodes(): - events_b = [] - messages = ["hello", "hello"] - - async def my_stream_handler(stream): - # Ensure the connected and opened_stream events were hit in Notifee obj - # and that the stream passed into opened_stream matches the stream created on - # node_b - assert events_b == [ - ["connectedb", stream.mplex_conn], - ["opened_streamb", stream], - ] - for message in messages: - read_string = (await stream.read(len(message))).decode() - - resp = ACK + read_string - await stream.write(resp.encode()) - - node_a, node_b = await perform_two_host_set_up(my_stream_handler) - - # Add notifee for node_a - events_a = [] - assert node_a.get_network().notify(MyNotifee(events_a, "a")) - - # Add notifee for node_b - assert node_b.get_network().notify(MyNotifee(events_b, "b")) - - stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) - - # Ensure the connected and opened_stream events were hit in MyNotifee obj - # and that stream passed into opened_stream matches the stream created on - # node_a - assert events_a == [["connecteda", stream.mplex_conn], ["opened_streama", stream]] - - for message in messages: - expected_resp = ACK + message - await stream.write(message.encode()) - - response = (await stream.read(len(expected_resp))).decode() - - assert response == expected_resp - - # Success, terminate pending tasks. - - -@pytest.mark.asyncio -async def test_one_notifier_on_two_nodes_with_listen(): - events_b = [] - messages = ["hello", "hello"] - - node_a_key_pair = create_new_key_pair() - node_a_transport_opt = ["/ip4/127.0.0.1/tcp/0"] - node_a = await new_node(node_a_key_pair, transport_opt=node_a_transport_opt) - await node_a.get_network().listen(multiaddr.Multiaddr(node_a_transport_opt[0])) - - # Set up node_b swarm to pass into host - node_b_key_pair = create_new_key_pair() - node_b_transport_opt = ["/ip4/127.0.0.1/tcp/0"] - node_b_multiaddr = multiaddr.Multiaddr(node_b_transport_opt[0]) - node_b_swarm = initialize_default_swarm( - node_b_key_pair, transport_opt=node_b_transport_opt - ) - node_b = BasicHost(node_b_swarm) - - async def my_stream_handler(stream): - # Ensure the listened, connected and opened_stream events were hit in Notifee obj - # and that the stream passed into opened_stream matches the stream created on - # node_b - assert events_b == [ - ["listenedb", node_b_multiaddr], - ["connectedb", stream.mplex_conn], - ["opened_streamb", stream], - ] - for message in messages: - read_string = (await stream.read(len(message))).decode() - resp = ACK + read_string - await stream.write(resp.encode()) - - # Add notifee for node_a - events_a = [] - assert node_a.get_network().notify(MyNotifee(events_a, "a")) - - # Add notifee for node_b - assert node_b.get_network().notify(MyNotifee(events_b, "b")) - - # start listen on node_b_swarm - await node_b.get_network().listen(node_b_multiaddr) - - node_b.set_stream_handler("/echo/1.0.0", my_stream_handler) - # Associate the peer with local ip address (see default parameters of Libp2p()) - node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) - stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) - - # Ensure the connected and opened_stream events were hit in MyNotifee obj - # and that stream passed into opened_stream matches the stream created on - # node_a - assert events_a == [["connecteda", stream.mplex_conn], ["opened_streama", stream]] - - for message in messages: - expected_resp = ACK + message - await stream.write(message.encode()) - - response = (await stream.read(len(expected_resp))).decode() - - assert response == expected_resp - - # Success, terminate pending tasks. - - -@pytest.mark.asyncio -async def test_two_notifiers(): - node_a, node_b = await perform_two_host_set_up() - - # Add notifee for node_a - events0 = [] - assert node_a.get_network().notify(MyNotifee(events0, "0")) - - events1 = [] - assert node_a.get_network().notify(MyNotifee(events1, "1")) - - stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) - - # Ensure the connected and opened_stream events were hit in both Notifee objs - # and that the stream passed into opened_stream matches the stream created on - # node_a - assert events0 == [["connected0", stream.mplex_conn], ["opened_stream0", stream]] - assert events1 == [["connected1", stream.mplex_conn], ["opened_stream1", stream]] - - messages = ["hello", "hello"] - for message in messages: - expected_resp = ACK + message - await stream.write(message.encode()) - - response = (await stream.read(len(expected_resp))).decode() - - assert response == expected_resp - - # Success, terminate pending tasks. - - -@pytest.mark.asyncio -async def test_ten_notifiers(): - num_notifiers = 10 - - node_a, node_b = await perform_two_host_set_up() - - # Add notifee for node_a - events_lst = [] - for i in range(num_notifiers): - events_lst.append([]) - assert node_a.get_network().notify(MyNotifee(events_lst[i], str(i))) - - stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) - - # Ensure the connected and opened_stream events were hit in both Notifee objs - # and that the stream passed into opened_stream matches the stream created on - # node_a - for i in range(num_notifiers): - assert events_lst[i] == [ - ["connected" + str(i), stream.mplex_conn], - ["opened_stream" + str(i), stream], - ] - - messages = ["hello", "hello"] - for message in messages: - expected_resp = ACK + message - await stream.write(message.encode()) - - response = (await stream.read(len(expected_resp))).decode() - - assert response == expected_resp - - # Success, terminate pending tasks. - - -@pytest.mark.asyncio -async def test_ten_notifiers_on_two_nodes(): - num_notifiers = 10 - events_lst_b = [] - - async def my_stream_handler(stream): - # Ensure the connected and opened_stream events were hit in all Notifee objs - # and that the stream passed into opened_stream matches the stream created on - # node_b - for i in range(num_notifiers): - assert events_lst_b[i] == [ - ["connectedb" + str(i), stream.mplex_conn], - ["opened_streamb" + str(i), stream], - ] - while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() - - resp = ACK + read_string - await stream.write(resp.encode()) - - node_a, node_b = await perform_two_host_set_up(my_stream_handler) - - # Add notifee for node_a and node_b - events_lst_a = [] - for i in range(num_notifiers): - events_lst_a.append([]) - events_lst_b.append([]) - assert node_a.get_network().notify(MyNotifee(events_lst_a[i], "a" + str(i))) - assert node_b.get_network().notify(MyNotifee(events_lst_b[i], "b" + str(i))) - - stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) - - # Ensure the connected and opened_stream events were hit in all Notifee objs - # and that the stream passed into opened_stream matches the stream created on - # node_a - for i in range(num_notifiers): - assert events_lst_a[i] == [ - ["connecteda" + str(i), stream.mplex_conn], - ["opened_streama" + str(i), stream], - ] - - messages = ["hello", "hello"] - for message in messages: - expected_resp = ACK + message - await stream.write(message.encode()) - - response = (await stream.read(len(expected_resp))).decode() - - assert response == expected_resp - - # Success, terminate pending tasks. - - -@pytest.mark.asyncio -async def test_invalid_notifee(): - num_notifiers = 10 - - node_a, node_b = await perform_two_host_set_up() - - # Add notifee for node_a - events_lst = [] - for _ in range(num_notifiers): - events_lst.append([]) - assert not node_a.get_network().notify(InvalidNotifee()) - - stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) - - # If this point is reached, this implies that the InvalidNotifee instance - # did not assert false, i.e. no functions of InvalidNotifee were called (which is correct - # given that InvalidNotifee should not have been added as a notifee) - messages = ["hello", "hello"] - for message in messages: - expected_resp = ACK + message - await stream.write(message.encode()) - - response = (await stream.read(len(expected_resp))).decode() - - assert response == expected_resp - - # Success, terminate pending tasks. diff --git a/tests/network/conftest.py b/tests/network/conftest.py index 47d5c5f0..018e822d 100644 --- a/tests/network/conftest.py +++ b/tests/network/conftest.py @@ -2,7 +2,11 @@ import asyncio import pytest -from tests.factories import net_stream_pair_factory, swarm_pair_factory +from tests.factories import ( + net_stream_pair_factory, + swarm_conn_pair_factory, + swarm_pair_factory, +) @pytest.fixture @@ -21,3 +25,12 @@ async def swarm_pair(is_host_secure): yield swarm_0, swarm_1 finally: await asyncio.gather(*[swarm_0.close(), swarm_1.close()]) + + +@pytest.fixture +async def swarm_conn_pair(is_host_secure): + conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory(is_host_secure) + try: + yield conn_0, conn_1 + finally: + await asyncio.gather(*[swarm_0.close(), swarm_1.close()]) diff --git a/tests/network/test_net_stream.py b/tests/network/test_net_stream.py index 80bed6ce..9229069f 100644 --- a/tests/network/test_net_stream.py +++ b/tests/network/test_net_stream.py @@ -7,9 +7,6 @@ from tests.constants import MAX_READ_LEN DATA = b"data_123" -# TODO: Move `muxed_stream` specific(currently we are using `MplexStream`) tests to its -# own file, after `generic_protocol_handler` is refactored out of `Mplex`. - @pytest.mark.asyncio async def test_net_stream_read_write(net_stream_pair): @@ -56,11 +53,9 @@ async def test_net_stream_read_until_eof(net_stream_pair): @pytest.mark.asyncio async def test_net_stream_read_after_remote_closed(net_stream_pair): stream_0, stream_1 = net_stream_pair - assert not stream_1.muxed_stream.event_remote_closed.is_set() await stream_0.write(DATA) await stream_0.close() await asyncio.sleep(0.01) - assert stream_1.muxed_stream.event_remote_closed.is_set() assert (await stream_1.read(MAX_READ_LEN)) == DATA with pytest.raises(StreamEOF): await stream_1.read(MAX_READ_LEN) diff --git a/tests/network/test_notify.py b/tests/network/test_notify.py new file mode 100644 index 00000000..aaf0ed55 --- /dev/null +++ b/tests/network/test_notify.py @@ -0,0 +1,112 @@ +""" +Test Notify and Notifee by ensuring that the proper events get +called, and that the stream passed into opened_stream is correct + +Note: Listen event does not get hit because MyNotifee is passed +into network after network has already started listening + +TODO: Add tests for closed_stream, listen_close when those +features are implemented in swarm +""" + +import asyncio +import enum + +import pytest + +from libp2p.network.notifee_interface import INotifee +from tests.configs import LISTEN_MADDR +from tests.factories import SwarmFactory +from tests.utils import connect_swarm + + +class Event(enum.Enum): + OpenedStream = 0 + ClosedStream = 1 # Not implemented + Connected = 2 + Disconnected = 3 + Listen = 4 + ListenClose = 5 # Not implemented + + +class MyNotifee(INotifee): + def __init__(self, events): + self.events = events + + async def opened_stream(self, network, stream): + self.events.append(Event.OpenedStream) + + async def closed_stream(self, network, stream): + # TODO: It is not implemented yet. + pass + + async def connected(self, network, conn): + self.events.append(Event.Connected) + + async def disconnected(self, network, conn): + self.events.append(Event.Disconnected) + + async def listen(self, network, _multiaddr): + self.events.append(Event.Listen) + + async def listen_close(self, network, _multiaddr): + # TODO: It is not implemented yet. + pass + + +@pytest.mark.asyncio +async def test_notify(is_host_secure): + swarms = [SwarmFactory(is_host_secure) for _ in range(2)] + + events_0_0 = [] + events_1_0 = [] + events_0_without_listen = [] + swarms[0].register_notifee(MyNotifee(events_0_0)) + swarms[1].register_notifee(MyNotifee(events_1_0)) + # Listen + await asyncio.gather(*[swarm.listen(LISTEN_MADDR) for swarm in swarms]) + + swarms[0].register_notifee(MyNotifee(events_0_without_listen)) + + # Connected + await connect_swarm(swarms[0], swarms[1]) + # OpenedStream: first + await swarms[0].new_stream(swarms[1].get_peer_id()) + # OpenedStream: second + await swarms[0].new_stream(swarms[1].get_peer_id()) + # OpenedStream: third, but different direction. + await swarms[1].new_stream(swarms[0].get_peer_id()) + + await asyncio.sleep(0.01) + + # TODO: Check `ClosedStream` and `ListenClose` events after they are ready. + + # Disconnected + await swarms[0].close_peer(swarms[1].get_peer_id()) + await asyncio.sleep(0.01) + + # Connected again, but different direction. + await connect_swarm(swarms[1], swarms[0]) + await asyncio.sleep(0.01) + + # Disconnected again, but different direction. + await swarms[1].close_peer(swarms[0].get_peer_id()) + await asyncio.sleep(0.01) + + expected_events_without_listen = [ + Event.Connected, + Event.OpenedStream, + Event.OpenedStream, + Event.OpenedStream, + Event.Disconnected, + Event.Connected, + Event.Disconnected, + ] + expected_events = [Event.Listen] + expected_events_without_listen + + assert events_0_0 == expected_events + assert events_1_0 == expected_events + assert events_0_without_listen == expected_events_without_listen + + # Clean up + await asyncio.gather(*[swarm.close() for swarm in swarms]) diff --git a/tests/network/test_swarm_conn.py b/tests/network/test_swarm_conn.py new file mode 100644 index 00000000..2abc7d0f --- /dev/null +++ b/tests/network/test_swarm_conn.py @@ -0,0 +1,45 @@ +import asyncio + +import pytest + + +@pytest.mark.asyncio +async def test_swarm_conn_close(swarm_conn_pair): + conn_0, conn_1 = swarm_conn_pair + + assert not conn_0.event_closed.is_set() + assert not conn_1.event_closed.is_set() + + await conn_0.close() + + await asyncio.sleep(0.01) + + assert conn_0.event_closed.is_set() + assert conn_1.event_closed.is_set() + assert conn_0 not in conn_0.swarm.connections.values() + assert conn_1 not in conn_1.swarm.connections.values() + + +@pytest.mark.asyncio +async def test_swarm_conn_streams(swarm_conn_pair): + conn_0, conn_1 = swarm_conn_pair + + assert len(await conn_0.get_streams()) == 0 + assert len(await conn_1.get_streams()) == 0 + + stream_0_0 = await conn_0.new_stream() + await asyncio.sleep(0.01) + assert len(await conn_0.get_streams()) == 1 + assert len(await conn_1.get_streams()) == 1 + + stream_0_1 = await conn_0.new_stream() + await asyncio.sleep(0.01) + assert len(await conn_0.get_streams()) == 2 + assert len(await conn_1.get_streams()) == 2 + + conn_0.remove_stream(stream_0_0) + assert len(await conn_0.get_streams()) == 1 + conn_0.remove_stream(stream_0_1) + assert len(await conn_0.get_streams()) == 0 + # Nothing happen if `stream_0_1` is not present or already removed. + conn_0.remove_stream(stream_0_1) diff --git a/tests/pubsub/test_pubsub.py b/tests/pubsub/test_pubsub.py index 34139494..29fdf363 100644 --- a/tests/pubsub/test_pubsub.py +++ b/tests/pubsub/test_pubsub.py @@ -233,7 +233,7 @@ class FakeNetStream: class FakeMplexConn(NamedTuple): peer_id: ID = ID(b"\x12\x20" + b"\x00" * 32) - mplex_conn = FakeMplexConn() + muxed_conn = FakeMplexConn() def __init__(self) -> None: self._queue = asyncio.Queue() diff --git a/tests/security/test_security_multistream.py b/tests/security/test_security_multistream.py index 26d31401..a9fe031f 100644 --- a/tests/security/test_security_multistream.py +++ b/tests/security/test_security_multistream.py @@ -53,8 +53,8 @@ async def perform_simple_test( node2_conn = node2.get_network().connections[peer_id_for_node(node1)] # Perform assertion - assertion_func(node1_conn.conn.secured_conn) - assertion_func(node2_conn.conn.secured_conn) + assertion_func(node1_conn.muxed_conn.secured_conn) + assertion_func(node2_conn.muxed_conn.secured_conn) # Success, terminate pending tasks. diff --git a/tests/stream_muxer/__init__.py b/tests/stream_muxer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/stream_muxer/conftest.py b/tests/stream_muxer/conftest.py new file mode 100644 index 00000000..b1d6c114 --- /dev/null +++ b/tests/stream_muxer/conftest.py @@ -0,0 +1,29 @@ +import asyncio + +import pytest + +from tests.factories import mplex_conn_pair_factory, mplex_stream_pair_factory + + +@pytest.fixture +async def mplex_conn_pair(is_host_secure): + mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory( + is_host_secure + ) + assert mplex_conn_0.initiator + assert not mplex_conn_1.initiator + try: + yield mplex_conn_0, mplex_conn_1 + finally: + await asyncio.gather(*[swarm_0.close(), swarm_1.close()]) + + +@pytest.fixture +async def mplex_stream_pair(is_host_secure): + mplex_stream_0, swarm_0, mplex_stream_1, swarm_1 = await mplex_stream_pair_factory( + is_host_secure + ) + try: + yield mplex_stream_0, mplex_stream_1 + finally: + await asyncio.gather(*[swarm_0.close(), swarm_1.close()]) diff --git a/tests/stream_muxer/test_mplex_conn.py b/tests/stream_muxer/test_mplex_conn.py new file mode 100644 index 00000000..6dc98ad6 --- /dev/null +++ b/tests/stream_muxer/test_mplex_conn.py @@ -0,0 +1,50 @@ +import asyncio + +import pytest + + +@pytest.mark.asyncio +async def test_mplex_conn(mplex_conn_pair): + conn_0, conn_1 = mplex_conn_pair + + assert len(conn_0.streams) == 0 + assert len(conn_1.streams) == 0 + assert not conn_0.event_shutting_down.is_set() + assert not conn_1.event_shutting_down.is_set() + assert not conn_0.event_closed.is_set() + assert not conn_1.event_closed.is_set() + + # Test: Open a stream, and both side get 1 more stream. + stream_0 = await conn_0.open_stream() + await asyncio.sleep(0.01) + assert len(conn_0.streams) == 1 + assert len(conn_1.streams) == 1 + # Test: From another side. + stream_1 = await conn_1.open_stream() + await asyncio.sleep(0.01) + assert len(conn_0.streams) == 2 + assert len(conn_1.streams) == 2 + + # Close from one side. + await conn_0.close() + # Sleep for a while for both side to handle `close`. + await asyncio.sleep(0.01) + # Test: Both side is closed. + assert conn_0.event_shutting_down.is_set() + assert conn_0.event_closed.is_set() + assert conn_1.event_shutting_down.is_set() + assert conn_1.event_closed.is_set() + # Test: All streams should have been closed. + assert stream_0.event_remote_closed.is_set() + assert stream_0.event_reset.is_set() + assert stream_0.event_local_closed.is_set() + assert conn_0.streams is None + # Test: All streams on the other side are also closed. + assert stream_1.event_remote_closed.is_set() + assert stream_1.event_reset.is_set() + assert stream_1.event_local_closed.is_set() + assert conn_1.streams is None + + # Test: No effect to close more than once between two side. + await conn_0.close() + await conn_1.close() diff --git a/tests/stream_muxer/test_mplex_stream.py b/tests/stream_muxer/test_mplex_stream.py new file mode 100644 index 00000000..e2bcb244 --- /dev/null +++ b/tests/stream_muxer/test_mplex_stream.py @@ -0,0 +1,182 @@ +import asyncio + +import pytest + +from libp2p.stream_muxer.mplex.exceptions import ( + MplexStreamClosed, + MplexStreamEOF, + MplexStreamReset, +) +from tests.constants import MAX_READ_LEN + +DATA = b"data_123" + + +@pytest.mark.asyncio +async def test_mplex_stream_read_write(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + await stream_0.write(DATA) + assert (await stream_1.read(MAX_READ_LEN)) == DATA + + +@pytest.mark.asyncio +async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair): + read_bytes = bytearray() + stream_0, stream_1 = mplex_stream_pair + + async def read_until_eof(): + read_bytes.extend(await stream_1.read()) + + task = asyncio.ensure_future(read_until_eof()) + + expected_data = bytearray() + + # Test: `read` doesn't return before `close` is called. + await stream_0.write(DATA) + expected_data.extend(DATA) + await asyncio.sleep(0.01) + assert len(read_bytes) == 0 + # Test: `read` doesn't return before `close` is called. + await stream_0.write(DATA) + expected_data.extend(DATA) + await asyncio.sleep(0.01) + assert len(read_bytes) == 0 + + # Test: Close the stream, `read` returns, and receive previous sent data. + await stream_0.close() + await asyncio.sleep(0.01) + assert read_bytes == expected_data + + task.cancel() + + +@pytest.mark.asyncio +async def test_mplex_stream_read_after_remote_closed(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + assert not stream_1.event_remote_closed.is_set() + await stream_0.write(DATA) + await stream_0.close() + await asyncio.sleep(0.01) + assert stream_1.event_remote_closed.is_set() + assert (await stream_1.read(MAX_READ_LEN)) == DATA + with pytest.raises(MplexStreamEOF): + await stream_1.read(MAX_READ_LEN) + + +@pytest.mark.asyncio +async def test_mplex_stream_read_after_local_reset(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + await stream_0.reset() + with pytest.raises(MplexStreamReset): + await stream_0.read(MAX_READ_LEN) + + +@pytest.mark.asyncio +async def test_mplex_stream_read_after_remote_reset(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + await stream_0.write(DATA) + await stream_0.reset() + # Sleep to let `stream_1` receive the message. + await asyncio.sleep(0.01) + with pytest.raises(MplexStreamReset): + await stream_1.read(MAX_READ_LEN) + + +@pytest.mark.asyncio +async def test_mplex_stream_read_after_remote_closed_and_reset(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + await stream_0.write(DATA) + await stream_0.close() + await stream_0.reset() + # Sleep to let `stream_1` receive the message. + await asyncio.sleep(0.01) + assert (await stream_1.read(MAX_READ_LEN)) == DATA + + +@pytest.mark.asyncio +async def test_mplex_stream_write_after_local_closed(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + await stream_0.write(DATA) + await stream_0.close() + with pytest.raises(MplexStreamClosed): + await stream_0.write(DATA) + + +@pytest.mark.asyncio +async def test_mplex_stream_write_after_local_reset(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + await stream_0.reset() + with pytest.raises(MplexStreamClosed): + await stream_0.write(DATA) + + +@pytest.mark.asyncio +async def test_mplex_stream_write_after_remote_reset(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + await stream_1.reset() + await asyncio.sleep(0.01) + with pytest.raises(MplexStreamClosed): + await stream_0.write(DATA) + + +@pytest.mark.asyncio +async def test_mplex_stream_both_close(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + # Flags are not set initially. + assert not stream_0.event_local_closed.is_set() + assert not stream_1.event_local_closed.is_set() + assert not stream_0.event_remote_closed.is_set() + assert not stream_1.event_remote_closed.is_set() + # Streams are present in their `mplex_conn`. + assert stream_0 in stream_0.muxed_conn.streams.values() + assert stream_1 in stream_1.muxed_conn.streams.values() + + # Test: Close one side. + await stream_0.close() + await asyncio.sleep(0.01) + + assert stream_0.event_local_closed.is_set() + assert not stream_1.event_local_closed.is_set() + assert not stream_0.event_remote_closed.is_set() + assert stream_1.event_remote_closed.is_set() + # Streams are still present in their `mplex_conn`. + assert stream_0 in stream_0.muxed_conn.streams.values() + assert stream_1 in stream_1.muxed_conn.streams.values() + + # Test: Close the other side. + await stream_1.close() + await asyncio.sleep(0.01) + # Both sides are closed. + assert stream_0.event_local_closed.is_set() + assert stream_1.event_local_closed.is_set() + assert stream_0.event_remote_closed.is_set() + assert stream_1.event_remote_closed.is_set() + # Streams are removed from their `mplex_conn`. + assert stream_0 not in stream_0.muxed_conn.streams.values() + assert stream_1 not in stream_1.muxed_conn.streams.values() + + # Test: Reset after both close. + await stream_0.reset() + + +@pytest.mark.asyncio +async def test_mplex_stream_reset(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + await stream_0.reset() + await asyncio.sleep(0.01) + + # Both sides are closed. + assert stream_0.event_local_closed.is_set() + assert stream_1.event_local_closed.is_set() + assert stream_0.event_remote_closed.is_set() + assert stream_1.event_remote_closed.is_set() + # Streams are removed from their `mplex_conn`. + assert stream_0 not in stream_0.muxed_conn.streams.values() + assert stream_1 not in stream_1.muxed_conn.streams.values() + + # `close` should do nothing. + await stream_0.close() + await stream_1.close() + # `reset` should do nothing as well. + await stream_0.reset() + await stream_1.reset() diff --git a/tests_interop/test_bindings.py b/tests_interop/test_bindings.py index 1e78ff43..dc0a2707 100644 --- a/tests_interop/test_bindings.py +++ b/tests_interop/test_bindings.py @@ -22,6 +22,5 @@ async def test_connect(hosts, p2pds): assert len(host.get_network().connections) == 1 # Test: `disconnect` from Go await p2pd.control.disconnect(host.get_id()) - # FIXME: Failed to handle disconnect await asyncio.sleep(0.01) assert len(host.get_network().connections) == 0