diff --git a/.travis.yml b/.travis.yml index 6b6b3ccc..1c7856d6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,16 +5,16 @@ matrix: - python: 3.6-dev dist: xenial env: TOXENV=py36-test - - python: 3.7-dev + - python: 3.7 dist: xenial env: TOXENV=py37-test - - python: 3.7-dev + - python: 3.7 dist: xenial env: TOXENV=lint - - python: 3.7-dev + - python: 3.7 dist: xenial env: TOXENV=docs - - python: 3.7-dev + - python: 3.7 dist: xenial env: TOXENV=py37-interop sudo: true diff --git a/Makefile b/Makefile index b82dbca0..5620e595 100644 --- a/Makefile +++ b/Makefile @@ -51,7 +51,7 @@ lint: black --check $(FILES_TO_LINT) isort --recursive --check-only --diff $(FILES_TO_LINT) docformatter --pre-summary-newline --check --recursive $(FILES_TO_LINT) - tox -elint # This is probably redundant, but just in case... + tox -e lint # This is probably redundant, but just in case... lint-roll: isort --recursive $(FILES_TO_LINT) diff --git a/docs/libp2p.pubsub.rst b/docs/libp2p.pubsub.rst index d217772b..38dce6ff 100644 --- a/docs/libp2p.pubsub.rst +++ b/docs/libp2p.pubsub.rst @@ -11,6 +11,22 @@ Subpackages Submodules ---------- +libp2p.pubsub.abc module +------------------------ + +.. automodule:: libp2p.pubsub.abc + :members: + :undoc-members: + :show-inheritance: + +libp2p.pubsub.exceptions module +------------------------------- + +.. automodule:: libp2p.pubsub.exceptions + :members: + :undoc-members: + :show-inheritance: + libp2p.pubsub.floodsub module ----------------------------- @@ -51,10 +67,10 @@ libp2p.pubsub.pubsub\_notifee module :undoc-members: :show-inheritance: -libp2p.pubsub.pubsub\_router\_interface module ----------------------------------------------- +libp2p.pubsub.subscription module +--------------------------------- -.. automodule:: libp2p.pubsub.pubsub_router_interface +.. automodule:: libp2p.pubsub.subscription :members: :undoc-members: :show-inheritance: diff --git a/examples/chat/chat.py b/examples/chat/chat.py index 24c92699..81f3891b 100755 --- a/examples/chat/chat.py +++ b/examples/chat/chat.py @@ -1,11 +1,10 @@ import argparse -import asyncio import sys -import urllib.request import multiaddr +import trio -from libp2p import new_node +from libp2p import new_host from libp2p.network.stream.net_stream_interface import INetStream from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.typing import TProtocol @@ -26,53 +25,47 @@ async def read_data(stream: INetStream) -> None: async def write_data(stream: INetStream) -> None: - loop = asyncio.get_event_loop() + async_f = trio.wrap_file(sys.stdin) while True: - line = await loop.run_in_executor(None, sys.stdin.readline) + line = await async_f.readline() await stream.write(line.encode()) -async def run(port: int, destination: str, localhost: bool) -> None: - if localhost: - ip = "127.0.0.1" - else: - ip = urllib.request.urlopen("https://v4.ident.me/").read().decode("utf8") - transport_opt = f"/ip4/{ip}/tcp/{port}" - host = await new_node(transport_opt=[transport_opt]) +async def run(port: int, destination: str) -> None: + localhost_ip = "127.0.0.1" + listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") + host = new_host() + async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery: + if not destination: # its the server - await host.get_network().listen(multiaddr.Multiaddr(transport_opt)) + async def stream_handler(stream: INetStream) -> None: + nursery.start_soon(read_data, stream) + nursery.start_soon(write_data, stream) - if not destination: # its the server + host.set_stream_handler(PROTOCOL_ID, stream_handler) - async def stream_handler(stream: INetStream) -> None: - asyncio.ensure_future(read_data(stream)) - asyncio.ensure_future(write_data(stream)) + print( + f"Run 'python ./examples/chat/chat.py " + f"-p {int(port) + 1} " + f"-d /ip4/{localhost_ip}/tcp/{port}/p2p/{host.get_id().pretty()}' " + "on another console." + ) + print("Waiting for incoming connection...") - host.set_stream_handler(PROTOCOL_ID, stream_handler) + else: # its the client + maddr = multiaddr.Multiaddr(destination) + info = info_from_p2p_addr(maddr) + # Associate the peer with local ip address + await host.connect(info) + # Start a stream with the destination. + # Multiaddress of the destination peer is fetched from the peerstore using 'peerId'. + stream = await host.new_stream(info.peer_id, [PROTOCOL_ID]) - localhost_opt = " --localhost" if localhost else "" + nursery.start_soon(read_data, stream) + nursery.start_soon(write_data, stream) + print(f"Connected to peer {info.addrs[0]}") - print( - f"Run 'python ./examples/chat/chat.py" - + localhost_opt - + f" -p {int(port) + 1} -d /ip4/{ip}/tcp/{port}/p2p/{host.get_id().pretty()}'" - + " on another console." - ) - print("Waiting for incoming connection...") - - else: # its the client - maddr = multiaddr.Multiaddr(destination) - info = info_from_p2p_addr(maddr) - # Associate the peer with local ip address - await host.connect(info) - - # Start a stream with the destination. - # Multiaddress of the destination peer is fetched from the peerstore using 'peerId'. - stream = await host.new_stream(info.peer_id, [PROTOCOL_ID]) - - asyncio.ensure_future(read_data(stream)) - asyncio.ensure_future(write_data(stream)) - print("Connected to peer %s" % info.addrs[0]) + await trio.sleep_forever() def main() -> None: @@ -86,11 +79,6 @@ def main() -> None: "/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q" ) parser = argparse.ArgumentParser(description=description) - parser.add_argument( - "--debug", - action="store_true", - help="generate the same node ID on every execution", - ) parser.add_argument( "-p", "--port", default=8000, type=int, help="source port number" ) @@ -100,26 +88,15 @@ def main() -> None: type=str, help=f"destination multiaddr string, e.g. {example_maddr}", ) - parser.add_argument( - "-l", - "--localhost", - dest="localhost", - action="store_true", - help="flag indicating if localhost should be used or an external IP", - ) args = parser.parse_args() if not args.port: raise RuntimeError("was not able to determine a local port") - loop = asyncio.get_event_loop() try: - asyncio.ensure_future(run(args.port, args.destination, args.localhost)) - loop.run_forever() + trio.run(run, *(args.port, args.destination)) except KeyboardInterrupt: pass - finally: - loop.close() if __name__ == "__main__": diff --git a/examples/echo/echo.py b/examples/echo/echo.py index 3f3ed33e..5ea8ab4a 100644 --- a/examples/echo/echo.py +++ b/examples/echo/echo.py @@ -1,10 +1,9 @@ import argparse -import asyncio -import urllib.request import multiaddr +import trio -from libp2p import new_node +from libp2p import new_host from libp2p.crypto.secp256k1 import create_new_key_pair from libp2p.network.stream.net_stream_interface import INetStream from libp2p.peer.peerinfo import info_from_p2p_addr @@ -20,12 +19,9 @@ async def _echo_stream_handler(stream: INetStream) -> None: await stream.close() -async def run(port: int, destination: str, localhost: bool, seed: int = None) -> None: - if localhost: - ip = "127.0.0.1" - else: - ip = urllib.request.urlopen("https://v4.ident.me/").read().decode("utf8") - transport_opt = f"/ip4/{ip}/tcp/{port}" +async def run(port: int, destination: str, seed: int = None) -> None: + localhost_ip = "127.0.0.1" + listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") if seed: import random @@ -38,47 +34,43 @@ async def run(port: int, destination: str, localhost: bool, seed: int = None) -> secret = secrets.token_bytes(32) - host = await new_node( - key_pair=create_new_key_pair(secret), transport_opt=[transport_opt] - ) + host = new_host(key_pair=create_new_key_pair(secret)) + async with host.run(listen_addrs=[listen_addr]): - print(f"I am {host.get_id().to_string()}") + print(f"I am {host.get_id().to_string()}") - await host.get_network().listen(multiaddr.Multiaddr(transport_opt)) + if not destination: # its the server - if not destination: # its the server + host.set_stream_handler(PROTOCOL_ID, _echo_stream_handler) - host.set_stream_handler(PROTOCOL_ID, _echo_stream_handler) + print( + f"Run 'python ./examples/echo/echo.py " + f"-p {int(port) + 1} " + f"-d /ip4/{localhost_ip}/tcp/{port}/p2p/{host.get_id().pretty()}' " + "on another console." + ) + print("Waiting for incoming connections...") + await trio.sleep_forever() - localhost_opt = " --localhost" if localhost else "" + else: # its the client + maddr = multiaddr.Multiaddr(destination) + info = info_from_p2p_addr(maddr) + # Associate the peer with local ip address + await host.connect(info) - print( - f"Run 'python ./examples/echo/echo.py" - + localhost_opt - + f" -p {int(port) + 1} -d /ip4/{ip}/tcp/{port}/p2p/{host.get_id().pretty()}'" - + " on another console." - ) - print("Waiting for incoming connections...") + # Start a stream with the destination. + # Multiaddress of the destination peer is fetched from the peerstore using 'peerId'. + stream = await host.new_stream(info.peer_id, [PROTOCOL_ID]) - else: # its the client - maddr = multiaddr.Multiaddr(destination) - info = info_from_p2p_addr(maddr) - # Associate the peer with local ip address - await host.connect(info) + msg = b"hi, there!\n" - # Start a stream with the destination. - # Multiaddress of the destination peer is fetched from the peerstore using 'peerId'. - stream = await host.new_stream(info.peer_id, [PROTOCOL_ID]) + await stream.write(msg) + # Notify the other side about EOF + await stream.close() + response = await stream.read() - msg = b"hi, there!\n" - - await stream.write(msg) - # Notify the other side about EOF - await stream.close() - response = await stream.read() - - print(f"Sent: {msg}") - print(f"Got: {response}") + print(f"Sent: {msg}") + print(f"Got: {response}") def main() -> None: @@ -94,11 +86,6 @@ def main() -> None: "/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q" ) parser = argparse.ArgumentParser(description=description) - parser.add_argument( - "--debug", - action="store_true", - help="generate the same node ID on every execution", - ) parser.add_argument( "-p", "--port", default=8000, type=int, help="source port number" ) @@ -108,13 +95,6 @@ def main() -> None: type=str, help=f"destination multiaddr string, e.g. {example_maddr}", ) - parser.add_argument( - "-l", - "--localhost", - dest="localhost", - action="store_true", - help="flag indicating if localhost should be used or an external IP", - ) parser.add_argument( "-s", "--seed", @@ -126,16 +106,10 @@ def main() -> None: if not args.port: raise RuntimeError("was not able to determine a local port") - loop = asyncio.get_event_loop() try: - asyncio.ensure_future( - run(args.port, args.destination, args.localhost, args.seed) - ) - loop.run_forever() + trio.run(run, args.port, args.destination, args.seed) except KeyboardInterrupt: pass - finally: - loop.close() if __name__ == "__main__": diff --git a/libp2p/__init__.py b/libp2p/__init__.py index 8813fd37..4d91b9da 100644 --- a/libp2p/__init__.py +++ b/libp2p/__init__.py @@ -1,12 +1,9 @@ -import asyncio -from typing import Sequence - from libp2p.crypto.keys import KeyPair from libp2p.crypto.rsa import create_new_key_pair from libp2p.host.basic_host import BasicHost from libp2p.host.host_interface import IHost from libp2p.host.routed_host import RoutedHost -from libp2p.network.network_interface import INetwork +from libp2p.network.network_interface import INetworkService from libp2p.network.swarm import Swarm from libp2p.peer.id import ID from libp2p.peer.peerstore import PeerStore @@ -21,18 +18,6 @@ from libp2p.transport.upgrader import TransportUpgrader from libp2p.typing import TProtocol -async def cleanup_done_tasks() -> None: - """clean up asyncio done tasks to free up resources.""" - while True: - for task in asyncio.all_tasks(): - if task.done(): - await task - - # Need not run often - # Some sleep necessary to context switch - await asyncio.sleep(3) - - def generate_new_rsa_identity() -> KeyPair: return create_new_key_pair() @@ -42,29 +27,28 @@ def generate_peer_id_from(key_pair: KeyPair) -> ID: return ID.from_pubkey(public_key) -def initialize_default_swarm( - key_pair: KeyPair, - id_opt: ID = None, - transport_opt: Sequence[str] = None, +def new_swarm( + key_pair: KeyPair = None, muxer_opt: TMuxerOptions = None, sec_opt: TSecurityOptions = None, peerstore_opt: IPeerStore = None, -) -> Swarm: +) -> INetworkService: """ - initialize swarm when no swarm is passed in. + Create a swarm instance based on the parameters. - :param id_opt: optional id for host - :param transport_opt: optional choice of transport upgrade + :param key_pair: optional choice of the ``KeyPair`` :param muxer_opt: optional choice of stream muxer :param sec_opt: optional choice of security upgrade :param peerstore_opt: optional peerstore :return: return a default swarm instance """ - if not id_opt: - id_opt = generate_peer_id_from(key_pair) + if key_pair is None: + key_pair = generate_new_rsa_identity() - # TODO: Parse `transport_opt` to determine transport + id_opt = generate_peer_id_from(key_pair) + + # TODO: Parse `listen_addrs` to determine transport transport = TCP() muxer_transports_by_protocol = muxer_opt or {MPLEX_PROTOCOL_ID: Mplex} @@ -80,57 +64,35 @@ def initialize_default_swarm( # Store our key pair in peerstore peerstore.add_key_pair(id_opt, key_pair) - # TODO: Initialize discovery if not presented return Swarm(id_opt, peerstore, upgrader, transport) -async def new_node( +def new_host( key_pair: KeyPair = None, - swarm_opt: INetwork = None, - transport_opt: Sequence[str] = None, muxer_opt: TMuxerOptions = None, sec_opt: TSecurityOptions = None, peerstore_opt: IPeerStore = None, disc_opt: IPeerRouting = None, -) -> BasicHost: +) -> IHost: """ - create new libp2p node. + Create a new libp2p host based on the given parameters. - :param key_pair: key pair for deriving an identity - :param swarm_opt: optional swarm - :param id_opt: optional id for host - :param transport_opt: optional choice of transport upgrade + :param key_pair: optional choice of the ``KeyPair`` :param muxer_opt: optional choice of stream muxer :param sec_opt: optional choice of security upgrade :param peerstore_opt: optional peerstore :param disc_opt: optional discovery :return: return a host instance """ - - if not key_pair: - key_pair = generate_new_rsa_identity() - - id_opt = generate_peer_id_from(key_pair) - - if not swarm_opt: - swarm_opt = initialize_default_swarm( - key_pair=key_pair, - id_opt=id_opt, - transport_opt=transport_opt, - muxer_opt=muxer_opt, - sec_opt=sec_opt, - peerstore_opt=peerstore_opt, - ) - - # TODO enable support for other host type - # TODO routing unimplemented - host: IHost # If not explicitly typed, MyPy raises error + swarm = new_swarm( + key_pair=key_pair, + muxer_opt=muxer_opt, + sec_opt=sec_opt, + peerstore_opt=peerstore_opt, + ) + host: IHost if disc_opt: - host = RoutedHost(swarm_opt, disc_opt) + host = RoutedHost(swarm, disc_opt) else: - host = BasicHost(swarm_opt) - - # Kick off cleanup job - asyncio.ensure_future(cleanup_done_tasks()) - + host = BasicHost(swarm) return host diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index 91d24c93..cc5ff8c6 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -1,12 +1,14 @@ import logging -from typing import TYPE_CHECKING, List, Sequence +from typing import TYPE_CHECKING, AsyncIterator, List, Sequence +from async_generator import asynccontextmanager +from async_service import background_trio_service import multiaddr from libp2p.crypto.keys import PrivateKey, PublicKey from libp2p.host.defaults import get_default_protocols from libp2p.host.exceptions import StreamFailure -from libp2p.network.network_interface import INetwork +from libp2p.network.network_interface import INetworkService from libp2p.network.stream.net_stream_interface import INetStream from libp2p.peer.id import ID from libp2p.peer.peerinfo import PeerInfo @@ -39,7 +41,7 @@ class BasicHost(IHost): right after a stream is initialized. """ - _network: INetwork + _network: INetworkService peerstore: IPeerStore multiselect: Multiselect @@ -47,7 +49,7 @@ class BasicHost(IHost): def __init__( self, - network: INetwork, + network: INetworkService, default_protocols: "OrderedDict[TProtocol, StreamHandlerFn]" = None, ) -> None: self._network = network @@ -70,7 +72,7 @@ class BasicHost(IHost): def get_private_key(self) -> PrivateKey: return self.peerstore.privkey(self.get_id()) - def get_network(self) -> INetwork: + def get_network(self) -> INetworkService: """ :return: network instance of host """ @@ -101,6 +103,20 @@ class BasicHost(IHost): addrs.append(addr.encapsulate(p2p_part)) return addrs + @asynccontextmanager + async def run( + self, listen_addrs: Sequence[multiaddr.Multiaddr] + ) -> AsyncIterator[None]: + """ + run the host instance and listen to ``listen_addrs``. + + :param listen_addrs: a sequence of multiaddrs that we want to listen to + """ + network = self.get_network() + async with background_trio_service(network): + await network.listen(*listen_addrs) + yield + def set_stream_handler( self, protocol_id: TProtocol, stream_handler: StreamHandlerFn ) -> None: diff --git a/libp2p/host/host_interface.py b/libp2p/host/host_interface.py index 43f4ac40..59146e7f 100644 --- a/libp2p/host/host_interface.py +++ b/libp2p/host/host_interface.py @@ -1,10 +1,10 @@ from abc import ABC, abstractmethod -from typing import Any, List, Sequence +from typing import Any, AsyncContextManager, List, Sequence import multiaddr from libp2p.crypto.keys import PrivateKey, PublicKey -from libp2p.network.network_interface import INetwork +from libp2p.network.network_interface import INetworkService from libp2p.network.stream.net_stream_interface import INetStream from libp2p.peer.id import ID from libp2p.peer.peerinfo import PeerInfo @@ -31,7 +31,7 @@ class IHost(ABC): """ @abstractmethod - def get_network(self) -> INetwork: + def get_network(self) -> INetworkService: """ :return: network instance of host """ @@ -49,6 +49,16 @@ class IHost(ABC): :return: all the multiaddr addresses this host is listening to """ + @abstractmethod + def run( + self, listen_addrs: Sequence[multiaddr.Multiaddr] + ) -> AsyncContextManager[None]: + """ + run the host instance and listen to ``listen_addrs``. + + :param listen_addrs: a sequence of multiaddrs that we want to listen to + """ + @abstractmethod def set_stream_handler( self, protocol_id: TProtocol, stream_handler: StreamHandlerFn diff --git a/libp2p/host/ping.py b/libp2p/host/ping.py index 3144ef4d..01102451 100644 --- a/libp2p/host/ping.py +++ b/libp2p/host/ping.py @@ -1,6 +1,7 @@ -import asyncio import logging +import trio + from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset from libp2p.network.stream.net_stream_interface import INetStream from libp2p.peer.id import ID as PeerID @@ -17,8 +18,9 @@ async def _handle_ping(stream: INetStream, peer_id: PeerID) -> bool: """Return a boolean indicating if we expect more pings from the peer at ``peer_id``.""" try: - payload = await asyncio.wait_for(stream.read(PING_LENGTH), RESP_TIMEOUT) - except asyncio.TimeoutError as error: + with trio.fail_after(RESP_TIMEOUT): + payload = await stream.read(PING_LENGTH) + except trio.TooSlowError as error: logger.debug("Timed out waiting for ping from %s: %s", peer_id, error) raise except StreamEOF: diff --git a/libp2p/host/routed_host.py b/libp2p/host/routed_host.py index 78b6fa54..91264c71 100644 --- a/libp2p/host/routed_host.py +++ b/libp2p/host/routed_host.py @@ -1,6 +1,6 @@ from libp2p.host.basic_host import BasicHost from libp2p.host.exceptions import ConnectionFailure -from libp2p.network.network_interface import INetwork +from libp2p.network.network_interface import INetworkService from libp2p.peer.peerinfo import PeerInfo from libp2p.routing.interfaces import IPeerRouting @@ -10,7 +10,7 @@ from libp2p.routing.interfaces import IPeerRouting class RoutedHost(BasicHost): _router: IPeerRouting - def __init__(self, network: INetwork, router: IPeerRouting): + def __init__(self, network: INetworkService, router: IPeerRouting): super().__init__(network) self._router = router diff --git a/libp2p/io/abc.py b/libp2p/io/abc.py index eea7b72f..8f2c7582 100644 --- a/libp2p/io/abc.py +++ b/libp2p/io/abc.py @@ -8,7 +8,7 @@ class Closer(ABC): class Reader(ABC): @abstractmethod - async def read(self, n: int = -1) -> bytes: + async def read(self, n: int = None) -> bytes: ... diff --git a/libp2p/io/msgio.py b/libp2p/io/msgio.py index f60b0ff9..32c0d09e 100644 --- a/libp2p/io/msgio.py +++ b/libp2p/io/msgio.py @@ -54,7 +54,7 @@ class MsgIOReader(ReadCloser): self.read_closer = read_closer self.next_length = None - async def read(self, n: int = -1) -> bytes: + async def read(self, n: int = None) -> bytes: return await self.read_msg() async def read_msg(self) -> bytes: diff --git a/libp2p/io/trio.py b/libp2p/io/trio.py new file mode 100644 index 00000000..465e4eaa --- /dev/null +++ b/libp2p/io/trio.py @@ -0,0 +1,40 @@ +import logging + +import trio + +from libp2p.io.abc import ReadWriteCloser +from libp2p.io.exceptions import IOException + +logger = logging.getLogger("libp2p.io.trio") + + +class TrioTCPStream(ReadWriteCloser): + stream: trio.SocketStream + # NOTE: Add both read and write lock to avoid `trio.BusyResourceError` + read_lock: trio.Lock + write_lock: trio.Lock + + def __init__(self, stream: trio.SocketStream) -> None: + self.stream = stream + self.read_lock = trio.Lock() + self.write_lock = trio.Lock() + + async def write(self, data: bytes) -> None: + """Raise `RawConnError` if the underlying connection breaks.""" + async with self.write_lock: + try: + await self.stream.send_all(data) + except (trio.ClosedResourceError, trio.BrokenResourceError) as error: + raise IOException from error + + async def read(self, n: int = None) -> bytes: + async with self.read_lock: + if n is not None and n == 0: + return b"" + try: + return await self.stream.receive_some(n) + except (trio.ClosedResourceError, trio.BrokenResourceError) as error: + raise IOException from error + + async def close(self) -> None: + await self.stream.aclose() diff --git a/libp2p/network/connection/net_connection_interface.py b/libp2p/network/connection/net_connection_interface.py index e308ad65..f1bcac24 100644 --- a/libp2p/network/connection/net_connection_interface.py +++ b/libp2p/network/connection/net_connection_interface.py @@ -1,6 +1,8 @@ from abc import abstractmethod from typing import Tuple +import trio + from libp2p.io.abc import Closer from libp2p.network.stream.net_stream_interface import INetStream from libp2p.stream_muxer.abc import IMuxedConn @@ -8,11 +10,12 @@ from libp2p.stream_muxer.abc import IMuxedConn class INetConn(Closer): muxed_conn: IMuxedConn + event_started: trio.Event @abstractmethod async def new_stream(self) -> INetStream: ... @abstractmethod - async def get_streams(self) -> Tuple[INetStream, ...]: + def get_streams(self) -> Tuple[INetStream, ...]: ... diff --git a/libp2p/network/connection/raw_connection.py b/libp2p/network/connection/raw_connection.py index 25211c88..2d8409f7 100644 --- a/libp2p/network/connection/raw_connection.py +++ b/libp2p/network/connection/raw_connection.py @@ -1,46 +1,26 @@ -import asyncio -import sys +from libp2p.io.abc import ReadWriteCloser +from libp2p.io.exceptions import IOException from .exceptions import RawConnError from .raw_connection_interface import IRawConnection class RawConnection(IRawConnection): - reader: asyncio.StreamReader - writer: asyncio.StreamWriter + stream: ReadWriteCloser is_initiator: bool - _drain_lock: asyncio.Lock - - def __init__( - self, - reader: asyncio.StreamReader, - writer: asyncio.StreamWriter, - initiator: bool, - ) -> None: - self.reader = reader - self.writer = writer + def __init__(self, stream: ReadWriteCloser, initiator: bool) -> None: + self.stream = stream self.is_initiator = initiator - self._drain_lock = asyncio.Lock() - async def write(self, data: bytes) -> None: """Raise `RawConnError` if the underlying connection breaks.""" - # Detect if underlying transport is closing before write data to it - # ref: https://github.com/ethereum/trinity/pull/614 - if self.writer.transport.is_closing(): - raise RawConnError("Transport is closing") - self.writer.write(data) - # Reference: https://github.com/ethereum/lahja/blob/93610b2eb46969ff1797e0748c7ac2595e130aef/lahja/asyncio/endpoint.py#L99-L102 # noqa: E501 - # Use a lock to serialize drain() calls. Circumvents this bug: - # https://bugs.python.org/issue29930 - async with self._drain_lock: - try: - await self.writer.drain() - except ConnectionResetError as error: - raise RawConnError() from error + try: + await self.stream.write(data) + except IOException as error: + raise RawConnError from error - async def read(self, n: int = -1) -> bytes: + async def read(self, n: int = None) -> bytes: """ Read up to ``n`` bytes from the underlying stream. This call is delegated directly to the underlying ``self.reader``. @@ -48,18 +28,9 @@ class RawConnection(IRawConnection): Raise `RawConnError` if the underlying connection breaks """ try: - return await self.reader.read(n) - except ConnectionResetError as error: - raise RawConnError() from error + return await self.stream.read(n) + except IOException as error: + raise RawConnError from error async def close(self) -> None: - if self.writer.transport.is_closing(): - return - self.writer.close() - if sys.version_info < (3, 7): - return - try: - await self.writer.wait_closed() - # In case the connection is already reset. - except ConnectionResetError: - return + await self.stream.close() diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index 29d544eb..baa9df50 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -1,5 +1,6 @@ -import asyncio -from typing import TYPE_CHECKING, Any, Awaitable, List, Set, Tuple +from typing import TYPE_CHECKING, Set, Tuple + +import trio from libp2p.network.connection.net_connection_interface import INetConn from libp2p.network.stream.net_stream import NetStream @@ -19,90 +20,78 @@ class SwarmConn(INetConn): muxed_conn: IMuxedConn swarm: "Swarm" streams: Set[NetStream] - event_closed: asyncio.Event - - _tasks: List["asyncio.Future[Any]"] + event_closed: trio.Event def __init__(self, muxed_conn: IMuxedConn, swarm: "Swarm") -> None: self.muxed_conn = muxed_conn self.swarm = swarm self.streams = set() - self.event_closed = asyncio.Event() + self.event_closed = trio.Event() + self.event_started = trio.Event() - self._tasks = [] + @property + def is_closed(self) -> bool: + return self.event_closed.is_set() async def close(self) -> None: if self.event_closed.is_set(): return self.event_closed.set() + await self._cleanup() + + async def _cleanup(self) -> None: self.swarm.remove_conn(self) await self.muxed_conn.close() # This is just for cleaning up state. The connection has already been closed. # We *could* optimize this but it really isn't worth it. - for stream in self.streams: + for stream in self.streams.copy(): await stream.reset() # Force context switch for stream handlers to process the stream reset event we just emit # before we cancel the stream handler tasks. - await asyncio.sleep(0.1) + await trio.sleep(0.1) - for task in self._tasks: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - # Schedule `self._notify_disconnected` to make it execute after `close` is finished. - self._notify_disconnected() + await self._notify_disconnected() async def _handle_new_streams(self) -> None: - while True: - try: - stream = await self.muxed_conn.accept_stream() - except MuxedConnUnavailable: - # If there is anything wrong in the MuxedConn, - # we should break the loop and close the connection. - break - # Asynchronously handle the accepted stream, to avoid blocking the next stream. - await self.run_task(self._handle_muxed_stream(stream)) - - await self.close() - - async def _call_stream_handler(self, net_stream: NetStream) -> None: - try: - await self.swarm.common_stream_handler(net_stream) - # TODO: More exact exceptions - except Exception: - # TODO: Emit logs. - # TODO: Clean up and remove the stream from SwarmConn if there is anything wrong. - self.remove_stream(net_stream) + self.event_started.set() + async with trio.open_nursery() as nursery: + while True: + try: + stream = await self.muxed_conn.accept_stream() + except MuxedConnUnavailable: + await self.close() + break + # Asynchronously handle the accepted stream, to avoid blocking the next stream. + nursery.start_soon(self._handle_muxed_stream, stream) async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None: - net_stream = self._add_stream(muxed_stream) - if self.swarm.common_stream_handler is not None: - await self.run_task(self._call_stream_handler(net_stream)) + net_stream = await self._add_stream(muxed_stream) + try: + # Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427 + await self.swarm.common_stream_handler(net_stream) # type: ignore + finally: + # As long as `common_stream_handler`, remove the stream. + self.remove_stream(net_stream) - def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: + async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: net_stream = NetStream(muxed_stream) self.streams.add(net_stream) - self.swarm.notify_opened_stream(net_stream) + await self.swarm.notify_opened_stream(net_stream) return net_stream - def _notify_disconnected(self) -> None: - self.swarm.notify_disconnected(self) + async def _notify_disconnected(self) -> None: + await self.swarm.notify_disconnected(self) async def start(self) -> None: - await self.run_task(self._handle_new_streams()) - - async def run_task(self, coro: Awaitable[Any]) -> None: - self._tasks.append(asyncio.ensure_future(coro)) + await self._handle_new_streams() async def new_stream(self) -> NetStream: muxed_stream = await self.muxed_conn.open_stream() - return self._add_stream(muxed_stream) + return await self._add_stream(muxed_stream) - async def get_streams(self) -> Tuple[NetStream, ...]: + def get_streams(self) -> Tuple[NetStream, ...]: return tuple(self.streams) def remove_stream(self, stream: NetStream) -> None: diff --git a/libp2p/network/network_interface.py b/libp2p/network/network_interface.py index 9e942831..70fb7295 100644 --- a/libp2p/network/network_interface.py +++ b/libp2p/network/network_interface.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Dict, Sequence +from async_service import ServiceAPI from multiaddr import Multiaddr from libp2p.network.connection.net_connection_interface import INetConn @@ -70,3 +71,7 @@ class INetwork(ABC): @abstractmethod async def close_peer(self, peer_id: ID) -> None: pass + + +class INetworkService(INetwork, ServiceAPI): + pass diff --git a/libp2p/network/stream/net_stream.py b/libp2p/network/stream/net_stream.py index 74f1c80e..72d5c6a7 100644 --- a/libp2p/network/stream/net_stream.py +++ b/libp2p/network/stream/net_stream.py @@ -37,7 +37,7 @@ class NetStream(INetStream): """ self.protocol_id = protocol_id - async def read(self, n: int = -1) -> bytes: + async def read(self, n: int = None) -> bytes: """ reads from stream. diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 99f09e36..2c870dd0 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -1,9 +1,11 @@ -import asyncio import logging from typing import Dict, List, Optional +from async_service import Service from multiaddr import Multiaddr +import trio +from libp2p.io.abc import ReadWriteCloser from libp2p.network.connection.net_connection_interface import INetConn from libp2p.peer.id import ID from libp2p.peer.peerstore import PeerStoreError @@ -23,14 +25,21 @@ from ..exceptions import MultiError from .connection.raw_connection import RawConnection from .connection.swarm_connection import SwarmConn from .exceptions import SwarmException -from .network_interface import INetwork +from .network_interface import INetworkService from .notifee_interface import INotifee from .stream.net_stream_interface import INetStream logger = logging.getLogger("libp2p.network.swarm") -class Swarm(INetwork): +def create_default_stream_handler(network: INetworkService) -> StreamHandlerFn: + async def stream_handler(stream: INetStream) -> None: + await network.get_manager().wait_finished() + + return stream_handler + + +class Swarm(Service, INetworkService): self_id: ID peerstore: IPeerStore @@ -40,7 +49,9 @@ class Swarm(INetwork): # whereas in Go one `peer_id` may point to multiple connections. connections: Dict[ID, INetConn] listeners: Dict[str, IListener] - common_stream_handler: Optional[StreamHandlerFn] + common_stream_handler: StreamHandlerFn + listener_nursery: Optional[trio.Nursery] + event_listener_nursery_created: trio.Event notifees: List[INotifee] @@ -61,13 +72,31 @@ class Swarm(INetwork): # Create Notifee array self.notifees = [] - self.common_stream_handler = None + # Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427 + self.common_stream_handler = create_default_stream_handler(self) # type: ignore + + self.listener_nursery = None + self.event_listener_nursery_created = trio.Event() + + async def run(self) -> None: + async with trio.open_nursery() as nursery: + # Create a nursery for listener tasks. + self.listener_nursery = nursery + self.event_listener_nursery_created.set() + try: + await self.manager.wait_finished() + finally: + # The service ended. Cancel listener tasks. + nursery.cancel_scope.cancel() + # Indicate that the nursery has been cancelled. + self.listener_nursery = None def get_peer_id(self) -> ID: return self.self_id def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None: - self.common_stream_handler = stream_handler + # Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427 + self.common_stream_handler = stream_handler # type: ignore async def dial_peer(self, peer_id: ID) -> INetConn: """ @@ -195,19 +224,15 @@ class Swarm(INetwork): - Call listener listen with the multiaddr - Map multiaddr to listener """ + # We need to wait until `self.listener_nursery` is created. + await self.event_listener_nursery_created.wait() + for maddr in multiaddrs: if str(maddr) in self.listeners: return True - async def conn_handler( - reader: asyncio.StreamReader, writer: asyncio.StreamWriter - ) -> None: - connection_info = writer.get_extra_info("peername") - # TODO make a proper multiaddr - peer_addr = f"/ip4/{connection_info[0]}/tcp/{connection_info[1]}" - logger.debug("inbound connection at %s", peer_addr) - # logger.debug("inbound connection request", peer_id) - raw_conn = RawConnection(reader, writer, False) + async def conn_handler(read_write_closer: ReadWriteCloser) -> None: + raw_conn = RawConnection(read_write_closer, False) # Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure # the conn and then mux the conn @@ -217,16 +242,13 @@ class Swarm(INetwork): raw_conn, ID(b""), False ) except SecurityUpgradeFailure as error: - logger.debug("failed to upgrade security for peer at %s", peer_addr) + logger.debug("failed to upgrade security for peer at %s", maddr) await raw_conn.close() raise SwarmException( - f"failed to upgrade security for peer at {peer_addr}" + f"failed to upgrade security for peer at {maddr}" ) from error peer_id = secured_conn.get_remote_peer() - logger.debug("upgraded security for peer at %s", peer_addr) - logger.debug("identified peer at %s as %s", peer_addr, peer_id) - try: muxed_conn = await self.upgrader.upgrade_connection( secured_conn, peer_id @@ -240,17 +262,24 @@ class Swarm(INetwork): logger.debug("upgraded mux for peer %s", peer_id) await self.add_conn(muxed_conn) - logger.debug("successfully opened connection to peer %s", peer_id) + # NOTE: This is a intentional barrier to prevent from the handler exiting and + # closing the connection. + await self.manager.wait_finished() + try: # Success listener = self.transport.create_listener(conn_handler) self.listeners[str(maddr)] = listener - await listener.listen(maddr) + # TODO: `listener.listen` is not bounded with nursery. If we want to be + # I/O agnostic, we should change the API. + if self.listener_nursery is None: + raise SwarmException("swarm instance hasn't been run") + await listener.listen(maddr, self.listener_nursery) # Call notifiers since event occurred - self.notify_listen(maddr) + await self.notify_listen(maddr) return True except IOError: @@ -261,26 +290,12 @@ class Swarm(INetwork): return False async def close(self) -> None: - # TODO: Prevent from new listeners and conns being added. - # Reference: https://github.com/libp2p/go-libp2p-swarm/blob/8be680aef8dea0a4497283f2f98470c2aeae6b65/swarm.go#L124-L134 # noqa: E501 - - # Close listeners - await asyncio.gather( - *[listener.close() for listener in self.listeners.values()] - ) - - # Close connections - await asyncio.gather( - *[connection.close() for connection in self.connections.values()] - ) - + await self.manager.stop() logger.debug("swarm successfully closed") async def close_peer(self, peer_id: ID) -> None: if peer_id not in self.connections: return - # TODO: Should be changed to close multisple connections, - # if we have several connections per peer in the future. connection = self.connections[peer_id] # NOTE: `connection.close` will delete `peer_id` from `self.connections` # and `notify_disconnected` for us. @@ -293,11 +308,14 @@ class Swarm(INetwork): and start to monitor the connection for its new streams and disconnection.""" swarm_conn = SwarmConn(muxed_conn, self) + self.manager.run_task(muxed_conn.start) + await muxed_conn.event_started.wait() + self.manager.run_task(swarm_conn.start) + await swarm_conn.event_started.wait() # Store muxed_conn with peer id self.connections[muxed_conn.peer_id] = swarm_conn # Call notifiers since event occurred - self.notify_connected(swarm_conn) - await swarm_conn.start() + await self.notify_connected(swarm_conn) return swarm_conn def remove_conn(self, swarm_conn: SwarmConn) -> None: @@ -306,14 +324,10 @@ class Swarm(INetwork): peer_id = swarm_conn.muxed_conn.peer_id if peer_id not in self.connections: return - # TODO: Should be changed to remove the exact connection, - # if we have several connections per peer in the future. del self.connections[peer_id] # Notifee - # TODO: Remeber the spawn notifying tasks and clean them up when closing. - def register_notifee(self, notifee: INotifee) -> None: """ :param notifee: object implementing Notifee interface @@ -321,20 +335,28 @@ class Swarm(INetwork): """ self.notifees.append(notifee) - def notify_opened_stream(self, stream: INetStream) -> None: - asyncio.gather( - *[notifee.opened_stream(self, stream) for notifee in self.notifees] - ) + async def notify_opened_stream(self, stream: INetStream) -> None: + async with trio.open_nursery() as nursery: + for notifee in self.notifees: + nursery.start_soon(notifee.opened_stream, self, stream) - # TODO: `notify_closed_stream` + async def notify_connected(self, conn: INetConn) -> None: + async with trio.open_nursery() as nursery: + for notifee in self.notifees: + nursery.start_soon(notifee.connected, self, conn) - def notify_connected(self, conn: INetConn) -> None: - asyncio.gather(*[notifee.connected(self, conn) for notifee in self.notifees]) + async def notify_disconnected(self, conn: INetConn) -> None: + async with trio.open_nursery() as nursery: + for notifee in self.notifees: + nursery.start_soon(notifee.disconnected, self, conn) - def notify_disconnected(self, conn: INetConn) -> None: - asyncio.gather(*[notifee.disconnected(self, conn) for notifee in self.notifees]) + async def notify_listen(self, multiaddr: Multiaddr) -> None: + async with trio.open_nursery() as nursery: + for notifee in self.notifees: + nursery.start_soon(notifee.listen, self, multiaddr) - def notify_listen(self, multiaddr: Multiaddr) -> None: - asyncio.gather(*[notifee.listen(self, multiaddr) for notifee in self.notifees]) + async def notify_closed_stream(self, stream: INetStream) -> None: + raise NotImplementedError - # TODO: `notify_listen_close` + async def notify_listen_close(self, multiaddr: Multiaddr) -> None: + raise NotImplementedError diff --git a/libp2p/peer/peerinfo.py b/libp2p/peer/peerinfo.py index 4015ef97..889b6f61 100644 --- a/libp2p/peer/peerinfo.py +++ b/libp2p/peer/peerinfo.py @@ -25,9 +25,6 @@ def info_from_p2p_addr(addr: multiaddr.Multiaddr) -> PeerInfo: if not addr: raise InvalidAddrError("`addr` should not be `None`") - if not isinstance(addr, multiaddr.Multiaddr): - raise InvalidAddrError(f"`addr`={addr} should be of type `Multiaddr`") - parts = addr.split() if not parts: raise InvalidAddrError( diff --git a/libp2p/pubsub/pubsub_router_interface.py b/libp2p/pubsub/abc.py similarity index 64% rename from libp2p/pubsub/pubsub_router_interface.py rename to libp2p/pubsub/abc.py index 99a9be75..da37b6a1 100644 --- a/libp2p/pubsub/pubsub_router_interface.py +++ b/libp2p/pubsub/abc.py @@ -1,15 +1,37 @@ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, List +from typing import ( + TYPE_CHECKING, + AsyncContextManager, + AsyncIterable, + KeysView, + List, + Tuple, +) + +from async_service import ServiceAPI from libp2p.peer.id import ID from libp2p.typing import TProtocol from .pb import rpc_pb2 +from .typing import ValidatorFn if TYPE_CHECKING: from .pubsub import Pubsub # noqa: F401 +class ISubscriptionAPI( + AsyncContextManager["ISubscriptionAPI"], AsyncIterable[rpc_pb2.Message] +): + @abstractmethod + async def unsubscribe(self) -> None: + ... + + @abstractmethod + async def get(self) -> rpc_pb2.Message: + ... + + class IPubsubRouter(ABC): @abstractmethod def get_protocols(self) -> List[TProtocol]: @@ -53,7 +75,6 @@ class IPubsubRouter(ABC): :param rpc: rpc message """ - # FIXME: Should be changed to type 'peer.ID' @abstractmethod async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None: """ @@ -80,3 +101,46 @@ class IPubsubRouter(ABC): :param topic: topic to leave """ + + +class IPubsub(ServiceAPI): + @property + @abstractmethod + def my_id(self) -> ID: + ... + + @property + @abstractmethod + def protocols(self) -> Tuple[TProtocol, ...]: + ... + + @property + @abstractmethod + def topic_ids(self) -> KeysView[str]: + ... + + @abstractmethod + def set_topic_validator( + self, topic: str, validator: ValidatorFn, is_async_validator: bool + ) -> None: + ... + + @abstractmethod + def remove_topic_validator(self, topic: str) -> None: + ... + + @abstractmethod + async def wait_until_ready(self) -> None: + ... + + @abstractmethod + async def subscribe(self, topic_id: str) -> ISubscriptionAPI: + ... + + @abstractmethod + async def unsubscribe(self, topic_id: str) -> None: + ... + + @abstractmethod + async def publish(self, topic_id: str, data: bytes) -> None: + ... diff --git a/libp2p/pubsub/exceptions.py b/libp2p/pubsub/exceptions.py new file mode 100644 index 00000000..55afde04 --- /dev/null +++ b/libp2p/pubsub/exceptions.py @@ -0,0 +1,9 @@ +from libp2p.exceptions import BaseLibp2pError + + +class PubsubRouterError(BaseLibp2pError): + pass + + +class NoPubsubAttached(PubsubRouterError): + pass diff --git a/libp2p/pubsub/floodsub.py b/libp2p/pubsub/floodsub.py index bae2bf27..dd9ae2fd 100644 --- a/libp2p/pubsub/floodsub.py +++ b/libp2p/pubsub/floodsub.py @@ -1,14 +1,16 @@ import logging from typing import Iterable, List, Sequence +import trio + from libp2p.network.stream.exceptions import StreamClosed from libp2p.peer.id import ID from libp2p.typing import TProtocol from libp2p.utils import encode_varint_prefixed +from .abc import IPubsubRouter from .pb import rpc_pb2 from .pubsub import Pubsub -from .pubsub_router_interface import IPubsubRouter PROTOCOL_ID = TProtocol("/floodsub/1.0.0") @@ -61,6 +63,8 @@ class FloodSub(IPubsubRouter): :param rpc: rpc message """ + # Checkpoint + await trio.hazmat.checkpoint() async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None: """ @@ -107,6 +111,8 @@ class FloodSub(IPubsubRouter): :param topic: topic to join """ + # Checkpoint + await trio.hazmat.checkpoint() async def leave(self, topic: str) -> None: """ @@ -115,6 +121,8 @@ class FloodSub(IPubsubRouter): :param topic: topic to leave """ + # Checkpoint + await trio.hazmat.checkpoint() def _get_peers_to_send( self, topic_ids: Iterable[str], msg_forwarder: ID, origin: ID diff --git a/libp2p/pubsub/gossipsub.py b/libp2p/pubsub/gossipsub.py index 05efe84e..4d25c254 100644 --- a/libp2p/pubsub/gossipsub.py +++ b/libp2p/pubsub/gossipsub.py @@ -1,28 +1,30 @@ from ast import literal_eval -import asyncio from collections import defaultdict import logging import random from typing import Any, DefaultDict, Dict, Iterable, List, Sequence, Set, Tuple +from async_service import Service +import trio + from libp2p.network.stream.exceptions import StreamClosed from libp2p.peer.id import ID from libp2p.pubsub import floodsub from libp2p.typing import TProtocol from libp2p.utils import encode_varint_prefixed +from .abc import IPubsubRouter +from .exceptions import NoPubsubAttached from .mcache import MessageCache from .pb import rpc_pb2 from .pubsub import Pubsub -from .pubsub_router_interface import IPubsubRouter PROTOCOL_ID = TProtocol("/meshsub/1.0.0") logger = logging.getLogger("libp2p.pubsub.gossipsub") -class GossipSub(IPubsubRouter): - +class GossipSub(IPubsubRouter, Service): protocols: List[TProtocol] pubsub: Pubsub @@ -38,7 +40,8 @@ class GossipSub(IPubsubRouter): # The protocol peer supports peer_protocol: Dict[ID, TProtocol] - time_since_last_publish: Dict[str, int] + # TODO: Add `time_since_last_publish` + # Create topic --> time since last publish map. mcache: MessageCache @@ -75,9 +78,6 @@ class GossipSub(IPubsubRouter): # Create peer --> protocol mapping self.peer_protocol = {} - # Create topic --> time since last publish map - self.time_since_last_publish = {} - # Create message cache self.mcache = MessageCache(gossip_window, gossip_history) @@ -85,6 +85,12 @@ class GossipSub(IPubsubRouter): self.heartbeat_initial_delay = heartbeat_initial_delay self.heartbeat_interval = heartbeat_interval + async def run(self) -> None: + if self.pubsub is None: + raise NoPubsubAttached + self.manager.run_daemon_task(self.heartbeat) + await self.manager.wait_finished() + # Interface functions def get_protocols(self) -> List[TProtocol]: @@ -104,9 +110,6 @@ class GossipSub(IPubsubRouter): logger.debug("attached to pusub") - # Start heartbeat now that we have a pubsub instance - asyncio.ensure_future(self.heartbeat()) - def add_peer(self, peer_id: ID, protocol_id: TProtocol) -> None: """ Notifies the router that a new peer has been connected. @@ -370,7 +373,7 @@ class GossipSub(IPubsubRouter): state changes in the preceding heartbeat """ # Start after a delay. Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/01b9825fbee1848751d90a8469e3f5f43bac8466/gossipsub.go#L410 # Noqa: E501 - await asyncio.sleep(self.heartbeat_initial_delay) + await trio.sleep(self.heartbeat_initial_delay) while True: # Maintain mesh and keep track of which peers to send GRAFT or PRUNE to peers_to_graft, peers_to_prune = self.mesh_heartbeat() @@ -385,7 +388,7 @@ class GossipSub(IPubsubRouter): self.mcache.shift() - await asyncio.sleep(self.heartbeat_interval) + await trio.sleep(self.heartbeat_interval) def mesh_heartbeat( self @@ -413,7 +416,7 @@ class GossipSub(IPubsubRouter): if num_mesh_peers_in_topic > self.degree_high: # Select |mesh[topic]| - D peers from mesh[topic] - selected_peers = GossipSub.select_from_minus( + selected_peers = self.select_from_minus( num_mesh_peers_in_topic - self.degree, self.mesh[topic], set() ) for peer in selected_peers: @@ -428,15 +431,10 @@ class GossipSub(IPubsubRouter): # Note: the comments here are the exact pseudocode from the spec for topic in self.fanout: # Delete topic entry if it's not in `pubsub.peer_topics` - # or if it's time-since-last-published > ttl - # TODO: there's no way time_since_last_publish gets set anywhere yet - if ( - topic not in self.pubsub.peer_topics - or self.time_since_last_publish[topic] > self.time_to_live - ): + # or (TODO) if it's time-since-last-published > ttl + if topic not in self.pubsub.peer_topics: # Remove topic from fanout del self.fanout[topic] - del self.time_since_last_publish[topic] else: # Check if fanout peers are still in the topic and remove the ones that are not # ref: https://github.com/libp2p/go-libp2p-pubsub/blob/01b9825fbee1848751d90a8469e3f5f43bac8466/gossipsub.go#L498-L504 # noqa: E501 diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 44092f94..2045d487 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -1,21 +1,12 @@ -import asyncio +import functools import logging import time -from typing import ( - TYPE_CHECKING, - Awaitable, - Callable, - Dict, - List, - NamedTuple, - Set, - Tuple, - Union, - cast, -) +from typing import TYPE_CHECKING, Dict, KeysView, List, NamedTuple, Set, Tuple, cast +from async_service import Service import base58 from lru import LRU +import trio from libp2p.crypto.keys import PrivateKey from libp2p.exceptions import ParseError, ValidationError @@ -28,15 +19,21 @@ from libp2p.peer.id import ID from libp2p.typing import TProtocol from libp2p.utils import encode_varint_prefixed, read_varint_prefixed_bytes +from .abc import IPubsub, ISubscriptionAPI from .pb import rpc_pb2 from .pubsub_notifee import PubsubNotifee +from .subscription import TrioSubscriptionAPI +from .typing import AsyncValidatorFn, SyncValidatorFn, ValidatorFn from .validators import PUBSUB_SIGNING_PREFIX, signature_validator if TYPE_CHECKING: - from .pubsub_router_interface import IPubsubRouter # noqa: F401 + from .abc import IPubsubRouter # noqa: F401 from typing import Any # noqa: F401 +# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/40e1c94708658b155f30cf99e4574f384756d83c/topic.go#L97 # noqa: E501 +SUBSCRIPTION_CHANNEL_SIZE = 32 + logger = logging.getLogger("libp2p.pubsub") @@ -45,34 +42,24 @@ def get_msg_id(msg: rpc_pb2.Message) -> Tuple[bytes, bytes]: return (msg.seqno, msg.from_id) -SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool] -AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]] -ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn] - - class TopicValidator(NamedTuple): validator: ValidatorFn is_async: bool -class Pubsub: +class Pubsub(Service, IPubsub): host: IHost - my_id: ID router: "IPubsubRouter" - peer_queue: "asyncio.Queue[ID]" - dead_peer_queue: "asyncio.Queue[ID]" - - protocols: List[TProtocol] - - incoming_msgs_from_peers: "asyncio.Queue[rpc_pb2.Message]" - outgoing_messages: "asyncio.Queue[rpc_pb2.Message]" + peer_receive_channel: "trio.MemoryReceiveChannel[ID]" + dead_peer_receive_channel: "trio.MemoryReceiveChannel[ID]" seen_messages: LRU - my_topics: Dict[str, "asyncio.Queue[rpc_pb2.Message]"] + subscribed_topics_send: Dict[str, "trio.MemorySendChannel[rpc_pb2.Message]"] + subscribed_topics_receive: Dict[str, "TrioSubscriptionAPI"] peer_topics: Dict[str, Set[ID]] peers: Dict[ID, INetStream] @@ -81,17 +68,17 @@ class Pubsub: counter: int # uint64 - _tasks: List["asyncio.Future[Any]"] - # Indicate if we should enforce signature verification strict_signing: bool sign_key: PrivateKey + event_handle_peer_queue_started: trio.Event + event_handle_dead_peer_queue_started: trio.Event + def __init__( self, host: IHost, router: "IPubsubRouter", - my_id: ID, cache_size: int = None, strict_signing: bool = True, ) -> None: @@ -107,39 +94,44 @@ class Pubsub: """ self.host = host self.router = router - self.my_id = my_id # Attach this new Pubsub object to the router self.router.attach(self) + peer_send, peer_receive = trio.open_memory_channel[ID](0) + dead_peer_send, dead_peer_receive = trio.open_memory_channel[ID](0) + # Only keep the receive channels in `Pubsub`. + # Therefore, we can only close from the receive side. + self.peer_receive_channel = peer_receive + self.dead_peer_receive_channel = dead_peer_receive # Register a notifee - self.peer_queue = asyncio.Queue() - self.dead_peer_queue = asyncio.Queue() self.host.get_network().register_notifee( - PubsubNotifee(self.peer_queue, self.dead_peer_queue) + PubsubNotifee(peer_send, dead_peer_send) ) # Register stream handlers for each pubsub router protocol to handle # the pubsub streams opened on those protocols - self.protocols = self.router.get_protocols() - for protocol in self.protocols: + for protocol in router.get_protocols(): self.host.set_stream_handler(protocol, self.stream_handler) - # Use asyncio queues for proper context switching - self.incoming_msgs_from_peers = asyncio.Queue() - self.outgoing_messages = asyncio.Queue() - # keeps track of seen messages as LRU cache if cache_size is None: self.cache_size = 128 else: self.cache_size = cache_size + self.strict_signing = strict_signing + if strict_signing: + self.sign_key = self.host.get_private_key() + else: + self.sign_key = None + self.seen_messages = LRU(self.cache_size) # Map of topics we are subscribed to blocking queues # for when the given topic receives a message - self.my_topics = {} + self.subscribed_topics_send = {} + self.subscribed_topics_receive = {} # Map of topic to peers to keep track of what peers are subscribed to self.peer_topics = {} @@ -152,22 +144,31 @@ class Pubsub: self.counter = int(time.time()) - self._tasks = [] - # Call handle peer to keep waiting for updates to peer queue - self._tasks.append(asyncio.ensure_future(self.handle_peer_queue())) - self._tasks.append(asyncio.ensure_future(self.handle_dead_peer_queue())) + self.event_handle_peer_queue_started = trio.Event() + self.event_handle_dead_peer_queue_started = trio.Event() - self.strict_signing = strict_signing - if strict_signing: - self.sign_key = self.host.get_private_key() - else: - self.sign_key = None + async def run(self) -> None: + self.manager.run_daemon_task(self.handle_peer_queue) + self.manager.run_daemon_task(self.handle_dead_peer_queue) + await self.manager.wait_finished() + + @property + def my_id(self) -> ID: + return self.host.get_id() + + @property + def protocols(self) -> Tuple[TProtocol, ...]: + return tuple(self.router.get_protocols()) + + @property + def topic_ids(self) -> KeysView[str]: + return self.subscribed_topics_receive.keys() def get_hello_packet(self) -> rpc_pb2.RPC: """Generate subscription message with all topics we are subscribed to only send hello packet if we have subscribed topics.""" packet = rpc_pb2.RPC() - for topic_id in self.my_topics: + for topic_id in self.topic_ids: packet.subscriptions.extend( [rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id)] ) @@ -182,7 +183,7 @@ class Pubsub: """ peer_id = stream.muxed_conn.peer_id - while True: + while self.manager.is_running: incoming: bytes = await read_varint_prefixed_bytes(stream) rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC() rpc_incoming.ParseFromString(incoming) @@ -194,11 +195,7 @@ class Pubsub: logger.debug( "received `publish` message %s from peer %s", msg, peer_id ) - self._tasks.append( - asyncio.ensure_future( - self.push_msg(msg_forwarder=peer_id, msg=msg) - ) - ) + self.manager.run_task(self.push_msg, peer_id, msg) if rpc_incoming.subscriptions: # deal with RPC.subscriptions @@ -226,9 +223,6 @@ class Pubsub: ) await self.router.handle_rpc(rpc_incoming, peer_id) - # Force context switch - await asyncio.sleep(0) - def set_topic_validator( self, topic: str, validator: ValidatorFn, is_async_validator: bool ) -> None: @@ -283,6 +277,10 @@ class Pubsub: await stream.reset() self._handle_dead_peer(peer_id) + async def wait_until_ready(self) -> None: + await self.event_handle_peer_queue_started.wait() + await self.event_handle_dead_peer_queue_started.wait() + async def _handle_new_peer(self, peer_id: ID) -> None: try: stream: INetStream = await self.host.new_stream(peer_id, self.protocols) @@ -325,18 +323,21 @@ class Pubsub: """Continuously read from peer queue and each time a new peer is found, open a stream to the peer using a supported pubsub protocol pubsub protocols we support.""" - while True: - peer_id: ID = await self.peer_queue.get() - # Add Peer - self._tasks.append(asyncio.ensure_future(self._handle_new_peer(peer_id))) + async with self.peer_receive_channel: + self.event_handle_peer_queue_started.set() + async for peer_id in self.peer_receive_channel: + # Add Peer + self.manager.run_task(self._handle_new_peer, peer_id) async def handle_dead_peer_queue(self) -> None: - """Continuously read from dead peer queue and close the stream between - that peer and remove peer info from pubsub and pubsub router.""" - while True: - peer_id: ID = await self.dead_peer_queue.get() - # Remove Peer - self._handle_dead_peer(peer_id) + """Continuously read from dead peer channel and close the stream + between that peer and remove peer info from pubsub and pubsub + router.""" + async with self.dead_peer_receive_channel: + self.event_handle_dead_peer_queue_started.set() + async for peer_id in self.dead_peer_receive_channel: + # Remove Peer + self._handle_dead_peer(peer_id) def handle_subscription( self, origin_id: ID, sub_message: rpc_pb2.RPC.SubOpts @@ -360,8 +361,7 @@ class Pubsub: if origin_id in self.peer_topics[sub_message.topicid]: self.peer_topics[sub_message.topicid].discard(origin_id) - # FIXME(mhchia): Change the function name? - async def handle_talk(self, publish_message: rpc_pb2.Message) -> None: + def notify_subscriptions(self, publish_message: rpc_pb2.Message) -> None: """ Put incoming message from a peer onto my blocking queue. @@ -370,13 +370,19 @@ class Pubsub: # Check if this message has any topics that we are subscribed to for topic in publish_message.topicIDs: - if topic in self.my_topics: + if topic in self.topic_ids: # we are subscribed to a topic this message was sent for, # so add message to the subscription output queue # for each topic - await self.my_topics[topic].put(publish_message) + try: + self.subscribed_topics_send[topic].send_nowait(publish_message) + except trio.WouldBlock: + # Channel is full, ignore this message. + logger.warning( + "fail to deliver message to subscription for topic %s", topic + ) - async def subscribe(self, topic_id: str) -> "asyncio.Queue[rpc_pb2.Message]": + async def subscribe(self, topic_id: str) -> ISubscriptionAPI: """ Subscribe ourself to a topic. @@ -386,11 +392,19 @@ class Pubsub: logger.debug("subscribing to topic %s", topic_id) # Already subscribed - if topic_id in self.my_topics: - return self.my_topics[topic_id] + if topic_id in self.topic_ids: + return self.subscribed_topics_receive[topic_id] - # Map topic_id to blocking queue - self.my_topics[topic_id] = asyncio.Queue() + send_channel, receive_channel = trio.open_memory_channel[rpc_pb2.Message]( + SUBSCRIPTION_CHANNEL_SIZE + ) + + subscription = TrioSubscriptionAPI( + receive_channel, + unsubscribe_fn=functools.partial(self.unsubscribe, topic_id), + ) + self.subscribed_topics_send[topic_id] = send_channel + self.subscribed_topics_receive[topic_id] = subscription # Create subscribe message packet: rpc_pb2.RPC = rpc_pb2.RPC() @@ -404,8 +418,8 @@ class Pubsub: # Tell router we are joining this topic await self.router.join(topic_id) - # Return the asyncio queue for messages on this topic - return self.my_topics[topic_id] + # Return the subscription for messages on this topic + return subscription async def unsubscribe(self, topic_id: str) -> None: """ @@ -417,10 +431,14 @@ class Pubsub: logger.debug("unsubscribing from topic %s", topic_id) # Return if we already unsubscribed from the topic - if topic_id not in self.my_topics: + if topic_id not in self.topic_ids: return - # Remove topic_id from map if present - del self.my_topics[topic_id] + # Remove topic_id from the maps before yielding + send_channel = self.subscribed_topics_send[topic_id] + del self.subscribed_topics_send[topic_id] + del self.subscribed_topics_receive[topic_id] + # Only close the send side + await send_channel.aclose() # Create unsubscribe message packet: rpc_pb2.RPC = rpc_pb2.RPC() @@ -462,7 +480,7 @@ class Pubsub: data=data, topicIDs=[topic_id], # Origin is ourself. - from_id=self.host.get_id().to_bytes(), + from_id=self.my_id.to_bytes(), seqno=self._next_seqno(), ) @@ -474,7 +492,7 @@ class Pubsub: msg.key = self.host.get_public_key().serialize() msg.signature = signature - await self.push_msg(self.host.get_id(), msg) + await self.push_msg(self.my_id, msg) logger.debug("successfully published message %s", msg) @@ -485,12 +503,12 @@ class Pubsub: :param msg_forwarder: the peer who forward us the message. :param msg: the message. """ - sync_topic_validators = [] - async_topic_validator_futures: List[Awaitable[bool]] = [] + sync_topic_validators: List[SyncValidatorFn] = [] + async_topic_validators: List[AsyncValidatorFn] = [] for topic_validator in self.get_msg_validators(msg): if topic_validator.is_async: - async_topic_validator_futures.append( - cast(Awaitable[bool], topic_validator.validator(msg_forwarder, msg)) + async_topic_validators.append( + cast(AsyncValidatorFn, topic_validator.validator) ) else: sync_topic_validators.append( @@ -503,9 +521,20 @@ class Pubsub: # TODO: Implement throttle on async validators - if len(async_topic_validator_futures) > 0: - results = await asyncio.gather(*async_topic_validator_futures) - if not all(results): + if len(async_topic_validators) > 0: + # TODO: Use a better pattern + final_result: bool = True + + async def run_async_validator(func: AsyncValidatorFn) -> None: + nonlocal final_result + result = await func(msg_forwarder, msg) + final_result = final_result and result + + async with trio.open_nursery() as nursery: + for async_validator in async_topic_validators: + nursery.start_soon(run_async_validator, async_validator) + + if not final_result: raise ValidationError(f"Validation failed for msg={msg}") async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None: @@ -548,7 +577,7 @@ class Pubsub: return self._mark_msg_seen(msg) - await self.handle_talk(msg) + self.notify_subscriptions(msg) await self.router.publish(msg_forwarder, msg) def _next_seqno(self) -> bytes: @@ -567,14 +596,4 @@ class Pubsub: self.seen_messages[msg_id] = 1 def _is_subscribed_to_msg(self, msg: rpc_pb2.Message) -> bool: - if not self.my_topics: - return False - return any(topic in self.my_topics for topic in msg.topicIDs) - - async def close(self) -> None: - for task in self._tasks: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass + return any(topic in self.topic_ids for topic in msg.topicIDs) diff --git a/libp2p/pubsub/pubsub_notifee.py b/libp2p/pubsub/pubsub_notifee.py index 6afa9ad2..cf728843 100644 --- a/libp2p/pubsub/pubsub_notifee.py +++ b/libp2p/pubsub/pubsub_notifee.py @@ -1,6 +1,7 @@ from typing import TYPE_CHECKING from multiaddr import Multiaddr +import trio from libp2p.network.connection.net_connection_interface import INetConn from libp2p.network.network_interface import INetwork @@ -8,19 +9,18 @@ from libp2p.network.notifee_interface import INotifee from libp2p.network.stream.net_stream_interface import INetStream if TYPE_CHECKING: - import asyncio # noqa: F401 from libp2p.peer.id import ID # noqa: F401 class PubsubNotifee(INotifee): - initiator_peers_queue: "asyncio.Queue[ID]" - dead_peers_queue: "asyncio.Queue[ID]" + initiator_peers_queue: "trio.MemorySendChannel[ID]" + dead_peers_queue: "trio.MemorySendChannel[ID]" def __init__( self, - initiator_peers_queue: "asyncio.Queue[ID]", - dead_peers_queue: "asyncio.Queue[ID]", + initiator_peers_queue: "trio.MemorySendChannel[ID]", + dead_peers_queue: "trio.MemorySendChannel[ID]", ) -> None: """ :param initiator_peers_queue: queue to add new peers to so that pubsub @@ -32,10 +32,10 @@ class PubsubNotifee(INotifee): self.dead_peers_queue = dead_peers_queue async def opened_stream(self, network: INetwork, stream: INetStream) -> None: - pass + await trio.hazmat.checkpoint() async def closed_stream(self, network: INetwork, stream: INetStream) -> None: - pass + await trio.hazmat.checkpoint() async def connected(self, network: INetwork, conn: INetConn) -> None: """ @@ -46,7 +46,11 @@ class PubsubNotifee(INotifee): :param network: network the connection was opened on :param conn: connection that was opened """ - await self.initiator_peers_queue.put(conn.muxed_conn.peer_id) + try: + await self.initiator_peers_queue.send(conn.muxed_conn.peer_id) + except trio.BrokenResourceError: + # The receive channel is closed by Pubsub. We should do nothing here. + pass async def disconnected(self, network: INetwork, conn: INetConn) -> None: """ @@ -56,10 +60,14 @@ class PubsubNotifee(INotifee): :param network: network the connection was opened on :param conn: connection that was opened """ - await self.dead_peers_queue.put(conn.muxed_conn.peer_id) + try: + await self.dead_peers_queue.send(conn.muxed_conn.peer_id) + except trio.BrokenResourceError: + # The receive channel is closed by Pubsub. We should do nothing here. + pass async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None: - pass + await trio.hazmat.checkpoint() async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None: - pass + await trio.hazmat.checkpoint() diff --git a/libp2p/pubsub/subscription.py b/libp2p/pubsub/subscription.py new file mode 100644 index 00000000..e3c926cc --- /dev/null +++ b/libp2p/pubsub/subscription.py @@ -0,0 +1,46 @@ +from types import TracebackType +from typing import AsyncIterator, Optional, Type + +import trio + +from .abc import ISubscriptionAPI +from .pb import rpc_pb2 +from .typing import UnsubscribeFn + + +class BaseSubscriptionAPI(ISubscriptionAPI): + async def __aenter__(self) -> "BaseSubscriptionAPI": + await trio.hazmat.checkpoint() + return self + + async def __aexit__( + self, + exc_type: "Optional[Type[BaseException]]", + exc_value: "Optional[BaseException]", + traceback: "Optional[TracebackType]", + ) -> None: + await self.unsubscribe() + + +class TrioSubscriptionAPI(BaseSubscriptionAPI): + receive_channel: "trio.MemoryReceiveChannel[rpc_pb2.Message]" + unsubscribe_fn: UnsubscribeFn + + def __init__( + self, + receive_channel: "trio.MemoryReceiveChannel[rpc_pb2.Message]", + unsubscribe_fn: UnsubscribeFn, + ) -> None: + self.receive_channel = receive_channel + # Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427 + self.unsubscribe_fn = unsubscribe_fn # type: ignore + + async def unsubscribe(self) -> None: + # Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427 + await self.unsubscribe_fn() # type: ignore + + def __aiter__(self) -> AsyncIterator[rpc_pb2.Message]: + return self.receive_channel.__aiter__() + + async def get(self) -> rpc_pb2.Message: + return await self.receive_channel.receive() diff --git a/libp2p/pubsub/typing.py b/libp2p/pubsub/typing.py new file mode 100644 index 00000000..33297a9f --- /dev/null +++ b/libp2p/pubsub/typing.py @@ -0,0 +1,11 @@ +from typing import Awaitable, Callable, Union + +from libp2p.peer.id import ID + +from .pb import rpc_pb2 + +SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool] +AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]] +ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn] + +UnsubscribeFn = Callable[[], Awaitable[None]] diff --git a/libp2p/security/insecure/transport.py b/libp2p/security/insecure/transport.py index b8cea7f0..103a580f 100644 --- a/libp2p/security/insecure/transport.py +++ b/libp2p/security/insecure/transport.py @@ -39,7 +39,7 @@ class InsecureSession(BaseSession): await self.conn.write(data) return len(data) - async def read(self, n: int = -1) -> bytes: + async def read(self, n: int = None) -> bytes: return await self.conn.read(n) async def close(self) -> None: diff --git a/libp2p/security/secio/transport.py b/libp2p/security/secio/transport.py index b11585bf..9e98873b 100644 --- a/libp2p/security/secio/transport.py +++ b/libp2p/security/secio/transport.py @@ -94,7 +94,7 @@ class SecureSession(BaseSession): data = self.buf.getbuffer()[self.low_watermark : self.high_watermark] - if n < 0: + if n is None: n = len(data) result = data[:n].tobytes() self.low_watermark += len(result) @@ -111,7 +111,7 @@ class SecureSession(BaseSession): self.low_watermark = 0 self.high_watermark = len(msg) - async def read(self, n: int = -1) -> bytes: + async def read(self, n: int = None) -> bytes: if n == 0: return bytes() diff --git a/libp2p/stream_muxer/abc.py b/libp2p/stream_muxer/abc.py index 71704c1e..82140ff4 100644 --- a/libp2p/stream_muxer/abc.py +++ b/libp2p/stream_muxer/abc.py @@ -1,5 +1,7 @@ from abc import ABC, abstractmethod +import trio + from libp2p.io.abc import ReadWriteCloser from libp2p.peer.id import ID from libp2p.security.secure_conn_interface import ISecureConn @@ -11,6 +13,7 @@ class IMuxedConn(ABC): """ peer_id: ID + event_started: trio.Event @abstractmethod def __init__(self, conn: ISecureConn, peer_id: ID) -> None: @@ -25,12 +28,17 @@ class IMuxedConn(ABC): @property @abstractmethod def is_initiator(self) -> bool: - pass + """if this connection is the initiator.""" + + @abstractmethod + async def start(self) -> None: + """start the multiplexer.""" @abstractmethod async def close(self) -> None: """close connection.""" + @property @abstractmethod def is_closed(self) -> bool: """ diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index f70cae20..4f62e152 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -1,7 +1,7 @@ -import asyncio import logging -from typing import Any # noqa: F401 -from typing import Awaitable, Dict, List, Optional, Tuple +from typing import Dict, Optional, Tuple + +import trio from libp2p.exceptions import ParseError from libp2p.io.exceptions import IncompleteReadError @@ -23,6 +23,8 @@ from .exceptions import MplexUnavailable from .mplex_stream import MplexStream MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0") +# Ref: https://github.com/libp2p/go-mplex/blob/414db61813d9ad3e6f4a7db5c1b1612de343ace9/multiplex.go#L115 # noqa: E501 +MPLEX_MESSAGE_CHANNEL_SIZE = 8 logger = logging.getLogger("libp2p.stream_muxer.mplex.mplex") @@ -36,12 +38,14 @@ class Mplex(IMuxedConn): peer_id: ID next_channel_id: int streams: Dict[StreamID, MplexStream] - streams_lock: asyncio.Lock - new_stream_queue: "asyncio.Queue[IMuxedStream]" - event_shutting_down: asyncio.Event - event_closed: asyncio.Event + streams_lock: trio.Lock + streams_msg_channels: Dict[StreamID, "trio.MemorySendChannel[bytes]"] + new_stream_send_channel: "trio.MemorySendChannel[IMuxedStream]" + new_stream_receive_channel: "trio.MemoryReceiveChannel[IMuxedStream]" - _tasks: List["asyncio.Future[Any]"] + event_shutting_down: trio.Event + event_closed: trio.Event + event_started: trio.Event def __init__(self, secured_conn: ISecureConn, peer_id: ID) -> None: """ @@ -61,15 +65,16 @@ class Mplex(IMuxedConn): # Mapping from stream ID -> buffer of messages for that stream self.streams = {} - self.streams_lock = asyncio.Lock() - self.new_stream_queue = asyncio.Queue() - self.event_shutting_down = asyncio.Event() - self.event_closed = asyncio.Event() + self.streams_lock = trio.Lock() + self.streams_msg_channels = {} + channels = trio.open_memory_channel[IMuxedStream](0) + self.new_stream_send_channel, self.new_stream_receive_channel = channels + self.event_shutting_down = trio.Event() + self.event_closed = trio.Event() + self.event_started = trio.Event() - self._tasks = [] - - # Kick off reading - self._tasks.append(asyncio.ensure_future(self.handle_incoming())) + async def start(self) -> None: + await self.handle_incoming() @property def is_initiator(self) -> bool: @@ -85,6 +90,7 @@ class Mplex(IMuxedConn): # Blocked until `close` is finally set. await self.event_closed.wait() + @property def is_closed(self) -> bool: """ check connection is fully closed. @@ -104,9 +110,13 @@ class Mplex(IMuxedConn): return next_id async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream: - stream = MplexStream(name, stream_id, self) + send_channel, receive_channel = trio.open_memory_channel[bytes]( + MPLEX_MESSAGE_CHANNEL_SIZE + ) + stream = MplexStream(name, stream_id, self, receive_channel) async with self.streams_lock: self.streams[stream_id] = stream + self.streams_msg_channels[stream_id] = send_channel return stream async def open_stream(self) -> IMuxedStream: @@ -123,27 +133,12 @@ class Mplex(IMuxedConn): await self.send_message(HeaderTags.NewStream, name.encode(), stream_id) return stream - async def _wait_until_shutting_down_or_closed(self, coro: Awaitable[Any]) -> Any: - task_coro = asyncio.ensure_future(coro) - task_wait_closed = asyncio.ensure_future(self.event_closed.wait()) - task_wait_shutting_down = asyncio.ensure_future(self.event_shutting_down.wait()) - done, pending = await asyncio.wait( - [task_coro, task_wait_closed, task_wait_shutting_down], - return_when=asyncio.FIRST_COMPLETED, - ) - for fut in pending: - fut.cancel() - if task_wait_closed in done: - raise MplexUnavailable("Mplex is closed") - if task_wait_shutting_down in done: - raise MplexUnavailable("Mplex is shutting down") - return task_coro.result() - async def accept_stream(self) -> IMuxedStream: """accepts a muxed stream opened by the other end.""" - return await self._wait_until_shutting_down_or_closed( - self.new_stream_queue.get() - ) + try: + return await self.new_stream_receive_channel.receive() + except trio.EndOfChannel: + raise MplexUnavailable async def send_message( self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID @@ -151,7 +146,7 @@ class Mplex(IMuxedConn): """ sends a message over the connection. - :param header: header to use + :param flag: header to use :param data: data to send in the message :param stream_id: stream the message is in """ @@ -163,9 +158,7 @@ class Mplex(IMuxedConn): _bytes = header + encode_varint_prefixed(data) - return await self._wait_until_shutting_down_or_closed( - self.write_to_stream(_bytes) - ) + return await self.write_to_stream(_bytes) async def write_to_stream(self, _bytes: bytes) -> int: """ @@ -174,21 +167,25 @@ class Mplex(IMuxedConn): :param _bytes: byte array to write :return: length written """ - await self.secured_conn.write(_bytes) + try: + await self.secured_conn.write(_bytes) + except RawConnError as e: + raise MplexUnavailable( + "failed to write message to the underlying connection" + ) from e + return len(_bytes) async def handle_incoming(self) -> None: """Read a message off of the secured connection and add it to the corresponding message buffer.""" - + self.event_started.set() while True: try: await self._handle_incoming_message() except MplexUnavailable as e: logger.debug("mplex unavailable while waiting for incoming: %s", e) break - # Force context switch - await asyncio.sleep(0) # If we enter here, it means this connection is shutting down. # We should clean things up. await self._cleanup() @@ -200,20 +197,19 @@ class Mplex(IMuxedConn): :return: stream_id, flag, message contents """ - # FIXME: No timeout is used in Go implementation. try: header = await decode_uvarint_from_stream(self.secured_conn) - message = await asyncio.wait_for( - read_varint_prefixed_bytes(self.secured_conn), timeout=5 - ) except (ParseError, RawConnError, IncompleteReadError) as error: raise MplexUnavailable( - "failed to read messages correctly from the underlying connection" - ) from error - except asyncio.TimeoutError as error: + f"failed to read the header correctly from the underlying connection: {error}" + ) + try: + message = await read_varint_prefixed_bytes(self.secured_conn) + except (ParseError, RawConnError, IncompleteReadError) as error: raise MplexUnavailable( - "failed to read more message body within the timeout" - ) from error + "failed to read the message body correctly from the underlying connection: " + f"{error}" + ) flag = header & 0x07 channel_id = header >> 3 @@ -226,9 +222,7 @@ class Mplex(IMuxedConn): :raise MplexUnavailable: `Mplex` encounters fatal error or is shutting down. """ - channel_id, flag, message = await self._wait_until_shutting_down_or_closed( - self.read_message() - ) + channel_id, flag, message = await self.read_message() stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1)) if flag == HeaderTags.NewStream.value: @@ -258,9 +252,10 @@ class Mplex(IMuxedConn): f"received NewStream message for existing stream: {stream_id}" ) mplex_stream = await self._initialize_stream(stream_id, message.decode()) - await self._wait_until_shutting_down_or_closed( - self.new_stream_queue.put(mplex_stream) - ) + try: + await self.new_stream_send_channel.send(mplex_stream) + except trio.ClosedResourceError: + raise MplexUnavailable async def _handle_message(self, stream_id: StreamID, message: bytes) -> None: async with self.streams_lock: @@ -270,13 +265,21 @@ class Mplex(IMuxedConn): # TODO: Warn and emit logs about this. return stream = self.streams[stream_id] + send_channel = self.streams_msg_channels[stream_id] async with stream.close_lock: if stream.event_remote_closed.is_set(): # TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501 return - await self._wait_until_shutting_down_or_closed( - stream.incoming_data.put(message) - ) + try: + send_channel.send_nowait(message) + except (trio.BrokenResourceError, trio.ClosedResourceError): + raise MplexUnavailable + except trio.WouldBlock: + # `send_channel` is full, reset this stream. + logger.warning( + "message channel of stream %s is full: stream is reset", stream_id + ) + await stream.reset() async def _handle_close(self, stream_id: StreamID) -> None: async with self.streams_lock: @@ -284,6 +287,8 @@ class Mplex(IMuxedConn): # Ignore unmatched messages for now. return stream = self.streams[stream_id] + send_channel = self.streams_msg_channels[stream_id] + await send_channel.aclose() # NOTE: If remote is already closed, then return: Technically a bug # on the other side. We should consider killing the connection. async with stream.close_lock: @@ -305,27 +310,30 @@ class Mplex(IMuxedConn): # This is *ok*. We forget the stream on reset. return stream = self.streams[stream_id] - + send_channel = self.streams_msg_channels[stream_id] + await send_channel.aclose() async with stream.close_lock: if not stream.event_remote_closed.is_set(): stream.event_reset.set() - stream.event_remote_closed.set() # If local is not closed, we should close it. if not stream.event_local_closed.is_set(): stream.event_local_closed.set() async with self.streams_lock: self.streams.pop(stream_id, None) + self.streams_msg_channels.pop(stream_id, None) async def _cleanup(self) -> None: if not self.event_shutting_down.is_set(): self.event_shutting_down.set() async with self.streams_lock: - for stream in self.streams.values(): + for stream_id, stream in self.streams.items(): async with stream.close_lock: if not stream.event_remote_closed.is_set(): stream.event_remote_closed.set() stream.event_reset.set() stream.event_local_closed.set() - self.streams = None + send_channel = self.streams_msg_channels[stream_id] + await send_channel.aclose() self.event_closed.set() + await self.new_stream_send_channel.aclose() diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 21d2749f..79675749 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -1,7 +1,9 @@ -import asyncio from typing import TYPE_CHECKING +import trio + from libp2p.stream_muxer.abc import IMuxedStream +from libp2p.stream_muxer.exceptions import MuxedConnUnavailable from .constants import HeaderTags from .datastructures import StreamID @@ -22,18 +24,25 @@ class MplexStream(IMuxedStream): read_deadline: int write_deadline: int - close_lock: asyncio.Lock + # TODO: Add lock for read/write to avoid interleaving receiving messages? + close_lock: trio.Lock # NOTE: `dataIn` is size of 8 in Go implementation. - incoming_data: "asyncio.Queue[bytes]" + incoming_data_channel: "trio.MemoryReceiveChannel[bytes]" - event_local_closed: asyncio.Event - event_remote_closed: asyncio.Event - event_reset: asyncio.Event + event_local_closed: trio.Event + event_remote_closed: trio.Event + event_reset: trio.Event _buf: bytearray - def __init__(self, name: str, stream_id: StreamID, muxed_conn: "Mplex") -> None: + def __init__( + self, + name: str, + stream_id: StreamID, + muxed_conn: "Mplex", + incoming_data_channel: "trio.MemoryReceiveChannel[bytes]", + ) -> None: """ create new MuxedStream in muxer. @@ -45,99 +54,82 @@ class MplexStream(IMuxedStream): self.muxed_conn = muxed_conn self.read_deadline = None self.write_deadline = None - self.event_local_closed = asyncio.Event() - self.event_remote_closed = asyncio.Event() - self.event_reset = asyncio.Event() - self.close_lock = asyncio.Lock() - self.incoming_data = asyncio.Queue() + self.event_local_closed = trio.Event() + self.event_remote_closed = trio.Event() + self.event_reset = trio.Event() + self.close_lock = trio.Lock() + self.incoming_data_channel = incoming_data_channel self._buf = bytearray() @property def is_initiator(self) -> bool: return self.stream_id.is_initiator - async def _wait_for_data(self) -> None: - task_event_reset = asyncio.ensure_future(self.event_reset.wait()) - task_incoming_data_get = asyncio.ensure_future(self.incoming_data.get()) - task_event_remote_closed = asyncio.ensure_future( - self.event_remote_closed.wait() - ) - done, pending = await asyncio.wait( # type: ignore - [ # type: ignore - task_event_reset, - task_incoming_data_get, - task_event_remote_closed, - ], - return_when=asyncio.FIRST_COMPLETED, - ) - for fut in pending: - fut.cancel() - - if task_event_reset in done: - if self.event_reset.is_set(): - raise MplexStreamReset() - else: - # However, it is abnormal that `Event.wait` is unblocked without any of the flag - # is set. The task is probably cancelled. - raise Exception( - "Should not enter here. " - f"It is probably because {task_event_remote_closed} is cancelled." - ) - - if task_incoming_data_get in done: - data = task_incoming_data_get.result() - self._buf.extend(data) - return - - if task_event_remote_closed in done: - if self.event_remote_closed.is_set(): - raise MplexStreamEOF() - else: - # However, it is abnormal that `Event.wait` is unblocked without any of the flag - # is set. The task is probably cancelled. - raise Exception( - "Should not enter here. " - f"It is probably because {task_event_remote_closed} is cancelled." - ) - - # TODO: Handle timeout when deadline is used. - async def _read_until_eof(self) -> bytes: - while True: - try: - await self._wait_for_data() - except MplexStreamEOF: - break + async for data in self.incoming_data_channel: + self._buf.extend(data) payload = self._buf self._buf = self._buf[len(payload) :] return bytes(payload) - async def read(self, n: int = -1) -> bytes: + def _read_return_when_blocked(self) -> bytes: + buf = bytearray() + while True: + try: + data = self.incoming_data_channel.receive_nowait() + buf.extend(data) + except (trio.WouldBlock, trio.EndOfChannel): + break + return buf + + async def read(self, n: int = None) -> bytes: """ Read up to n bytes. Read possibly returns fewer than `n` bytes, if - there are not enough bytes in the Mplex buffer. If `n == -1`, read + there are not enough bytes in the Mplex buffer. If `n is None`, read until EOF. :param n: number of bytes to read :return: bytes actually read """ - if n < 0 and n != -1: + if n is not None and n < 0: raise ValueError( - f"the number of bytes to read `n` must be positive or -1 to indicate read until EOF" + f"the number of bytes to read `n` must be non-negative or " + "`None` to indicate read until EOF" ) if self.event_reset.is_set(): - raise MplexStreamReset() - if n == -1: + raise MplexStreamReset + if n is None: return await self._read_until_eof() - if len(self._buf) == 0 and self.incoming_data.empty(): - await self._wait_for_data() - # Now we are sure we have something to read. - # Try to put enough incoming data into `self._buf`. - while len(self._buf) < n: + if len(self._buf) == 0: + data: bytes + # Peek whether there is data available. If yes, we just read until there is no data, + # and then return. try: - self._buf.extend(self.incoming_data.get_nowait()) - except asyncio.QueueEmpty: - break + data = self.incoming_data_channel.receive_nowait() + self._buf.extend(data) + except trio.EndOfChannel: + raise MplexStreamEOF + except trio.WouldBlock: + # We know `receive` will be blocked here. Wait for data here with `receive` and + # catch all kinds of errors here. + try: + data = await self.incoming_data_channel.receive() + self._buf.extend(data) + except trio.EndOfChannel: + if self.event_reset.is_set(): + raise MplexStreamReset + if self.event_remote_closed.is_set(): + raise MplexStreamEOF + except trio.ClosedResourceError as error: + # Probably `incoming_data_channel` is closed in `reset` when we are waiting + # for `receive`. + if self.event_reset.is_set(): + raise MplexStreamReset + raise Exception( + "`incoming_data_channel` is closed but stream is not reset. " + "This should never happen." + ) from error + self._buf.extend(self._read_return_when_blocked()) payload = self._buf[:n] self._buf = self._buf[len(payload) :] return bytes(payload) @@ -198,14 +190,17 @@ class MplexStream(IMuxedStream): if self.is_initiator else HeaderTags.ResetReceiver ) - asyncio.ensure_future( - self.muxed_conn.send_message(flag, None, self.stream_id) - ) - await asyncio.sleep(0) + # Try to send reset message to the other side. Ignore if there is anything wrong. + try: + await self.muxed_conn.send_message(flag, None, self.stream_id) + except MuxedConnUnavailable: + pass self.event_local_closed.set() self.event_remote_closed.set() + await self.incoming_data_channel.aclose() + async with self.muxed_conn.streams_lock: if self.muxed_conn.streams is not None: self.muxed_conn.streams.pop(self.stream_id, None) diff --git a/libp2p/tools/constants.py b/libp2p/tools/constants.py index 8c22d151..b1ad2652 100644 --- a/libp2p/tools/constants.py +++ b/libp2p/tools/constants.py @@ -7,7 +7,7 @@ from libp2p.pubsub import floodsub, gossipsub # Just a arbitrary large number. # It is used when calling `MplexStream.read(MAX_READ_LEN)`, # to avoid `MplexStream.read()`, which blocking reads until EOF. -MAX_READ_LEN = 2 ** 32 - 1 +MAX_READ_LEN = 65535 LISTEN_MADDR = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0") diff --git a/libp2p/tools/factories.py b/libp2p/tools/factories.py index 2d1f99ad..67e26519 100644 --- a/libp2p/tools/factories.py +++ b/libp2p/tools/factories.py @@ -1,40 +1,55 @@ -import asyncio -from typing import Any, AsyncIterator, Dict, Tuple, cast +from typing import Any, AsyncIterator, Dict, List, Sequence, Tuple, cast -# NOTE: import ``asynccontextmanager`` from ``contextlib`` when support for python 3.6 is dropped. +from async_exit_stack import AsyncExitStack from async_generator import asynccontextmanager +from async_service import background_trio_service import factory +from multiaddr import Multiaddr +import trio from libp2p import generate_new_rsa_identity, generate_peer_id_from from libp2p.crypto.keys import KeyPair from libp2p.host.basic_host import BasicHost +from libp2p.host.host_interface import IHost +from libp2p.host.routed_host import RoutedHost +from libp2p.io.abc import ReadWriteCloser +from libp2p.network.connection.raw_connection import RawConnection +from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.connection.swarm_connection import SwarmConn from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.swarm import Swarm from libp2p.peer.id import ID +from libp2p.peer.peerinfo import PeerInfo from libp2p.peer.peerstore import PeerStore +from libp2p.pubsub.abc import IPubsubRouter from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.gossipsub import GossipSub from libp2p.pubsub.pubsub import Pubsub +from libp2p.routing.interfaces import IPeerRouting from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport import libp2p.security.secio.transport as secio from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex from libp2p.stream_muxer.mplex.mplex_stream import MplexStream +from libp2p.tools.constants import GOSSIPSUB_PARAMS from libp2p.transport.tcp.tcp import TCP from libp2p.transport.typing import TMuxerOptions from libp2p.transport.upgrader import TransportUpgrader from libp2p.typing import TProtocol -from .constants import ( - FLOODSUB_PROTOCOL_ID, - GOSSIPSUB_PARAMS, - GOSSIPSUB_PROTOCOL_ID, - LISTEN_MADDR, -) +from .constants import FLOODSUB_PROTOCOL_ID, GOSSIPSUB_PROTOCOL_ID, LISTEN_MADDR from .utils import connect, connect_swarm +class IDFactory(factory.Factory): + class Meta: + model = ID + + peer_id_bytes = factory.LazyFunction( + lambda: generate_peer_id_from(generate_new_rsa_identity()) + ) + + def initialize_peerstore_with_our_keypair(self_id: ID, key_pair: KeyPair) -> PeerStore: peer_store = PeerStore() peer_store.add_key_pair(self_id, key_pair) @@ -50,6 +65,29 @@ def security_transport_factory( return {secio.ID: secio.Transport(key_pair)} +@asynccontextmanager +async def raw_conn_factory( + nursery: trio.Nursery +) -> AsyncIterator[Tuple[IRawConnection, IRawConnection]]: + conn_0 = None + conn_1 = None + event = trio.Event() + + async def tcp_stream_handler(stream: ReadWriteCloser) -> None: + nonlocal conn_1 + conn_1 = RawConnection(stream, initiator=False) + event.set() + await trio.sleep_forever() + + tcp_transport = TCP() + listener = tcp_transport.create_listener(tcp_stream_handler) + await listener.listen(LISTEN_MADDR, nursery) + listening_maddr = listener.get_addrs()[0] + conn_0 = await tcp_transport.dial(listening_maddr) + await event.wait() + yield conn_0, conn_1 + + class SwarmFactory(factory.Factory): class Meta: model = Swarm @@ -71,9 +109,10 @@ class SwarmFactory(factory.Factory): transport = factory.LazyFunction(TCP) @classmethod + @asynccontextmanager async def create_and_listen( cls, is_secure: bool, key_pair: KeyPair = None, muxer_opt: TMuxerOptions = None - ) -> Swarm: + ) -> AsyncIterator[Swarm]: # `factory.Factory.__init__` does *not* prepare a *default value* if we pass # an argument explicitly with `None`. If an argument is `None`, we don't pass it to # `factory.Factory.__init__`, in order to let the function initialize it. @@ -83,20 +122,23 @@ class SwarmFactory(factory.Factory): if muxer_opt is not None: optional_kwargs["muxer_opt"] = muxer_opt swarm = cls(is_secure=is_secure, **optional_kwargs) - await swarm.listen(LISTEN_MADDR) - return swarm + async with background_trio_service(swarm): + await swarm.listen(LISTEN_MADDR) + yield swarm @classmethod + @asynccontextmanager async def create_batch_and_listen( cls, is_secure: bool, number: int, muxer_opt: TMuxerOptions = None - ) -> Tuple[Swarm, ...]: - # Ignore typing since we are removing asyncio soon - return await asyncio.gather( # type: ignore - *[ - cls.create_and_listen(is_secure=is_secure, muxer_opt=muxer_opt) + ) -> AsyncIterator[Tuple[Swarm, ...]]: + async with AsyncExitStack() as stack: + ctx_mgrs = [ + await stack.enter_async_context( + cls.create_and_listen(is_secure=is_secure, muxer_opt=muxer_opt) + ) for _ in range(number) ] - ) + yield tuple(ctx_mgrs) class HostFactory(factory.Factory): @@ -107,22 +149,57 @@ class HostFactory(factory.Factory): is_secure = False key_pair = factory.LazyFunction(generate_new_rsa_identity) - network = factory.LazyAttribute( - lambda o: SwarmFactory(is_secure=o.is_secure, key_pair=o.key_pair) - ) + network = factory.LazyAttribute(lambda o: SwarmFactory(is_secure=o.is_secure)) @classmethod + @asynccontextmanager async def create_batch_and_listen( cls, is_secure: bool, number: int - ) -> Tuple[BasicHost, ...]: - key_pairs = [generate_new_rsa_identity() for _ in range(number)] - swarms = await asyncio.gather( - *[ - SwarmFactory.create_and_listen(is_secure, key_pair) - for key_pair in key_pairs - ] - ) - return tuple(BasicHost(swarm) for swarm in swarms) + ) -> AsyncIterator[Tuple[BasicHost, ...]]: + async with SwarmFactory.create_batch_and_listen(is_secure, number) as swarms: + hosts = tuple(BasicHost(swarm) for swarm in swarms) + yield hosts + + +class DummyRouter(IPeerRouting): + _routing_table: Dict[ID, PeerInfo] + + def __init__(self) -> None: + self._routing_table = dict() + + def _add_peer(self, peer_id: ID, addrs: List[Multiaddr]) -> None: + self._routing_table[peer_id] = PeerInfo(peer_id, addrs) + + async def find_peer(self, peer_id: ID) -> PeerInfo: + await trio.hazmat.checkpoint() + return self._routing_table.get(peer_id, None) + + +class RoutedHostFactory(factory.Factory): + class Meta: + model = RoutedHost + + class Params: + is_secure = False + + network = factory.LazyAttribute( + lambda o: HostFactory(is_secure=o.is_secure).get_network() + ) + router = factory.LazyFunction(DummyRouter) + + @classmethod + @asynccontextmanager + async def create_batch_and_listen( + cls, is_secure: bool, number: int + ) -> AsyncIterator[Tuple[RoutedHost, ...]]: + routing_table = DummyRouter() + async with HostFactory.create_batch_and_listen(is_secure, number) as hosts: + for host in hosts: + routing_table._add_peer(host.get_id(), host.get_addrs()) + routed_hosts = tuple( + RoutedHost(host.get_network(), routing_table) for host in hosts + ) + yield routed_hosts class FloodsubFactory(factory.Factory): @@ -153,89 +230,192 @@ class PubsubFactory(factory.Factory): host = factory.SubFactory(HostFactory) router = None - my_id = factory.LazyAttribute(lambda obj: obj.host.get_id()) cache_size = None strict_signing = False + @classmethod + @asynccontextmanager + async def create_and_start( + cls, host: IHost, router: IPubsubRouter, cache_size: int, strict_signing: bool + ) -> AsyncIterator[Pubsub]: + pubsub = cls( + host=host, + router=router, + cache_size=cache_size, + strict_signing=strict_signing, + ) + async with background_trio_service(pubsub): + await pubsub.wait_until_ready() + yield pubsub + @classmethod + @asynccontextmanager + async def _create_batch_with_router( + cls, + number: int, + routers: Sequence[IPubsubRouter], + is_secure: bool = False, + cache_size: int = None, + strict_signing: bool = False, + ) -> AsyncIterator[Tuple[Pubsub, ...]]: + async with HostFactory.create_batch_and_listen(is_secure, number) as hosts: + # Pubsubs should exit before hosts + async with AsyncExitStack() as stack: + pubsubs = [ + await stack.enter_async_context( + cls.create_and_start(host, router, cache_size, strict_signing) + ) + for host, router in zip(hosts, routers) + ] + yield tuple(pubsubs) + + @classmethod + @asynccontextmanager + async def create_batch_with_floodsub( + cls, + number: int, + is_secure: bool = False, + cache_size: int = None, + strict_signing: bool = False, + protocols: Sequence[TProtocol] = None, + ) -> AsyncIterator[Tuple[Pubsub, ...]]: + if protocols is not None: + floodsubs = FloodsubFactory.create_batch(number, protocols=list(protocols)) + else: + floodsubs = FloodsubFactory.create_batch(number) + async with cls._create_batch_with_router( + number, floodsubs, is_secure, cache_size, strict_signing + ) as pubsubs: + yield pubsubs + + @classmethod + @asynccontextmanager + async def create_batch_with_gossipsub( + cls, + number: int, + *, + is_secure: bool = False, + cache_size: int = None, + strict_signing: bool = False, + protocols: Sequence[TProtocol] = None, + degree: int = GOSSIPSUB_PARAMS.degree, + degree_low: int = GOSSIPSUB_PARAMS.degree_low, + degree_high: int = GOSSIPSUB_PARAMS.degree_high, + time_to_live: int = GOSSIPSUB_PARAMS.time_to_live, + gossip_window: int = GOSSIPSUB_PARAMS.gossip_window, + gossip_history: int = GOSSIPSUB_PARAMS.gossip_history, + heartbeat_interval: float = GOSSIPSUB_PARAMS.heartbeat_interval, + heartbeat_initial_delay: float = GOSSIPSUB_PARAMS.heartbeat_initial_delay, + ) -> AsyncIterator[Tuple[Pubsub, ...]]: + if protocols is not None: + gossipsubs = GossipsubFactory.create_batch( + number, + protocols=protocols, + degree=degree, + degree_low=degree_low, + degree_high=degree_high, + time_to_live=time_to_live, + gossip_window=gossip_window, + heartbeat_interval=heartbeat_interval, + ) + else: + gossipsubs = GossipsubFactory.create_batch( + number, + degree=degree, + degree_low=degree_low, + degree_high=degree_high, + time_to_live=time_to_live, + gossip_window=gossip_window, + heartbeat_interval=heartbeat_interval, + ) + + async with cls._create_batch_with_router( + number, gossipsubs, is_secure, cache_size, strict_signing + ) as pubsubs: + async with AsyncExitStack() as stack: + for router in gossipsubs: + await stack.enter_async_context(background_trio_service(router)) + yield pubsubs + + +@asynccontextmanager async def swarm_pair_factory( is_secure: bool, muxer_opt: TMuxerOptions = None -) -> Tuple[Swarm, Swarm]: - swarms = await SwarmFactory.create_batch_and_listen( +) -> AsyncIterator[Tuple[Swarm, Swarm]]: + async with SwarmFactory.create_batch_and_listen( is_secure, 2, muxer_opt=muxer_opt - ) - await connect_swarm(swarms[0], swarms[1]) - return swarms[0], swarms[1] + ) as swarms: + await connect_swarm(swarms[0], swarms[1]) + yield swarms[0], swarms[1] -async def host_pair_factory(is_secure: bool) -> Tuple[BasicHost, BasicHost]: - hosts = await HostFactory.create_batch_and_listen(is_secure, 2) - await connect(hosts[0], hosts[1]) - return hosts[0], hosts[1] - - -@asynccontextmanager # type: ignore -async def pair_of_connected_hosts( - is_secure: bool = True +@asynccontextmanager +async def host_pair_factory( + is_secure: bool ) -> AsyncIterator[Tuple[BasicHost, BasicHost]]: - a, b = await host_pair_factory(is_secure) - yield a, b - close_tasks = (a.close(), b.close()) - await asyncio.gather(*close_tasks) + async with HostFactory.create_batch_and_listen(is_secure, 2) as hosts: + await connect(hosts[0], hosts[1]) + yield hosts[0], hosts[1] +@asynccontextmanager async def swarm_conn_pair_factory( is_secure: bool, muxer_opt: TMuxerOptions = None -) -> Tuple[SwarmConn, Swarm, SwarmConn, Swarm]: - swarms = await swarm_pair_factory(is_secure) - conn_0 = swarms[0].connections[swarms[1].get_peer_id()] - conn_1 = swarms[1].connections[swarms[0].get_peer_id()] - return cast(SwarmConn, conn_0), swarms[0], cast(SwarmConn, conn_1), swarms[1] +) -> AsyncIterator[Tuple[SwarmConn, SwarmConn]]: + async with swarm_pair_factory(is_secure) as swarms: + conn_0 = swarms[0].connections[swarms[1].get_peer_id()] + conn_1 = swarms[1].connections[swarms[0].get_peer_id()] + yield cast(SwarmConn, conn_0), cast(SwarmConn, conn_1) -async def mplex_conn_pair_factory(is_secure: bool) -> Tuple[Mplex, Swarm, Mplex, Swarm]: +@asynccontextmanager +async def mplex_conn_pair_factory( + is_secure: bool +) -> AsyncIterator[Tuple[Mplex, Mplex]]: muxer_opt = {MPLEX_PROTOCOL_ID: Mplex} - conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory( - is_secure, muxer_opt=muxer_opt - ) - return ( - cast(Mplex, conn_0.muxed_conn), - swarm_0, - cast(Mplex, conn_1.muxed_conn), - swarm_1, - ) + async with swarm_conn_pair_factory(is_secure, muxer_opt=muxer_opt) as swarm_pair: + yield ( + cast(Mplex, swarm_pair[0].muxed_conn), + cast(Mplex, swarm_pair[1].muxed_conn), + ) +@asynccontextmanager async def mplex_stream_pair_factory( is_secure: bool -) -> Tuple[MplexStream, Swarm, MplexStream, Swarm]: - mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory( - is_secure - ) - stream_0 = await mplex_conn_0.open_stream() - await asyncio.sleep(0.01) - stream_1: MplexStream - async with mplex_conn_1.streams_lock: - if len(mplex_conn_1.streams) != 1: - raise Exception("Mplex should not have any stream upon connection") - stream_1 = tuple(mplex_conn_1.streams.values())[0] - return cast(MplexStream, stream_0), swarm_0, stream_1, swarm_1 +) -> AsyncIterator[Tuple[MplexStream, MplexStream]]: + async with mplex_conn_pair_factory(is_secure) as mplex_conn_pair_info: + mplex_conn_0, mplex_conn_1 = mplex_conn_pair_info + stream_0 = cast(MplexStream, await mplex_conn_0.open_stream()) + await trio.sleep(0.01) + stream_1: MplexStream + async with mplex_conn_1.streams_lock: + if len(mplex_conn_1.streams) != 1: + raise Exception("Mplex should not have any other stream") + stream_1 = tuple(mplex_conn_1.streams.values())[0] + yield stream_0, stream_1 +@asynccontextmanager async def net_stream_pair_factory( is_secure: bool -) -> Tuple[INetStream, BasicHost, INetStream, BasicHost]: +) -> AsyncIterator[Tuple[INetStream, INetStream]]: protocol_id = TProtocol("/example/id/1") stream_1: INetStream - # Just a proxy, we only care about the stream - def handler(stream: INetStream) -> None: + # Just a proxy, we only care about the stream. + # Add a barrier to avoid stream being removed. + event_handler_finished = trio.Event() + + async def handler(stream: INetStream) -> None: nonlocal stream_1 stream_1 = stream + await event_handler_finished.wait() - host_0, host_1 = await host_pair_factory(is_secure) - host_1.set_stream_handler(protocol_id, handler) + async with host_pair_factory(is_secure) as hosts: + hosts[1].set_stream_handler(protocol_id, handler) - stream_0 = await host_0.new_stream(host_1.get_id(), [protocol_id]) - return stream_0, host_0, stream_1, host_1 + stream_0 = await hosts[0].new_stream(hosts[1].get_id(), [protocol_id]) + yield stream_0, stream_1 + event_handler_finished.set() diff --git a/libp2p/tools/interop/constants.py b/libp2p/tools/interop/constants.py index 331e2843..8f039f1c 100644 --- a/libp2p/tools/interop/constants.py +++ b/libp2p/tools/interop/constants.py @@ -1,2 +1 @@ LOCALHOST_IP = "127.0.0.1" -PEXPECT_NEW_LINE = "\r\n" diff --git a/libp2p/tools/interop/daemon.py b/libp2p/tools/interop/daemon.py index 9398ab73..3344255a 100644 --- a/libp2p/tools/interop/daemon.py +++ b/libp2p/tools/interop/daemon.py @@ -1,52 +1,22 @@ -import asyncio -import time -from typing import Any, Awaitable, Callable, List +from typing import AsyncIterator +from async_generator import asynccontextmanager import multiaddr from multiaddr import Multiaddr from p2pclient import Client -import pytest +import trio from libp2p.peer.id import ID from libp2p.peer.peerinfo import PeerInfo, info_from_p2p_addr from .constants import LOCALHOST_IP from .envs import GO_BIN_PATH +from .process import BaseInteractiveProcess P2PD_PATH = GO_BIN_PATH / "p2pd" -TIMEOUT_DURATION = 30 - - -async def try_until_success( - coro_func: Callable[[], Awaitable[Any]], timeout: int = TIMEOUT_DURATION -) -> None: - """ - Keep running ``coro_func`` until either it succeed or time is up. - - All arguments of ``coro_func`` should be filled, i.e. it should be - called without arguments. - """ - t_start = time.monotonic() - while True: - result = await coro_func() - if result: - break - if (time.monotonic() - t_start) >= timeout: - # timeout - pytest.fail(f"{coro_func} is still failing after `{timeout}` seconds") - await asyncio.sleep(0.01) - - -class P2PDProcess: - proc: asyncio.subprocess.Process - cmd: str = str(P2PD_PATH) - args: List[Any] - is_proc_running: bool - - _tasks: List["asyncio.Future[Any]"] - +class P2PDProcess(BaseInteractiveProcess): def __init__( self, control_maddr: Multiaddr, @@ -75,74 +45,21 @@ class P2PDProcess: # - gossipsubHeartbeatInterval: GossipSubHeartbeatInitialDelay = 100 * time.Millisecond # noqa: E501 # - gossipsubHeartbeatInitialDelay: GossipSubHeartbeatInterval = 1 * time.Second # Referece: https://github.com/libp2p/go-libp2p-daemon/blob/b95e77dbfcd186ccf817f51e95f73f9fd5982600/p2pd/main.go#L348-L353 # noqa: E501 + self.proc = None + self.cmd = str(P2PD_PATH) self.args = args - self.is_proc_running = False - - self._tasks = [] - - async def wait_until_ready(self) -> None: - lines_head_pattern = (b"Control socket:", b"Peer ID:", b"Peer Addrs:") - lines_head_occurred = {line: False for line in lines_head_pattern} - - async def read_from_daemon_and_check() -> bool: - line = await self.proc.stdout.readline() - for head_pattern in lines_head_occurred: - if line.startswith(head_pattern): - lines_head_occurred[head_pattern] = True - return all([value for value in lines_head_occurred.values()]) - - await try_until_success(read_from_daemon_and_check) - # Sleep a little bit to ensure the listener is up after logs are emitted. - await asyncio.sleep(0.01) - - async def start_printing_logs(self) -> None: - async def _print_from_stream( - src_name: str, reader: asyncio.StreamReader - ) -> None: - while True: - line = await reader.readline() - if line != b"": - print(f"{src_name}\t: {line.rstrip().decode()}") - await asyncio.sleep(0.01) - - self._tasks.append( - asyncio.ensure_future(_print_from_stream("out", self.proc.stdout)) - ) - self._tasks.append( - asyncio.ensure_future(_print_from_stream("err", self.proc.stderr)) - ) - await asyncio.sleep(0) - - async def start(self) -> None: - if self.is_proc_running: - return - self.proc = await asyncio.subprocess.create_subprocess_exec( - self.cmd, - *self.args, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - bufsize=0, - ) - self.is_proc_running = True - await self.wait_until_ready() - await self.start_printing_logs() - - async def close(self) -> None: - if self.is_proc_running: - self.proc.terminate() - await self.proc.wait() - self.is_proc_running = False - for task in self._tasks: - task.cancel() + self.patterns = (b"Control socket:", b"Peer ID:", b"Peer Addrs:") + self.bytes_read = bytearray() + self.event_ready = trio.Event() class Daemon: - p2pd_proc: P2PDProcess + p2pd_proc: BaseInteractiveProcess control: Client peer_info: PeerInfo def __init__( - self, p2pd_proc: P2PDProcess, control: Client, peer_info: PeerInfo + self, p2pd_proc: BaseInteractiveProcess, control: Client, peer_info: PeerInfo ) -> None: self.p2pd_proc = p2pd_proc self.control = control @@ -164,6 +81,7 @@ class Daemon: await self.control.close() +@asynccontextmanager async def make_p2pd( daemon_control_port: int, client_callback_port: int, @@ -172,7 +90,7 @@ async def make_p2pd( is_gossipsub: bool = True, is_pubsub_signing: bool = False, is_pubsub_signing_strict: bool = False, -) -> Daemon: +) -> AsyncIterator[Daemon]: control_maddr = Multiaddr(f"/ip4/{LOCALHOST_IP}/tcp/{daemon_control_port}") p2pd_proc = P2PDProcess( control_maddr, @@ -185,21 +103,22 @@ async def make_p2pd( await p2pd_proc.start() client_callback_maddr = Multiaddr(f"/ip4/{LOCALHOST_IP}/tcp/{client_callback_port}") p2pc = Client(control_maddr, client_callback_maddr) - await p2pc.listen() - peer_id, maddrs = await p2pc.identify() - listen_maddr: Multiaddr = None - for maddr in maddrs: - try: - ip = maddr.value_for_protocol(multiaddr.protocols.P_IP4) - # NOTE: Check if this `maddr` uses `tcp`. - maddr.value_for_protocol(multiaddr.protocols.P_TCP) - except multiaddr.exceptions.ProtocolLookupError: - continue - if ip == LOCALHOST_IP: - listen_maddr = maddr - break - assert listen_maddr is not None, "no loopback maddr is found" - peer_info = info_from_p2p_addr( - listen_maddr.encapsulate(Multiaddr(f"/p2p/{peer_id.to_string()}")) - ) - return Daemon(p2pd_proc, p2pc, peer_info) + + async with p2pc.listen(): + peer_id, maddrs = await p2pc.identify() + listen_maddr: Multiaddr = None + for maddr in maddrs: + try: + ip = maddr.value_for_protocol(multiaddr.protocols.P_IP4) + # NOTE: Check if this `maddr` uses `tcp`. + maddr.value_for_protocol(multiaddr.protocols.P_TCP) + except multiaddr.exceptions.ProtocolLookupError: + continue + if ip == LOCALHOST_IP: + listen_maddr = maddr + break + assert listen_maddr is not None, "no loopback maddr is found" + peer_info = info_from_p2p_addr( + listen_maddr.encapsulate(Multiaddr(f"/p2p/{peer_id.to_string()}")) + ) + yield Daemon(p2pd_proc, p2pc, peer_info) diff --git a/libp2p/tools/interop/process.py b/libp2p/tools/interop/process.py new file mode 100644 index 00000000..0c17e51b --- /dev/null +++ b/libp2p/tools/interop/process.py @@ -0,0 +1,66 @@ +from abc import ABC, abstractmethod +import subprocess +from typing import Iterable, List + +import trio + +TIMEOUT_DURATION = 30 + + +class AbstractInterativeProcess(ABC): + @abstractmethod + async def start(self) -> None: + ... + + @abstractmethod + async def close(self) -> None: + ... + + +class BaseInteractiveProcess(AbstractInterativeProcess): + proc: trio.Process = None + cmd: str + args: List[str] + bytes_read: bytearray + patterns: Iterable[bytes] = None + event_ready: trio.Event + + async def wait_until_ready(self) -> None: + patterns_occurred = {pat: False for pat in self.patterns} + + async def read_from_daemon_and_check() -> None: + async for data in self.proc.stdout: + # TODO: It takes O(n^2), which is quite bad. + # But it should succeed in a few seconds. + self.bytes_read.extend(data) + for pat, occurred in patterns_occurred.items(): + if occurred: + continue + if pat in self.bytes_read: + patterns_occurred[pat] = True + if all([value for value in patterns_occurred.values()]): + return + + with trio.fail_after(TIMEOUT_DURATION): + await read_from_daemon_and_check() + self.event_ready.set() + # Sleep a little bit to ensure the listener is up after logs are emitted. + await trio.sleep(0.01) + + async def start(self) -> None: + if self.proc is not None: + return + # NOTE: Ignore type checks here since mypy complains about bufsize=0 + self.proc = await trio.open_process( # type: ignore + [self.cmd] + self.args, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, # Redirect stderr to stdout, which makes parsing easier + bufsize=0, + ) + await self.wait_until_ready() + + async def close(self) -> None: + if self.proc is None: + return + self.proc.terminate() + await self.proc.wait() diff --git a/libp2p/tools/interop/utils.py b/libp2p/tools/interop/utils.py index c9174179..ce05c8fb 100644 --- a/libp2p/tools/interop/utils.py +++ b/libp2p/tools/interop/utils.py @@ -1,7 +1,7 @@ -import asyncio from typing import Union from multiaddr import Multiaddr +import trio from libp2p.host.host_interface import IHost from libp2p.peer.id import ID @@ -50,7 +50,7 @@ async def connect(a: TDaemonOrHost, b: TDaemonOrHost) -> None: else: # isinstance(b, IHost) await a.connect(b_peer_info) # Allow additional sleep for both side to establish the connection. - await asyncio.sleep(0.1) + await trio.sleep(0.1) a_peer_info = _get_peer_info(a) diff --git a/libp2p/tools/pubsub/dummy_account_node.py b/libp2p/tools/pubsub/dummy_account_node.py index b5b18f5d..32f78851 100644 --- a/libp2p/tools/pubsub/dummy_account_node.py +++ b/libp2p/tools/pubsub/dummy_account_node.py @@ -1,12 +1,12 @@ -import asyncio -from typing import Dict -import uuid +from typing import AsyncIterator, Dict, Tuple + +from async_exit_stack import AsyncExitStack +from async_generator import asynccontextmanager +from async_service import Service, background_trio_service from libp2p.host.host_interface import IHost -from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.pubsub import Pubsub -from libp2p.tools.constants import LISTEN_MADDR -from libp2p.tools.factories import FloodsubFactory, PubsubFactory +from libp2p.tools.factories import PubsubFactory CRYPTO_TOPIC = "ethereum" @@ -18,7 +18,7 @@ CRYPTO_TOPIC = "ethereum" # Determine message type by looking at first item before first comma -class DummyAccountNode: +class DummyAccountNode(Service): """ Node which has an internal balance mapping, meant to serve as a dummy crypto blockchain. @@ -27,19 +27,24 @@ class DummyAccountNode: crypto each user in the mappings holds """ - libp2p_node: IHost pubsub: Pubsub - floodsub: FloodSub - def __init__(self, libp2p_node: IHost, pubsub: Pubsub, floodsub: FloodSub): - self.libp2p_node = libp2p_node + def __init__(self, pubsub: Pubsub) -> None: self.pubsub = pubsub - self.floodsub = floodsub self.balances: Dict[str, int] = {} - self.node_id = str(uuid.uuid1()) + + @property + def host(self) -> IHost: + return self.pubsub.host + + async def run(self) -> None: + self.subscription = await self.pubsub.subscribe(CRYPTO_TOPIC) + self.manager.run_daemon_task(self.handle_incoming_msgs) + await self.manager.wait_finished() @classmethod - async def create(cls) -> "DummyAccountNode": + @asynccontextmanager + async def create(cls, number: int) -> AsyncIterator[Tuple["DummyAccountNode", ...]]: """ Create a new DummyAccountNode and attach a libp2p node, a floodsub, and a pubsub instance to this new node. @@ -47,15 +52,17 @@ class DummyAccountNode: We use create as this serves as a factory function and allows us to use async await, unlike the init function """ - - pubsub = PubsubFactory(router=FloodsubFactory()) - await pubsub.host.get_network().listen(LISTEN_MADDR) - return cls(libp2p_node=pubsub.host, pubsub=pubsub, floodsub=pubsub.router) + async with PubsubFactory.create_batch_with_floodsub(number) as pubsubs: + async with AsyncExitStack() as stack: + dummy_acount_nodes = tuple(cls(pubsub) for pubsub in pubsubs) + for node in dummy_acount_nodes: + await stack.enter_async_context(background_trio_service(node)) + yield dummy_acount_nodes async def handle_incoming_msgs(self) -> None: """Handle all incoming messages on the CRYPTO_TOPIC from peers.""" while True: - incoming = await self.q.get() + incoming = await self.subscription.get() msg_comps = incoming.data.decode("utf-8").split(",") if msg_comps[0] == "send": @@ -63,13 +70,6 @@ class DummyAccountNode: elif msg_comps[0] == "set": self.handle_set_crypto(msg_comps[1], int(msg_comps[2])) - async def setup_crypto_networking(self) -> None: - """Subscribe to CRYPTO_TOPIC and perform call to function that handles - all incoming messages on said topic.""" - self.q = await self.pubsub.subscribe(CRYPTO_TOPIC) - - asyncio.ensure_future(self.handle_incoming_msgs()) - async def publish_send_crypto( self, source_user: str, dest_user: str, amount: int ) -> None: diff --git a/libp2p/tools/pubsub/floodsub_integration_test_settings.py b/libp2p/tools/pubsub/floodsub_integration_test_settings.py index d6b5b678..3d3325fc 100644 --- a/libp2p/tools/pubsub/floodsub_integration_test_settings.py +++ b/libp2p/tools/pubsub/floodsub_integration_test_settings.py @@ -1,12 +1,10 @@ # type: ignore # To add typing to this module, it's better to do it after refactoring test cases into classes -import asyncio - import pytest +import trio -from libp2p.tools.constants import FLOODSUB_PROTOCOL_ID, LISTEN_MADDR -from libp2p.tools.factories import PubsubFactory +from libp2p.tools.constants import FLOODSUB_PROTOCOL_ID from libp2p.tools.utils import connect SUPPORTED_PROTOCOLS = [FLOODSUB_PROTOCOL_ID] @@ -15,6 +13,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [ { "name": "simple_two_nodes", "supported_protocols": SUPPORTED_PROTOCOLS, + "nodes": ["A", "B"], "adj_list": {"A": ["B"]}, "topic_map": {"topic1": ["B"]}, "messages": [{"topics": ["topic1"], "data": b"foo", "node_id": "A"}], @@ -22,6 +21,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [ { "name": "three_nodes_two_topics", "supported_protocols": SUPPORTED_PROTOCOLS, + "nodes": ["A", "B", "C"], "adj_list": {"A": ["B"], "B": ["C"]}, "topic_map": {"topic1": ["B", "C"], "topic2": ["B", "C"]}, "messages": [ @@ -32,6 +32,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [ { "name": "two_nodes_one_topic_single_subscriber_is_sender", "supported_protocols": SUPPORTED_PROTOCOLS, + "nodes": ["A", "B"], "adj_list": {"A": ["B"]}, "topic_map": {"topic1": ["B"]}, "messages": [{"topics": ["topic1"], "data": b"Alex is tall", "node_id": "B"}], @@ -39,6 +40,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [ { "name": "two_nodes_one_topic_two_msgs", "supported_protocols": SUPPORTED_PROTOCOLS, + "nodes": ["A", "B"], "adj_list": {"A": ["B"]}, "topic_map": {"topic1": ["B"]}, "messages": [ @@ -49,6 +51,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [ { "name": "seven_nodes_tree_one_topics", "supported_protocols": SUPPORTED_PROTOCOLS, + "nodes": ["1", "2", "3", "4", "5", "6", "7"], "adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]}, "topic_map": {"astrophysics": ["2", "3", "4", "5", "6", "7"]}, "messages": [{"topics": ["astrophysics"], "data": b"e=mc^2", "node_id": "1"}], @@ -56,6 +59,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [ { "name": "seven_nodes_tree_three_topics", "supported_protocols": SUPPORTED_PROTOCOLS, + "nodes": ["1", "2", "3", "4", "5", "6", "7"], "adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]}, "topic_map": { "astrophysics": ["2", "3", "4", "5", "6", "7"], @@ -71,6 +75,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [ { "name": "seven_nodes_tree_three_topics_diff_origin", "supported_protocols": SUPPORTED_PROTOCOLS, + "nodes": ["1", "2", "3", "4", "5", "6", "7"], "adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]}, "topic_map": { "astrophysics": ["1", "2", "3", "4", "5", "6", "7"], @@ -86,6 +91,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [ { "name": "three_nodes_clique_two_topic_diff_origin", "supported_protocols": SUPPORTED_PROTOCOLS, + "nodes": ["1", "2", "3"], "adj_list": {"1": ["2", "3"], "2": ["3"]}, "topic_map": {"astrophysics": ["1", "2", "3"], "school": ["1", "2", "3"]}, "messages": [ @@ -97,6 +103,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [ { "name": "four_nodes_clique_two_topic_diff_origin_many_msgs", "supported_protocols": SUPPORTED_PROTOCOLS, + "nodes": ["1", "2", "3", "4"], "adj_list": { "1": ["2", "3", "4"], "2": ["1", "3", "4"], @@ -120,6 +127,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [ { "name": "five_nodes_ring_two_topic_diff_origin_many_msgs", "supported_protocols": SUPPORTED_PROTOCOLS, + "nodes": ["1", "2", "3", "4", "5"], "adj_list": {"1": ["2"], "2": ["3"], "3": ["4"], "4": ["5"], "5": ["1"]}, "topic_map": { "astrophysics": ["1", "2", "3", "4", "5"], @@ -143,15 +151,7 @@ floodsub_protocol_pytest_params = [ ] -def _collect_node_ids(adj_list): - node_ids = set() - for node, neighbors in adj_list.items(): - node_ids.add(node) - node_ids.update(set(neighbors)) - return node_ids - - -async def perform_test_from_obj(obj, router_factory) -> None: +async def perform_test_from_obj(obj, pubsub_factory) -> None: """ Perform pubsub tests from a test object, which is composed as follows: @@ -185,68 +185,75 @@ async def perform_test_from_obj(obj, router_factory) -> None: # Step 1) Create graph adj_list = obj["adj_list"] + node_list = obj["nodes"] node_map = {} pubsub_map = {} - async def add_node(node_id_str: str): - pubsub_router = router_factory(protocols=obj["supported_protocols"]) - pubsub = PubsubFactory(router=pubsub_router) - await pubsub.host.get_network().listen(LISTEN_MADDR) - node_map[node_id_str] = pubsub.host - pubsub_map[node_id_str] = pubsub + async with pubsub_factory( + number=len(node_list), protocols=obj["supported_protocols"] + ) as pubsubs: + for node_id_str, pubsub in zip(node_list, pubsubs): + node_map[node_id_str] = pubsub.host + pubsub_map[node_id_str] = pubsub - all_node_ids = _collect_node_ids(adj_list) + # Connect nodes and wait at least for 2 seconds + async with trio.open_nursery() as nursery: + for start_node_id in adj_list: + # For each neighbor of start_node, create if does not yet exist, + # then connect start_node to neighbor + for neighbor_id in adj_list[start_node_id]: + nursery.start_soon( + connect, node_map[start_node_id], node_map[neighbor_id] + ) + nursery.start_soon(trio.sleep, 2) - for node in all_node_ids: - await add_node(node) + # Step 2) Subscribe to topics + queues_map = {} + topic_map = obj["topic_map"] - for node, neighbors in adj_list.items(): - for neighbor_id in neighbors: - await connect(node_map[node], node_map[neighbor_id]) - - # NOTE: the test using this routine will fail w/o these sleeps... - await asyncio.sleep(1) - - # Step 2) Subscribe to topics - queues_map = {} - topic_map = obj["topic_map"] - - for topic, node_ids in topic_map.items(): - for node_id in node_ids: - queue = await pubsub_map[node_id].subscribe(topic) + async def subscribe_node(node_id, topic): if node_id not in queues_map: queues_map[node_id] = {} - # Store queue in topic-queue map for node - queues_map[node_id][topic] = queue + # Avoid repeated works + if topic in queues_map[node_id]: + # Checkpoint + await trio.hazmat.checkpoint() + return + sub = await pubsub_map[node_id].subscribe(topic) + queues_map[node_id][topic] = sub - # NOTE: the test using this routine will fail w/o these sleeps... - await asyncio.sleep(1) + async with trio.open_nursery() as nursery: + for topic, node_ids in topic_map.items(): + for node_id in node_ids: + nursery.start_soon(subscribe_node, node_id, topic) + nursery.start_soon(trio.sleep, 2) - # Step 3) Publish messages - topics_in_msgs_ordered = [] - messages = obj["messages"] + # Step 3) Publish messages + topics_in_msgs_ordered = [] + messages = obj["messages"] - for msg in messages: - topics = msg["topics"] - data = msg["data"] - node_id = msg["node_id"] + for msg in messages: + topics = msg["topics"] + data = msg["data"] + node_id = msg["node_id"] + + # Publish message + # TODO: Should be single RPC package with several topics + for topic in topics: + await pubsub_map[node_id].publish(topic, data) - # Publish message - # TODO: Should be single RPC package with several topics - for topic in topics: - await pubsub_map[node_id].publish(topic, data) # For each topic in topics, add (topic, node_id, data) tuple to ordered test list - topics_in_msgs_ordered.append((topic, node_id, data)) + for topic in topics: + topics_in_msgs_ordered.append((topic, node_id, data)) + # Allow time for publishing before continuing + await trio.sleep(1) - # Step 4) Check that all messages were received correctly. - for topic, origin_node_id, data in topics_in_msgs_ordered: - # Look at each node in each topic - for node_id in topic_map[topic]: - # Get message from subscription queue - queue = queues_map[node_id][topic] - msg = await queue.get() - assert data == msg.data - # Check the message origin - assert node_map[origin_node_id].get_id().to_bytes() == msg.from_id - - # Success, terminate pending tasks. + # Step 4) Check that all messages were received correctly. + for topic, origin_node_id, data in topics_in_msgs_ordered: + # Look at each node in each topic + for node_id in topic_map[topic]: + # Get message from subscription queue + msg = await queues_map[node_id][topic].get() + assert data == msg.data + # Check the message origin + assert node_map[origin_node_id].get_id().to_bytes() == msg.from_id diff --git a/libp2p/tools/utils.py b/libp2p/tools/utils.py index b53e4dfb..5a262b3b 100644 --- a/libp2p/tools/utils.py +++ b/libp2p/tools/utils.py @@ -1,17 +1,10 @@ -from typing import Dict, Sequence, Tuple, cast +from typing import Awaitable, Callable -import multiaddr - -from libp2p import new_node -from libp2p.host.basic_host import BasicHost from libp2p.host.host_interface import IHost -from libp2p.host.routed_host import RoutedHost +from libp2p.network.stream.exceptions import StreamError from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.swarm import Swarm -from libp2p.peer.id import ID -from libp2p.peer.peerinfo import PeerInfo, info_from_p2p_addr -from libp2p.routing.interfaces import IPeerRouting -from libp2p.typing import StreamHandlerFn, TProtocol +from libp2p.peer.peerinfo import info_from_p2p_addr from .constants import MAX_READ_LEN @@ -36,63 +29,20 @@ async def connect(node1: IHost, node2: IHost) -> None: await node1.connect(info) -async def set_up_nodes_by_transport_opt( - transport_opt_list: Sequence[Sequence[str]] -) -> Tuple[BasicHost, ...]: - nodes_list = [] - for transport_opt in transport_opt_list: - node = await new_node(transport_opt=transport_opt) - await node.get_network().listen(multiaddr.Multiaddr(transport_opt[0])) - nodes_list.append(node) - return tuple(nodes_list) +def create_echo_stream_handler( + ack_prefix: str +) -> Callable[[INetStream], Awaitable[None]]: + async def echo_stream_handler(stream: INetStream) -> None: + while True: + try: + read_string = (await stream.read(MAX_READ_LEN)).decode() + except StreamError: + break + resp = ack_prefix + read_string + try: + await stream.write(resp.encode()) + except StreamError: + break -async def echo_stream_handler(stream: INetStream) -> None: - while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() - - resp = f"ack:{read_string}" - await stream.write(resp.encode()) - - -async def perform_two_host_set_up( - handler: StreamHandlerFn = echo_stream_handler -) -> Tuple[BasicHost, BasicHost]: - transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] - (node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list) - - node_b.set_stream_handler(TProtocol("/echo/1.0.0"), handler) - - # Associate the peer with local ip address (see default parameters of Libp2p()) - node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) - return node_a, node_b - - -class DummyRouter(IPeerRouting): - _routing_table: Dict[ID, PeerInfo] - - def __init__(self) -> None: - self._routing_table = dict() - - async def find_peer(self, peer_id: ID) -> PeerInfo: - return self._routing_table.get(peer_id, None) - - -async def set_up_routed_hosts() -> Tuple[RoutedHost, RoutedHost]: - router_a, router_b = DummyRouter(), DummyRouter() - transport = "/ip4/127.0.0.1/tcp/0" - host_a = await new_node(transport_opt=[transport], disc_opt=router_a) - host_b = await new_node(transport_opt=[transport], disc_opt=router_b) - - address = multiaddr.Multiaddr(transport) - await host_a.get_network().listen(address) - await host_b.get_network().listen(address) - - mock_routing_table = { - host_a.get_id(): PeerInfo(host_a.get_id(), host_a.get_addrs()), - host_b.get_id(): PeerInfo(host_b.get_id(), host_b.get_addrs()), - } - - router_a._routing_table = router_b._routing_table = mock_routing_table - - return cast(RoutedHost, host_a), cast(RoutedHost, host_b) + return echo_stream_handler diff --git a/libp2p/transport/listener_interface.py b/libp2p/transport/listener_interface.py index 1b22531b..d170d1de 100644 --- a/libp2p/transport/listener_interface.py +++ b/libp2p/transport/listener_interface.py @@ -1,12 +1,13 @@ from abc import ABC, abstractmethod -from typing import List +from typing import Tuple from multiaddr import Multiaddr +import trio class IListener(ABC): @abstractmethod - async def listen(self, maddr: Multiaddr) -> bool: + async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool: """ put listener in listening mode and wait for incoming connections. @@ -15,7 +16,7 @@ class IListener(ABC): """ @abstractmethod - def get_addrs(self) -> List[Multiaddr]: + def get_addrs(self) -> Tuple[Multiaddr, ...]: """ retrieve list of addresses the listener is listening on. @@ -24,5 +25,4 @@ class IListener(ABC): @abstractmethod async def close(self) -> None: - """close the listener such that no more connections can be open on this - transport instance.""" + ... diff --git a/libp2p/transport/tcp/tcp.py b/libp2p/transport/tcp/tcp.py index 28ff0532..1004e288 100644 --- a/libp2p/transport/tcp/tcp.py +++ b/libp2p/transport/tcp/tcp.py @@ -1,10 +1,11 @@ -import asyncio -from socket import socket -import sys -from typing import List +import logging +from typing import Awaitable, Callable, List, Sequence, Tuple from multiaddr import Multiaddr +import trio +from trio_typing import TaskStatus +from libp2p.io.trio import TrioTCPStream from libp2p.network.connection.raw_connection import RawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.transport.exceptions import OpenConnectionError @@ -12,53 +13,61 @@ from libp2p.transport.listener_interface import IListener from libp2p.transport.transport_interface import ITransport from libp2p.transport.typing import THandler +logger = logging.getLogger("libp2p.transport.tcp") + class TCPListener(IListener): - multiaddrs: List[Multiaddr] - server = None + listeners: List[trio.SocketListener] def __init__(self, handler_function: THandler) -> None: - self.multiaddrs = [] - self.server = None + self.listeners = [] self.handler = handler_function - async def listen(self, maddr: Multiaddr) -> bool: + # TODO: Get rid of `nursery`? + async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> None: """ put listener in listening mode and wait for incoming connections. :param maddr: maddr of peer :return: return True if successful """ - self.server = await asyncio.start_server( - self.handler, + + async def serve_tcp( + handler: Callable[[trio.SocketStream], Awaitable[None]], + port: int, + host: str, + task_status: TaskStatus[Sequence[trio.SocketListener]] = None, + ) -> None: + """Just a proxy function to add logging here.""" + logger.debug("serve_tcp %s %s", host, port) + await trio.serve_tcp(handler, port, host=host, task_status=task_status) + + async def handler(stream: trio.SocketStream) -> None: + tcp_stream = TrioTCPStream(stream) + await self.handler(tcp_stream) + + listeners = await nursery.start( + serve_tcp, + handler, + int(maddr.value_for_protocol("tcp")), maddr.value_for_protocol("ip4"), - maddr.value_for_protocol("tcp"), ) - socket = self.server.sockets[0] - self.multiaddrs.append(_multiaddr_from_socket(socket)) + self.listeners.extend(listeners) - return True - - def get_addrs(self) -> List[Multiaddr]: + def get_addrs(self) -> Tuple[Multiaddr, ...]: """ retrieve list of addresses the listener is listening on. :return: return list of addrs """ - # TODO check if server is listening - return self.multiaddrs + return tuple( + _multiaddr_from_socket(listener.socket) for listener in self.listeners + ) async def close(self) -> None: - """close the listener such that no more connections can be open on this - transport instance.""" - if self.server is None: - return - self.server.close() - server = self.server - self.server = None - if sys.version_info < (3, 7): - return - await server.wait_closed() + async with trio.open_nursery() as nursery: + for listener in self.listeners: + nursery.start_soon(listener.aclose) class TCP(ITransport): @@ -74,11 +83,12 @@ class TCP(ITransport): self.port = int(maddr.value_for_protocol("tcp")) try: - reader, writer = await asyncio.open_connection(self.host, self.port) - except (ConnectionAbortedError, ConnectionRefusedError) as error: - raise OpenConnectionError(error) + stream = await trio.open_tcp_stream(self.host, self.port) + except OSError as error: + raise OpenConnectionError from error + read_write_closer = TrioTCPStream(stream) - return RawConnection(reader, writer, True) + return RawConnection(read_write_closer, True) def create_listener(self, handler_function: THandler) -> TCPListener: """ @@ -91,6 +101,6 @@ class TCP(ITransport): return TCPListener(handler_function) -def _multiaddr_from_socket(socket: socket) -> Multiaddr: - addr, port = socket.getsockname()[:2] - return Multiaddr(f"/ip4/{addr}/tcp/{port}") +def _multiaddr_from_socket(socket: trio.socket.SocketType) -> Multiaddr: + ip, port = socket.getsockname() # type: ignore + return Multiaddr(f"/ip4/{ip}/tcp/{port}") diff --git a/libp2p/transport/typing.py b/libp2p/transport/typing.py index f9b31dcb..d68a8aa4 100644 --- a/libp2p/transport/typing.py +++ b/libp2p/transport/typing.py @@ -1,11 +1,11 @@ -from asyncio import StreamReader, StreamWriter from typing import Awaitable, Callable, Mapping, Type +from libp2p.io.abc import ReadWriteCloser from libp2p.security.secure_transport_interface import ISecureTransport from libp2p.stream_muxer.abc import IMuxedConn from libp2p.typing import TProtocol -THandler = Callable[[StreamReader, StreamWriter], Awaitable[None]] +THandler = Callable[[ReadWriteCloser], Awaitable[None]] TSecurityOptions = Mapping[TProtocol, ISecureTransport] TMuxerClass = Type[IMuxedConn] TMuxerOptions = Mapping[TProtocol, TMuxerClass] diff --git a/setup.py b/setup.py index 2d0b6f35..e808b959 100644 --- a/setup.py +++ b/setup.py @@ -7,8 +7,8 @@ from setuptools import find_packages, setup extras_require = { "test": [ "pytest>=4.6.3,<5.0.0", - "pytest-xdist>=1.30.0,<2", - "pytest-asyncio>=0.10.0,<1.0.0", + "pytest-xdist>=1.30.0", + "pytest-trio>=0.5.2", "factory-boy>=2.12.0,<3.0.0", ], "lint": [ @@ -74,6 +74,10 @@ install_requires = [ "pynacl==1.3.0", "dataclasses>=0.7, <1;python_version<'3.7'", "async_generator==1.10", + "trio>=0.13.0", + "async-service>=0.1.0a6", + "async-exit-stack==1.0.1", + "trio-typing>=0.3.0,<0.4.0", ] diff --git a/tests/conftest.py b/tests/conftest.py index 746fb026..48d705c2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,5 @@ -import asyncio - import pytest -from libp2p.tools.constants import LISTEN_MADDR from libp2p.tools.factories import HostFactory @@ -17,17 +14,6 @@ def num_hosts(): @pytest.fixture -async def hosts(num_hosts, is_host_secure): - _hosts = HostFactory.create_batch(num_hosts, is_secure=is_host_secure) - await asyncio.gather( - *[_host.get_network().listen(LISTEN_MADDR) for _host in _hosts] - ) - try: +async def hosts(num_hosts, is_host_secure, nursery): + async with HostFactory.create_batch_and_listen(is_host_secure, num_hosts) as _hosts: yield _hosts - finally: - # TODO: It's possible that `close` raises exceptions currently, - # due to the connection reset things. Though we don't care much about that when - # cleaning up the tasks, it is probably better to handle the exceptions properly. - await asyncio.gather( - *[_host.close() for _host in _hosts], return_exceptions=True - ) diff --git a/tests/examples/test_chat.py b/tests/examples/test_examples.py similarity index 85% rename from tests/examples/test_chat.py rename to tests/examples/test_examples.py index 536b5a05..c8ce9ed8 100644 --- a/tests/examples/test_chat.py +++ b/tests/examples/test_examples.py @@ -1,10 +1,10 @@ -import asyncio - import pytest +import trio from libp2p.host.exceptions import StreamFailure from libp2p.peer.peerinfo import info_from_p2p_addr -from libp2p.tools.utils import set_up_nodes_by_transport_opt +from libp2p.tools.factories import HostFactory +from libp2p.tools.utils import MAX_READ_LEN PROTOCOL_ID = "/chat/1.0.0" @@ -25,7 +25,7 @@ async def hello_world(host_a, host_b): # Multiaddress of the destination peer is fetched from the peerstore using 'peerId'. stream = await host_b.new_stream(host_a.get_id(), [PROTOCOL_ID]) await stream.write(hello_world_from_host_b) - read = await stream.read() + read = await stream.read(MAX_READ_LEN) assert read == hello_world_from_host_a await stream.close() @@ -47,7 +47,7 @@ async def connect_write(host_a, host_b): await stream.write(message.encode()) # Reader needs time due to async reads - await asyncio.sleep(2) + await trio.sleep(2) await stream.close() assert received == messages @@ -88,16 +88,14 @@ async def no_common_protocol(host_a, host_b): await host_b.new_stream(host_a.get_id(), ["/fakeproto/0.0.1"]) -@pytest.mark.asyncio @pytest.mark.parametrize( "test", [(hello_world), (connect_write), (connect_read), (no_common_protocol)] ) -async def test_chat(test): - transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] - (host_a, host_b) = await set_up_nodes_by_transport_opt(transport_opt_list) +@pytest.mark.trio +async def test_chat(test, is_host_secure): + async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts: + addr = hosts[0].get_addrs()[0] + info = info_from_p2p_addr(addr) + await hosts[1].connect(info) - addr = host_a.get_addrs()[0] - info = info_from_p2p_addr(addr) - await host_b.connect(info) - - await test(host_a, host_b) + await test(hosts[0], hosts[1]) diff --git a/tests/host/test_basic_host.py b/tests/host/test_basic_host.py index 1eec04a8..55605ed5 100644 --- a/tests/host/test_basic_host.py +++ b/tests/host/test_basic_host.py @@ -1,4 +1,4 @@ -from libp2p import initialize_default_swarm +from libp2p import new_swarm from libp2p.crypto.rsa import create_new_key_pair from libp2p.host.basic_host import BasicHost from libp2p.host.defaults import get_default_protocols @@ -6,7 +6,7 @@ from libp2p.host.defaults import get_default_protocols def test_default_protocols(): key_pair = create_new_key_pair() - swarm = initialize_default_swarm(key_pair) + swarm = new_swarm(key_pair) host = BasicHost(swarm) mux = host.get_mux() diff --git a/tests/host/test_ping.py b/tests/host/test_ping.py index fcc5a850..7a0f8db5 100644 --- a/tests/host/test_ping.py +++ b/tests/host/test_ping.py @@ -1,18 +1,19 @@ -import asyncio import secrets import pytest +import trio from libp2p.host.ping import ID, PING_LENGTH -from libp2p.tools.factories import pair_of_connected_hosts +from libp2p.tools.factories import host_pair_factory -@pytest.mark.asyncio -async def test_ping_once(): - async with pair_of_connected_hosts() as (host_a, host_b): +@pytest.mark.trio +async def test_ping_once(is_host_secure): + async with host_pair_factory(is_host_secure) as (host_a, host_b): stream = await host_b.new_stream(host_a.get_id(), (ID,)) some_ping = secrets.token_bytes(PING_LENGTH) await stream.write(some_ping) + await trio.sleep(0.01) some_pong = await stream.read(PING_LENGTH) assert some_ping == some_pong await stream.close() @@ -21,9 +22,9 @@ async def test_ping_once(): SOME_PING_COUNT = 3 -@pytest.mark.asyncio -async def test_ping_several(): - async with pair_of_connected_hosts() as (host_a, host_b): +@pytest.mark.trio +async def test_ping_several(is_host_secure): + async with host_pair_factory(is_host_secure) as (host_a, host_b): stream = await host_b.new_stream(host_a.get_id(), (ID,)) for _ in range(SOME_PING_COUNT): some_ping = secrets.token_bytes(PING_LENGTH) @@ -33,5 +34,5 @@ async def test_ping_several(): # NOTE: simulate some time to sleep to mirror a real # world usage where a peer sends pings on some periodic interval # NOTE: this interval can be `0` for this test. - await asyncio.sleep(0) + await trio.sleep(0) await stream.close() diff --git a/tests/host/test_routed_host.py b/tests/host/test_routed_host.py index 271246cb..4cfed6bf 100644 --- a/tests/host/test_routed_host.py +++ b/tests/host/test_routed_host.py @@ -1,33 +1,26 @@ -import asyncio - import pytest from libp2p.host.exceptions import ConnectionFailure from libp2p.peer.peerinfo import PeerInfo -from libp2p.tools.utils import set_up_nodes_by_transport_opt, set_up_routed_hosts +from libp2p.tools.factories import HostFactory, RoutedHostFactory -@pytest.mark.asyncio +@pytest.mark.trio async def test_host_routing_success(): - host_a, host_b = await set_up_routed_hosts() - # forces to use routing as no addrs are provided - await host_a.connect(PeerInfo(host_b.get_id(), [])) - await host_b.connect(PeerInfo(host_a.get_id(), [])) - - # Clean up - await asyncio.gather(*[host_a.close(), host_b.close()]) + async with RoutedHostFactory.create_batch_and_listen(False, 2) as hosts: + # forces to use routing as no addrs are provided + await hosts[0].connect(PeerInfo(hosts[1].get_id(), [])) + await hosts[1].connect(PeerInfo(hosts[0].get_id(), [])) -@pytest.mark.asyncio +@pytest.mark.trio async def test_host_routing_fail(): - host_a, host_b = await set_up_routed_hosts() - basic_host_c = (await set_up_nodes_by_transport_opt([["/ip4/127.0.0.1/tcp/0"]]))[0] - - # routing fails because host_c does not use routing - with pytest.raises(ConnectionFailure): - await host_a.connect(PeerInfo(basic_host_c.get_id(), [])) - with pytest.raises(ConnectionFailure): - await host_b.connect(PeerInfo(basic_host_c.get_id(), [])) - - # Clean up - await asyncio.gather(*[host_a.close(), host_b.close(), basic_host_c.close()]) + is_secure = False + async with RoutedHostFactory.create_batch_and_listen( + is_secure, 2 + ) as routed_hosts, HostFactory.create_batch_and_listen(is_secure, 1) as basic_hosts: + # routing fails because host_c does not use routing + with pytest.raises(ConnectionFailure): + await routed_hosts[0].connect(PeerInfo(basic_hosts[0].get_id(), [])) + with pytest.raises(ConnectionFailure): + await routed_hosts[1].connect(PeerInfo(basic_hosts[0].get_id(), [])) diff --git a/tests/identity/identify/test_protocol.py b/tests/identity/identify/test_protocol.py index fab78ec1..4bbdbcba 100644 --- a/tests/identity/identify/test_protocol.py +++ b/tests/identity/identify/test_protocol.py @@ -2,12 +2,12 @@ import pytest from libp2p.identity.identify.pb.identify_pb2 import Identify from libp2p.identity.identify.protocol import ID, _mk_identify_protobuf -from libp2p.tools.factories import pair_of_connected_hosts +from libp2p.tools.factories import host_pair_factory -@pytest.mark.asyncio -async def test_identify_protocol(): - async with pair_of_connected_hosts() as (host_a, host_b): +@pytest.mark.trio +async def test_identify_protocol(is_host_secure): + async with host_pair_factory(is_host_secure) as (host_a, host_b): stream = await host_b.new_stream(host_a.get_id(), (ID,)) response = await stream.read() await stream.close() diff --git a/tests/libp2p/test_libp2p.py b/tests/libp2p/test_libp2p.py index 541c1733..99a60bd5 100644 --- a/tests/libp2p/test_libp2p.py +++ b/tests/libp2p/test_libp2p.py @@ -1,350 +1,285 @@ import multiaddr import pytest -from libp2p.peer.peerinfo import info_from_p2p_addr +from libp2p.network.stream.exceptions import StreamError from libp2p.tools.constants import MAX_READ_LEN -from libp2p.tools.utils import set_up_nodes_by_transport_opt +from libp2p.tools.factories import HostFactory +from libp2p.tools.utils import connect, create_echo_stream_handler +from libp2p.typing import TProtocol + +PROTOCOL_ID_0 = TProtocol("/echo/0") +PROTOCOL_ID_1 = TProtocol("/echo/1") +PROTOCOL_ID_2 = TProtocol("/echo/2") +PROTOCOL_ID_3 = TProtocol("/echo/3") + +ACK_STR_0 = "ack_0:" +ACK_STR_1 = "ack_1:" +ACK_STR_2 = "ack_2:" +ACK_STR_3 = "ack_3:" -@pytest.mark.asyncio -async def test_simple_messages(): - transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] - (node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list) - - async def stream_handler(stream): - while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() - - response = "ack:" + read_string - await stream.write(response.encode()) - - node_b.set_stream_handler("/echo/1.0.0", stream_handler) - - # Associate the peer with local ip address (see default parameters of Libp2p()) - node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) - - stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) - - messages = ["hello" + str(x) for x in range(10)] - for message in messages: - await stream.write(message.encode()) - - response = (await stream.read(MAX_READ_LEN)).decode() - - assert response == ("ack:" + message) - - # Success, terminate pending tasks. - - -@pytest.mark.asyncio -async def test_double_response(): - transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] - (node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list) - - async def stream_handler(stream): - while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() - - response = "ack1:" + read_string - await stream.write(response.encode()) - - response = "ack2:" + read_string - await stream.write(response.encode()) - - node_b.set_stream_handler("/echo/1.0.0", stream_handler) - - # Associate the peer with local ip address (see default parameters of Libp2p()) - node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) - stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) - - messages = ["hello" + str(x) for x in range(10)] - for message in messages: - await stream.write(message.encode()) - - response1 = (await stream.read(MAX_READ_LEN)).decode() - assert response1 == ("ack1:" + message) - - response2 = (await stream.read(MAX_READ_LEN)).decode() - assert response2 == ("ack2:" + message) - - # Success, terminate pending tasks. - - -@pytest.mark.asyncio -async def test_multiple_streams(): - # Node A should be able to open a stream with node B and then vice versa. - # Stream IDs should be generated uniquely so that the stream state is not overwritten - transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] - (node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list) - - async def stream_handler_a(stream): - while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() - - response = "ack_a:" + read_string - await stream.write(response.encode()) - - async def stream_handler_b(stream): - while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() - - response = "ack_b:" + read_string - await stream.write(response.encode()) - - node_a.set_stream_handler("/echo_a/1.0.0", stream_handler_a) - node_b.set_stream_handler("/echo_b/1.0.0", stream_handler_b) - - # Associate the peer with local ip address (see default parameters of Libp2p()) - node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) - node_b.get_peerstore().add_addrs(node_a.get_id(), node_a.get_addrs(), 10) - - stream_a = await node_a.new_stream(node_b.get_id(), ["/echo_b/1.0.0"]) - stream_b = await node_b.new_stream(node_a.get_id(), ["/echo_a/1.0.0"]) - - # A writes to /echo_b via stream_a, and B writes to /echo_a via stream_b - messages = ["hello" + str(x) for x in range(10)] - for message in messages: - a_message = message + "_a" - b_message = message + "_b" - - await stream_a.write(a_message.encode()) - await stream_b.write(b_message.encode()) - - response_a = (await stream_a.read(MAX_READ_LEN)).decode() - response_b = (await stream_b.read(MAX_READ_LEN)).decode() - - assert response_a == ("ack_b:" + a_message) and response_b == ( - "ack_a:" + b_message +@pytest.mark.trio +async def test_simple_messages(is_host_secure): + async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts: + hosts[1].set_stream_handler( + PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0) ) - # Success, terminate pending tasks. + # Associate the peer with local ip address (see default parameters of Libp2p()) + hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10) + + stream = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0]) + + messages = ["hello" + str(x) for x in range(10)] + for message in messages: + await stream.write(message.encode()) + response = (await stream.read(MAX_READ_LEN)).decode() + assert response == (ACK_STR_0 + message) -@pytest.mark.asyncio -async def test_multiple_streams_same_initiator_different_protocols(): - transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] - (node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list) +@pytest.mark.trio +async def test_double_response(is_host_secure): + async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts: - async def stream_handler_a1(stream): - while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() + async def double_response_stream_handler(stream): + while True: + try: + read_string = (await stream.read(MAX_READ_LEN)).decode() + except StreamError: + break - response = "ack_a1:" + read_string - await stream.write(response.encode()) + response = ACK_STR_0 + read_string + try: + await stream.write(response.encode()) + except StreamError: + break - async def stream_handler_a2(stream): - while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() + response = ACK_STR_1 + read_string + try: + await stream.write(response.encode()) + except StreamError: + break - response = "ack_a2:" + read_string - await stream.write(response.encode()) + hosts[1].set_stream_handler(PROTOCOL_ID_0, double_response_stream_handler) - async def stream_handler_a3(stream): - while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() + # Associate the peer with local ip address (see default parameters of Libp2p()) + hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10) + stream = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0]) - response = "ack_a3:" + read_string - await stream.write(response.encode()) - - node_b.set_stream_handler("/echo_a1/1.0.0", stream_handler_a1) - node_b.set_stream_handler("/echo_a2/1.0.0", stream_handler_a2) - node_b.set_stream_handler("/echo_a3/1.0.0", stream_handler_a3) - - # Associate the peer with local ip address (see default parameters of Libp2p()) - node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) - node_b.get_peerstore().add_addrs(node_a.get_id(), node_a.get_addrs(), 10) - - # Open streams to node_b over echo_a1 echo_a2 echo_a3 protocols - stream_a1 = await node_a.new_stream(node_b.get_id(), ["/echo_a1/1.0.0"]) - stream_a2 = await node_a.new_stream(node_b.get_id(), ["/echo_a2/1.0.0"]) - stream_a3 = await node_a.new_stream(node_b.get_id(), ["/echo_a3/1.0.0"]) - - messages = ["hello" + str(x) for x in range(10)] - for message in messages: - a1_message = message + "_a1" - a2_message = message + "_a2" - a3_message = message + "_a3" - - await stream_a1.write(a1_message.encode()) - await stream_a2.write(a2_message.encode()) - await stream_a3.write(a3_message.encode()) - - response_a1 = (await stream_a1.read(MAX_READ_LEN)).decode() - response_a2 = (await stream_a2.read(MAX_READ_LEN)).decode() - response_a3 = (await stream_a3.read(MAX_READ_LEN)).decode() - - assert ( - response_a1 == ("ack_a1:" + a1_message) - and response_a2 == ("ack_a2:" + a2_message) - and response_a3 == ("ack_a3:" + a3_message) - ) - - # Success, terminate pending tasks. - - -@pytest.mark.asyncio -async def test_multiple_streams_two_initiators(): - transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] - (node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list) - - async def stream_handler_a1(stream): - while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() - - response = "ack_a1:" + read_string - await stream.write(response.encode()) - - async def stream_handler_a2(stream): - while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() - - response = "ack_a2:" + read_string - await stream.write(response.encode()) - - async def stream_handler_b1(stream): - while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() - - response = "ack_b1:" + read_string - await stream.write(response.encode()) - - async def stream_handler_b2(stream): - while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() - - response = "ack_b2:" + read_string - await stream.write(response.encode()) - - node_a.set_stream_handler("/echo_b1/1.0.0", stream_handler_b1) - node_a.set_stream_handler("/echo_b2/1.0.0", stream_handler_b2) - - node_b.set_stream_handler("/echo_a1/1.0.0", stream_handler_a1) - node_b.set_stream_handler("/echo_a2/1.0.0", stream_handler_a2) - - # Associate the peer with local ip address (see default parameters of Libp2p()) - node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) - node_b.get_peerstore().add_addrs(node_a.get_id(), node_a.get_addrs(), 10) - - stream_a1 = await node_a.new_stream(node_b.get_id(), ["/echo_a1/1.0.0"]) - stream_a2 = await node_a.new_stream(node_b.get_id(), ["/echo_a2/1.0.0"]) - - stream_b1 = await node_b.new_stream(node_a.get_id(), ["/echo_b1/1.0.0"]) - stream_b2 = await node_b.new_stream(node_a.get_id(), ["/echo_b2/1.0.0"]) - - # A writes to /echo_b via stream_a, and B writes to /echo_a via stream_b - messages = ["hello" + str(x) for x in range(10)] - for message in messages: - a1_message = message + "_a1" - a2_message = message + "_a2" - - b1_message = message + "_b1" - b2_message = message + "_b2" - - await stream_a1.write(a1_message.encode()) - await stream_a2.write(a2_message.encode()) - - await stream_b1.write(b1_message.encode()) - await stream_b2.write(b2_message.encode()) - - response_a1 = (await stream_a1.read(MAX_READ_LEN)).decode() - response_a2 = (await stream_a2.read(MAX_READ_LEN)).decode() - - response_b1 = (await stream_b1.read(MAX_READ_LEN)).decode() - response_b2 = (await stream_b2.read(MAX_READ_LEN)).decode() - - assert ( - response_a1 == ("ack_a1:" + a1_message) - and response_a2 == ("ack_a2:" + a2_message) - and response_b1 == ("ack_b1:" + b1_message) - and response_b2 == ("ack_b2:" + b2_message) - ) - - # Success, terminate pending tasks. - - -@pytest.mark.asyncio -async def test_triangle_nodes_connection(): - transport_opt_list = [ - ["/ip4/127.0.0.1/tcp/0"], - ["/ip4/127.0.0.1/tcp/0"], - ["/ip4/127.0.0.1/tcp/0"], - ] - (node_a, node_b, node_c) = await set_up_nodes_by_transport_opt(transport_opt_list) - - async def stream_handler(stream): - while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() - - response = "ack:" + read_string - await stream.write(response.encode()) - - node_a.set_stream_handler("/echo/1.0.0", stream_handler) - node_b.set_stream_handler("/echo/1.0.0", stream_handler) - node_c.set_stream_handler("/echo/1.0.0", stream_handler) - - # Associate the peer with local ip address (see default parameters of Libp2p()) - # Associate all permutations - node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) - node_a.get_peerstore().add_addrs(node_c.get_id(), node_c.get_addrs(), 10) - - node_b.get_peerstore().add_addrs(node_a.get_id(), node_a.get_addrs(), 10) - node_b.get_peerstore().add_addrs(node_c.get_id(), node_c.get_addrs(), 10) - - node_c.get_peerstore().add_addrs(node_a.get_id(), node_a.get_addrs(), 10) - node_c.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) - - stream_a_to_b = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) - stream_a_to_c = await node_a.new_stream(node_c.get_id(), ["/echo/1.0.0"]) - - stream_b_to_a = await node_b.new_stream(node_a.get_id(), ["/echo/1.0.0"]) - stream_b_to_c = await node_b.new_stream(node_c.get_id(), ["/echo/1.0.0"]) - - stream_c_to_a = await node_c.new_stream(node_a.get_id(), ["/echo/1.0.0"]) - stream_c_to_b = await node_c.new_stream(node_b.get_id(), ["/echo/1.0.0"]) - - messages = ["hello" + str(x) for x in range(5)] - streams = [ - stream_a_to_b, - stream_a_to_c, - stream_b_to_a, - stream_b_to_c, - stream_c_to_a, - stream_c_to_b, - ] - - for message in messages: - for stream in streams: + messages = ["hello" + str(x) for x in range(10)] + for message in messages: await stream.write(message.encode()) - response = (await stream.read(MAX_READ_LEN)).decode() + response1 = (await stream.read(MAX_READ_LEN)).decode() + assert response1 == (ACK_STR_0 + message) - assert response == ("ack:" + message) - - # Success, terminate pending tasks. + response2 = (await stream.read(MAX_READ_LEN)).decode() + assert response2 == (ACK_STR_1 + message) -@pytest.mark.asyncio -async def test_host_connect(): - transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] - (node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list) +@pytest.mark.trio +async def test_multiple_streams(is_host_secure): + # hosts[0] should be able to open a stream with hosts[1] and then vice versa. + # Stream IDs should be generated uniquely so that the stream state is not overwritten - # Only our peer ID is stored in peer store - assert len(node_a.get_peerstore().peer_ids()) == 1 + async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts: + hosts[0].set_stream_handler( + PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0) + ) + hosts[1].set_stream_handler( + PROTOCOL_ID_1, create_echo_stream_handler(ACK_STR_1) + ) - addr = node_b.get_addrs()[0] - info = info_from_p2p_addr(addr) - await node_a.connect(info) + # Associate the peer with local ip address (see default parameters of Libp2p()) + hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10) + hosts[1].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10) - assert len(node_a.get_peerstore().peer_ids()) == 2 + stream_a = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_1]) + stream_b = await hosts[1].new_stream(hosts[0].get_id(), [PROTOCOL_ID_0]) - await node_a.connect(info) + # A writes to /echo_b via stream_a, and B writes to /echo_a via stream_b + messages = ["hello" + str(x) for x in range(10)] + for message in messages: + a_message = message + "_a" + b_message = message + "_b" - # make sure we don't do double connection - assert len(node_a.get_peerstore().peer_ids()) == 2 + await stream_a.write(a_message.encode()) + await stream_b.write(b_message.encode()) - assert node_b.get_id() in node_a.get_peerstore().peer_ids() - ma_node_b = multiaddr.Multiaddr("/p2p/%s" % node_b.get_id().pretty()) - for addr in node_a.get_peerstore().addrs(node_b.get_id()): - assert addr.encapsulate(ma_node_b) in node_b.get_addrs() + response_a = (await stream_a.read(MAX_READ_LEN)).decode() + response_b = (await stream_b.read(MAX_READ_LEN)).decode() - # Success, terminate pending tasks. + assert response_a == (ACK_STR_1 + a_message) and response_b == ( + ACK_STR_0 + b_message + ) + + +@pytest.mark.trio +async def test_multiple_streams_same_initiator_different_protocols(is_host_secure): + async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts: + + hosts[1].set_stream_handler( + PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0) + ) + hosts[1].set_stream_handler( + PROTOCOL_ID_1, create_echo_stream_handler(ACK_STR_1) + ) + hosts[1].set_stream_handler( + PROTOCOL_ID_2, create_echo_stream_handler(ACK_STR_2) + ) + + # Associate the peer with local ip address (see default parameters of Libp2p()) + hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10) + hosts[1].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10) + + # Open streams to hosts[1] over echo_a1 echo_a2 echo_a3 protocols + stream_a1 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0]) + stream_a2 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_1]) + stream_a3 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_2]) + + messages = ["hello" + str(x) for x in range(10)] + for message in messages: + a1_message = message + "_a1" + a2_message = message + "_a2" + a3_message = message + "_a3" + + await stream_a1.write(a1_message.encode()) + await stream_a2.write(a2_message.encode()) + await stream_a3.write(a3_message.encode()) + + response_a1 = (await stream_a1.read(MAX_READ_LEN)).decode() + response_a2 = (await stream_a2.read(MAX_READ_LEN)).decode() + response_a3 = (await stream_a3.read(MAX_READ_LEN)).decode() + + assert ( + response_a1 == (ACK_STR_0 + a1_message) + and response_a2 == (ACK_STR_1 + a2_message) + and response_a3 == (ACK_STR_2 + a3_message) + ) + + # Success, terminate pending tasks. + + +@pytest.mark.trio +async def test_multiple_streams_two_initiators(is_host_secure): + async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts: + hosts[0].set_stream_handler( + PROTOCOL_ID_2, create_echo_stream_handler(ACK_STR_2) + ) + hosts[0].set_stream_handler( + PROTOCOL_ID_3, create_echo_stream_handler(ACK_STR_3) + ) + + hosts[1].set_stream_handler( + PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0) + ) + hosts[1].set_stream_handler( + PROTOCOL_ID_1, create_echo_stream_handler(ACK_STR_1) + ) + + # Associate the peer with local ip address (see default parameters of Libp2p()) + hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10) + hosts[1].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10) + + stream_a1 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0]) + stream_a2 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_1]) + + stream_b1 = await hosts[1].new_stream(hosts[0].get_id(), [PROTOCOL_ID_2]) + stream_b2 = await hosts[1].new_stream(hosts[0].get_id(), [PROTOCOL_ID_3]) + + # A writes to /echo_b via stream_a, and B writes to /echo_a via stream_b + messages = ["hello" + str(x) for x in range(10)] + for message in messages: + a1_message = message + "_a1" + a2_message = message + "_a2" + + b1_message = message + "_b1" + b2_message = message + "_b2" + + await stream_a1.write(a1_message.encode()) + await stream_a2.write(a2_message.encode()) + + await stream_b1.write(b1_message.encode()) + await stream_b2.write(b2_message.encode()) + + response_a1 = (await stream_a1.read(MAX_READ_LEN)).decode() + response_a2 = (await stream_a2.read(MAX_READ_LEN)).decode() + + response_b1 = (await stream_b1.read(MAX_READ_LEN)).decode() + response_b2 = (await stream_b2.read(MAX_READ_LEN)).decode() + + assert ( + response_a1 == (ACK_STR_0 + a1_message) + and response_a2 == (ACK_STR_1 + a2_message) + and response_b1 == (ACK_STR_2 + b1_message) + and response_b2 == (ACK_STR_3 + b2_message) + ) + + +@pytest.mark.trio +async def test_triangle_nodes_connection(is_host_secure): + async with HostFactory.create_batch_and_listen(is_host_secure, 3) as hosts: + + hosts[0].set_stream_handler( + PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0) + ) + hosts[1].set_stream_handler( + PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0) + ) + hosts[2].set_stream_handler( + PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0) + ) + + # Associate the peer with local ip address (see default parameters of Libp2p()) + # Associate all permutations + hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10) + hosts[0].get_peerstore().add_addrs(hosts[2].get_id(), hosts[2].get_addrs(), 10) + + hosts[1].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10) + hosts[1].get_peerstore().add_addrs(hosts[2].get_id(), hosts[2].get_addrs(), 10) + + hosts[2].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10) + hosts[2].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10) + + stream_0_to_1 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0]) + stream_0_to_2 = await hosts[0].new_stream(hosts[2].get_id(), [PROTOCOL_ID_0]) + + stream_1_to_0 = await hosts[1].new_stream(hosts[0].get_id(), [PROTOCOL_ID_0]) + stream_1_to_2 = await hosts[1].new_stream(hosts[2].get_id(), [PROTOCOL_ID_0]) + + stream_2_to_0 = await hosts[2].new_stream(hosts[0].get_id(), [PROTOCOL_ID_0]) + stream_2_to_1 = await hosts[2].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0]) + + messages = ["hello" + str(x) for x in range(5)] + streams = [ + stream_0_to_1, + stream_0_to_2, + stream_1_to_0, + stream_1_to_2, + stream_2_to_0, + stream_2_to_1, + ] + + for message in messages: + for stream in streams: + await stream.write(message.encode()) + response = (await stream.read(MAX_READ_LEN)).decode() + assert response == (ACK_STR_0 + message) + + +@pytest.mark.trio +async def test_host_connect(is_host_secure): + async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts: + assert len(hosts[0].get_peerstore().peer_ids()) == 1 + + await connect(hosts[0], hosts[1]) + assert len(hosts[0].get_peerstore().peer_ids()) == 2 + + await connect(hosts[0], hosts[1]) + # make sure we don't do double connection + assert len(hosts[0].get_peerstore().peer_ids()) == 2 + + assert hosts[1].get_id() in hosts[0].get_peerstore().peer_ids() + ma_node_b = multiaddr.Multiaddr("/p2p/%s" % hosts[1].get_id().pretty()) + for addr in hosts[0].get_peerstore().addrs(hosts[1].get_id()): + assert addr.encapsulate(ma_node_b) in hosts[1].get_addrs() diff --git a/tests/network/conftest.py b/tests/network/conftest.py index 6b75b756..5aad36c9 100644 --- a/tests/network/conftest.py +++ b/tests/network/conftest.py @@ -1,5 +1,3 @@ -import asyncio - import pytest from libp2p.tools.factories import ( @@ -11,26 +9,17 @@ from libp2p.tools.factories import ( @pytest.fixture async def net_stream_pair(is_host_secure): - stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory(is_host_secure) - try: - yield stream_0, stream_1 - finally: - await asyncio.gather(*[host_0.close(), host_1.close()]) + async with net_stream_pair_factory(is_host_secure) as net_stream_pair: + yield net_stream_pair @pytest.fixture async def swarm_pair(is_host_secure): - swarm_0, swarm_1 = await swarm_pair_factory(is_host_secure) - try: - yield swarm_0, swarm_1 - finally: - await asyncio.gather(*[swarm_0.close(), swarm_1.close()]) + async with swarm_pair_factory(is_host_secure) as swarms: + yield swarms @pytest.fixture async def swarm_conn_pair(is_host_secure): - conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory(is_host_secure) - try: - yield conn_0, conn_1 - finally: - await asyncio.gather(*[swarm_0.close(), swarm_1.close()]) + async with swarm_conn_pair_factory(is_host_secure) as swarm_conn_pair: + yield swarm_conn_pair diff --git a/tests/network/test_net_stream.py b/tests/network/test_net_stream.py index d0fea932..b558f1dd 100644 --- a/tests/network/test_net_stream.py +++ b/tests/network/test_net_stream.py @@ -1,6 +1,5 @@ -import asyncio - import pytest +import trio from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset from libp2p.tools.constants import MAX_READ_LEN @@ -8,7 +7,7 @@ from libp2p.tools.constants import MAX_READ_LEN DATA = b"data_123" -@pytest.mark.asyncio +@pytest.mark.trio async def test_net_stream_read_write(net_stream_pair): stream_0, stream_1 = net_stream_pair assert ( @@ -19,7 +18,7 @@ async def test_net_stream_read_write(net_stream_pair): assert (await stream_1.read(MAX_READ_LEN)) == DATA -@pytest.mark.asyncio +@pytest.mark.trio async def test_net_stream_read_until_eof(net_stream_pair): read_bytes = bytearray() stream_0, stream_1 = net_stream_pair @@ -27,41 +26,39 @@ async def test_net_stream_read_until_eof(net_stream_pair): async def read_until_eof(): read_bytes.extend(await stream_1.read()) - task = asyncio.ensure_future(read_until_eof()) + async with trio.open_nursery() as nursery: + nursery.start_soon(read_until_eof) + expected_data = bytearray() - expected_data = bytearray() + # Test: `read` doesn't return before `close` is called. + await stream_0.write(DATA) + expected_data.extend(DATA) + await trio.sleep(0.01) + assert len(read_bytes) == 0 + # Test: `read` doesn't return before `close` is called. + await stream_0.write(DATA) + expected_data.extend(DATA) + await trio.sleep(0.01) + assert len(read_bytes) == 0 - # Test: `read` doesn't return before `close` is called. - await stream_0.write(DATA) - expected_data.extend(DATA) - await asyncio.sleep(0.01) - assert len(read_bytes) == 0 - # Test: `read` doesn't return before `close` is called. - await stream_0.write(DATA) - expected_data.extend(DATA) - await asyncio.sleep(0.01) - assert len(read_bytes) == 0 - - # Test: Close the stream, `read` returns, and receive previous sent data. - await stream_0.close() - await asyncio.sleep(0.01) - assert read_bytes == expected_data - - task.cancel() + # Test: Close the stream, `read` returns, and receive previous sent data. + await stream_0.close() + await trio.sleep(0.01) + assert read_bytes == expected_data -@pytest.mark.asyncio +@pytest.mark.trio async def test_net_stream_read_after_remote_closed(net_stream_pair): stream_0, stream_1 = net_stream_pair await stream_0.write(DATA) await stream_0.close() - await asyncio.sleep(0.01) + await trio.sleep(0.01) assert (await stream_1.read(MAX_READ_LEN)) == DATA with pytest.raises(StreamEOF): await stream_1.read(MAX_READ_LEN) -@pytest.mark.asyncio +@pytest.mark.trio async def test_net_stream_read_after_local_reset(net_stream_pair): stream_0, stream_1 = net_stream_pair await stream_0.reset() @@ -69,29 +66,29 @@ async def test_net_stream_read_after_local_reset(net_stream_pair): await stream_0.read(MAX_READ_LEN) -@pytest.mark.asyncio +@pytest.mark.trio async def test_net_stream_read_after_remote_reset(net_stream_pair): stream_0, stream_1 = net_stream_pair await stream_0.write(DATA) await stream_0.reset() # Sleep to let `stream_1` receive the message. - await asyncio.sleep(0.01) + await trio.sleep(0.01) with pytest.raises(StreamReset): await stream_1.read(MAX_READ_LEN) -@pytest.mark.asyncio +@pytest.mark.trio async def test_net_stream_read_after_remote_closed_and_reset(net_stream_pair): stream_0, stream_1 = net_stream_pair await stream_0.write(DATA) await stream_0.close() await stream_0.reset() # Sleep to let `stream_1` receive the message. - await asyncio.sleep(0.01) + await trio.sleep(0.01) assert (await stream_1.read(MAX_READ_LEN)) == DATA -@pytest.mark.asyncio +@pytest.mark.trio async def test_net_stream_write_after_local_closed(net_stream_pair): stream_0, stream_1 = net_stream_pair await stream_0.write(DATA) @@ -100,7 +97,7 @@ async def test_net_stream_write_after_local_closed(net_stream_pair): await stream_0.write(DATA) -@pytest.mark.asyncio +@pytest.mark.trio async def test_net_stream_write_after_local_reset(net_stream_pair): stream_0, stream_1 = net_stream_pair await stream_0.reset() @@ -108,10 +105,10 @@ async def test_net_stream_write_after_local_reset(net_stream_pair): await stream_0.write(DATA) -@pytest.mark.asyncio +@pytest.mark.trio async def test_net_stream_write_after_remote_reset(net_stream_pair): stream_0, stream_1 = net_stream_pair await stream_1.reset() - await asyncio.sleep(0.01) + await trio.sleep(0.01) with pytest.raises(StreamClosed): await stream_0.write(DATA) diff --git a/tests/network/test_notify.py b/tests/network/test_notify.py index f8187b1e..328ff128 100644 --- a/tests/network/test_notify.py +++ b/tests/network/test_notify.py @@ -8,11 +8,11 @@ into network after network has already started listening TODO: Add tests for closed_stream, listen_close when those features are implemented in swarm """ - -import asyncio import enum +from async_service import background_trio_service import pytest +import trio from libp2p.network.notifee_interface import INotifee from libp2p.tools.constants import LISTEN_MADDR @@ -54,59 +54,63 @@ class MyNotifee(INotifee): pass -@pytest.mark.asyncio +@pytest.mark.trio async def test_notify(is_host_secure): swarms = [SwarmFactory(is_secure=is_host_secure) for _ in range(2)] events_0_0 = [] events_1_0 = [] events_0_without_listen = [] - swarms[0].register_notifee(MyNotifee(events_0_0)) - swarms[1].register_notifee(MyNotifee(events_1_0)) - # Listen - await asyncio.gather(*[swarm.listen(LISTEN_MADDR) for swarm in swarms]) + # Run swarms. + async with background_trio_service(swarms[0]), background_trio_service(swarms[1]): + # Register events before listening, to allow `MyNotifee` is notified with the event + # `listen`. + swarms[0].register_notifee(MyNotifee(events_0_0)) + swarms[1].register_notifee(MyNotifee(events_1_0)) - swarms[0].register_notifee(MyNotifee(events_0_without_listen)) + # Listen + async with trio.open_nursery() as nursery: + nursery.start_soon(swarms[0].listen, LISTEN_MADDR) + nursery.start_soon(swarms[1].listen, LISTEN_MADDR) - # Connected - await connect_swarm(swarms[0], swarms[1]) - # OpenedStream: first - await swarms[0].new_stream(swarms[1].get_peer_id()) - # OpenedStream: second - await swarms[0].new_stream(swarms[1].get_peer_id()) - # OpenedStream: third, but different direction. - await swarms[1].new_stream(swarms[0].get_peer_id()) + swarms[0].register_notifee(MyNotifee(events_0_without_listen)) - await asyncio.sleep(0.01) + # Connected + await connect_swarm(swarms[0], swarms[1]) + # OpenedStream: first + await swarms[0].new_stream(swarms[1].get_peer_id()) + # OpenedStream: second + await swarms[0].new_stream(swarms[1].get_peer_id()) + # OpenedStream: third, but different direction. + await swarms[1].new_stream(swarms[0].get_peer_id()) - # TODO: Check `ClosedStream` and `ListenClose` events after they are ready. + await trio.sleep(0.01) - # Disconnected - await swarms[0].close_peer(swarms[1].get_peer_id()) - await asyncio.sleep(0.01) + # TODO: Check `ClosedStream` and `ListenClose` events after they are ready. - # Connected again, but different direction. - await connect_swarm(swarms[1], swarms[0]) - await asyncio.sleep(0.01) + # Disconnected + await swarms[0].close_peer(swarms[1].get_peer_id()) + await trio.sleep(0.01) - # Disconnected again, but different direction. - await swarms[1].close_peer(swarms[0].get_peer_id()) - await asyncio.sleep(0.01) + # Connected again, but different direction. + await connect_swarm(swarms[1], swarms[0]) + await trio.sleep(0.01) - expected_events_without_listen = [ - Event.Connected, - Event.OpenedStream, - Event.OpenedStream, - Event.OpenedStream, - Event.Disconnected, - Event.Connected, - Event.Disconnected, - ] - expected_events = [Event.Listen] + expected_events_without_listen + # Disconnected again, but different direction. + await swarms[1].close_peer(swarms[0].get_peer_id()) + await trio.sleep(0.01) - assert events_0_0 == expected_events - assert events_1_0 == expected_events - assert events_0_without_listen == expected_events_without_listen + expected_events_without_listen = [ + Event.Connected, + Event.OpenedStream, + Event.OpenedStream, + Event.OpenedStream, + Event.Disconnected, + Event.Connected, + Event.Disconnected, + ] + expected_events = [Event.Listen] + expected_events_without_listen - # Clean up - await asyncio.gather(*[swarm.close() for swarm in swarms]) + assert events_0_0 == expected_events + assert events_1_0 == expected_events + assert events_0_without_listen == expected_events_without_listen diff --git a/tests/network/test_swarm.py b/tests/network/test_swarm.py index 3cc9375b..70b82477 100644 --- a/tests/network/test_swarm.py +++ b/tests/network/test_swarm.py @@ -1,89 +1,84 @@ -import asyncio - from multiaddr import Multiaddr import pytest +import trio +from trio.testing import wait_all_tasks_blocked from libp2p.network.exceptions import SwarmException from libp2p.tools.factories import SwarmFactory from libp2p.tools.utils import connect_swarm -@pytest.mark.asyncio +@pytest.mark.trio async def test_swarm_dial_peer(is_host_secure): - swarms = await SwarmFactory.create_batch_and_listen(is_host_secure, 3) - # Test: No addr found. - with pytest.raises(SwarmException): + async with SwarmFactory.create_batch_and_listen(is_host_secure, 3) as swarms: + # Test: No addr found. + with pytest.raises(SwarmException): + await swarms[0].dial_peer(swarms[1].get_peer_id()) + + # Test: len(addr) in the peerstore is 0. + swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), [], 10000) + with pytest.raises(SwarmException): + await swarms[0].dial_peer(swarms[1].get_peer_id()) + + # Test: Succeed if addrs of the peer_id are present in the peerstore. + addrs = tuple( + addr + for transport in swarms[1].listeners.values() + for addr in transport.get_addrs() + ) + swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000) await swarms[0].dial_peer(swarms[1].get_peer_id()) + assert swarms[0].get_peer_id() in swarms[1].connections + assert swarms[1].get_peer_id() in swarms[0].connections - # Test: len(addr) in the peerstore is 0. - swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), [], 10000) - with pytest.raises(SwarmException): - await swarms[0].dial_peer(swarms[1].get_peer_id()) - - # Test: Succeed if addrs of the peer_id are present in the peerstore. - addrs = tuple( - addr - for transport in swarms[1].listeners.values() - for addr in transport.get_addrs() - ) - swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000) - await swarms[0].dial_peer(swarms[1].get_peer_id()) - assert swarms[0].get_peer_id() in swarms[1].connections - assert swarms[1].get_peer_id() in swarms[0].connections - - # Test: Reuse connections when we already have ones with a peer. - conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()] - conn = await swarms[0].dial_peer(swarms[1].get_peer_id()) - assert conn is conn_to_1 - - # Clean up - await asyncio.gather(*[swarm.close() for swarm in swarms]) + # Test: Reuse connections when we already have ones with a peer. + conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()] + conn = await swarms[0].dial_peer(swarms[1].get_peer_id()) + assert conn is conn_to_1 -@pytest.mark.asyncio +@pytest.mark.trio async def test_swarm_close_peer(is_host_secure): - swarms = await SwarmFactory.create_batch_and_listen(is_host_secure, 3) - # 0 <> 1 <> 2 - await connect_swarm(swarms[0], swarms[1]) - await connect_swarm(swarms[1], swarms[2]) + async with SwarmFactory.create_batch_and_listen(is_host_secure, 3) as swarms: + # 0 <> 1 <> 2 + await connect_swarm(swarms[0], swarms[1]) + await connect_swarm(swarms[1], swarms[2]) - # peer 1 closes peer 0 - await swarms[1].close_peer(swarms[0].get_peer_id()) - await asyncio.sleep(0.01) - # 0 1 <> 2 - assert len(swarms[0].connections) == 0 - assert ( - len(swarms[1].connections) == 1 - and swarms[2].get_peer_id() in swarms[1].connections - ) + # peer 1 closes peer 0 + await swarms[1].close_peer(swarms[0].get_peer_id()) + await trio.sleep(0.01) + await wait_all_tasks_blocked() + # 0 1 <> 2 + assert len(swarms[0].connections) == 0 + assert ( + len(swarms[1].connections) == 1 + and swarms[2].get_peer_id() in swarms[1].connections + ) - # peer 1 is closed by peer 2 - await swarms[2].close_peer(swarms[1].get_peer_id()) - await asyncio.sleep(0.01) - # 0 1 2 - assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0 + # peer 1 is closed by peer 2 + await swarms[2].close_peer(swarms[1].get_peer_id()) + await trio.sleep(0.01) + # 0 1 2 + assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0 - await connect_swarm(swarms[0], swarms[1]) - # 0 <> 1 2 - assert ( - len(swarms[0].connections) == 1 - and swarms[1].get_peer_id() in swarms[0].connections - ) - assert ( - len(swarms[1].connections) == 1 - and swarms[0].get_peer_id() in swarms[1].connections - ) - # peer 0 closes peer 1 - await swarms[0].close_peer(swarms[1].get_peer_id()) - await asyncio.sleep(0.01) - # 0 1 2 - assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0 - - # Clean up - await asyncio.gather(*[swarm.close() for swarm in swarms]) + await connect_swarm(swarms[0], swarms[1]) + # 0 <> 1 2 + assert ( + len(swarms[0].connections) == 1 + and swarms[1].get_peer_id() in swarms[0].connections + ) + assert ( + len(swarms[1].connections) == 1 + and swarms[0].get_peer_id() in swarms[1].connections + ) + # peer 0 closes peer 1 + await swarms[0].close_peer(swarms[1].get_peer_id()) + await trio.sleep(0.01) + # 0 1 2 + assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0 -@pytest.mark.asyncio +@pytest.mark.trio async def test_swarm_remove_conn(swarm_pair): swarm_0, swarm_1 = swarm_pair conn_0 = swarm_0.connections[swarm_1.get_peer_id()] @@ -94,57 +89,54 @@ async def test_swarm_remove_conn(swarm_pair): assert swarm_1.get_peer_id() not in swarm_0.connections -@pytest.mark.asyncio +@pytest.mark.trio async def test_swarm_multiaddr(is_host_secure): - swarms = await SwarmFactory.create_batch_and_listen(is_host_secure, 3) + async with SwarmFactory.create_batch_and_listen(is_host_secure, 3) as swarms: - def clear(): - swarms[0].peerstore.clear_addrs(swarms[1].get_peer_id()) + def clear(): + swarms[0].peerstore.clear_addrs(swarms[1].get_peer_id()) - clear() - # No addresses - with pytest.raises(SwarmException): + clear() + # No addresses + with pytest.raises(SwarmException): + await swarms[0].dial_peer(swarms[1].get_peer_id()) + + clear() + # Wrong addresses + swarms[0].peerstore.add_addrs( + swarms[1].get_peer_id(), [Multiaddr("/ip4/0.0.0.0/tcp/9999")], 10000 + ) + + with pytest.raises(SwarmException): + await swarms[0].dial_peer(swarms[1].get_peer_id()) + + clear() + # Multiple wrong addresses + swarms[0].peerstore.add_addrs( + swarms[1].get_peer_id(), + [Multiaddr("/ip4/0.0.0.0/tcp/9999"), Multiaddr("/ip4/0.0.0.0/tcp/9998")], + 10000, + ) + + with pytest.raises(SwarmException): + await swarms[0].dial_peer(swarms[1].get_peer_id()) + + # Test one address + addrs = tuple( + addr + for transport in swarms[1].listeners.values() + for addr in transport.get_addrs() + ) + + swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs[:1], 10000) await swarms[0].dial_peer(swarms[1].get_peer_id()) - clear() - # Wrong addresses - swarms[0].peerstore.add_addrs( - swarms[1].get_peer_id(), [Multiaddr("/ip4/0.0.0.0/tcp/9999")], 10000 - ) + # Test multiple addresses + addrs = tuple( + addr + for transport in swarms[1].listeners.values() + for addr in transport.get_addrs() + ) - with pytest.raises(SwarmException): + swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs + addrs, 10000) await swarms[0].dial_peer(swarms[1].get_peer_id()) - - clear() - # Multiple wrong addresses - swarms[0].peerstore.add_addrs( - swarms[1].get_peer_id(), - [Multiaddr("/ip4/0.0.0.0/tcp/9999"), Multiaddr("/ip4/0.0.0.0/tcp/9998")], - 10000, - ) - - with pytest.raises(SwarmException): - await swarms[0].dial_peer(swarms[1].get_peer_id()) - - # Test one address - addrs = tuple( - addr - for transport in swarms[1].listeners.values() - for addr in transport.get_addrs() - ) - - swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs[:1], 10000) - await swarms[0].dial_peer(swarms[1].get_peer_id()) - - # Test multiple addresses - addrs = tuple( - addr - for transport in swarms[1].listeners.values() - for addr in transport.get_addrs() - ) - - swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs + addrs, 10000) - await swarms[0].dial_peer(swarms[1].get_peer_id()) - - for swarm in swarms: - await swarm.close() diff --git a/tests/network/test_swarm_conn.py b/tests/network/test_swarm_conn.py index 2abc7d0f..dc692f44 100644 --- a/tests/network/test_swarm_conn.py +++ b/tests/network/test_swarm_conn.py @@ -1,45 +1,46 @@ -import asyncio - import pytest +import trio +from trio.testing import wait_all_tasks_blocked -@pytest.mark.asyncio +@pytest.mark.trio async def test_swarm_conn_close(swarm_conn_pair): conn_0, conn_1 = swarm_conn_pair - assert not conn_0.event_closed.is_set() - assert not conn_1.event_closed.is_set() + assert not conn_0.is_closed + assert not conn_1.is_closed await conn_0.close() - await asyncio.sleep(0.01) + await trio.sleep(0.1) + await wait_all_tasks_blocked() - assert conn_0.event_closed.is_set() - assert conn_1.event_closed.is_set() + assert conn_0.is_closed + assert conn_1.is_closed assert conn_0 not in conn_0.swarm.connections.values() assert conn_1 not in conn_1.swarm.connections.values() -@pytest.mark.asyncio +@pytest.mark.trio async def test_swarm_conn_streams(swarm_conn_pair): conn_0, conn_1 = swarm_conn_pair - assert len(await conn_0.get_streams()) == 0 - assert len(await conn_1.get_streams()) == 0 + assert len(conn_0.get_streams()) == 0 + assert len(conn_1.get_streams()) == 0 stream_0_0 = await conn_0.new_stream() - await asyncio.sleep(0.01) - assert len(await conn_0.get_streams()) == 1 - assert len(await conn_1.get_streams()) == 1 + await trio.sleep(0.01) + assert len(conn_0.get_streams()) == 1 + assert len(conn_1.get_streams()) == 1 stream_0_1 = await conn_0.new_stream() - await asyncio.sleep(0.01) - assert len(await conn_0.get_streams()) == 2 - assert len(await conn_1.get_streams()) == 2 + await trio.sleep(0.01) + assert len(conn_0.get_streams()) == 2 + assert len(conn_1.get_streams()) == 2 conn_0.remove_stream(stream_0_0) - assert len(await conn_0.get_streams()) == 1 + assert len(conn_0.get_streams()) == 1 conn_0.remove_stream(stream_0_1) - assert len(await conn_0.get_streams()) == 0 + assert len(conn_0.get_streams()) == 0 # Nothing happen if `stream_0_1` is not present or already removed. conn_0.remove_stream(stream_0_1) diff --git a/tests/peer/test_peerinfo.py b/tests/peer/test_peerinfo.py index 29c46887..deb760d4 100644 --- a/tests/peer/test_peerinfo.py +++ b/tests/peer/test_peerinfo.py @@ -25,8 +25,6 @@ def test_init_(): @pytest.mark.parametrize( "addr", ( - pytest.param(None), - pytest.param(random.randint(0, 255), id="random integer"), pytest.param(multiaddr.Multiaddr("/"), id="empty multiaddr"), pytest.param( multiaddr.Multiaddr("/ip4/127.0.0.1"), diff --git a/tests/protocol_muxer/test_protocol_muxer.py b/tests/protocol_muxer/test_protocol_muxer.py index 42dae60c..cd82652c 100644 --- a/tests/protocol_muxer/test_protocol_muxer.py +++ b/tests/protocol_muxer/test_protocol_muxer.py @@ -1,83 +1,94 @@ import pytest from libp2p.host.exceptions import StreamFailure -from libp2p.tools.utils import echo_stream_handler, set_up_nodes_by_transport_opt +from libp2p.tools.factories import HostFactory +from libp2p.tools.utils import create_echo_stream_handler -# TODO: Add tests for multiple streams being opened on different -# protocols through the same connection +PROTOCOL_ECHO = "/echo/1.0.0" +PROTOCOL_POTATO = "/potato/1.0.0" +PROTOCOL_FOO = "/foo/1.0.0" +PROTOCOL_ROCK = "/rock/1.0.0" -# Note: async issues occurred when using the same port -# so that's why I use different ports here. -# TODO: modify tests so that those async issues don't occur -# when using the same ports across tests +ACK_PREFIX = "ack:" async def perform_simple_test( - expected_selected_protocol, protocols_for_client, protocols_with_handlers + expected_selected_protocol, + protocols_for_client, + protocols_with_handlers, + is_host_secure, ): - transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] - (node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list) + async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts: + for protocol in protocols_with_handlers: + hosts[1].set_stream_handler( + protocol, create_echo_stream_handler(ACK_PREFIX) + ) - for protocol in protocols_with_handlers: - node_b.set_stream_handler(protocol, echo_stream_handler) + # Associate the peer with local ip address (see default parameters of Libp2p()) + hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10) + stream = await hosts[0].new_stream(hosts[1].get_id(), protocols_for_client) + messages = ["hello" + str(x) for x in range(10)] + for message in messages: + expected_resp = "ack:" + message + await stream.write(message.encode()) + response = (await stream.read(len(expected_resp))).decode() + assert response == expected_resp - # Associate the peer with local ip address (see default parameters of Libp2p()) - node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) - - stream = await node_a.new_stream(node_b.get_id(), protocols_for_client) - messages = ["hello" + str(x) for x in range(10)] - for message in messages: - expected_resp = "ack:" + message - await stream.write(message.encode()) - response = (await stream.read(len(expected_resp))).decode() - assert response == expected_resp - - assert expected_selected_protocol == stream.get_protocol() - - # Success, terminate pending tasks. + assert expected_selected_protocol == stream.get_protocol() -@pytest.mark.asyncio -async def test_single_protocol_succeeds(): - expected_selected_protocol = "/echo/1.0.0" +@pytest.mark.trio +async def test_single_protocol_succeeds(is_host_secure): + expected_selected_protocol = PROTOCOL_ECHO await perform_simple_test( - expected_selected_protocol, ["/echo/1.0.0"], ["/echo/1.0.0"] + expected_selected_protocol, + [expected_selected_protocol], + [expected_selected_protocol], + is_host_secure, ) -@pytest.mark.asyncio -async def test_single_protocol_fails(): +@pytest.mark.trio +async def test_single_protocol_fails(is_host_secure): with pytest.raises(StreamFailure): - await perform_simple_test("", ["/echo/1.0.0"], ["/potato/1.0.0"]) + await perform_simple_test( + "", [PROTOCOL_ECHO], [PROTOCOL_POTATO], is_host_secure + ) # Cleanup not reached on error -@pytest.mark.asyncio -async def test_multiple_protocol_first_is_valid_succeeds(): - expected_selected_protocol = "/echo/1.0.0" - protocols_for_client = ["/echo/1.0.0", "/potato/1.0.0"] - protocols_for_listener = ["/foo/1.0.0", "/echo/1.0.0"] +@pytest.mark.trio +async def test_multiple_protocol_first_is_valid_succeeds(is_host_secure): + expected_selected_protocol = PROTOCOL_ECHO + protocols_for_client = [PROTOCOL_ECHO, PROTOCOL_POTATO] + protocols_for_listener = [PROTOCOL_FOO, PROTOCOL_ECHO] await perform_simple_test( - expected_selected_protocol, protocols_for_client, protocols_for_listener + expected_selected_protocol, + protocols_for_client, + protocols_for_listener, + is_host_secure, ) -@pytest.mark.asyncio -async def test_multiple_protocol_second_is_valid_succeeds(): - expected_selected_protocol = "/foo/1.0.0" - protocols_for_client = ["/rock/1.0.0", "/foo/1.0.0"] - protocols_for_listener = ["/foo/1.0.0", "/echo/1.0.0"] +@pytest.mark.trio +async def test_multiple_protocol_second_is_valid_succeeds(is_host_secure): + expected_selected_protocol = PROTOCOL_FOO + protocols_for_client = [PROTOCOL_ROCK, PROTOCOL_FOO] + protocols_for_listener = [PROTOCOL_FOO, PROTOCOL_ECHO] await perform_simple_test( - expected_selected_protocol, protocols_for_client, protocols_for_listener + expected_selected_protocol, + protocols_for_client, + protocols_for_listener, + is_host_secure, ) -@pytest.mark.asyncio -async def test_multiple_protocol_fails(): - protocols_for_client = ["/rock/1.0.0", "/foo/1.0.0", "/bar/1.0.0"] +@pytest.mark.trio +async def test_multiple_protocol_fails(is_host_secure): + protocols_for_client = [PROTOCOL_ROCK, PROTOCOL_FOO, "/bar/1.0.0"] protocols_for_listener = ["/aspyn/1.0.0", "/rob/1.0.0", "/zx/1.0.0", "/alex/1.0.0"] with pytest.raises(StreamFailure): - await perform_simple_test("", protocols_for_client, protocols_for_listener) - - # Cleanup not reached on error + await perform_simple_test( + "", protocols_for_client, protocols_for_listener, is_host_secure + ) diff --git a/tests/pubsub/conftest.py b/tests/pubsub/conftest.py deleted file mode 100644 index 520fdf4b..00000000 --- a/tests/pubsub/conftest.py +++ /dev/null @@ -1,58 +0,0 @@ -import pytest - -from libp2p.tools.constants import GOSSIPSUB_PARAMS -from libp2p.tools.factories import FloodsubFactory, GossipsubFactory, PubsubFactory - - -@pytest.fixture -def is_strict_signing(): - return False - - -def _make_pubsubs(hosts, pubsub_routers, cache_size, is_strict_signing): - if len(pubsub_routers) != len(hosts): - raise ValueError( - f"lenght of pubsub_routers={pubsub_routers} should be equaled to the " - f"length of hosts={len(hosts)}" - ) - return tuple( - PubsubFactory( - host=host, - router=router, - cache_size=cache_size, - strict_signing=is_strict_signing, - ) - for host, router in zip(hosts, pubsub_routers) - ) - - -@pytest.fixture -def pubsub_cache_size(): - return None # default - - -@pytest.fixture -def gossipsub_params(): - return GOSSIPSUB_PARAMS - - -@pytest.fixture -def pubsubs_fsub(num_hosts, hosts, pubsub_cache_size, is_strict_signing): - floodsubs = FloodsubFactory.create_batch(num_hosts) - _pubsubs_fsub = _make_pubsubs( - hosts, floodsubs, pubsub_cache_size, is_strict_signing - ) - yield _pubsubs_fsub - # TODO: Clean up - - -@pytest.fixture -def pubsubs_gsub( - num_hosts, hosts, pubsub_cache_size, gossipsub_params, is_strict_signing -): - gossipsubs = GossipsubFactory.create_batch(num_hosts, **gossipsub_params._asdict()) - _pubsubs_gsub = _make_pubsubs( - hosts, gossipsubs, pubsub_cache_size, is_strict_signing - ) - yield _pubsubs_gsub - # TODO: Clean up diff --git a/tests/pubsub/test_dummyaccount_demo.py b/tests/pubsub/test_dummyaccount_demo.py index cdda6035..24d5bd41 100644 --- a/tests/pubsub/test_dummyaccount_demo.py +++ b/tests/pubsub/test_dummyaccount_demo.py @@ -1,19 +1,10 @@ -import asyncio -from threading import Thread - import pytest +import trio from libp2p.tools.pubsub.dummy_account_node import DummyAccountNode from libp2p.tools.utils import connect -def create_setup_in_new_thread_func(dummy_node): - def setup_in_new_thread(): - asyncio.ensure_future(dummy_node.setup_crypto_networking()) - - return setup_in_new_thread - - async def perform_test(num_nodes, adjacency_map, action_func, assertion_func): """ Helper function to allow for easy construction of custom tests for dummy @@ -26,47 +17,35 @@ async def perform_test(num_nodes, adjacency_map, action_func, assertion_func): :param assertion_func: assertions for testing the results of the actions are correct """ - # Create nodes - dummy_nodes = [] - for _ in range(num_nodes): - dummy_nodes.append(await DummyAccountNode.create()) + async with DummyAccountNode.create(num_nodes) as dummy_nodes: + # Create connections between nodes according to `adjacency_map` + async with trio.open_nursery() as nursery: + for source_num in adjacency_map: + target_nums = adjacency_map[source_num] + for target_num in target_nums: + nursery.start_soon( + connect, + dummy_nodes[source_num].host, + dummy_nodes[target_num].host, + ) - # Create network - for source_num in adjacency_map: - target_nums = adjacency_map[source_num] - for target_num in target_nums: - await connect( - dummy_nodes[source_num].libp2p_node, dummy_nodes[target_num].libp2p_node - ) + # Allow time for network creation to take place + await trio.sleep(0.25) - # Allow time for network creation to take place - await asyncio.sleep(0.25) + # Perform action function + await action_func(dummy_nodes) - # Start a thread for each node so that each node can listen and respond - # to messages on its own thread, which will avoid waiting indefinitely - # on the main thread. On this thread, call the setup func for the node, - # which subscribes the node to the CRYPTO_TOPIC topic - for dummy_node in dummy_nodes: - thread = Thread(target=create_setup_in_new_thread_func(dummy_node)) - thread.run() + # Allow time for action function to be performed (i.e. messages to propogate) + await trio.sleep(1) - # Allow time for nodes to subscribe to CRYPTO_TOPIC topic - await asyncio.sleep(0.25) - - # Perform action function - await action_func(dummy_nodes) - - # Allow time for action function to be performed (i.e. messages to propogate) - await asyncio.sleep(1) - - # Perform assertion function - for dummy_node in dummy_nodes: - assertion_func(dummy_node) + # Perform assertion function + for dummy_node in dummy_nodes: + assertion_func(dummy_node) # Success, terminate pending tasks. -@pytest.mark.asyncio +@pytest.mark.trio async def test_simple_two_nodes(): num_nodes = 2 adj_map = {0: [1]} @@ -80,7 +59,7 @@ async def test_simple_two_nodes(): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.asyncio +@pytest.mark.trio async def test_simple_three_nodes_line_topography(): num_nodes = 3 adj_map = {0: [1], 1: [2]} @@ -94,7 +73,7 @@ async def test_simple_three_nodes_line_topography(): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.asyncio +@pytest.mark.trio async def test_simple_three_nodes_triangle_topography(): num_nodes = 3 adj_map = {0: [1, 2], 1: [2]} @@ -108,7 +87,7 @@ async def test_simple_three_nodes_triangle_topography(): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.asyncio +@pytest.mark.trio async def test_simple_seven_nodes_tree_topography(): num_nodes = 7 adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]} @@ -122,14 +101,14 @@ async def test_simple_seven_nodes_tree_topography(): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.asyncio +@pytest.mark.trio async def test_set_then_send_from_root_seven_nodes_tree_topography(): num_nodes = 7 adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]} async def action_func(dummy_nodes): await dummy_nodes[0].publish_set_crypto("aspyn", 20) - await asyncio.sleep(0.25) + await trio.sleep(0.25) await dummy_nodes[0].publish_send_crypto("aspyn", "alex", 5) def assertion_func(dummy_node): @@ -139,14 +118,14 @@ async def test_set_then_send_from_root_seven_nodes_tree_topography(): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.asyncio +@pytest.mark.trio async def test_set_then_send_from_different_leafs_seven_nodes_tree_topography(): num_nodes = 7 adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]} async def action_func(dummy_nodes): await dummy_nodes[6].publish_set_crypto("aspyn", 20) - await asyncio.sleep(0.25) + await trio.sleep(0.25) await dummy_nodes[4].publish_send_crypto("aspyn", "alex", 5) def assertion_func(dummy_node): @@ -156,7 +135,7 @@ async def test_set_then_send_from_different_leafs_seven_nodes_tree_topography(): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.asyncio +@pytest.mark.trio async def test_simple_five_nodes_ring_topography(): num_nodes = 5 adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]} @@ -170,14 +149,14 @@ async def test_simple_five_nodes_ring_topography(): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.asyncio +@pytest.mark.trio async def test_set_then_send_from_diff_nodes_five_nodes_ring_topography(): num_nodes = 5 adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]} async def action_func(dummy_nodes): await dummy_nodes[0].publish_set_crypto("alex", 20) - await asyncio.sleep(0.25) + await trio.sleep(0.25) await dummy_nodes[3].publish_send_crypto("alex", "rob", 12) def assertion_func(dummy_node): @@ -187,7 +166,7 @@ async def test_set_then_send_from_diff_nodes_five_nodes_ring_topography(): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.asyncio +@pytest.mark.trio @pytest.mark.slow async def test_set_then_send_from_five_diff_nodes_five_nodes_ring_topography(): num_nodes = 5 @@ -195,13 +174,13 @@ async def test_set_then_send_from_five_diff_nodes_five_nodes_ring_topography(): async def action_func(dummy_nodes): await dummy_nodes[0].publish_set_crypto("alex", 20) - await asyncio.sleep(1) + await trio.sleep(1) await dummy_nodes[1].publish_send_crypto("alex", "rob", 3) - await asyncio.sleep(1) + await trio.sleep(1) await dummy_nodes[2].publish_send_crypto("rob", "aspyn", 2) - await asyncio.sleep(1) + await trio.sleep(1) await dummy_nodes[3].publish_send_crypto("aspyn", "zx", 1) - await asyncio.sleep(1) + await trio.sleep(1) await dummy_nodes[4].publish_send_crypto("zx", "raul", 1) def assertion_func(dummy_node): diff --git a/tests/pubsub/test_floodsub.py b/tests/pubsub/test_floodsub.py index 7564a949..148c001b 100644 --- a/tests/pubsub/test_floodsub.py +++ b/tests/pubsub/test_floodsub.py @@ -1,9 +1,10 @@ -import asyncio +import functools import pytest +import trio from libp2p.peer.id import ID -from libp2p.tools.factories import FloodsubFactory +from libp2p.tools.factories import PubsubFactory from libp2p.tools.pubsub.floodsub_integration_test_settings import ( floodsub_protocol_pytest_params, perform_test_from_obj, @@ -11,79 +12,80 @@ from libp2p.tools.pubsub.floodsub_integration_test_settings import ( from libp2p.tools.utils import connect -@pytest.mark.parametrize("num_hosts", (2,)) -@pytest.mark.asyncio -async def test_simple_two_nodes(pubsubs_fsub): - topic = "my_topic" - data = b"some data" +@pytest.mark.trio +async def test_simple_two_nodes(): + async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub: + topic = "my_topic" + data = b"some data" - await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host) - await asyncio.sleep(0.25) + await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host) + await trio.sleep(0.25) - sub_b = await pubsubs_fsub[1].subscribe(topic) - # Sleep to let a know of b's subscription - await asyncio.sleep(0.25) + sub_b = await pubsubs_fsub[1].subscribe(topic) + # Sleep to let a know of b's subscription + await trio.sleep(0.25) - await pubsubs_fsub[0].publish(topic, data) + await pubsubs_fsub[0].publish(topic, data) - res_b = await sub_b.get() - - # Check that the msg received by node_b is the same - # as the message sent by node_a - assert ID(res_b.from_id) == pubsubs_fsub[0].host.get_id() - assert res_b.data == data - assert res_b.topicIDs == [topic] - - # Success, terminate pending tasks. - - -# Initialize Pubsub with a cache_size of 4 -@pytest.mark.parametrize("num_hosts, pubsub_cache_size", ((2, 4),)) -@pytest.mark.asyncio -async def test_lru_cache_two_nodes(pubsubs_fsub, monkeypatch): - # two nodes with cache_size of 4 - # `node_a` send the following messages to node_b - message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1] - # `node_b` should only receive the following - expected_received_indices = [1, 2, 3, 4, 5, 1] - - topic = "my_topic" - - # Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`. - def get_msg_id(msg): - # Originally it is `(msg.seqno, msg.from_id)` - return (msg.data, msg.from_id) - - import libp2p.pubsub.pubsub - - monkeypatch.setattr(libp2p.pubsub.pubsub, "get_msg_id", get_msg_id) - - await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host) - await asyncio.sleep(0.25) - - sub_b = await pubsubs_fsub[1].subscribe(topic) - await asyncio.sleep(0.25) - - def _make_testing_data(i: int) -> bytes: - num_int_bytes = 4 - if i >= 2 ** (num_int_bytes * 8): - raise ValueError("integer is too large to be serialized") - return b"data" + i.to_bytes(num_int_bytes, "big") - - for index in message_indices: - await pubsubs_fsub[0].publish(topic, _make_testing_data(index)) - await asyncio.sleep(0.25) - - for index in expected_received_indices: res_b = await sub_b.get() - assert res_b.data == _make_testing_data(index) - assert sub_b.empty() - # Success, terminate pending tasks. + # Check that the msg received by node_b is the same + # as the message sent by node_a + assert ID(res_b.from_id) == pubsubs_fsub[0].host.get_id() + assert res_b.data == data + assert res_b.topicIDs == [topic] + + +@pytest.mark.trio +async def test_lru_cache_two_nodes(monkeypatch): + # two nodes with cache_size of 4 + async with PubsubFactory.create_batch_with_floodsub( + 2, cache_size=4 + ) as pubsubs_fsub: + # `node_a` send the following messages to node_b + message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1] + # `node_b` should only receive the following + expected_received_indices = [1, 2, 3, 4, 5, 1] + + topic = "my_topic" + + # Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`. + def get_msg_id(msg): + # Originally it is `(msg.seqno, msg.from_id)` + return (msg.data, msg.from_id) + + import libp2p.pubsub.pubsub + + monkeypatch.setattr(libp2p.pubsub.pubsub, "get_msg_id", get_msg_id) + + await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host) + await trio.sleep(0.25) + + sub_b = await pubsubs_fsub[1].subscribe(topic) + await trio.sleep(0.25) + + def _make_testing_data(i: int) -> bytes: + num_int_bytes = 4 + if i >= 2 ** (num_int_bytes * 8): + raise ValueError("integer is too large to be serialized") + return b"data" + i.to_bytes(num_int_bytes, "big") + + for index in message_indices: + await pubsubs_fsub[0].publish(topic, _make_testing_data(index)) + await trio.sleep(0.25) + + for index in expected_received_indices: + res_b = await sub_b.get() + assert res_b.data == _make_testing_data(index) @pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params) -@pytest.mark.asyncio +@pytest.mark.trio @pytest.mark.slow -async def test_gossipsub_run_with_floodsub_tests(test_case_obj): - await perform_test_from_obj(test_case_obj, FloodsubFactory) +async def test_gossipsub_run_with_floodsub_tests(test_case_obj, is_host_secure): + await perform_test_from_obj( + test_case_obj, + functools.partial( + PubsubFactory.create_batch_with_floodsub, is_secure=is_host_secure + ), + ) diff --git a/tests/pubsub/test_gossipsub.py b/tests/pubsub/test_gossipsub.py index 1bc34260..a423fbd6 100644 --- a/tests/pubsub/test_gossipsub.py +++ b/tests/pubsub/test_gossipsub.py @@ -1,495 +1,478 @@ -import asyncio import random import pytest +import trio -from libp2p.peer.id import ID from libp2p.pubsub.gossipsub import PROTOCOL_ID -from libp2p.tools.constants import GOSSIPSUB_PARAMS, GossipsubParams +from libp2p.tools.factories import IDFactory, PubsubFactory from libp2p.tools.pubsub.utils import dense_connect, one_to_all_connect from libp2p.tools.utils import connect -@pytest.mark.parametrize( - "num_hosts, gossipsub_params", - ((4, GossipsubParams(degree=4, degree_low=3, degree_high=5)),), -) -@pytest.mark.asyncio -async def test_join(num_hosts, hosts, pubsubs_gsub): - gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub) - hosts_indices = list(range(num_hosts)) +@pytest.mark.trio +async def test_join(): + async with PubsubFactory.create_batch_with_gossipsub( + 4, degree=4, degree_low=3, degree_high=5 + ) as pubsubs_gsub: + gossipsubs = [pubsub.router for pubsub in pubsubs_gsub] + hosts = [pubsub.host for pubsub in pubsubs_gsub] + hosts_indices = list(range(len(pubsubs_gsub))) - topic = "test_join" - central_node_index = 0 - # Remove index of central host from the indices - hosts_indices.remove(central_node_index) - num_subscribed_peer = 2 - subscribed_peer_indices = random.sample(hosts_indices, num_subscribed_peer) + topic = "test_join" + central_node_index = 0 + # Remove index of central host from the indices + hosts_indices.remove(central_node_index) + num_subscribed_peer = 2 + subscribed_peer_indices = random.sample(hosts_indices, num_subscribed_peer) - # All pubsub except the one of central node subscribe to topic - for i in subscribed_peer_indices: - await pubsubs_gsub[i].subscribe(topic) + # All pubsub except the one of central node subscribe to topic + for i in subscribed_peer_indices: + await pubsubs_gsub[i].subscribe(topic) - # Connect central host to all other hosts - await one_to_all_connect(hosts, central_node_index) + # Connect central host to all other hosts + await one_to_all_connect(hosts, central_node_index) - # Wait 2 seconds for heartbeat to allow mesh to connect - await asyncio.sleep(2) - - # Central node publish to the topic so that this topic - # is added to central node's fanout - # publish from the randomly chosen host - await pubsubs_gsub[central_node_index].publish(topic, b"data") - - # Check that the gossipsub of central node has fanout for the topic - assert topic in gossipsubs[central_node_index].fanout - # Check that the gossipsub of central node does not have a mesh for the topic - assert topic not in gossipsubs[central_node_index].mesh - - # Central node subscribes the topic - await pubsubs_gsub[central_node_index].subscribe(topic) - - await asyncio.sleep(2) - - # Check that the gossipsub of central node no longer has fanout for the topic - assert topic not in gossipsubs[central_node_index].fanout - - for i in hosts_indices: - if i in subscribed_peer_indices: - assert hosts[i].get_id() in gossipsubs[central_node_index].mesh[topic] - assert hosts[central_node_index].get_id() in gossipsubs[i].mesh[topic] - else: - assert hosts[i].get_id() not in gossipsubs[central_node_index].mesh[topic] - assert topic not in gossipsubs[i].mesh - - -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_leave(pubsubs_gsub): - gossipsub = pubsubs_gsub[0].router - topic = "test_leave" - - assert topic not in gossipsub.mesh - - await gossipsub.join(topic) - assert topic in gossipsub.mesh - - await gossipsub.leave(topic) - assert topic not in gossipsub.mesh - - # Test re-leave - await gossipsub.leave(topic) - - -@pytest.mark.parametrize("num_hosts", (2,)) -@pytest.mark.asyncio -async def test_handle_graft(pubsubs_gsub, hosts, event_loop, monkeypatch): - gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub) - - index_alice = 0 - id_alice = hosts[index_alice].get_id() - index_bob = 1 - id_bob = hosts[index_bob].get_id() - await connect(hosts[index_alice], hosts[index_bob]) - - # Wait 2 seconds for heartbeat to allow mesh to connect - await asyncio.sleep(2) - - topic = "test_handle_graft" - # Only lice subscribe to the topic - await gossipsubs[index_alice].join(topic) - - # Monkey patch bob's `emit_prune` function so we can - # check if it is called in `handle_graft` - event_emit_prune = asyncio.Event() - - async def emit_prune(topic, sender_peer_id): - event_emit_prune.set() - - monkeypatch.setattr(gossipsubs[index_bob], "emit_prune", emit_prune) - - # Check that alice is bob's peer but not his mesh peer - assert gossipsubs[index_bob].peer_protocol[id_alice] == PROTOCOL_ID - assert topic not in gossipsubs[index_bob].mesh - - await gossipsubs[index_alice].emit_graft(topic, id_bob) - - # Check that `emit_prune` is called - await asyncio.wait_for(event_emit_prune.wait(), timeout=1, loop=event_loop) - assert event_emit_prune.is_set() - - # Check that bob is alice's peer but not her mesh peer - assert topic in gossipsubs[index_alice].mesh - assert id_bob not in gossipsubs[index_alice].mesh[topic] - assert gossipsubs[index_alice].peer_protocol[id_bob] == PROTOCOL_ID - - await gossipsubs[index_bob].emit_graft(topic, id_alice) - - await asyncio.sleep(1) - - # Check that bob is now alice's mesh peer - assert id_bob in gossipsubs[index_alice].mesh[topic] - - -@pytest.mark.parametrize( - "num_hosts, gossipsub_params", ((2, GossipsubParams(heartbeat_interval=3)),) -) -@pytest.mark.asyncio -async def test_handle_prune(pubsubs_gsub, hosts): - gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub) - - index_alice = 0 - id_alice = hosts[index_alice].get_id() - index_bob = 1 - id_bob = hosts[index_bob].get_id() - - topic = "test_handle_prune" - for pubsub in pubsubs_gsub: - await pubsub.subscribe(topic) - - await connect(hosts[index_alice], hosts[index_bob]) - - # Wait for heartbeat to allow mesh to connect - await asyncio.sleep(1) - - # Check that they are each other's mesh peer - assert id_alice in gossipsubs[index_bob].mesh[topic] - assert id_bob in gossipsubs[index_alice].mesh[topic] - - # alice emit prune message to bob, alice should be removed - # from bob's mesh peer - await gossipsubs[index_alice].emit_prune(topic, id_bob) - # `emit_prune` does not remove bob from alice's mesh peers - assert id_bob in gossipsubs[index_alice].mesh[topic] - - # NOTE: We increase `heartbeat_interval` to 3 seconds so that bob will not - # add alice back to his mesh after heartbeat. - # Wait for bob to `handle_prune` - await asyncio.sleep(0.1) - - # Check that alice is no longer bob's mesh peer - assert id_alice not in gossipsubs[index_bob].mesh[topic] - - -@pytest.mark.parametrize("num_hosts", (10,)) -@pytest.mark.asyncio -async def test_dense(num_hosts, pubsubs_gsub, hosts): - num_msgs = 5 - - # All pubsub subscribe to foobar - queues = [] - for pubsub in pubsubs_gsub: - q = await pubsub.subscribe("foobar") - - # Add each blocking queue to an array of blocking queues - queues.append(q) - - # Densely connect libp2p hosts in a random way - await dense_connect(hosts) - - # Wait 2 seconds for heartbeat to allow mesh to connect - await asyncio.sleep(2) - - for i in range(num_msgs): - msg_content = b"foo " + i.to_bytes(1, "big") - - # randomly pick a message origin - origin_idx = random.randint(0, num_hosts - 1) + # Wait 2 seconds for heartbeat to allow mesh to connect + await trio.sleep(2) + # Central node publish to the topic so that this topic + # is added to central node's fanout # publish from the randomly chosen host - await pubsubs_gsub[origin_idx].publish("foobar", msg_content) + await pubsubs_gsub[central_node_index].publish(topic, b"data") - await asyncio.sleep(0.5) - # Assert that all blocking queues receive the message - for queue in queues: - msg = await queue.get() - assert msg.data == msg_content + # Check that the gossipsub of central node has fanout for the topic + assert topic in gossipsubs[central_node_index].fanout + # Check that the gossipsub of central node does not have a mesh for the topic + assert topic not in gossipsubs[central_node_index].mesh + + # Central node subscribes the topic + await pubsubs_gsub[central_node_index].subscribe(topic) + + await trio.sleep(2) + + # Check that the gossipsub of central node no longer has fanout for the topic + assert topic not in gossipsubs[central_node_index].fanout + + for i in hosts_indices: + if i in subscribed_peer_indices: + assert hosts[i].get_id() in gossipsubs[central_node_index].mesh[topic] + assert hosts[central_node_index].get_id() in gossipsubs[i].mesh[topic] + else: + assert ( + hosts[i].get_id() not in gossipsubs[central_node_index].mesh[topic] + ) + assert topic not in gossipsubs[i].mesh -@pytest.mark.parametrize("num_hosts", (10,)) -@pytest.mark.asyncio -async def test_fanout(hosts, pubsubs_gsub): - num_msgs = 5 +@pytest.mark.trio +async def test_leave(): + async with PubsubFactory.create_batch_with_gossipsub(1) as pubsubs_gsub: + gossipsub = pubsubs_gsub[0].router + topic = "test_leave" - # All pubsub subscribe to foobar except for `pubsubs_gsub[0]` - queues = [] - for i in range(1, len(pubsubs_gsub)): - q = await pubsubs_gsub[i].subscribe("foobar") + assert topic not in gossipsub.mesh - # Add each blocking queue to an array of blocking queues - queues.append(q) + await gossipsub.join(topic) + assert topic in gossipsub.mesh - # Sparsely connect libp2p hosts in random way - await dense_connect(hosts) + await gossipsub.leave(topic) + assert topic not in gossipsub.mesh - # Wait 2 seconds for heartbeat to allow mesh to connect - await asyncio.sleep(2) - - topic = "foobar" - # Send messages with origin not subscribed - for i in range(num_msgs): - msg_content = b"foo " + i.to_bytes(1, "big") - - # Pick the message origin to the node that is not subscribed to 'foobar' - origin_idx = 0 - - # publish from the randomly chosen host - await pubsubs_gsub[origin_idx].publish(topic, msg_content) - - await asyncio.sleep(0.5) - # Assert that all blocking queues receive the message - for queue in queues: - msg = await queue.get() - assert msg.data == msg_content - - # Subscribe message origin - queues.insert(0, await pubsubs_gsub[0].subscribe(topic)) - - # Send messages again - for i in range(num_msgs): - msg_content = b"bar " + i.to_bytes(1, "big") - - # Pick the message origin to the node that is not subscribed to 'foobar' - origin_idx = 0 - - # publish from the randomly chosen host - await pubsubs_gsub[origin_idx].publish(topic, msg_content) - - await asyncio.sleep(0.5) - # Assert that all blocking queues receive the message - for queue in queues: - msg = await queue.get() - assert msg.data == msg_content + # Test re-leave + await gossipsub.leave(topic) -@pytest.mark.parametrize("num_hosts", (10,)) -@pytest.mark.asyncio +@pytest.mark.trio +async def test_handle_graft(monkeypatch): + async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub: + gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub) + + index_alice = 0 + id_alice = pubsubs_gsub[index_alice].my_id + index_bob = 1 + id_bob = pubsubs_gsub[index_bob].my_id + await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host) + + # Wait 2 seconds for heartbeat to allow mesh to connect + await trio.sleep(2) + + topic = "test_handle_graft" + # Only lice subscribe to the topic + await gossipsubs[index_alice].join(topic) + + # Monkey patch bob's `emit_prune` function so we can + # check if it is called in `handle_graft` + event_emit_prune = trio.Event() + + async def emit_prune(topic, sender_peer_id): + event_emit_prune.set() + await trio.hazmat.checkpoint() + + monkeypatch.setattr(gossipsubs[index_bob], "emit_prune", emit_prune) + + # Check that alice is bob's peer but not his mesh peer + assert gossipsubs[index_bob].peer_protocol[id_alice] == PROTOCOL_ID + assert topic not in gossipsubs[index_bob].mesh + + await gossipsubs[index_alice].emit_graft(topic, id_bob) + + # Check that `emit_prune` is called + await event_emit_prune.wait() + + # Check that bob is alice's peer but not her mesh peer + assert topic in gossipsubs[index_alice].mesh + assert id_bob not in gossipsubs[index_alice].mesh[topic] + assert gossipsubs[index_alice].peer_protocol[id_bob] == PROTOCOL_ID + + await gossipsubs[index_bob].emit_graft(topic, id_alice) + + await trio.sleep(1) + + # Check that bob is now alice's mesh peer + assert id_bob in gossipsubs[index_alice].mesh[topic] + + +@pytest.mark.trio +async def test_handle_prune(): + async with PubsubFactory.create_batch_with_gossipsub( + 2, heartbeat_interval=3 + ) as pubsubs_gsub: + gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub) + + index_alice = 0 + id_alice = pubsubs_gsub[index_alice].my_id + index_bob = 1 + id_bob = pubsubs_gsub[index_bob].my_id + + topic = "test_handle_prune" + for pubsub in pubsubs_gsub: + await pubsub.subscribe(topic) + + await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host) + + # Wait for heartbeat to allow mesh to connect + await trio.sleep(1) + + # Check that they are each other's mesh peer + assert id_alice in gossipsubs[index_bob].mesh[topic] + assert id_bob in gossipsubs[index_alice].mesh[topic] + + # alice emit prune message to bob, alice should be removed + # from bob's mesh peer + await gossipsubs[index_alice].emit_prune(topic, id_bob) + # `emit_prune` does not remove bob from alice's mesh peers + assert id_bob in gossipsubs[index_alice].mesh[topic] + + # NOTE: We increase `heartbeat_interval` to 3 seconds so that bob will not + # add alice back to his mesh after heartbeat. + # Wait for bob to `handle_prune` + await trio.sleep(0.1) + + # Check that alice is no longer bob's mesh peer + assert id_alice not in gossipsubs[index_bob].mesh[topic] + + +@pytest.mark.trio +async def test_dense(): + async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub: + hosts = [pubsub.host for pubsub in pubsubs_gsub] + num_msgs = 5 + + # All pubsub subscribe to foobar + queues = [await pubsub.subscribe("foobar") for pubsub in pubsubs_gsub] + + # Densely connect libp2p hosts in a random way + await dense_connect(hosts) + + # Wait 2 seconds for heartbeat to allow mesh to connect + await trio.sleep(2) + + for i in range(num_msgs): + msg_content = b"foo " + i.to_bytes(1, "big") + + # randomly pick a message origin + origin_idx = random.randint(0, len(hosts) - 1) + + # publish from the randomly chosen host + await pubsubs_gsub[origin_idx].publish("foobar", msg_content) + + await trio.sleep(0.5) + # Assert that all blocking queues receive the message + for queue in queues: + msg = await queue.get() + assert msg.data == msg_content + + +@pytest.mark.trio +async def test_fanout(): + async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub: + hosts = [pubsub.host for pubsub in pubsubs_gsub] + num_msgs = 5 + + # All pubsub subscribe to foobar except for `pubsubs_gsub[0]` + subs = [await pubsub.subscribe("foobar") for pubsub in pubsubs_gsub[1:]] + + # Sparsely connect libp2p hosts in random way + await dense_connect(hosts) + + # Wait 2 seconds for heartbeat to allow mesh to connect + await trio.sleep(2) + + topic = "foobar" + # Send messages with origin not subscribed + for i in range(num_msgs): + msg_content = b"foo " + i.to_bytes(1, "big") + + # Pick the message origin to the node that is not subscribed to 'foobar' + origin_idx = 0 + + # publish from the randomly chosen host + await pubsubs_gsub[origin_idx].publish(topic, msg_content) + + await trio.sleep(0.5) + # Assert that all blocking queues receive the message + for sub in subs: + msg = await sub.get() + assert msg.data == msg_content + + # Subscribe message origin + subs.insert(0, await pubsubs_gsub[0].subscribe(topic)) + + # Send messages again + for i in range(num_msgs): + msg_content = b"bar " + i.to_bytes(1, "big") + + # Pick the message origin to the node that is not subscribed to 'foobar' + origin_idx = 0 + + # publish from the randomly chosen host + await pubsubs_gsub[origin_idx].publish(topic, msg_content) + + await trio.sleep(0.5) + # Assert that all blocking queues receive the message + for sub in subs: + msg = await sub.get() + assert msg.data == msg_content + + +@pytest.mark.trio @pytest.mark.slow -async def test_fanout_maintenance(hosts, pubsubs_gsub): - num_msgs = 5 +async def test_fanout_maintenance(): + async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub: + hosts = [pubsub.host for pubsub in pubsubs_gsub] + num_msgs = 5 - # All pubsub subscribe to foobar - queues = [] - topic = "foobar" - for i in range(1, len(pubsubs_gsub)): - q = await pubsubs_gsub[i].subscribe(topic) + # All pubsub subscribe to foobar + queues = [] + topic = "foobar" + for i in range(1, len(pubsubs_gsub)): + q = await pubsubs_gsub[i].subscribe(topic) - # Add each blocking queue to an array of blocking queues - queues.append(q) + # Add each blocking queue to an array of blocking queues + queues.append(q) - # Sparsely connect libp2p hosts in random way - await dense_connect(hosts) + # Sparsely connect libp2p hosts in random way + await dense_connect(hosts) - # Wait 2 seconds for heartbeat to allow mesh to connect - await asyncio.sleep(2) + # Wait 2 seconds for heartbeat to allow mesh to connect + await trio.sleep(2) - # Send messages with origin not subscribed - for i in range(num_msgs): - msg_content = b"foo " + i.to_bytes(1, "big") + # Send messages with origin not subscribed + for i in range(num_msgs): + msg_content = b"foo " + i.to_bytes(1, "big") - # Pick the message origin to the node that is not subscribed to 'foobar' - origin_idx = 0 + # Pick the message origin to the node that is not subscribed to 'foobar' + origin_idx = 0 + + # publish from the randomly chosen host + await pubsubs_gsub[origin_idx].publish(topic, msg_content) + + await trio.sleep(0.5) + # Assert that all blocking queues receive the message + for queue in queues: + msg = await queue.get() + assert msg.data == msg_content + + for sub in pubsubs_gsub: + await sub.unsubscribe(topic) + + queues = [] + + await trio.sleep(2) + + # Resub and repeat + for i in range(1, len(pubsubs_gsub)): + q = await pubsubs_gsub[i].subscribe(topic) + + # Add each blocking queue to an array of blocking queues + queues.append(q) + + await trio.sleep(2) + + # Check messages can still be sent + for i in range(num_msgs): + msg_content = b"bar " + i.to_bytes(1, "big") + + # Pick the message origin to the node that is not subscribed to 'foobar' + origin_idx = 0 + + # publish from the randomly chosen host + await pubsubs_gsub[origin_idx].publish(topic, msg_content) + + await trio.sleep(0.5) + # Assert that all blocking queues receive the message + for queue in queues: + msg = await queue.get() + assert msg.data == msg_content + + +@pytest.mark.trio +async def test_gossip_propagation(): + async with PubsubFactory.create_batch_with_gossipsub( + 2, degree=1, degree_low=0, degree_high=2, gossip_window=50, gossip_history=100 + ) as pubsubs_gsub: + topic = "foo" + queue_0 = await pubsubs_gsub[0].subscribe(topic) + + # node 0 publish to topic + msg_content = b"foo_msg" # publish from the randomly chosen host - await pubsubs_gsub[origin_idx].publish(topic, msg_content) + await pubsubs_gsub[0].publish(topic, msg_content) - await asyncio.sleep(0.5) - # Assert that all blocking queues receive the message - for queue in queues: - msg = await queue.get() - assert msg.data == msg_content - - for sub in pubsubs_gsub: - await sub.unsubscribe(topic) - - queues = [] - - await asyncio.sleep(2) - - # Resub and repeat - for i in range(1, len(pubsubs_gsub)): - q = await pubsubs_gsub[i].subscribe(topic) - - # Add each blocking queue to an array of blocking queues - queues.append(q) - - await asyncio.sleep(2) - - # Check messages can still be sent - for i in range(num_msgs): - msg_content = b"bar " + i.to_bytes(1, "big") - - # Pick the message origin to the node that is not subscribed to 'foobar' - origin_idx = 0 - - # publish from the randomly chosen host - await pubsubs_gsub[origin_idx].publish(topic, msg_content) - - await asyncio.sleep(0.5) - # Assert that all blocking queues receive the message - for queue in queues: - msg = await queue.get() - assert msg.data == msg_content + await trio.sleep(0.5) + # Assert that the blocking queues receive the message + msg = await queue_0.get() + assert msg.data == msg_content -@pytest.mark.parametrize( - "num_hosts, gossipsub_params", - ( - ( - 2, - GossipsubParams( - degree=1, - degree_low=0, - degree_high=2, - gossip_window=50, - gossip_history=100, - ), - ), - ), -) -@pytest.mark.asyncio -async def test_gossip_propagation(hosts, pubsubs_gsub): - topic = "foo" - await pubsubs_gsub[0].subscribe(topic) - - # node 0 publish to topic - msg_content = b"foo_msg" - - # publish from the randomly chosen host - await pubsubs_gsub[0].publish(topic, msg_content) - - # now node 1 subscribes - queue_1 = await pubsubs_gsub[1].subscribe(topic) - - await connect(hosts[0], hosts[1]) - - # wait for gossip heartbeat - await asyncio.sleep(2) - - # should be able to read message - msg = await queue_1.get() - assert msg.data == msg_content - - -@pytest.mark.parametrize( - "num_hosts, gossipsub_params", ((1, GossipsubParams(heartbeat_initial_delay=100)),) -) @pytest.mark.parametrize("initial_mesh_peer_count", (7, 10, 13)) -@pytest.mark.asyncio -async def test_mesh_heartbeat( - num_hosts, initial_mesh_peer_count, pubsubs_gsub, hosts, monkeypatch -): - # It's difficult to set up the initial peer subscription condition. - # Ideally I would like to have initial mesh peer count that's below ``GossipSubDegree`` - # so I can test if `mesh_heartbeat` return correct peers to GRAFT. - # The problem is that I can not set it up so that we have peers subscribe to the topic - # but not being part of our mesh peers (as these peers are the peers to GRAFT). - # So I monkeypatch the peer subscriptions and our mesh peers. - total_peer_count = 14 - topic = "TEST_MESH_HEARTBEAT" +@pytest.mark.trio +async def test_mesh_heartbeat(initial_mesh_peer_count, monkeypatch): + async with PubsubFactory.create_batch_with_gossipsub( + 1, heartbeat_initial_delay=100 + ) as pubsubs_gsub: + # It's difficult to set up the initial peer subscription condition. + # Ideally I would like to have initial mesh peer count that's below ``GossipSubDegree`` + # so I can test if `mesh_heartbeat` return correct peers to GRAFT. + # The problem is that I can not set it up so that we have peers subscribe to the topic + # but not being part of our mesh peers (as these peers are the peers to GRAFT). + # So I monkeypatch the peer subscriptions and our mesh peers. + total_peer_count = 14 + topic = "TEST_MESH_HEARTBEAT" - fake_peer_ids = [ - ID((i).to_bytes(2, byteorder="big")) for i in range(total_peer_count) - ] - peer_protocol = {peer_id: PROTOCOL_ID for peer_id in fake_peer_ids} - monkeypatch.setattr(pubsubs_gsub[0].router, "peer_protocol", peer_protocol) + fake_peer_ids = [IDFactory() for _ in range(total_peer_count)] + peer_protocol = {peer_id: PROTOCOL_ID for peer_id in fake_peer_ids} + monkeypatch.setattr(pubsubs_gsub[0].router, "peer_protocol", peer_protocol) - peer_topics = {topic: set(fake_peer_ids)} - # Monkeypatch the peer subscriptions - monkeypatch.setattr(pubsubs_gsub[0], "peer_topics", peer_topics) + peer_topics = {topic: set(fake_peer_ids)} + # Monkeypatch the peer subscriptions + monkeypatch.setattr(pubsubs_gsub[0], "peer_topics", peer_topics) - mesh_peer_indices = random.sample(range(total_peer_count), initial_mesh_peer_count) - mesh_peers = [fake_peer_ids[i] for i in mesh_peer_indices] - router_mesh = {topic: set(mesh_peers)} - # Monkeypatch our mesh peers - monkeypatch.setattr(pubsubs_gsub[0].router, "mesh", router_mesh) + mesh_peer_indices = random.sample( + range(total_peer_count), initial_mesh_peer_count + ) + mesh_peers = [fake_peer_ids[i] for i in mesh_peer_indices] + router_mesh = {topic: set(mesh_peers)} + # Monkeypatch our mesh peers + monkeypatch.setattr(pubsubs_gsub[0].router, "mesh", router_mesh) - peers_to_graft, peers_to_prune = pubsubs_gsub[0].router.mesh_heartbeat() - if initial_mesh_peer_count > GOSSIPSUB_PARAMS.degree: - # If number of initial mesh peers is more than `GossipSubDegree`, we should PRUNE mesh peers - assert len(peers_to_graft) == 0 - assert len(peers_to_prune) == initial_mesh_peer_count - GOSSIPSUB_PARAMS.degree - for peer in peers_to_prune: - assert peer in mesh_peers - elif initial_mesh_peer_count < GOSSIPSUB_PARAMS.degree: - # If number of initial mesh peers is less than `GossipSubDegree`, we should GRAFT more peers - assert len(peers_to_prune) == 0 - assert len(peers_to_graft) == GOSSIPSUB_PARAMS.degree - initial_mesh_peer_count - for peer in peers_to_graft: - assert peer not in mesh_peers - else: - assert len(peers_to_prune) == 0 and len(peers_to_graft) == 0 - - -@pytest.mark.parametrize( - "num_hosts, gossipsub_params", ((1, GossipsubParams(heartbeat_initial_delay=100)),) -) -@pytest.mark.parametrize("initial_peer_count", (1, 4, 7)) -@pytest.mark.asyncio -async def test_gossip_heartbeat( - num_hosts, initial_peer_count, pubsubs_gsub, hosts, monkeypatch -): - # The problem is that I can not set it up so that we have peers subscribe to the topic - # but not being part of our mesh peers (as these peers are the peers to GRAFT). - # So I monkeypatch the peer subscriptions and our mesh peers. - total_peer_count = 28 - topic_mesh = "TEST_GOSSIP_HEARTBEAT_1" - topic_fanout = "TEST_GOSSIP_HEARTBEAT_2" - - fake_peer_ids = [ - ID((i).to_bytes(2, byteorder="big")) for i in range(total_peer_count) - ] - peer_protocol = {peer_id: PROTOCOL_ID for peer_id in fake_peer_ids} - monkeypatch.setattr(pubsubs_gsub[0].router, "peer_protocol", peer_protocol) - - topic_mesh_peer_count = 14 - # Split into mesh peers and fanout peers - peer_topics = { - topic_mesh: set(fake_peer_ids[:topic_mesh_peer_count]), - topic_fanout: set(fake_peer_ids[topic_mesh_peer_count:]), - } - # Monkeypatch the peer subscriptions - monkeypatch.setattr(pubsubs_gsub[0], "peer_topics", peer_topics) - - mesh_peer_indices = random.sample(range(topic_mesh_peer_count), initial_peer_count) - mesh_peers = [fake_peer_ids[i] for i in mesh_peer_indices] - router_mesh = {topic_mesh: set(mesh_peers)} - # Monkeypatch our mesh peers - monkeypatch.setattr(pubsubs_gsub[0].router, "mesh", router_mesh) - fanout_peer_indices = random.sample( - range(topic_mesh_peer_count, total_peer_count), initial_peer_count - ) - fanout_peers = [fake_peer_ids[i] for i in fanout_peer_indices] - router_fanout = {topic_fanout: set(fanout_peers)} - # Monkeypatch our fanout peers - monkeypatch.setattr(pubsubs_gsub[0].router, "fanout", router_fanout) - - def window(topic): - if topic == topic_mesh: - return [topic_mesh] - elif topic == topic_fanout: - return [topic_fanout] + peers_to_graft, peers_to_prune = pubsubs_gsub[0].router.mesh_heartbeat() + if initial_mesh_peer_count > pubsubs_gsub[0].router.degree: + # If number of initial mesh peers is more than `GossipSubDegree`, + # we should PRUNE mesh peers + assert len(peers_to_graft) == 0 + assert ( + len(peers_to_prune) + == initial_mesh_peer_count - pubsubs_gsub[0].router.degree + ) + for peer in peers_to_prune: + assert peer in mesh_peers + elif initial_mesh_peer_count < pubsubs_gsub[0].router.degree: + # If number of initial mesh peers is less than `GossipSubDegree`, + # we should GRAFT more peers + assert len(peers_to_prune) == 0 + assert ( + len(peers_to_graft) + == pubsubs_gsub[0].router.degree - initial_mesh_peer_count + ) + for peer in peers_to_graft: + assert peer not in mesh_peers else: - return [] + assert len(peers_to_prune) == 0 and len(peers_to_graft) == 0 - # Monkeypatch the memory cache messages - monkeypatch.setattr(pubsubs_gsub[0].router.mcache, "window", window) - peers_to_gossip = pubsubs_gsub[0].router.gossip_heartbeat() - # If our mesh peer count is less than `GossipSubDegree`, we should gossip to up to - # `GossipSubDegree` peers (exclude mesh peers). - if topic_mesh_peer_count - initial_peer_count < GOSSIPSUB_PARAMS.degree: - # The same goes for fanout so it's two times the number of peers to gossip. - assert len(peers_to_gossip) == 2 * (topic_mesh_peer_count - initial_peer_count) - elif topic_mesh_peer_count - initial_peer_count >= GOSSIPSUB_PARAMS.degree: - assert len(peers_to_gossip) == 2 * (GOSSIPSUB_PARAMS.degree) +@pytest.mark.parametrize("initial_peer_count", (1, 4, 7)) +@pytest.mark.trio +async def test_gossip_heartbeat(initial_peer_count, monkeypatch): + async with PubsubFactory.create_batch_with_gossipsub( + 1, heartbeat_initial_delay=100 + ) as pubsubs_gsub: + # The problem is that I can not set it up so that we have peers subscribe to the topic + # but not being part of our mesh peers (as these peers are the peers to GRAFT). + # So I monkeypatch the peer subscriptions and our mesh peers. + total_peer_count = 28 + topic_mesh = "TEST_GOSSIP_HEARTBEAT_1" + topic_fanout = "TEST_GOSSIP_HEARTBEAT_2" - for peer in peers_to_gossip: - if peer in peer_topics[topic_mesh]: - # Check that the peer to gossip to is not in our mesh peers - assert peer not in mesh_peers - assert topic_mesh in peers_to_gossip[peer] - elif peer in peer_topics[topic_fanout]: - # Check that the peer to gossip to is not in our fanout peers - assert peer not in fanout_peers - assert topic_fanout in peers_to_gossip[peer] + fake_peer_ids = [IDFactory() for _ in range(total_peer_count)] + peer_protocol = {peer_id: PROTOCOL_ID for peer_id in fake_peer_ids} + monkeypatch.setattr(pubsubs_gsub[0].router, "peer_protocol", peer_protocol) + + topic_mesh_peer_count = 14 + # Split into mesh peers and fanout peers + peer_topics = { + topic_mesh: set(fake_peer_ids[:topic_mesh_peer_count]), + topic_fanout: set(fake_peer_ids[topic_mesh_peer_count:]), + } + # Monkeypatch the peer subscriptions + monkeypatch.setattr(pubsubs_gsub[0], "peer_topics", peer_topics) + + mesh_peer_indices = random.sample( + range(topic_mesh_peer_count), initial_peer_count + ) + mesh_peers = [fake_peer_ids[i] for i in mesh_peer_indices] + router_mesh = {topic_mesh: set(mesh_peers)} + # Monkeypatch our mesh peers + monkeypatch.setattr(pubsubs_gsub[0].router, "mesh", router_mesh) + fanout_peer_indices = random.sample( + range(topic_mesh_peer_count, total_peer_count), initial_peer_count + ) + fanout_peers = [fake_peer_ids[i] for i in fanout_peer_indices] + router_fanout = {topic_fanout: set(fanout_peers)} + # Monkeypatch our fanout peers + monkeypatch.setattr(pubsubs_gsub[0].router, "fanout", router_fanout) + + def window(topic): + if topic == topic_mesh: + return [topic_mesh] + elif topic == topic_fanout: + return [topic_fanout] + else: + return [] + + # Monkeypatch the memory cache messages + monkeypatch.setattr(pubsubs_gsub[0].router.mcache, "window", window) + + peers_to_gossip = pubsubs_gsub[0].router.gossip_heartbeat() + # If our mesh peer count is less than `GossipSubDegree`, we should gossip to up to + # `GossipSubDegree` peers (exclude mesh peers). + if topic_mesh_peer_count - initial_peer_count < pubsubs_gsub[0].router.degree: + # The same goes for fanout so it's two times the number of peers to gossip. + assert len(peers_to_gossip) == 2 * ( + topic_mesh_peer_count - initial_peer_count + ) + elif ( + topic_mesh_peer_count - initial_peer_count >= pubsubs_gsub[0].router.degree + ): + assert len(peers_to_gossip) == 2 * (pubsubs_gsub[0].router.degree) + + for peer in peers_to_gossip: + if peer in peer_topics[topic_mesh]: + # Check that the peer to gossip to is not in our mesh peers + assert peer not in mesh_peers + assert topic_mesh in peers_to_gossip[peer] + elif peer in peer_topics[topic_fanout]: + # Check that the peer to gossip to is not in our fanout peers + assert peer not in fanout_peers + assert topic_fanout in peers_to_gossip[peer] diff --git a/tests/pubsub/test_gossipsub_backward_compatibility.py b/tests/pubsub/test_gossipsub_backward_compatibility.py index d82fd229..08f0284b 100644 --- a/tests/pubsub/test_gossipsub_backward_compatibility.py +++ b/tests/pubsub/test_gossipsub_backward_compatibility.py @@ -3,25 +3,25 @@ import functools import pytest from libp2p.tools.constants import FLOODSUB_PROTOCOL_ID -from libp2p.tools.factories import GossipsubFactory +from libp2p.tools.factories import PubsubFactory from libp2p.tools.pubsub.floodsub_integration_test_settings import ( floodsub_protocol_pytest_params, perform_test_from_obj, ) -@pytest.mark.asyncio -async def test_gossipsub_initialize_with_floodsub_protocol(): - GossipsubFactory(protocols=[FLOODSUB_PROTOCOL_ID]) - - @pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params) -@pytest.mark.asyncio +@pytest.mark.trio @pytest.mark.slow async def test_gossipsub_run_with_floodsub_tests(test_case_obj): await perform_test_from_obj( test_case_obj, functools.partial( - GossipsubFactory, degree=3, degree_low=2, degree_high=4, time_to_live=30 + PubsubFactory.create_batch_with_gossipsub, + protocols=[FLOODSUB_PROTOCOL_ID], + degree=3, + degree_low=2, + degree_high=4, + time_to_live=30, ), ) diff --git a/tests/pubsub/test_mcache.py b/tests/pubsub/test_mcache.py index e80ad27a..fb764b31 100644 --- a/tests/pubsub/test_mcache.py +++ b/tests/pubsub/test_mcache.py @@ -1,5 +1,3 @@ -import pytest - from libp2p.pubsub.mcache import MessageCache @@ -12,8 +10,7 @@ class Msg: self.from_id = from_id -@pytest.mark.asyncio -async def test_mcache(): +def test_mcache(): # Ported from: # https://github.com/libp2p/go-libp2p-pubsub/blob/51b7501433411b5096cac2b4994a36a68515fc03/mcache_test.go mcache = MessageCache(3, 5) diff --git a/tests/pubsub/test_pubsub.py b/tests/pubsub/test_pubsub.py index 6f3c6725..d6c29310 100644 --- a/tests/pubsub/test_pubsub.py +++ b/tests/pubsub/test_pubsub.py @@ -1,12 +1,14 @@ -import asyncio +from contextlib import contextmanager from typing import NamedTuple import pytest +import trio from libp2p.exceptions import ValidationError -from libp2p.peer.id import ID from libp2p.pubsub.pb import rpc_pb2 -from libp2p.pubsub.pubsub import PUBSUB_SIGNING_PREFIX +from libp2p.pubsub.pubsub import PUBSUB_SIGNING_PREFIX, SUBSCRIPTION_CHANNEL_SIZE +from libp2p.tools.constants import MAX_READ_LEN +from libp2p.tools.factories import IDFactory, PubsubFactory, net_stream_pair_factory from libp2p.tools.pubsub.utils import make_pubsub_msg from libp2p.tools.utils import connect from libp2p.utils import encode_varint_prefixed @@ -15,335 +17,317 @@ TESTING_TOPIC = "TEST_SUBSCRIBE" TESTING_DATA = b"data" -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_subscribe_and_unsubscribe(pubsubs_fsub): - await pubsubs_fsub[0].subscribe(TESTING_TOPIC) - assert TESTING_TOPIC in pubsubs_fsub[0].my_topics +@pytest.mark.trio +async def test_subscribe_and_unsubscribe(): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: + await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids - await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) - assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics + await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) + assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_re_subscribe(pubsubs_fsub): - await pubsubs_fsub[0].subscribe(TESTING_TOPIC) - assert TESTING_TOPIC in pubsubs_fsub[0].my_topics +@pytest.mark.trio +async def test_re_subscribe(): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: + await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids - await pubsubs_fsub[0].subscribe(TESTING_TOPIC) - assert TESTING_TOPIC in pubsubs_fsub[0].my_topics + await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_re_unsubscribe(pubsubs_fsub): - # Unsubscribe from topic we didn't even subscribe to - assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].my_topics - await pubsubs_fsub[0].unsubscribe("NOT_MY_TOPIC") - assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].my_topics +@pytest.mark.trio +async def test_re_unsubscribe(): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: + # Unsubscribe from topic we didn't even subscribe to + assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].topic_ids + await pubsubs_fsub[0].unsubscribe("NOT_MY_TOPIC") + assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].topic_ids - await pubsubs_fsub[0].subscribe(TESTING_TOPIC) - assert TESTING_TOPIC in pubsubs_fsub[0].my_topics + await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids - await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) - assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics + await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) + assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids - await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) - assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics + await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) + assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids -@pytest.mark.asyncio -async def test_peers_subscribe(pubsubs_fsub): - await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host) - await pubsubs_fsub[0].subscribe(TESTING_TOPIC) - # Yield to let 0 notify 1 - await asyncio.sleep(1) - assert pubsubs_fsub[0].my_id in pubsubs_fsub[1].peer_topics[TESTING_TOPIC] - await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) - # Yield to let 0 notify 1 - await asyncio.sleep(1) - assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics[TESTING_TOPIC] +@pytest.mark.trio +async def test_peers_subscribe(): + async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub: + await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host) + await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + # Yield to let 0 notify 1 + await trio.sleep(1) + assert pubsubs_fsub[0].my_id in pubsubs_fsub[1].peer_topics[TESTING_TOPIC] + await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) + # Yield to let 0 notify 1 + await trio.sleep(1) + assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics[TESTING_TOPIC] -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_get_hello_packet(pubsubs_fsub): - def _get_hello_packet_topic_ids(): - packet = pubsubs_fsub[0].get_hello_packet() - return tuple(sub.topicid for sub in packet.subscriptions) +@pytest.mark.trio +async def test_get_hello_packet(): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: - # Test: No subscription, so there should not be any topic ids in the hello packet. - assert len(_get_hello_packet_topic_ids()) == 0 + def _get_hello_packet_topic_ids(): + packet = pubsubs_fsub[0].get_hello_packet() + return tuple(sub.topicid for sub in packet.subscriptions) - # Test: After subscriptions, topic ids should be in the hello packet. - topic_ids = ["t", "o", "p", "i", "c"] - await asyncio.gather(*[pubsubs_fsub[0].subscribe(topic) for topic in topic_ids]) - topic_ids_in_hello = _get_hello_packet_topic_ids() - for topic in topic_ids: - assert topic in topic_ids_in_hello + # Test: No subscription, so there should not be any topic ids in the hello packet. + assert len(_get_hello_packet_topic_ids()) == 0 + + # Test: After subscriptions, topic ids should be in the hello packet. + topic_ids = ["t", "o", "p", "i", "c"] + for topic in topic_ids: + await pubsubs_fsub[0].subscribe(topic) + topic_ids_in_hello = _get_hello_packet_topic_ids() + for topic in topic_ids: + assert topic in topic_ids_in_hello -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_set_and_remove_topic_validator(pubsubs_fsub): +@pytest.mark.trio +async def test_set_and_remove_topic_validator(): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: + is_sync_validator_called = False - is_sync_validator_called = False + def sync_validator(peer_id, msg): + nonlocal is_sync_validator_called + is_sync_validator_called = True - def sync_validator(peer_id, msg): - nonlocal is_sync_validator_called - is_sync_validator_called = True + is_async_validator_called = False - is_async_validator_called = False + async def async_validator(peer_id, msg): + nonlocal is_async_validator_called + is_async_validator_called = True + await trio.hazmat.checkpoint() - async def async_validator(peer_id, msg): - nonlocal is_async_validator_called - is_async_validator_called = True + topic = "TEST_VALIDATOR" - topic = "TEST_VALIDATOR" + assert topic not in pubsubs_fsub[0].topic_validators - assert topic not in pubsubs_fsub[0].topic_validators + # Register sync validator + pubsubs_fsub[0].set_topic_validator(topic, sync_validator, False) - # Register sync validator - pubsubs_fsub[0].set_topic_validator(topic, sync_validator, False) + assert topic in pubsubs_fsub[0].topic_validators + topic_validator = pubsubs_fsub[0].topic_validators[topic] + assert not topic_validator.is_async - assert topic in pubsubs_fsub[0].topic_validators - topic_validator = pubsubs_fsub[0].topic_validators[topic] - assert not topic_validator.is_async + # Validate with sync validator + topic_validator.validator(peer_id=IDFactory(), msg="msg") - # Validate with sync validator - topic_validator.validator(peer_id=ID(b"peer"), msg="msg") + assert is_sync_validator_called + assert not is_async_validator_called - assert is_sync_validator_called - assert not is_async_validator_called + # Register with async validator + pubsubs_fsub[0].set_topic_validator(topic, async_validator, True) - # Register with async validator - pubsubs_fsub[0].set_topic_validator(topic, async_validator, True) + is_sync_validator_called = False + assert topic in pubsubs_fsub[0].topic_validators + topic_validator = pubsubs_fsub[0].topic_validators[topic] + assert topic_validator.is_async - is_sync_validator_called = False - assert topic in pubsubs_fsub[0].topic_validators - topic_validator = pubsubs_fsub[0].topic_validators[topic] - assert topic_validator.is_async + # Validate with async validator + await topic_validator.validator(peer_id=IDFactory(), msg="msg") - # Validate with async validator - await topic_validator.validator(peer_id=ID(b"peer"), msg="msg") + assert is_async_validator_called + assert not is_sync_validator_called - assert is_async_validator_called - assert not is_sync_validator_called - - # Remove validator - pubsubs_fsub[0].remove_topic_validator(topic) - assert topic not in pubsubs_fsub[0].topic_validators + # Remove validator + pubsubs_fsub[0].remove_topic_validator(topic) + assert topic not in pubsubs_fsub[0].topic_validators -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_get_msg_validators(pubsubs_fsub): +@pytest.mark.trio +async def test_get_msg_validators(): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: + times_sync_validator_called = 0 - times_sync_validator_called = 0 + def sync_validator(peer_id, msg): + nonlocal times_sync_validator_called + times_sync_validator_called += 1 - def sync_validator(peer_id, msg): - nonlocal times_sync_validator_called - times_sync_validator_called += 1 + times_async_validator_called = 0 - times_async_validator_called = 0 + async def async_validator(peer_id, msg): + nonlocal times_async_validator_called + times_async_validator_called += 1 + await trio.hazmat.checkpoint() - async def async_validator(peer_id, msg): - nonlocal times_async_validator_called - times_async_validator_called += 1 + topic_1 = "TEST_VALIDATOR_1" + topic_2 = "TEST_VALIDATOR_2" + topic_3 = "TEST_VALIDATOR_3" - topic_1 = "TEST_VALIDATOR_1" - topic_2 = "TEST_VALIDATOR_2" - topic_3 = "TEST_VALIDATOR_3" + # Register sync validator for topic 1 and 2 + pubsubs_fsub[0].set_topic_validator(topic_1, sync_validator, False) + pubsubs_fsub[0].set_topic_validator(topic_2, sync_validator, False) - # Register sync validator for topic 1 and 2 - pubsubs_fsub[0].set_topic_validator(topic_1, sync_validator, False) - pubsubs_fsub[0].set_topic_validator(topic_2, sync_validator, False) + # Register async validator for topic 3 + pubsubs_fsub[0].set_topic_validator(topic_3, async_validator, True) - # Register async validator for topic 3 - pubsubs_fsub[0].set_topic_validator(topic_3, async_validator, True) + msg = make_pubsub_msg( + origin_id=pubsubs_fsub[0].my_id, + topic_ids=[topic_1, topic_2, topic_3], + data=b"1234", + seqno=b"\x00" * 8, + ) - msg = make_pubsub_msg( - origin_id=pubsubs_fsub[0].my_id, - topic_ids=[topic_1, topic_2, topic_3], - data=b"1234", - seqno=b"\x00" * 8, - ) + topic_validators = pubsubs_fsub[0].get_msg_validators(msg) + for topic_validator in topic_validators: + if topic_validator.is_async: + await topic_validator.validator(peer_id=IDFactory(), msg="msg") + else: + topic_validator.validator(peer_id=IDFactory(), msg="msg") - topic_validators = pubsubs_fsub[0].get_msg_validators(msg) - for topic_validator in topic_validators: - if topic_validator.is_async: - await topic_validator.validator(peer_id=ID(b"peer"), msg="msg") - else: - topic_validator.validator(peer_id=ID(b"peer"), msg="msg") - - assert times_sync_validator_called == 2 - assert times_async_validator_called == 1 + assert times_sync_validator_called == 2 + assert times_async_validator_called == 1 -@pytest.mark.parametrize("num_hosts", (1,)) @pytest.mark.parametrize( "is_topic_1_val_passed, is_topic_2_val_passed", ((False, True), (True, False), (True, True)), ) -@pytest.mark.asyncio -async def test_validate_msg(pubsubs_fsub, is_topic_1_val_passed, is_topic_2_val_passed): - def passed_sync_validator(peer_id, msg): - return True +@pytest.mark.trio +async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: - def failed_sync_validator(peer_id, msg): - return False + def passed_sync_validator(peer_id, msg): + return True - async def passed_async_validator(peer_id, msg): - return True + def failed_sync_validator(peer_id, msg): + return False - async def failed_async_validator(peer_id, msg): - return False + async def passed_async_validator(peer_id, msg): + await trio.hazmat.checkpoint() + return True - topic_1 = "TEST_SYNC_VALIDATOR" - topic_2 = "TEST_ASYNC_VALIDATOR" + async def failed_async_validator(peer_id, msg): + await trio.hazmat.checkpoint() + return False - if is_topic_1_val_passed: - pubsubs_fsub[0].set_topic_validator(topic_1, passed_sync_validator, False) - else: - pubsubs_fsub[0].set_topic_validator(topic_1, failed_sync_validator, False) + topic_1 = "TEST_SYNC_VALIDATOR" + topic_2 = "TEST_ASYNC_VALIDATOR" - if is_topic_2_val_passed: - pubsubs_fsub[0].set_topic_validator(topic_2, passed_async_validator, True) - else: - pubsubs_fsub[0].set_topic_validator(topic_2, failed_async_validator, True) - - msg = make_pubsub_msg( - origin_id=pubsubs_fsub[0].my_id, - topic_ids=[topic_1, topic_2], - data=b"1234", - seqno=b"\x00" * 8, - ) - - if is_topic_1_val_passed and is_topic_2_val_passed: - await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg) - else: - with pytest.raises(ValidationError): - await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg) - - -class FakeNetStream: - _queue: asyncio.Queue - - class FakeMplexConn(NamedTuple): - peer_id: ID = ID(b"\x12\x20" + b"\x00" * 32) - - muxed_conn = FakeMplexConn() - - def __init__(self) -> None: - self._queue = asyncio.Queue() - - async def read(self, n: int = -1) -> bytes: - buf = bytearray() - # Force to blocking wait if no data available now. - if self._queue.empty(): - first_byte = await self._queue.get() - buf.extend(first_byte) - # If `n == -1`, read until no data is in the buffer(_queue). - # Else, read until no data is in the buffer(_queue) or we have read `n` bytes. - while (n == -1) or (len(buf) < n): - if self._queue.empty(): - break - buf.extend(await self._queue.get()) - return bytes(buf) - - async def write(self, data: bytes) -> int: - for i in data: - await self._queue.put(i.to_bytes(1, "big")) - return len(data) - - -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_continuously_read_stream(pubsubs_fsub, monkeypatch): - stream = FakeNetStream() - - await pubsubs_fsub[0].subscribe(TESTING_TOPIC) - - event_push_msg = asyncio.Event() - event_handle_subscription = asyncio.Event() - event_handle_rpc = asyncio.Event() - - async def mock_push_msg(msg_forwarder, msg): - event_push_msg.set() - - def mock_handle_subscription(origin_id, sub_message): - event_handle_subscription.set() - - async def mock_handle_rpc(rpc, sender_peer_id): - event_handle_rpc.set() - - monkeypatch.setattr(pubsubs_fsub[0], "push_msg", mock_push_msg) - monkeypatch.setattr( - pubsubs_fsub[0], "handle_subscription", mock_handle_subscription - ) - monkeypatch.setattr(pubsubs_fsub[0].router, "handle_rpc", mock_handle_rpc) - - async def wait_for_event_occurring(event): - try: - await asyncio.wait_for(event.wait(), timeout=1) - except asyncio.TimeoutError as error: - event.clear() - raise asyncio.TimeoutError( - f"Event {event} is not set before the timeout. " - "This indicates the mocked functions are not called properly." - ) from error + if is_topic_1_val_passed: + pubsubs_fsub[0].set_topic_validator(topic_1, passed_sync_validator, False) else: - event.clear() + pubsubs_fsub[0].set_topic_validator(topic_1, failed_sync_validator, False) - # Kick off the task `continuously_read_stream` - task = asyncio.ensure_future(pubsubs_fsub[0].continuously_read_stream(stream)) + if is_topic_2_val_passed: + pubsubs_fsub[0].set_topic_validator(topic_2, passed_async_validator, True) + else: + pubsubs_fsub[0].set_topic_validator(topic_2, failed_async_validator, True) - # Test: `push_msg` is called when publishing to a subscribed topic. - publish_subscribed_topic = rpc_pb2.RPC( - publish=[rpc_pb2.Message(topicIDs=[TESTING_TOPIC])] - ) - await stream.write( - encode_varint_prefixed(publish_subscribed_topic.SerializeToString()) - ) - await wait_for_event_occurring(event_push_msg) - # Make sure the other events are not emitted. - with pytest.raises(asyncio.TimeoutError): - await wait_for_event_occurring(event_handle_subscription) - with pytest.raises(asyncio.TimeoutError): - await wait_for_event_occurring(event_handle_rpc) + msg = make_pubsub_msg( + origin_id=pubsubs_fsub[0].my_id, + topic_ids=[topic_1, topic_2], + data=b"1234", + seqno=b"\x00" * 8, + ) - # Test: `push_msg` is not called when publishing to a topic-not-subscribed. - publish_not_subscribed_topic = rpc_pb2.RPC( - publish=[rpc_pb2.Message(topicIDs=["NOT_SUBSCRIBED"])] - ) - await stream.write( - encode_varint_prefixed(publish_not_subscribed_topic.SerializeToString()) - ) - with pytest.raises(asyncio.TimeoutError): - await wait_for_event_occurring(event_push_msg) + if is_topic_1_val_passed and is_topic_2_val_passed: + await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg) + else: + with pytest.raises(ValidationError): + await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg) - # Test: `handle_subscription` is called when a subscription message is received. - subscription_msg = rpc_pb2.RPC(subscriptions=[rpc_pb2.RPC.SubOpts()]) - await stream.write(encode_varint_prefixed(subscription_msg.SerializeToString())) - await wait_for_event_occurring(event_handle_subscription) - # Make sure the other events are not emitted. - with pytest.raises(asyncio.TimeoutError): - await wait_for_event_occurring(event_push_msg) - with pytest.raises(asyncio.TimeoutError): - await wait_for_event_occurring(event_handle_rpc) - # Test: `handle_rpc` is called when a control message is received. - control_msg = rpc_pb2.RPC(control=rpc_pb2.ControlMessage()) - await stream.write(encode_varint_prefixed(control_msg.SerializeToString())) - await wait_for_event_occurring(event_handle_rpc) - # Make sure the other events are not emitted. - with pytest.raises(asyncio.TimeoutError): - await wait_for_event_occurring(event_push_msg) - with pytest.raises(asyncio.TimeoutError): - await wait_for_event_occurring(event_handle_subscription) +@pytest.mark.trio +async def test_continuously_read_stream(monkeypatch, nursery, is_host_secure): + async def wait_for_event_occurring(event): + await trio.hazmat.checkpoint() + with trio.fail_after(0.1): + await event.wait() - task.cancel() + class Events(NamedTuple): + push_msg: trio.Event + handle_subscription: trio.Event + handle_rpc: trio.Event + + @contextmanager + def mock_methods(): + event_push_msg = trio.Event() + event_handle_subscription = trio.Event() + event_handle_rpc = trio.Event() + + async def mock_push_msg(msg_forwarder, msg): + event_push_msg.set() + await trio.hazmat.checkpoint() + + def mock_handle_subscription(origin_id, sub_message): + event_handle_subscription.set() + + async def mock_handle_rpc(rpc, sender_peer_id): + event_handle_rpc.set() + await trio.hazmat.checkpoint() + + with monkeypatch.context() as m: + m.setattr(pubsubs_fsub[0], "push_msg", mock_push_msg) + m.setattr(pubsubs_fsub[0], "handle_subscription", mock_handle_subscription) + m.setattr(pubsubs_fsub[0].router, "handle_rpc", mock_handle_rpc) + yield Events(event_push_msg, event_handle_subscription, event_handle_rpc) + + async with PubsubFactory.create_batch_with_floodsub( + 1, is_secure=is_host_secure + ) as pubsubs_fsub, net_stream_pair_factory(is_secure=is_host_secure) as stream_pair: + await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + # Kick off the task `continuously_read_stream` + nursery.start_soon(pubsubs_fsub[0].continuously_read_stream, stream_pair[0]) + + # Test: `push_msg` is called when publishing to a subscribed topic. + publish_subscribed_topic = rpc_pb2.RPC( + publish=[rpc_pb2.Message(topicIDs=[TESTING_TOPIC])] + ) + with mock_methods() as events: + await stream_pair[1].write( + encode_varint_prefixed(publish_subscribed_topic.SerializeToString()) + ) + await wait_for_event_occurring(events.push_msg) + # Make sure the other events are not emitted. + with pytest.raises(trio.TooSlowError): + await wait_for_event_occurring(events.handle_subscription) + with pytest.raises(trio.TooSlowError): + await wait_for_event_occurring(events.handle_rpc) + + # Test: `push_msg` is not called when publishing to a topic-not-subscribed. + publish_not_subscribed_topic = rpc_pb2.RPC( + publish=[rpc_pb2.Message(topicIDs=["NOT_SUBSCRIBED"])] + ) + with mock_methods() as events: + await stream_pair[1].write( + encode_varint_prefixed(publish_not_subscribed_topic.SerializeToString()) + ) + with pytest.raises(trio.TooSlowError): + await wait_for_event_occurring(events.push_msg) + + # Test: `handle_subscription` is called when a subscription message is received. + subscription_msg = rpc_pb2.RPC(subscriptions=[rpc_pb2.RPC.SubOpts()]) + with mock_methods() as events: + await stream_pair[1].write( + encode_varint_prefixed(subscription_msg.SerializeToString()) + ) + await wait_for_event_occurring(events.handle_subscription) + # Make sure the other events are not emitted. + with pytest.raises(trio.TooSlowError): + await wait_for_event_occurring(events.push_msg) + with pytest.raises(trio.TooSlowError): + await wait_for_event_occurring(events.handle_rpc) + + # Test: `handle_rpc` is called when a control message is received. + control_msg = rpc_pb2.RPC(control=rpc_pb2.ControlMessage()) + with mock_methods() as events: + await stream_pair[1].write( + encode_varint_prefixed(control_msg.SerializeToString()) + ) + await wait_for_event_occurring(events.handle_rpc) + # Make sure the other events are not emitted. + with pytest.raises(trio.TooSlowError): + await wait_for_event_occurring(events.push_msg) + with pytest.raises(trio.TooSlowError): + await wait_for_event_occurring(events.handle_subscription) # TODO: Add the following tests after they are aligned with Go. @@ -352,229 +336,321 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch): # - `test_handle_peer_queue` -@pytest.mark.parametrize("num_hosts", (1,)) -def test_handle_subscription(pubsubs_fsub): - assert len(pubsubs_fsub[0].peer_topics) == 0 - sub_msg_0 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=TESTING_TOPIC) - peer_ids = [ID(b"\x12\x20" + i.to_bytes(32, "big")) for i in range(2)] - # Test: One peer is subscribed - pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_0) - assert ( - len(pubsubs_fsub[0].peer_topics) == 1 - and TESTING_TOPIC in pubsubs_fsub[0].peer_topics - ) - assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 1 - assert peer_ids[0] in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] - # Test: Another peer is subscribed - pubsubs_fsub[0].handle_subscription(peer_ids[1], sub_msg_0) - assert len(pubsubs_fsub[0].peer_topics) == 1 - assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 2 - assert peer_ids[1] in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] - # Test: Subscribe to another topic - another_topic = "ANOTHER_TOPIC" - sub_msg_1 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=another_topic) - pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_1) - assert len(pubsubs_fsub[0].peer_topics) == 2 - assert another_topic in pubsubs_fsub[0].peer_topics - assert peer_ids[0] in pubsubs_fsub[0].peer_topics[another_topic] - # Test: unsubscribe - unsub_msg = rpc_pb2.RPC.SubOpts(subscribe=False, topicid=TESTING_TOPIC) - pubsubs_fsub[0].handle_subscription(peer_ids[0], unsub_msg) - assert peer_ids[0] not in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] +@pytest.mark.trio +async def test_handle_subscription(): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: + assert len(pubsubs_fsub[0].peer_topics) == 0 + sub_msg_0 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=TESTING_TOPIC) + peer_ids = [IDFactory() for _ in range(2)] + # Test: One peer is subscribed + pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_0) + assert ( + len(pubsubs_fsub[0].peer_topics) == 1 + and TESTING_TOPIC in pubsubs_fsub[0].peer_topics + ) + assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 1 + assert peer_ids[0] in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] + # Test: Another peer is subscribed + pubsubs_fsub[0].handle_subscription(peer_ids[1], sub_msg_0) + assert len(pubsubs_fsub[0].peer_topics) == 1 + assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 2 + assert peer_ids[1] in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] + # Test: Subscribe to another topic + another_topic = "ANOTHER_TOPIC" + sub_msg_1 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=another_topic) + pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_1) + assert len(pubsubs_fsub[0].peer_topics) == 2 + assert another_topic in pubsubs_fsub[0].peer_topics + assert peer_ids[0] in pubsubs_fsub[0].peer_topics[another_topic] + # Test: unsubscribe + unsub_msg = rpc_pb2.RPC.SubOpts(subscribe=False, topicid=TESTING_TOPIC) + pubsubs_fsub[0].handle_subscription(peer_ids[0], unsub_msg) + assert peer_ids[0] not in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_handle_talk(pubsubs_fsub): - sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC) - msg_0 = make_pubsub_msg( - origin_id=pubsubs_fsub[0].my_id, - topic_ids=[TESTING_TOPIC], - data=b"1234", - seqno=b"\x00" * 8, - ) - await pubsubs_fsub[0].handle_talk(msg_0) - msg_1 = make_pubsub_msg( - origin_id=pubsubs_fsub[0].my_id, - topic_ids=["NOT_SUBSCRIBED"], - data=b"1234", - seqno=b"\x11" * 8, - ) - await pubsubs_fsub[0].handle_talk(msg_1) - assert ( - len(pubsubs_fsub[0].my_topics) == 1 - and sub == pubsubs_fsub[0].my_topics[TESTING_TOPIC] - ) - assert sub.qsize() == 1 - assert (await sub.get()) == msg_0 +@pytest.mark.trio +async def test_handle_talk(): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: + sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + msg_0 = make_pubsub_msg( + origin_id=pubsubs_fsub[0].my_id, + topic_ids=[TESTING_TOPIC], + data=b"1234", + seqno=b"\x00" * 8, + ) + pubsubs_fsub[0].notify_subscriptions(msg_0) + msg_1 = make_pubsub_msg( + origin_id=pubsubs_fsub[0].my_id, + topic_ids=["NOT_SUBSCRIBED"], + data=b"1234", + seqno=b"\x11" * 8, + ) + pubsubs_fsub[0].notify_subscriptions(msg_1) + assert ( + len(pubsubs_fsub[0].topic_ids) == 1 + and sub == pubsubs_fsub[0].subscribed_topics_receive[TESTING_TOPIC] + ) + assert (await sub.get()) == msg_0 -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_message_all_peers(pubsubs_fsub, monkeypatch): - peer_ids = [ID(b"\x12\x20" + i.to_bytes(32, "big")) for i in range(10)] - mock_peers = {peer_id: FakeNetStream() for peer_id in peer_ids} - monkeypatch.setattr(pubsubs_fsub[0], "peers", mock_peers) +@pytest.mark.trio +async def test_message_all_peers(monkeypatch, is_host_secure): + async with PubsubFactory.create_batch_with_floodsub( + 1, is_secure=is_host_secure + ) as pubsubs_fsub, net_stream_pair_factory(is_secure=is_host_secure) as stream_pair: + peer_id = IDFactory() + mock_peers = {peer_id: stream_pair[0]} + with monkeypatch.context() as m: + m.setattr(pubsubs_fsub[0], "peers", mock_peers) - empty_rpc = rpc_pb2.RPC() - empty_rpc_bytes = empty_rpc.SerializeToString() - empty_rpc_bytes_len_prefixed = encode_varint_prefixed(empty_rpc_bytes) - await pubsubs_fsub[0].message_all_peers(empty_rpc_bytes) - for stream in mock_peers.values(): - assert (await stream.read()) == empty_rpc_bytes_len_prefixed + empty_rpc = rpc_pb2.RPC() + empty_rpc_bytes = empty_rpc.SerializeToString() + empty_rpc_bytes_len_prefixed = encode_varint_prefixed(empty_rpc_bytes) + await pubsubs_fsub[0].message_all_peers(empty_rpc_bytes) + assert ( + await stream_pair[1].read(MAX_READ_LEN) + ) == empty_rpc_bytes_len_prefixed -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_publish(pubsubs_fsub, monkeypatch): +@pytest.mark.trio +async def test_subscribe_and_publish(): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: + pubsub = pubsubs_fsub[0] + + list_data = [b"d0", b"d1"] + event_receive_data_started = trio.Event() + + async def publish_data(topic): + await event_receive_data_started.wait() + for data in list_data: + await pubsub.publish(topic, data) + + async def receive_data(topic): + i = 0 + event_receive_data_started.set() + assert topic not in pubsub.topic_ids + subscription = await pubsub.subscribe(topic) + async with subscription: + assert topic in pubsub.topic_ids + async for msg in subscription: + assert msg.data == list_data[i] + i += 1 + if i == len(list_data): + break + assert topic not in pubsub.topic_ids + + async with trio.open_nursery() as nursery: + nursery.start_soon(receive_data, TESTING_TOPIC) + nursery.start_soon(publish_data, TESTING_TOPIC) + + +@pytest.mark.trio +async def test_subscribe_and_publish_full_channel(): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: + pubsub = pubsubs_fsub[0] + + extra_data_0 = b"extra_data_0" + extra_data_1 = b"extra_data_1" + + # Test: Subscription channel is of size `SUBSCRIPTION_CHANNEL_SIZE`. + # When the channel is full, new received messages are dropped. + # After the channel has empty slot, the channel can receive new messages. + + # Assume `SUBSCRIPTION_CHANNEL_SIZE` is smaller than `2**(4*8)`. + list_data = [i.to_bytes(4, "big") for i in range(SUBSCRIPTION_CHANNEL_SIZE)] + # Expect `extra_data_0` is dropped and `extra_data_1` is appended. + expected_list_data = list_data + [extra_data_1] + + subscription = await pubsub.subscribe(TESTING_TOPIC) + for data in list_data: + await pubsub.publish(TESTING_TOPIC, data) + + # Publish `extra_data_0` which should be dropped since the channel is already full. + await pubsub.publish(TESTING_TOPIC, extra_data_0) + # Consume a message and there is an empty slot in the channel. + assert (await subscription.get()).data == expected_list_data.pop(0) + # Publish `extra_data_1` which should be appended to the channel. + await pubsub.publish(TESTING_TOPIC, extra_data_1) + + for expected_data in expected_list_data: + assert (await subscription.get()).data == expected_data + + +@pytest.mark.trio +async def test_publish_push_msg_is_called(monkeypatch): msg_forwarders = [] msgs = [] async def push_msg(msg_forwarder, msg): msg_forwarders.append(msg_forwarder) msgs.append(msg) + await trio.hazmat.checkpoint() - monkeypatch.setattr(pubsubs_fsub[0], "push_msg", push_msg) + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: + with monkeypatch.context() as m: + m.setattr(pubsubs_fsub[0], "push_msg", push_msg) - await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA) - await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA) + await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA) + await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA) - assert len(msgs) == 2, "`push_msg` should be called every time `publish` is called" - assert (msg_forwarders[0] == msg_forwarders[1]) and ( - msg_forwarders[1] == pubsubs_fsub[0].my_id - ) - assert msgs[0].seqno != msgs[1].seqno, "`seqno` should be different every time" + assert ( + len(msgs) == 2 + ), "`push_msg` should be called every time `publish` is called" + assert (msg_forwarders[0] == msg_forwarders[1]) and ( + msg_forwarders[1] == pubsubs_fsub[0].my_id + ) + assert ( + msgs[0].seqno != msgs[1].seqno + ), "`seqno` should be different every time" -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_push_msg(pubsubs_fsub, monkeypatch): - msg_0 = make_pubsub_msg( - origin_id=pubsubs_fsub[0].my_id, - topic_ids=[TESTING_TOPIC], - data=TESTING_DATA, - seqno=b"\x00" * 8, - ) +@pytest.mark.trio +async def test_push_msg(monkeypatch): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: + msg_0 = make_pubsub_msg( + origin_id=pubsubs_fsub[0].my_id, + topic_ids=[TESTING_TOPIC], + data=TESTING_DATA, + seqno=b"\x00" * 8, + ) - event = asyncio.Event() + @contextmanager + def mock_router_publish(): - async def router_publish(*args, **kwargs): - event.set() + event = trio.Event() - monkeypatch.setattr(pubsubs_fsub[0].router, "publish", router_publish) + async def router_publish(*args, **kwargs): + event.set() + await trio.hazmat.checkpoint() - # Test: `msg` is not seen before `push_msg`, and is seen after `push_msg`. - assert not pubsubs_fsub[0]._is_msg_seen(msg_0) - await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0) - assert pubsubs_fsub[0]._is_msg_seen(msg_0) - # Test: Ensure `router.publish` is called in `push_msg` - await asyncio.wait_for(event.wait(), timeout=0.1) + with monkeypatch.context() as m: + m.setattr(pubsubs_fsub[0].router, "publish", router_publish) + yield event - # Test: `push_msg` the message again and it will be reject. - # `router_publish` is not called then. - event.clear() - await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0) - await asyncio.sleep(0.01) - assert not event.is_set() + with mock_router_publish() as event: + # Test: `msg` is not seen before `push_msg`, and is seen after `push_msg`. + assert not pubsubs_fsub[0]._is_msg_seen(msg_0) + await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0) + assert pubsubs_fsub[0]._is_msg_seen(msg_0) + # Test: Ensure `router.publish` is called in `push_msg` + with trio.fail_after(0.1): + await event.wait() - sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC) - # Test: `push_msg` succeeds with another unseen msg. - msg_1 = make_pubsub_msg( - origin_id=pubsubs_fsub[0].my_id, - topic_ids=[TESTING_TOPIC], - data=TESTING_DATA, - seqno=b"\x11" * 8, - ) - assert not pubsubs_fsub[0]._is_msg_seen(msg_1) - await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_1) - assert pubsubs_fsub[0]._is_msg_seen(msg_1) - await asyncio.wait_for(event.wait(), timeout=0.1) - # Test: Subscribers are notified when `push_msg` new messages. - assert (await sub.get()) == msg_1 + with mock_router_publish() as event: + # Test: `push_msg` the message again and it will be reject. + # `router_publish` is not called then. + await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0) + await trio.sleep(0.01) + assert not event.is_set() - # Test: add a topic validator and `push_msg` the message that - # does not pass the validation. - # `router_publish` is not called then. - def failed_sync_validator(peer_id, msg): - return False + sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + # Test: `push_msg` succeeds with another unseen msg. + msg_1 = make_pubsub_msg( + origin_id=pubsubs_fsub[0].my_id, + topic_ids=[TESTING_TOPIC], + data=TESTING_DATA, + seqno=b"\x11" * 8, + ) + assert not pubsubs_fsub[0]._is_msg_seen(msg_1) + await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_1) + assert pubsubs_fsub[0]._is_msg_seen(msg_1) + with trio.fail_after(0.1): + await event.wait() + # Test: Subscribers are notified when `push_msg` new messages. + assert (await sub.get()) == msg_1 - pubsubs_fsub[0].set_topic_validator(TESTING_TOPIC, failed_sync_validator, False) + with mock_router_publish() as event: + # Test: add a topic validator and `push_msg` the message that + # does not pass the validation. + # `router_publish` is not called then. + def failed_sync_validator(peer_id, msg): + return False - msg_2 = make_pubsub_msg( - origin_id=pubsubs_fsub[0].my_id, - topic_ids=[TESTING_TOPIC], - data=TESTING_DATA, - seqno=b"\x22" * 8, - ) + pubsubs_fsub[0].set_topic_validator( + TESTING_TOPIC, failed_sync_validator, False + ) - event.clear() - await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_2) - await asyncio.sleep(0.01) - assert not event.is_set() + msg_2 = make_pubsub_msg( + origin_id=pubsubs_fsub[0].my_id, + topic_ids=[TESTING_TOPIC], + data=TESTING_DATA, + seqno=b"\x22" * 8, + ) + + await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_2) + await trio.sleep(0.01) + assert not event.is_set() -@pytest.mark.parametrize("num_hosts, is_strict_signing", ((2, True),)) -@pytest.mark.asyncio -async def test_strict_signing(pubsubs_fsub, hosts): - await connect(hosts[0], hosts[1]) - await pubsubs_fsub[0].subscribe(TESTING_TOPIC) - await pubsubs_fsub[1].subscribe(TESTING_TOPIC) - await asyncio.sleep(1) +@pytest.mark.trio +async def test_strict_signing(): + async with PubsubFactory.create_batch_with_floodsub( + 2, strict_signing=True + ) as pubsubs_fsub: + await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host) + await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + await pubsubs_fsub[1].subscribe(TESTING_TOPIC) + await trio.sleep(1) - await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA) - await asyncio.sleep(1) + await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA) + await trio.sleep(1) - assert len(pubsubs_fsub[0].seen_messages) == 1 - assert len(pubsubs_fsub[1].seen_messages) == 1 + assert len(pubsubs_fsub[0].seen_messages) == 1 + assert len(pubsubs_fsub[1].seen_messages) == 1 -@pytest.mark.parametrize("num_hosts, is_strict_signing", ((2, True),)) -@pytest.mark.asyncio -async def test_strict_signing_failed_validation(pubsubs_fsub, hosts, monkeypatch): - msg = make_pubsub_msg( - origin_id=pubsubs_fsub[0].my_id, - topic_ids=[TESTING_TOPIC], - data=TESTING_DATA, - seqno=b"\x00" * 8, - ) - priv_key = pubsubs_fsub[0].sign_key - signature = priv_key.sign(PUBSUB_SIGNING_PREFIX.encode() + msg.SerializeToString()) +@pytest.mark.trio +async def test_strict_signing_failed_validation(monkeypatch): + async with PubsubFactory.create_batch_with_floodsub( + 2, strict_signing=True + ) as pubsubs_fsub: + msg = make_pubsub_msg( + origin_id=pubsubs_fsub[0].my_id, + topic_ids=[TESTING_TOPIC], + data=TESTING_DATA, + seqno=b"\x00" * 8, + ) + priv_key = pubsubs_fsub[0].sign_key + signature = priv_key.sign( + PUBSUB_SIGNING_PREFIX.encode() + msg.SerializeToString() + ) - event = asyncio.Event() + event = trio.Event() - def _is_msg_seen(msg): - return False + def _is_msg_seen(msg): + return False - # Use router publish to check if `push_msg` succeed. - async def router_publish(*args, **kwargs): - # The event will only be set if `push_msg` succeed. - event.set() + # Use router publish to check if `push_msg` succeed. + async def router_publish(*args, **kwargs): + await trio.hazmat.checkpoint() + # The event will only be set if `push_msg` succeed. + event.set() - monkeypatch.setattr(pubsubs_fsub[0], "_is_msg_seen", _is_msg_seen) - monkeypatch.setattr(pubsubs_fsub[0].router, "publish", router_publish) + monkeypatch.setattr(pubsubs_fsub[0], "_is_msg_seen", _is_msg_seen) + monkeypatch.setattr(pubsubs_fsub[0].router, "publish", router_publish) - # Test: no signature attached in `msg` - await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg) - await asyncio.sleep(0.01) - assert not event.is_set() + # Test: no signature attached in `msg` + await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg) + await trio.sleep(0.01) + assert not event.is_set() - # Test: `msg.key` does not match `msg.from_id` - msg.key = hosts[1].get_public_key().serialize() - msg.signature = signature - await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg) - await asyncio.sleep(0.01) - assert not event.is_set() + # Test: `msg.key` does not match `msg.from_id` + msg.key = pubsubs_fsub[1].host.get_public_key().serialize() + msg.signature = signature + await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg) + await trio.sleep(0.01) + assert not event.is_set() - # Test: invalid signature - msg.key = hosts[0].get_public_key().serialize() - msg.signature = b"\x12" * 100 - await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg) - await asyncio.sleep(0.01) - assert not event.is_set() + # Test: invalid signature + msg.key = pubsubs_fsub[0].host.get_public_key().serialize() + msg.signature = b"\x12" * 100 + await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg) + await trio.sleep(0.01) + assert not event.is_set() - # Finally, assert the signature indeed will pass validation - msg.key = hosts[0].get_public_key().serialize() - msg.signature = signature - await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg) - await asyncio.sleep(0.01) - assert event.is_set() + # Finally, assert the signature indeed will pass validation + msg.key = pubsubs_fsub[0].host.get_public_key().serialize() + msg.signature = signature + await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg) + await trio.sleep(0.01) + assert event.is_set() diff --git a/tests/pubsub/test_subscription.py b/tests/pubsub/test_subscription.py new file mode 100644 index 00000000..a0a6c10c --- /dev/null +++ b/tests/pubsub/test_subscription.py @@ -0,0 +1,84 @@ +import math + +import pytest +import trio + +from libp2p.pubsub.pb import rpc_pb2 +from libp2p.pubsub.subscription import TrioSubscriptionAPI + +GET_TIMEOUT = 0.001 + + +def make_trio_subscription(): + send_channel, receive_channel = trio.open_memory_channel(math.inf) + + async def unsubscribe_fn(): + await send_channel.aclose() + + return ( + send_channel, + TrioSubscriptionAPI(receive_channel, unsubscribe_fn=unsubscribe_fn), + ) + + +def make_pubsub_msg(): + return rpc_pb2.Message() + + +async def send_something(send_channel): + msg = make_pubsub_msg() + await send_channel.send(msg) + return msg + + +@pytest.mark.trio +async def test_trio_subscription_get(): + send_channel, sub = make_trio_subscription() + data_0 = await send_something(send_channel) + data_1 = await send_something(send_channel) + assert data_0 == await sub.get() + assert data_1 == await sub.get() + # No more message + with pytest.raises(trio.TooSlowError): + with trio.fail_after(GET_TIMEOUT): + await sub.get() + + +@pytest.mark.trio +async def test_trio_subscription_iter(): + send_channel, sub = make_trio_subscription() + received_data = [] + + async def iter_subscriptions(subscription): + async for data in sub: + received_data.append(data) + + async with trio.open_nursery() as nursery: + nursery.start_soon(iter_subscriptions, sub) + await send_something(send_channel) + await send_something(send_channel) + await send_channel.aclose() + + assert len(received_data) == 2 + + +@pytest.mark.trio +async def test_trio_subscription_unsubscribe(): + send_channel, sub = make_trio_subscription() + await sub.unsubscribe() + # Test: If the subscription is unsubscribed, `send_channel` should be closed. + with pytest.raises(trio.ClosedResourceError): + await send_something(send_channel) + # Test: No side effect when cancelled twice. + await sub.unsubscribe() + + +@pytest.mark.trio +async def test_trio_subscription_async_context_manager(): + send_channel, sub = make_trio_subscription() + async with sub: + # Test: `sub` is not cancelled yet, so `send_something` works fine. + await send_something(send_channel) + # Test: `sub` is cancelled, `send_something` fails + with pytest.raises(trio.ClosedResourceError): + await send_something(send_channel) diff --git a/tests/security/test_secio.py b/tests/security/test_secio.py index 50374809..d009a738 100644 --- a/tests/security/test_secio.py +++ b/tests/security/test_secio.py @@ -1,70 +1,15 @@ -import asyncio - import pytest +import trio from libp2p.crypto.secp256k1 import create_new_key_pair -from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.peer.id import ID from libp2p.security.secio.transport import NONCE_SIZE, create_secure_session +from libp2p.tools.constants import MAX_READ_LEN +from libp2p.tools.factories import raw_conn_factory -class InMemoryConnection(IRawConnection): - def __init__(self, peer, is_initiator=False): - self.peer = peer - self.recv_queue = asyncio.Queue() - self.send_queue = asyncio.Queue() - self.is_initiator = is_initiator - - self.current_msg = None - self.current_position = 0 - - self.closed = False - - async def write(self, data: bytes) -> int: - if self.closed: - raise Exception("InMemoryConnection is closed for writing") - - await self.send_queue.put(data) - return len(data) - - async def read(self, n: int = -1) -> bytes: - """ - NOTE: have to buffer the current message and juggle packets - off the recv queue to satisfy the semantics of this function. - """ - if self.closed: - raise Exception("InMemoryConnection is closed for reading") - - if not self.current_msg: - self.current_msg = await self.recv_queue.get() - self.current_position = 0 - - if n < 0: - msg = self.current_msg - self.current_msg = None - return msg - - next_msg = self.current_msg[self.current_position : self.current_position + n] - self.current_position += n - if self.current_position == len(self.current_msg): - self.current_msg = None - return next_msg - - async def close(self) -> None: - self.closed = True - - -async def create_pipe(local_conn, remote_conn): - try: - while True: - next_msg = await local_conn.send_queue.get() - await remote_conn.recv_queue.put(next_msg) - except asyncio.CancelledError: - return - - -@pytest.mark.asyncio -async def test_create_secure_session(): +@pytest.mark.trio +async def test_create_secure_session(nursery): local_nonce = b"\x01" * NONCE_SIZE local_key_pair = create_new_key_pair(b"a") local_peer = ID.from_pubkey(local_key_pair.public_key) @@ -73,30 +18,32 @@ async def test_create_secure_session(): remote_key_pair = create_new_key_pair(b"b") remote_peer = ID.from_pubkey(remote_key_pair.public_key) - local_conn = InMemoryConnection(local_peer, is_initiator=True) - remote_conn = InMemoryConnection(remote_peer) + async with raw_conn_factory(nursery) as conns: + local_conn, remote_conn = conns - local_pipe_task = asyncio.ensure_future(create_pipe(local_conn, remote_conn)) - remote_pipe_task = asyncio.ensure_future(create_pipe(remote_conn, local_conn)) + local_secure_conn, remote_secure_conn = None, None - local_session_builder = create_secure_session( - local_nonce, local_peer, local_key_pair.private_key, local_conn, remote_peer - ) - remote_session_builder = create_secure_session( - remote_nonce, remote_peer, remote_key_pair.private_key, remote_conn - ) - local_secure_conn, remote_secure_conn = await asyncio.gather( - local_session_builder, remote_session_builder - ) + async def local_create_secure_session(): + nonlocal local_secure_conn + local_secure_conn = await create_secure_session( + local_nonce, + local_peer, + local_key_pair.private_key, + local_conn, + remote_peer, + ) - msg = b"abc" - await local_secure_conn.write(msg) - received_msg = await remote_secure_conn.read() - assert received_msg == msg + async def remote_create_secure_session(): + nonlocal remote_secure_conn + remote_secure_conn = await create_secure_session( + remote_nonce, remote_peer, remote_key_pair.private_key, remote_conn + ) - await asyncio.gather(local_secure_conn.close(), remote_secure_conn.close()) + async with trio.open_nursery() as nursery_1: + nursery_1.start_soon(local_create_secure_session) + nursery_1.start_soon(remote_create_secure_session) - local_pipe_task.cancel() - remote_pipe_task.cancel() - await local_pipe_task - await remote_pipe_task + msg = b"abc" + await local_secure_conn.write(msg) + received_msg = await remote_secure_conn.read(MAX_READ_LEN) + assert received_msg == msg diff --git a/tests/security/test_security_multistream.py b/tests/security/test_security_multistream.py index c4eb3ecb..cd968ac2 100644 --- a/tests/security/test_security_multistream.py +++ b/tests/security/test_security_multistream.py @@ -1,8 +1,7 @@ -import asyncio - import pytest +import trio -from libp2p import new_node +from libp2p import new_host from libp2p.crypto.rsa import create_new_key_pair from libp2p.security.insecure.transport import InsecureSession, InsecureTransport from libp2p.tools.constants import LISTEN_MADDR @@ -24,42 +23,36 @@ noninitiator_key_pair = create_new_key_pair() async def perform_simple_test( assertion_func, transports_for_initiator, transports_for_noninitiator ): - # Create libp2p nodes and connect them, then secure the connection, then check # the proper security was chosen # TODO: implement -- note we need to introduce the notion of communicating over a raw connection # for testing, we do NOT want to communicate over a stream so we can't just create two nodes # and use their conn because our mplex will internally relay messages to a stream - node1 = await new_node( - key_pair=initiator_key_pair, sec_opt=transports_for_initiator - ) - node2 = await new_node( + node1 = new_host(key_pair=initiator_key_pair, sec_opt=transports_for_initiator) + node2 = new_host( key_pair=noninitiator_key_pair, sec_opt=transports_for_noninitiator ) + async with node1.run(listen_addrs=[LISTEN_MADDR]), node2.run( + listen_addrs=[LISTEN_MADDR] + ): + await connect(node1, node2) - await node1.get_network().listen(LISTEN_MADDR) - await node2.get_network().listen(LISTEN_MADDR) + # Wait a very short period to allow conns to be stored (since the functions + # storing the conns are async, they may happen at slightly different times + # on each node) + await trio.sleep(0.1) - await connect(node1, node2) + # Get conns + node1_conn = node1.get_network().connections[peer_id_for_node(node2)] + node2_conn = node2.get_network().connections[peer_id_for_node(node1)] - # Wait a very short period to allow conns to be stored (since the functions - # storing the conns are async, they may happen at slightly different times - # on each node) - await asyncio.sleep(0.1) - - # Get conns - node1_conn = node1.get_network().connections[peer_id_for_node(node2)] - node2_conn = node2.get_network().connections[peer_id_for_node(node1)] - - # Perform assertion - assertion_func(node1_conn.muxed_conn.secured_conn) - assertion_func(node2_conn.muxed_conn.secured_conn) - - # Success, terminate pending tasks. + # Perform assertion + assertion_func(node1_conn.muxed_conn.secured_conn) + assertion_func(node2_conn.muxed_conn.secured_conn) -@pytest.mark.asyncio +@pytest.mark.trio async def test_single_insecure_security_transport_succeeds(): transports_for_initiator = {"foo": InsecureTransport(initiator_key_pair)} transports_for_noninitiator = {"foo": InsecureTransport(noninitiator_key_pair)} @@ -72,7 +65,7 @@ async def test_single_insecure_security_transport_succeeds(): ) -@pytest.mark.asyncio +@pytest.mark.trio async def test_default_insecure_security(): transports_for_initiator = None transports_for_noninitiator = None diff --git a/tests/stream_muxer/conftest.py b/tests/stream_muxer/conftest.py index cdb57e8f..248422d9 100644 --- a/tests/stream_muxer/conftest.py +++ b/tests/stream_muxer/conftest.py @@ -1,5 +1,3 @@ -import asyncio - import pytest from libp2p.tools.factories import mplex_conn_pair_factory, mplex_stream_pair_factory @@ -7,23 +5,13 @@ from libp2p.tools.factories import mplex_conn_pair_factory, mplex_stream_pair_fa @pytest.fixture async def mplex_conn_pair(is_host_secure): - mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory( - is_host_secure - ) - assert mplex_conn_0.is_initiator - assert not mplex_conn_1.is_initiator - try: - yield mplex_conn_0, mplex_conn_1 - finally: - await asyncio.gather(*[swarm_0.close(), swarm_1.close()]) + async with mplex_conn_pair_factory(is_host_secure) as mplex_conn_pair: + assert mplex_conn_pair[0].is_initiator + assert not mplex_conn_pair[1].is_initiator + yield mplex_conn_pair[0], mplex_conn_pair[1] @pytest.fixture async def mplex_stream_pair(is_host_secure): - mplex_stream_0, swarm_0, mplex_stream_1, swarm_1 = await mplex_stream_pair_factory( - is_host_secure - ) - try: - yield mplex_stream_0, mplex_stream_1 - finally: - await asyncio.gather(*[swarm_0.close(), swarm_1.close()]) + async with mplex_stream_pair_factory(is_host_secure) as mplex_stream_pair: + yield mplex_stream_pair diff --git a/tests/stream_muxer/test_mplex_conn.py b/tests/stream_muxer/test_mplex_conn.py index 6dc98ad6..df1097dd 100644 --- a/tests/stream_muxer/test_mplex_conn.py +++ b/tests/stream_muxer/test_mplex_conn.py @@ -1,49 +1,40 @@ -import asyncio - import pytest +import trio -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_conn(mplex_conn_pair): conn_0, conn_1 = mplex_conn_pair assert len(conn_0.streams) == 0 assert len(conn_1.streams) == 0 - assert not conn_0.event_shutting_down.is_set() - assert not conn_1.event_shutting_down.is_set() - assert not conn_0.event_closed.is_set() - assert not conn_1.event_closed.is_set() # Test: Open a stream, and both side get 1 more stream. stream_0 = await conn_0.open_stream() - await asyncio.sleep(0.01) + await trio.sleep(0.01) assert len(conn_0.streams) == 1 assert len(conn_1.streams) == 1 # Test: From another side. stream_1 = await conn_1.open_stream() - await asyncio.sleep(0.01) + await trio.sleep(0.01) assert len(conn_0.streams) == 2 assert len(conn_1.streams) == 2 # Close from one side. await conn_0.close() # Sleep for a while for both side to handle `close`. - await asyncio.sleep(0.01) + await trio.sleep(0.01) # Test: Both side is closed. - assert conn_0.event_shutting_down.is_set() - assert conn_0.event_closed.is_set() - assert conn_1.event_shutting_down.is_set() - assert conn_1.event_closed.is_set() + assert conn_0.is_closed + assert conn_1.is_closed # Test: All streams should have been closed. assert stream_0.event_remote_closed.is_set() assert stream_0.event_reset.is_set() assert stream_0.event_local_closed.is_set() - assert conn_0.streams is None # Test: All streams on the other side are also closed. assert stream_1.event_remote_closed.is_set() assert stream_1.event_reset.is_set() assert stream_1.event_local_closed.is_set() - assert conn_1.streams is None # Test: No effect to close more than once between two side. await conn_0.close() diff --git a/tests/stream_muxer/test_mplex_stream.py b/tests/stream_muxer/test_mplex_stream.py index f3458d8f..3bc8bc1c 100644 --- a/tests/stream_muxer/test_mplex_stream.py +++ b/tests/stream_muxer/test_mplex_stream.py @@ -1,25 +1,48 @@ -import asyncio - import pytest +import trio +from trio.testing import wait_all_tasks_blocked from libp2p.stream_muxer.mplex.exceptions import ( MplexStreamClosed, MplexStreamEOF, MplexStreamReset, ) +from libp2p.stream_muxer.mplex.mplex import MPLEX_MESSAGE_CHANNEL_SIZE from libp2p.tools.constants import MAX_READ_LEN DATA = b"data_123" -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_stream_read_write(mplex_stream_pair): stream_0, stream_1 = mplex_stream_pair await stream_0.write(DATA) assert (await stream_1.read(MAX_READ_LEN)) == DATA -@pytest.mark.asyncio +@pytest.mark.trio +async def test_mplex_stream_full_buffer(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + # Test: The message channel is of size `MPLEX_MESSAGE_CHANNEL_SIZE`. + # It should be fine to read even there are already `MPLEX_MESSAGE_CHANNEL_SIZE` + # messages arriving. + for _ in range(MPLEX_MESSAGE_CHANNEL_SIZE): + await stream_0.write(DATA) + await wait_all_tasks_blocked() + # Sanity check + assert MAX_READ_LEN >= MPLEX_MESSAGE_CHANNEL_SIZE * len(DATA) + assert (await stream_1.read(MAX_READ_LEN)) == MPLEX_MESSAGE_CHANNEL_SIZE * DATA + + # Test: Read after `MPLEX_MESSAGE_CHANNEL_SIZE + 1` messages has arrived, which + # exceeds the channel size. The stream should have been reset. + for _ in range(MPLEX_MESSAGE_CHANNEL_SIZE + 1): + await stream_0.write(DATA) + await wait_all_tasks_blocked() + with pytest.raises(MplexStreamReset): + await stream_1.read(MAX_READ_LEN) + + +@pytest.mark.trio async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair): read_bytes = bytearray() stream_0, stream_1 = mplex_stream_pair @@ -27,43 +50,46 @@ async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair): async def read_until_eof(): read_bytes.extend(await stream_1.read()) - task = asyncio.ensure_future(read_until_eof()) - expected_data = bytearray() - # Test: `read` doesn't return before `close` is called. - await stream_0.write(DATA) - expected_data.extend(DATA) - await asyncio.sleep(0.01) - assert len(read_bytes) == 0 - # Test: `read` doesn't return before `close` is called. - await stream_0.write(DATA) - expected_data.extend(DATA) - await asyncio.sleep(0.01) - assert len(read_bytes) == 0 + async with trio.open_nursery() as nursery: + nursery.start_soon(read_until_eof) + # Test: `read` doesn't return before `close` is called. + await stream_0.write(DATA) + expected_data.extend(DATA) + await trio.sleep(0.01) + assert len(read_bytes) == 0 + # Test: `read` doesn't return before `close` is called. + await stream_0.write(DATA) + expected_data.extend(DATA) + await trio.sleep(0.01) + assert len(read_bytes) == 0 + + # Test: Close the stream, `read` returns, and receive previous sent data. + await stream_0.close() - # Test: Close the stream, `read` returns, and receive previous sent data. - await stream_0.close() - await asyncio.sleep(0.01) assert read_bytes == expected_data - task.cancel() - -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_stream_read_after_remote_closed(mplex_stream_pair): stream_0, stream_1 = mplex_stream_pair assert not stream_1.event_remote_closed.is_set() await stream_0.write(DATA) + assert not stream_0.event_local_closed.is_set() + await trio.sleep(0.01) + await wait_all_tasks_blocked() await stream_0.close() - await asyncio.sleep(0.01) + assert stream_0.event_local_closed.is_set() + await trio.sleep(0.01) + await wait_all_tasks_blocked() assert stream_1.event_remote_closed.is_set() assert (await stream_1.read(MAX_READ_LEN)) == DATA with pytest.raises(MplexStreamEOF): await stream_1.read(MAX_READ_LEN) -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_stream_read_after_local_reset(mplex_stream_pair): stream_0, stream_1 = mplex_stream_pair await stream_0.reset() @@ -71,29 +97,30 @@ async def test_mplex_stream_read_after_local_reset(mplex_stream_pair): await stream_0.read(MAX_READ_LEN) -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_stream_read_after_remote_reset(mplex_stream_pair): stream_0, stream_1 = mplex_stream_pair await stream_0.write(DATA) await stream_0.reset() # Sleep to let `stream_1` receive the message. - await asyncio.sleep(0.01) + await trio.sleep(0.1) + await wait_all_tasks_blocked() with pytest.raises(MplexStreamReset): await stream_1.read(MAX_READ_LEN) -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_stream_read_after_remote_closed_and_reset(mplex_stream_pair): stream_0, stream_1 = mplex_stream_pair await stream_0.write(DATA) await stream_0.close() await stream_0.reset() # Sleep to let `stream_1` receive the message. - await asyncio.sleep(0.01) + await trio.sleep(0.01) assert (await stream_1.read(MAX_READ_LEN)) == DATA -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_stream_write_after_local_closed(mplex_stream_pair): stream_0, stream_1 = mplex_stream_pair await stream_0.write(DATA) @@ -102,7 +129,7 @@ async def test_mplex_stream_write_after_local_closed(mplex_stream_pair): await stream_0.write(DATA) -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_stream_write_after_local_reset(mplex_stream_pair): stream_0, stream_1 = mplex_stream_pair await stream_0.reset() @@ -110,16 +137,16 @@ async def test_mplex_stream_write_after_local_reset(mplex_stream_pair): await stream_0.write(DATA) -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_stream_write_after_remote_reset(mplex_stream_pair): stream_0, stream_1 = mplex_stream_pair await stream_1.reset() - await asyncio.sleep(0.01) + await trio.sleep(0.01) with pytest.raises(MplexStreamClosed): await stream_0.write(DATA) -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_stream_both_close(mplex_stream_pair): stream_0, stream_1 = mplex_stream_pair # Flags are not set initially. @@ -133,7 +160,7 @@ async def test_mplex_stream_both_close(mplex_stream_pair): # Test: Close one side. await stream_0.close() - await asyncio.sleep(0.01) + await trio.sleep(0.01) assert stream_0.event_local_closed.is_set() assert not stream_1.event_local_closed.is_set() @@ -145,7 +172,7 @@ async def test_mplex_stream_both_close(mplex_stream_pair): # Test: Close the other side. await stream_1.close() - await asyncio.sleep(0.01) + await trio.sleep(0.01) # Both sides are closed. assert stream_0.event_local_closed.is_set() assert stream_1.event_local_closed.is_set() @@ -159,11 +186,11 @@ async def test_mplex_stream_both_close(mplex_stream_pair): await stream_0.reset() -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_stream_reset(mplex_stream_pair): stream_0, stream_1 = mplex_stream_pair await stream_0.reset() - await asyncio.sleep(0.01) + await trio.sleep(0.01) # Both sides are closed. assert stream_0.event_local_closed.is_set() diff --git a/tests/transport/test_tcp.py b/tests/transport/test_tcp.py index 7231a060..130b3cc4 100644 --- a/tests/transport/test_tcp.py +++ b/tests/transport/test_tcp.py @@ -1,20 +1,53 @@ -import asyncio - +from multiaddr import Multiaddr import pytest +import trio -from libp2p.transport.tcp.tcp import _multiaddr_from_socket +from libp2p.network.connection.raw_connection import RawConnection +from libp2p.tools.constants import LISTEN_MADDR +from libp2p.transport.exceptions import OpenConnectionError +from libp2p.transport.tcp.tcp import TCP -@pytest.mark.asyncio -async def test_multiaddr_from_socket(): - def handler(r, w): +@pytest.mark.trio +async def test_tcp_listener(nursery): + transport = TCP() + + async def handler(tcp_stream): pass - server = await asyncio.start_server(handler, "127.0.0.1", 8000) - assert str(_multiaddr_from_socket(server.sockets[0])) == "/ip4/127.0.0.1/tcp/8000" + listener = transport.create_listener(handler) + assert len(listener.get_addrs()) == 0 + await listener.listen(LISTEN_MADDR, nursery) + assert len(listener.get_addrs()) == 1 + await listener.listen(LISTEN_MADDR, nursery) + assert len(listener.get_addrs()) == 2 - server = await asyncio.start_server(handler, "127.0.0.1", 0) - addr = _multiaddr_from_socket(server.sockets[0]) - assert addr.value_for_protocol("ip4") == "127.0.0.1" - port = addr.value_for_protocol("tcp") - assert int(port) > 0 + +@pytest.mark.trio +async def test_tcp_dial(nursery): + transport = TCP() + raw_conn_other_side = None + event = trio.Event() + + async def handler(tcp_stream): + nonlocal raw_conn_other_side + raw_conn_other_side = RawConnection(tcp_stream, False) + event.set() + await trio.sleep_forever() + + # Test: `OpenConnectionError` is raised when trying to dial to a port which + # no one is not listening to. + with pytest.raises(OpenConnectionError): + await transport.dial(Multiaddr("/ip4/127.0.0.1/tcp/1")) + + listener = transport.create_listener(handler) + await listener.listen(LISTEN_MADDR, nursery) + addrs = listener.get_addrs() + assert len(addrs) == 1 + listen_addr = addrs[0] + raw_conn = await transport.dial(listen_addr) + await event.wait() + + data = b"123" + await raw_conn_other_side.write(data) + assert (await raw_conn.read(len(data))) == data diff --git a/tests_interop/conftest.py b/tests_interop/conftest.py index 08df614c..067db83f 100644 --- a/tests_interop/conftest.py +++ b/tests_interop/conftest.py @@ -1,20 +1,13 @@ -import asyncio -import sys -from typing import Union - +import anyio +from async_exit_stack import AsyncExitStack from p2pclient.datastructures import StreamInfo -import pexpect +from p2pclient.utils import get_unused_tcp_port import pytest +import trio from libp2p.io.abc import ReadWriteCloser -from libp2p.tools.constants import GOSSIPSUB_PARAMS, LISTEN_MADDR -from libp2p.tools.factories import ( - FloodsubFactory, - GossipsubFactory, - HostFactory, - PubsubFactory, -) -from libp2p.tools.interop.daemon import Daemon, make_p2pd +from libp2p.tools.factories import HostFactory, PubsubFactory +from libp2p.tools.interop.daemon import make_p2pd from libp2p.tools.interop.utils import connect @@ -23,48 +16,6 @@ def is_host_secure(): return False -@pytest.fixture -def num_hosts(): - return 3 - - -@pytest.fixture -async def hosts(num_hosts, is_host_secure): - _hosts = HostFactory.create_batch(num_hosts, is_secure=is_host_secure) - await asyncio.gather( - *[_host.get_network().listen(LISTEN_MADDR) for _host in _hosts] - ) - try: - yield _hosts - finally: - # TODO: It's possible that `close` raises exceptions currently, - # due to the connection reset things. Though we don't care much about that when - # cleaning up the tasks, it is probably better to handle the exceptions properly. - await asyncio.gather( - *[_host.close() for _host in _hosts], return_exceptions=True - ) - - -@pytest.fixture -def proc_factory(): - procs = [] - - def call_proc(cmd, args, logfile=None, encoding=None): - if logfile is None: - logfile = sys.stdout - if encoding is None: - encoding = "utf-8" - proc = pexpect.spawn(cmd, args, logfile=logfile, encoding=encoding) - procs.append(proc) - return proc - - try: - yield call_proc - finally: - for proc in procs: - proc.close() - - @pytest.fixture def num_p2pds(): return 1 @@ -87,79 +38,57 @@ def is_pubsub_signing_strict(): @pytest.fixture async def p2pds( - num_p2pds, - is_host_secure, - is_gossipsub, - unused_tcp_port_factory, - is_pubsub_signing, - is_pubsub_signing_strict, + num_p2pds, is_host_secure, is_gossipsub, is_pubsub_signing, is_pubsub_signing_strict ): - p2pds: Union[Daemon, Exception] = await asyncio.gather( - *[ - make_p2pd( - unused_tcp_port_factory(), - unused_tcp_port_factory(), - is_host_secure, - is_gossipsub=is_gossipsub, - is_pubsub_signing=is_pubsub_signing, - is_pubsub_signing_strict=is_pubsub_signing_strict, + async with AsyncExitStack() as stack: + p2pds = [ + await stack.enter_async_context( + make_p2pd( + get_unused_tcp_port(), + get_unused_tcp_port(), + is_host_secure, + is_gossipsub=is_gossipsub, + is_pubsub_signing=is_pubsub_signing, + is_pubsub_signing_strict=is_pubsub_signing_strict, + ) ) for _ in range(num_p2pds) - ], - return_exceptions=True, - ) - p2pds_succeeded = tuple(p2pd for p2pd in p2pds if isinstance(p2pd, Daemon)) - if len(p2pds_succeeded) != len(p2pds): - # Not all succeeded. Close the succeeded ones and print the failed ones(exceptions). - await asyncio.gather(*[p2pd.close() for p2pd in p2pds_succeeded]) - exceptions = tuple(p2pd for p2pd in p2pds if isinstance(p2pd, Exception)) - raise Exception(f"not all p2pds succeed: first exception={exceptions[0]}") - try: - yield p2pds - finally: - await asyncio.gather(*[p2pd.close() for p2pd in p2pds]) + ] + try: + yield p2pds + finally: + for p2pd in p2pds: + await p2pd.close() @pytest.fixture -def pubsubs(num_hosts, hosts, is_gossipsub, is_pubsub_signing_strict): +async def pubsubs(num_hosts, is_host_secure, is_gossipsub, is_pubsub_signing_strict): if is_gossipsub: - routers = GossipsubFactory.create_batch(num_hosts, **GOSSIPSUB_PARAMS._asdict()) + yield PubsubFactory.create_batch_with_gossipsub( + num_hosts, is_secure=is_host_secure, strict_signing=is_pubsub_signing_strict + ) else: - routers = FloodsubFactory.create_batch(num_hosts) - _pubsubs = tuple( - PubsubFactory(host=host, router=router, strict_signing=is_pubsub_signing_strict) - for host, router in zip(hosts, routers) - ) - yield _pubsubs - # TODO: Clean up + yield PubsubFactory.create_batch_with_floodsub( + num_hosts, is_host_secure, strict_signing=is_pubsub_signing_strict + ) class DaemonStream(ReadWriteCloser): stream_info: StreamInfo - reader: asyncio.StreamReader - writer: asyncio.StreamWriter + stream: anyio.abc.SocketStream - def __init__( - self, - stream_info: StreamInfo, - reader: asyncio.StreamReader, - writer: asyncio.StreamWriter, - ) -> None: + def __init__(self, stream_info: StreamInfo, stream: anyio.abc.SocketStream) -> None: self.stream_info = stream_info - self.reader = reader - self.writer = writer + self.stream = stream async def close(self) -> None: - self.writer.close() - if sys.version_info < (3, 7): - return - await self.writer.wait_closed() + await self.stream.close() - async def read(self, n: int = -1) -> bytes: - return await self.reader.read(n) + async def read(self, n: int = None) -> bytes: + return await self.stream.receive_some(n) - async def write(self, data: bytes) -> int: - return self.writer.write(data) + async def write(self, data: bytes) -> None: + return await self.stream.send_all(data) @pytest.fixture @@ -168,40 +97,38 @@ async def is_to_fail_daemon_stream(): @pytest.fixture -async def py_to_daemon_stream_pair(hosts, p2pds, is_to_fail_daemon_stream): - assert len(hosts) >= 1 - assert len(p2pds) >= 1 - host = hosts[0] - p2pd = p2pds[0] - protocol_id = "/protocol/id/123" - stream_py = None - stream_daemon = None - event_stream_handled = asyncio.Event() - await connect(host, p2pd) +async def py_to_daemon_stream_pair(p2pds, is_host_secure, is_to_fail_daemon_stream): + async with HostFactory.create_batch_and_listen(is_host_secure, 1) as hosts: + assert len(p2pds) >= 1 + host = hosts[0] + p2pd = p2pds[0] + protocol_id = "/protocol/id/123" + stream_py = None + stream_daemon = None + event_stream_handled = trio.Event() + await connect(host, p2pd) - async def daemon_stream_handler(stream_info, reader, writer): - nonlocal stream_daemon - stream_daemon = DaemonStream(stream_info, reader, writer) - event_stream_handled.set() + async def daemon_stream_handler(stream_info, stream): + nonlocal stream_daemon + stream_daemon = DaemonStream(stream_info, stream) + event_stream_handled.set() + await trio.hazmat.checkpoint() - await p2pd.control.stream_handler(protocol_id, daemon_stream_handler) - # Sleep for a while to wait for the handler being registered. - await asyncio.sleep(0.01) + await p2pd.control.stream_handler(protocol_id, daemon_stream_handler) + # Sleep for a while to wait for the handler being registered. + await trio.sleep(0.01) - if is_to_fail_daemon_stream: - # FIXME: This is a workaround to make daemon reset the stream. - # We intentionally close the listener on the python side, it makes the connection from - # daemon to us fail, and therefore the daemon resets the opened stream on their side. - # Reference: https://github.com/libp2p/go-libp2p-daemon/blob/b95e77dbfcd186ccf817f51e95f73f9fd5982600/stream.go#L47-L50 # noqa: E501 - # We need it because we want to test against `stream_py` after the remote side(daemon) - # is reset. This should be removed after the API `stream.reset` is exposed in daemon - # some day. - listener = p2pds[0].control.control.listener - listener.close() - if sys.version_info[0:2] > (3, 6): - await listener.wait_closed() - stream_py = await host.new_stream(p2pd.peer_id, [protocol_id]) - if not is_to_fail_daemon_stream: - await event_stream_handled.wait() - # NOTE: If `is_to_fail_daemon_stream == True`, then `stream_daemon == None`. - yield stream_py, stream_daemon + if is_to_fail_daemon_stream: + # FIXME: This is a workaround to make daemon reset the stream. + # We intentionally close the listener on the python side, it makes the connection from + # daemon to us fail, and therefore the daemon resets the opened stream on their side. + # Reference: https://github.com/libp2p/go-libp2p-daemon/blob/b95e77dbfcd186ccf817f51e95f73f9fd5982600/stream.go#L47-L50 # noqa: E501 + # We need it because we want to test against `stream_py` after the remote side(daemon) + # is reset. This should be removed after the API `stream.reset` is exposed in daemon + # some day. + await p2pds[0].control.control.close() + stream_py = await host.new_stream(p2pd.peer_id, [protocol_id]) + if not is_to_fail_daemon_stream: + await event_stream_handled.wait() + # NOTE: If `is_to_fail_daemon_stream == True`, then `stream_daemon == None`. + yield stream_py, stream_daemon diff --git a/tests_interop/test_bindings.py b/tests_interop/test_bindings.py index 9e70aa21..87cdbb1b 100644 --- a/tests_interop/test_bindings.py +++ b/tests_interop/test_bindings.py @@ -1,26 +1,26 @@ -import asyncio - import pytest +import trio +from libp2p.tools.factories import HostFactory from libp2p.tools.interop.utils import connect -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_connect(hosts, p2pds): - p2pd = p2pds[0] - host = hosts[0] - assert len(await p2pd.control.list_peers()) == 0 - # Test: connect from Py - await connect(host, p2pd) - assert len(await p2pd.control.list_peers()) == 1 - # Test: `disconnect` from Py - await host.disconnect(p2pd.peer_id) - assert len(await p2pd.control.list_peers()) == 0 - # Test: connect from Go - await connect(p2pd, host) - assert len(host.get_network().connections) == 1 - # Test: `disconnect` from Go - await p2pd.control.disconnect(host.get_id()) - await asyncio.sleep(0.01) - assert len(host.get_network().connections) == 0 +@pytest.mark.trio +async def test_connect(is_host_secure, p2pds): + async with HostFactory.create_batch_and_listen(is_host_secure, 1) as hosts: + p2pd = p2pds[0] + host = hosts[0] + assert len(await p2pd.control.list_peers()) == 0 + # Test: connect from Py + await connect(host, p2pd) + assert len(await p2pd.control.list_peers()) == 1 + # Test: `disconnect` from Py + await host.disconnect(p2pd.peer_id) + assert len(await p2pd.control.list_peers()) == 0 + # Test: connect from Go + await connect(p2pd, host) + assert len(host.get_network().connections) == 1 + # Test: `disconnect` from Go + await p2pd.control.disconnect(host.get_id()) + await trio.sleep(0.01) + assert len(host.get_network().connections) == 0 diff --git a/tests_interop/test_echo.py b/tests_interop/test_echo.py index 6ac867ba..85810a7f 100644 --- a/tests_interop/test_echo.py +++ b/tests_interop/test_echo.py @@ -1,82 +1,99 @@ -import asyncio +import re from multiaddr import Multiaddr +from p2pclient.utils import get_unused_tcp_port import pytest +import trio -from libp2p.peer.peerinfo import info_from_p2p_addr -from libp2p.tools.interop.constants import PEXPECT_NEW_LINE +from libp2p.peer.peerinfo import PeerInfo, info_from_p2p_addr +from libp2p.tools.factories import HostFactory from libp2p.tools.interop.envs import GO_BIN_PATH +from libp2p.tools.interop.process import BaseInteractiveProcess from libp2p.typing import TProtocol ECHO_PATH = GO_BIN_PATH / "echo" ECHO_PROTOCOL_ID = TProtocol("/echo/1.0.0") -async def make_echo_proc( - proc_factory, port: int, is_secure: bool, destination: Multiaddr = None -): - args = [f"-l={port}"] - if not is_secure: - args.append("-insecure") - if destination is not None: - args.append(f"-d={str(destination)}") - echo_proc = proc_factory(str(ECHO_PATH), args) - await echo_proc.expect(r"I am ([\w\./]+)" + PEXPECT_NEW_LINE, async_=True) - maddr_str_ipfs = echo_proc.match.group(1) - maddr_str = maddr_str_ipfs.replace("ipfs", "p2p") - maddr = Multiaddr(maddr_str) - go_pinfo = info_from_p2p_addr(maddr) - if destination is None: - await echo_proc.expect("listening for connections", async_=True) - return echo_proc, go_pinfo +class EchoProcess(BaseInteractiveProcess): + port: int + _peer_info: PeerInfo + + def __init__( + self, port: int, is_secure: bool, destination: Multiaddr = None + ) -> None: + args = [f"-l={port}"] + if not is_secure: + args.append("-insecure") + if destination is not None: + args.append(f"-d={str(destination)}") + + patterns = [b"I am"] + if destination is None: + patterns.append(b"listening for connections") + + self.args = args + self.cmd = str(ECHO_PATH) + self.patterns = patterns + self.bytes_read = bytearray() + self.event_ready = trio.Event() + + self.port = port + self._peer_info = None + self.regex_pat = re.compile(br"I am ([\w\./]+)") + + @property + def peer_info(self) -> None: + if self._peer_info is not None: + return self._peer_info + if not self.event_ready.is_set(): + raise Exception("process is not ready yet. failed to parse the peer info") + # Example: + # b"I am /ip4/127.0.0.1/tcp/56171/ipfs/QmU41TRPs34WWqa1brJEojBLYZKrrBcJq9nyNfVvSrbZUJ\n" + m = re.search(br"I am ([\w\./]+)", self.bytes_read) + if m is None: + raise Exception("failed to find the pattern for the listening multiaddr") + maddr_bytes_str_ipfs = m.group(1) + maddr_str = maddr_bytes_str_ipfs.decode().replace("ipfs", "p2p") + maddr = Multiaddr(maddr_str) + self._peer_info = info_from_p2p_addr(maddr) + return self._peer_info -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_insecure_conn_py_to_go( - hosts, proc_factory, is_host_secure, unused_tcp_port -): - go_proc, go_pinfo = await make_echo_proc( - proc_factory, unused_tcp_port, is_host_secure - ) +@pytest.mark.trio +async def test_insecure_conn_py_to_go(is_host_secure): + async with HostFactory.create_batch_and_listen(is_host_secure, 1) as hosts: + go_proc = EchoProcess(get_unused_tcp_port(), is_host_secure) + await go_proc.start() - host = hosts[0] - await host.connect(go_pinfo) - await go_proc.expect("swarm listener accepted connection", async_=True) - s = await host.new_stream(go_pinfo.peer_id, [ECHO_PROTOCOL_ID]) - - await go_proc.expect("Got a new stream!", async_=True) - data = "data321123\n" - await s.write(data.encode()) - await go_proc.expect(f"read: {data[:-1]}", async_=True) - echoed_resp = await s.read(len(data)) - assert echoed_resp.decode() == data - await s.close() + host = hosts[0] + peer_info = go_proc.peer_info + await host.connect(peer_info) + s = await host.new_stream(peer_info.peer_id, [ECHO_PROTOCOL_ID]) + data = "data321123\n" + await s.write(data.encode()) + echoed_resp = await s.read(len(data)) + assert echoed_resp.decode() == data + await s.close() -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_insecure_conn_go_to_py( - hosts, proc_factory, is_host_secure, unused_tcp_port -): - host = hosts[0] - expected_data = "Hello, world!\n" - reply_data = "Replyooo!\n" - event_handler_finished = asyncio.Event() +@pytest.mark.trio +async def test_insecure_conn_go_to_py(is_host_secure): + async with HostFactory.create_batch_and_listen(is_host_secure, 1) as hosts: + host = hosts[0] + expected_data = "Hello, world!\n" + reply_data = "Replyooo!\n" + event_handler_finished = trio.Event() - async def _handle_echo(stream): - read_data = await stream.read(len(expected_data)) - assert read_data == expected_data.encode() - event_handler_finished.set() - await stream.write(reply_data.encode()) - await stream.close() + async def _handle_echo(stream): + read_data = await stream.read(len(expected_data)) + assert read_data == expected_data.encode() + event_handler_finished.set() + await stream.write(reply_data.encode()) + await stream.close() - host.set_stream_handler(ECHO_PROTOCOL_ID, _handle_echo) - py_maddr = host.get_addrs()[0] - go_proc, _ = await make_echo_proc( - proc_factory, unused_tcp_port, is_host_secure, py_maddr - ) - await go_proc.expect("connect with peer", async_=True) - await go_proc.expect("opened stream", async_=True) - await event_handler_finished.wait() - await go_proc.expect(f"read reply: .*{reply_data.rstrip()}.*", async_=True) + host.set_stream_handler(ECHO_PROTOCOL_ID, _handle_echo) + py_maddr = host.get_addrs()[0] + go_proc = EchoProcess(get_unused_tcp_port(), is_host_secure, py_maddr) + await go_proc.start() + await event_handler_finished.wait() diff --git a/tests_interop/test_net_stream.py b/tests_interop/test_net_stream.py index 2c897d2b..59812a9b 100644 --- a/tests_interop/test_net_stream.py +++ b/tests_interop/test_net_stream.py @@ -1,6 +1,5 @@ -import asyncio - import pytest +import trio from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset from libp2p.tools.constants import MAX_READ_LEN @@ -8,7 +7,7 @@ from libp2p.tools.constants import MAX_READ_LEN DATA = b"data" -@pytest.mark.asyncio +@pytest.mark.trio async def test_net_stream_read_write(py_to_daemon_stream_pair, p2pds): stream_py, stream_daemon = py_to_daemon_stream_pair assert ( @@ -19,19 +18,19 @@ async def test_net_stream_read_write(py_to_daemon_stream_pair, p2pds): assert (await stream_daemon.read(MAX_READ_LEN)) == DATA -@pytest.mark.asyncio +@pytest.mark.trio async def test_net_stream_read_after_remote_closed(py_to_daemon_stream_pair, p2pds): stream_py, stream_daemon = py_to_daemon_stream_pair await stream_daemon.write(DATA) await stream_daemon.close() - await asyncio.sleep(0.01) + await trio.sleep(0.01) assert (await stream_py.read(MAX_READ_LEN)) == DATA # EOF with pytest.raises(StreamEOF): await stream_py.read(MAX_READ_LEN) -@pytest.mark.asyncio +@pytest.mark.trio async def test_net_stream_read_after_local_reset(py_to_daemon_stream_pair, p2pds): stream_py, _ = py_to_daemon_stream_pair await stream_py.reset() @@ -40,15 +39,15 @@ async def test_net_stream_read_after_local_reset(py_to_daemon_stream_pair, p2pds @pytest.mark.parametrize("is_to_fail_daemon_stream", (True,)) -@pytest.mark.asyncio +@pytest.mark.trio async def test_net_stream_read_after_remote_reset(py_to_daemon_stream_pair, p2pds): stream_py, _ = py_to_daemon_stream_pair - await asyncio.sleep(0.01) + await trio.sleep(0.01) with pytest.raises(StreamReset): await stream_py.read(MAX_READ_LEN) -@pytest.mark.asyncio +@pytest.mark.trio async def test_net_stream_write_after_local_closed(py_to_daemon_stream_pair, p2pds): stream_py, _ = py_to_daemon_stream_pair await stream_py.write(DATA) @@ -57,7 +56,7 @@ async def test_net_stream_write_after_local_closed(py_to_daemon_stream_pair, p2p await stream_py.write(DATA) -@pytest.mark.asyncio +@pytest.mark.trio async def test_net_stream_write_after_local_reset(py_to_daemon_stream_pair, p2pds): stream_py, stream_daemon = py_to_daemon_stream_pair await stream_py.reset() @@ -66,9 +65,9 @@ async def test_net_stream_write_after_local_reset(py_to_daemon_stream_pair, p2pd @pytest.mark.parametrize("is_to_fail_daemon_stream", (True,)) -@pytest.mark.asyncio +@pytest.mark.trio async def test_net_stream_write_after_remote_reset(py_to_daemon_stream_pair, p2pds): stream_py, _ = py_to_daemon_stream_pair - await asyncio.sleep(0.01) + await trio.sleep(0.01) with pytest.raises(StreamClosed): await stream_py.write(DATA) diff --git a/tests_interop/test_pubsub.py b/tests_interop/test_pubsub.py index db42c7cd..793f2446 100644 --- a/tests_interop/test_pubsub.py +++ b/tests_interop/test_pubsub.py @@ -1,11 +1,15 @@ -import asyncio import functools +import math from p2pclient.pb import p2pd_pb2 import pytest +import trio +from libp2p.io.trio import TrioTCPStream from libp2p.peer.id import ID from libp2p.pubsub.pb import rpc_pb2 +from libp2p.pubsub.subscription import TrioSubscriptionAPI +from libp2p.tools.factories import PubsubFactory from libp2p.tools.interop.utils import connect from libp2p.utils import read_varint_prefixed_bytes @@ -13,26 +17,15 @@ TOPIC_0 = "ABALA" TOPIC_1 = "YOOOO" -async def p2pd_subscribe(p2pd, topic) -> "asyncio.Queue[rpc_pb2.Message]": - reader, writer = await p2pd.control.pubsub_subscribe(topic) +async def p2pd_subscribe(p2pd, topic, nursery): + stream = TrioTCPStream(await p2pd.control.pubsub_subscribe(topic)) + send_channel, receive_channel = trio.open_memory_channel(math.inf) - queue = asyncio.Queue() + sub = TrioSubscriptionAPI(receive_channel, unsubscribe_fn=stream.close) async def _read_pubsub_msg() -> None: - writer_closed_task = asyncio.ensure_future(writer.wait_closed()) - while True: - done, pending = await asyncio.wait( - [read_varint_prefixed_bytes(reader), writer_closed_task], - return_when=asyncio.FIRST_COMPLETED, - ) - done_tasks = tuple(done) - if writer.is_closing(): - return - read_task = done_tasks[0] - # Sanity check - assert read_task._coro.__name__ == "read_varint_prefixed_bytes" - msg_bytes = read_task.result() + msg_bytes = await read_varint_prefixed_bytes(stream) ps_msg = p2pd_pb2.PSMessage() ps_msg.ParseFromString(msg_bytes) # Fill in the message used in py-libp2p @@ -44,11 +37,10 @@ async def p2pd_subscribe(p2pd, topic) -> "asyncio.Queue[rpc_pb2.Message]": signature=ps_msg.signature, key=ps_msg.key, ) - queue.put_nowait(msg) + await send_channel.send(msg) - asyncio.ensure_future(_read_pubsub_msg()) - await asyncio.sleep(0) - return queue + nursery.start_soon(_read_pubsub_msg) + return sub def validate_pubsub_msg(msg: rpc_pb2.Message, data: bytes, from_peer_id: ID) -> None: @@ -59,108 +51,119 @@ def validate_pubsub_msg(msg: rpc_pb2.Message, data: bytes, from_peer_id: ID) -> "is_pubsub_signing, is_pubsub_signing_strict", ((True, True), (False, False)) ) @pytest.mark.parametrize("is_gossipsub", (True, False)) -@pytest.mark.parametrize("num_hosts, num_p2pds", ((1, 2),)) -@pytest.mark.asyncio -async def test_pubsub(pubsubs, p2pds): - # - # Test: Recognize pubsub peers on connection. - # - py_pubsub = pubsubs[0] - # go0 <-> py <-> go1 - await connect(p2pds[0], py_pubsub.host) - await connect(py_pubsub.host, p2pds[1]) - py_peer_id = py_pubsub.host.get_id() - # Check pubsub peers - pubsub_peers_0 = await p2pds[0].control.pubsub_list_peers("") - assert len(pubsub_peers_0) == 1 and pubsub_peers_0[0] == py_peer_id - pubsub_peers_1 = await p2pds[1].control.pubsub_list_peers("") - assert len(pubsub_peers_1) == 1 and pubsub_peers_1[0] == py_peer_id - assert ( - len(py_pubsub.peers) == 2 - and p2pds[0].peer_id in py_pubsub.peers - and p2pds[1].peer_id in py_pubsub.peers - ) +@pytest.mark.parametrize("num_p2pds", (2,)) +@pytest.mark.trio +async def test_pubsub( + p2pds, is_gossipsub, is_host_secure, is_pubsub_signing_strict, nursery +): + pubsub_factory = None + if is_gossipsub: + pubsub_factory = PubsubFactory.create_batch_with_gossipsub + else: + pubsub_factory = PubsubFactory.create_batch_with_floodsub - # - # Test: `subscribe`. - # - # (name, topics) - # (go_0, [0, 1]) <-> (py, [0, 1]) <-> (go_1, [1]) - sub_py_topic_0 = await py_pubsub.subscribe(TOPIC_0) - sub_py_topic_1 = await py_pubsub.subscribe(TOPIC_1) - sub_go_0_topic_0 = await p2pd_subscribe(p2pds[0], TOPIC_0) - sub_go_0_topic_1 = await p2pd_subscribe(p2pds[0], TOPIC_1) - sub_go_1_topic_1 = await p2pd_subscribe(p2pds[1], TOPIC_1) - # Check topic peers - await asyncio.sleep(0.1) - # go_0 - go_0_topic_0_peers = await p2pds[0].control.pubsub_list_peers(TOPIC_0) - assert len(go_0_topic_0_peers) == 1 and py_peer_id == go_0_topic_0_peers[0] - go_0_topic_1_peers = await p2pds[0].control.pubsub_list_peers(TOPIC_1) - assert len(go_0_topic_1_peers) == 1 and py_peer_id == go_0_topic_1_peers[0] - # py - py_topic_0_peers = list(py_pubsub.peer_topics[TOPIC_0]) - assert len(py_topic_0_peers) == 1 and p2pds[0].peer_id == py_topic_0_peers[0] - # go_1 - go_1_topic_1_peers = await p2pds[1].control.pubsub_list_peers(TOPIC_1) - assert len(go_1_topic_1_peers) == 1 and py_peer_id == go_1_topic_1_peers[0] + async with pubsub_factory( + 1, is_secure=is_host_secure, strict_signing=is_pubsub_signing_strict + ) as pubsubs: + # + # Test: Recognize pubsub peers on connection. + # + py_pubsub = pubsubs[0] + # go0 <-> py <-> go1 + await connect(p2pds[0], py_pubsub.host) + await connect(py_pubsub.host, p2pds[1]) + py_peer_id = py_pubsub.host.get_id() + # Check pubsub peers + pubsub_peers_0 = await p2pds[0].control.pubsub_list_peers("") + assert len(pubsub_peers_0) == 1 and pubsub_peers_0[0] == py_peer_id + pubsub_peers_1 = await p2pds[1].control.pubsub_list_peers("") + assert len(pubsub_peers_1) == 1 and pubsub_peers_1[0] == py_peer_id + assert ( + len(py_pubsub.peers) == 2 + and p2pds[0].peer_id in py_pubsub.peers + and p2pds[1].peer_id in py_pubsub.peers + ) - # - # Test: `publish` - # - # 1. py publishes - # - 1.1. py publishes data_11 to topic_0, py and go_0 receives. - # - 1.2. py publishes data_12 to topic_1, all receive. - # 2. go publishes - # - 2.1. go_0 publishes data_21 to topic_0, py and go_0 receive. - # - 2.2. go_1 publishes data_22 to topic_1, all receive. + # + # Test: `subscribe`. + # + # (name, topics) + # (go_0, [0, 1]) <-> (py, [0, 1]) <-> (go_1, [1]) + sub_py_topic_0 = await py_pubsub.subscribe(TOPIC_0) + sub_py_topic_1 = await py_pubsub.subscribe(TOPIC_1) + sub_go_0_topic_0 = await p2pd_subscribe(p2pds[0], TOPIC_0, nursery) + sub_go_0_topic_1 = await p2pd_subscribe(p2pds[0], TOPIC_1, nursery) + sub_go_1_topic_1 = await p2pd_subscribe(p2pds[1], TOPIC_1, nursery) + # Check topic peers + await trio.sleep(0.1) + # go_0 + go_0_topic_0_peers = await p2pds[0].control.pubsub_list_peers(TOPIC_0) + assert len(go_0_topic_0_peers) == 1 and py_peer_id == go_0_topic_0_peers[0] + go_0_topic_1_peers = await p2pds[0].control.pubsub_list_peers(TOPIC_1) + assert len(go_0_topic_1_peers) == 1 and py_peer_id == go_0_topic_1_peers[0] + # py + py_topic_0_peers = list(py_pubsub.peer_topics[TOPIC_0]) + assert len(py_topic_0_peers) == 1 and p2pds[0].peer_id == py_topic_0_peers[0] + # go_1 + go_1_topic_1_peers = await p2pds[1].control.pubsub_list_peers(TOPIC_1) + assert len(go_1_topic_1_peers) == 1 and py_peer_id == go_1_topic_1_peers[0] - # 1.1. py publishes data_11 to topic_0, py and go_0 receives. - data_11 = b"data_11" - await py_pubsub.publish(TOPIC_0, data_11) - validate_11 = functools.partial( - validate_pubsub_msg, data=data_11, from_peer_id=py_peer_id - ) - validate_11(await sub_py_topic_0.get()) - validate_11(await sub_go_0_topic_0.get()) + # + # Test: `publish` + # + # 1. py publishes + # - 1.1. py publishes data_11 to topic_0, py and go_0 receives. + # - 1.2. py publishes data_12 to topic_1, all receive. + # 2. go publishes + # - 2.1. go_0 publishes data_21 to topic_0, py and go_0 receive. + # - 2.2. go_1 publishes data_22 to topic_1, all receive. - # 1.2. py publishes data_12 to topic_1, all receive. - data_12 = b"data_12" - validate_12 = functools.partial( - validate_pubsub_msg, data=data_12, from_peer_id=py_peer_id - ) - await py_pubsub.publish(TOPIC_1, data_12) - validate_12(await sub_py_topic_1.get()) - validate_12(await sub_go_0_topic_1.get()) - validate_12(await sub_go_1_topic_1.get()) + # 1.1. py publishes data_11 to topic_0, py and go_0 receives. + data_11 = b"data_11" + await py_pubsub.publish(TOPIC_0, data_11) + validate_11 = functools.partial( + validate_pubsub_msg, data=data_11, from_peer_id=py_peer_id + ) + validate_11(await sub_py_topic_0.get()) + validate_11(await sub_go_0_topic_0.get()) - # 2.1. go_0 publishes data_21 to topic_0, py and go_0 receive. - data_21 = b"data_21" - validate_21 = functools.partial( - validate_pubsub_msg, data=data_21, from_peer_id=p2pds[0].peer_id - ) - await p2pds[0].control.pubsub_publish(TOPIC_0, data_21) - validate_21(await sub_py_topic_0.get()) - validate_21(await sub_go_0_topic_0.get()) + # 1.2. py publishes data_12 to topic_1, all receive. + data_12 = b"data_12" + validate_12 = functools.partial( + validate_pubsub_msg, data=data_12, from_peer_id=py_peer_id + ) + await py_pubsub.publish(TOPIC_1, data_12) + validate_12(await sub_py_topic_1.get()) + validate_12(await sub_go_0_topic_1.get()) + validate_12(await sub_go_1_topic_1.get()) - # 2.2. go_1 publishes data_22 to topic_1, all receive. - data_22 = b"data_22" - validate_22 = functools.partial( - validate_pubsub_msg, data=data_22, from_peer_id=p2pds[1].peer_id - ) - await p2pds[1].control.pubsub_publish(TOPIC_1, data_22) - validate_22(await sub_py_topic_1.get()) - validate_22(await sub_go_0_topic_1.get()) - validate_22(await sub_go_1_topic_1.get()) + # 2.1. go_0 publishes data_21 to topic_0, py and go_0 receive. + data_21 = b"data_21" + validate_21 = functools.partial( + validate_pubsub_msg, data=data_21, from_peer_id=p2pds[0].peer_id + ) + await p2pds[0].control.pubsub_publish(TOPIC_0, data_21) + validate_21(await sub_py_topic_0.get()) + validate_21(await sub_go_0_topic_0.get()) - # - # Test: `unsubscribe` and re`subscribe` - # - await py_pubsub.unsubscribe(TOPIC_0) - await asyncio.sleep(0.1) - assert py_peer_id not in (await p2pds[0].control.pubsub_list_peers(TOPIC_0)) - assert py_peer_id not in (await p2pds[1].control.pubsub_list_peers(TOPIC_0)) - await py_pubsub.subscribe(TOPIC_0) - await asyncio.sleep(0.1) - assert py_peer_id in (await p2pds[0].control.pubsub_list_peers(TOPIC_0)) - assert py_peer_id in (await p2pds[1].control.pubsub_list_peers(TOPIC_0)) + # 2.2. go_1 publishes data_22 to topic_1, all receive. + data_22 = b"data_22" + validate_22 = functools.partial( + validate_pubsub_msg, data=data_22, from_peer_id=p2pds[1].peer_id + ) + await p2pds[1].control.pubsub_publish(TOPIC_1, data_22) + validate_22(await sub_py_topic_1.get()) + validate_22(await sub_go_0_topic_1.get()) + validate_22(await sub_go_1_topic_1.get()) + + # + # Test: `unsubscribe` and re`subscribe` + # + await py_pubsub.unsubscribe(TOPIC_0) + await trio.sleep(0.1) + assert py_peer_id not in (await p2pds[0].control.pubsub_list_peers(TOPIC_0)) + assert py_peer_id not in (await p2pds[1].control.pubsub_list_peers(TOPIC_0)) + await py_pubsub.subscribe(TOPIC_0) + await trio.sleep(0.1) + assert py_peer_id in (await p2pds[0].control.pubsub_list_peers(TOPIC_0)) + assert py_peer_id in (await p2pds[1].control.pubsub_list_peers(TOPIC_0)) diff --git a/tox.ini b/tox.ini index afa838ab..a03627ee 100644 --- a/tox.ini +++ b/tox.ini @@ -12,7 +12,7 @@ envlist = combine_as_imports=False force_sort_within_sections=True include_trailing_comma=True -known_third_party=hypothesis,pytest,p2pclient,pexpect,factory,lru +known_third_party=anyio,factory,lru,p2pclient,pytest known_first_party=libp2p line_length=88 multi_line_output=3 @@ -58,7 +58,6 @@ commands = [testenv:py37-interop] deps = p2pclient - pexpect passenv = CI TRAVIS TRAVIS_* GOPATH extras = test commands =