From ed17bfd6639993b26a316c39367e234b240c9c3d Mon Sep 17 00:00:00 2001 From: Chih Cheng Liang Date: Mon, 11 Nov 2019 16:38:30 +0800 Subject: [PATCH 01/81] hack chat example --- examples/chat/chat.py | 16 +++++++--------- libp2p/__init__.py | 15 --------------- 2 files changed, 7 insertions(+), 24 deletions(-) diff --git a/examples/chat/chat.py b/examples/chat/chat.py index 24c92699..73436a0d 100755 --- a/examples/chat/chat.py +++ b/examples/chat/chat.py @@ -1,5 +1,7 @@ import argparse import asyncio +import trio_asyncio +import trio import sys import urllib.request @@ -74,6 +76,10 @@ async def run(port: int, destination: str, localhost: bool) -> None: asyncio.ensure_future(write_data(stream)) print("Connected to peer %s" % info.addrs[0]) +async def async_main_wrapper(*args): + async with trio_asyncio.open_loop() as loop: + assert loop == asyncio.get_event_loop() + await run(*args) def main() -> None: description = """ @@ -112,15 +118,7 @@ def main() -> None: if not args.port: raise RuntimeError("was not able to determine a local port") - loop = asyncio.get_event_loop() - try: - asyncio.ensure_future(run(args.port, args.destination, args.localhost)) - loop.run_forever() - except KeyboardInterrupt: - pass - finally: - loop.close() - + trio.run(async_main_wrapper, *(args.port, args.destination, args.localhost)) if __name__ == "__main__": main() diff --git a/libp2p/__init__.py b/libp2p/__init__.py index a1dca535..3359f7de 100644 --- a/libp2p/__init__.py +++ b/libp2p/__init__.py @@ -24,18 +24,6 @@ from libp2p.transport.upgrader import TransportUpgrader from libp2p.typing import TProtocol -async def cleanup_done_tasks() -> None: - """clean up asyncio done tasks to free up resources.""" - while True: - for task in asyncio.all_tasks(): - if task.done(): - await task - - # Need not run often - # Some sleep necessary to context switch - await asyncio.sleep(3) - - def generate_new_rsa_identity() -> KeyPair: return create_new_key_pair() @@ -155,7 +143,4 @@ async def new_node( else: host = BasicHost(key_pair.public_key, swarm_opt) - # Kick off cleanup job - asyncio.ensure_future(cleanup_done_tasks()) - return host From f5c725788e2647bb2acf22331d552c589707a3e2 Mon Sep 17 00:00:00 2001 From: Chih Cheng Liang Date: Mon, 11 Nov 2019 22:52:48 +0800 Subject: [PATCH 02/81] need manual stop --- examples/chat/chat.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/chat/chat.py b/examples/chat/chat.py index 73436a0d..b1d416ff 100755 --- a/examples/chat/chat.py +++ b/examples/chat/chat.py @@ -79,7 +79,10 @@ async def run(port: int, destination: str, localhost: bool) -> None: async def async_main_wrapper(*args): async with trio_asyncio.open_loop() as loop: assert loop == asyncio.get_event_loop() - await run(*args) + stopped_event = trio.Event() + await trio_asyncio.run_asyncio(run, *args) + await stopped_event.wait() + def main() -> None: description = """ From d4d345c3c7850fe3dfca3a4d77c2902cdc58f7a8 Mon Sep 17 00:00:00 2001 From: Chih Cheng Liang Date: Mon, 18 Nov 2019 17:14:37 +0800 Subject: [PATCH 03/81] progressing --- examples/chat/chat.py | 18 +++++++----------- libp2p/__init__.py | 2 +- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/examples/chat/chat.py b/examples/chat/chat.py index b1d416ff..8f10cd30 100755 --- a/examples/chat/chat.py +++ b/examples/chat/chat.py @@ -40,9 +40,9 @@ async def run(port: int, destination: str, localhost: bool) -> None: else: ip = urllib.request.urlopen("https://v4.ident.me/").read().decode("utf8") transport_opt = f"/ip4/{ip}/tcp/{port}" - host = await new_node(transport_opt=[transport_opt]) + host = new_node(transport_opt=[transport_opt]) - await host.get_network().listen(multiaddr.Multiaddr(transport_opt)) + await trio_asyncio.run_asyncio(host.get_network().listen,multiaddr.Multiaddr(transport_opt) ) if not destination: # its the server @@ -66,22 +66,18 @@ async def run(port: int, destination: str, localhost: bool) -> None: maddr = multiaddr.Multiaddr(destination) info = info_from_p2p_addr(maddr) # Associate the peer with local ip address - await host.connect(info) + await trio_asyncio.run_asyncio(host.connect, info) # Start a stream with the destination. # Multiaddress of the destination peer is fetched from the peerstore using 'peerId'. - stream = await host.new_stream(info.peer_id, [PROTOCOL_ID]) + stream = await trio_asyncio.run_asyncio(host.new_stream, *(info.peer_id, [PROTOCOL_ID])) asyncio.ensure_future(read_data(stream)) asyncio.ensure_future(write_data(stream)) print("Connected to peer %s" % info.addrs[0]) -async def async_main_wrapper(*args): - async with trio_asyncio.open_loop() as loop: - assert loop == asyncio.get_event_loop() - stopped_event = trio.Event() - await trio_asyncio.run_asyncio(run, *args) - await stopped_event.wait() + stopped_event = trio.Event() + await stopped_event.wait() def main() -> None: @@ -121,7 +117,7 @@ def main() -> None: if not args.port: raise RuntimeError("was not able to determine a local port") - trio.run(async_main_wrapper, *(args.port, args.destination, args.localhost)) + trio_asyncio.run(run, *(args.port, args.destination, args.localhost)) if __name__ == "__main__": main() diff --git a/libp2p/__init__.py b/libp2p/__init__.py index 3359f7de..43beeeae 100644 --- a/libp2p/__init__.py +++ b/libp2p/__init__.py @@ -97,7 +97,7 @@ def initialize_default_swarm( return Swarm(id_opt, peerstore, upgrader, transport) -async def new_node( +def new_node( key_pair: KeyPair = None, swarm_opt: INetwork = None, transport_opt: Sequence[str] = None, From 41ff884eefa1ace322b433b793872c48c1715621 Mon Sep 17 00:00:00 2001 From: Chih Cheng Liang Date: Tue, 19 Nov 2019 14:01:12 +0800 Subject: [PATCH 04/81] rewrite tcp reader/writer interface --- libp2p/io/trio.py | 32 +++++++++++++++ libp2p/network/connection/raw_connection.py | 40 ++++++------------- libp2p/network/swarm.py | 22 +++-------- libp2p/stream_muxer/mplex/mplex.py | 35 +++-------------- libp2p/stream_muxer/mplex/mplex_stream.py | 26 +++++++------ libp2p/tools/utils.py | 7 ++-- libp2p/transport/tcp/tcp.py | 43 +++++++++++++++------ libp2p/transport/typing.py | 5 ++- tests/libp2p/test_libp2p.py | 8 ++-- 9 files changed, 112 insertions(+), 106 deletions(-) create mode 100644 libp2p/io/trio.py diff --git a/libp2p/io/trio.py b/libp2p/io/trio.py new file mode 100644 index 00000000..6e88f01d --- /dev/null +++ b/libp2p/io/trio.py @@ -0,0 +1,32 @@ +import trio +from trio import SocketStream +from libp2p.io.abc import ReadWriteCloser +from libp2p.io.exceptions import IOException +import logging + + +logger = logging.getLogger("libp2p.io.trio") + + +class TrioReadWriteCloser(ReadWriteCloser): + stream: SocketStream + + def __init__(self, stream: SocketStream) -> None: + self.stream = stream + + async def write(self, data: bytes) -> None: + """Raise `RawConnError` if the underlying connection breaks.""" + try: + await self.stream.send_all(data) + except (trio.ClosedResourceError, trio.BrokenResourceError) as error: + raise IOException(error) + + async def read(self, n: int = -1) -> bytes: + max_bytes = n if n != -1 else None + try: + return await self.stream.receive_some(max_bytes) + except (trio.ClosedResourceError, trio.BrokenResourceError) as error: + raise IOException(error) + + async def close(self) -> None: + await self.stream.aclose() diff --git a/libp2p/network/connection/raw_connection.py b/libp2p/network/connection/raw_connection.py index 08d22055..50b28984 100644 --- a/libp2p/network/connection/raw_connection.py +++ b/libp2p/network/connection/raw_connection.py @@ -1,42 +1,25 @@ -import asyncio +import trio +from libp2p.io.exceptions import IOException from .exceptions import RawConnError from .raw_connection_interface import IRawConnection +from libp2p.io.abc import ReadWriteCloser class RawConnection(IRawConnection): - reader: asyncio.StreamReader - writer: asyncio.StreamWriter + read_write_closer: ReadWriteCloser is_initiator: bool - _drain_lock: asyncio.Lock - - def __init__( - self, - reader: asyncio.StreamReader, - writer: asyncio.StreamWriter, - initiator: bool, - ) -> None: - self.reader = reader - self.writer = writer + def __init__(self, read_write_closer: ReadWriteCloser, initiator: bool) -> None: + self.read_write_closer = read_write_closer self.is_initiator = initiator - self._drain_lock = asyncio.Lock() - async def write(self, data: bytes) -> None: """Raise `RawConnError` if the underlying connection breaks.""" try: - self.writer.write(data) - except ConnectionResetError as error: + await self.read_write_closer.write(data) + except IOException as error: raise RawConnError(error) - # Reference: https://github.com/ethereum/lahja/blob/93610b2eb46969ff1797e0748c7ac2595e130aef/lahja/asyncio/endpoint.py#L99-L102 # noqa: E501 - # Use a lock to serialize drain() calls. Circumvents this bug: - # https://bugs.python.org/issue29930 - async with self._drain_lock: - try: - await self.writer.drain() - except ConnectionResetError as error: - raise RawConnError(error) async def read(self, n: int = -1) -> bytes: """ @@ -46,10 +29,9 @@ class RawConnection(IRawConnection): Raise `RawConnError` if the underlying connection breaks """ try: - return await self.reader.read(n) - except ConnectionResetError as error: + return await self.read_write_closer.read(n) + except IOException as error: raise RawConnError(error) async def close(self) -> None: - self.writer.close() - await self.writer.wait_closed() + await self.read_write_closer.close() diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 7bb40cee..9b89fa56 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -4,6 +4,7 @@ from typing import Dict, List, Optional from multiaddr import Multiaddr +from libp2p.io.abc import ReadWriteCloser from libp2p.network.connection.net_connection_interface import INetConn from libp2p.peer.id import ID from libp2p.peer.peerstore import PeerStoreError @@ -149,7 +150,7 @@ class Swarm(INetwork): logger.debug("successfully opened a stream to peer %s", peer_id) return net_stream - async def listen(self, *multiaddrs: Multiaddr) -> bool: + async def listen(self, *multiaddrs: Multiaddr, nursery) -> bool: """ :param multiaddrs: one or many multiaddrs to start listening on :return: true if at least one success @@ -167,15 +168,8 @@ class Swarm(INetwork): if str(maddr) in self.listeners: return True - async def conn_handler( - reader: asyncio.StreamReader, writer: asyncio.StreamWriter - ) -> None: - connection_info = writer.get_extra_info("peername") - # TODO make a proper multiaddr - peer_addr = f"/ip4/{connection_info[0]}/tcp/{connection_info[1]}" - logger.debug("inbound connection at %s", peer_addr) - # logger.debug("inbound connection request", peer_id) - raw_conn = RawConnection(reader, writer, False) + async def conn_handler(read_write_closer: ReadWriteCloser) -> None: + raw_conn = RawConnection(read_write_closer, False) # Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure # the conn and then mux the conn @@ -185,14 +179,10 @@ class Swarm(INetwork): raw_conn, ID(b""), False ) except SecurityUpgradeFailure as error: - error_msg = "fail to upgrade security for peer at %s" - logger.debug(error_msg, peer_addr) await raw_conn.close() - raise SwarmException(error_msg % peer_addr) from error + raise SwarmException() from error peer_id = secured_conn.get_remote_peer() - logger.debug("upgraded security for peer at %s", peer_addr) - logger.debug("identified peer at %s as %s", peer_addr, peer_id) try: muxed_conn = await self.upgrader.upgrade_connection( @@ -213,7 +203,7 @@ class Swarm(INetwork): # Success listener = self.transport.create_listener(conn_handler) self.listeners[str(maddr)] = listener - await listener.listen(maddr) + await listener.listen(maddr, nursery) # Call notifiers since event occurred self.notify_listen(maddr) diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 1a43c7cb..1f81c2f2 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -123,27 +123,10 @@ class Mplex(IMuxedConn): await self.send_message(HeaderTags.NewStream, name.encode(), stream_id) return stream - async def _wait_until_shutting_down_or_closed(self, coro: Awaitable[Any]) -> Any: - task_coro = asyncio.ensure_future(coro) - task_wait_closed = asyncio.ensure_future(self.event_closed.wait()) - task_wait_shutting_down = asyncio.ensure_future(self.event_shutting_down.wait()) - done, pending = await asyncio.wait( - [task_coro, task_wait_closed, task_wait_shutting_down], - return_when=asyncio.FIRST_COMPLETED, - ) - for fut in pending: - fut.cancel() - if task_wait_closed in done: - raise MplexUnavailable("Mplex is closed") - if task_wait_shutting_down in done: - raise MplexUnavailable("Mplex is shutting down") - return task_coro.result() async def accept_stream(self) -> IMuxedStream: """accepts a muxed stream opened by the other end.""" - return await self._wait_until_shutting_down_or_closed( - self.new_stream_queue.get() - ) + return await self.new_stream_queue.get() async def send_message( self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID @@ -163,9 +146,7 @@ class Mplex(IMuxedConn): _bytes = header + encode_varint_prefixed(data) - return await self._wait_until_shutting_down_or_closed( - self.write_to_stream(_bytes) - ) + return await self.write_to_stream(_bytes) async def write_to_stream(self, _bytes: bytes) -> int: """ @@ -226,9 +207,7 @@ class Mplex(IMuxedConn): :raise MplexUnavailable: `Mplex` encounters fatal error or is shutting down. """ - channel_id, flag, message = await self._wait_until_shutting_down_or_closed( - self.read_message() - ) + channel_id, flag, message = await self.read_message() stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1)) if flag == HeaderTags.NewStream.value: @@ -258,9 +237,7 @@ class Mplex(IMuxedConn): f"received NewStream message for existing stream: {stream_id}" ) mplex_stream = await self._initialize_stream(stream_id, message.decode()) - await self._wait_until_shutting_down_or_closed( - self.new_stream_queue.put(mplex_stream) - ) + await self.new_stream_queue.put(mplex_stream) async def _handle_message(self, stream_id: StreamID, message: bytes) -> None: async with self.streams_lock: @@ -274,9 +251,7 @@ class Mplex(IMuxedConn): if stream.event_remote_closed.is_set(): # TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501 return - await self._wait_until_shutting_down_or_closed( - stream.incoming_data.put(message) - ) + await stream.incoming_data.put(message) async def _handle_close(self, stream_id: StreamID) -> None: async with self.streams_lock: diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index f080d3cf..7659d32f 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -1,3 +1,4 @@ +import trio import asyncio from typing import TYPE_CHECKING @@ -22,14 +23,14 @@ class MplexStream(IMuxedStream): read_deadline: int write_deadline: int - close_lock: asyncio.Lock + close_lock: trio.Lock # NOTE: `dataIn` is size of 8 in Go implementation. incoming_data: "asyncio.Queue[bytes]" - event_local_closed: asyncio.Event - event_remote_closed: asyncio.Event - event_reset: asyncio.Event + event_local_closed: trio.Event + event_remote_closed: trio.Event + event_reset: trio.Event _buf: bytearray @@ -45,10 +46,10 @@ class MplexStream(IMuxedStream): self.muxed_conn = muxed_conn self.read_deadline = None self.write_deadline = None - self.event_local_closed = asyncio.Event() - self.event_remote_closed = asyncio.Event() - self.event_reset = asyncio.Event() - self.close_lock = asyncio.Lock() + self.event_local_closed = trio.Event() + self.event_remote_closed = trio.Event() + self.event_reset = trio.Event() + self.close_lock = trio.Lock() self.incoming_data = asyncio.Queue() self._buf = bytearray() @@ -199,10 +200,11 @@ class MplexStream(IMuxedStream): if self.is_initiator else HeaderTags.ResetReceiver ) - asyncio.ensure_future( - self.muxed_conn.send_message(flag, None, self.stream_id) - ) - await asyncio.sleep(0) + async with trio.open_nursery() as nursery: + nursery.start_soon( + self.muxed_conn.send_message, flag, None, self.stream_id + ) + await trio.sleep(0) self.event_local_closed.set() self.event_remote_closed.set() diff --git a/libp2p/tools/utils.py b/libp2p/tools/utils.py index 84e3edf9..89f864eb 100644 --- a/libp2p/tools/utils.py +++ b/libp2p/tools/utils.py @@ -1,3 +1,4 @@ +import trio from typing import List, Sequence, Tuple import multiaddr @@ -37,12 +38,12 @@ async def connect(node1: IHost, node2: IHost) -> None: async def set_up_nodes_by_transport_opt( - transport_opt_list: Sequence[Sequence[str]] + transport_opt_list: Sequence[Sequence[str]], nursery: trio.Nursery ) -> Tuple[BasicHost, ...]: nodes_list = [] for transport_opt in transport_opt_list: - node = await new_node(transport_opt=transport_opt) - await node.get_network().listen(multiaddr.Multiaddr(transport_opt[0])) + node = new_node(transport_opt=transport_opt) + await node.get_network().listen(multiaddr.Multiaddr(transport_opt[0]), nursery=nursery) nodes_list.append(node) return tuple(nodes_list) diff --git a/libp2p/transport/tcp/tcp.py b/libp2p/transport/tcp/tcp.py index 7470510d..8838b504 100644 --- a/libp2p/transport/tcp/tcp.py +++ b/libp2p/transport/tcp/tcp.py @@ -1,4 +1,5 @@ import asyncio +import trio from socket import socket from typing import List @@ -10,6 +11,10 @@ from libp2p.transport.exceptions import OpenConnectionError from libp2p.transport.listener_interface import IListener from libp2p.transport.transport_interface import ITransport from libp2p.transport.typing import THandler +from libp2p.io.trio import TrioReadWriteCloser +import logging + +logger = logging.getLogger("libp2p.transport.tcp") class TCPListener(IListener): @@ -21,20 +26,38 @@ class TCPListener(IListener): self.server = None self.handler = handler_function - async def listen(self, maddr: Multiaddr) -> bool: + async def listen(self, maddr: Multiaddr, nursery) -> bool: """ put listener in listening mode and wait for incoming connections. :param maddr: maddr of peer :return: return True if successful """ - self.server = await asyncio.start_server( - self.handler, - maddr.value_for_protocol("ip4"), - maddr.value_for_protocol("tcp"), + + async def serve_tcp(handler, port, host, task_status=None): + logger.debug("serve_tcp %s %s", host, port) + await trio.serve_tcp(handler, port, host=host, task_status=task_status) + + async def handler(stream): + read_write_closer = TrioReadWriteCloser(stream) + await self.handler(read_write_closer) + + listeners = await nursery.start( + serve_tcp, + *( + handler, + int(maddr.value_for_protocol("tcp")), + maddr.value_for_protocol("ip4"), + ), ) - socket = self.server.sockets[0] + # self.server = await asyncio.start_server( + # self.handler, + # maddr.value_for_protocol("ip4"), + # maddr.value_for_protocol("tcp"), + # ) + socket = listeners[0].socket self.multiaddrs.append(_multiaddr_from_socket(socket)) + logger.debug("Multiaddrs %s", self.multiaddrs) return True @@ -69,12 +92,10 @@ class TCP(ITransport): self.host = maddr.value_for_protocol("ip4") self.port = int(maddr.value_for_protocol("tcp")) - try: - reader, writer = await asyncio.open_connection(self.host, self.port) - except (ConnectionAbortedError, ConnectionRefusedError) as error: - raise OpenConnectionError(error) + stream = await trio.open_tcp_stream(self.host, self.port) + read_write_closer = TrioReadWriteCloser(stream) - return RawConnection(reader, writer, True) + return RawConnection(read_write_closer, True) def create_listener(self, handler_function: THandler) -> TCPListener: """ diff --git a/libp2p/transport/typing.py b/libp2p/transport/typing.py index f9b31dcb..0f42335e 100644 --- a/libp2p/transport/typing.py +++ b/libp2p/transport/typing.py @@ -1,11 +1,12 @@ -from asyncio import StreamReader, StreamWriter + from typing import Awaitable, Callable, Mapping, Type from libp2p.security.secure_transport_interface import ISecureTransport from libp2p.stream_muxer.abc import IMuxedConn from libp2p.typing import TProtocol +from libp2p.io.abc import ReadWriteCloser -THandler = Callable[[StreamReader, StreamWriter], Awaitable[None]] +THandler = Callable[[ReadWriteCloser], Awaitable[None]] TSecurityOptions = Mapping[TProtocol, ISecureTransport] TMuxerClass = Type[IMuxedConn] TMuxerOptions = Mapping[TProtocol, TMuxerClass] diff --git a/tests/libp2p/test_libp2p.py b/tests/libp2p/test_libp2p.py index 330250c4..9f78dce5 100644 --- a/tests/libp2p/test_libp2p.py +++ b/tests/libp2p/test_libp2p.py @@ -1,3 +1,4 @@ +import trio import multiaddr import pytest @@ -6,10 +7,10 @@ from libp2p.tools.constants import MAX_READ_LEN from libp2p.tools.utils import set_up_nodes_by_transport_opt -@pytest.mark.asyncio -async def test_simple_messages(): +@pytest.mark.trio +async def test_simple_messages(nursery): transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] - (node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list) + (node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list, nursery) async def stream_handler(stream): while True: @@ -23,6 +24,7 @@ async def test_simple_messages(): # Associate the peer with local ip address (see default parameters of Libp2p()) node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) + stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) messages = ["hello" + str(x) for x in range(10)] From c55ea0e5bb98df4d5684eec8cac960a93371104e Mon Sep 17 00:00:00 2001 From: Chih Cheng Liang Date: Tue, 19 Nov 2019 18:01:29 +0800 Subject: [PATCH 05/81] implement trio queue interface --- libp2p/utils.py | 25 +++++++++++++++++++++++++ tests/test_utils.py | 19 +++++++++++++++++++ 2 files changed, 44 insertions(+) create mode 100644 tests/test_utils.py diff --git a/libp2p/utils.py b/libp2p/utils.py index 3d0794a1..bce9e580 100644 --- a/libp2p/utils.py +++ b/libp2p/utils.py @@ -6,6 +6,9 @@ from libp2p.io.abc import Reader from .io.utils import read_exactly +from typing import Generic, TypeVar +import trio + # Unsigned LEB128(varint codec) # Reference: https://github.com/ethereum/py-wasm/blob/master/wasm/parsers/leb128.py @@ -95,3 +98,25 @@ async def read_fixedint_prefixed(reader: Reader) -> bytes: len_bytes = await reader.read(SIZE_LEN_BYTES) len_int = int.from_bytes(len_bytes, "big") return await reader.read(len_int) + + +TItem = TypeVar("TItem") + + +class IQueue(Generic[TItem]): + async def put(self, item: TItem): + ... + + async def get(self) -> TItem: + ... + + +class TrioQueue(IQueue): + def __init__(self): + self.send_channel, self.receive_channel = trio.open_memory_channel(0) + + async def put(self, item: TItem): + await self.send_channel.send(item) + + async def get(self) -> TItem: + return await self.receive_channel.receive() diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..a6284371 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,19 @@ +import trio +import pytest +from libp2p.utils import TrioQueue + + +@pytest.mark.trio +async def test_trio_queue(): + queue = TrioQueue() + + async def queue_get(task_status=None): + result = await queue.get() + task_status.started(result) + + async with trio.open_nursery() as nursery: + nursery.start_soon(queue.put, 123) + result = await nursery.start(queue_get) + + assert result == 123 + From a397ccdc04fb5e9808bb38c5fcf911cde7a33257 Mon Sep 17 00:00:00 2001 From: Chih Cheng Liang Date: Tue, 19 Nov 2019 18:04:48 +0800 Subject: [PATCH 06/81] makes test_mplex_stream.py::test_mplex_stream_read_write work --- examples/chat/chat.py | 13 ++-- libp2p/io/trio.py | 9 +-- libp2p/network/connection/raw_connection.py | 3 +- libp2p/network/swarm.py | 12 ++-- libp2p/stream_muxer/mplex/mplex.py | 17 ++--- libp2p/stream_muxer/mplex/mplex_stream.py | 79 ++------------------- libp2p/tools/utils.py | 8 ++- libp2p/transport/tcp/tcp.py | 15 ++-- libp2p/transport/typing.py | 3 +- libp2p/utils.py | 6 +- tests/libp2p/test_libp2p.py | 4 +- tests/stream_muxer/test_mplex_stream.py | 19 +++-- tests/test_utils.py | 4 +- 13 files changed, 70 insertions(+), 122 deletions(-) diff --git a/examples/chat/chat.py b/examples/chat/chat.py index 8f10cd30..fbfd89e3 100755 --- a/examples/chat/chat.py +++ b/examples/chat/chat.py @@ -1,11 +1,11 @@ import argparse import asyncio -import trio_asyncio -import trio import sys import urllib.request import multiaddr +import trio +import trio_asyncio from libp2p import new_node from libp2p.network.stream.net_stream_interface import INetStream @@ -42,7 +42,9 @@ async def run(port: int, destination: str, localhost: bool) -> None: transport_opt = f"/ip4/{ip}/tcp/{port}" host = new_node(transport_opt=[transport_opt]) - await trio_asyncio.run_asyncio(host.get_network().listen,multiaddr.Multiaddr(transport_opt) ) + await trio_asyncio.run_asyncio( + host.get_network().listen, multiaddr.Multiaddr(transport_opt) + ) if not destination: # its the server @@ -70,7 +72,9 @@ async def run(port: int, destination: str, localhost: bool) -> None: # Start a stream with the destination. # Multiaddress of the destination peer is fetched from the peerstore using 'peerId'. - stream = await trio_asyncio.run_asyncio(host.new_stream, *(info.peer_id, [PROTOCOL_ID])) + stream = await trio_asyncio.run_asyncio( + host.new_stream, *(info.peer_id, [PROTOCOL_ID]) + ) asyncio.ensure_future(read_data(stream)) asyncio.ensure_future(write_data(stream)) @@ -119,5 +123,6 @@ def main() -> None: trio_asyncio.run(run, *(args.port, args.destination, args.localhost)) + if __name__ == "__main__": main() diff --git a/libp2p/io/trio.py b/libp2p/io/trio.py index 6e88f01d..7d0584e4 100644 --- a/libp2p/io/trio.py +++ b/libp2p/io/trio.py @@ -1,9 +1,10 @@ -import trio -from trio import SocketStream -from libp2p.io.abc import ReadWriteCloser -from libp2p.io.exceptions import IOException import logging +import trio +from trio import SocketStream + +from libp2p.io.abc import ReadWriteCloser +from libp2p.io.exceptions import IOException logger = logging.getLogger("libp2p.io.trio") diff --git a/libp2p/network/connection/raw_connection.py b/libp2p/network/connection/raw_connection.py index 50b28984..2bdb3b10 100644 --- a/libp2p/network/connection/raw_connection.py +++ b/libp2p/network/connection/raw_connection.py @@ -1,9 +1,10 @@ import trio +from libp2p.io.abc import ReadWriteCloser from libp2p.io.exceptions import IOException + from .exceptions import RawConnError from .raw_connection_interface import IRawConnection -from libp2p.io.abc import ReadWriteCloser class RawConnection(IRawConnection): diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 9b89fa56..e54ad6f7 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -3,6 +3,7 @@ import logging from typing import Dict, List, Optional from multiaddr import Multiaddr +import trio from libp2p.io.abc import ReadWriteCloser from libp2p.network.connection.net_connection_interface import INetConn @@ -69,7 +70,7 @@ class Swarm(INetwork): def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None: self.common_stream_handler = stream_handler - async def dial_peer(self, peer_id: ID) -> INetConn: + async def dial_peer(self, peer_id: ID, nursery) -> INetConn: """ dial_peer try to create a connection to peer_id. @@ -121,6 +122,7 @@ class Swarm(INetwork): try: muxed_conn = await self.upgrader.upgrade_connection(secured_conn, peer_id) + muxed_conn.run(nursery) except MuxerUpgradeFailure as error: error_msg = "fail to upgrade mux for peer %s" logger.debug(error_msg, peer_id) @@ -135,7 +137,7 @@ class Swarm(INetwork): return swarm_conn - async def new_stream(self, peer_id: ID) -> INetStream: + async def new_stream(self, peer_id: ID, nursery) -> INetStream: """ :param peer_id: peer_id of destination :param protocol_id: protocol id @@ -144,7 +146,7 @@ class Swarm(INetwork): """ logger.debug("attempting to open a stream to peer %s", peer_id) - swarm_conn = await self.dial_peer(peer_id) + swarm_conn = await self.dial_peer(peer_id, nursery) net_stream = await swarm_conn.new_stream() logger.debug("successfully opened a stream to peer %s", peer_id) @@ -183,11 +185,11 @@ class Swarm(INetwork): raise SwarmException() from error peer_id = secured_conn.get_remote_peer() - try: muxed_conn = await self.upgrader.upgrade_connection( secured_conn, peer_id ) + muxed_conn.run(nursery) except MuxerUpgradeFailure as error: error_msg = "fail to upgrade mux for peer %s" logger.debug(error_msg, peer_id) @@ -198,6 +200,8 @@ class Swarm(INetwork): await self.add_conn(muxed_conn) logger.debug("successfully opened connection to peer %s", peer_id) + event = trio.Event() + await event.wait() try: # Success diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 1f81c2f2..b6df526b 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -3,6 +3,8 @@ import logging from typing import Any # noqa: F401 from typing import Awaitable, Dict, List, Optional, Tuple +import trio + from libp2p.exceptions import ParseError from libp2p.io.exceptions import IncompleteReadError from libp2p.network.connection.exceptions import RawConnError @@ -41,8 +43,6 @@ class Mplex(IMuxedConn): event_shutting_down: asyncio.Event event_closed: asyncio.Event - _tasks: List["asyncio.Future[Any]"] - def __init__(self, secured_conn: ISecureConn, peer_id: ID) -> None: """ create a new muxed connection. @@ -66,10 +66,8 @@ class Mplex(IMuxedConn): self.event_shutting_down = asyncio.Event() self.event_closed = asyncio.Event() - self._tasks = [] - - # Kick off reading - self._tasks.append(asyncio.ensure_future(self.handle_incoming())) + def run(self, nursery): + nursery.start_soon(self.handle_incoming) @property def is_initiator(self) -> bool: @@ -123,7 +121,6 @@ class Mplex(IMuxedConn): await self.send_message(HeaderTags.NewStream, name.encode(), stream_id) return stream - async def accept_stream(self) -> IMuxedStream: """accepts a muxed stream opened by the other end.""" return await self.new_stream_queue.get() @@ -169,7 +166,7 @@ class Mplex(IMuxedConn): logger.debug("mplex unavailable while waiting for incoming: %s", e) break # Force context switch - await asyncio.sleep(0) + await trio.sleep(0) # If we enter here, it means this connection is shutting down. # We should clean things up. await self._cleanup() @@ -184,9 +181,7 @@ class Mplex(IMuxedConn): # FIXME: No timeout is used in Go implementation. try: header = await decode_uvarint_from_stream(self.secured_conn) - message = await asyncio.wait_for( - read_varint_prefixed_bytes(self.secured_conn), timeout=5 - ) + message = await read_varint_prefixed_bytes(self.secured_conn) except (ParseError, RawConnError, IncompleteReadError) as error: raise MplexUnavailable( "failed to read messages correctly from the underlying connection" diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 7659d32f..77458cbd 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -1,8 +1,10 @@ -import trio import asyncio from typing import TYPE_CHECKING +import trio + from libp2p.stream_muxer.abc import IMuxedStream +from libp2p.utils import IQueue, TrioQueue from .constants import HeaderTags from .datastructures import StreamID @@ -26,7 +28,7 @@ class MplexStream(IMuxedStream): close_lock: trio.Lock # NOTE: `dataIn` is size of 8 in Go implementation. - incoming_data: "asyncio.Queue[bytes]" + incoming_data: IQueue[bytes] event_local_closed: trio.Event event_remote_closed: trio.Event @@ -50,69 +52,13 @@ class MplexStream(IMuxedStream): self.event_remote_closed = trio.Event() self.event_reset = trio.Event() self.close_lock = trio.Lock() - self.incoming_data = asyncio.Queue() + self.incoming_data = TrioQueue() self._buf = bytearray() @property def is_initiator(self) -> bool: return self.stream_id.is_initiator - async def _wait_for_data(self) -> None: - task_event_reset = asyncio.ensure_future(self.event_reset.wait()) - task_incoming_data_get = asyncio.ensure_future(self.incoming_data.get()) - task_event_remote_closed = asyncio.ensure_future( - self.event_remote_closed.wait() - ) - done, pending = await asyncio.wait( # type: ignore - [ # type: ignore - task_event_reset, - task_incoming_data_get, - task_event_remote_closed, - ], - return_when=asyncio.FIRST_COMPLETED, - ) - for fut in pending: - fut.cancel() - - if task_event_reset in done: - if self.event_reset.is_set(): - raise MplexStreamReset - else: - # However, it is abnormal that `Event.wait` is unblocked without any of the flag - # is set. The task is probably cancelled. - raise Exception( - "Should not enter here. " - f"It is probably because {task_event_remote_closed} is cancelled." - ) - - if task_incoming_data_get in done: - data = task_incoming_data_get.result() - self._buf.extend(data) - return - - if task_event_remote_closed in done: - if self.event_remote_closed.is_set(): - raise MplexStreamEOF - else: - # However, it is abnormal that `Event.wait` is unblocked without any of the flag - # is set. The task is probably cancelled. - raise Exception( - "Should not enter here. " - f"It is probably because {task_event_remote_closed} is cancelled." - ) - - # TODO: Handle timeout when deadline is used. - - async def _read_until_eof(self) -> bytes: - while True: - try: - await self._wait_for_data() - except MplexStreamEOF: - break - payload = self._buf - self._buf = self._buf[len(payload) :] - return bytes(payload) - async def read(self, n: int = -1) -> bytes: """ Read up to n bytes. Read possibly returns fewer than `n` bytes, if @@ -128,20 +74,7 @@ class MplexStream(IMuxedStream): ) if self.event_reset.is_set(): raise MplexStreamReset - if n == -1: - return await self._read_until_eof() - if len(self._buf) == 0 and self.incoming_data.empty(): - await self._wait_for_data() - # Now we are sure we have something to read. - # Try to put enough incoming data into `self._buf`. - while len(self._buf) < n: - try: - self._buf.extend(self.incoming_data.get_nowait()) - except asyncio.QueueEmpty: - break - payload = self._buf[:n] - self._buf = self._buf[len(payload) :] - return bytes(payload) + return await self.incoming_data.get() async def write(self, data: bytes) -> int: """ diff --git a/libp2p/tools/utils.py b/libp2p/tools/utils.py index 89f864eb..d1c266a0 100644 --- a/libp2p/tools/utils.py +++ b/libp2p/tools/utils.py @@ -17,7 +17,7 @@ from libp2p.typing import StreamHandlerFn, TProtocol from .constants import MAX_READ_LEN -async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm) -> None: +async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm, nursery: trio.Nursery) -> None: peer_id = swarm_1.get_peer_id() addrs = tuple( addr @@ -25,7 +25,7 @@ async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm) -> None: for addr in transport.get_addrs() ) swarm_0.peerstore.add_addrs(peer_id, addrs, 10000) - await swarm_0.dial_peer(peer_id) + await swarm_0.dial_peer(peer_id, nursery) assert swarm_0.get_peer_id() in swarm_1.connections assert swarm_1.get_peer_id() in swarm_0.connections @@ -43,7 +43,9 @@ async def set_up_nodes_by_transport_opt( nodes_list = [] for transport_opt in transport_opt_list: node = new_node(transport_opt=transport_opt) - await node.get_network().listen(multiaddr.Multiaddr(transport_opt[0]), nursery=nursery) + await node.get_network().listen( + multiaddr.Multiaddr(transport_opt[0]), nursery=nursery + ) nodes_list.append(node) return tuple(nodes_list) diff --git a/libp2p/transport/tcp/tcp.py b/libp2p/transport/tcp/tcp.py index 8838b504..0bc1a962 100644 --- a/libp2p/transport/tcp/tcp.py +++ b/libp2p/transport/tcp/tcp.py @@ -1,18 +1,18 @@ import asyncio -import trio +import logging from socket import socket from typing import List from multiaddr import Multiaddr +import trio +from libp2p.io.trio import TrioReadWriteCloser from libp2p.network.connection.raw_connection import RawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.transport.exceptions import OpenConnectionError from libp2p.transport.listener_interface import IListener from libp2p.transport.transport_interface import ITransport from libp2p.transport.typing import THandler -from libp2p.io.trio import TrioReadWriteCloser -import logging logger = logging.getLogger("libp2p.transport.tcp") @@ -44,11 +44,9 @@ class TCPListener(IListener): listeners = await nursery.start( serve_tcp, - *( - handler, - int(maddr.value_for_protocol("tcp")), - maddr.value_for_protocol("ip4"), - ), + handler, + int(maddr.value_for_protocol("tcp")), + maddr.value_for_protocol("ip4"), ) # self.server = await asyncio.start_server( # self.handler, @@ -57,7 +55,6 @@ class TCPListener(IListener): # ) socket = listeners[0].socket self.multiaddrs.append(_multiaddr_from_socket(socket)) - logger.debug("Multiaddrs %s", self.multiaddrs) return True diff --git a/libp2p/transport/typing.py b/libp2p/transport/typing.py index 0f42335e..d68a8aa4 100644 --- a/libp2p/transport/typing.py +++ b/libp2p/transport/typing.py @@ -1,10 +1,9 @@ - from typing import Awaitable, Callable, Mapping, Type +from libp2p.io.abc import ReadWriteCloser from libp2p.security.secure_transport_interface import ISecureTransport from libp2p.stream_muxer.abc import IMuxedConn from libp2p.typing import TProtocol -from libp2p.io.abc import ReadWriteCloser THandler = Callable[[ReadWriteCloser], Awaitable[None]] TSecurityOptions = Mapping[TProtocol, ISecureTransport] diff --git a/libp2p/utils.py b/libp2p/utils.py index bce9e580..27880c48 100644 --- a/libp2p/utils.py +++ b/libp2p/utils.py @@ -1,14 +1,14 @@ import itertools import math +from typing import Generic, TypeVar + +import trio from libp2p.exceptions import ParseError from libp2p.io.abc import Reader from .io.utils import read_exactly -from typing import Generic, TypeVar -import trio - # Unsigned LEB128(varint codec) # Reference: https://github.com/ethereum/py-wasm/blob/master/wasm/parsers/leb128.py diff --git a/tests/libp2p/test_libp2p.py b/tests/libp2p/test_libp2p.py index 9f78dce5..628e008a 100644 --- a/tests/libp2p/test_libp2p.py +++ b/tests/libp2p/test_libp2p.py @@ -1,6 +1,6 @@ -import trio import multiaddr import pytest +import trio from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.tools.constants import MAX_READ_LEN @@ -24,11 +24,11 @@ async def test_simple_messages(nursery): # Associate the peer with local ip address (see default parameters of Libp2p()) node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) - stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) messages = ["hello" + str(x) for x in range(10)] for message in messages: + await stream.write(message.encode()) response = (await stream.read(MAX_READ_LEN)).decode() diff --git a/tests/stream_muxer/test_mplex_stream.py b/tests/stream_muxer/test_mplex_stream.py index f3458d8f..27c7f45e 100644 --- a/tests/stream_muxer/test_mplex_stream.py +++ b/tests/stream_muxer/test_mplex_stream.py @@ -1,20 +1,31 @@ import asyncio import pytest +import trio from libp2p.stream_muxer.mplex.exceptions import ( MplexStreamClosed, MplexStreamEOF, MplexStreamReset, ) -from libp2p.tools.constants import MAX_READ_LEN +from libp2p.tools.constants import MAX_READ_LEN, LISTEN_MADDR +from libp2p.tools.factories import SwarmFactory +from libp2p.tools.utils import connect_swarm DATA = b"data_123" -@pytest.mark.asyncio -async def test_mplex_stream_read_write(mplex_stream_pair): - stream_0, stream_1 = mplex_stream_pair +@pytest.mark.trio +async def test_mplex_stream_read_write(nursery): + swarm0, swarm1 = SwarmFactory(), SwarmFactory() + await swarm0.listen(LISTEN_MADDR, nursery=nursery) + await swarm1.listen(LISTEN_MADDR, nursery=nursery) + await connect_swarm(swarm0, swarm1, nursery) + conn_0 = swarm0.connections[swarm1.get_peer_id()] + conn_1 = swarm1.connections[swarm0.get_peer_id()] + stream_0 = await conn_0.muxed_conn.open_stream() + await trio.sleep(1) + stream_1 = tuple(conn_1.muxed_conn.streams.values())[0] await stream_0.write(DATA) assert (await stream_1.read(MAX_READ_LEN)) == DATA diff --git a/tests/test_utils.py b/tests/test_utils.py index a6284371..7bd807d5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,6 @@ -import trio import pytest +import trio + from libp2p.utils import TrioQueue @@ -16,4 +17,3 @@ async def test_trio_queue(): result = await nursery.start(queue_get) assert result == 123 - From 6ab0e108d3ac2b4fe4ce7a6028ffcd2c790a0ade Mon Sep 17 00:00:00 2001 From: Chih Cheng Liang Date: Tue, 19 Nov 2019 18:07:19 +0800 Subject: [PATCH 07/81] minor --- libp2p/transport/tcp/tcp.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/libp2p/transport/tcp/tcp.py b/libp2p/transport/tcp/tcp.py index 0bc1a962..40a890b2 100644 --- a/libp2p/transport/tcp/tcp.py +++ b/libp2p/transport/tcp/tcp.py @@ -48,11 +48,6 @@ class TCPListener(IListener): int(maddr.value_for_protocol("tcp")), maddr.value_for_protocol("ip4"), ) - # self.server = await asyncio.start_server( - # self.handler, - # maddr.value_for_protocol("ip4"), - # maddr.value_for_protocol("tcp"), - # ) socket = listeners[0].socket self.multiaddrs.append(_multiaddr_from_socket(socket)) From 50db9e14741e831e8c0ea84699c16268b8d1cf54 Mon Sep 17 00:00:00 2001 From: Chih Cheng Liang Date: Tue, 19 Nov 2019 18:21:21 +0800 Subject: [PATCH 08/81] add setup.py --- setup.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/setup.py b/setup.py index eead1bfe..cbb8eaf8 100644 --- a/setup.py +++ b/setup.py @@ -9,6 +9,7 @@ extras_require = { "pytest>=4.6.3,<5.0.0", "pytest-asyncio>=0.10.0,<1.0.0", "pytest-xdist>=1.30.0", + "pytest-trio>=0.5.2", ], "lint": [ "mypy>=0.701,<1.0", @@ -68,6 +69,8 @@ setuptools.setup( "coincurve>=10.0.0,<11.0.0", "fastecdsa==1.7.4", "pynacl==1.3.0", + "trio-asyncio>=0.10.0", + "trio>=0.13.0", ], extras_require=extras_require, packages=setuptools.find_packages(exclude=["tests", "tests.*"]), From 417b5e7d6161218ad90fd4086a50fabea871e44b Mon Sep 17 00:00:00 2001 From: Chih Cheng Liang Date: Tue, 19 Nov 2019 18:36:53 +0800 Subject: [PATCH 09/81] remove unused asyncio --- libp2p/__init__.py | 1 - libp2p/kademlia/__init__.py | 3 +-- libp2p/stream_muxer/mplex/mplex_stream.py | 1 - libp2p/transport/tcp/tcp.py | 1 - 4 files changed, 1 insertion(+), 5 deletions(-) diff --git a/libp2p/__init__.py b/libp2p/__init__.py index 43beeeae..dc103be0 100644 --- a/libp2p/__init__.py +++ b/libp2p/__init__.py @@ -1,4 +1,3 @@ -import asyncio from typing import Sequence from libp2p.crypto.keys import KeyPair diff --git a/libp2p/kademlia/__init__.py b/libp2p/kademlia/__init__.py index 14568595..e85d6e0a 100644 --- a/libp2p/kademlia/__init__.py +++ b/libp2p/kademlia/__init__.py @@ -1,3 +1,2 @@ -"""Kademlia is a Python implementation of the Kademlia protocol which utilizes -the asyncio library.""" +"""Kademlia is a Python implementation of the Kademlia protocol.""" __version__ = "2.0" diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 77458cbd..44b3aef0 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -1,4 +1,3 @@ -import asyncio from typing import TYPE_CHECKING import trio diff --git a/libp2p/transport/tcp/tcp.py b/libp2p/transport/tcp/tcp.py index 40a890b2..fc000476 100644 --- a/libp2p/transport/tcp/tcp.py +++ b/libp2p/transport/tcp/tcp.py @@ -1,4 +1,3 @@ -import asyncio import logging from socket import socket from typing import List From ec43c25b45ccdd1b8fef9a2e14056f6e9f0a73bf Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 26 Nov 2019 19:24:30 +0800 Subject: [PATCH 10/81] Rewrite factories, made some of the test running --- libp2p/network/connection/swarm_connection.py | 49 +++---- libp2p/network/swarm.py | 75 ++++++---- libp2p/stream_muxer/mplex/mplex.py | 25 ++-- libp2p/stream_muxer/mplex/mplex_stream.py | 1 - libp2p/tools/factories.py | 135 +++++++++--------- libp2p/tools/utils.py | 4 +- libp2p/transport/tcp/tcp.py | 2 +- tests/host/test_ping.py | 6 +- tests/identity/identify/test_protocol.py | 4 +- tests/network/conftest.py | 21 +-- tests/network/test_swarm.py | 129 ++++++++--------- tests/stream_muxer/conftest.py | 22 +-- tests/stream_muxer/test_mplex_stream.py | 69 +++++---- 13 files changed, 260 insertions(+), 282 deletions(-) diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index 29d544eb..b91783d1 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -1,6 +1,8 @@ -import asyncio from typing import TYPE_CHECKING, Any, Awaitable, List, Set, Tuple +import trio +from async_service import Service + from libp2p.network.connection.net_connection_interface import INetConn from libp2p.network.stream.net_stream import NetStream from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream @@ -15,21 +17,17 @@ Reference: https://github.com/libp2p/go-libp2p-swarm/blob/04c86bbdafd390651cb2ee """ -class SwarmConn(INetConn): +class SwarmConn(INetConn, Service): muxed_conn: IMuxedConn swarm: "Swarm" streams: Set[NetStream] - event_closed: asyncio.Event - - _tasks: List["asyncio.Future[Any]"] + event_closed: trio.Event def __init__(self, muxed_conn: IMuxedConn, swarm: "Swarm") -> None: self.muxed_conn = muxed_conn self.swarm = swarm self.streams = set() - self.event_closed = asyncio.Event() - - self._tasks = [] + self.event_closed = trio.Event() async def close(self) -> None: if self.event_closed.is_set(): @@ -45,16 +43,11 @@ class SwarmConn(INetConn): await stream.reset() # Force context switch for stream handlers to process the stream reset event we just emit # before we cancel the stream handler tasks. - await asyncio.sleep(0.1) + await trio.sleep(0.1) - for task in self._tasks: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass + # FIXME: Now let `_notify_disconnected` finish first. # Schedule `self._notify_disconnected` to make it execute after `close` is finished. - self._notify_disconnected() + await self._notify_disconnected() async def _handle_new_streams(self) -> None: while True: @@ -65,7 +58,7 @@ class SwarmConn(INetConn): # we should break the loop and close the connection. break # Asynchronously handle the accepted stream, to avoid blocking the next stream. - await self.run_task(self._handle_muxed_stream(stream)) + self.manager.run_task(self._handle_muxed_stream, stream) await self.close() @@ -79,28 +72,26 @@ class SwarmConn(INetConn): self.remove_stream(net_stream) async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None: - net_stream = self._add_stream(muxed_stream) + net_stream = await self._add_stream(muxed_stream) if self.swarm.common_stream_handler is not None: - await self.run_task(self._call_stream_handler(net_stream)) + await self._call_stream_handler(net_stream) - def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: + async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: net_stream = NetStream(muxed_stream) self.streams.add(net_stream) - self.swarm.notify_opened_stream(net_stream) + await self.swarm.notify_opened_stream(net_stream) return net_stream - def _notify_disconnected(self) -> None: - self.swarm.notify_disconnected(self) + async def _notify_disconnected(self) -> None: + await self.swarm.notify_disconnected(self) - async def start(self) -> None: - await self.run_task(self._handle_new_streams()) - - async def run_task(self, coro: Awaitable[Any]) -> None: - self._tasks.append(asyncio.ensure_future(coro)) + async def run(self) -> None: + self.manager.run_task(self._handle_new_streams) + await self.manager.wait_finished() async def new_stream(self) -> NetStream: muxed_stream = await self.muxed_conn.open_stream() - return self._add_stream(muxed_stream) + return await self._add_stream(muxed_stream) async def get_streams(self) -> Tuple[NetStream, ...]: return tuple(self.streams) diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index e54ad6f7..337da896 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -1,7 +1,8 @@ -import asyncio import logging from typing import Dict, List, Optional +from async_service import Service + from multiaddr import Multiaddr import trio @@ -31,7 +32,7 @@ from .stream.net_stream_interface import INetStream logger = logging.getLogger("libp2p.network.swarm") -class Swarm(INetwork): +class Swarm(INetwork, Service): self_id: ID peerstore: IPeerStore @@ -64,13 +65,16 @@ class Swarm(INetwork): self.common_stream_handler = None + async def run(self) -> None: + await self.manager.wait_finished() + def get_peer_id(self) -> ID: return self.self_id def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None: self.common_stream_handler = stream_handler - async def dial_peer(self, peer_id: ID, nursery) -> INetConn: + async def dial_peer(self, peer_id: ID) -> INetConn: """ dial_peer try to create a connection to peer_id. @@ -122,7 +126,7 @@ class Swarm(INetwork): try: muxed_conn = await self.upgrader.upgrade_connection(secured_conn, peer_id) - muxed_conn.run(nursery) + self.manager.run_child_service(muxed_conn) except MuxerUpgradeFailure as error: error_msg = "fail to upgrade mux for peer %s" logger.debug(error_msg, peer_id) @@ -137,7 +141,7 @@ class Swarm(INetwork): return swarm_conn - async def new_stream(self, peer_id: ID, nursery) -> INetStream: + async def new_stream(self, peer_id: ID) -> INetStream: """ :param peer_id: peer_id of destination :param protocol_id: protocol id @@ -146,13 +150,13 @@ class Swarm(INetwork): """ logger.debug("attempting to open a stream to peer %s", peer_id) - swarm_conn = await self.dial_peer(peer_id, nursery) + swarm_conn = await self.dial_peer(peer_id) net_stream = await swarm_conn.new_stream() logger.debug("successfully opened a stream to peer %s", peer_id) return net_stream - async def listen(self, *multiaddrs: Multiaddr, nursery) -> bool: + async def listen(self, *multiaddrs: Multiaddr) -> bool: """ :param multiaddrs: one or many multiaddrs to start listening on :return: true if at least one success @@ -189,7 +193,7 @@ class Swarm(INetwork): muxed_conn = await self.upgrader.upgrade_connection( secured_conn, peer_id ) - muxed_conn.run(nursery) + self.manager.run_child_service(muxed_conn) except MuxerUpgradeFailure as error: error_msg = "fail to upgrade mux for peer %s" logger.debug(error_msg, peer_id) @@ -200,6 +204,8 @@ class Swarm(INetwork): await self.add_conn(muxed_conn) logger.debug("successfully opened connection to peer %s", peer_id) + # FIXME: This is a intentional barrier to prevent from the handler exiting and + # closing the connection. event = trio.Event() await event.wait() @@ -207,10 +213,11 @@ class Swarm(INetwork): # Success listener = self.transport.create_listener(conn_handler) self.listeners[str(maddr)] = listener - await listener.listen(maddr, nursery) + # FIXME: Hack + await listener.listen(maddr, self.manager._task_nursery) # Call notifiers since event occurred - self.notify_listen(maddr) + await self.notify_listen(maddr) return True except IOError: @@ -225,15 +232,16 @@ class Swarm(INetwork): # Reference: https://github.com/libp2p/go-libp2p-swarm/blob/8be680aef8dea0a4497283f2f98470c2aeae6b65/swarm.go#L124-L134 # noqa: E501 # Close listeners - await asyncio.gather( - *[listener.close() for listener in self.listeners.values()] - ) - - # Close connections - await asyncio.gather( - *[connection.close() for connection in self.connections.values()] - ) + # await asyncio.gather( + # *[listener.close() for listener in self.listeners.values()] + # ) + # # Close connections + # await asyncio.gather( + # *[connection.close() for connection in self.connections.values()] + # ) + self.manager.stop() + await self.manager.wait_finished() logger.debug("swarm successfully closed") async def close_peer(self, peer_id: ID) -> None: @@ -253,11 +261,12 @@ class Swarm(INetwork): and start to monitor the connection for its new streams and disconnection.""" swarm_conn = SwarmConn(muxed_conn, self) + manager = self.manager.run_child_service(swarm_conn) # Store muxed_conn with peer id self.connections[muxed_conn.peer_id] = swarm_conn # Call notifiers since event occurred - self.notify_connected(swarm_conn) - await swarm_conn.start() + self.manager.run_task(self.notify_connected, swarm_conn) + await manager.wait_started() return swarm_conn def remove_conn(self, swarm_conn: SwarmConn) -> None: @@ -281,20 +290,26 @@ class Swarm(INetwork): """ self.notifees.append(notifee) - def notify_opened_stream(self, stream: INetStream) -> None: - asyncio.gather( - *[notifee.opened_stream(self, stream) for notifee in self.notifees] - ) + async def notify_opened_stream(self, stream: INetStream) -> None: + async with trio.open_nursery() as nursery: + for notifee in self.notifees: + nursery.start_soon(notifee.opened_stream, self, stream) # TODO: `notify_closed_stream` - def notify_connected(self, conn: INetConn) -> None: - asyncio.gather(*[notifee.connected(self, conn) for notifee in self.notifees]) + async def notify_connected(self, conn: INetConn) -> None: + async with trio.open_nursery() as nursery: + for notifee in self.notifees: + nursery.start_soon(notifee.connected, self, conn) - def notify_disconnected(self, conn: INetConn) -> None: - asyncio.gather(*[notifee.disconnected(self, conn) for notifee in self.notifees]) + async def notify_disconnected(self, conn: INetConn) -> None: + async with trio.open_nursery() as nursery: + for notifee in self.notifees: + nursery.start_soon(notifee.disconnected, self, conn) - def notify_listen(self, multiaddr: Multiaddr) -> None: - asyncio.gather(*[notifee.listen(self, multiaddr) for notifee in self.notifees]) + async def notify_listen(self, multiaddr: Multiaddr) -> None: + async with trio.open_nursery() as nursery: + for notifee in self.notifees: + nursery.start_soon(notifee.listen, self, multiaddr) # TODO: `notify_listen_close` diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index b6df526b..9ac56143 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -4,6 +4,7 @@ from typing import Any # noqa: F401 from typing import Awaitable, Dict, List, Optional, Tuple import trio +from async_service import Service from libp2p.exceptions import ParseError from libp2p.io.exceptions import IncompleteReadError @@ -17,6 +18,7 @@ from libp2p.utils import ( encode_uvarint, encode_varint_prefixed, read_varint_prefixed_bytes, + TrioQueue, ) from .constants import HeaderTags @@ -29,7 +31,7 @@ MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0") logger = logging.getLogger("libp2p.stream_muxer.mplex.mplex") -class Mplex(IMuxedConn): +class Mplex(IMuxedConn, Service): """ reference: https://github.com/libp2p/go-mplex/blob/master/multiplex.go """ @@ -38,10 +40,10 @@ class Mplex(IMuxedConn): peer_id: ID next_channel_id: int streams: Dict[StreamID, MplexStream] - streams_lock: asyncio.Lock - new_stream_queue: "asyncio.Queue[IMuxedStream]" - event_shutting_down: asyncio.Event - event_closed: asyncio.Event + streams_lock: trio.Lock + new_stream_queue: "TrioQueue[IMuxedStream]" + event_shutting_down: trio.Event + event_closed: trio.Event def __init__(self, secured_conn: ISecureConn, peer_id: ID) -> None: """ @@ -61,13 +63,14 @@ class Mplex(IMuxedConn): # Mapping from stream ID -> buffer of messages for that stream self.streams = {} - self.streams_lock = asyncio.Lock() - self.new_stream_queue = asyncio.Queue() - self.event_shutting_down = asyncio.Event() - self.event_closed = asyncio.Event() + self.streams_lock = trio.Lock() + self.new_stream_queue = TrioQueue() + self.event_shutting_down = trio.Event() + self.event_closed = trio.Event() - def run(self, nursery): - nursery.start_soon(self.handle_incoming) + async def run(self): + self.manager.run_task(self.handle_incoming) + await self.manager.wait_finished() @property def is_initiator(self) -> bool: diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 44b3aef0..58da4040 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -136,7 +136,6 @@ class MplexStream(IMuxedStream): nursery.start_soon( self.muxed_conn.send_message, flag, None, self.stream_id ) - await trio.sleep(0) self.event_local_closed.set() self.event_remote_closed.set() diff --git a/libp2p/tools/factories.py b/libp2p/tools/factories.py index 2b63544d..f324371a 100644 --- a/libp2p/tools/factories.py +++ b/libp2p/tools/factories.py @@ -1,8 +1,9 @@ -import asyncio -from contextlib import asynccontextmanager +import trio +from contextlib import asynccontextmanager, AsyncExitStack from typing import Any, AsyncIterator, Dict, Tuple, cast import factory +from async_service import background_trio_service from libp2p import generate_new_rsa_identity, generate_peer_id_from from libp2p.crypto.keys import KeyPair @@ -61,6 +62,7 @@ class SwarmFactory(factory.Factory): transport = factory.LazyFunction(TCP) @classmethod + @asynccontextmanager async def create_and_listen( cls, is_secure: bool, key_pair: KeyPair = None, muxer_opt: TMuxerOptions = None ) -> Swarm: @@ -73,20 +75,23 @@ class SwarmFactory(factory.Factory): if muxer_opt is not None: optional_kwargs["muxer_opt"] = muxer_opt swarm = cls(is_secure=is_secure, **optional_kwargs) - await swarm.listen(LISTEN_MADDR) - return swarm + async with background_trio_service(swarm): + await swarm.listen(LISTEN_MADDR) + yield swarm @classmethod + @asynccontextmanager async def create_batch_and_listen( cls, is_secure: bool, number: int, muxer_opt: TMuxerOptions = None ) -> Tuple[Swarm, ...]: - # Ignore typing since we are removing asyncio soon - return await asyncio.gather( # type: ignore - *[ - cls.create_and_listen(is_secure=is_secure, muxer_opt=muxer_opt) + async with AsyncExitStack() as stack: + ctx_mgrs = [ + await stack.enter_async_context( + cls.create_and_listen(is_secure=is_secure, muxer_opt=muxer_opt) + ) for _ in range(number) ] - ) + yield ctx_mgrs class HostFactory(factory.Factory): @@ -103,20 +108,23 @@ class HostFactory(factory.Factory): ) @classmethod + @asynccontextmanager async def create_batch_and_listen( cls, is_secure: bool, number: int ) -> Tuple[BasicHost, ...]: key_pairs = [generate_new_rsa_identity() for _ in range(number)] - swarms = await asyncio.gather( - *[ - SwarmFactory.create_and_listen(is_secure, key_pair) + async with AsyncExitStack() as stack: + swarms = [ + await stack.enter_async_context( + SwarmFactory.create_and_listen(is_secure, key_pair) + ) for key_pair in key_pairs ] - ) - return tuple( - BasicHost(key_pair.public_key, swarm) - for key_pair, swarm in zip(key_pairs, swarms) - ) + hosts = tuple( + BasicHost(key_pair.public_key, swarm) + for key_pair, swarm in zip(key_pairs, swarms) + ) + yield hosts class FloodsubFactory(factory.Factory): @@ -150,73 +158,60 @@ class PubsubFactory(factory.Factory): cache_size = None +@asynccontextmanager async def swarm_pair_factory( is_secure: bool, muxer_opt: TMuxerOptions = None ) -> Tuple[Swarm, Swarm]: - swarms = await SwarmFactory.create_batch_and_listen( + async with SwarmFactory.create_batch_and_listen( is_secure, 2, muxer_opt=muxer_opt - ) - await connect_swarm(swarms[0], swarms[1]) - return swarms[0], swarms[1] - - -async def host_pair_factory(is_secure: bool) -> Tuple[BasicHost, BasicHost]: - hosts = await HostFactory.create_batch_and_listen(is_secure, 2) - await connect(hosts[0], hosts[1]) - return hosts[0], hosts[1] + ) as swarms: + await connect_swarm(swarms[0], swarms[1]) + yield swarms[0], swarms[1] @asynccontextmanager -async def pair_of_connected_hosts( - is_secure: bool = True -) -> AsyncIterator[Tuple[BasicHost, BasicHost]]: - a, b = await host_pair_factory(is_secure) - yield a, b - close_tasks = (a.close(), b.close()) - await asyncio.gather(*close_tasks) +async def host_pair_factory(is_secure: bool) -> Tuple[BasicHost, BasicHost]: + async with HostFactory.create_batch_and_listen(is_secure, 2) as hosts: + await connect(hosts[0], hosts[1]) + yield hosts[0], hosts[1] +@asynccontextmanager async def swarm_conn_pair_factory( is_secure: bool, muxer_opt: TMuxerOptions = None -) -> Tuple[SwarmConn, Swarm, SwarmConn, Swarm]: - swarms = await swarm_pair_factory(is_secure) - conn_0 = swarms[0].connections[swarms[1].get_peer_id()] - conn_1 = swarms[1].connections[swarms[0].get_peer_id()] - return cast(SwarmConn, conn_0), swarms[0], cast(SwarmConn, conn_1), swarms[1] +) -> Tuple[SwarmConn, SwarmConn]: + async with swarm_pair_factory(is_secure) as swarms: + conn_0 = swarms[0].connections[swarms[1].get_peer_id()] + conn_1 = swarms[1].connections[swarms[0].get_peer_id()] + yield cast(SwarmConn, conn_0), cast(SwarmConn, conn_1) -async def mplex_conn_pair_factory(is_secure: bool) -> Tuple[Mplex, Swarm, Mplex, Swarm]: +@asynccontextmanager +async def mplex_conn_pair_factory(is_secure: bool) -> Tuple[Mplex, Mplex]: muxer_opt = {MPLEX_PROTOCOL_ID: Mplex} - conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory( - is_secure, muxer_opt=muxer_opt - ) - return ( - cast(Mplex, conn_0.muxed_conn), - swarm_0, - cast(Mplex, conn_1.muxed_conn), - swarm_1, - ) + async with swarm_conn_pair_factory(is_secure, muxer_opt=muxer_opt) as swarm_pair: + yield ( + cast(Mplex, swarm_pair[0].muxed_conn), + cast(Mplex, swarm_pair[1].muxed_conn), + ) -async def mplex_stream_pair_factory( - is_secure: bool -) -> Tuple[MplexStream, Swarm, MplexStream, Swarm]: - mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory( - is_secure - ) - stream_0 = await mplex_conn_0.open_stream() - await asyncio.sleep(0.01) - stream_1: MplexStream - async with mplex_conn_1.streams_lock: - if len(mplex_conn_1.streams) != 1: - raise Exception("Mplex should not have any stream upon connection") - stream_1 = tuple(mplex_conn_1.streams.values())[0] - return cast(MplexStream, stream_0), swarm_0, stream_1, swarm_1 +@asynccontextmanager +async def mplex_stream_pair_factory(is_secure: bool) -> Tuple[MplexStream, MplexStream]: + async with mplex_conn_pair_factory(is_secure) as mplex_conn_pair_info: + mplex_conn_0, mplex_conn_1 = mplex_conn_pair_info + stream_0 = await mplex_conn_0.open_stream() + await trio.sleep(0.01) + stream_1: MplexStream + async with mplex_conn_1.streams_lock: + if len(mplex_conn_1.streams) != 1: + raise Exception("Mplex should not have any stream upon connection") + stream_1 = tuple(mplex_conn_1.streams.values())[0] + yield cast(MplexStream, stream_0), cast(MplexStream, stream_1) -async def net_stream_pair_factory( - is_secure: bool -) -> Tuple[INetStream, BasicHost, INetStream, BasicHost]: +@asynccontextmanager +async def net_stream_pair_factory(is_secure: bool) -> Tuple[INetStream, INetStream]: protocol_id = TProtocol("/example/id/1") stream_1: INetStream @@ -226,8 +221,8 @@ async def net_stream_pair_factory( nonlocal stream_1 stream_1 = stream - host_0, host_1 = await host_pair_factory(is_secure) - host_1.set_stream_handler(protocol_id, handler) + async with host_pair_factory(is_secure) as hosts: + hosts[1].set_stream_handler(protocol_id, handler) - stream_0 = await host_0.new_stream(host_1.get_id(), [protocol_id]) - return stream_0, host_0, stream_1, host_1 + stream_0 = await hosts[0].new_stream(hosts[1].get_id(), [protocol_id]) + yield stream_0, stream_1 diff --git a/libp2p/tools/utils.py b/libp2p/tools/utils.py index d1c266a0..0d39156b 100644 --- a/libp2p/tools/utils.py +++ b/libp2p/tools/utils.py @@ -17,7 +17,7 @@ from libp2p.typing import StreamHandlerFn, TProtocol from .constants import MAX_READ_LEN -async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm, nursery: trio.Nursery) -> None: +async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm) -> None: peer_id = swarm_1.get_peer_id() addrs = tuple( addr @@ -25,7 +25,7 @@ async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm, nursery: trio.Nursery) - for addr in transport.get_addrs() ) swarm_0.peerstore.add_addrs(peer_id, addrs, 10000) - await swarm_0.dial_peer(peer_id, nursery) + await swarm_0.dial_peer(peer_id) assert swarm_0.get_peer_id() in swarm_1.connections assert swarm_1.get_peer_id() in swarm_0.connections diff --git a/libp2p/transport/tcp/tcp.py b/libp2p/transport/tcp/tcp.py index fc000476..16365982 100644 --- a/libp2p/transport/tcp/tcp.py +++ b/libp2p/transport/tcp/tcp.py @@ -25,7 +25,7 @@ class TCPListener(IListener): self.server = None self.handler = handler_function - async def listen(self, maddr: Multiaddr, nursery) -> bool: + async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool: """ put listener in listening mode and wait for incoming connections. diff --git a/tests/host/test_ping.py b/tests/host/test_ping.py index fcc5a850..1bd02f0f 100644 --- a/tests/host/test_ping.py +++ b/tests/host/test_ping.py @@ -4,12 +4,12 @@ import secrets import pytest from libp2p.host.ping import ID, PING_LENGTH -from libp2p.tools.factories import pair_of_connected_hosts +from libp2p.tools.factories import host_pair_factory @pytest.mark.asyncio async def test_ping_once(): - async with pair_of_connected_hosts() as (host_a, host_b): + async with host_pair_factory() as (host_a, host_b): stream = await host_b.new_stream(host_a.get_id(), (ID,)) some_ping = secrets.token_bytes(PING_LENGTH) await stream.write(some_ping) @@ -23,7 +23,7 @@ SOME_PING_COUNT = 3 @pytest.mark.asyncio async def test_ping_several(): - async with pair_of_connected_hosts() as (host_a, host_b): + async with host_pair_factory() as (host_a, host_b): stream = await host_b.new_stream(host_a.get_id(), (ID,)) for _ in range(SOME_PING_COUNT): some_ping = secrets.token_bytes(PING_LENGTH) diff --git a/tests/identity/identify/test_protocol.py b/tests/identity/identify/test_protocol.py index fab78ec1..6136c876 100644 --- a/tests/identity/identify/test_protocol.py +++ b/tests/identity/identify/test_protocol.py @@ -2,12 +2,12 @@ import pytest from libp2p.identity.identify.pb.identify_pb2 import Identify from libp2p.identity.identify.protocol import ID, _mk_identify_protobuf -from libp2p.tools.factories import pair_of_connected_hosts +from libp2p.tools.factories import host_pair_factory @pytest.mark.asyncio async def test_identify_protocol(): - async with pair_of_connected_hosts() as (host_a, host_b): + async with host_pair_factory() as (host_a, host_b): stream = await host_b.new_stream(host_a.get_id(), (ID,)) response = await stream.read() await stream.close() diff --git a/tests/network/conftest.py b/tests/network/conftest.py index 6b75b756..c45dbdba 100644 --- a/tests/network/conftest.py +++ b/tests/network/conftest.py @@ -11,26 +11,17 @@ from libp2p.tools.factories import ( @pytest.fixture async def net_stream_pair(is_host_secure): - stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory(is_host_secure) - try: - yield stream_0, stream_1 - finally: - await asyncio.gather(*[host_0.close(), host_1.close()]) + async with net_stream_pair_factory(is_host_secure) as net_stream_pair: + yield net_stream_pair @pytest.fixture async def swarm_pair(is_host_secure): - swarm_0, swarm_1 = await swarm_pair_factory(is_host_secure) - try: - yield swarm_0, swarm_1 - finally: - await asyncio.gather(*[swarm_0.close(), swarm_1.close()]) + async with swarm_pair_factory(is_host_secure) as swarms: + yield swarms @pytest.fixture async def swarm_conn_pair(is_host_secure): - conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory(is_host_secure) - try: - yield conn_0, conn_1 - finally: - await asyncio.gather(*[swarm_0.close(), swarm_1.close()]) + async with swarm_conn_pair_factory(is_host_secure) as swarm_conn_pair: + yield swarm_conn_pair diff --git a/tests/network/test_swarm.py b/tests/network/test_swarm.py index 6fe25434..de086352 100644 --- a/tests/network/test_swarm.py +++ b/tests/network/test_swarm.py @@ -1,88 +1,83 @@ -import asyncio - +import trio import pytest +from trio.testing import wait_all_tasks_blocked from libp2p.network.exceptions import SwarmException from libp2p.tools.factories import SwarmFactory from libp2p.tools.utils import connect_swarm -@pytest.mark.asyncio +@pytest.mark.trio async def test_swarm_dial_peer(is_host_secure): - swarms = await SwarmFactory.create_batch_and_listen(is_host_secure, 3) - # Test: No addr found. - with pytest.raises(SwarmException): + async with SwarmFactory.create_batch_and_listen(is_host_secure, 3) as swarms: + # Test: No addr found. + with pytest.raises(SwarmException): + await swarms[0].dial_peer(swarms[1].get_peer_id()) + + # Test: len(addr) in the peerstore is 0. + swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), [], 10000) + with pytest.raises(SwarmException): + await swarms[0].dial_peer(swarms[1].get_peer_id()) + + # Test: Succeed if addrs of the peer_id are present in the peerstore. + addrs = tuple( + addr + for transport in swarms[1].listeners.values() + for addr in transport.get_addrs() + ) + swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000) await swarms[0].dial_peer(swarms[1].get_peer_id()) + assert swarms[0].get_peer_id() in swarms[1].connections + assert swarms[1].get_peer_id() in swarms[0].connections - # Test: len(addr) in the peerstore is 0. - swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), [], 10000) - with pytest.raises(SwarmException): - await swarms[0].dial_peer(swarms[1].get_peer_id()) - - # Test: Succeed if addrs of the peer_id are present in the peerstore. - addrs = tuple( - addr - for transport in swarms[1].listeners.values() - for addr in transport.get_addrs() - ) - swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000) - await swarms[0].dial_peer(swarms[1].get_peer_id()) - assert swarms[0].get_peer_id() in swarms[1].connections - assert swarms[1].get_peer_id() in swarms[0].connections - - # Test: Reuse connections when we already have ones with a peer. - conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()] - conn = await swarms[0].dial_peer(swarms[1].get_peer_id()) - assert conn is conn_to_1 - - # Clean up - await asyncio.gather(*[swarm.close() for swarm in swarms]) + # Test: Reuse connections when we already have ones with a peer. + conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()] + conn = await swarms[0].dial_peer(swarms[1].get_peer_id()) + assert conn is conn_to_1 -@pytest.mark.asyncio +@pytest.mark.trio async def test_swarm_close_peer(is_host_secure): - swarms = await SwarmFactory.create_batch_and_listen(is_host_secure, 3) - # 0 <> 1 <> 2 - await connect_swarm(swarms[0], swarms[1]) - await connect_swarm(swarms[1], swarms[2]) + async with SwarmFactory.create_batch_and_listen(is_host_secure, 3) as swarms: + # 0 <> 1 <> 2 + await connect_swarm(swarms[0], swarms[1]) + await connect_swarm(swarms[1], swarms[2]) - # peer 1 closes peer 0 - await swarms[1].close_peer(swarms[0].get_peer_id()) - await asyncio.sleep(0.01) - # 0 1 <> 2 - assert len(swarms[0].connections) == 0 - assert ( - len(swarms[1].connections) == 1 - and swarms[2].get_peer_id() in swarms[1].connections - ) + # peer 1 closes peer 0 + await swarms[1].close_peer(swarms[0].get_peer_id()) + await trio.sleep(0.01) + await wait_all_tasks_blocked() + # 0 1 <> 2 + assert len(swarms[0].connections) == 0 + assert ( + len(swarms[1].connections) == 1 + and swarms[2].get_peer_id() in swarms[1].connections + ) - # peer 1 is closed by peer 2 - await swarms[2].close_peer(swarms[1].get_peer_id()) - await asyncio.sleep(0.01) - # 0 1 2 - assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0 + # peer 1 is closed by peer 2 + await swarms[2].close_peer(swarms[1].get_peer_id()) + await trio.sleep(0.01) + # 0 1 2 + assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0 - await connect_swarm(swarms[0], swarms[1]) - # 0 <> 1 2 - assert ( - len(swarms[0].connections) == 1 - and swarms[1].get_peer_id() in swarms[0].connections - ) - assert ( - len(swarms[1].connections) == 1 - and swarms[0].get_peer_id() in swarms[1].connections - ) - # peer 0 closes peer 1 - await swarms[0].close_peer(swarms[1].get_peer_id()) - await asyncio.sleep(0.01) - # 0 1 2 - assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0 - - # Clean up - await asyncio.gather(*[swarm.close() for swarm in swarms]) + await connect_swarm(swarms[0], swarms[1]) + # 0 <> 1 2 + assert ( + len(swarms[0].connections) == 1 + and swarms[1].get_peer_id() in swarms[0].connections + ) + assert ( + len(swarms[1].connections) == 1 + and swarms[0].get_peer_id() in swarms[1].connections + ) + # peer 0 closes peer 1 + await swarms[0].close_peer(swarms[1].get_peer_id()) + await trio.sleep(0.01) + # 0 1 2 + assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0 -@pytest.mark.asyncio +@pytest.mark.trio async def test_swarm_remove_conn(swarm_pair): swarm_0, swarm_1 = swarm_pair conn_0 = swarm_0.connections[swarm_1.get_peer_id()] diff --git a/tests/stream_muxer/conftest.py b/tests/stream_muxer/conftest.py index cdb57e8f..5c5bc2bb 100644 --- a/tests/stream_muxer/conftest.py +++ b/tests/stream_muxer/conftest.py @@ -7,23 +7,13 @@ from libp2p.tools.factories import mplex_conn_pair_factory, mplex_stream_pair_fa @pytest.fixture async def mplex_conn_pair(is_host_secure): - mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory( - is_host_secure - ) - assert mplex_conn_0.is_initiator - assert not mplex_conn_1.is_initiator - try: - yield mplex_conn_0, mplex_conn_1 - finally: - await asyncio.gather(*[swarm_0.close(), swarm_1.close()]) + async with mplex_conn_pair_factory(is_host_secure) as mplex_conn_pair: + assert mplex_conn_pair[0].is_initiator + assert not mplex_conn_pair[1].is_initiator + yield mplex_conn_pair[0], mplex_conn_pair[1] @pytest.fixture async def mplex_stream_pair(is_host_secure): - mplex_stream_0, swarm_0, mplex_stream_1, swarm_1 = await mplex_stream_pair_factory( - is_host_secure - ) - try: - yield mplex_stream_0, mplex_stream_1 - finally: - await asyncio.gather(*[swarm_0.close(), swarm_1.close()]) + async with mplex_stream_pair_factory(is_host_secure) as mplex_stream_pair: + yield mplex_stream_pair diff --git a/tests/stream_muxer/test_mplex_stream.py b/tests/stream_muxer/test_mplex_stream.py index 27c7f45e..dae66571 100644 --- a/tests/stream_muxer/test_mplex_stream.py +++ b/tests/stream_muxer/test_mplex_stream.py @@ -1,5 +1,3 @@ -import asyncio - import pytest import trio @@ -12,25 +10,26 @@ from libp2p.tools.constants import MAX_READ_LEN, LISTEN_MADDR from libp2p.tools.factories import SwarmFactory from libp2p.tools.utils import connect_swarm + DATA = b"data_123" @pytest.mark.trio -async def test_mplex_stream_read_write(nursery): - swarm0, swarm1 = SwarmFactory(), SwarmFactory() - await swarm0.listen(LISTEN_MADDR, nursery=nursery) - await swarm1.listen(LISTEN_MADDR, nursery=nursery) - await connect_swarm(swarm0, swarm1, nursery) - conn_0 = swarm0.connections[swarm1.get_peer_id()] - conn_1 = swarm1.connections[swarm0.get_peer_id()] - stream_0 = await conn_0.muxed_conn.open_stream() - await trio.sleep(1) - stream_1 = tuple(conn_1.muxed_conn.streams.values())[0] - await stream_0.write(DATA) - assert (await stream_1.read(MAX_READ_LEN)) == DATA +async def test_mplex_stream_read_write(): + async with SwarmFactory.create_batch_and_listen(False, 2) as swarms: + await swarms[0].listen(LISTEN_MADDR) + await swarms[1].listen(LISTEN_MADDR) + await connect_swarm(swarms[0], swarms[1]) + conn_0 = swarms[0].connections[swarms[1].get_peer_id()] + conn_1 = swarms[1].connections[swarms[0].get_peer_id()] + stream_0 = await conn_0.muxed_conn.open_stream() + await trio.sleep(1) + stream_1 = tuple(conn_1.muxed_conn.streams.values())[0] + await stream_0.write(DATA) + assert (await stream_1.read(MAX_READ_LEN)) == DATA -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair): read_bytes = bytearray() stream_0, stream_1 = mplex_stream_pair @@ -38,43 +37,43 @@ async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair): async def read_until_eof(): read_bytes.extend(await stream_1.read()) - task = asyncio.ensure_future(read_until_eof()) + task = trio.ensure_future(read_until_eof()) expected_data = bytearray() # Test: `read` doesn't return before `close` is called. await stream_0.write(DATA) expected_data.extend(DATA) - await asyncio.sleep(0.01) + await trio.sleep(0.01) assert len(read_bytes) == 0 # Test: `read` doesn't return before `close` is called. await stream_0.write(DATA) expected_data.extend(DATA) - await asyncio.sleep(0.01) + await trio.sleep(0.01) assert len(read_bytes) == 0 # Test: Close the stream, `read` returns, and receive previous sent data. await stream_0.close() - await asyncio.sleep(0.01) + await trio.sleep(0.01) assert read_bytes == expected_data task.cancel() -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_stream_read_after_remote_closed(mplex_stream_pair): stream_0, stream_1 = mplex_stream_pair assert not stream_1.event_remote_closed.is_set() await stream_0.write(DATA) await stream_0.close() - await asyncio.sleep(0.01) + await trio.sleep(0.01) assert stream_1.event_remote_closed.is_set() assert (await stream_1.read(MAX_READ_LEN)) == DATA with pytest.raises(MplexStreamEOF): await stream_1.read(MAX_READ_LEN) -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_stream_read_after_local_reset(mplex_stream_pair): stream_0, stream_1 = mplex_stream_pair await stream_0.reset() @@ -82,29 +81,29 @@ async def test_mplex_stream_read_after_local_reset(mplex_stream_pair): await stream_0.read(MAX_READ_LEN) -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_stream_read_after_remote_reset(mplex_stream_pair): stream_0, stream_1 = mplex_stream_pair await stream_0.write(DATA) await stream_0.reset() # Sleep to let `stream_1` receive the message. - await asyncio.sleep(0.01) + await trio.sleep(0.01) with pytest.raises(MplexStreamReset): await stream_1.read(MAX_READ_LEN) -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_stream_read_after_remote_closed_and_reset(mplex_stream_pair): stream_0, stream_1 = mplex_stream_pair await stream_0.write(DATA) await stream_0.close() await stream_0.reset() # Sleep to let `stream_1` receive the message. - await asyncio.sleep(0.01) + await trio.sleep(0.01) assert (await stream_1.read(MAX_READ_LEN)) == DATA -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_stream_write_after_local_closed(mplex_stream_pair): stream_0, stream_1 = mplex_stream_pair await stream_0.write(DATA) @@ -113,7 +112,7 @@ async def test_mplex_stream_write_after_local_closed(mplex_stream_pair): await stream_0.write(DATA) -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_stream_write_after_local_reset(mplex_stream_pair): stream_0, stream_1 = mplex_stream_pair await stream_0.reset() @@ -121,16 +120,16 @@ async def test_mplex_stream_write_after_local_reset(mplex_stream_pair): await stream_0.write(DATA) -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_stream_write_after_remote_reset(mplex_stream_pair): stream_0, stream_1 = mplex_stream_pair await stream_1.reset() - await asyncio.sleep(0.01) + await trio.sleep(0.01) with pytest.raises(MplexStreamClosed): await stream_0.write(DATA) -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_stream_both_close(mplex_stream_pair): stream_0, stream_1 = mplex_stream_pair # Flags are not set initially. @@ -144,7 +143,7 @@ async def test_mplex_stream_both_close(mplex_stream_pair): # Test: Close one side. await stream_0.close() - await asyncio.sleep(0.01) + await trio.sleep(0.01) assert stream_0.event_local_closed.is_set() assert not stream_1.event_local_closed.is_set() @@ -156,7 +155,7 @@ async def test_mplex_stream_both_close(mplex_stream_pair): # Test: Close the other side. await stream_1.close() - await asyncio.sleep(0.01) + await trio.sleep(0.01) # Both sides are closed. assert stream_0.event_local_closed.is_set() assert stream_1.event_local_closed.is_set() @@ -170,11 +169,11 @@ async def test_mplex_stream_both_close(mplex_stream_pair): await stream_0.reset() -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_stream_reset(mplex_stream_pair): stream_0, stream_1 = mplex_stream_pair await stream_0.reset() - await asyncio.sleep(0.01) + await trio.sleep(0.01) # Both sides are closed. assert stream_0.event_local_closed.is_set() From 1e600ea7e00b4f5fbb6aa160141abfa74488a1ba Mon Sep 17 00:00:00 2001 From: mhchia Date: Fri, 29 Nov 2019 19:09:56 +0800 Subject: [PATCH 11/81] Fix `Mplex` and `Swarm` --- libp2p/io/trio.py | 4 + libp2p/network/connection/swarm_connection.py | 8 +- libp2p/network/swarm.py | 5 +- libp2p/stream_muxer/mplex/mplex.py | 112 ++++++++++++++---- libp2p/stream_muxer/mplex/mplex_stream.py | 79 ++++++++++-- libp2p/tools/factories.py | 2 +- libp2p/utils.py | 25 ---- tests/conftest.py | 18 +-- tests/network/conftest.py | 2 - tests/stream_muxer/conftest.py | 2 - tests/stream_muxer/test_mplex_conn.py | 10 +- tests/stream_muxer/test_mplex_stream.py | 68 ++++++++--- tests/test_utils.py | 19 --- 13 files changed, 232 insertions(+), 122 deletions(-) delete mode 100644 tests/test_utils.py diff --git a/libp2p/io/trio.py b/libp2p/io/trio.py index 7d0584e4..e74e9ed2 100644 --- a/libp2p/io/trio.py +++ b/libp2p/io/trio.py @@ -23,6 +23,10 @@ class TrioReadWriteCloser(ReadWriteCloser): raise IOException(error) async def read(self, n: int = -1) -> bytes: + if n == 0: + # Check point + await trio.sleep(0) + return b"" max_bytes = n if n != -1 else None try: return await self.stream.receive_some(max_bytes) diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index b91783d1..46b3f6fd 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -50,8 +50,11 @@ class SwarmConn(INetConn, Service): await self._notify_disconnected() async def _handle_new_streams(self) -> None: - while True: + while self.manager.is_running: try: + print( + f"!@# SwarmConn._handle_new_streams: {self.muxed_conn._id}: waiting for new streams" + ) stream = await self.muxed_conn.accept_stream() except MuxedConnUnavailable: # If there is anything wrong in the MuxedConn, @@ -60,6 +63,9 @@ class SwarmConn(INetConn, Service): # Asynchronously handle the accepted stream, to avoid blocking the next stream. self.manager.run_task(self._handle_muxed_stream, stream) + print( + f"!@# SwarmConn._handle_new_streams: {self.muxed_conn._id}: out of the loop" + ) await self.close() async def _call_stream_handler(self, net_stream: NetStream) -> None: diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 337da896..78fb7fdc 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -206,8 +206,7 @@ class Swarm(INetwork, Service): logger.debug("successfully opened connection to peer %s", peer_id) # FIXME: This is a intentional barrier to prevent from the handler exiting and # closing the connection. - event = trio.Event() - await event.wait() + await trio.sleep_forever() try: # Success @@ -240,7 +239,7 @@ class Swarm(INetwork, Service): # await asyncio.gather( # *[connection.close() for connection in self.connections.values()] # ) - self.manager.stop() + await self.manager.stop() await self.manager.wait_finished() logger.debug("swarm successfully closed") diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 9ac56143..53d855b2 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -1,3 +1,4 @@ +import math import asyncio import logging from typing import Any # noqa: F401 @@ -18,7 +19,6 @@ from libp2p.utils import ( encode_uvarint, encode_varint_prefixed, read_varint_prefixed_bytes, - TrioQueue, ) from .constants import HeaderTags @@ -41,7 +41,10 @@ class Mplex(IMuxedConn, Service): next_channel_id: int streams: Dict[StreamID, MplexStream] streams_lock: trio.Lock - new_stream_queue: "TrioQueue[IMuxedStream]" + streams_msg_channels: Dict[StreamID, "trio.MemorySendChannel[bytes]"] + new_stream_send_channel: "trio.MemorySendChannel[IMuxedStream]" + new_stream_receive_channel: "trio.MemoryReceiveChannel[IMuxedStream]" + event_shutting_down: trio.Event event_closed: trio.Event @@ -64,7 +67,10 @@ class Mplex(IMuxedConn, Service): # Mapping from stream ID -> buffer of messages for that stream self.streams = {} self.streams_lock = trio.Lock() - self.new_stream_queue = TrioQueue() + self.streams_msg_channels = {} + send_channel, receive_channel = trio.open_memory_channel(math.inf) + self.new_stream_send_channel = send_channel + self.new_stream_receive_channel = receive_channel self.event_shutting_down = trio.Event() self.event_closed = trio.Event() @@ -105,9 +111,13 @@ class Mplex(IMuxedConn, Service): return next_id async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream: - stream = MplexStream(name, stream_id, self) + # Use an unbounded buffer, to avoid `handle_incoming` being blocked when doing + # `send_channel.send`. + send_channel, receive_channel = trio.open_memory_channel(math.inf) + stream = MplexStream(name, stream_id, self, receive_channel) async with self.streams_lock: self.streams[stream_id] = stream + self.streams_msg_channels[stream_id] = send_channel return stream async def open_stream(self) -> IMuxedStream: @@ -126,7 +136,10 @@ class Mplex(IMuxedConn, Service): async def accept_stream(self) -> IMuxedStream: """accepts a muxed stream opened by the other end.""" - return await self.new_stream_queue.get() + try: + return await self.new_stream_receive_channel.receive() + except (trio.ClosedResourceError, trio.EndOfChannel): + raise MplexUnavailable async def send_message( self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID @@ -138,6 +151,9 @@ class Mplex(IMuxedConn, Service): :param data: data to send in the message :param stream_id: stream the message is in """ + print( + f"!@# send_message: {self._id}: flag={flag}, data={data}, stream_id={stream_id}" + ) # << by 3, then or with flag header = encode_uvarint((stream_id.channel_id << 3) | flag.value) @@ -162,14 +178,21 @@ class Mplex(IMuxedConn, Service): """Read a message off of the secured connection and add it to the corresponding message buffer.""" - while True: + while self.manager.is_running: try: + print( + f"!@# handle_incoming: {self._id}: before _handle_incoming_message" + ) await self._handle_incoming_message() + print( + f"!@# handle_incoming: {self._id}: after _handle_incoming_message" + ) except MplexUnavailable as e: logger.debug("mplex unavailable while waiting for incoming: %s", e) + print(f"!@# handle_incoming: {self._id}: MplexUnavailable: {e}") break - # Force context switch - await trio.sleep(0) + + print(f"!@# handle_incoming: {self._id}: leaving") # If we enter here, it means this connection is shutting down. # We should clean things up. await self._cleanup() @@ -181,51 +204,73 @@ class Mplex(IMuxedConn, Service): :return: stream_id, flag, message contents """ - # FIXME: No timeout is used in Go implementation. try: header = await decode_uvarint_from_stream(self.secured_conn) + except (ParseError, RawConnError, IncompleteReadError) as error: + raise MplexUnavailable( + f"failed to read the header correctly from the underlying connection: {error}" + ) + try: message = await read_varint_prefixed_bytes(self.secured_conn) except (ParseError, RawConnError, IncompleteReadError) as error: raise MplexUnavailable( - "failed to read messages correctly from the underlying connection" - ) from error - except asyncio.TimeoutError as error: - raise MplexUnavailable( - "failed to read more message body within the timeout" - ) from error + "failed to read the message body correctly from the underlying connection: " + f"{error}" + ) flag = header & 0x07 channel_id = header >> 3 return channel_id, flag, message + @property + def _id(self) -> int: + return 0 if self.is_initiator else 1 + async def _handle_incoming_message(self) -> None: """ Read and handle a new incoming message. :raise MplexUnavailable: `Mplex` encounters fatal error or is shutting down. """ + print(f"!@# _handle_incoming_message: {self._id}: before reading") channel_id, flag, message = await self.read_message() + print( + f"!@# _handle_incoming_message: {self._id}: channel_id={channel_id}, flag={flag}, message={message}" + ) stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1)) + print(f"!@# _handle_incoming_message: {self._id}: 2") if flag == HeaderTags.NewStream.value: + print(f"!@# _handle_incoming_message: {self._id}: 3") await self._handle_new_stream(stream_id, message) + print(f"!@# _handle_incoming_message: {self._id}: 4") elif flag in ( HeaderTags.MessageInitiator.value, HeaderTags.MessageReceiver.value, ): + print(f"!@# _handle_incoming_message: {self._id}: 5") await self._handle_message(stream_id, message) + print(f"!@# _handle_incoming_message: {self._id}: 6") elif flag in (HeaderTags.CloseInitiator.value, HeaderTags.CloseReceiver.value): + print(f"!@# _handle_incoming_message: {self._id}: 7") await self._handle_close(stream_id) + print(f"!@# _handle_incoming_message: {self._id}: 8") elif flag in (HeaderTags.ResetInitiator.value, HeaderTags.ResetReceiver.value): + print(f"!@# _handle_incoming_message: {self._id}: 9") await self._handle_reset(stream_id) + print(f"!@# _handle_incoming_message: {self._id}: 10") else: + print(f"!@# _handle_incoming_message: {self._id}: 11") # Receives messages with an unknown flag # TODO: logging async with self.streams_lock: + print(f"!@# _handle_incoming_message: {self._id}: 12") if stream_id in self.streams: + print(f"!@# _handle_incoming_message: {self._id}: 13") stream = self.streams[stream_id] await stream.reset() + print(f"!@# _handle_incoming_message: {self._id}: 14") async def _handle_new_stream(self, stream_id: StreamID, message: bytes) -> None: async with self.streams_lock: @@ -235,43 +280,65 @@ class Mplex(IMuxedConn, Service): f"received NewStream message for existing stream: {stream_id}" ) mplex_stream = await self._initialize_stream(stream_id, message.decode()) - await self.new_stream_queue.put(mplex_stream) + try: + await self.new_stream_send_channel.send(mplex_stream) + except (trio.BrokenResourceError, trio.EndOfChannel): + raise MplexUnavailable async def _handle_message(self, stream_id: StreamID, message: bytes) -> None: + print( + f"!@# _handle_message: {self._id}: stream_id={stream_id}, message={message}" + ) async with self.streams_lock: + print(f"!@# _handle_message: {self._id}: 1") if stream_id not in self.streams: # We receive a message of the stream `stream_id` which is not accepted # before. It is abnormal. Possibly disconnect? # TODO: Warn and emit logs about this. + print(f"!@# _handle_message: {self._id}: 2") return + print(f"!@# _handle_message: {self._id}: 3") stream = self.streams[stream_id] + send_channel = self.streams_msg_channels[stream_id] async with stream.close_lock: + print(f"!@# _handle_message: {self._id}: 4") if stream.event_remote_closed.is_set(): + print(f"!@# _handle_message: {self._id}: 5") # TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501 return - await stream.incoming_data.put(message) + print(f"!@# _handle_message: {self._id}: 6") + await send_channel.send(message) + print(f"!@# _handle_message: {self._id}: 7") async def _handle_close(self, stream_id: StreamID) -> None: + print(f"!@# _handle_close: {self._id}: step=0") async with self.streams_lock: if stream_id not in self.streams: # Ignore unmatched messages for now. return stream = self.streams[stream_id] + send_channel = self.streams_msg_channels[stream_id] + print(f"!@# _handle_close: {self._id}: step=1") + await send_channel.aclose() + print(f"!@# _handle_close: {self._id}: step=2") # NOTE: If remote is already closed, then return: Technically a bug # on the other side. We should consider killing the connection. async with stream.close_lock: if stream.event_remote_closed.is_set(): return + print(f"!@# _handle_close: {self._id}: step=3") is_local_closed: bool async with stream.close_lock: stream.event_remote_closed.set() is_local_closed = stream.event_local_closed.is_set() + print(f"!@# _handle_close: {self._id}: step=4") # If local is also closed, both sides are closed. Then, we should clean up # the entry of this stream, to avoid others from accessing it. if is_local_closed: async with self.streams_lock: if stream_id in self.streams: del self.streams[stream_id] + print(f"!@# _handle_close: {self._id}: step=5") async def _handle_reset(self, stream_id: StreamID) -> None: async with self.streams_lock: @@ -279,11 +346,11 @@ class Mplex(IMuxedConn, Service): # This is *ok*. We forget the stream on reset. return stream = self.streams[stream_id] - + send_channel = self.streams_msg_channels[stream_id] + await send_channel.aclose() async with stream.close_lock: if not stream.event_remote_closed.is_set(): stream.event_reset.set() - stream.event_remote_closed.set() # If local is not closed, we should close it. if not stream.event_local_closed.is_set(): @@ -291,16 +358,21 @@ class Mplex(IMuxedConn, Service): async with self.streams_lock: if stream_id in self.streams: del self.streams[stream_id] + del self.streams_msg_channels[stream_id] async def _cleanup(self) -> None: if not self.event_shutting_down.is_set(): self.event_shutting_down.set() async with self.streams_lock: - for stream in self.streams.values(): + for stream_id, stream in self.streams.items(): async with stream.close_lock: if not stream.event_remote_closed.is_set(): stream.event_remote_closed.set() stream.event_reset.set() stream.event_local_closed.set() + send_channel = self.streams_msg_channels[stream_id] + await send_channel.aclose() self.streams = None self.event_closed.set() + await self.new_stream_send_channel.aclose() + await self.new_stream_receive_channel.aclose() diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 58da4040..6ecc4077 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -3,7 +3,6 @@ from typing import TYPE_CHECKING import trio from libp2p.stream_muxer.abc import IMuxedStream -from libp2p.utils import IQueue, TrioQueue from .constants import HeaderTags from .datastructures import StreamID @@ -24,10 +23,11 @@ class MplexStream(IMuxedStream): read_deadline: int write_deadline: int + # TODO: Add lock for read/write to avoid interleaving receiving messages? close_lock: trio.Lock # NOTE: `dataIn` is size of 8 in Go implementation. - incoming_data: IQueue[bytes] + incoming_data_channel: "trio.MemoryReceiveChannel[bytes]" event_local_closed: trio.Event event_remote_closed: trio.Event @@ -35,7 +35,13 @@ class MplexStream(IMuxedStream): _buf: bytearray - def __init__(self, name: str, stream_id: StreamID, muxed_conn: "Mplex") -> None: + def __init__( + self, + name: str, + stream_id: StreamID, + muxed_conn: "Mplex", + incoming_data_channel: "trio.MemoryReceiveChannel[bytes]", + ) -> None: """ create new MuxedStream in muxer. @@ -51,13 +57,30 @@ class MplexStream(IMuxedStream): self.event_remote_closed = trio.Event() self.event_reset = trio.Event() self.close_lock = trio.Lock() - self.incoming_data = TrioQueue() + self.incoming_data_channel = incoming_data_channel self._buf = bytearray() @property def is_initiator(self) -> bool: return self.stream_id.is_initiator + async def _read_until_eof(self) -> bytes: + async for data in self.incoming_data_channel: + self._buf.extend(data) + payload = self._buf + self._buf = self._buf[len(payload) :] + return bytes(payload) + + def _read_return_when_blocked(self) -> bytes: + buf = bytearray() + while True: + try: + data = self.incoming_data_channel.receive_nowait() + buf.extend(data) + except (trio.WouldBlock, trio.EndOfChannel): + break + return buf + async def read(self, n: int = -1) -> bytes: """ Read up to n bytes. Read possibly returns fewer than `n` bytes, if @@ -73,7 +96,40 @@ class MplexStream(IMuxedStream): ) if self.event_reset.is_set(): raise MplexStreamReset - return await self.incoming_data.get() + if n == -1: + return await self._read_until_eof() + if len(self._buf) == 0: + data: bytes + # Peek whether there is data available. If yes, we just read until there is no data, + # and then return. + try: + data = self.incoming_data_channel.receive_nowait() + except trio.EndOfChannel: + raise MplexStreamEOF + except trio.WouldBlock: + # We know `receive` will be blocked here. Wait for data here with `receive` and + # catch all kinds of errors here. + try: + data = await self.incoming_data_channel.receive() + except trio.EndOfChannel: + if self.event_reset.is_set(): + raise MplexStreamReset + if self.event_remote_closed.is_set(): + raise MplexStreamEOF + except trio.ClosedResourceError as error: + # Probably `incoming_data_channel` is closed in `reset` when we are waiting + # for `receive`. + if self.event_reset.is_set(): + raise MplexStreamReset + raise Exception( + "`incoming_data_channel` is closed but stream is not reset. " + "This should never happen." + ) from error + self._buf.extend(data) + self._buf.extend(self._read_return_when_blocked()) + payload = self._buf[:n] + self._buf = self._buf[len(payload) :] + return bytes(payload) async def write(self, data: bytes) -> int: """ @@ -99,22 +155,26 @@ class MplexStream(IMuxedStream): if self.event_local_closed.is_set(): return + print(f"!@# stream.close: {self.muxed_conn._id}: step=0") flag = ( HeaderTags.CloseInitiator if self.is_initiator else HeaderTags.CloseReceiver ) # TODO: Raise when `muxed_conn.send_message` fails and `Mplex` isn't shutdown. await self.muxed_conn.send_message(flag, None, self.stream_id) + print(f"!@# stream.close: {self.muxed_conn._id}: step=1") _is_remote_closed: bool async with self.close_lock: self.event_local_closed.set() _is_remote_closed = self.event_remote_closed.is_set() + print(f"!@# stream.close: {self.muxed_conn._id}: step=2") if _is_remote_closed: # Both sides are closed, we can safely remove the buffer from the dict. async with self.muxed_conn.streams_lock: if self.stream_id in self.muxed_conn.streams: del self.muxed_conn.streams[self.stream_id] + print(f"!@# stream.close: {self.muxed_conn._id}: step=3") async def reset(self) -> None: """closes both ends of the stream tells this remote side to hang up.""" @@ -132,14 +192,15 @@ class MplexStream(IMuxedStream): if self.is_initiator else HeaderTags.ResetReceiver ) - async with trio.open_nursery() as nursery: - nursery.start_soon( - self.muxed_conn.send_message, flag, None, self.stream_id - ) + self.muxed_conn.manager.run_task( + self.muxed_conn.send_message, flag, None, self.stream_id + ) self.event_local_closed.set() self.event_remote_closed.set() + await self.incoming_data_channel.aclose() + async with self.muxed_conn.streams_lock: if ( self.muxed_conn.streams is not None diff --git a/libp2p/tools/factories.py b/libp2p/tools/factories.py index f324371a..448ec90d 100644 --- a/libp2p/tools/factories.py +++ b/libp2p/tools/factories.py @@ -205,7 +205,7 @@ async def mplex_stream_pair_factory(is_secure: bool) -> Tuple[MplexStream, Mplex stream_1: MplexStream async with mplex_conn_1.streams_lock: if len(mplex_conn_1.streams) != 1: - raise Exception("Mplex should not have any stream upon connection") + raise Exception("Mplex should not have any other stream") stream_1 = tuple(mplex_conn_1.streams.values())[0] yield cast(MplexStream, stream_0), cast(MplexStream, stream_1) diff --git a/libp2p/utils.py b/libp2p/utils.py index 27880c48..3d0794a1 100644 --- a/libp2p/utils.py +++ b/libp2p/utils.py @@ -1,8 +1,5 @@ import itertools import math -from typing import Generic, TypeVar - -import trio from libp2p.exceptions import ParseError from libp2p.io.abc import Reader @@ -98,25 +95,3 @@ async def read_fixedint_prefixed(reader: Reader) -> bytes: len_bytes = await reader.read(SIZE_LEN_BYTES) len_int = int.from_bytes(len_bytes, "big") return await reader.read(len_int) - - -TItem = TypeVar("TItem") - - -class IQueue(Generic[TItem]): - async def put(self, item: TItem): - ... - - async def get(self) -> TItem: - ... - - -class TrioQueue(IQueue): - def __init__(self): - self.send_channel, self.receive_channel = trio.open_memory_channel(0) - - async def put(self, item: TItem): - await self.send_channel.send(item) - - async def get(self) -> TItem: - return await self.receive_channel.receive() diff --git a/tests/conftest.py b/tests/conftest.py index 746fb026..48d705c2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,5 @@ -import asyncio - import pytest -from libp2p.tools.constants import LISTEN_MADDR from libp2p.tools.factories import HostFactory @@ -17,17 +14,6 @@ def num_hosts(): @pytest.fixture -async def hosts(num_hosts, is_host_secure): - _hosts = HostFactory.create_batch(num_hosts, is_secure=is_host_secure) - await asyncio.gather( - *[_host.get_network().listen(LISTEN_MADDR) for _host in _hosts] - ) - try: +async def hosts(num_hosts, is_host_secure, nursery): + async with HostFactory.create_batch_and_listen(is_host_secure, num_hosts) as _hosts: yield _hosts - finally: - # TODO: It's possible that `close` raises exceptions currently, - # due to the connection reset things. Though we don't care much about that when - # cleaning up the tasks, it is probably better to handle the exceptions properly. - await asyncio.gather( - *[_host.close() for _host in _hosts], return_exceptions=True - ) diff --git a/tests/network/conftest.py b/tests/network/conftest.py index c45dbdba..5aad36c9 100644 --- a/tests/network/conftest.py +++ b/tests/network/conftest.py @@ -1,5 +1,3 @@ -import asyncio - import pytest from libp2p.tools.factories import ( diff --git a/tests/stream_muxer/conftest.py b/tests/stream_muxer/conftest.py index 5c5bc2bb..248422d9 100644 --- a/tests/stream_muxer/conftest.py +++ b/tests/stream_muxer/conftest.py @@ -1,5 +1,3 @@ -import asyncio - import pytest from libp2p.tools.factories import mplex_conn_pair_factory, mplex_stream_pair_factory diff --git a/tests/stream_muxer/test_mplex_conn.py b/tests/stream_muxer/test_mplex_conn.py index 6dc98ad6..d48432d2 100644 --- a/tests/stream_muxer/test_mplex_conn.py +++ b/tests/stream_muxer/test_mplex_conn.py @@ -1,9 +1,9 @@ -import asyncio +import trio import pytest -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_conn(mplex_conn_pair): conn_0, conn_1 = mplex_conn_pair @@ -16,19 +16,19 @@ async def test_mplex_conn(mplex_conn_pair): # Test: Open a stream, and both side get 1 more stream. stream_0 = await conn_0.open_stream() - await asyncio.sleep(0.01) + await trio.sleep(0.01) assert len(conn_0.streams) == 1 assert len(conn_1.streams) == 1 # Test: From another side. stream_1 = await conn_1.open_stream() - await asyncio.sleep(0.01) + await trio.sleep(0.01) assert len(conn_0.streams) == 2 assert len(conn_1.streams) == 2 # Close from one side. await conn_0.close() # Sleep for a while for both side to handle `close`. - await asyncio.sleep(0.01) + await trio.sleep(0.01) # Test: Both side is closed. assert conn_0.event_shutting_down.is_set() assert conn_0.event_closed.is_set() diff --git a/tests/stream_muxer/test_mplex_stream.py b/tests/stream_muxer/test_mplex_stream.py index dae66571..e3d19a5d 100644 --- a/tests/stream_muxer/test_mplex_stream.py +++ b/tests/stream_muxer/test_mplex_stream.py @@ -1,5 +1,6 @@ import pytest import trio +from trio.testing import wait_all_tasks_blocked from libp2p.stream_muxer.mplex.exceptions import ( MplexStreamClosed, @@ -37,37 +38,65 @@ async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair): async def read_until_eof(): read_bytes.extend(await stream_1.read()) - task = trio.ensure_future(read_until_eof()) - expected_data = bytearray() - # Test: `read` doesn't return before `close` is called. - await stream_0.write(DATA) - expected_data.extend(DATA) - await trio.sleep(0.01) - assert len(read_bytes) == 0 - # Test: `read` doesn't return before `close` is called. - await stream_0.write(DATA) - expected_data.extend(DATA) - await trio.sleep(0.01) - assert len(read_bytes) == 0 + async with trio.open_nursery() as nursery: + nursery.start_soon(read_until_eof) + # Test: `read` doesn't return before `close` is called. + await stream_0.write(DATA) + expected_data.extend(DATA) + await trio.sleep(0.01) + assert len(read_bytes) == 0 + # Test: `read` doesn't return before `close` is called. + await stream_0.write(DATA) + expected_data.extend(DATA) + await trio.sleep(0.01) + assert len(read_bytes) == 0 + + # Test: Close the stream, `read` returns, and receive previous sent data. + await stream_0.close() - # Test: Close the stream, `read` returns, and receive previous sent data. - await stream_0.close() - await trio.sleep(0.01) assert read_bytes == expected_data - task.cancel() - @pytest.mark.trio async def test_mplex_stream_read_after_remote_closed(mplex_stream_pair): stream_0, stream_1 = mplex_stream_pair assert not stream_1.event_remote_closed.is_set() await stream_0.write(DATA) - await stream_0.close() + assert not stream_0.event_local_closed.is_set() await trio.sleep(0.01) + await wait_all_tasks_blocked() + await stream_0.close() + assert stream_0.event_local_closed.is_set() + await trio.sleep(0.01) + print( + "!@# ", + stream_0.muxed_conn.event_shutting_down.is_set(), + stream_0.muxed_conn.event_closed.is_set(), + stream_1.muxed_conn.event_shutting_down.is_set(), + stream_1.muxed_conn.event_closed.is_set(), + ) + # await trio.sleep(100000) + await wait_all_tasks_blocked() + print( + "!@# ", + stream_0.muxed_conn.event_shutting_down.is_set(), + stream_0.muxed_conn.event_closed.is_set(), + stream_1.muxed_conn.event_shutting_down.is_set(), + stream_1.muxed_conn.event_closed.is_set(), + ) + print("!@# sleeping") + print("!@# result=", stream_1.event_remote_closed.is_set()) + # await trio.sleep_forever() assert stream_1.event_remote_closed.is_set() + print( + "!@# ", + stream_0.muxed_conn.event_shutting_down.is_set(), + stream_0.muxed_conn.event_closed.is_set(), + stream_1.muxed_conn.event_shutting_down.is_set(), + stream_1.muxed_conn.event_closed.is_set(), + ) assert (await stream_1.read(MAX_READ_LEN)) == DATA with pytest.raises(MplexStreamEOF): await stream_1.read(MAX_READ_LEN) @@ -87,7 +116,8 @@ async def test_mplex_stream_read_after_remote_reset(mplex_stream_pair): await stream_0.write(DATA) await stream_0.reset() # Sleep to let `stream_1` receive the message. - await trio.sleep(0.01) + await trio.sleep(0.1) + await wait_all_tasks_blocked() with pytest.raises(MplexStreamReset): await stream_1.read(MAX_READ_LEN) diff --git a/tests/test_utils.py b/tests/test_utils.py deleted file mode 100644 index 7bd807d5..00000000 --- a/tests/test_utils.py +++ /dev/null @@ -1,19 +0,0 @@ -import pytest -import trio - -from libp2p.utils import TrioQueue - - -@pytest.mark.trio -async def test_trio_queue(): - queue = TrioQueue() - - async def queue_get(task_status=None): - result = await queue.get() - task_status.started(result) - - async with trio.open_nursery() as nursery: - nursery.start_soon(queue.put, 123) - result = await nursery.start(queue_get) - - assert result == 123 From 79fcdf3a02115e9bccebc23418b638a665609f51 Mon Sep 17 00:00:00 2001 From: mhchia Date: Sun, 1 Dec 2019 16:26:16 +0800 Subject: [PATCH 12/81] Update tests in test_tcp.py Besides, run `make format` --- libp2p/network/connection/swarm_connection.py | 2 +- libp2p/network/swarm.py | 13 +---- libp2p/stream_muxer/mplex/mplex.py | 4 +- libp2p/tools/factories.py | 6 +- libp2p/tools/utils.py | 2 +- libp2p/transport/tcp/tcp.py | 8 +-- tests/network/test_swarm.py | 2 +- tests/stream_muxer/test_mplex_conn.py | 3 +- tests/stream_muxer/test_mplex_stream.py | 3 +- tests/transport/test_tcp.py | 55 ++++++++++++++----- 10 files changed, 55 insertions(+), 43 deletions(-) diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index 46b3f6fd..48774ec2 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, List, Set, Tuple -import trio from async_service import Service +import trio from libp2p.network.connection.net_connection_interface import INetConn from libp2p.network.stream.net_stream import NetStream diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 78fb7fdc..37614bcd 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -2,7 +2,6 @@ import logging from typing import Dict, List, Optional from async_service import Service - from multiaddr import Multiaddr import trio @@ -205,7 +204,7 @@ class Swarm(INetwork, Service): logger.debug("successfully opened connection to peer %s", peer_id) # FIXME: This is a intentional barrier to prevent from the handler exiting and - # closing the connection. + # closing the connection. Probably change to `Service.manager.wait_finished`? await trio.sleep_forever() try: @@ -229,16 +228,6 @@ class Swarm(INetwork, Service): async def close(self) -> None: # TODO: Prevent from new listeners and conns being added. # Reference: https://github.com/libp2p/go-libp2p-swarm/blob/8be680aef8dea0a4497283f2f98470c2aeae6b65/swarm.go#L124-L134 # noqa: E501 - - # Close listeners - # await asyncio.gather( - # *[listener.close() for listener in self.listeners.values()] - # ) - - # # Close connections - # await asyncio.gather( - # *[connection.close() for connection in self.connections.values()] - # ) await self.manager.stop() await self.manager.wait_finished() logger.debug("swarm successfully closed") diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 53d855b2..f93acea7 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -1,11 +1,11 @@ -import math import asyncio import logging +import math from typing import Any # noqa: F401 from typing import Awaitable, Dict, List, Optional, Tuple -import trio from async_service import Service +import trio from libp2p.exceptions import ParseError from libp2p.io.exceptions import IncompleteReadError diff --git a/libp2p/tools/factories.py b/libp2p/tools/factories.py index 448ec90d..470cbc31 100644 --- a/libp2p/tools/factories.py +++ b/libp2p/tools/factories.py @@ -1,9 +1,9 @@ -import trio -from contextlib import asynccontextmanager, AsyncExitStack +from contextlib import AsyncExitStack, asynccontextmanager from typing import Any, AsyncIterator, Dict, Tuple, cast -import factory from async_service import background_trio_service +import factory +import trio from libp2p import generate_new_rsa_identity, generate_peer_id_from from libp2p.crypto.keys import KeyPair diff --git a/libp2p/tools/utils.py b/libp2p/tools/utils.py index 0d39156b..db1e8abe 100644 --- a/libp2p/tools/utils.py +++ b/libp2p/tools/utils.py @@ -1,7 +1,7 @@ -import trio from typing import List, Sequence, Tuple import multiaddr +import trio from libp2p import new_node from libp2p.host.basic_host import BasicHost diff --git a/libp2p/transport/tcp/tcp.py b/libp2p/transport/tcp/tcp.py index 16365982..745bafe8 100644 --- a/libp2p/transport/tcp/tcp.py +++ b/libp2p/transport/tcp/tcp.py @@ -25,7 +25,8 @@ class TCPListener(IListener): self.server = None self.handler = handler_function - async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool: + # TODO: Fix handling? + async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> None: """ put listener in listening mode and wait for incoming connections. @@ -50,16 +51,13 @@ class TCPListener(IListener): socket = listeners[0].socket self.multiaddrs.append(_multiaddr_from_socket(socket)) - return True - def get_addrs(self) -> List[Multiaddr]: """ retrieve list of addresses the listener is listening on. :return: return list of addrs """ - # TODO check if server is listening - return self.multiaddrs + return tuple(self.multiaddrs) async def close(self) -> None: """close the listener such that no more connections can be open on this diff --git a/tests/network/test_swarm.py b/tests/network/test_swarm.py index de086352..1492441f 100644 --- a/tests/network/test_swarm.py +++ b/tests/network/test_swarm.py @@ -1,5 +1,5 @@ -import trio import pytest +import trio from trio.testing import wait_all_tasks_blocked from libp2p.network.exceptions import SwarmException diff --git a/tests/stream_muxer/test_mplex_conn.py b/tests/stream_muxer/test_mplex_conn.py index d48432d2..4cedc36d 100644 --- a/tests/stream_muxer/test_mplex_conn.py +++ b/tests/stream_muxer/test_mplex_conn.py @@ -1,6 +1,5 @@ -import trio - import pytest +import trio @pytest.mark.trio diff --git a/tests/stream_muxer/test_mplex_stream.py b/tests/stream_muxer/test_mplex_stream.py index e3d19a5d..e47af49d 100644 --- a/tests/stream_muxer/test_mplex_stream.py +++ b/tests/stream_muxer/test_mplex_stream.py @@ -7,11 +7,10 @@ from libp2p.stream_muxer.mplex.exceptions import ( MplexStreamEOF, MplexStreamReset, ) -from libp2p.tools.constants import MAX_READ_LEN, LISTEN_MADDR +from libp2p.tools.constants import LISTEN_MADDR, MAX_READ_LEN from libp2p.tools.factories import SwarmFactory from libp2p.tools.utils import connect_swarm - DATA = b"data_123" diff --git a/tests/transport/test_tcp.py b/tests/transport/test_tcp.py index 7231a060..c8fe6f21 100644 --- a/tests/transport/test_tcp.py +++ b/tests/transport/test_tcp.py @@ -1,20 +1,47 @@ -import asyncio - +from multiaddr import Multiaddr import pytest +import trio -from libp2p.transport.tcp.tcp import _multiaddr_from_socket +from libp2p.network.connection.raw_connection import RawConnection +from libp2p.tools.constants import LISTEN_MADDR, MAX_READ_LEN +from libp2p.transport.tcp.tcp import TCP -@pytest.mark.asyncio -async def test_multiaddr_from_socket(): - def handler(r, w): - pass +@pytest.mark.trio +async def test_tcp_listener(nursery): + transport = TCP() - server = await asyncio.start_server(handler, "127.0.0.1", 8000) - assert str(_multiaddr_from_socket(server.sockets[0])) == "/ip4/127.0.0.1/tcp/8000" + async def handler(tcp_stream): + ... - server = await asyncio.start_server(handler, "127.0.0.1", 0) - addr = _multiaddr_from_socket(server.sockets[0]) - assert addr.value_for_protocol("ip4") == "127.0.0.1" - port = addr.value_for_protocol("tcp") - assert int(port) > 0 + listener = transport.create_listener(handler) + assert len(listener.get_addrs()) == 0 + await listener.listen(LISTEN_MADDR, nursery) + assert len(listener.get_addrs()) == 1 + await listener.listen(LISTEN_MADDR, nursery) + assert len(listener.get_addrs()) == 2 + + +@pytest.mark.trio +async def test_tcp_dial(nursery): + transport = TCP() + raw_conn_other_side = None + + async def handler(tcp_stream): + nonlocal raw_conn_other_side + raw_conn_other_side = RawConnection(tcp_stream, False) + await trio.sleep_forever() + + # Test: OSError is raised when trying to dial to a port which no one is not listening to. + with pytest.raises(OSError): + await transport.dial(Multiaddr("/ip4/127.0.0.1/tcp/1")) + + listener = transport.create_listener(handler) + await listener.listen(LISTEN_MADDR, nursery) + assert len(listener.multiaddrs) == 1 + listen_addr = listener.multiaddrs[0] + raw_conn = await transport.dial(listen_addr) + + data = b"123" + await raw_conn_other_side.write(data) + assert (await raw_conn.read(len(data))) == data From 62e47080f570bdc0230bc6f009eb331ce1ba3812 Mon Sep 17 00:00:00 2001 From: mhchia Date: Sun, 1 Dec 2019 16:51:06 +0800 Subject: [PATCH 13/81] Fix `tests/network` --- tests/network/test_net_stream.py | 64 +++++++++++------------ tests/network/test_notify.py | 89 +++++++++++++++++--------------- tests/network/test_swarm_conn.py | 12 ++--- 3 files changed, 84 insertions(+), 81 deletions(-) diff --git a/tests/network/test_net_stream.py b/tests/network/test_net_stream.py index d0fea932..2c2772a1 100644 --- a/tests/network/test_net_stream.py +++ b/tests/network/test_net_stream.py @@ -1,4 +1,4 @@ -import asyncio +import trio import pytest @@ -8,7 +8,7 @@ from libp2p.tools.constants import MAX_READ_LEN DATA = b"data_123" -@pytest.mark.asyncio +@pytest.mark.trio async def test_net_stream_read_write(net_stream_pair): stream_0, stream_1 = net_stream_pair assert ( @@ -19,7 +19,7 @@ async def test_net_stream_read_write(net_stream_pair): assert (await stream_1.read(MAX_READ_LEN)) == DATA -@pytest.mark.asyncio +@pytest.mark.trio async def test_net_stream_read_until_eof(net_stream_pair): read_bytes = bytearray() stream_0, stream_1 = net_stream_pair @@ -27,41 +27,39 @@ async def test_net_stream_read_until_eof(net_stream_pair): async def read_until_eof(): read_bytes.extend(await stream_1.read()) - task = asyncio.ensure_future(read_until_eof()) + async with trio.open_nursery() as nursery: + nursery.start_soon(read_until_eof) + expected_data = bytearray() - expected_data = bytearray() + # Test: `read` doesn't return before `close` is called. + await stream_0.write(DATA) + expected_data.extend(DATA) + await trio.sleep(0.01) + assert len(read_bytes) == 0 + # Test: `read` doesn't return before `close` is called. + await stream_0.write(DATA) + expected_data.extend(DATA) + await trio.sleep(0.01) + assert len(read_bytes) == 0 - # Test: `read` doesn't return before `close` is called. - await stream_0.write(DATA) - expected_data.extend(DATA) - await asyncio.sleep(0.01) - assert len(read_bytes) == 0 - # Test: `read` doesn't return before `close` is called. - await stream_0.write(DATA) - expected_data.extend(DATA) - await asyncio.sleep(0.01) - assert len(read_bytes) == 0 - - # Test: Close the stream, `read` returns, and receive previous sent data. - await stream_0.close() - await asyncio.sleep(0.01) - assert read_bytes == expected_data - - task.cancel() + # Test: Close the stream, `read` returns, and receive previous sent data. + await stream_0.close() + await trio.sleep(0.01) + assert read_bytes == expected_data -@pytest.mark.asyncio +@pytest.mark.trio async def test_net_stream_read_after_remote_closed(net_stream_pair): stream_0, stream_1 = net_stream_pair await stream_0.write(DATA) await stream_0.close() - await asyncio.sleep(0.01) + await trio.sleep(0.01) assert (await stream_1.read(MAX_READ_LEN)) == DATA with pytest.raises(StreamEOF): await stream_1.read(MAX_READ_LEN) -@pytest.mark.asyncio +@pytest.mark.trio async def test_net_stream_read_after_local_reset(net_stream_pair): stream_0, stream_1 = net_stream_pair await stream_0.reset() @@ -69,29 +67,29 @@ async def test_net_stream_read_after_local_reset(net_stream_pair): await stream_0.read(MAX_READ_LEN) -@pytest.mark.asyncio +@pytest.mark.trio async def test_net_stream_read_after_remote_reset(net_stream_pair): stream_0, stream_1 = net_stream_pair await stream_0.write(DATA) await stream_0.reset() # Sleep to let `stream_1` receive the message. - await asyncio.sleep(0.01) + await trio.sleep(0.01) with pytest.raises(StreamReset): await stream_1.read(MAX_READ_LEN) -@pytest.mark.asyncio +@pytest.mark.trio async def test_net_stream_read_after_remote_closed_and_reset(net_stream_pair): stream_0, stream_1 = net_stream_pair await stream_0.write(DATA) await stream_0.close() await stream_0.reset() # Sleep to let `stream_1` receive the message. - await asyncio.sleep(0.01) + await trio.sleep(0.01) assert (await stream_1.read(MAX_READ_LEN)) == DATA -@pytest.mark.asyncio +@pytest.mark.trio async def test_net_stream_write_after_local_closed(net_stream_pair): stream_0, stream_1 = net_stream_pair await stream_0.write(DATA) @@ -100,7 +98,7 @@ async def test_net_stream_write_after_local_closed(net_stream_pair): await stream_0.write(DATA) -@pytest.mark.asyncio +@pytest.mark.trio async def test_net_stream_write_after_local_reset(net_stream_pair): stream_0, stream_1 = net_stream_pair await stream_0.reset() @@ -108,10 +106,10 @@ async def test_net_stream_write_after_local_reset(net_stream_pair): await stream_0.write(DATA) -@pytest.mark.asyncio +@pytest.mark.trio async def test_net_stream_write_after_remote_reset(net_stream_pair): stream_0, stream_1 = net_stream_pair await stream_1.reset() - await asyncio.sleep(0.01) + await trio.sleep(0.01) with pytest.raises(StreamClosed): await stream_0.write(DATA) diff --git a/tests/network/test_notify.py b/tests/network/test_notify.py index f8187b1e..1d9982e8 100644 --- a/tests/network/test_notify.py +++ b/tests/network/test_notify.py @@ -8,12 +8,13 @@ into network after network has already started listening TODO: Add tests for closed_stream, listen_close when those features are implemented in swarm """ - -import asyncio +import trio import enum import pytest +from async_service import background_trio_service + from libp2p.network.notifee_interface import INotifee from libp2p.tools.constants import LISTEN_MADDR from libp2p.tools.factories import SwarmFactory @@ -54,59 +55,63 @@ class MyNotifee(INotifee): pass -@pytest.mark.asyncio +@pytest.mark.trio async def test_notify(is_host_secure): swarms = [SwarmFactory(is_secure=is_host_secure) for _ in range(2)] events_0_0 = [] events_1_0 = [] events_0_without_listen = [] - swarms[0].register_notifee(MyNotifee(events_0_0)) - swarms[1].register_notifee(MyNotifee(events_1_0)) - # Listen - await asyncio.gather(*[swarm.listen(LISTEN_MADDR) for swarm in swarms]) + # Run swarms. + async with background_trio_service(swarms[0]), background_trio_service(swarms[1]): + # Register events before listening, to allow `MyNotifee` is notified with the event + # `listen`. + swarms[0].register_notifee(MyNotifee(events_0_0)) + swarms[1].register_notifee(MyNotifee(events_1_0)) - swarms[0].register_notifee(MyNotifee(events_0_without_listen)) + # Listen + async with trio.open_nursery() as nursery: + nursery.start_soon(swarms[0].listen, LISTEN_MADDR) + nursery.start_soon(swarms[1].listen, LISTEN_MADDR) - # Connected - await connect_swarm(swarms[0], swarms[1]) - # OpenedStream: first - await swarms[0].new_stream(swarms[1].get_peer_id()) - # OpenedStream: second - await swarms[0].new_stream(swarms[1].get_peer_id()) - # OpenedStream: third, but different direction. - await swarms[1].new_stream(swarms[0].get_peer_id()) + swarms[0].register_notifee(MyNotifee(events_0_without_listen)) - await asyncio.sleep(0.01) + # Connected + await connect_swarm(swarms[0], swarms[1]) + # OpenedStream: first + await swarms[0].new_stream(swarms[1].get_peer_id()) + # OpenedStream: second + await swarms[0].new_stream(swarms[1].get_peer_id()) + # OpenedStream: third, but different direction. + await swarms[1].new_stream(swarms[0].get_peer_id()) - # TODO: Check `ClosedStream` and `ListenClose` events after they are ready. + await trio.sleep(0.01) - # Disconnected - await swarms[0].close_peer(swarms[1].get_peer_id()) - await asyncio.sleep(0.01) + # TODO: Check `ClosedStream` and `ListenClose` events after they are ready. - # Connected again, but different direction. - await connect_swarm(swarms[1], swarms[0]) - await asyncio.sleep(0.01) + # Disconnected + await swarms[0].close_peer(swarms[1].get_peer_id()) + await trio.sleep(0.01) - # Disconnected again, but different direction. - await swarms[1].close_peer(swarms[0].get_peer_id()) - await asyncio.sleep(0.01) + # Connected again, but different direction. + await connect_swarm(swarms[1], swarms[0]) + await trio.sleep(0.01) - expected_events_without_listen = [ - Event.Connected, - Event.OpenedStream, - Event.OpenedStream, - Event.OpenedStream, - Event.Disconnected, - Event.Connected, - Event.Disconnected, - ] - expected_events = [Event.Listen] + expected_events_without_listen + # Disconnected again, but different direction. + await swarms[1].close_peer(swarms[0].get_peer_id()) + await trio.sleep(0.01) - assert events_0_0 == expected_events - assert events_1_0 == expected_events - assert events_0_without_listen == expected_events_without_listen + expected_events_without_listen = [ + Event.Connected, + Event.OpenedStream, + Event.OpenedStream, + Event.OpenedStream, + Event.Disconnected, + Event.Connected, + Event.Disconnected, + ] + expected_events = [Event.Listen] + expected_events_without_listen - # Clean up - await asyncio.gather(*[swarm.close() for swarm in swarms]) + assert events_0_0 == expected_events + assert events_1_0 == expected_events + assert events_0_without_listen == expected_events_without_listen diff --git a/tests/network/test_swarm_conn.py b/tests/network/test_swarm_conn.py index 2abc7d0f..0b93808d 100644 --- a/tests/network/test_swarm_conn.py +++ b/tests/network/test_swarm_conn.py @@ -1,9 +1,9 @@ -import asyncio +import trio import pytest -@pytest.mark.asyncio +@pytest.mark.trio async def test_swarm_conn_close(swarm_conn_pair): conn_0, conn_1 = swarm_conn_pair @@ -12,7 +12,7 @@ async def test_swarm_conn_close(swarm_conn_pair): await conn_0.close() - await asyncio.sleep(0.01) + await trio.sleep(0.01) assert conn_0.event_closed.is_set() assert conn_1.event_closed.is_set() @@ -20,7 +20,7 @@ async def test_swarm_conn_close(swarm_conn_pair): assert conn_1 not in conn_1.swarm.connections.values() -@pytest.mark.asyncio +@pytest.mark.trio async def test_swarm_conn_streams(swarm_conn_pair): conn_0, conn_1 = swarm_conn_pair @@ -28,12 +28,12 @@ async def test_swarm_conn_streams(swarm_conn_pair): assert len(await conn_1.get_streams()) == 0 stream_0_0 = await conn_0.new_stream() - await asyncio.sleep(0.01) + await trio.sleep(0.01) assert len(await conn_0.get_streams()) == 1 assert len(await conn_1.get_streams()) == 1 stream_0_1 = await conn_0.new_stream() - await asyncio.sleep(0.01) + await trio.sleep(0.01) assert len(await conn_0.get_streams()) == 2 assert len(await conn_1.get_streams()) == 2 From 31bf774a16479ac42033e07d50f6f0233bf75b70 Mon Sep 17 00:00:00 2001 From: mhchia Date: Sun, 1 Dec 2019 17:43:14 +0800 Subject: [PATCH 14/81] Fix tests in `protocol_muxer` and `libp2p` --- libp2p/tools/utils.py | 27 +- tests/libp2p/test_libp2p.py | 571 +++++++++----------- tests/network/test_net_stream.py | 3 +- tests/network/test_notify.py | 5 +- tests/network/test_swarm_conn.py | 3 +- tests/protocol_muxer/test_protocol_muxer.py | 114 ++-- 6 files changed, 325 insertions(+), 398 deletions(-) diff --git a/libp2p/tools/utils.py b/libp2p/tools/utils.py index db1e8abe..5ec48867 100644 --- a/libp2p/tools/utils.py +++ b/libp2p/tools/utils.py @@ -1,4 +1,4 @@ -from typing import List, Sequence, Tuple +from typing import Callable, List, Sequence, Tuple import multiaddr import trio @@ -12,7 +12,6 @@ from libp2p.network.swarm import Swarm from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.routing.interfaces import IPeerRouting from libp2p.routing.kademlia.kademlia_peer_router import KadmeliaPeerRouter -from libp2p.typing import StreamHandlerFn, TProtocol from .constants import MAX_READ_LEN @@ -79,22 +78,12 @@ async def set_up_routers( return routers -async def echo_stream_handler(stream: INetStream) -> None: - while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() +def create_echo_stream_handler(ack_prefix: str) -> Callable[[INetStream], None]: + async def echo_stream_handler(stream: INetStream) -> None: + while True: + read_string = (await stream.read(MAX_READ_LEN)).decode() - resp = "ack:" + read_string - await stream.write(resp.encode()) + resp = ack_prefix + read_string + await stream.write(resp.encode()) - -async def perform_two_host_set_up( - handler: StreamHandlerFn = echo_stream_handler -) -> Tuple[BasicHost, BasicHost]: - transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] - (node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list) - - node_b.set_stream_handler(TProtocol("/echo/1.0.0"), handler) - - # Associate the peer with local ip address (see default parameters of Libp2p()) - node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) - return node_a, node_b + return echo_stream_handler diff --git a/tests/libp2p/test_libp2p.py b/tests/libp2p/test_libp2p.py index 628e008a..5af5ea10 100644 --- a/tests/libp2p/test_libp2p.py +++ b/tests/libp2p/test_libp2p.py @@ -1,351 +1,280 @@ import multiaddr import pytest -import trio from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.tools.constants import MAX_READ_LEN -from libp2p.tools.utils import set_up_nodes_by_transport_opt +from libp2p.tools.factories import HostFactory +from libp2p.tools.utils import create_echo_stream_handler +from libp2p.typing import TProtocol + +PROTOCOL_ID_0 = TProtocol("/echo/0") +PROTOCOL_ID_1 = TProtocol("/echo/1") +PROTOCOL_ID_2 = TProtocol("/echo/2") +PROTOCOL_ID_3 = TProtocol("/echo/3") + +ACK_STR_0 = "ack_0:" +ACK_STR_1 = "ack_1:" +ACK_STR_2 = "ack_2:" +ACK_STR_3 = "ack_3:" @pytest.mark.trio -async def test_simple_messages(nursery): - transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] - (node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list, nursery) - - async def stream_handler(stream): - while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() - - response = "ack:" + read_string - await stream.write(response.encode()) - - node_b.set_stream_handler("/echo/1.0.0", stream_handler) - - # Associate the peer with local ip address (see default parameters of Libp2p()) - node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) - - stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) - - messages = ["hello" + str(x) for x in range(10)] - for message in messages: - - await stream.write(message.encode()) - - response = (await stream.read(MAX_READ_LEN)).decode() - - assert response == ("ack:" + message) - - # Success, terminate pending tasks. - - -@pytest.mark.asyncio -async def test_double_response(): - transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] - (node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list) - - async def stream_handler(stream): - while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() - - response = "ack1:" + read_string - await stream.write(response.encode()) - - response = "ack2:" + read_string - await stream.write(response.encode()) - - node_b.set_stream_handler("/echo/1.0.0", stream_handler) - - # Associate the peer with local ip address (see default parameters of Libp2p()) - node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) - stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) - - messages = ["hello" + str(x) for x in range(10)] - for message in messages: - await stream.write(message.encode()) - - response1 = (await stream.read(MAX_READ_LEN)).decode() - assert response1 == ("ack1:" + message) - - response2 = (await stream.read(MAX_READ_LEN)).decode() - assert response2 == ("ack2:" + message) - - # Success, terminate pending tasks. - - -@pytest.mark.asyncio -async def test_multiple_streams(): - # Node A should be able to open a stream with node B and then vice versa. - # Stream IDs should be generated uniquely so that the stream state is not overwritten - transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] - (node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list) - - async def stream_handler_a(stream): - while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() - - response = "ack_a:" + read_string - await stream.write(response.encode()) - - async def stream_handler_b(stream): - while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() - - response = "ack_b:" + read_string - await stream.write(response.encode()) - - node_a.set_stream_handler("/echo_a/1.0.0", stream_handler_a) - node_b.set_stream_handler("/echo_b/1.0.0", stream_handler_b) - - # Associate the peer with local ip address (see default parameters of Libp2p()) - node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) - node_b.get_peerstore().add_addrs(node_a.get_id(), node_a.get_addrs(), 10) - - stream_a = await node_a.new_stream(node_b.get_id(), ["/echo_b/1.0.0"]) - stream_b = await node_b.new_stream(node_a.get_id(), ["/echo_a/1.0.0"]) - - # A writes to /echo_b via stream_a, and B writes to /echo_a via stream_b - messages = ["hello" + str(x) for x in range(10)] - for message in messages: - a_message = message + "_a" - b_message = message + "_b" - - await stream_a.write(a_message.encode()) - await stream_b.write(b_message.encode()) - - response_a = (await stream_a.read(MAX_READ_LEN)).decode() - response_b = (await stream_b.read(MAX_READ_LEN)).decode() - - assert response_a == ("ack_b:" + a_message) and response_b == ( - "ack_a:" + b_message +async def test_simple_messages(is_host_secure): + async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts: + hosts[1].set_stream_handler( + PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0) ) - # Success, terminate pending tasks. + # Associate the peer with local ip address (see default parameters of Libp2p()) + hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10) + + stream = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0]) + + messages = ["hello" + str(x) for x in range(10)] + for message in messages: + await stream.write(message.encode()) + response = (await stream.read(MAX_READ_LEN)).decode() + assert response == (ACK_STR_0 + message) -@pytest.mark.asyncio -async def test_multiple_streams_same_initiator_different_protocols(): - transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] - (node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list) +@pytest.mark.trio +async def test_double_response(is_host_secure): + async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts: - async def stream_handler_a1(stream): - while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() + async def double_response_stream_handler(stream): + while True: + read_string = (await stream.read(MAX_READ_LEN)).decode() - response = "ack_a1:" + read_string - await stream.write(response.encode()) + response = ACK_STR_0 + read_string + await stream.write(response.encode()) - async def stream_handler_a2(stream): - while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() + response = ACK_STR_1 + read_string + await stream.write(response.encode()) - response = "ack_a2:" + read_string - await stream.write(response.encode()) + hosts[1].set_stream_handler(PROTOCOL_ID_0, double_response_stream_handler) - async def stream_handler_a3(stream): - while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() + # Associate the peer with local ip address (see default parameters of Libp2p()) + hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10) + stream = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0]) - response = "ack_a3:" + read_string - await stream.write(response.encode()) - - node_b.set_stream_handler("/echo_a1/1.0.0", stream_handler_a1) - node_b.set_stream_handler("/echo_a2/1.0.0", stream_handler_a2) - node_b.set_stream_handler("/echo_a3/1.0.0", stream_handler_a3) - - # Associate the peer with local ip address (see default parameters of Libp2p()) - node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) - node_b.get_peerstore().add_addrs(node_a.get_id(), node_a.get_addrs(), 10) - - # Open streams to node_b over echo_a1 echo_a2 echo_a3 protocols - stream_a1 = await node_a.new_stream(node_b.get_id(), ["/echo_a1/1.0.0"]) - stream_a2 = await node_a.new_stream(node_b.get_id(), ["/echo_a2/1.0.0"]) - stream_a3 = await node_a.new_stream(node_b.get_id(), ["/echo_a3/1.0.0"]) - - messages = ["hello" + str(x) for x in range(10)] - for message in messages: - a1_message = message + "_a1" - a2_message = message + "_a2" - a3_message = message + "_a3" - - await stream_a1.write(a1_message.encode()) - await stream_a2.write(a2_message.encode()) - await stream_a3.write(a3_message.encode()) - - response_a1 = (await stream_a1.read(MAX_READ_LEN)).decode() - response_a2 = (await stream_a2.read(MAX_READ_LEN)).decode() - response_a3 = (await stream_a3.read(MAX_READ_LEN)).decode() - - assert ( - response_a1 == ("ack_a1:" + a1_message) - and response_a2 == ("ack_a2:" + a2_message) - and response_a3 == ("ack_a3:" + a3_message) - ) - - # Success, terminate pending tasks. - - -@pytest.mark.asyncio -async def test_multiple_streams_two_initiators(): - transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] - (node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list) - - async def stream_handler_a1(stream): - while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() - - response = "ack_a1:" + read_string - await stream.write(response.encode()) - - async def stream_handler_a2(stream): - while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() - - response = "ack_a2:" + read_string - await stream.write(response.encode()) - - async def stream_handler_b1(stream): - while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() - - response = "ack_b1:" + read_string - await stream.write(response.encode()) - - async def stream_handler_b2(stream): - while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() - - response = "ack_b2:" + read_string - await stream.write(response.encode()) - - node_a.set_stream_handler("/echo_b1/1.0.0", stream_handler_b1) - node_a.set_stream_handler("/echo_b2/1.0.0", stream_handler_b2) - - node_b.set_stream_handler("/echo_a1/1.0.0", stream_handler_a1) - node_b.set_stream_handler("/echo_a2/1.0.0", stream_handler_a2) - - # Associate the peer with local ip address (see default parameters of Libp2p()) - node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) - node_b.get_peerstore().add_addrs(node_a.get_id(), node_a.get_addrs(), 10) - - stream_a1 = await node_a.new_stream(node_b.get_id(), ["/echo_a1/1.0.0"]) - stream_a2 = await node_a.new_stream(node_b.get_id(), ["/echo_a2/1.0.0"]) - - stream_b1 = await node_b.new_stream(node_a.get_id(), ["/echo_b1/1.0.0"]) - stream_b2 = await node_b.new_stream(node_a.get_id(), ["/echo_b2/1.0.0"]) - - # A writes to /echo_b via stream_a, and B writes to /echo_a via stream_b - messages = ["hello" + str(x) for x in range(10)] - for message in messages: - a1_message = message + "_a1" - a2_message = message + "_a2" - - b1_message = message + "_b1" - b2_message = message + "_b2" - - await stream_a1.write(a1_message.encode()) - await stream_a2.write(a2_message.encode()) - - await stream_b1.write(b1_message.encode()) - await stream_b2.write(b2_message.encode()) - - response_a1 = (await stream_a1.read(MAX_READ_LEN)).decode() - response_a2 = (await stream_a2.read(MAX_READ_LEN)).decode() - - response_b1 = (await stream_b1.read(MAX_READ_LEN)).decode() - response_b2 = (await stream_b2.read(MAX_READ_LEN)).decode() - - assert ( - response_a1 == ("ack_a1:" + a1_message) - and response_a2 == ("ack_a2:" + a2_message) - and response_b1 == ("ack_b1:" + b1_message) - and response_b2 == ("ack_b2:" + b2_message) - ) - - # Success, terminate pending tasks. - - -@pytest.mark.asyncio -async def test_triangle_nodes_connection(): - transport_opt_list = [ - ["/ip4/127.0.0.1/tcp/0"], - ["/ip4/127.0.0.1/tcp/0"], - ["/ip4/127.0.0.1/tcp/0"], - ] - (node_a, node_b, node_c) = await set_up_nodes_by_transport_opt(transport_opt_list) - - async def stream_handler(stream): - while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() - - response = "ack:" + read_string - await stream.write(response.encode()) - - node_a.set_stream_handler("/echo/1.0.0", stream_handler) - node_b.set_stream_handler("/echo/1.0.0", stream_handler) - node_c.set_stream_handler("/echo/1.0.0", stream_handler) - - # Associate the peer with local ip address (see default parameters of Libp2p()) - # Associate all permutations - node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) - node_a.get_peerstore().add_addrs(node_c.get_id(), node_c.get_addrs(), 10) - - node_b.get_peerstore().add_addrs(node_a.get_id(), node_a.get_addrs(), 10) - node_b.get_peerstore().add_addrs(node_c.get_id(), node_c.get_addrs(), 10) - - node_c.get_peerstore().add_addrs(node_a.get_id(), node_a.get_addrs(), 10) - node_c.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) - - stream_a_to_b = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) - stream_a_to_c = await node_a.new_stream(node_c.get_id(), ["/echo/1.0.0"]) - - stream_b_to_a = await node_b.new_stream(node_a.get_id(), ["/echo/1.0.0"]) - stream_b_to_c = await node_b.new_stream(node_c.get_id(), ["/echo/1.0.0"]) - - stream_c_to_a = await node_c.new_stream(node_a.get_id(), ["/echo/1.0.0"]) - stream_c_to_b = await node_c.new_stream(node_b.get_id(), ["/echo/1.0.0"]) - - messages = ["hello" + str(x) for x in range(5)] - streams = [ - stream_a_to_b, - stream_a_to_c, - stream_b_to_a, - stream_b_to_c, - stream_c_to_a, - stream_c_to_b, - ] - - for message in messages: - for stream in streams: + messages = ["hello" + str(x) for x in range(10)] + for message in messages: await stream.write(message.encode()) - response = (await stream.read(MAX_READ_LEN)).decode() + response1 = (await stream.read(MAX_READ_LEN)).decode() + assert response1 == (ACK_STR_0 + message) - assert response == ("ack:" + message) - - # Success, terminate pending tasks. + response2 = (await stream.read(MAX_READ_LEN)).decode() + assert response2 == (ACK_STR_1 + message) -@pytest.mark.asyncio -async def test_host_connect(): - transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] - (node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list) +@pytest.mark.trio +async def test_multiple_streams(is_host_secure): + # hosts[0] should be able to open a stream with hosts[1] and then vice versa. + # Stream IDs should be generated uniquely so that the stream state is not overwritten - assert not node_a.get_peerstore().peer_ids() + async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts: + hosts[0].set_stream_handler( + PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0) + ) + hosts[1].set_stream_handler( + PROTOCOL_ID_1, create_echo_stream_handler(ACK_STR_1) + ) - addr = node_b.get_addrs()[0] - info = info_from_p2p_addr(addr) - await node_a.connect(info) + # Associate the peer with local ip address (see default parameters of Libp2p()) + hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10) + hosts[1].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10) - assert len(node_a.get_peerstore().peer_ids()) == 1 + stream_a = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_1]) + stream_b = await hosts[1].new_stream(hosts[0].get_id(), [PROTOCOL_ID_0]) - await node_a.connect(info) + # A writes to /echo_b via stream_a, and B writes to /echo_a via stream_b + messages = ["hello" + str(x) for x in range(10)] + for message in messages: + a_message = message + "_a" + b_message = message + "_b" - # make sure we don't do double connection - assert len(node_a.get_peerstore().peer_ids()) == 1 + await stream_a.write(a_message.encode()) + await stream_b.write(b_message.encode()) - assert node_b.get_id() in node_a.get_peerstore().peer_ids() - ma_node_b = multiaddr.Multiaddr("/p2p/%s" % node_b.get_id().pretty()) - for addr in node_a.get_peerstore().addrs(node_b.get_id()): - assert addr.encapsulate(ma_node_b) in node_b.get_addrs() + response_a = (await stream_a.read(MAX_READ_LEN)).decode() + response_b = (await stream_b.read(MAX_READ_LEN)).decode() - # Success, terminate pending tasks. + assert response_a == (ACK_STR_1 + a_message) and response_b == ( + ACK_STR_0 + b_message + ) + + +@pytest.mark.trio +async def test_multiple_streams_same_initiator_different_protocols(is_host_secure): + async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts: + + hosts[1].set_stream_handler( + PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0) + ) + hosts[1].set_stream_handler( + PROTOCOL_ID_1, create_echo_stream_handler(ACK_STR_1) + ) + hosts[1].set_stream_handler( + PROTOCOL_ID_2, create_echo_stream_handler(ACK_STR_2) + ) + + # Associate the peer with local ip address (see default parameters of Libp2p()) + hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10) + hosts[1].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10) + + # Open streams to hosts[1] over echo_a1 echo_a2 echo_a3 protocols + stream_a1 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0]) + stream_a2 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_1]) + stream_a3 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_2]) + + messages = ["hello" + str(x) for x in range(10)] + for message in messages: + a1_message = message + "_a1" + a2_message = message + "_a2" + a3_message = message + "_a3" + + await stream_a1.write(a1_message.encode()) + await stream_a2.write(a2_message.encode()) + await stream_a3.write(a3_message.encode()) + + response_a1 = (await stream_a1.read(MAX_READ_LEN)).decode() + response_a2 = (await stream_a2.read(MAX_READ_LEN)).decode() + response_a3 = (await stream_a3.read(MAX_READ_LEN)).decode() + + assert ( + response_a1 == (ACK_STR_0 + a1_message) + and response_a2 == (ACK_STR_1 + a2_message) + and response_a3 == (ACK_STR_2 + a3_message) + ) + + # Success, terminate pending tasks. + + +@pytest.mark.trio +async def test_multiple_streams_two_initiators(is_host_secure): + async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts: + hosts[0].set_stream_handler( + PROTOCOL_ID_2, create_echo_stream_handler(ACK_STR_2) + ) + hosts[0].set_stream_handler( + PROTOCOL_ID_3, create_echo_stream_handler(ACK_STR_3) + ) + + hosts[1].set_stream_handler( + PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0) + ) + hosts[1].set_stream_handler( + PROTOCOL_ID_1, create_echo_stream_handler(ACK_STR_1) + ) + + # Associate the peer with local ip address (see default parameters of Libp2p()) + hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10) + hosts[1].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10) + + stream_a1 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0]) + stream_a2 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_1]) + + stream_b1 = await hosts[1].new_stream(hosts[0].get_id(), [PROTOCOL_ID_2]) + stream_b2 = await hosts[1].new_stream(hosts[0].get_id(), [PROTOCOL_ID_3]) + + # A writes to /echo_b via stream_a, and B writes to /echo_a via stream_b + messages = ["hello" + str(x) for x in range(10)] + for message in messages: + a1_message = message + "_a1" + a2_message = message + "_a2" + + b1_message = message + "_b1" + b2_message = message + "_b2" + + await stream_a1.write(a1_message.encode()) + await stream_a2.write(a2_message.encode()) + + await stream_b1.write(b1_message.encode()) + await stream_b2.write(b2_message.encode()) + + response_a1 = (await stream_a1.read(MAX_READ_LEN)).decode() + response_a2 = (await stream_a2.read(MAX_READ_LEN)).decode() + + response_b1 = (await stream_b1.read(MAX_READ_LEN)).decode() + response_b2 = (await stream_b2.read(MAX_READ_LEN)).decode() + + assert ( + response_a1 == (ACK_STR_0 + a1_message) + and response_a2 == (ACK_STR_1 + a2_message) + and response_b1 == (ACK_STR_2 + b1_message) + and response_b2 == (ACK_STR_3 + b2_message) + ) + + +@pytest.mark.trio +async def test_triangle_nodes_connection(is_host_secure): + async with HostFactory.create_batch_and_listen(is_host_secure, 3) as hosts: + + hosts[0].set_stream_handler( + PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0) + ) + hosts[1].set_stream_handler( + PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0) + ) + hosts[2].set_stream_handler( + PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0) + ) + + # Associate the peer with local ip address (see default parameters of Libp2p()) + # Associate all permutations + hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10) + hosts[0].get_peerstore().add_addrs(hosts[2].get_id(), hosts[2].get_addrs(), 10) + + hosts[1].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10) + hosts[1].get_peerstore().add_addrs(hosts[2].get_id(), hosts[2].get_addrs(), 10) + + hosts[2].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10) + hosts[2].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10) + + stream_0_to_1 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0]) + stream_0_to_2 = await hosts[0].new_stream(hosts[2].get_id(), [PROTOCOL_ID_0]) + + stream_1_to_0 = await hosts[1].new_stream(hosts[0].get_id(), [PROTOCOL_ID_0]) + stream_1_to_2 = await hosts[1].new_stream(hosts[2].get_id(), [PROTOCOL_ID_0]) + + stream_2_to_0 = await hosts[2].new_stream(hosts[0].get_id(), [PROTOCOL_ID_0]) + stream_2_to_1 = await hosts[2].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0]) + + messages = ["hello" + str(x) for x in range(5)] + streams = [ + stream_0_to_1, + stream_0_to_2, + stream_1_to_0, + stream_1_to_2, + stream_2_to_0, + stream_2_to_1, + ] + + for message in messages: + for stream in streams: + await stream.write(message.encode()) + response = (await stream.read(MAX_READ_LEN)).decode() + assert response == (ACK_STR_0 + message) + + +@pytest.mark.trio +async def test_host_connect(is_host_secure): + async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts: + assert not hosts[0].get_peerstore().peer_ids() + + addr = hosts[1].get_addrs()[0] + info = info_from_p2p_addr(addr) + await hosts[0].connect(info) + + assert len(hosts[0].get_peerstore().peer_ids()) == 1 + + await hosts[0].connect(info) + + # make sure we don't do double connection + assert len(hosts[0].get_peerstore().peer_ids()) == 1 + + assert hosts[1].get_id() in hosts[0].get_peerstore().peer_ids() + ma_node_b = multiaddr.Multiaddr("/p2p/%s" % hosts[1].get_id().pretty()) + for addr in hosts[0].get_peerstore().addrs(hosts[1].get_id()): + assert addr.encapsulate(ma_node_b) in hosts[1].get_addrs() diff --git a/tests/network/test_net_stream.py b/tests/network/test_net_stream.py index 2c2772a1..b558f1dd 100644 --- a/tests/network/test_net_stream.py +++ b/tests/network/test_net_stream.py @@ -1,6 +1,5 @@ -import trio - import pytest +import trio from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset from libp2p.tools.constants import MAX_READ_LEN diff --git a/tests/network/test_notify.py b/tests/network/test_notify.py index 1d9982e8..328ff128 100644 --- a/tests/network/test_notify.py +++ b/tests/network/test_notify.py @@ -8,12 +8,11 @@ into network after network has already started listening TODO: Add tests for closed_stream, listen_close when those features are implemented in swarm """ -import trio import enum -import pytest - from async_service import background_trio_service +import pytest +import trio from libp2p.network.notifee_interface import INotifee from libp2p.tools.constants import LISTEN_MADDR diff --git a/tests/network/test_swarm_conn.py b/tests/network/test_swarm_conn.py index 0b93808d..e1c1285c 100644 --- a/tests/network/test_swarm_conn.py +++ b/tests/network/test_swarm_conn.py @@ -1,6 +1,5 @@ -import trio - import pytest +import trio @pytest.mark.trio diff --git a/tests/protocol_muxer/test_protocol_muxer.py b/tests/protocol_muxer/test_protocol_muxer.py index 42dae60c..9533d1f3 100644 --- a/tests/protocol_muxer/test_protocol_muxer.py +++ b/tests/protocol_muxer/test_protocol_muxer.py @@ -1,83 +1,95 @@ import pytest from libp2p.host.exceptions import StreamFailure -from libp2p.tools.utils import echo_stream_handler, set_up_nodes_by_transport_opt +from libp2p.tools.factories import HostFactory +from libp2p.tools.utils import create_echo_stream_handler -# TODO: Add tests for multiple streams being opened on different -# protocols through the same connection -# Note: async issues occurred when using the same port -# so that's why I use different ports here. -# TODO: modify tests so that those async issues don't occur -# when using the same ports across tests +PROTOCOL_ECHO = "/echo/1.0.0" +PROTOCOL_POTATO = "/potato/1.0.0" +PROTOCOL_FOO = "/foo/1.0.0" +PROTOCOL_ROCK = "/rock/1.0.0" + +ACK_PREFIX = "ack:" async def perform_simple_test( - expected_selected_protocol, protocols_for_client, protocols_with_handlers + expected_selected_protocol, + protocols_for_client, + protocols_with_handlers, + is_host_secure, ): - transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] - (node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list) + async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts: + for protocol in protocols_with_handlers: + hosts[1].set_stream_handler( + protocol, create_echo_stream_handler(ACK_PREFIX) + ) - for protocol in protocols_with_handlers: - node_b.set_stream_handler(protocol, echo_stream_handler) + # Associate the peer with local ip address (see default parameters of Libp2p()) + hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10) + stream = await hosts[0].new_stream(hosts[1].get_id(), protocols_for_client) + messages = ["hello" + str(x) for x in range(10)] + for message in messages: + expected_resp = "ack:" + message + await stream.write(message.encode()) + response = (await stream.read(len(expected_resp))).decode() + assert response == expected_resp - # Associate the peer with local ip address (see default parameters of Libp2p()) - node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) - - stream = await node_a.new_stream(node_b.get_id(), protocols_for_client) - messages = ["hello" + str(x) for x in range(10)] - for message in messages: - expected_resp = "ack:" + message - await stream.write(message.encode()) - response = (await stream.read(len(expected_resp))).decode() - assert response == expected_resp - - assert expected_selected_protocol == stream.get_protocol() - - # Success, terminate pending tasks. + assert expected_selected_protocol == stream.get_protocol() -@pytest.mark.asyncio -async def test_single_protocol_succeeds(): - expected_selected_protocol = "/echo/1.0.0" +@pytest.mark.trio +async def test_single_protocol_succeeds(is_host_secure): + expected_selected_protocol = PROTOCOL_ECHO await perform_simple_test( - expected_selected_protocol, ["/echo/1.0.0"], ["/echo/1.0.0"] + expected_selected_protocol, + [expected_selected_protocol], + [expected_selected_protocol], + is_host_secure, ) -@pytest.mark.asyncio -async def test_single_protocol_fails(): +@pytest.mark.trio +async def test_single_protocol_fails(is_host_secure): with pytest.raises(StreamFailure): - await perform_simple_test("", ["/echo/1.0.0"], ["/potato/1.0.0"]) + await perform_simple_test( + "", [PROTOCOL_ECHO], [PROTOCOL_POTATO], is_host_secure + ) # Cleanup not reached on error -@pytest.mark.asyncio -async def test_multiple_protocol_first_is_valid_succeeds(): - expected_selected_protocol = "/echo/1.0.0" - protocols_for_client = ["/echo/1.0.0", "/potato/1.0.0"] - protocols_for_listener = ["/foo/1.0.0", "/echo/1.0.0"] +@pytest.mark.trio +async def test_multiple_protocol_first_is_valid_succeeds(is_host_secure): + expected_selected_protocol = PROTOCOL_ECHO + protocols_for_client = [PROTOCOL_ECHO, PROTOCOL_POTATO] + protocols_for_listener = [PROTOCOL_FOO, PROTOCOL_ECHO] await perform_simple_test( - expected_selected_protocol, protocols_for_client, protocols_for_listener + expected_selected_protocol, + protocols_for_client, + protocols_for_listener, + is_host_secure, ) -@pytest.mark.asyncio -async def test_multiple_protocol_second_is_valid_succeeds(): - expected_selected_protocol = "/foo/1.0.0" - protocols_for_client = ["/rock/1.0.0", "/foo/1.0.0"] - protocols_for_listener = ["/foo/1.0.0", "/echo/1.0.0"] +@pytest.mark.trio +async def test_multiple_protocol_second_is_valid_succeeds(is_host_secure): + expected_selected_protocol = PROTOCOL_FOO + protocols_for_client = [PROTOCOL_ROCK, PROTOCOL_FOO] + protocols_for_listener = [PROTOCOL_FOO, PROTOCOL_ECHO] await perform_simple_test( - expected_selected_protocol, protocols_for_client, protocols_for_listener + expected_selected_protocol, + protocols_for_client, + protocols_for_listener, + is_host_secure, ) -@pytest.mark.asyncio -async def test_multiple_protocol_fails(): - protocols_for_client = ["/rock/1.0.0", "/foo/1.0.0", "/bar/1.0.0"] +@pytest.mark.trio +async def test_multiple_protocol_fails(is_host_secure): + protocols_for_client = [PROTOCOL_ROCK, PROTOCOL_FOO, "/bar/1.0.0"] protocols_for_listener = ["/aspyn/1.0.0", "/rob/1.0.0", "/zx/1.0.0", "/alex/1.0.0"] with pytest.raises(StreamFailure): - await perform_simple_test("", protocols_for_client, protocols_for_listener) - - # Cleanup not reached on error + await perform_simple_test( + "", protocols_for_client, protocols_for_listener, is_host_secure + ) From 6149aacc01ec6ddd931938ec8beb4e9eed3b6f67 Mon Sep 17 00:00:00 2001 From: mhchia Date: Sun, 1 Dec 2019 17:55:07 +0800 Subject: [PATCH 15/81] Fix `examples` --- .../{test_chat.py => test_examples.py} | 26 +++++++++---------- tests/identity/identify/test_protocol.py | 6 ++--- 2 files changed, 15 insertions(+), 17 deletions(-) rename tests/examples/{test_chat.py => test_examples.py} (85%) diff --git a/tests/examples/test_chat.py b/tests/examples/test_examples.py similarity index 85% rename from tests/examples/test_chat.py rename to tests/examples/test_examples.py index 536b5a05..c8ce9ed8 100644 --- a/tests/examples/test_chat.py +++ b/tests/examples/test_examples.py @@ -1,10 +1,10 @@ -import asyncio - import pytest +import trio from libp2p.host.exceptions import StreamFailure from libp2p.peer.peerinfo import info_from_p2p_addr -from libp2p.tools.utils import set_up_nodes_by_transport_opt +from libp2p.tools.factories import HostFactory +from libp2p.tools.utils import MAX_READ_LEN PROTOCOL_ID = "/chat/1.0.0" @@ -25,7 +25,7 @@ async def hello_world(host_a, host_b): # Multiaddress of the destination peer is fetched from the peerstore using 'peerId'. stream = await host_b.new_stream(host_a.get_id(), [PROTOCOL_ID]) await stream.write(hello_world_from_host_b) - read = await stream.read() + read = await stream.read(MAX_READ_LEN) assert read == hello_world_from_host_a await stream.close() @@ -47,7 +47,7 @@ async def connect_write(host_a, host_b): await stream.write(message.encode()) # Reader needs time due to async reads - await asyncio.sleep(2) + await trio.sleep(2) await stream.close() assert received == messages @@ -88,16 +88,14 @@ async def no_common_protocol(host_a, host_b): await host_b.new_stream(host_a.get_id(), ["/fakeproto/0.0.1"]) -@pytest.mark.asyncio @pytest.mark.parametrize( "test", [(hello_world), (connect_write), (connect_read), (no_common_protocol)] ) -async def test_chat(test): - transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] - (host_a, host_b) = await set_up_nodes_by_transport_opt(transport_opt_list) +@pytest.mark.trio +async def test_chat(test, is_host_secure): + async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts: + addr = hosts[0].get_addrs()[0] + info = info_from_p2p_addr(addr) + await hosts[1].connect(info) - addr = host_a.get_addrs()[0] - info = info_from_p2p_addr(addr) - await host_b.connect(info) - - await test(host_a, host_b) + await test(hosts[0], hosts[1]) diff --git a/tests/identity/identify/test_protocol.py b/tests/identity/identify/test_protocol.py index 6136c876..4bbdbcba 100644 --- a/tests/identity/identify/test_protocol.py +++ b/tests/identity/identify/test_protocol.py @@ -5,9 +5,9 @@ from libp2p.identity.identify.protocol import ID, _mk_identify_protobuf from libp2p.tools.factories import host_pair_factory -@pytest.mark.asyncio -async def test_identify_protocol(): - async with host_pair_factory() as (host_a, host_b): +@pytest.mark.trio +async def test_identify_protocol(is_host_secure): + async with host_pair_factory(is_host_secure) as (host_a, host_b): stream = await host_b.new_stream(host_a.get_id(), (ID,)) response = await stream.read() await stream.close() From eb494e8682334f77d1cffaf5a2369281b8de0a85 Mon Sep 17 00:00:00 2001 From: mhchia Date: Sun, 1 Dec 2019 19:17:32 +0800 Subject: [PATCH 16/81] Fix ping protocol --- libp2p/host/ping.py | 35 ++++++++++++----------- libp2p/stream_muxer/mplex/mplex.py | 1 - libp2p/stream_muxer/mplex/mplex_stream.py | 3 +- tests/host/test_ping.py | 17 +++++------ 4 files changed, 29 insertions(+), 27 deletions(-) diff --git a/libp2p/host/ping.py b/libp2p/host/ping.py index 3144ef4d..589fc917 100644 --- a/libp2p/host/ping.py +++ b/libp2p/host/ping.py @@ -1,4 +1,4 @@ -import asyncio +import trio import logging from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset @@ -16,22 +16,23 @@ logger = logging.getLogger("libp2p.host.ping") async def _handle_ping(stream: INetStream, peer_id: PeerID) -> bool: """Return a boolean indicating if we expect more pings from the peer at ``peer_id``.""" - try: - payload = await asyncio.wait_for(stream.read(PING_LENGTH), RESP_TIMEOUT) - except asyncio.TimeoutError as error: - logger.debug("Timed out waiting for ping from %s: %s", peer_id, error) - raise - except StreamEOF: - logger.debug("Other side closed while waiting for ping from %s", peer_id) - return False - except StreamReset as error: - logger.debug( - "Other side reset while waiting for ping from %s: %s", peer_id, error - ) - raise - except Exception as error: - logger.debug("Error while waiting to read ping for %s: %s", peer_id, error) - raise + with trio.fail_after(RESP_TIMEOUT): + try: + payload = await stream.read(PING_LENGTH) + except trio.TooSlowError as error: + logger.debug("Timed out waiting for ping from %s: %s", peer_id, error) + raise + except StreamEOF: + logger.debug("Other side closed while waiting for ping from %s", peer_id) + return False + except StreamReset as error: + logger.debug( + "Other side reset while waiting for ping from %s: %s", peer_id, error + ) + raise + except Exception as error: + logger.debug("Error while waiting to read ping for %s: %s", peer_id, error) + raise logger.debug("Received ping from %s with data: 0x%s", peer_id, payload.hex()) diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index f93acea7..6d8a64e6 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -1,4 +1,3 @@ -import asyncio import logging import math from typing import Any # noqa: F401 diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 6ecc4077..eeefc422 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -104,6 +104,7 @@ class MplexStream(IMuxedStream): # and then return. try: data = self.incoming_data_channel.receive_nowait() + self._buf.extend(data) except trio.EndOfChannel: raise MplexStreamEOF except trio.WouldBlock: @@ -111,6 +112,7 @@ class MplexStream(IMuxedStream): # catch all kinds of errors here. try: data = await self.incoming_data_channel.receive() + self._buf.extend(data) except trio.EndOfChannel: if self.event_reset.is_set(): raise MplexStreamReset @@ -125,7 +127,6 @@ class MplexStream(IMuxedStream): "`incoming_data_channel` is closed but stream is not reset. " "This should never happen." ) from error - self._buf.extend(data) self._buf.extend(self._read_return_when_blocked()) payload = self._buf[:n] self._buf = self._buf[len(payload) :] diff --git a/tests/host/test_ping.py b/tests/host/test_ping.py index 1bd02f0f..29135141 100644 --- a/tests/host/test_ping.py +++ b/tests/host/test_ping.py @@ -1,4 +1,4 @@ -import asyncio +import trio import secrets import pytest @@ -7,12 +7,13 @@ from libp2p.host.ping import ID, PING_LENGTH from libp2p.tools.factories import host_pair_factory -@pytest.mark.asyncio -async def test_ping_once(): - async with host_pair_factory() as (host_a, host_b): +@pytest.mark.trio +async def test_ping_once(is_host_secure): + async with host_pair_factory(is_host_secure) as (host_a, host_b): stream = await host_b.new_stream(host_a.get_id(), (ID,)) some_ping = secrets.token_bytes(PING_LENGTH) await stream.write(some_ping) + await trio.sleep(0.01) some_pong = await stream.read(PING_LENGTH) assert some_ping == some_pong await stream.close() @@ -21,9 +22,9 @@ async def test_ping_once(): SOME_PING_COUNT = 3 -@pytest.mark.asyncio -async def test_ping_several(): - async with host_pair_factory() as (host_a, host_b): +@pytest.mark.trio +async def test_ping_several(is_host_secure): + async with host_pair_factory(is_host_secure) as (host_a, host_b): stream = await host_b.new_stream(host_a.get_id(), (ID,)) for _ in range(SOME_PING_COUNT): some_ping = secrets.token_bytes(PING_LENGTH) @@ -33,5 +34,5 @@ async def test_ping_several(): # NOTE: simulate some time to sleep to mirror a real # world usage where a peer sends pings on some periodic interval # NOTE: this interval can be `0` for this test. - await asyncio.sleep(0) + await trio.sleep(0) await stream.close() From bdbb7b239422a9c88641f5a85b666aabf3e10622 Mon Sep 17 00:00:00 2001 From: mhchia Date: Sun, 1 Dec 2019 19:17:44 +0800 Subject: [PATCH 17/81] Add `RoutedHostFactory` And skip the tests for `RoutedHost` for now, since there are too many to be fixed in `Kademlia`, and it's not that necessary now. --- libp2p/tools/factories.py | 27 ++++++++++++++++ libp2p/tools/utils.py | 6 ++-- tests/host/test_routed_host.py | 57 +++++++++++++++++----------------- 3 files changed, 58 insertions(+), 32 deletions(-) diff --git a/libp2p/tools/factories.py b/libp2p/tools/factories.py index 470cbc31..6b6c78a5 100644 --- a/libp2p/tools/factories.py +++ b/libp2p/tools/factories.py @@ -8,6 +8,9 @@ import trio from libp2p import generate_new_rsa_identity, generate_peer_id_from from libp2p.crypto.keys import KeyPair from libp2p.host.basic_host import BasicHost +from libp2p.host.routed_host import RoutedHost +from libp2p.tools.utils import set_up_routers +from libp2p.kademlia.network import KademliaServer from libp2p.network.connection.swarm_connection import SwarmConn from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.swarm import Swarm @@ -127,6 +130,30 @@ class HostFactory(factory.Factory): yield hosts +class RoutedHostFactory(factory.Factory): + class Meta: + model = RoutedHost + + public_key = factory.LazyAttribute(lambda o: o.key_pair.public_key) + network = factory.LazyAttribute( + lambda o: SwarmFactory(is_secure=o.is_secure, key_pair=o.key_pair) + ) + router = factory.LazyFunction(KademliaServer) + + @classmethod + @asynccontextmanager + async def create_batch_and_listen( + cls, is_secure: bool, number: int + ) -> Tuple[RoutedHost, ...]: + key_pairs = [generate_new_rsa_identity() for _ in range(number)] + routers = await set_up_routers((0,) * number) + async with SwarmFactory.create_batch_and_listen(is_secure, number) as swarms: + yield tuple( + RoutedHost(key_pair.public_key, swarm, router) + for key_pair, swarm, router in zip(key_pairs, swarms, routers) + ) + + class FloodsubFactory(factory.Factory): class Meta: model = FloodSub diff --git a/libp2p/tools/utils.py b/libp2p/tools/utils.py index 5ec48867..9ad68150 100644 --- a/libp2p/tools/utils.py +++ b/libp2p/tools/utils.py @@ -61,15 +61,15 @@ async def set_up_nodes_by_transport_and_disc_opt( async def set_up_routers( - router_confs: Tuple[int, int] = (0, 0) + router_ports: Tuple[int, ...] = (0, 0) ) -> List[KadmeliaPeerRouter]: """The default ``router_confs`` selects two free ports local to this machine.""" bootstrap_node = KademliaServer() # type: ignore - await bootstrap_node.listen(router_confs[0]) + await bootstrap_node.listen(router_ports[0]) routers = [KadmeliaPeerRouter(bootstrap_node)] - for port in router_confs[1:]: + for port in router_ports[1:]: node = KademliaServer() # type: ignore await node.listen(port) diff --git a/tests/host/test_routed_host.py b/tests/host/test_routed_host.py index 9083d3fc..006dd222 100644 --- a/tests/host/test_routed_host.py +++ b/tests/host/test_routed_host.py @@ -1,5 +1,3 @@ -import asyncio - import pytest from libp2p.host.exceptions import ConnectionFailure @@ -10,38 +8,40 @@ from libp2p.tools.utils import ( set_up_nodes_by_transport_opt, set_up_routers, ) +from libp2p.tools.factories import RoutedHostFactory -@pytest.mark.asyncio -async def test_host_routing_success(): - routers = await set_up_routers() - transports = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] - transport_disc_opt_list = zip(transports, routers) - (host_a, host_b) = await set_up_nodes_by_transport_and_disc_opt( - transport_disc_opt_list - ) +# FIXME: - # Set routing info - await routers[0].server.set( - host_a.get_id().xor_id, - peer_info_to_str(PeerInfo(host_a.get_id(), host_a.get_addrs())), - ) - await routers[1].server.set( - host_b.get_id().xor_id, - peer_info_to_str(PeerInfo(host_b.get_id(), host_b.get_addrs())), - ) +# TODO: Kademlia is full of asyncio code. Skip it for now +@pytest.mark.skip +@pytest.mark.trio +async def test_host_routing_success(is_host_secure): + async with RoutedHostFactory.create_batch_and_listen( + is_host_secure, 2 + ) as routed_hosts: + # Set routing info + await routed_hosts[0]._router.server.set( + routed_hosts[0].get_id().xor_id, + peer_info_to_str( + PeerInfo(routed_hosts[0].get_id(), routed_hosts[0].get_addrs()) + ), + ) + await routed_hosts[1]._router.server.set( + routed_hosts[1].get_id().xor_id, + peer_info_to_str( + PeerInfo(routed_hosts[1].get_id(), routed_hosts[1].get_addrs()) + ), + ) - # forces to use routing as no addrs are provided - await host_a.connect(PeerInfo(host_b.get_id(), [])) - await host_b.connect(PeerInfo(host_a.get_id(), [])) - - # Clean up - await asyncio.gather(*[host_a.close(), host_b.close()]) - routers[0].server.stop() - routers[1].server.stop() + # forces to use routing as no addrs are provided + await routed_hosts[0].connect(PeerInfo(routed_hosts[1].get_id(), [])) + await routed_hosts[1].connect(PeerInfo(routed_hosts[0].get_id(), [])) -@pytest.mark.asyncio +# TODO: Kademlia is full of asyncio code. Skip it for now +@pytest.mark.skip +@pytest.mark.trio async def test_host_routing_fail(): routers = await set_up_routers() transports = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] @@ -69,6 +69,5 @@ async def test_host_routing_fail(): await host_b.connect(PeerInfo(host_c.get_id(), [])) # Clean up - await asyncio.gather(*[host_a.close(), host_b.close(), host_c.close()]) routers[0].server.stop() routers[1].server.stop() From e9ab0646e38a1e13fd0ba5e2d6362da459b241db Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 3 Dec 2019 17:27:49 +0800 Subject: [PATCH 18/81] Fix Pubsub --- libp2p/pubsub/floodsub.py | 8 + libp2p/pubsub/pubsub.py | 181 ++++--- libp2p/pubsub/pubsub_notifee.py | 19 +- libp2p/stream_muxer/mplex/mplex.py | 2 +- libp2p/tools/factories.py | 49 +- tests/pubsub/conftest.py | 32 +- tests/pubsub/test_pubsub.py | 800 ++++++++++++++--------------- 7 files changed, 568 insertions(+), 523 deletions(-) diff --git a/libp2p/pubsub/floodsub.py b/libp2p/pubsub/floodsub.py index bac0bd77..8c15a441 100644 --- a/libp2p/pubsub/floodsub.py +++ b/libp2p/pubsub/floodsub.py @@ -1,6 +1,8 @@ import logging from typing import Iterable, List, Sequence +import trio + from libp2p.network.stream.exceptions import StreamClosed from libp2p.peer.id import ID from libp2p.typing import TProtocol @@ -61,6 +63,8 @@ class FloodSub(IPubsubRouter): :param rpc: rpc message """ + # Checkpoint + await trio.sleep(0) async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None: """ @@ -102,6 +106,8 @@ class FloodSub(IPubsubRouter): :param topic: topic to join """ + # Checkpoint + await trio.sleep(0) async def leave(self, topic: str) -> None: """ @@ -110,6 +116,8 @@ class FloodSub(IPubsubRouter): :param topic: topic to leave """ + # Checkpoint + await trio.sleep(0) def _get_peers_to_send( self, topic_ids: Iterable[str], msg_forwarder: ID, origin: ID diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 3834eb4b..7c4b50de 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -1,11 +1,13 @@ -import asyncio +from abc import ABC, abstractmethod import logging +import math import time from typing import ( TYPE_CHECKING, Awaitable, Callable, Dict, + KeysView, List, NamedTuple, Tuple, @@ -13,8 +15,10 @@ from typing import ( cast, ) +from async_service import Service import base58 from lru import LRU +import trio from libp2p.exceptions import ParseError, ValidationError from libp2p.host.host_interface import IHost @@ -53,24 +57,24 @@ class TopicValidator(NamedTuple): is_async: bool -class Pubsub: +class BasePubsub(ABC): + pass + + +class Pubsub(BasePubsub, Service): host: IHost - my_id: ID router: "IPubsubRouter" - peer_queue: "asyncio.Queue[ID]" - dead_peer_queue: "asyncio.Queue[ID]" - - protocols: List[TProtocol] - - incoming_msgs_from_peers: "asyncio.Queue[rpc_pb2.Message]" - outgoing_messages: "asyncio.Queue[rpc_pb2.Message]" + peer_receive_channel: "trio.MemoryReceiveChannel[ID]" + dead_peer_receive_channel: "trio.MemoryReceiveChannel[ID]" seen_messages: LRU - my_topics: Dict[str, "asyncio.Queue[rpc_pb2.Message]"] + # TODO: Implement `trio.abc.Channel`? + subscribed_topics_send: Dict[str, "trio.MemorySendChannel[rpc_pb2.Message]"] + subscribed_topics_receive: Dict[str, "trio.MemoryReceiveChannel[rpc_pb2.Message]"] peer_topics: Dict[str, List[ID]] peers: Dict[ID, INetStream] @@ -80,10 +84,8 @@ class Pubsub: # TODO: Be sure it is increased atomically everytime. counter: int # uint64 - _tasks: List["asyncio.Future[Any]"] - def __init__( - self, host: IHost, router: "IPubsubRouter", my_id: ID, cache_size: int = None + self, host: IHost, router: "IPubsubRouter", cache_size: int = None ) -> None: """ Construct a new Pubsub object, which is responsible for handling all @@ -97,28 +99,26 @@ class Pubsub: """ self.host = host self.router = router - self.my_id = my_id # Attach this new Pubsub object to the router self.router.attach(self) + peer_send_channel, peer_receive_channel = trio.open_memory_channel(0) + dead_peer_send_channel, dead_peer_receive_channel = trio.open_memory_channel(0) + # Only keep the receive channels in `Pubsub`. + # Therefore, we can only close from the receive side. + self.peer_receive_channel = peer_receive_channel + self.dead_peer_receive_channel = dead_peer_receive_channel # Register a notifee - self.peer_queue = asyncio.Queue() - self.dead_peer_queue = asyncio.Queue() self.host.get_network().register_notifee( - PubsubNotifee(self.peer_queue, self.dead_peer_queue) + PubsubNotifee(peer_send_channel, dead_peer_send_channel) ) # Register stream handlers for each pubsub router protocol to handle # the pubsub streams opened on those protocols - self.protocols = self.router.get_protocols() - for protocol in self.protocols: + for protocol in router.protocols: self.host.set_stream_handler(protocol, self.stream_handler) - # Use asyncio queues for proper context switching - self.incoming_msgs_from_peers = asyncio.Queue() - self.outgoing_messages = asyncio.Queue() - # keeps track of seen messages as LRU cache if cache_size is None: self.cache_size = 128 @@ -129,7 +129,8 @@ class Pubsub: # Map of topics we are subscribed to blocking queues # for when the given topic receives a message - self.my_topics = {} + self.subscribed_topics_send = {} + self.subscribed_topics_receive = {} # Map of topic to peers to keep track of what peers are subscribed to self.peer_topics = {} @@ -142,16 +143,28 @@ class Pubsub: self.counter = time.time_ns() - self._tasks = [] - # Call handle peer to keep waiting for updates to peer queue - self._tasks.append(asyncio.ensure_future(self.handle_peer_queue())) - self._tasks.append(asyncio.ensure_future(self.handle_dead_peer_queue())) + async def run(self) -> None: + self.manager.run_daemon_task(self.handle_peer_queue) + self.manager.run_daemon_task(self.handle_dead_peer_queue) + await self.manager.wait_finished() + + @property + def my_id(self) -> ID: + return self.host.get_id() + + @property + def protocols(self) -> Tuple[TProtocol, ...]: + return tuple(self.router.get_protocols()) + + @property + def topic_ids(self) -> KeysView[str]: + return self.subscribed_topics_receive.keys() def get_hello_packet(self) -> rpc_pb2.RPC: """Generate subscription message with all topics we are subscribed to only send hello packet if we have subscribed topics.""" packet = rpc_pb2.RPC() - for topic_id in self.my_topics: + for topic_id in self.topic_ids: packet.subscriptions.extend( [rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id)] ) @@ -166,7 +179,7 @@ class Pubsub: """ peer_id = stream.muxed_conn.peer_id - while True: + while self.manager.is_running: incoming: bytes = await read_varint_prefixed_bytes(stream) rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC() rpc_incoming.ParseFromString(incoming) @@ -178,11 +191,7 @@ class Pubsub: logger.debug( "received `publish` message %s from peer %s", msg, peer_id ) - self._tasks.append( - asyncio.ensure_future( - self.push_msg(msg_forwarder=peer_id, msg=msg) - ) - ) + self.manager.run_task(self.push_msg, peer_id, msg) if rpc_incoming.subscriptions: # deal with RPC.subscriptions @@ -210,9 +219,6 @@ class Pubsub: ) await self.router.handle_rpc(rpc_incoming, peer_id) - # Force context switch - await asyncio.sleep(0) - def set_topic_validator( self, topic: str, validator: ValidatorFn, is_async_validator: bool ) -> None: @@ -285,7 +291,6 @@ class Pubsub: logger.debug("Fail to add new peer %s: stream closed", peer_id) del self.peers[peer_id] return - # TODO: Check EOF of this stream. # TODO: Check if the peer in black list. try: self.router.add_peer(peer_id, stream.get_protocol()) @@ -311,23 +316,25 @@ class Pubsub: async def handle_peer_queue(self) -> None: """ - Continuously read from peer queue and each time a new peer is found, + Continuously read from peer channel and each time a new peer is found, open a stream to the peer using a supported pubsub protocol TODO: Handle failure for when the peer does not support any of the pubsub protocols we support """ - while True: - peer_id: ID = await self.peer_queue.get() - # Add Peer - self._tasks.append(asyncio.ensure_future(self._handle_new_peer(peer_id))) + async with self.peer_receive_channel: + while self.manager.is_running: + peer_id: ID = await self.peer_receive_channel.receive() + # Add Peer + self.manager.run_task(self._handle_new_peer, peer_id) async def handle_dead_peer_queue(self) -> None: - """Continuously read from dead peer queue and close the stream between + """Continuously read from dead peer channel and close the stream between that peer and remove peer info from pubsub and pubsub router.""" - while True: - peer_id: ID = await self.dead_peer_queue.get() - # Remove Peer - self._handle_dead_peer(peer_id) + async with self.dead_peer_receive_channel: + while self.manager.is_running: + peer_id: ID = await self.dead_peer_receive_channel.receive() + # Remove Peer + self._handle_dead_peer(peer_id) def handle_subscription( self, origin_id: ID, sub_message: rpc_pb2.RPC.SubOpts @@ -361,13 +368,16 @@ class Pubsub: # Check if this message has any topics that we are subscribed to for topic in publish_message.topicIDs: - if topic in self.my_topics: + if topic in self.topic_ids: # we are subscribed to a topic this message was sent for, # so add message to the subscription output queue # for each topic - await self.my_topics[topic].put(publish_message) + await self.subscribed_topics_send[topic].send(publish_message) - async def subscribe(self, topic_id: str) -> "asyncio.Queue[rpc_pb2.Message]": + # TODO: Change to return an `AsyncIterable` to be I/O-agnostic? + async def subscribe( + self, topic_id: str + ) -> "trio.MemoryReceiveChannel[rpc_pb2.Message]": """ Subscribe ourself to a topic. @@ -377,11 +387,13 @@ class Pubsub: logger.debug("subscribing to topic %s", topic_id) # Already subscribed - if topic_id in self.my_topics: - return self.my_topics[topic_id] + if topic_id in self.topic_ids: + return self.subscribed_topics_receive[topic_id] - # Map topic_id to blocking queue - self.my_topics[topic_id] = asyncio.Queue() + # Map topic_id to a blocking channel + send_channel, receive_channel = trio.open_memory_channel(math.inf) + self.subscribed_topics_send[topic_id] = send_channel + self.subscribed_topics_receive[topic_id] = receive_channel # Create subscribe message packet: rpc_pb2.RPC = rpc_pb2.RPC() @@ -395,8 +407,8 @@ class Pubsub: # Tell router we are joining this topic await self.router.join(topic_id) - # Return the asyncio queue for messages on this topic - return self.my_topics[topic_id] + # Return the trio channel for messages on this topic + return receive_channel async def unsubscribe(self, topic_id: str) -> None: """ @@ -408,10 +420,14 @@ class Pubsub: logger.debug("unsubscribing from topic %s", topic_id) # Return if we already unsubscribed from the topic - if topic_id not in self.my_topics: + if topic_id not in self.topic_ids: return - # Remove topic_id from map if present - del self.my_topics[topic_id] + # Remove topic_id from the maps before yielding + send_channel = self.subscribed_topics_send[topic_id] + del self.subscribed_topics_send[topic_id] + del self.subscribed_topics_receive[topic_id] + # Only close the send side + await send_channel.aclose() # Create unsubscribe message packet: rpc_pb2.RPC = rpc_pb2.RPC() @@ -453,13 +469,13 @@ class Pubsub: data=data, topicIDs=[topic_id], # Origin is ourself. - from_id=self.host.get_id().to_bytes(), + from_id=self.my_id.to_bytes(), seqno=self._next_seqno(), ) # TODO: Sign with our signing key - await self.push_msg(self.host.get_id(), msg) + await self.push_msg(self.my_id, msg) logger.debug("successfully published message %s", msg) @@ -470,12 +486,12 @@ class Pubsub: :param msg_forwarder: the peer who forward us the message. :param msg: the message. """ - sync_topic_validators = [] - async_topic_validator_futures: List[Awaitable[bool]] = [] + sync_topic_validators: List[SyncValidatorFn] = [] + async_topic_validators: List[AsyncValidatorFn] = [] for topic_validator in self.get_msg_validators(msg): if topic_validator.is_async: - async_topic_validator_futures.append( - cast(Awaitable[bool], topic_validator.validator(msg_forwarder, msg)) + async_topic_validators.append( + cast(AsyncValidatorFn, topic_validator.validator) ) else: sync_topic_validators.append( @@ -488,9 +504,20 @@ class Pubsub: # TODO: Implement throttle on async validators - if len(async_topic_validator_futures) > 0: - results = await asyncio.gather(*async_topic_validator_futures) - if not all(results): + if len(async_topic_validators) > 0: + # TODO: Use a better pattern + final_result = True + + async def run_async_validator(func: AsyncValidatorFn) -> None: + nonlocal final_result + result = await func(msg_forwarder, msg) + final_result = final_result and result + + async with trio.open_nursery() as nursery: + for validator in async_topic_validators: + nursery.start_soon(run_async_validator, validator) + + if not final_result: raise ValidationError(f"Validation failed for msg={msg}") async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None: @@ -551,14 +578,4 @@ class Pubsub: self.seen_messages[msg_id] = 1 def _is_subscribed_to_msg(self, msg: rpc_pb2.Message) -> bool: - if not self.my_topics: - return False - return any(topic in self.my_topics for topic in msg.topicIDs) - - async def close(self) -> None: - for task in self._tasks: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass + return any(topic in self.topic_ids for topic in msg.topicIDs) diff --git a/libp2p/pubsub/pubsub_notifee.py b/libp2p/pubsub/pubsub_notifee.py index 6afa9ad2..7394736e 100644 --- a/libp2p/pubsub/pubsub_notifee.py +++ b/libp2p/pubsub/pubsub_notifee.py @@ -8,19 +8,19 @@ from libp2p.network.notifee_interface import INotifee from libp2p.network.stream.net_stream_interface import INetStream if TYPE_CHECKING: - import asyncio # noqa: F401 + import trio # noqa: F401 from libp2p.peer.id import ID # noqa: F401 class PubsubNotifee(INotifee): - initiator_peers_queue: "asyncio.Queue[ID]" - dead_peers_queue: "asyncio.Queue[ID]" + initiator_peers_queue: "trio.MemorySendChannel[ID]" + dead_peers_queue: "trio.MemorySendChannel[ID]" def __init__( self, - initiator_peers_queue: "asyncio.Queue[ID]", - dead_peers_queue: "asyncio.Queue[ID]", + initiator_peers_queue: "trio.MemorySendChannel[ID]", + dead_peers_queue: "trio.MemorySendChannel[ID]", ) -> None: """ :param initiator_peers_queue: queue to add new peers to so that pubsub @@ -46,7 +46,12 @@ class PubsubNotifee(INotifee): :param network: network the connection was opened on :param conn: connection that was opened """ - await self.initiator_peers_queue.put(conn.muxed_conn.peer_id) + try: + await self.initiator_peers_queue.send(conn.muxed_conn.peer_id) + except trio.BrokenResourceError: + # Raised when the receive channel is closed. + # TODO: Do something with loggers? + ... async def disconnected(self, network: INetwork, conn: INetConn) -> None: """ @@ -56,7 +61,7 @@ class PubsubNotifee(INotifee): :param network: network the connection was opened on :param conn: connection that was opened """ - await self.dead_peers_queue.put(conn.muxed_conn.peer_id) + await self.dead_peers_queue.send(conn.muxed_conn.peer_id) async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None: pass diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 6d8a64e6..ac6cdcdb 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -281,7 +281,7 @@ class Mplex(IMuxedConn, Service): mplex_stream = await self._initialize_stream(stream_id, message.decode()) try: await self.new_stream_send_channel.send(mplex_stream) - except (trio.BrokenResourceError, trio.EndOfChannel): + except (trio.BrokenResourceError, trio.ClosedResourceError): raise MplexUnavailable async def _handle_message(self, stream_id: StreamID, message: bytes) -> None: diff --git a/libp2p/tools/factories.py b/libp2p/tools/factories.py index 6b6c78a5..ac243012 100644 --- a/libp2p/tools/factories.py +++ b/libp2p/tools/factories.py @@ -5,6 +5,7 @@ from async_service import background_trio_service import factory import trio +from libp2p.tools.constants import GOSSIPSUB_PARAMS from libp2p import generate_new_rsa_identity, generate_peer_id_from from libp2p.crypto.keys import KeyPair from libp2p.host.basic_host import BasicHost @@ -15,6 +16,7 @@ from libp2p.network.connection.swarm_connection import SwarmConn from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.swarm import Swarm from libp2p.peer.peerstore import PeerStore +from libp2p.peer.id import ID from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.gossipsub import GossipSub from libp2p.pubsub.pubsub import Pubsub @@ -28,15 +30,19 @@ from libp2p.transport.typing import TMuxerOptions from libp2p.transport.upgrader import TransportUpgrader from libp2p.typing import TProtocol -from .constants import ( - FLOODSUB_PROTOCOL_ID, - GOSSIPSUB_PARAMS, - GOSSIPSUB_PROTOCOL_ID, - LISTEN_MADDR, -) +from .constants import FLOODSUB_PROTOCOL_ID, GOSSIPSUB_PROTOCOL_ID, LISTEN_MADDR from .utils import connect, connect_swarm +class IDFactory(factory.Factory): + class Meta: + model = ID + + peer_id_bytes = factory.LazyFunction( + lambda: generate_peer_id_from(generate_new_rsa_identity()) + ) + + def security_transport_factory( is_secure: bool, key_pair: KeyPair ) -> Dict[TProtocol, BaseSecureTransport]: @@ -181,9 +187,38 @@ class PubsubFactory(factory.Factory): host = factory.SubFactory(HostFactory) router = None - my_id = factory.LazyAttribute(lambda obj: obj.host.get_id()) cache_size = None + @classmethod + @asynccontextmanager + async def create_and_start(cls, host, router, cache_size): + pubsub = PubsubFactory(host=host, router=router, cache_size=cache_size) + async with background_trio_service(pubsub): + yield pubsub + + @classmethod + @asynccontextmanager + async def create_batch_with_floodsub( + cls, number: int, is_secure: bool = False, cache_size: int = None + ): + floodsubs = FloodsubFactory.create_batch(number) + async with HostFactory.create_batch_and_listen(is_secure, number) as hosts: + # Pubsubs should exit before hosts + async with AsyncExitStack() as stack: + pubsubs = [ + await stack.enter_async_context( + cls.create_and_start(host, router, cache_size) + ) + for host, router in zip(hosts, floodsubs) + ] + yield pubsubs + + # @classmethod + # async def create_batch_with_gossipsub( + # cls, number: int, cache_size: int = None, gossipsub_params=GOSSIPSUB_PARAMS + # ): + # ... + @asynccontextmanager async def swarm_pair_factory( diff --git a/tests/pubsub/conftest.py b/tests/pubsub/conftest.py index 9dbe90b9..6c08dd78 100644 --- a/tests/pubsub/conftest.py +++ b/tests/pubsub/conftest.py @@ -4,18 +4,6 @@ from libp2p.tools.constants import GOSSIPSUB_PARAMS from libp2p.tools.factories import FloodsubFactory, GossipsubFactory, PubsubFactory -def _make_pubsubs(hosts, pubsub_routers, cache_size): - if len(pubsub_routers) != len(hosts): - raise ValueError( - f"lenght of pubsub_routers={pubsub_routers} should be equaled to the " - f"length of hosts={len(hosts)}" - ) - return tuple( - PubsubFactory(host=host, router=router, cache_size=cache_size) - for host, router in zip(hosts, pubsub_routers) - ) - - @pytest.fixture def pubsub_cache_size(): return None # default @@ -26,17 +14,9 @@ def gossipsub_params(): return GOSSIPSUB_PARAMS -@pytest.fixture -def pubsubs_fsub(num_hosts, hosts, pubsub_cache_size): - floodsubs = FloodsubFactory.create_batch(num_hosts) - _pubsubs_fsub = _make_pubsubs(hosts, floodsubs, pubsub_cache_size) - yield _pubsubs_fsub - # TODO: Clean up - - -@pytest.fixture -def pubsubs_gsub(num_hosts, hosts, pubsub_cache_size, gossipsub_params): - gossipsubs = GossipsubFactory.create_batch(num_hosts, **gossipsub_params._asdict()) - _pubsubs_gsub = _make_pubsubs(hosts, gossipsubs, pubsub_cache_size) - yield _pubsubs_gsub - # TODO: Clean up +# @pytest.fixture +# def pubsubs_gsub(num_hosts, hosts, pubsub_cache_size, gossipsub_params): +# gossipsubs = GossipsubFactory.create_batch(num_hosts, **gossipsub_params._asdict()) +# _pubsubs_gsub = _make_pubsubs(hosts, gossipsubs, pubsub_cache_size) +# yield _pubsubs_gsub +# # TODO: Clean up diff --git a/tests/pubsub/test_pubsub.py b/tests/pubsub/test_pubsub.py index ebe20037..22cea0c2 100644 --- a/tests/pubsub/test_pubsub.py +++ b/tests/pubsub/test_pubsub.py @@ -1,348 +1,328 @@ -import asyncio +from contextlib import contextmanager from typing import NamedTuple import pytest +import trio from libp2p.exceptions import ValidationError from libp2p.peer.id import ID from libp2p.pubsub.pb import rpc_pb2 from libp2p.tools.pubsub.utils import make_pubsub_msg from libp2p.tools.utils import connect +from libp2p.tools.constants import MAX_READ_LEN +from libp2p.tools.factories import PubsubFactory, net_stream_pair_factory, IDFactory from libp2p.utils import encode_varint_prefixed TESTING_TOPIC = "TEST_SUBSCRIBE" TESTING_DATA = b"data" -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_subscribe_and_unsubscribe(pubsubs_fsub): - await pubsubs_fsub[0].subscribe(TESTING_TOPIC) - assert TESTING_TOPIC in pubsubs_fsub[0].my_topics +@pytest.mark.trio +async def test_subscribe_and_unsubscribe(): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: + await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids - await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) - assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics + await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) + assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_re_subscribe(pubsubs_fsub): - await pubsubs_fsub[0].subscribe(TESTING_TOPIC) - assert TESTING_TOPIC in pubsubs_fsub[0].my_topics +@pytest.mark.trio +async def test_re_subscribe(): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: + await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids - await pubsubs_fsub[0].subscribe(TESTING_TOPIC) - assert TESTING_TOPIC in pubsubs_fsub[0].my_topics + await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_re_unsubscribe(pubsubs_fsub): - # Unsubscribe from topic we didn't even subscribe to - assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].my_topics - await pubsubs_fsub[0].unsubscribe("NOT_MY_TOPIC") - assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].my_topics +@pytest.mark.trio +async def test_re_unsubscribe(): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: + # Unsubscribe from topic we didn't even subscribe to + assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].topic_ids + await pubsubs_fsub[0].unsubscribe("NOT_MY_TOPIC") + assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].topic_ids - await pubsubs_fsub[0].subscribe(TESTING_TOPIC) - assert TESTING_TOPIC in pubsubs_fsub[0].my_topics + await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids - await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) - assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics + await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) + assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids - await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) - assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics + await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) + assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids -@pytest.mark.asyncio -async def test_peers_subscribe(pubsubs_fsub): - await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host) - await pubsubs_fsub[0].subscribe(TESTING_TOPIC) - # Yield to let 0 notify 1 - await asyncio.sleep(0.1) - assert pubsubs_fsub[0].my_id in pubsubs_fsub[1].peer_topics[TESTING_TOPIC] - await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) - # Yield to let 0 notify 1 - await asyncio.sleep(0.1) - assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics[TESTING_TOPIC] +@pytest.mark.trio +async def test_peers_subscribe(): + async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub: + await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host) + await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + # Yield to let 0 notify 1 + await trio.sleep(0.1) + assert pubsubs_fsub[0].my_id in pubsubs_fsub[1].peer_topics[TESTING_TOPIC] + await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) + # Yield to let 0 notify 1 + await trio.sleep(0.1) + assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics[TESTING_TOPIC] -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_get_hello_packet(pubsubs_fsub): - def _get_hello_packet_topic_ids(): - packet = pubsubs_fsub[0].get_hello_packet() - return tuple(sub.topicid for sub in packet.subscriptions) +@pytest.mark.trio +async def test_get_hello_packet(): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: - # Test: No subscription, so there should not be any topic ids in the hello packet. - assert len(_get_hello_packet_topic_ids()) == 0 + def _get_hello_packet_topic_ids(): + packet = pubsubs_fsub[0].get_hello_packet() + return tuple(sub.topicid for sub in packet.subscriptions) - # Test: After subscriptions, topic ids should be in the hello packet. - topic_ids = ["t", "o", "p", "i", "c"] - await asyncio.gather(*[pubsubs_fsub[0].subscribe(topic) for topic in topic_ids]) - topic_ids_in_hello = _get_hello_packet_topic_ids() - for topic in topic_ids: - assert topic in topic_ids_in_hello + # Test: No subscription, so there should not be any topic ids in the hello packet. + assert len(_get_hello_packet_topic_ids()) == 0 + + # Test: After subscriptions, topic ids should be in the hello packet. + topic_ids = ["t", "o", "p", "i", "c"] + for topic in topic_ids: + await pubsubs_fsub[0].subscribe(topic) + topic_ids_in_hello = _get_hello_packet_topic_ids() + for topic in topic_ids: + assert topic in topic_ids_in_hello -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_set_and_remove_topic_validator(pubsubs_fsub): +@pytest.mark.trio +async def test_set_and_remove_topic_validator(): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: + is_sync_validator_called = False - is_sync_validator_called = False + def sync_validator(peer_id, msg): + nonlocal is_sync_validator_called + is_sync_validator_called = True - def sync_validator(peer_id, msg): - nonlocal is_sync_validator_called - is_sync_validator_called = True + is_async_validator_called = False - is_async_validator_called = False + async def async_validator(peer_id, msg): + nonlocal is_async_validator_called + is_async_validator_called = True - async def async_validator(peer_id, msg): - nonlocal is_async_validator_called - is_async_validator_called = True + topic = "TEST_VALIDATOR" - topic = "TEST_VALIDATOR" + assert topic not in pubsubs_fsub[0].topic_validators - assert topic not in pubsubs_fsub[0].topic_validators + # Register sync validator + pubsubs_fsub[0].set_topic_validator(topic, sync_validator, False) - # Register sync validator - pubsubs_fsub[0].set_topic_validator(topic, sync_validator, False) + assert topic in pubsubs_fsub[0].topic_validators + topic_validator = pubsubs_fsub[0].topic_validators[topic] + assert not topic_validator.is_async - assert topic in pubsubs_fsub[0].topic_validators - topic_validator = pubsubs_fsub[0].topic_validators[topic] - assert not topic_validator.is_async + # Validate with sync validator + topic_validator.validator(peer_id=IDFactory(), msg="msg") - # Validate with sync validator - topic_validator.validator(peer_id=ID(b"peer"), msg="msg") + assert is_sync_validator_called + assert not is_async_validator_called - assert is_sync_validator_called - assert not is_async_validator_called + # Register with async validator + pubsubs_fsub[0].set_topic_validator(topic, async_validator, True) - # Register with async validator - pubsubs_fsub[0].set_topic_validator(topic, async_validator, True) + is_sync_validator_called = False + assert topic in pubsubs_fsub[0].topic_validators + topic_validator = pubsubs_fsub[0].topic_validators[topic] + assert topic_validator.is_async - is_sync_validator_called = False - assert topic in pubsubs_fsub[0].topic_validators - topic_validator = pubsubs_fsub[0].topic_validators[topic] - assert topic_validator.is_async + # Validate with async validator + await topic_validator.validator(peer_id=IDFactory(), msg="msg") - # Validate with async validator - await topic_validator.validator(peer_id=ID(b"peer"), msg="msg") + assert is_async_validator_called + assert not is_sync_validator_called - assert is_async_validator_called - assert not is_sync_validator_called - - # Remove validator - pubsubs_fsub[0].remove_topic_validator(topic) - assert topic not in pubsubs_fsub[0].topic_validators + # Remove validator + pubsubs_fsub[0].remove_topic_validator(topic) + assert topic not in pubsubs_fsub[0].topic_validators -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_get_msg_validators(pubsubs_fsub): +@pytest.mark.trio +async def test_get_msg_validators(): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: + times_sync_validator_called = 0 - times_sync_validator_called = 0 + def sync_validator(peer_id, msg): + nonlocal times_sync_validator_called + times_sync_validator_called += 1 - def sync_validator(peer_id, msg): - nonlocal times_sync_validator_called - times_sync_validator_called += 1 + times_async_validator_called = 0 - times_async_validator_called = 0 + async def async_validator(peer_id, msg): + nonlocal times_async_validator_called + times_async_validator_called += 1 - async def async_validator(peer_id, msg): - nonlocal times_async_validator_called - times_async_validator_called += 1 + topic_1 = "TEST_VALIDATOR_1" + topic_2 = "TEST_VALIDATOR_2" + topic_3 = "TEST_VALIDATOR_3" - topic_1 = "TEST_VALIDATOR_1" - topic_2 = "TEST_VALIDATOR_2" - topic_3 = "TEST_VALIDATOR_3" + # Register sync validator for topic 1 and 2 + pubsubs_fsub[0].set_topic_validator(topic_1, sync_validator, False) + pubsubs_fsub[0].set_topic_validator(topic_2, sync_validator, False) - # Register sync validator for topic 1 and 2 - pubsubs_fsub[0].set_topic_validator(topic_1, sync_validator, False) - pubsubs_fsub[0].set_topic_validator(topic_2, sync_validator, False) + # Register async validator for topic 3 + pubsubs_fsub[0].set_topic_validator(topic_3, async_validator, True) - # Register async validator for topic 3 - pubsubs_fsub[0].set_topic_validator(topic_3, async_validator, True) + msg = make_pubsub_msg( + origin_id=pubsubs_fsub[0].my_id, + topic_ids=[topic_1, topic_2, topic_3], + data=b"1234", + seqno=b"\x00" * 8, + ) - msg = make_pubsub_msg( - origin_id=pubsubs_fsub[0].my_id, - topic_ids=[topic_1, topic_2, topic_3], - data=b"1234", - seqno=b"\x00" * 8, - ) + topic_validators = pubsubs_fsub[0].get_msg_validators(msg) + for topic_validator in topic_validators: + if topic_validator.is_async: + await topic_validator.validator(peer_id=IDFactory(), msg="msg") + else: + topic_validator.validator(peer_id=IDFactory(), msg="msg") - topic_validators = pubsubs_fsub[0].get_msg_validators(msg) - for topic_validator in topic_validators: - if topic_validator.is_async: - await topic_validator.validator(peer_id=ID(b"peer"), msg="msg") - else: - topic_validator.validator(peer_id=ID(b"peer"), msg="msg") - - assert times_sync_validator_called == 2 - assert times_async_validator_called == 1 + assert times_sync_validator_called == 2 + assert times_async_validator_called == 1 -@pytest.mark.parametrize("num_hosts", (1,)) @pytest.mark.parametrize( "is_topic_1_val_passed, is_topic_2_val_passed", ((False, True), (True, False), (True, True)), ) -@pytest.mark.asyncio -async def test_validate_msg(pubsubs_fsub, is_topic_1_val_passed, is_topic_2_val_passed): - def passed_sync_validator(peer_id, msg): - return True +@pytest.mark.trio +async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: - def failed_sync_validator(peer_id, msg): - return False + def passed_sync_validator(peer_id, msg): + return True - async def passed_async_validator(peer_id, msg): - return True + def failed_sync_validator(peer_id, msg): + return False - async def failed_async_validator(peer_id, msg): - return False + async def passed_async_validator(peer_id, msg): + return True - topic_1 = "TEST_SYNC_VALIDATOR" - topic_2 = "TEST_ASYNC_VALIDATOR" + async def failed_async_validator(peer_id, msg): + return False - if is_topic_1_val_passed: - pubsubs_fsub[0].set_topic_validator(topic_1, passed_sync_validator, False) - else: - pubsubs_fsub[0].set_topic_validator(topic_1, failed_sync_validator, False) + topic_1 = "TEST_SYNC_VALIDATOR" + topic_2 = "TEST_ASYNC_VALIDATOR" - if is_topic_2_val_passed: - pubsubs_fsub[0].set_topic_validator(topic_2, passed_async_validator, True) - else: - pubsubs_fsub[0].set_topic_validator(topic_2, failed_async_validator, True) - - msg = make_pubsub_msg( - origin_id=pubsubs_fsub[0].my_id, - topic_ids=[topic_1, topic_2], - data=b"1234", - seqno=b"\x00" * 8, - ) - - if is_topic_1_val_passed and is_topic_2_val_passed: - await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg) - else: - with pytest.raises(ValidationError): - await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg) - - -class FakeNetStream: - _queue: asyncio.Queue - - class FakeMplexConn(NamedTuple): - peer_id: ID = ID(b"\x12\x20" + b"\x00" * 32) - - muxed_conn = FakeMplexConn() - - def __init__(self) -> None: - self._queue = asyncio.Queue() - - async def read(self, n: int = -1) -> bytes: - buf = bytearray() - # Force to blocking wait if no data available now. - if self._queue.empty(): - first_byte = await self._queue.get() - buf.extend(first_byte) - # If `n == -1`, read until no data is in the buffer(_queue). - # Else, read until no data is in the buffer(_queue) or we have read `n` bytes. - while (n == -1) or (len(buf) < n): - if self._queue.empty(): - break - buf.extend(await self._queue.get()) - return bytes(buf) - - async def write(self, data: bytes) -> int: - for i in data: - await self._queue.put(i.to_bytes(1, "big")) - return len(data) - - -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_continuously_read_stream(pubsubs_fsub, monkeypatch): - stream = FakeNetStream() - - await pubsubs_fsub[0].subscribe(TESTING_TOPIC) - - event_push_msg = asyncio.Event() - event_handle_subscription = asyncio.Event() - event_handle_rpc = asyncio.Event() - - async def mock_push_msg(msg_forwarder, msg): - event_push_msg.set() - - def mock_handle_subscription(origin_id, sub_message): - event_handle_subscription.set() - - async def mock_handle_rpc(rpc, sender_peer_id): - event_handle_rpc.set() - - monkeypatch.setattr(pubsubs_fsub[0], "push_msg", mock_push_msg) - monkeypatch.setattr( - pubsubs_fsub[0], "handle_subscription", mock_handle_subscription - ) - monkeypatch.setattr(pubsubs_fsub[0].router, "handle_rpc", mock_handle_rpc) - - async def wait_for_event_occurring(event): - try: - await asyncio.wait_for(event.wait(), timeout=1) - except asyncio.TimeoutError as error: - event.clear() - raise asyncio.TimeoutError( - f"Event {event} is not set before the timeout. " - "This indicates the mocked functions are not called properly." - ) from error + if is_topic_1_val_passed: + pubsubs_fsub[0].set_topic_validator(topic_1, passed_sync_validator, False) else: - event.clear() + pubsubs_fsub[0].set_topic_validator(topic_1, failed_sync_validator, False) - # Kick off the task `continuously_read_stream` - task = asyncio.ensure_future(pubsubs_fsub[0].continuously_read_stream(stream)) + if is_topic_2_val_passed: + pubsubs_fsub[0].set_topic_validator(topic_2, passed_async_validator, True) + else: + pubsubs_fsub[0].set_topic_validator(topic_2, failed_async_validator, True) - # Test: `push_msg` is called when publishing to a subscribed topic. - publish_subscribed_topic = rpc_pb2.RPC( - publish=[rpc_pb2.Message(topicIDs=[TESTING_TOPIC])] - ) - await stream.write( - encode_varint_prefixed(publish_subscribed_topic.SerializeToString()) - ) - await wait_for_event_occurring(event_push_msg) - # Make sure the other events are not emitted. - with pytest.raises(asyncio.TimeoutError): - await wait_for_event_occurring(event_handle_subscription) - with pytest.raises(asyncio.TimeoutError): - await wait_for_event_occurring(event_handle_rpc) + msg = make_pubsub_msg( + origin_id=pubsubs_fsub[0].my_id, + topic_ids=[topic_1, topic_2], + data=b"1234", + seqno=b"\x00" * 8, + ) - # Test: `push_msg` is not called when publishing to a topic-not-subscribed. - publish_not_subscribed_topic = rpc_pb2.RPC( - publish=[rpc_pb2.Message(topicIDs=["NOT_SUBSCRIBED"])] - ) - await stream.write( - encode_varint_prefixed(publish_not_subscribed_topic.SerializeToString()) - ) - with pytest.raises(asyncio.TimeoutError): - await wait_for_event_occurring(event_push_msg) + if is_topic_1_val_passed and is_topic_2_val_passed: + await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg) + else: + with pytest.raises(ValidationError): + await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg) - # Test: `handle_subscription` is called when a subscription message is received. - subscription_msg = rpc_pb2.RPC(subscriptions=[rpc_pb2.RPC.SubOpts()]) - await stream.write(encode_varint_prefixed(subscription_msg.SerializeToString())) - await wait_for_event_occurring(event_handle_subscription) - # Make sure the other events are not emitted. - with pytest.raises(asyncio.TimeoutError): - await wait_for_event_occurring(event_push_msg) - with pytest.raises(asyncio.TimeoutError): - await wait_for_event_occurring(event_handle_rpc) - # Test: `handle_rpc` is called when a control message is received. - control_msg = rpc_pb2.RPC(control=rpc_pb2.ControlMessage()) - await stream.write(encode_varint_prefixed(control_msg.SerializeToString())) - await wait_for_event_occurring(event_handle_rpc) - # Make sure the other events are not emitted. - with pytest.raises(asyncio.TimeoutError): - await wait_for_event_occurring(event_push_msg) - with pytest.raises(asyncio.TimeoutError): - await wait_for_event_occurring(event_handle_subscription) +@pytest.mark.trio +async def test_continuously_read_stream(monkeypatch, nursery, is_host_secure): + async def wait_for_event_occurring(event): + with trio.fail_after(0.1): + await event.wait() - task.cancel() + class Events(NamedTuple): + push_msg: trio.Event + handle_subscription: trio.Event + handle_rpc: trio.Event + + @contextmanager + def mock_methods(): + event_push_msg = trio.Event() + event_handle_subscription = trio.Event() + event_handle_rpc = trio.Event() + + async def mock_push_msg(msg_forwarder, msg): + event_push_msg.set() + await trio.sleep(0) + + def mock_handle_subscription(origin_id, sub_message): + event_handle_subscription.set() + + async def mock_handle_rpc(rpc, sender_peer_id): + event_handle_rpc.set() + await trio.sleep(0) + + with monkeypatch.context() as m: + m.setattr(pubsubs_fsub[0], "push_msg", mock_push_msg) + m.setattr(pubsubs_fsub[0], "handle_subscription", mock_handle_subscription) + m.setattr(pubsubs_fsub[0].router, "handle_rpc", mock_handle_rpc) + yield Events(event_push_msg, event_handle_subscription, event_handle_rpc) + + async with PubsubFactory.create_batch_with_floodsub( + 1, is_secure=is_host_secure + ) as pubsubs_fsub, net_stream_pair_factory(is_secure=is_host_secure) as stream_pair: + await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + # Kick off the task `continuously_read_stream` + nursery.start_soon(pubsubs_fsub[0].continuously_read_stream, stream_pair[0]) + + # Test: `push_msg` is called when publishing to a subscribed topic. + publish_subscribed_topic = rpc_pb2.RPC( + publish=[rpc_pb2.Message(topicIDs=[TESTING_TOPIC])] + ) + with mock_methods() as events: + await stream_pair[1].write( + encode_varint_prefixed(publish_subscribed_topic.SerializeToString()) + ) + await wait_for_event_occurring(events.push_msg) + # Make sure the other events are not emitted. + with pytest.raises(trio.TooSlowError): + await wait_for_event_occurring(events.handle_subscription) + with pytest.raises(trio.TooSlowError): + await wait_for_event_occurring(events.handle_rpc) + + # Test: `push_msg` is not called when publishing to a topic-not-subscribed. + publish_not_subscribed_topic = rpc_pb2.RPC( + publish=[rpc_pb2.Message(topicIDs=["NOT_SUBSCRIBED"])] + ) + with mock_methods() as events: + await stream_pair[1].write( + encode_varint_prefixed(publish_not_subscribed_topic.SerializeToString()) + ) + with pytest.raises(trio.TooSlowError): + await wait_for_event_occurring(events.push_msg) + + # Test: `handle_subscription` is called when a subscription message is received. + subscription_msg = rpc_pb2.RPC(subscriptions=[rpc_pb2.RPC.SubOpts()]) + with mock_methods() as events: + await stream_pair[1].write( + encode_varint_prefixed(subscription_msg.SerializeToString()) + ) + await wait_for_event_occurring(events.handle_subscription) + # Make sure the other events are not emitted. + with pytest.raises(trio.TooSlowError): + await wait_for_event_occurring(events.push_msg) + with pytest.raises(trio.TooSlowError): + await wait_for_event_occurring(events.handle_rpc) + + # Test: `handle_rpc` is called when a control message is received. + control_msg = rpc_pb2.RPC(control=rpc_pb2.ControlMessage()) + with mock_methods() as events: + await stream_pair[1].write( + encode_varint_prefixed(control_msg.SerializeToString()) + ) + await wait_for_event_occurring(events.handle_rpc) + # Make sure the other events are not emitted. + with pytest.raises(trio.TooSlowError): + await wait_for_event_occurring(events.push_msg) + with pytest.raises(trio.TooSlowError): + await wait_for_event_occurring(events.handle_subscription) # TODO: Add the following tests after they are aligned with Go. @@ -351,81 +331,84 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch): # - `test_handle_peer_queue` -@pytest.mark.parametrize("num_hosts", (1,)) -def test_handle_subscription(pubsubs_fsub): - assert len(pubsubs_fsub[0].peer_topics) == 0 - sub_msg_0 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=TESTING_TOPIC) - peer_ids = [ID(b"\x12\x20" + i.to_bytes(32, "big")) for i in range(2)] - # Test: One peer is subscribed - pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_0) - assert ( - len(pubsubs_fsub[0].peer_topics) == 1 - and TESTING_TOPIC in pubsubs_fsub[0].peer_topics - ) - assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 1 - assert peer_ids[0] in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] - # Test: Another peer is subscribed - pubsubs_fsub[0].handle_subscription(peer_ids[1], sub_msg_0) - assert len(pubsubs_fsub[0].peer_topics) == 1 - assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 2 - assert peer_ids[1] in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] - # Test: Subscribe to another topic - another_topic = "ANOTHER_TOPIC" - sub_msg_1 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=another_topic) - pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_1) - assert len(pubsubs_fsub[0].peer_topics) == 2 - assert another_topic in pubsubs_fsub[0].peer_topics - assert peer_ids[0] in pubsubs_fsub[0].peer_topics[another_topic] - # Test: unsubscribe - unsub_msg = rpc_pb2.RPC.SubOpts(subscribe=False, topicid=TESTING_TOPIC) - pubsubs_fsub[0].handle_subscription(peer_ids[0], unsub_msg) - assert peer_ids[0] not in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] +@pytest.mark.trio +async def test_handle_subscription(): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: + assert len(pubsubs_fsub[0].peer_topics) == 0 + sub_msg_0 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=TESTING_TOPIC) + peer_ids = [IDFactory() for _ in range(2)] + # Test: One peer is subscribed + pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_0) + assert ( + len(pubsubs_fsub[0].peer_topics) == 1 + and TESTING_TOPIC in pubsubs_fsub[0].peer_topics + ) + assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 1 + assert peer_ids[0] in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] + # Test: Another peer is subscribed + pubsubs_fsub[0].handle_subscription(peer_ids[1], sub_msg_0) + assert len(pubsubs_fsub[0].peer_topics) == 1 + assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 2 + assert peer_ids[1] in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] + # Test: Subscribe to another topic + another_topic = "ANOTHER_TOPIC" + sub_msg_1 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=another_topic) + pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_1) + assert len(pubsubs_fsub[0].peer_topics) == 2 + assert another_topic in pubsubs_fsub[0].peer_topics + assert peer_ids[0] in pubsubs_fsub[0].peer_topics[another_topic] + # Test: unsubscribe + unsub_msg = rpc_pb2.RPC.SubOpts(subscribe=False, topicid=TESTING_TOPIC) + pubsubs_fsub[0].handle_subscription(peer_ids[0], unsub_msg) + assert peer_ids[0] not in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_handle_talk(pubsubs_fsub): - sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC) - msg_0 = make_pubsub_msg( - origin_id=pubsubs_fsub[0].my_id, - topic_ids=[TESTING_TOPIC], - data=b"1234", - seqno=b"\x00" * 8, - ) - await pubsubs_fsub[0].handle_talk(msg_0) - msg_1 = make_pubsub_msg( - origin_id=pubsubs_fsub[0].my_id, - topic_ids=["NOT_SUBSCRIBED"], - data=b"1234", - seqno=b"\x11" * 8, - ) - await pubsubs_fsub[0].handle_talk(msg_1) - assert ( - len(pubsubs_fsub[0].my_topics) == 1 - and sub == pubsubs_fsub[0].my_topics[TESTING_TOPIC] - ) - assert sub.qsize() == 1 - assert (await sub.get()) == msg_0 +@pytest.mark.trio +async def test_handle_talk(): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: + sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + msg_0 = make_pubsub_msg( + origin_id=pubsubs_fsub[0].my_id, + topic_ids=[TESTING_TOPIC], + data=b"1234", + seqno=b"\x00" * 8, + ) + await pubsubs_fsub[0].handle_talk(msg_0) + msg_1 = make_pubsub_msg( + origin_id=pubsubs_fsub[0].my_id, + topic_ids=["NOT_SUBSCRIBED"], + data=b"1234", + seqno=b"\x11" * 8, + ) + await pubsubs_fsub[0].handle_talk(msg_1) + assert ( + len(pubsubs_fsub[0].topic_ids) == 1 + and sub == pubsubs_fsub[0].subscribed_topics_receive[TESTING_TOPIC] + ) + assert (await sub.receive()) == msg_0 -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_message_all_peers(pubsubs_fsub, monkeypatch): - peer_ids = [ID(b"\x12\x20" + i.to_bytes(32, "big")) for i in range(10)] - mock_peers = {peer_id: FakeNetStream() for peer_id in peer_ids} - monkeypatch.setattr(pubsubs_fsub[0], "peers", mock_peers) +@pytest.mark.trio +async def test_message_all_peers(monkeypatch, is_host_secure): + async with PubsubFactory.create_batch_with_floodsub( + 1, is_secure=is_host_secure + ) as pubsubs_fsub, net_stream_pair_factory(is_secure=is_host_secure) as stream_pair: + peer_id = IDFactory() + mock_peers = {peer_id: stream_pair[0]} + with monkeypatch.context() as m: + m.setattr(pubsubs_fsub[0], "peers", mock_peers) - empty_rpc = rpc_pb2.RPC() - empty_rpc_bytes = empty_rpc.SerializeToString() - empty_rpc_bytes_len_prefixed = encode_varint_prefixed(empty_rpc_bytes) - await pubsubs_fsub[0].message_all_peers(empty_rpc_bytes) - for stream in mock_peers.values(): - assert (await stream.read()) == empty_rpc_bytes_len_prefixed + empty_rpc = rpc_pb2.RPC() + empty_rpc_bytes = empty_rpc.SerializeToString() + empty_rpc_bytes_len_prefixed = encode_varint_prefixed(empty_rpc_bytes) + await pubsubs_fsub[0].message_all_peers(empty_rpc_bytes) + assert ( + await stream_pair[1].read(MAX_READ_LEN) + ) == empty_rpc_bytes_len_prefixed -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_publish(pubsubs_fsub, monkeypatch): +@pytest.mark.trio +async def test_publish(monkeypatch): msg_forwarders = [] msgs = [] @@ -433,80 +416,97 @@ async def test_publish(pubsubs_fsub, monkeypatch): msg_forwarders.append(msg_forwarder) msgs.append(msg) - monkeypatch.setattr(pubsubs_fsub[0], "push_msg", push_msg) + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: + with monkeypatch.context() as m: + m.setattr(pubsubs_fsub[0], "push_msg", push_msg) - await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA) - await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA) + await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA) + await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA) - assert len(msgs) == 2, "`push_msg` should be called every time `publish` is called" - assert (msg_forwarders[0] == msg_forwarders[1]) and ( - msg_forwarders[1] == pubsubs_fsub[0].my_id - ) - assert msgs[0].seqno != msgs[1].seqno, "`seqno` should be different every time" + assert ( + len(msgs) == 2 + ), "`push_msg` should be called every time `publish` is called" + assert (msg_forwarders[0] == msg_forwarders[1]) and ( + msg_forwarders[1] == pubsubs_fsub[0].my_id + ) + assert ( + msgs[0].seqno != msgs[1].seqno + ), "`seqno` should be different every time" -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_push_msg(pubsubs_fsub, monkeypatch): - msg_0 = make_pubsub_msg( - origin_id=pubsubs_fsub[0].my_id, - topic_ids=[TESTING_TOPIC], - data=TESTING_DATA, - seqno=b"\x00" * 8, - ) +@pytest.mark.trio +async def test_push_msg(monkeypatch): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: + msg_0 = make_pubsub_msg( + origin_id=pubsubs_fsub[0].my_id, + topic_ids=[TESTING_TOPIC], + data=TESTING_DATA, + seqno=b"\x00" * 8, + ) - event = asyncio.Event() + @contextmanager + def mock_router_publish(): - async def router_publish(*args, **kwargs): - event.set() + event = trio.Event() - monkeypatch.setattr(pubsubs_fsub[0].router, "publish", router_publish) + async def router_publish(*args, **kwargs): + event.set() + await trio.sleep(0) - # Test: `msg` is not seen before `push_msg`, and is seen after `push_msg`. - assert not pubsubs_fsub[0]._is_msg_seen(msg_0) - await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0) - assert pubsubs_fsub[0]._is_msg_seen(msg_0) - # Test: Ensure `router.publish` is called in `push_msg` - await asyncio.wait_for(event.wait(), timeout=0.1) + with monkeypatch.context() as m: + m.setattr(pubsubs_fsub[0].router, "publish", router_publish) + yield event - # Test: `push_msg` the message again and it will be reject. - # `router_publish` is not called then. - event.clear() - await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0) - await asyncio.sleep(0.01) - assert not event.is_set() + with mock_router_publish() as event: + # Test: `msg` is not seen before `push_msg`, and is seen after `push_msg`. + assert not pubsubs_fsub[0]._is_msg_seen(msg_0) + await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0) + assert pubsubs_fsub[0]._is_msg_seen(msg_0) + # Test: Ensure `router.publish` is called in `push_msg` + with trio.fail_after(0.1): + await event.wait() - sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC) - # Test: `push_msg` succeeds with another unseen msg. - msg_1 = make_pubsub_msg( - origin_id=pubsubs_fsub[0].my_id, - topic_ids=[TESTING_TOPIC], - data=TESTING_DATA, - seqno=b"\x11" * 8, - ) - assert not pubsubs_fsub[0]._is_msg_seen(msg_1) - await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_1) - assert pubsubs_fsub[0]._is_msg_seen(msg_1) - await asyncio.wait_for(event.wait(), timeout=0.1) - # Test: Subscribers are notified when `push_msg` new messages. - assert (await sub.get()) == msg_1 + with mock_router_publish() as event: + # Test: `push_msg` the message again and it will be reject. + # `router_publish` is not called then. + await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0) + await trio.sleep(0.01) + assert not event.is_set() - # Test: add a topic validator and `push_msg` the message that - # does not pass the validation. - # `router_publish` is not called then. - def failed_sync_validator(peer_id, msg): - return False + sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + # Test: `push_msg` succeeds with another unseen msg. + msg_1 = make_pubsub_msg( + origin_id=pubsubs_fsub[0].my_id, + topic_ids=[TESTING_TOPIC], + data=TESTING_DATA, + seqno=b"\x11" * 8, + ) + assert not pubsubs_fsub[0]._is_msg_seen(msg_1) + await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_1) + assert pubsubs_fsub[0]._is_msg_seen(msg_1) + with trio.fail_after(0.1): + await event.wait() + # Test: Subscribers are notified when `push_msg` new messages. + assert (await sub.receive()) == msg_1 - pubsubs_fsub[0].set_topic_validator(TESTING_TOPIC, failed_sync_validator, False) + with mock_router_publish() as event: + # Test: add a topic validator and `push_msg` the message that + # does not pass the validation. + # `router_publish` is not called then. + def failed_sync_validator(peer_id, msg): + return False - msg_2 = make_pubsub_msg( - origin_id=pubsubs_fsub[0].my_id, - topic_ids=[TESTING_TOPIC], - data=TESTING_DATA, - seqno=b"\x22" * 8, - ) + pubsubs_fsub[0].set_topic_validator( + TESTING_TOPIC, failed_sync_validator, False + ) - event.clear() - await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_2) - await asyncio.sleep(0.01) - assert not event.is_set() + msg_2 = make_pubsub_msg( + origin_id=pubsubs_fsub[0].my_id, + topic_ids=[TESTING_TOPIC], + data=TESTING_DATA, + seqno=b"\x22" * 8, + ) + + await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_2) + await trio.sleep(0.01) + assert not event.is_set() From 1929f307fb64b6d526b0e5c07160dd53cd3cbfcf Mon Sep 17 00:00:00 2001 From: mhchia Date: Fri, 6 Dec 2019 17:06:37 +0800 Subject: [PATCH 19/81] Fix all modules except for security --- libp2p/host/ping.py | 3 +- libp2p/io/trio.py | 52 +- libp2p/network/connection/raw_connection.py | 14 +- libp2p/network/connection/swarm_connection.py | 13 +- libp2p/network/swarm.py | 12 +- libp2p/pubsub/floodsub.py | 6 +- libp2p/pubsub/gossipsub.py | 33 +- libp2p/pubsub/pubsub.py | 36 +- libp2p/stream_muxer/abc.py | 6 +- libp2p/stream_muxer/mplex/mplex.py | 66 +- libp2p/tools/factories.py | 155 +++-- libp2p/tools/pubsub/dummy_account_node.py | 51 +- .../floodsub_integration_test_settings.py | 151 ++-- libp2p/tools/utils.py | 56 +- libp2p/transport/listener_interface.py | 12 +- libp2p/transport/tcp/tcp.py | 42 +- tests/host/test_ping.py | 2 +- tests/host/test_routed_host.py | 73 -- tests/protocol_muxer/test_protocol_muxer.py | 1 - tests/pubsub/conftest.py | 22 - tests/pubsub/test_dummyaccount_demo.py | 95 +-- tests/pubsub/test_floodsub.py | 117 ++-- tests/pubsub/test_gossipsub.py | 646 +++++++++--------- .../test_gossipsub_backward_compatibility.py | 16 +- tests/pubsub/test_mcache.py | 5 +- tests/pubsub/test_pubsub.py | 9 +- tests/stream_muxer/test_mplex_stream.py | 23 - tests/transport/test_tcp.py | 2 +- 28 files changed, 764 insertions(+), 955 deletions(-) delete mode 100644 tests/host/test_routed_host.py delete mode 100644 tests/pubsub/conftest.py diff --git a/libp2p/host/ping.py b/libp2p/host/ping.py index 589fc917..9e23f1cc 100644 --- a/libp2p/host/ping.py +++ b/libp2p/host/ping.py @@ -1,6 +1,7 @@ -import trio import logging +import trio + from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset from libp2p.network.stream.net_stream_interface import INetStream from libp2p.peer.id import ID as PeerID diff --git a/libp2p/io/trio.py b/libp2p/io/trio.py index e74e9ed2..840c3bc8 100644 --- a/libp2p/io/trio.py +++ b/libp2p/io/trio.py @@ -1,7 +1,6 @@ import logging import trio -from trio import SocketStream from libp2p.io.abc import ReadWriteCloser from libp2p.io.exceptions import IOException @@ -9,29 +8,48 @@ from libp2p.io.exceptions import IOException logger = logging.getLogger("libp2p.io.trio") -class TrioReadWriteCloser(ReadWriteCloser): - stream: SocketStream +class TrioTCPStream(ReadWriteCloser): + stream: trio.SocketStream + # NOTE: Add both read and write lock to avoid `trio.BusyResourceError` + read_lock: trio.Lock + write_lock: trio.Lock - def __init__(self, stream: SocketStream) -> None: + def __init__(self, stream: trio.SocketStream) -> None: self.stream = stream + self.read_lock = trio.Lock() + self.write_lock = trio.Lock() async def write(self, data: bytes) -> None: """Raise `RawConnError` if the underlying connection breaks.""" - try: - await self.stream.send_all(data) - except (trio.ClosedResourceError, trio.BrokenResourceError) as error: - raise IOException(error) + async with self.write_lock: + try: + await self.stream.send_all(data) + except (trio.ClosedResourceError, trio.BrokenResourceError) as error: + raise IOException from error + except trio.BusyResourceError as error: + # This should never happen, since we already access streams with read/write locks. + raise Exception( + "this should never happen " + "since we already access streams with read/write locks." + ) from error async def read(self, n: int = -1) -> bytes: - if n == 0: - # Check point - await trio.sleep(0) - return b"" - max_bytes = n if n != -1 else None - try: - return await self.stream.receive_some(max_bytes) - except (trio.ClosedResourceError, trio.BrokenResourceError) as error: - raise IOException(error) + async with self.read_lock: + if n == 0: + # Checkpoint + await trio.hazmat.checkpoint() + return b"" + max_bytes = n if n != -1 else None + try: + return await self.stream.receive_some(max_bytes) + except (trio.ClosedResourceError, trio.BrokenResourceError) as error: + raise IOException from error + except trio.BusyResourceError as error: + # This should never happen, since we already access streams with read/write locks. + raise Exception( + "this should never happen " + "since we already access streams with read/write locks." + ) from error async def close(self) -> None: await self.stream.aclose() diff --git a/libp2p/network/connection/raw_connection.py b/libp2p/network/connection/raw_connection.py index 2bdb3b10..25b1049c 100644 --- a/libp2p/network/connection/raw_connection.py +++ b/libp2p/network/connection/raw_connection.py @@ -1,5 +1,3 @@ -import trio - from libp2p.io.abc import ReadWriteCloser from libp2p.io.exceptions import IOException @@ -8,17 +6,17 @@ from .raw_connection_interface import IRawConnection class RawConnection(IRawConnection): - read_write_closer: ReadWriteCloser + stream: ReadWriteCloser is_initiator: bool - def __init__(self, read_write_closer: ReadWriteCloser, initiator: bool) -> None: - self.read_write_closer = read_write_closer + def __init__(self, stream: ReadWriteCloser, initiator: bool) -> None: + self.stream = stream self.is_initiator = initiator async def write(self, data: bytes) -> None: """Raise `RawConnError` if the underlying connection breaks.""" try: - await self.read_write_closer.write(data) + await self.stream.write(data) except IOException as error: raise RawConnError(error) @@ -30,9 +28,9 @@ class RawConnection(IRawConnection): Raise `RawConnError` if the underlying connection breaks """ try: - return await self.read_write_closer.read(n) + return await self.stream.read(n) except IOException as error: raise RawConnError(error) async def close(self) -> None: - await self.read_write_closer.close() + await self.stream.close() diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index 48774ec2..1e310338 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Awaitable, List, Set, Tuple +from typing import TYPE_CHECKING, Set, Tuple from async_service import Service import trio @@ -45,16 +45,11 @@ class SwarmConn(INetConn, Service): # before we cancel the stream handler tasks. await trio.sleep(0.1) - # FIXME: Now let `_notify_disconnected` finish first. - # Schedule `self._notify_disconnected` to make it execute after `close` is finished. await self._notify_disconnected() async def _handle_new_streams(self) -> None: while self.manager.is_running: try: - print( - f"!@# SwarmConn._handle_new_streams: {self.muxed_conn._id}: waiting for new streams" - ) stream = await self.muxed_conn.accept_stream() except MuxedConnUnavailable: # If there is anything wrong in the MuxedConn, @@ -63,9 +58,6 @@ class SwarmConn(INetConn, Service): # Asynchronously handle the accepted stream, to avoid blocking the next stream. self.manager.run_task(self._handle_muxed_stream, stream) - print( - f"!@# SwarmConn._handle_new_streams: {self.muxed_conn._id}: out of the loop" - ) await self.close() async def _call_stream_handler(self, net_stream: NetStream) -> None: @@ -92,8 +84,7 @@ class SwarmConn(INetConn, Service): await self.swarm.notify_disconnected(self) async def run(self) -> None: - self.manager.run_task(self._handle_new_streams) - await self.manager.wait_finished() + await self._handle_new_streams() async def new_stream(self) -> NetStream: muxed_stream = await self.muxed_conn.open_stream() diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 37614bcd..4bf86dde 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -203,16 +203,17 @@ class Swarm(INetwork, Service): await self.add_conn(muxed_conn) logger.debug("successfully opened connection to peer %s", peer_id) - # FIXME: This is a intentional barrier to prevent from the handler exiting and - # closing the connection. Probably change to `Service.manager.wait_finished`? - await trio.sleep_forever() + # NOTE: This is a intentional barrier to prevent from the handler exiting and + # closing the connection. + await self.manager.wait_finished() try: # Success listener = self.transport.create_listener(conn_handler) self.listeners[str(maddr)] = listener - # FIXME: Hack - await listener.listen(maddr, self.manager._task_nursery) + # TODO: `listener.listen` is not bounded with nursery. If we want to be + # I/O agnostic, we should change the API. + await listener.listen(maddr, self.manager._task_nursery) # type: ignore # Call notifiers since event occurred await self.notify_listen(maddr) @@ -278,6 +279,7 @@ class Swarm(INetwork, Service): """ self.notifees.append(notifee) + # TODO: Use `run_task`. async def notify_opened_stream(self, stream: INetStream) -> None: async with trio.open_nursery() as nursery: for notifee in self.notifees: diff --git a/libp2p/pubsub/floodsub.py b/libp2p/pubsub/floodsub.py index 8c15a441..9e323eb2 100644 --- a/libp2p/pubsub/floodsub.py +++ b/libp2p/pubsub/floodsub.py @@ -64,7 +64,7 @@ class FloodSub(IPubsubRouter): :param rpc: rpc message """ # Checkpoint - await trio.sleep(0) + await trio.hazmat.checkpoint() async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None: """ @@ -107,7 +107,7 @@ class FloodSub(IPubsubRouter): :param topic: topic to join """ # Checkpoint - await trio.sleep(0) + await trio.hazmat.checkpoint() async def leave(self, topic: str) -> None: """ @@ -117,7 +117,7 @@ class FloodSub(IPubsubRouter): :param topic: topic to leave """ # Checkpoint - await trio.sleep(0) + await trio.hazmat.checkpoint() def _get_peers_to_send( self, topic_ids: Iterable[str], msg_forwarder: ID, origin: ID diff --git a/libp2p/pubsub/gossipsub.py b/libp2p/pubsub/gossipsub.py index 93faebdd..df0f83f4 100644 --- a/libp2p/pubsub/gossipsub.py +++ b/libp2p/pubsub/gossipsub.py @@ -1,15 +1,18 @@ from ast import literal_eval -import asyncio import logging import random from typing import Any, Dict, Iterable, List, Sequence, Set +from async_service import Service +import trio + from libp2p.network.stream.exceptions import StreamClosed from libp2p.peer.id import ID from libp2p.pubsub import floodsub from libp2p.typing import TProtocol from libp2p.utils import encode_varint_prefixed +from .exceptions import NoPubsubAttached from .mcache import MessageCache from .pb import rpc_pb2 from .pubsub import Pubsub @@ -20,8 +23,7 @@ PROTOCOL_ID = TProtocol("/meshsub/1.0.0") logger = logging.getLogger("libp2p.pubsub.gossipsub") -class GossipSub(IPubsubRouter): - +class GossipSub(IPubsubRouter, Service): protocols: List[TProtocol] pubsub: Pubsub @@ -86,6 +88,12 @@ class GossipSub(IPubsubRouter): # Create heartbeat timer self.heartbeat_interval = heartbeat_interval + async def run(self) -> None: + if self.pubsub is None: + raise NoPubsubAttached + self.manager.run_task(self.heartbeat) + await self.manager.wait_finished() + # Interface functions def get_protocols(self) -> List[TProtocol]: @@ -105,10 +113,6 @@ class GossipSub(IPubsubRouter): logger.debug("attached to pusub") - # Start heartbeat now that we have a pubsub instance - # TODO: Start after delay - asyncio.ensure_future(self.heartbeat()) - def add_peer(self, peer_id: ID, protocol_id: TProtocol) -> None: """ Notifies the router that a new peer has been connected. @@ -310,7 +314,7 @@ class GossipSub(IPubsubRouter): await self.fanout_heartbeat() await self.gossip_heartbeat() - await asyncio.sleep(self.heartbeat_interval) + await trio.sleep(self.heartbeat_interval) async def mesh_heartbeat(self) -> None: # Note: the comments here are the exact pseudocode from the spec @@ -338,7 +342,7 @@ class GossipSub(IPubsubRouter): if num_mesh_peers_in_topic > self.degree_high: # Select |mesh[topic]| - D peers from mesh[topic] - selected_peers = GossipSub.select_from_minus( + selected_peers = self.select_from_minus( num_mesh_peers_in_topic - self.degree, self.mesh[topic], [] ) for peer in selected_peers: @@ -353,7 +357,10 @@ class GossipSub(IPubsubRouter): for topic in self.fanout: # If time since last published > ttl # TODO: there's no way time_since_last_publish gets set anywhere yet - if self.time_since_last_publish[topic] > self.time_to_live: + if ( + topic in self.time_since_last_publish + and self.time_since_last_publish[topic] > self.time_to_live + ): # Remove topic from fanout del self.fanout[topic] del self.time_since_last_publish[topic] @@ -407,11 +414,7 @@ class GossipSub(IPubsubRouter): topic, self.degree, [] ) for peer in peers_to_emit_ihave_to: - if ( - peer not in self.mesh[topic] - and peer not in self.fanout[topic] - ): - + if peer not in self.fanout[topic]: msg_id_strs = [str(msg) for msg in msg_ids] await self.emit_ihave(topic, msg_id_strs, peer) diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 7c4b50de..3370ea30 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -1,4 +1,4 @@ -from abc import ABC, abstractmethod +from abc import ABC import logging import math import time @@ -57,6 +57,7 @@ class TopicValidator(NamedTuple): is_async: bool +# TODO: Add interface for Pubsub class BasePubsub(ABC): pass @@ -103,20 +104,24 @@ class Pubsub(BasePubsub, Service): # Attach this new Pubsub object to the router self.router.attach(self) - peer_send_channel, peer_receive_channel = trio.open_memory_channel(0) - dead_peer_send_channel, dead_peer_receive_channel = trio.open_memory_channel(0) + peer_channels: Tuple[ + "trio.MemorySendChannel[ID]", "trio.MemoryReceiveChannel[ID]" + ] = trio.open_memory_channel(0) + dead_peer_channels: Tuple[ + "trio.MemorySendChannel[ID]", "trio.MemoryReceiveChannel[ID]" + ] = trio.open_memory_channel(0) # Only keep the receive channels in `Pubsub`. # Therefore, we can only close from the receive side. - self.peer_receive_channel = peer_receive_channel - self.dead_peer_receive_channel = dead_peer_receive_channel + self.peer_receive_channel = peer_channels[1] + self.dead_peer_receive_channel = dead_peer_channels[1] # Register a notifee self.host.get_network().register_notifee( - PubsubNotifee(peer_send_channel, dead_peer_send_channel) + PubsubNotifee(peer_channels[0], dead_peer_channels[0]) ) # Register stream handlers for each pubsub router protocol to handle # the pubsub streams opened on those protocols - for protocol in router.protocols: + for protocol in router.get_protocols(): self.host.set_stream_handler(protocol, self.stream_handler) # keeps track of seen messages as LRU cache @@ -328,8 +333,9 @@ class Pubsub(BasePubsub, Service): self.manager.run_task(self._handle_new_peer, peer_id) async def handle_dead_peer_queue(self) -> None: - """Continuously read from dead peer channel and close the stream between - that peer and remove peer info from pubsub and pubsub router.""" + """Continuously read from dead peer channel and close the stream + between that peer and remove peer info from pubsub and pubsub + router.""" async with self.dead_peer_receive_channel: while self.manager.is_running: peer_id: ID = await self.dead_peer_receive_channel.receive() @@ -391,7 +397,11 @@ class Pubsub(BasePubsub, Service): return self.subscribed_topics_receive[topic_id] # Map topic_id to a blocking channel - send_channel, receive_channel = trio.open_memory_channel(math.inf) + channels: Tuple[ + "trio.MemorySendChannel[rpc_pb2.Message]", + "trio.MemoryReceiveChannel[rpc_pb2.Message]", + ] = trio.open_memory_channel(math.inf) + send_channel, receive_channel = channels self.subscribed_topics_send[topic_id] = send_channel self.subscribed_topics_receive[topic_id] = receive_channel @@ -506,7 +516,7 @@ class Pubsub(BasePubsub, Service): if len(async_topic_validators) > 0: # TODO: Use a better pattern - final_result = True + final_result: bool = True async def run_async_validator(func: AsyncValidatorFn) -> None: nonlocal final_result @@ -514,8 +524,8 @@ class Pubsub(BasePubsub, Service): final_result = final_result and result async with trio.open_nursery() as nursery: - for validator in async_topic_validators: - nursery.start_soon(run_async_validator, validator) + for async_validator in async_topic_validators: + nursery.start_soon(run_async_validator, async_validator) if not final_result: raise ValidationError(f"Validation failed for msg={msg}") diff --git a/libp2p/stream_muxer/abc.py b/libp2p/stream_muxer/abc.py index 71704c1e..12a8f805 100644 --- a/libp2p/stream_muxer/abc.py +++ b/libp2p/stream_muxer/abc.py @@ -1,11 +1,13 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod + +from async_service import ServiceAPI from libp2p.io.abc import ReadWriteCloser from libp2p.peer.id import ID from libp2p.security.secure_conn_interface import ISecureConn -class IMuxedConn(ABC): +class IMuxedConn(ServiceAPI): """ reference: https://github.com/libp2p/go-stream-muxer/blob/master/muxer.go """ diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index ac6cdcdb..e23da00c 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -1,7 +1,6 @@ import logging import math -from typing import Any # noqa: F401 -from typing import Awaitable, Dict, List, Optional, Tuple +from typing import Dict, Optional, Tuple from async_service import Service import trio @@ -67,13 +66,15 @@ class Mplex(IMuxedConn, Service): self.streams = {} self.streams_lock = trio.Lock() self.streams_msg_channels = {} - send_channel, receive_channel = trio.open_memory_channel(math.inf) - self.new_stream_send_channel = send_channel - self.new_stream_receive_channel = receive_channel + channels: Tuple[ + "trio.MemorySendChannel[IMuxedStream]", + "trio.MemoryReceiveChannel[IMuxedStream]", + ] = trio.open_memory_channel(math.inf) + self.new_stream_send_channel, self.new_stream_receive_channel = channels self.event_shutting_down = trio.Event() self.event_closed = trio.Event() - async def run(self): + async def run(self) -> None: self.manager.run_task(self.handle_incoming) await self.manager.wait_finished() @@ -112,11 +113,13 @@ class Mplex(IMuxedConn, Service): async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream: # Use an unbounded buffer, to avoid `handle_incoming` being blocked when doing # `send_channel.send`. - send_channel, receive_channel = trio.open_memory_channel(math.inf) - stream = MplexStream(name, stream_id, self, receive_channel) + channels: Tuple[ + "trio.MemorySendChannel[bytes]", "trio.MemoryReceiveChannel[bytes]" + ] = trio.open_memory_channel(math.inf) + stream = MplexStream(name, stream_id, self, channels[1]) async with self.streams_lock: self.streams[stream_id] = stream - self.streams_msg_channels[stream_id] = send_channel + self.streams_msg_channels[stream_id] = channels[0] return stream async def open_stream(self) -> IMuxedStream: @@ -150,9 +153,6 @@ class Mplex(IMuxedConn, Service): :param data: data to send in the message :param stream_id: stream the message is in """ - print( - f"!@# send_message: {self._id}: flag={flag}, data={data}, stream_id={stream_id}" - ) # << by 3, then or with flag header = encode_uvarint((stream_id.channel_id << 3) | flag.value) @@ -179,19 +179,10 @@ class Mplex(IMuxedConn, Service): while self.manager.is_running: try: - print( - f"!@# handle_incoming: {self._id}: before _handle_incoming_message" - ) await self._handle_incoming_message() - print( - f"!@# handle_incoming: {self._id}: after _handle_incoming_message" - ) except MplexUnavailable as e: logger.debug("mplex unavailable while waiting for incoming: %s", e) - print(f"!@# handle_incoming: {self._id}: MplexUnavailable: {e}") break - - print(f"!@# handle_incoming: {self._id}: leaving") # If we enter here, it means this connection is shutting down. # We should clean things up. await self._cleanup() @@ -232,44 +223,27 @@ class Mplex(IMuxedConn, Service): :raise MplexUnavailable: `Mplex` encounters fatal error or is shutting down. """ - print(f"!@# _handle_incoming_message: {self._id}: before reading") channel_id, flag, message = await self.read_message() - print( - f"!@# _handle_incoming_message: {self._id}: channel_id={channel_id}, flag={flag}, message={message}" - ) stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1)) - print(f"!@# _handle_incoming_message: {self._id}: 2") if flag == HeaderTags.NewStream.value: - print(f"!@# _handle_incoming_message: {self._id}: 3") await self._handle_new_stream(stream_id, message) - print(f"!@# _handle_incoming_message: {self._id}: 4") elif flag in ( HeaderTags.MessageInitiator.value, HeaderTags.MessageReceiver.value, ): - print(f"!@# _handle_incoming_message: {self._id}: 5") await self._handle_message(stream_id, message) - print(f"!@# _handle_incoming_message: {self._id}: 6") elif flag in (HeaderTags.CloseInitiator.value, HeaderTags.CloseReceiver.value): - print(f"!@# _handle_incoming_message: {self._id}: 7") await self._handle_close(stream_id) - print(f"!@# _handle_incoming_message: {self._id}: 8") elif flag in (HeaderTags.ResetInitiator.value, HeaderTags.ResetReceiver.value): - print(f"!@# _handle_incoming_message: {self._id}: 9") await self._handle_reset(stream_id) - print(f"!@# _handle_incoming_message: {self._id}: 10") else: - print(f"!@# _handle_incoming_message: {self._id}: 11") # Receives messages with an unknown flag # TODO: logging async with self.streams_lock: - print(f"!@# _handle_incoming_message: {self._id}: 12") if stream_id in self.streams: - print(f"!@# _handle_incoming_message: {self._id}: 13") stream = self.streams[stream_id] await stream.reset() - print(f"!@# _handle_incoming_message: {self._id}: 14") async def _handle_new_stream(self, stream_id: StreamID, message: bytes) -> None: async with self.streams_lock: @@ -285,59 +259,43 @@ class Mplex(IMuxedConn, Service): raise MplexUnavailable async def _handle_message(self, stream_id: StreamID, message: bytes) -> None: - print( - f"!@# _handle_message: {self._id}: stream_id={stream_id}, message={message}" - ) async with self.streams_lock: - print(f"!@# _handle_message: {self._id}: 1") if stream_id not in self.streams: # We receive a message of the stream `stream_id` which is not accepted # before. It is abnormal. Possibly disconnect? # TODO: Warn and emit logs about this. - print(f"!@# _handle_message: {self._id}: 2") return - print(f"!@# _handle_message: {self._id}: 3") stream = self.streams[stream_id] send_channel = self.streams_msg_channels[stream_id] async with stream.close_lock: - print(f"!@# _handle_message: {self._id}: 4") if stream.event_remote_closed.is_set(): - print(f"!@# _handle_message: {self._id}: 5") # TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501 return - print(f"!@# _handle_message: {self._id}: 6") await send_channel.send(message) - print(f"!@# _handle_message: {self._id}: 7") async def _handle_close(self, stream_id: StreamID) -> None: - print(f"!@# _handle_close: {self._id}: step=0") async with self.streams_lock: if stream_id not in self.streams: # Ignore unmatched messages for now. return stream = self.streams[stream_id] send_channel = self.streams_msg_channels[stream_id] - print(f"!@# _handle_close: {self._id}: step=1") await send_channel.aclose() - print(f"!@# _handle_close: {self._id}: step=2") # NOTE: If remote is already closed, then return: Technically a bug # on the other side. We should consider killing the connection. async with stream.close_lock: if stream.event_remote_closed.is_set(): return - print(f"!@# _handle_close: {self._id}: step=3") is_local_closed: bool async with stream.close_lock: stream.event_remote_closed.set() is_local_closed = stream.event_local_closed.is_set() - print(f"!@# _handle_close: {self._id}: step=4") # If local is also closed, both sides are closed. Then, we should clean up # the entry of this stream, to avoid others from accessing it. if is_local_closed: async with self.streams_lock: if stream_id in self.streams: del self.streams[stream_id] - print(f"!@# _handle_close: {self._id}: step=5") async def _handle_reset(self, stream_id: StreamID) -> None: async with self.streams_lock: diff --git a/libp2p/tools/factories.py b/libp2p/tools/factories.py index ac243012..568a2762 100644 --- a/libp2p/tools/factories.py +++ b/libp2p/tools/factories.py @@ -1,30 +1,29 @@ from contextlib import AsyncExitStack, asynccontextmanager -from typing import Any, AsyncIterator, Dict, Tuple, cast +from typing import Any, AsyncIterator, Dict, Sequence, Tuple, cast from async_service import background_trio_service import factory import trio -from libp2p.tools.constants import GOSSIPSUB_PARAMS from libp2p import generate_new_rsa_identity, generate_peer_id_from from libp2p.crypto.keys import KeyPair from libp2p.host.basic_host import BasicHost -from libp2p.host.routed_host import RoutedHost -from libp2p.tools.utils import set_up_routers -from libp2p.kademlia.network import KademliaServer +from libp2p.host.host_interface import IHost from libp2p.network.connection.swarm_connection import SwarmConn from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.swarm import Swarm -from libp2p.peer.peerstore import PeerStore from libp2p.peer.id import ID +from libp2p.peer.peerstore import PeerStore from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.gossipsub import GossipSub from libp2p.pubsub.pubsub import Pubsub +from libp2p.pubsub.pubsub_router_interface import IPubsubRouter from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport import libp2p.security.secio.transport as secio from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex from libp2p.stream_muxer.mplex.mplex_stream import MplexStream +from libp2p.tools.constants import GOSSIPSUB_PARAMS from libp2p.transport.tcp.tcp import TCP from libp2p.transport.typing import TMuxerOptions from libp2p.transport.upgrader import TransportUpgrader @@ -74,7 +73,7 @@ class SwarmFactory(factory.Factory): @asynccontextmanager async def create_and_listen( cls, is_secure: bool, key_pair: KeyPair = None, muxer_opt: TMuxerOptions = None - ) -> Swarm: + ) -> AsyncIterator[Swarm]: # `factory.Factory.__init__` does *not* prepare a *default value* if we pass # an argument explicitly with `None`. If an argument is `None`, we don't pass it to # `factory.Factory.__init__`, in order to let the function initialize it. @@ -92,7 +91,7 @@ class SwarmFactory(factory.Factory): @asynccontextmanager async def create_batch_and_listen( cls, is_secure: bool, number: int, muxer_opt: TMuxerOptions = None - ) -> Tuple[Swarm, ...]: + ) -> AsyncIterator[Tuple[Swarm, ...]]: async with AsyncExitStack() as stack: ctx_mgrs = [ await stack.enter_async_context( @@ -100,7 +99,7 @@ class SwarmFactory(factory.Factory): ) for _ in range(number) ] - yield ctx_mgrs + yield tuple(ctx_mgrs) class HostFactory(factory.Factory): @@ -120,7 +119,7 @@ class HostFactory(factory.Factory): @asynccontextmanager async def create_batch_and_listen( cls, is_secure: bool, number: int - ) -> Tuple[BasicHost, ...]: + ) -> AsyncIterator[Tuple[BasicHost, ...]]: key_pairs = [generate_new_rsa_identity() for _ in range(number)] async with AsyncExitStack() as stack: swarms = [ @@ -136,30 +135,6 @@ class HostFactory(factory.Factory): yield hosts -class RoutedHostFactory(factory.Factory): - class Meta: - model = RoutedHost - - public_key = factory.LazyAttribute(lambda o: o.key_pair.public_key) - network = factory.LazyAttribute( - lambda o: SwarmFactory(is_secure=o.is_secure, key_pair=o.key_pair) - ) - router = factory.LazyFunction(KademliaServer) - - @classmethod - @asynccontextmanager - async def create_batch_and_listen( - cls, is_secure: bool, number: int - ) -> Tuple[RoutedHost, ...]: - key_pairs = [generate_new_rsa_identity() for _ in range(number)] - routers = await set_up_routers((0,) * number) - async with SwarmFactory.create_batch_and_listen(is_secure, number) as swarms: - yield tuple( - RoutedHost(key_pair.public_key, swarm, router) - for key_pair, swarm, router in zip(key_pairs, swarms, routers) - ) - - class FloodsubFactory(factory.Factory): class Meta: model = FloodSub @@ -191,17 +166,22 @@ class PubsubFactory(factory.Factory): @classmethod @asynccontextmanager - async def create_and_start(cls, host, router, cache_size): + async def create_and_start( + cls, host: IHost, router: IPubsubRouter, cache_size: int + ) -> AsyncIterator[Pubsub]: pubsub = PubsubFactory(host=host, router=router, cache_size=cache_size) async with background_trio_service(pubsub): yield pubsub @classmethod @asynccontextmanager - async def create_batch_with_floodsub( - cls, number: int, is_secure: bool = False, cache_size: int = None - ): - floodsubs = FloodsubFactory.create_batch(number) + async def _create_batch_with_router( + cls, + number: int, + routers: Sequence[IPubsubRouter], + is_secure: bool = False, + cache_size: int = None, + ) -> AsyncIterator[Tuple[Pubsub, ...]]: async with HostFactory.create_batch_and_listen(is_secure, number) as hosts: # Pubsubs should exit before hosts async with AsyncExitStack() as stack: @@ -209,21 +189,80 @@ class PubsubFactory(factory.Factory): await stack.enter_async_context( cls.create_and_start(host, router, cache_size) ) - for host, router in zip(hosts, floodsubs) + for host, router in zip(hosts, routers) ] - yield pubsubs + yield tuple(pubsubs) - # @classmethod - # async def create_batch_with_gossipsub( - # cls, number: int, cache_size: int = None, gossipsub_params=GOSSIPSUB_PARAMS - # ): - # ... + @classmethod + @asynccontextmanager + async def create_batch_with_floodsub( + cls, + number: int, + is_secure: bool = False, + cache_size: int = None, + protocols: Sequence[TProtocol] = None, + ) -> AsyncIterator[Tuple[Pubsub, ...]]: + if protocols is not None: + floodsubs = FloodsubFactory.create_batch(number, protocols=list(protocols)) + else: + floodsubs = FloodsubFactory.create_batch(number) + async with cls._create_batch_with_router( + number, floodsubs, is_secure, cache_size + ) as pubsubs: + yield pubsubs + + @classmethod + @asynccontextmanager + async def create_batch_with_gossipsub( + cls, + number: int, + *, + is_secure: bool = False, + cache_size: int = None, + protocols: Sequence[TProtocol] = None, + degree: int = GOSSIPSUB_PARAMS.degree, + degree_low: int = GOSSIPSUB_PARAMS.degree_low, + degree_high: int = GOSSIPSUB_PARAMS.degree_high, + time_to_live: int = GOSSIPSUB_PARAMS.time_to_live, + gossip_window: int = GOSSIPSUB_PARAMS.gossip_window, + gossip_history: int = GOSSIPSUB_PARAMS.gossip_history, + heartbeat_interval: float = GOSSIPSUB_PARAMS.heartbeat_interval, + ) -> AsyncIterator[Tuple[Pubsub, ...]]: + if protocols is not None: + gossipsubs = GossipsubFactory.create_batch( + number, + protocols=protocols, + degree=degree, + degree_low=degree_low, + degree_high=degree_high, + time_to_live=time_to_live, + gossip_window=gossip_window, + heartbeat_interval=heartbeat_interval, + ) + else: + gossipsubs = GossipsubFactory.create_batch( + number, + degree=degree, + degree_low=degree_low, + degree_high=degree_high, + time_to_live=time_to_live, + gossip_window=gossip_window, + heartbeat_interval=heartbeat_interval, + ) + + async with cls._create_batch_with_router( + number, gossipsubs, is_secure, cache_size + ) as pubsubs: + async with AsyncExitStack() as stack: + for router in gossipsubs: + await stack.enter_async_context(background_trio_service(router)) + yield pubsubs @asynccontextmanager async def swarm_pair_factory( is_secure: bool, muxer_opt: TMuxerOptions = None -) -> Tuple[Swarm, Swarm]: +) -> AsyncIterator[Tuple[Swarm, Swarm]]: async with SwarmFactory.create_batch_and_listen( is_secure, 2, muxer_opt=muxer_opt ) as swarms: @@ -232,7 +271,9 @@ async def swarm_pair_factory( @asynccontextmanager -async def host_pair_factory(is_secure: bool) -> Tuple[BasicHost, BasicHost]: +async def host_pair_factory( + is_secure: bool +) -> AsyncIterator[Tuple[BasicHost, BasicHost]]: async with HostFactory.create_batch_and_listen(is_secure, 2) as hosts: await connect(hosts[0], hosts[1]) yield hosts[0], hosts[1] @@ -241,7 +282,7 @@ async def host_pair_factory(is_secure: bool) -> Tuple[BasicHost, BasicHost]: @asynccontextmanager async def swarm_conn_pair_factory( is_secure: bool, muxer_opt: TMuxerOptions = None -) -> Tuple[SwarmConn, SwarmConn]: +) -> AsyncIterator[Tuple[SwarmConn, SwarmConn]]: async with swarm_pair_factory(is_secure) as swarms: conn_0 = swarms[0].connections[swarms[1].get_peer_id()] conn_1 = swarms[1].connections[swarms[0].get_peer_id()] @@ -249,7 +290,9 @@ async def swarm_conn_pair_factory( @asynccontextmanager -async def mplex_conn_pair_factory(is_secure: bool) -> Tuple[Mplex, Mplex]: +async def mplex_conn_pair_factory( + is_secure: bool +) -> AsyncIterator[Tuple[Mplex, Mplex]]: muxer_opt = {MPLEX_PROTOCOL_ID: Mplex} async with swarm_conn_pair_factory(is_secure, muxer_opt=muxer_opt) as swarm_pair: yield ( @@ -259,21 +302,25 @@ async def mplex_conn_pair_factory(is_secure: bool) -> Tuple[Mplex, Mplex]: @asynccontextmanager -async def mplex_stream_pair_factory(is_secure: bool) -> Tuple[MplexStream, MplexStream]: +async def mplex_stream_pair_factory( + is_secure: bool +) -> AsyncIterator[Tuple[MplexStream, MplexStream]]: async with mplex_conn_pair_factory(is_secure) as mplex_conn_pair_info: mplex_conn_0, mplex_conn_1 = mplex_conn_pair_info - stream_0 = await mplex_conn_0.open_stream() + stream_0 = cast(MplexStream, await mplex_conn_0.open_stream()) await trio.sleep(0.01) stream_1: MplexStream async with mplex_conn_1.streams_lock: if len(mplex_conn_1.streams) != 1: raise Exception("Mplex should not have any other stream") stream_1 = tuple(mplex_conn_1.streams.values())[0] - yield cast(MplexStream, stream_0), cast(MplexStream, stream_1) + yield stream_0, stream_1 @asynccontextmanager -async def net_stream_pair_factory(is_secure: bool) -> Tuple[INetStream, INetStream]: +async def net_stream_pair_factory( + is_secure: bool +) -> AsyncIterator[Tuple[INetStream, INetStream]]: protocol_id = TProtocol("/example/id/1") stream_1: INetStream diff --git a/libp2p/tools/pubsub/dummy_account_node.py b/libp2p/tools/pubsub/dummy_account_node.py index 94f65763..5a61ed69 100644 --- a/libp2p/tools/pubsub/dummy_account_node.py +++ b/libp2p/tools/pubsub/dummy_account_node.py @@ -1,12 +1,11 @@ -import asyncio -from typing import Dict -import uuid +from contextlib import AsyncExitStack, asynccontextmanager +from typing import AsyncIterator, Dict, Tuple + +from async_service import Service, background_trio_service from libp2p.host.host_interface import IHost -from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.pubsub import Pubsub -from libp2p.tools.constants import LISTEN_MADDR -from libp2p.tools.factories import FloodsubFactory, PubsubFactory +from libp2p.tools.factories import PubsubFactory CRYPTO_TOPIC = "ethereum" @@ -18,7 +17,7 @@ CRYPTO_TOPIC = "ethereum" # Determine message type by looking at first item before first comma -class DummyAccountNode: +class DummyAccountNode(Service): """ Node which has an internal balance mapping, meant to serve as a dummy crypto blockchain. @@ -27,19 +26,24 @@ class DummyAccountNode: crypto each user in the mappings holds """ - libp2p_node: IHost pubsub: Pubsub - floodsub: FloodSub - def __init__(self, libp2p_node: IHost, pubsub: Pubsub, floodsub: FloodSub): - self.libp2p_node = libp2p_node + def __init__(self, pubsub: Pubsub) -> None: self.pubsub = pubsub - self.floodsub = floodsub self.balances: Dict[str, int] = {} - self.node_id = str(uuid.uuid1()) + + @property + def host(self) -> IHost: + return self.pubsub.host + + async def run(self) -> None: + self.subscription = await self.pubsub.subscribe(CRYPTO_TOPIC) + self.manager.run_daemon_task(self.handle_incoming_msgs) + await self.manager.wait_finished() @classmethod - async def create(cls) -> "DummyAccountNode": + @asynccontextmanager + async def create(cls, number: int) -> AsyncIterator[Tuple["DummyAccountNode", ...]]: """ Create a new DummyAccountNode and attach a libp2p node, a floodsub, and a pubsub instance to this new node. @@ -47,15 +51,17 @@ class DummyAccountNode: We use create as this serves as a factory function and allows us to use async await, unlike the init function """ - - pubsub = PubsubFactory(router=FloodsubFactory()) - await pubsub.host.get_network().listen(LISTEN_MADDR) - return cls(libp2p_node=pubsub.host, pubsub=pubsub, floodsub=pubsub.router) + async with PubsubFactory.create_batch_with_floodsub(number) as pubsubs: + async with AsyncExitStack() as stack: + dummy_acount_nodes = tuple(cls(pubsub) for pubsub in pubsubs) + for node in dummy_acount_nodes: + await stack.enter_async_context(background_trio_service(node)) + yield dummy_acount_nodes async def handle_incoming_msgs(self) -> None: """Handle all incoming messages on the CRYPTO_TOPIC from peers.""" while True: - incoming = await self.q.get() + incoming = await self.subscription.receive() msg_comps = incoming.data.decode("utf-8").split(",") if msg_comps[0] == "send": @@ -63,13 +69,6 @@ class DummyAccountNode: elif msg_comps[0] == "set": self.handle_set_crypto(msg_comps[1], int(msg_comps[2])) - async def setup_crypto_networking(self) -> None: - """Subscribe to CRYPTO_TOPIC and perform call to function that handles - all incoming messages on said topic.""" - self.q = await self.pubsub.subscribe(CRYPTO_TOPIC) - - asyncio.ensure_future(self.handle_incoming_msgs()) - async def publish_send_crypto( self, source_user: str, dest_user: str, amount: int ) -> None: diff --git a/libp2p/tools/pubsub/floodsub_integration_test_settings.py b/libp2p/tools/pubsub/floodsub_integration_test_settings.py index 90939dec..58a5b242 100644 --- a/libp2p/tools/pubsub/floodsub_integration_test_settings.py +++ b/libp2p/tools/pubsub/floodsub_integration_test_settings.py @@ -1,12 +1,10 @@ # type: ignore # To add typing to this module, it's better to do it after refactoring test cases into classes -import asyncio - import pytest +import trio -from libp2p.tools.constants import FLOODSUB_PROTOCOL_ID, LISTEN_MADDR -from libp2p.tools.factories import PubsubFactory +from libp2p.tools.constants import FLOODSUB_PROTOCOL_ID from libp2p.tools.utils import connect SUPPORTED_PROTOCOLS = [FLOODSUB_PROTOCOL_ID] @@ -15,6 +13,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [ { "name": "simple_two_nodes", "supported_protocols": SUPPORTED_PROTOCOLS, + "nodes": ["A", "B"], "adj_list": {"A": ["B"]}, "topic_map": {"topic1": ["B"]}, "messages": [{"topics": ["topic1"], "data": b"foo", "node_id": "A"}], @@ -22,6 +21,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [ { "name": "three_nodes_two_topics", "supported_protocols": SUPPORTED_PROTOCOLS, + "nodes": ["A", "B", "C"], "adj_list": {"A": ["B"], "B": ["C"]}, "topic_map": {"topic1": ["B", "C"], "topic2": ["B", "C"]}, "messages": [ @@ -32,6 +32,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [ { "name": "two_nodes_one_topic_single_subscriber_is_sender", "supported_protocols": SUPPORTED_PROTOCOLS, + "nodes": ["A", "B"], "adj_list": {"A": ["B"]}, "topic_map": {"topic1": ["B"]}, "messages": [{"topics": ["topic1"], "data": b"Alex is tall", "node_id": "B"}], @@ -39,6 +40,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [ { "name": "two_nodes_one_topic_two_msgs", "supported_protocols": SUPPORTED_PROTOCOLS, + "nodes": ["A", "B"], "adj_list": {"A": ["B"]}, "topic_map": {"topic1": ["B"]}, "messages": [ @@ -49,6 +51,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [ { "name": "seven_nodes_tree_one_topics", "supported_protocols": SUPPORTED_PROTOCOLS, + "nodes": ["1", "2", "3", "4", "5", "6", "7"], "adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]}, "topic_map": {"astrophysics": ["2", "3", "4", "5", "6", "7"]}, "messages": [{"topics": ["astrophysics"], "data": b"e=mc^2", "node_id": "1"}], @@ -56,6 +59,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [ { "name": "seven_nodes_tree_three_topics", "supported_protocols": SUPPORTED_PROTOCOLS, + "nodes": ["1", "2", "3", "4", "5", "6", "7"], "adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]}, "topic_map": { "astrophysics": ["2", "3", "4", "5", "6", "7"], @@ -71,6 +75,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [ { "name": "seven_nodes_tree_three_topics_diff_origin", "supported_protocols": SUPPORTED_PROTOCOLS, + "nodes": ["1", "2", "3", "4", "5", "6", "7"], "adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]}, "topic_map": { "astrophysics": ["1", "2", "3", "4", "5", "6", "7"], @@ -86,6 +91,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [ { "name": "three_nodes_clique_two_topic_diff_origin", "supported_protocols": SUPPORTED_PROTOCOLS, + "nodes": ["1", "2", "3"], "adj_list": {"1": ["2", "3"], "2": ["3"]}, "topic_map": {"astrophysics": ["1", "2", "3"], "school": ["1", "2", "3"]}, "messages": [ @@ -97,6 +103,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [ { "name": "four_nodes_clique_two_topic_diff_origin_many_msgs", "supported_protocols": SUPPORTED_PROTOCOLS, + "nodes": ["1", "2", "3", "4"], "adj_list": { "1": ["2", "3", "4"], "2": ["1", "3", "4"], @@ -120,6 +127,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [ { "name": "five_nodes_ring_two_topic_diff_origin_many_msgs", "supported_protocols": SUPPORTED_PROTOCOLS, + "nodes": ["1", "2", "3", "4", "5"], "adj_list": {"1": ["2"], "2": ["3"], "3": ["4"], "4": ["5"], "5": ["1"]}, "topic_map": { "astrophysics": ["1", "2", "3", "4", "5"], @@ -143,7 +151,7 @@ floodsub_protocol_pytest_params = [ ] -async def perform_test_from_obj(obj, router_factory) -> None: +async def perform_test_from_obj(obj, pubsub_factory) -> None: """ Perform pubsub tests from a test obj. test obj are composed as follows: @@ -174,88 +182,75 @@ async def perform_test_from_obj(obj, router_factory) -> None: # Step 1) Create graph adj_list = obj["adj_list"] + node_list = obj["nodes"] node_map = {} pubsub_map = {} - async def add_node(node_id_str: str) -> None: - pubsub_router = router_factory(protocols=obj["supported_protocols"]) - pubsub = PubsubFactory(router=pubsub_router) - await pubsub.host.get_network().listen(LISTEN_MADDR) - node_map[node_id_str] = pubsub.host - pubsub_map[node_id_str] = pubsub + async with pubsub_factory( + number=len(node_list), protocols=obj["supported_protocols"] + ) as pubsubs: + for node_id_str, pubsub in zip(node_list, pubsubs): + node_map[node_id_str] = pubsub.host + pubsub_map[node_id_str] = pubsub - tasks_connect = [] - for start_node_id in adj_list: - # Create node if node does not yet exist - if start_node_id not in node_map: - await add_node(start_node_id) + # Connect nodes and wait at least for 2 seconds + async with trio.open_nursery() as nursery: + for start_node_id in adj_list: + # For each neighbor of start_node, create if does not yet exist, + # then connect start_node to neighbor + for neighbor_id in adj_list[start_node_id]: + nursery.start_soon( + connect, node_map[start_node_id], node_map[neighbor_id] + ) + nursery.start_soon(trio.sleep, 2) - # For each neighbor of start_node, create if does not yet exist, - # then connect start_node to neighbor - for neighbor_id in adj_list[start_node_id]: - # Create neighbor if neighbor does not yet exist - if neighbor_id not in node_map: - await add_node(neighbor_id) - tasks_connect.append( - connect(node_map[start_node_id], node_map[neighbor_id]) - ) - # Connect nodes and wait at least for 2 seconds - await asyncio.gather(*tasks_connect, asyncio.sleep(2)) + # Step 2) Subscribe to topics + queues_map = {} + topic_map = obj["topic_map"] - # Step 2) Subscribe to topics - queues_map = {} - topic_map = obj["topic_map"] + async def subscribe_node(node_id, topic): + if node_id not in queues_map: + queues_map[node_id] = {} + # Avoid repeated works + if topic in queues_map[node_id]: + # Checkpoint + await trio.hazmat.checkpoint() + return + sub = await pubsub_map[node_id].subscribe(topic) + queues_map[node_id][topic] = sub - tasks_topic = [] - tasks_topic_data = [] - for topic, node_ids in topic_map.items(): - for node_id in node_ids: - tasks_topic.append(pubsub_map[node_id].subscribe(topic)) - tasks_topic_data.append((node_id, topic)) - tasks_topic.append(asyncio.sleep(2)) + async with trio.open_nursery() as nursery: + for topic, node_ids in topic_map.items(): + for node_id in node_ids: + nursery.start_soon(subscribe_node, node_id, topic) + nursery.start_soon(trio.sleep, 2) - # Gather is like Promise.all - responses = await asyncio.gather(*tasks_topic) - for i in range(len(responses) - 1): - node_id, topic = tasks_topic_data[i] - if node_id not in queues_map: - queues_map[node_id] = {} - # Store queue in topic-queue map for node - queues_map[node_id][topic] = responses[i] + # Step 3) Publish messages + topics_in_msgs_ordered = [] + messages = obj["messages"] - # Allow time for subscribing before continuing - await asyncio.sleep(0.01) + for msg in messages: + topics = msg["topics"] + data = msg["data"] + node_id = msg["node_id"] - # Step 3) Publish messages - topics_in_msgs_ordered = [] - messages = obj["messages"] - tasks_publish = [] + # Publish message + # TODO: Should be single RPC package with several topics + for topic in topics: + await pubsub_map[node_id].publish(topic, data) - for msg in messages: - topics = msg["topics"] - data = msg["data"] - node_id = msg["node_id"] + # For each topic in topics, add (topic, node_id, data) tuple to ordered test list + for topic in topics: + topics_in_msgs_ordered.append((topic, node_id, data)) + # Allow time for publishing before continuing + await trio.sleep(1) - # Publish message - # TODO: Should be single RPC package with several topics - for topic in topics: - tasks_publish.append(pubsub_map[node_id].publish(topic, data)) - - # For each topic in topics, add (topic, node_id, data) tuple to ordered test list - for topic in topics: - topics_in_msgs_ordered.append((topic, node_id, data)) - - # Allow time for publishing before continuing - await asyncio.gather(*tasks_publish, asyncio.sleep(2)) - - # Step 4) Check that all messages were received correctly. - for topic, origin_node_id, data in topics_in_msgs_ordered: - # Look at each node in each topic - for node_id in topic_map[topic]: - # Get message from subscription queue - msg = await queues_map[node_id][topic].get() - assert data == msg.data - # Check the message origin - assert node_map[origin_node_id].get_id().to_bytes() == msg.from_id - - # Success, terminate pending tasks. + # Step 4) Check that all messages were received correctly. + for topic, origin_node_id, data in topics_in_msgs_ordered: + # Look at each node in each topic + for node_id in topic_map[topic]: + # Get message from subscription queue + msg = await queues_map[node_id][topic].receive() + assert data == msg.data + # Check the message origin + assert node_map[origin_node_id].get_id().to_bytes() == msg.from_id diff --git a/libp2p/tools/utils.py b/libp2p/tools/utils.py index 9ad68150..a66155c3 100644 --- a/libp2p/tools/utils.py +++ b/libp2p/tools/utils.py @@ -1,17 +1,9 @@ -from typing import Callable, List, Sequence, Tuple +from typing import Awaitable, Callable -import multiaddr -import trio - -from libp2p import new_node -from libp2p.host.basic_host import BasicHost from libp2p.host.host_interface import IHost -from libp2p.kademlia.network import KademliaServer from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.swarm import Swarm from libp2p.peer.peerinfo import info_from_p2p_addr -from libp2p.routing.interfaces import IPeerRouting -from libp2p.routing.kademlia.kademlia_peer_router import KadmeliaPeerRouter from .constants import MAX_READ_LEN @@ -36,49 +28,9 @@ async def connect(node1: IHost, node2: IHost) -> None: await node1.connect(info) -async def set_up_nodes_by_transport_opt( - transport_opt_list: Sequence[Sequence[str]], nursery: trio.Nursery -) -> Tuple[BasicHost, ...]: - nodes_list = [] - for transport_opt in transport_opt_list: - node = new_node(transport_opt=transport_opt) - await node.get_network().listen( - multiaddr.Multiaddr(transport_opt[0]), nursery=nursery - ) - nodes_list.append(node) - return tuple(nodes_list) - - -async def set_up_nodes_by_transport_and_disc_opt( - transport_disc_opt_list: Sequence[Tuple[Sequence[str], IPeerRouting]] -) -> Tuple[BasicHost, ...]: - nodes_list = [] - for transport_opt, disc_opt in transport_disc_opt_list: - node = await new_node(transport_opt=transport_opt, disc_opt=disc_opt) - await node.get_network().listen(multiaddr.Multiaddr(transport_opt[0])) - nodes_list.append(node) - return tuple(nodes_list) - - -async def set_up_routers( - router_ports: Tuple[int, ...] = (0, 0) -) -> List[KadmeliaPeerRouter]: - """The default ``router_confs`` selects two free ports local to this - machine.""" - bootstrap_node = KademliaServer() # type: ignore - await bootstrap_node.listen(router_ports[0]) - - routers = [KadmeliaPeerRouter(bootstrap_node)] - for port in router_ports[1:]: - node = KademliaServer() # type: ignore - await node.listen(port) - - await node.bootstrap_node(bootstrap_node.address) - routers.append(KadmeliaPeerRouter(node)) - return routers - - -def create_echo_stream_handler(ack_prefix: str) -> Callable[[INetStream], None]: +def create_echo_stream_handler( + ack_prefix: str +) -> Callable[[INetStream], Awaitable[None]]: async def echo_stream_handler(stream: INetStream) -> None: while True: read_string = (await stream.read(MAX_READ_LEN)).decode() diff --git a/libp2p/transport/listener_interface.py b/libp2p/transport/listener_interface.py index 1b22531b..6d737233 100644 --- a/libp2p/transport/listener_interface.py +++ b/libp2p/transport/listener_interface.py @@ -1,12 +1,13 @@ from abc import ABC, abstractmethod -from typing import List +from typing import Tuple from multiaddr import Multiaddr +import trio class IListener(ABC): @abstractmethod - async def listen(self, maddr: Multiaddr) -> bool: + async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool: """ put listener in listening mode and wait for incoming connections. @@ -15,14 +16,9 @@ class IListener(ABC): """ @abstractmethod - def get_addrs(self) -> List[Multiaddr]: + def get_addrs(self) -> Tuple[Multiaddr, ...]: """ retrieve list of addresses the listener is listening on. :return: return list of addrs """ - - @abstractmethod - async def close(self) -> None: - """close the listener such that no more connections can be open on this - transport instance.""" diff --git a/libp2p/transport/tcp/tcp.py b/libp2p/transport/tcp/tcp.py index 745bafe8..04d8874f 100644 --- a/libp2p/transport/tcp/tcp.py +++ b/libp2p/transport/tcp/tcp.py @@ -1,14 +1,13 @@ import logging -from socket import socket -from typing import List +from typing import Awaitable, Callable, List, Sequence, Tuple from multiaddr import Multiaddr import trio +from trio_typing import TaskStatus -from libp2p.io.trio import TrioReadWriteCloser +from libp2p.io.trio import TrioTCPStream from libp2p.network.connection.raw_connection import RawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection -from libp2p.transport.exceptions import OpenConnectionError from libp2p.transport.listener_interface import IListener from libp2p.transport.transport_interface import ITransport from libp2p.transport.typing import THandler @@ -18,14 +17,12 @@ logger = logging.getLogger("libp2p.transport.tcp") class TCPListener(IListener): multiaddrs: List[Multiaddr] - server = None def __init__(self, handler_function: THandler) -> None: self.multiaddrs = [] - self.server = None self.handler = handler_function - # TODO: Fix handling? + # TODO: Get rid of `nursery`? async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> None: """ put listener in listening mode and wait for incoming connections. @@ -34,13 +31,18 @@ class TCPListener(IListener): :return: return True if successful """ - async def serve_tcp(handler, port, host, task_status=None): + async def serve_tcp( + handler: Callable[[trio.SocketStream], Awaitable[None]], + port: int, + host: str, + task_status: TaskStatus[Sequence[trio.SocketListener]] = None, + ) -> None: logger.debug("serve_tcp %s %s", host, port) await trio.serve_tcp(handler, port, host=host, task_status=task_status) - async def handler(stream): - read_write_closer = TrioReadWriteCloser(stream) - await self.handler(read_write_closer) + async def handler(stream: trio.SocketStream) -> None: + tcp_stream = TrioTCPStream(stream) + await self.handler(tcp_stream) listeners = await nursery.start( serve_tcp, @@ -51,7 +53,7 @@ class TCPListener(IListener): socket = listeners[0].socket self.multiaddrs.append(_multiaddr_from_socket(socket)) - def get_addrs(self) -> List[Multiaddr]: + def get_addrs(self) -> Tuple[Multiaddr, ...]: """ retrieve list of addresses the listener is listening on. @@ -59,15 +61,6 @@ class TCPListener(IListener): """ return tuple(self.multiaddrs) - async def close(self) -> None: - """close the listener such that no more connections can be open on this - transport instance.""" - if self.server is None: - return - self.server.close() - await self.server.wait_closed() - self.server = None - class TCP(ITransport): async def dial(self, maddr: Multiaddr) -> IRawConnection: @@ -82,7 +75,7 @@ class TCP(ITransport): self.port = int(maddr.value_for_protocol("tcp")) stream = await trio.open_tcp_stream(self.host, self.port) - read_write_closer = TrioReadWriteCloser(stream) + read_write_closer = TrioTCPStream(stream) return RawConnection(read_write_closer, True) @@ -97,5 +90,6 @@ class TCP(ITransport): return TCPListener(handler_function) -def _multiaddr_from_socket(socket: socket) -> Multiaddr: - return Multiaddr("/ip4/%s/tcp/%s" % socket.getsockname()) +def _multiaddr_from_socket(socket: trio.socket.SocketType) -> Multiaddr: + ip, port = socket.getsockname() # type: ignore + return Multiaddr(f"/ip4/{ip}/tcp/{port}") diff --git a/tests/host/test_ping.py b/tests/host/test_ping.py index 29135141..7a0f8db5 100644 --- a/tests/host/test_ping.py +++ b/tests/host/test_ping.py @@ -1,7 +1,7 @@ -import trio import secrets import pytest +import trio from libp2p.host.ping import ID, PING_LENGTH from libp2p.tools.factories import host_pair_factory diff --git a/tests/host/test_routed_host.py b/tests/host/test_routed_host.py deleted file mode 100644 index 006dd222..00000000 --- a/tests/host/test_routed_host.py +++ /dev/null @@ -1,73 +0,0 @@ -import pytest - -from libp2p.host.exceptions import ConnectionFailure -from libp2p.peer.peerinfo import PeerInfo -from libp2p.routing.kademlia.kademlia_peer_router import peer_info_to_str -from libp2p.tools.utils import ( - set_up_nodes_by_transport_and_disc_opt, - set_up_nodes_by_transport_opt, - set_up_routers, -) -from libp2p.tools.factories import RoutedHostFactory - - -# FIXME: - -# TODO: Kademlia is full of asyncio code. Skip it for now -@pytest.mark.skip -@pytest.mark.trio -async def test_host_routing_success(is_host_secure): - async with RoutedHostFactory.create_batch_and_listen( - is_host_secure, 2 - ) as routed_hosts: - # Set routing info - await routed_hosts[0]._router.server.set( - routed_hosts[0].get_id().xor_id, - peer_info_to_str( - PeerInfo(routed_hosts[0].get_id(), routed_hosts[0].get_addrs()) - ), - ) - await routed_hosts[1]._router.server.set( - routed_hosts[1].get_id().xor_id, - peer_info_to_str( - PeerInfo(routed_hosts[1].get_id(), routed_hosts[1].get_addrs()) - ), - ) - - # forces to use routing as no addrs are provided - await routed_hosts[0].connect(PeerInfo(routed_hosts[1].get_id(), [])) - await routed_hosts[1].connect(PeerInfo(routed_hosts[0].get_id(), [])) - - -# TODO: Kademlia is full of asyncio code. Skip it for now -@pytest.mark.skip -@pytest.mark.trio -async def test_host_routing_fail(): - routers = await set_up_routers() - transports = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] - transport_disc_opt_list = zip(transports, routers) - (host_a, host_b) = await set_up_nodes_by_transport_and_disc_opt( - transport_disc_opt_list - ) - - host_c = (await set_up_nodes_by_transport_opt([["/ip4/127.0.0.1/tcp/0"]]))[0] - - # Set routing info - await routers[0].server.set( - host_a.get_id().xor_id, - peer_info_to_str(PeerInfo(host_a.get_id(), host_a.get_addrs())), - ) - await routers[1].server.set( - host_b.get_id().xor_id, - peer_info_to_str(PeerInfo(host_b.get_id(), host_b.get_addrs())), - ) - - # routing fails because host_c does not use routing - with pytest.raises(ConnectionFailure): - await host_a.connect(PeerInfo(host_c.get_id(), [])) - with pytest.raises(ConnectionFailure): - await host_b.connect(PeerInfo(host_c.get_id(), [])) - - # Clean up - routers[0].server.stop() - routers[1].server.stop() diff --git a/tests/protocol_muxer/test_protocol_muxer.py b/tests/protocol_muxer/test_protocol_muxer.py index 9533d1f3..cd82652c 100644 --- a/tests/protocol_muxer/test_protocol_muxer.py +++ b/tests/protocol_muxer/test_protocol_muxer.py @@ -4,7 +4,6 @@ from libp2p.host.exceptions import StreamFailure from libp2p.tools.factories import HostFactory from libp2p.tools.utils import create_echo_stream_handler - PROTOCOL_ECHO = "/echo/1.0.0" PROTOCOL_POTATO = "/potato/1.0.0" PROTOCOL_FOO = "/foo/1.0.0" diff --git a/tests/pubsub/conftest.py b/tests/pubsub/conftest.py deleted file mode 100644 index 6c08dd78..00000000 --- a/tests/pubsub/conftest.py +++ /dev/null @@ -1,22 +0,0 @@ -import pytest - -from libp2p.tools.constants import GOSSIPSUB_PARAMS -from libp2p.tools.factories import FloodsubFactory, GossipsubFactory, PubsubFactory - - -@pytest.fixture -def pubsub_cache_size(): - return None # default - - -@pytest.fixture -def gossipsub_params(): - return GOSSIPSUB_PARAMS - - -# @pytest.fixture -# def pubsubs_gsub(num_hosts, hosts, pubsub_cache_size, gossipsub_params): -# gossipsubs = GossipsubFactory.create_batch(num_hosts, **gossipsub_params._asdict()) -# _pubsubs_gsub = _make_pubsubs(hosts, gossipsubs, pubsub_cache_size) -# yield _pubsubs_gsub -# # TODO: Clean up diff --git a/tests/pubsub/test_dummyaccount_demo.py b/tests/pubsub/test_dummyaccount_demo.py index cdda6035..24d5bd41 100644 --- a/tests/pubsub/test_dummyaccount_demo.py +++ b/tests/pubsub/test_dummyaccount_demo.py @@ -1,19 +1,10 @@ -import asyncio -from threading import Thread - import pytest +import trio from libp2p.tools.pubsub.dummy_account_node import DummyAccountNode from libp2p.tools.utils import connect -def create_setup_in_new_thread_func(dummy_node): - def setup_in_new_thread(): - asyncio.ensure_future(dummy_node.setup_crypto_networking()) - - return setup_in_new_thread - - async def perform_test(num_nodes, adjacency_map, action_func, assertion_func): """ Helper function to allow for easy construction of custom tests for dummy @@ -26,47 +17,35 @@ async def perform_test(num_nodes, adjacency_map, action_func, assertion_func): :param assertion_func: assertions for testing the results of the actions are correct """ - # Create nodes - dummy_nodes = [] - for _ in range(num_nodes): - dummy_nodes.append(await DummyAccountNode.create()) + async with DummyAccountNode.create(num_nodes) as dummy_nodes: + # Create connections between nodes according to `adjacency_map` + async with trio.open_nursery() as nursery: + for source_num in adjacency_map: + target_nums = adjacency_map[source_num] + for target_num in target_nums: + nursery.start_soon( + connect, + dummy_nodes[source_num].host, + dummy_nodes[target_num].host, + ) - # Create network - for source_num in adjacency_map: - target_nums = adjacency_map[source_num] - for target_num in target_nums: - await connect( - dummy_nodes[source_num].libp2p_node, dummy_nodes[target_num].libp2p_node - ) + # Allow time for network creation to take place + await trio.sleep(0.25) - # Allow time for network creation to take place - await asyncio.sleep(0.25) + # Perform action function + await action_func(dummy_nodes) - # Start a thread for each node so that each node can listen and respond - # to messages on its own thread, which will avoid waiting indefinitely - # on the main thread. On this thread, call the setup func for the node, - # which subscribes the node to the CRYPTO_TOPIC topic - for dummy_node in dummy_nodes: - thread = Thread(target=create_setup_in_new_thread_func(dummy_node)) - thread.run() + # Allow time for action function to be performed (i.e. messages to propogate) + await trio.sleep(1) - # Allow time for nodes to subscribe to CRYPTO_TOPIC topic - await asyncio.sleep(0.25) - - # Perform action function - await action_func(dummy_nodes) - - # Allow time for action function to be performed (i.e. messages to propogate) - await asyncio.sleep(1) - - # Perform assertion function - for dummy_node in dummy_nodes: - assertion_func(dummy_node) + # Perform assertion function + for dummy_node in dummy_nodes: + assertion_func(dummy_node) # Success, terminate pending tasks. -@pytest.mark.asyncio +@pytest.mark.trio async def test_simple_two_nodes(): num_nodes = 2 adj_map = {0: [1]} @@ -80,7 +59,7 @@ async def test_simple_two_nodes(): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.asyncio +@pytest.mark.trio async def test_simple_three_nodes_line_topography(): num_nodes = 3 adj_map = {0: [1], 1: [2]} @@ -94,7 +73,7 @@ async def test_simple_three_nodes_line_topography(): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.asyncio +@pytest.mark.trio async def test_simple_three_nodes_triangle_topography(): num_nodes = 3 adj_map = {0: [1, 2], 1: [2]} @@ -108,7 +87,7 @@ async def test_simple_three_nodes_triangle_topography(): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.asyncio +@pytest.mark.trio async def test_simple_seven_nodes_tree_topography(): num_nodes = 7 adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]} @@ -122,14 +101,14 @@ async def test_simple_seven_nodes_tree_topography(): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.asyncio +@pytest.mark.trio async def test_set_then_send_from_root_seven_nodes_tree_topography(): num_nodes = 7 adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]} async def action_func(dummy_nodes): await dummy_nodes[0].publish_set_crypto("aspyn", 20) - await asyncio.sleep(0.25) + await trio.sleep(0.25) await dummy_nodes[0].publish_send_crypto("aspyn", "alex", 5) def assertion_func(dummy_node): @@ -139,14 +118,14 @@ async def test_set_then_send_from_root_seven_nodes_tree_topography(): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.asyncio +@pytest.mark.trio async def test_set_then_send_from_different_leafs_seven_nodes_tree_topography(): num_nodes = 7 adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]} async def action_func(dummy_nodes): await dummy_nodes[6].publish_set_crypto("aspyn", 20) - await asyncio.sleep(0.25) + await trio.sleep(0.25) await dummy_nodes[4].publish_send_crypto("aspyn", "alex", 5) def assertion_func(dummy_node): @@ -156,7 +135,7 @@ async def test_set_then_send_from_different_leafs_seven_nodes_tree_topography(): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.asyncio +@pytest.mark.trio async def test_simple_five_nodes_ring_topography(): num_nodes = 5 adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]} @@ -170,14 +149,14 @@ async def test_simple_five_nodes_ring_topography(): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.asyncio +@pytest.mark.trio async def test_set_then_send_from_diff_nodes_five_nodes_ring_topography(): num_nodes = 5 adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]} async def action_func(dummy_nodes): await dummy_nodes[0].publish_set_crypto("alex", 20) - await asyncio.sleep(0.25) + await trio.sleep(0.25) await dummy_nodes[3].publish_send_crypto("alex", "rob", 12) def assertion_func(dummy_node): @@ -187,7 +166,7 @@ async def test_set_then_send_from_diff_nodes_five_nodes_ring_topography(): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.asyncio +@pytest.mark.trio @pytest.mark.slow async def test_set_then_send_from_five_diff_nodes_five_nodes_ring_topography(): num_nodes = 5 @@ -195,13 +174,13 @@ async def test_set_then_send_from_five_diff_nodes_five_nodes_ring_topography(): async def action_func(dummy_nodes): await dummy_nodes[0].publish_set_crypto("alex", 20) - await asyncio.sleep(1) + await trio.sleep(1) await dummy_nodes[1].publish_send_crypto("alex", "rob", 3) - await asyncio.sleep(1) + await trio.sleep(1) await dummy_nodes[2].publish_send_crypto("rob", "aspyn", 2) - await asyncio.sleep(1) + await trio.sleep(1) await dummy_nodes[3].publish_send_crypto("aspyn", "zx", 1) - await asyncio.sleep(1) + await trio.sleep(1) await dummy_nodes[4].publish_send_crypto("zx", "raul", 1) def assertion_func(dummy_node): diff --git a/tests/pubsub/test_floodsub.py b/tests/pubsub/test_floodsub.py index 7564a949..dbeb6833 100644 --- a/tests/pubsub/test_floodsub.py +++ b/tests/pubsub/test_floodsub.py @@ -1,9 +1,10 @@ -import asyncio +import functools import pytest +import trio from libp2p.peer.id import ID -from libp2p.tools.factories import FloodsubFactory +from libp2p.tools.factories import PubsubFactory from libp2p.tools.pubsub.floodsub_integration_test_settings import ( floodsub_protocol_pytest_params, perform_test_from_obj, @@ -11,79 +12,83 @@ from libp2p.tools.pubsub.floodsub_integration_test_settings import ( from libp2p.tools.utils import connect -@pytest.mark.parametrize("num_hosts", (2,)) -@pytest.mark.asyncio -async def test_simple_two_nodes(pubsubs_fsub): - topic = "my_topic" - data = b"some data" +@pytest.mark.trio +async def test_simple_two_nodes(): + async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub: + topic = "my_topic" + data = b"some data" - await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host) - await asyncio.sleep(0.25) + await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host) + await trio.sleep(0.25) - sub_b = await pubsubs_fsub[1].subscribe(topic) - # Sleep to let a know of b's subscription - await asyncio.sleep(0.25) + sub_b = await pubsubs_fsub[1].subscribe(topic) + # Sleep to let a know of b's subscription + await trio.sleep(0.25) - await pubsubs_fsub[0].publish(topic, data) + await pubsubs_fsub[0].publish(topic, data) - res_b = await sub_b.get() + res_b = await sub_b.receive() - # Check that the msg received by node_b is the same - # as the message sent by node_a - assert ID(res_b.from_id) == pubsubs_fsub[0].host.get_id() - assert res_b.data == data - assert res_b.topicIDs == [topic] - - # Success, terminate pending tasks. + # Check that the msg received by node_b is the same + # as the message sent by node_a + assert ID(res_b.from_id) == pubsubs_fsub[0].host.get_id() + assert res_b.data == data + assert res_b.topicIDs == [topic] -# Initialize Pubsub with a cache_size of 4 -@pytest.mark.parametrize("num_hosts, pubsub_cache_size", ((2, 4),)) -@pytest.mark.asyncio -async def test_lru_cache_two_nodes(pubsubs_fsub, monkeypatch): +@pytest.mark.trio +async def test_lru_cache_two_nodes(monkeypatch): # two nodes with cache_size of 4 - # `node_a` send the following messages to node_b - message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1] - # `node_b` should only receive the following - expected_received_indices = [1, 2, 3, 4, 5, 1] + async with PubsubFactory.create_batch_with_floodsub( + 2, cache_size=4 + ) as pubsubs_fsub: + # `node_a` send the following messages to node_b + message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1] + # `node_b` should only receive the following + expected_received_indices = [1, 2, 3, 4, 5, 1] - topic = "my_topic" + topic = "my_topic" - # Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`. - def get_msg_id(msg): - # Originally it is `(msg.seqno, msg.from_id)` - return (msg.data, msg.from_id) + # Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`. + def get_msg_id(msg): + # Originally it is `(msg.seqno, msg.from_id)` + return (msg.data, msg.from_id) - import libp2p.pubsub.pubsub + import libp2p.pubsub.pubsub - monkeypatch.setattr(libp2p.pubsub.pubsub, "get_msg_id", get_msg_id) + monkeypatch.setattr(libp2p.pubsub.pubsub, "get_msg_id", get_msg_id) - await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host) - await asyncio.sleep(0.25) + await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host) + await trio.sleep(0.25) - sub_b = await pubsubs_fsub[1].subscribe(topic) - await asyncio.sleep(0.25) + sub_b = await pubsubs_fsub[1].subscribe(topic) + await trio.sleep(0.25) - def _make_testing_data(i: int) -> bytes: - num_int_bytes = 4 - if i >= 2 ** (num_int_bytes * 8): - raise ValueError("integer is too large to be serialized") - return b"data" + i.to_bytes(num_int_bytes, "big") + def _make_testing_data(i: int) -> bytes: + num_int_bytes = 4 + if i >= 2 ** (num_int_bytes * 8): + raise ValueError("integer is too large to be serialized") + return b"data" + i.to_bytes(num_int_bytes, "big") - for index in message_indices: - await pubsubs_fsub[0].publish(topic, _make_testing_data(index)) - await asyncio.sleep(0.25) + for index in message_indices: + await pubsubs_fsub[0].publish(topic, _make_testing_data(index)) + await trio.sleep(0.25) - for index in expected_received_indices: - res_b = await sub_b.get() - assert res_b.data == _make_testing_data(index) - assert sub_b.empty() + for index in expected_received_indices: + res_b = await sub_b.receive() + assert res_b.data == _make_testing_data(index) - # Success, terminate pending tasks. + with pytest.raises(trio.WouldBlock): + sub_b.receive_nowait() @pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params) -@pytest.mark.asyncio +@pytest.mark.trio @pytest.mark.slow -async def test_gossipsub_run_with_floodsub_tests(test_case_obj): - await perform_test_from_obj(test_case_obj, FloodsubFactory) +async def test_gossipsub_run_with_floodsub_tests(test_case_obj, is_host_secure): + await perform_test_from_obj( + test_case_obj, + functools.partial( + PubsubFactory.create_batch_with_floodsub, is_secure=is_host_secure + ), + ) diff --git a/tests/pubsub/test_gossipsub.py b/tests/pubsub/test_gossipsub.py index 2121f8fb..b1ed3af1 100644 --- a/tests/pubsub/test_gossipsub.py +++ b/tests/pubsub/test_gossipsub.py @@ -1,368 +1,350 @@ -import asyncio import random import pytest +import trio -from libp2p.tools.constants import GossipsubParams +from libp2p.tools.factories import PubsubFactory from libp2p.tools.pubsub.utils import dense_connect, one_to_all_connect from libp2p.tools.utils import connect -@pytest.mark.parametrize( - "num_hosts, gossipsub_params", - ((4, GossipsubParams(degree=4, degree_low=3, degree_high=5)),), -) -@pytest.mark.asyncio -async def test_join(num_hosts, hosts, pubsubs_gsub): - gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub) - hosts_indices = list(range(num_hosts)) +@pytest.mark.trio +async def test_join(): + async with PubsubFactory.create_batch_with_gossipsub( + 4, degree=4, degree_low=3, degree_high=5 + ) as pubsubs_gsub: + gossipsubs = [pubsub.router for pubsub in pubsubs_gsub] + hosts = [pubsub.host for pubsub in pubsubs_gsub] + hosts_indices = list(range(len(pubsubs_gsub))) - topic = "test_join" - central_node_index = 0 - # Remove index of central host from the indices - hosts_indices.remove(central_node_index) - num_subscribed_peer = 2 - subscribed_peer_indices = random.sample(hosts_indices, num_subscribed_peer) + topic = "test_join" + central_node_index = 0 + # Remove index of central host from the indices + hosts_indices.remove(central_node_index) + num_subscribed_peer = 2 + subscribed_peer_indices = random.sample(hosts_indices, num_subscribed_peer) - # All pubsub except the one of central node subscribe to topic - for i in subscribed_peer_indices: - await pubsubs_gsub[i].subscribe(topic) + # All pubsub except the one of central node subscribe to topic + for i in subscribed_peer_indices: + await pubsubs_gsub[i].subscribe(topic) - # Connect central host to all other hosts - await one_to_all_connect(hosts, central_node_index) + # Connect central host to all other hosts + await one_to_all_connect(hosts, central_node_index) - # Wait 2 seconds for heartbeat to allow mesh to connect - await asyncio.sleep(2) - - # Central node publish to the topic so that this topic - # is added to central node's fanout - # publish from the randomly chosen host - await pubsubs_gsub[central_node_index].publish(topic, b"data") - - # Check that the gossipsub of central node has fanout for the topic - assert topic in gossipsubs[central_node_index].fanout - # Check that the gossipsub of central node does not have a mesh for the topic - assert topic not in gossipsubs[central_node_index].mesh - - # Central node subscribes the topic - await pubsubs_gsub[central_node_index].subscribe(topic) - - await asyncio.sleep(2) - - # Check that the gossipsub of central node no longer has fanout for the topic - assert topic not in gossipsubs[central_node_index].fanout - - for i in hosts_indices: - if i in subscribed_peer_indices: - assert hosts[i].get_id() in gossipsubs[central_node_index].mesh[topic] - assert hosts[central_node_index].get_id() in gossipsubs[i].mesh[topic] - else: - assert hosts[i].get_id() not in gossipsubs[central_node_index].mesh[topic] - assert topic not in gossipsubs[i].mesh - - -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_leave(pubsubs_gsub): - gossipsub = pubsubs_gsub[0].router - topic = "test_leave" - - assert topic not in gossipsub.mesh - - await gossipsub.join(topic) - assert topic in gossipsub.mesh - - await gossipsub.leave(topic) - assert topic not in gossipsub.mesh - - # Test re-leave - await gossipsub.leave(topic) - - -@pytest.mark.parametrize("num_hosts", (2,)) -@pytest.mark.asyncio -async def test_handle_graft(pubsubs_gsub, hosts, event_loop, monkeypatch): - gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub) - - index_alice = 0 - id_alice = hosts[index_alice].get_id() - index_bob = 1 - id_bob = hosts[index_bob].get_id() - await connect(hosts[index_alice], hosts[index_bob]) - - # Wait 2 seconds for heartbeat to allow mesh to connect - await asyncio.sleep(2) - - topic = "test_handle_graft" - # Only lice subscribe to the topic - await gossipsubs[index_alice].join(topic) - - # Monkey patch bob's `emit_prune` function so we can - # check if it is called in `handle_graft` - event_emit_prune = asyncio.Event() - - async def emit_prune(topic, sender_peer_id): - event_emit_prune.set() - - monkeypatch.setattr(gossipsubs[index_bob], "emit_prune", emit_prune) - - # Check that alice is bob's peer but not his mesh peer - assert id_alice in gossipsubs[index_bob].peers_gossipsub - assert topic not in gossipsubs[index_bob].mesh - - await gossipsubs[index_alice].emit_graft(topic, id_bob) - - # Check that `emit_prune` is called - await asyncio.wait_for(event_emit_prune.wait(), timeout=1, loop=event_loop) - assert event_emit_prune.is_set() - - # Check that bob is alice's peer but not her mesh peer - assert topic in gossipsubs[index_alice].mesh - assert id_bob not in gossipsubs[index_alice].mesh[topic] - assert id_bob in gossipsubs[index_alice].peers_gossipsub - - await gossipsubs[index_bob].emit_graft(topic, id_alice) - - await asyncio.sleep(1) - - # Check that bob is now alice's mesh peer - assert id_bob in gossipsubs[index_alice].mesh[topic] - - -@pytest.mark.parametrize( - "num_hosts, gossipsub_params", ((2, GossipsubParams(heartbeat_interval=3)),) -) -@pytest.mark.asyncio -async def test_handle_prune(pubsubs_gsub, hosts): - gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub) - - index_alice = 0 - id_alice = hosts[index_alice].get_id() - index_bob = 1 - id_bob = hosts[index_bob].get_id() - - topic = "test_handle_prune" - for pubsub in pubsubs_gsub: - await pubsub.subscribe(topic) - - await connect(hosts[index_alice], hosts[index_bob]) - - # Wait 3 seconds for heartbeat to allow mesh to connect - await asyncio.sleep(3) - - # Check that they are each other's mesh peer - assert id_alice in gossipsubs[index_bob].mesh[topic] - assert id_bob in gossipsubs[index_alice].mesh[topic] - - # alice emit prune message to bob, alice should be removed - # from bob's mesh peer - await gossipsubs[index_alice].emit_prune(topic, id_bob) - - # FIXME: This test currently works because the heartbeat interval - # is increased to 3 seconds, so alice won't get add back into - # bob's mesh peer during heartbeat. - await asyncio.sleep(1) - - # Check that alice is no longer bob's mesh peer - assert id_alice not in gossipsubs[index_bob].mesh[topic] - assert id_bob in gossipsubs[index_alice].mesh[topic] - - -@pytest.mark.parametrize("num_hosts", (10,)) -@pytest.mark.asyncio -async def test_dense(num_hosts, pubsubs_gsub, hosts): - num_msgs = 5 - - # All pubsub subscribe to foobar - queues = [] - for pubsub in pubsubs_gsub: - q = await pubsub.subscribe("foobar") - - # Add each blocking queue to an array of blocking queues - queues.append(q) - - # Densely connect libp2p hosts in a random way - await dense_connect(hosts) - - # Wait 2 seconds for heartbeat to allow mesh to connect - await asyncio.sleep(2) - - for i in range(num_msgs): - msg_content = b"foo " + i.to_bytes(1, "big") - - # randomly pick a message origin - origin_idx = random.randint(0, num_hosts - 1) + # Wait 2 seconds for heartbeat to allow mesh to connect + await trio.sleep(2) + # Central node publish to the topic so that this topic + # is added to central node's fanout # publish from the randomly chosen host - await pubsubs_gsub[origin_idx].publish("foobar", msg_content) + await pubsubs_gsub[central_node_index].publish(topic, b"data") - await asyncio.sleep(0.5) - # Assert that all blocking queues receive the message - for queue in queues: - msg = await queue.get() - assert msg.data == msg_content + # Check that the gossipsub of central node has fanout for the topic + assert topic in gossipsubs[central_node_index].fanout + # Check that the gossipsub of central node does not have a mesh for the topic + assert topic not in gossipsubs[central_node_index].mesh + + # Central node subscribes the topic + await pubsubs_gsub[central_node_index].subscribe(topic) + + await trio.sleep(2) + + # Check that the gossipsub of central node no longer has fanout for the topic + assert topic not in gossipsubs[central_node_index].fanout + + for i in hosts_indices: + if i in subscribed_peer_indices: + assert hosts[i].get_id() in gossipsubs[central_node_index].mesh[topic] + assert hosts[central_node_index].get_id() in gossipsubs[i].mesh[topic] + else: + assert ( + hosts[i].get_id() not in gossipsubs[central_node_index].mesh[topic] + ) + assert topic not in gossipsubs[i].mesh -@pytest.mark.parametrize("num_hosts", (10,)) -@pytest.mark.asyncio -async def test_fanout(hosts, pubsubs_gsub): - num_msgs = 5 +@pytest.mark.trio +async def test_leave(): + async with PubsubFactory.create_batch_with_gossipsub(1) as pubsubs_gsub: + gossipsub = pubsubs_gsub[0].router + topic = "test_leave" - # All pubsub subscribe to foobar except for `pubsubs_gsub[0]` - queues = [] - for i in range(1, len(pubsubs_gsub)): - q = await pubsubs_gsub[i].subscribe("foobar") + assert topic not in gossipsub.mesh - # Add each blocking queue to an array of blocking queues - queues.append(q) + await gossipsub.join(topic) + assert topic in gossipsub.mesh - # Sparsely connect libp2p hosts in random way - await dense_connect(hosts) + await gossipsub.leave(topic) + assert topic not in gossipsub.mesh - # Wait 2 seconds for heartbeat to allow mesh to connect - await asyncio.sleep(2) - - topic = "foobar" - # Send messages with origin not subscribed - for i in range(num_msgs): - msg_content = b"foo " + i.to_bytes(1, "big") - - # Pick the message origin to the node that is not subscribed to 'foobar' - origin_idx = 0 - - # publish from the randomly chosen host - await pubsubs_gsub[origin_idx].publish(topic, msg_content) - - await asyncio.sleep(0.5) - # Assert that all blocking queues receive the message - for queue in queues: - msg = await queue.get() - assert msg.data == msg_content - - # Subscribe message origin - queues.insert(0, await pubsubs_gsub[0].subscribe(topic)) - - # Send messages again - for i in range(num_msgs): - msg_content = b"bar " + i.to_bytes(1, "big") - - # Pick the message origin to the node that is not subscribed to 'foobar' - origin_idx = 0 - - # publish from the randomly chosen host - await pubsubs_gsub[origin_idx].publish(topic, msg_content) - - await asyncio.sleep(0.5) - # Assert that all blocking queues receive the message - for queue in queues: - msg = await queue.get() - assert msg.data == msg_content + # Test re-leave + await gossipsub.leave(topic) -@pytest.mark.parametrize("num_hosts", (10,)) -@pytest.mark.asyncio +@pytest.mark.trio +async def test_handle_graft(monkeypatch): + async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub: + gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub) + + index_alice = 0 + id_alice = pubsubs_gsub[index_alice].my_id + index_bob = 1 + id_bob = pubsubs_gsub[index_bob].my_id + await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host) + + # Wait 2 seconds for heartbeat to allow mesh to connect + await trio.sleep(2) + + topic = "test_handle_graft" + # Only lice subscribe to the topic + await gossipsubs[index_alice].join(topic) + + # Monkey patch bob's `emit_prune` function so we can + # check if it is called in `handle_graft` + event_emit_prune = trio.Event() + + async def emit_prune(topic, sender_peer_id): + event_emit_prune.set() + + monkeypatch.setattr(gossipsubs[index_bob], "emit_prune", emit_prune) + + # Check that alice is bob's peer but not his mesh peer + assert id_alice in gossipsubs[index_bob].peers_gossipsub + assert topic not in gossipsubs[index_bob].mesh + + await gossipsubs[index_alice].emit_graft(topic, id_bob) + + # Check that `emit_prune` is called + await event_emit_prune.wait() + + # Check that bob is alice's peer but not her mesh peer + assert topic in gossipsubs[index_alice].mesh + assert id_bob not in gossipsubs[index_alice].mesh[topic] + assert id_bob in gossipsubs[index_alice].peers_gossipsub + + await gossipsubs[index_bob].emit_graft(topic, id_alice) + + await trio.sleep(1) + + # Check that bob is now alice's mesh peer + assert id_bob in gossipsubs[index_alice].mesh[topic] + + +@pytest.mark.trio +async def test_handle_prune(): + async with PubsubFactory.create_batch_with_gossipsub( + 2, heartbeat_interval=3 + ) as pubsubs_gsub: + gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub) + + index_alice = 0 + id_alice = pubsubs_gsub[index_alice].my_id + index_bob = 1 + id_bob = pubsubs_gsub[index_bob].my_id + + topic = "test_handle_prune" + for pubsub in pubsubs_gsub: + await pubsub.subscribe(topic) + + await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host) + + # Wait 3 seconds for heartbeat to allow mesh to connect + await trio.sleep(3) + + # Check that they are each other's mesh peer + assert id_alice in gossipsubs[index_bob].mesh[topic] + assert id_bob in gossipsubs[index_alice].mesh[topic] + + # alice emit prune message to bob, alice should be removed + # from bob's mesh peer + await gossipsubs[index_alice].emit_prune(topic, id_bob) + + # FIXME: This test currently works because the heartbeat interval + # is increased to 3 seconds, so alice won't get add back into + # bob's mesh peer during heartbeat. + await trio.sleep(1) + + # Check that alice is no longer bob's mesh peer + assert id_alice not in gossipsubs[index_bob].mesh[topic] + assert id_bob in gossipsubs[index_alice].mesh[topic] + + +@pytest.mark.trio +async def test_dense(): + async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub: + hosts = [pubsub.host for pubsub in pubsubs_gsub] + num_msgs = 5 + + # All pubsub subscribe to foobar + queues = [await pubsub.subscribe("foobar") for pubsub in pubsubs_gsub] + + # Densely connect libp2p hosts in a random way + await dense_connect(hosts) + + # Wait 2 seconds for heartbeat to allow mesh to connect + await trio.sleep(2) + + for i in range(num_msgs): + msg_content = b"foo " + i.to_bytes(1, "big") + + # randomly pick a message origin + origin_idx = random.randint(0, len(hosts) - 1) + + # publish from the randomly chosen host + await pubsubs_gsub[origin_idx].publish("foobar", msg_content) + + await trio.sleep(0.5) + # Assert that all blocking queues receive the message + for queue in queues: + msg = await queue.receive() + assert msg.data == msg_content + + +@pytest.mark.trio +async def test_fanout(): + async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub: + hosts = [pubsub.host for pubsub in pubsubs_gsub] + num_msgs = 5 + + # All pubsub subscribe to foobar except for `pubsubs_gsub[0]` + subs = [await pubsub.subscribe("foobar") for pubsub in pubsubs_gsub[1:]] + + # Sparsely connect libp2p hosts in random way + await dense_connect(hosts) + + # Wait 2 seconds for heartbeat to allow mesh to connect + await trio.sleep(2) + + topic = "foobar" + # Send messages with origin not subscribed + for i in range(num_msgs): + msg_content = b"foo " + i.to_bytes(1, "big") + + # Pick the message origin to the node that is not subscribed to 'foobar' + origin_idx = 0 + + # publish from the randomly chosen host + await pubsubs_gsub[origin_idx].publish(topic, msg_content) + + await trio.sleep(0.5) + # Assert that all blocking queues receive the message + for sub in subs: + msg = await sub.receive() + assert msg.data == msg_content + + # Subscribe message origin + subs.insert(0, await pubsubs_gsub[0].subscribe(topic)) + + # Send messages again + for i in range(num_msgs): + msg_content = b"bar " + i.to_bytes(1, "big") + + # Pick the message origin to the node that is not subscribed to 'foobar' + origin_idx = 0 + + # publish from the randomly chosen host + await pubsubs_gsub[origin_idx].publish(topic, msg_content) + + await trio.sleep(0.5) + # Assert that all blocking queues receive the message + for sub in subs: + msg = await sub.receive() + assert msg.data == msg_content + + +@pytest.mark.trio @pytest.mark.slow -async def test_fanout_maintenance(hosts, pubsubs_gsub): - num_msgs = 5 +async def test_fanout_maintenance(): + async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub: + hosts = [pubsub.host for pubsub in pubsubs_gsub] + num_msgs = 5 - # All pubsub subscribe to foobar - queues = [] - topic = "foobar" - for i in range(1, len(pubsubs_gsub)): - q = await pubsubs_gsub[i].subscribe(topic) + # All pubsub subscribe to foobar + queues = [] + topic = "foobar" + for i in range(1, len(pubsubs_gsub)): + q = await pubsubs_gsub[i].subscribe(topic) - # Add each blocking queue to an array of blocking queues - queues.append(q) + # Add each blocking queue to an array of blocking queues + queues.append(q) - # Sparsely connect libp2p hosts in random way - await dense_connect(hosts) + # Sparsely connect libp2p hosts in random way + await dense_connect(hosts) - # Wait 2 seconds for heartbeat to allow mesh to connect - await asyncio.sleep(2) + # Wait 2 seconds for heartbeat to allow mesh to connect + await trio.sleep(2) - # Send messages with origin not subscribed - for i in range(num_msgs): - msg_content = b"foo " + i.to_bytes(1, "big") + # Send messages with origin not subscribed + for i in range(num_msgs): + msg_content = b"foo " + i.to_bytes(1, "big") - # Pick the message origin to the node that is not subscribed to 'foobar' - origin_idx = 0 + # Pick the message origin to the node that is not subscribed to 'foobar' + origin_idx = 0 + + # publish from the randomly chosen host + await pubsubs_gsub[origin_idx].publish(topic, msg_content) + + await trio.sleep(0.5) + # Assert that all blocking queues receive the message + for queue in queues: + msg = await queue.receive() + assert msg.data == msg_content + + for sub in pubsubs_gsub: + await sub.unsubscribe(topic) + + queues = [] + + await trio.sleep(2) + + # Resub and repeat + for i in range(1, len(pubsubs_gsub)): + q = await pubsubs_gsub[i].subscribe(topic) + + # Add each blocking queue to an array of blocking queues + queues.append(q) + + await trio.sleep(2) + + # Check messages can still be sent + for i in range(num_msgs): + msg_content = b"bar " + i.to_bytes(1, "big") + + # Pick the message origin to the node that is not subscribed to 'foobar' + origin_idx = 0 + + # publish from the randomly chosen host + await pubsubs_gsub[origin_idx].publish(topic, msg_content) + + await trio.sleep(0.5) + # Assert that all blocking queues receive the message + for queue in queues: + msg = await queue.receive() + assert msg.data == msg_content + + +@pytest.mark.trio +async def test_gossip_propagation(): + async with PubsubFactory.create_batch_with_gossipsub( + 2, degree=1, degree_low=0, degree_high=2, gossip_window=50, gossip_history=100 + ) as pubsubs_gsub: + topic = "foo" + await pubsubs_gsub[0].subscribe(topic) + + # node 0 publish to topic + msg_content = b"foo_msg" # publish from the randomly chosen host - await pubsubs_gsub[origin_idx].publish(topic, msg_content) + await pubsubs_gsub[0].publish(topic, msg_content) - await asyncio.sleep(0.5) - # Assert that all blocking queues receive the message - for queue in queues: - msg = await queue.get() - assert msg.data == msg_content + # now node 1 subscribes + queue_1 = await pubsubs_gsub[1].subscribe(topic) - for sub in pubsubs_gsub: - await sub.unsubscribe(topic) + await connect(pubsubs_gsub[0].host, pubsubs_gsub[1].host) - queues = [] + # wait for gossip heartbeat + await trio.sleep(2) - await asyncio.sleep(2) - - # Resub and repeat - for i in range(1, len(pubsubs_gsub)): - q = await pubsubs_gsub[i].subscribe(topic) - - # Add each blocking queue to an array of blocking queues - queues.append(q) - - await asyncio.sleep(2) - - # Check messages can still be sent - for i in range(num_msgs): - msg_content = b"bar " + i.to_bytes(1, "big") - - # Pick the message origin to the node that is not subscribed to 'foobar' - origin_idx = 0 - - # publish from the randomly chosen host - await pubsubs_gsub[origin_idx].publish(topic, msg_content) - - await asyncio.sleep(0.5) - # Assert that all blocking queues receive the message - for queue in queues: - msg = await queue.get() - assert msg.data == msg_content - - -@pytest.mark.parametrize( - "num_hosts, gossipsub_params", - ( - ( - 2, - GossipsubParams( - degree=1, - degree_low=0, - degree_high=2, - gossip_window=50, - gossip_history=100, - ), - ), - ), -) -@pytest.mark.asyncio -async def test_gossip_propagation(hosts, pubsubs_gsub): - topic = "foo" - await pubsubs_gsub[0].subscribe(topic) - - # node 0 publish to topic - msg_content = b"foo_msg" - - # publish from the randomly chosen host - await pubsubs_gsub[0].publish(topic, msg_content) - - # now node 1 subscribes - queue_1 = await pubsubs_gsub[1].subscribe(topic) - - await connect(hosts[0], hosts[1]) - - # wait for gossip heartbeat - await asyncio.sleep(2) - - # should be able to read message - msg = await queue_1.get() - assert msg.data == msg_content + # should be able to read message + msg = await queue_1.receive() + assert msg.data == msg_content diff --git a/tests/pubsub/test_gossipsub_backward_compatibility.py b/tests/pubsub/test_gossipsub_backward_compatibility.py index d82fd229..08f0284b 100644 --- a/tests/pubsub/test_gossipsub_backward_compatibility.py +++ b/tests/pubsub/test_gossipsub_backward_compatibility.py @@ -3,25 +3,25 @@ import functools import pytest from libp2p.tools.constants import FLOODSUB_PROTOCOL_ID -from libp2p.tools.factories import GossipsubFactory +from libp2p.tools.factories import PubsubFactory from libp2p.tools.pubsub.floodsub_integration_test_settings import ( floodsub_protocol_pytest_params, perform_test_from_obj, ) -@pytest.mark.asyncio -async def test_gossipsub_initialize_with_floodsub_protocol(): - GossipsubFactory(protocols=[FLOODSUB_PROTOCOL_ID]) - - @pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params) -@pytest.mark.asyncio +@pytest.mark.trio @pytest.mark.slow async def test_gossipsub_run_with_floodsub_tests(test_case_obj): await perform_test_from_obj( test_case_obj, functools.partial( - GossipsubFactory, degree=3, degree_low=2, degree_high=4, time_to_live=30 + PubsubFactory.create_batch_with_gossipsub, + protocols=[FLOODSUB_PROTOCOL_ID], + degree=3, + degree_low=2, + degree_high=4, + time_to_live=30, ), ) diff --git a/tests/pubsub/test_mcache.py b/tests/pubsub/test_mcache.py index e80ad27a..fb764b31 100644 --- a/tests/pubsub/test_mcache.py +++ b/tests/pubsub/test_mcache.py @@ -1,5 +1,3 @@ -import pytest - from libp2p.pubsub.mcache import MessageCache @@ -12,8 +10,7 @@ class Msg: self.from_id = from_id -@pytest.mark.asyncio -async def test_mcache(): +def test_mcache(): # Ported from: # https://github.com/libp2p/go-libp2p-pubsub/blob/51b7501433411b5096cac2b4994a36a68515fc03/mcache_test.go mcache = MessageCache(3, 5) diff --git a/tests/pubsub/test_pubsub.py b/tests/pubsub/test_pubsub.py index 22cea0c2..ea04788d 100644 --- a/tests/pubsub/test_pubsub.py +++ b/tests/pubsub/test_pubsub.py @@ -5,12 +5,11 @@ import pytest import trio from libp2p.exceptions import ValidationError -from libp2p.peer.id import ID from libp2p.pubsub.pb import rpc_pb2 +from libp2p.tools.constants import MAX_READ_LEN +from libp2p.tools.factories import IDFactory, PubsubFactory, net_stream_pair_factory from libp2p.tools.pubsub.utils import make_pubsub_msg from libp2p.tools.utils import connect -from libp2p.tools.constants import MAX_READ_LEN -from libp2p.tools.factories import PubsubFactory, net_stream_pair_factory, IDFactory from libp2p.utils import encode_varint_prefixed TESTING_TOPIC = "TEST_SUBSCRIBE" @@ -250,14 +249,14 @@ async def test_continuously_read_stream(monkeypatch, nursery, is_host_secure): async def mock_push_msg(msg_forwarder, msg): event_push_msg.set() - await trio.sleep(0) + await trio.hazmat.checkpoint() def mock_handle_subscription(origin_id, sub_message): event_handle_subscription.set() async def mock_handle_rpc(rpc, sender_peer_id): event_handle_rpc.set() - await trio.sleep(0) + await trio.hazmat.checkpoint() with monkeypatch.context() as m: m.setattr(pubsubs_fsub[0], "push_msg", mock_push_msg) diff --git a/tests/stream_muxer/test_mplex_stream.py b/tests/stream_muxer/test_mplex_stream.py index e47af49d..55ee97bd 100644 --- a/tests/stream_muxer/test_mplex_stream.py +++ b/tests/stream_muxer/test_mplex_stream.py @@ -69,33 +69,10 @@ async def test_mplex_stream_read_after_remote_closed(mplex_stream_pair): await stream_0.close() assert stream_0.event_local_closed.is_set() await trio.sleep(0.01) - print( - "!@# ", - stream_0.muxed_conn.event_shutting_down.is_set(), - stream_0.muxed_conn.event_closed.is_set(), - stream_1.muxed_conn.event_shutting_down.is_set(), - stream_1.muxed_conn.event_closed.is_set(), - ) # await trio.sleep(100000) await wait_all_tasks_blocked() - print( - "!@# ", - stream_0.muxed_conn.event_shutting_down.is_set(), - stream_0.muxed_conn.event_closed.is_set(), - stream_1.muxed_conn.event_shutting_down.is_set(), - stream_1.muxed_conn.event_closed.is_set(), - ) - print("!@# sleeping") - print("!@# result=", stream_1.event_remote_closed.is_set()) # await trio.sleep_forever() assert stream_1.event_remote_closed.is_set() - print( - "!@# ", - stream_0.muxed_conn.event_shutting_down.is_set(), - stream_0.muxed_conn.event_closed.is_set(), - stream_1.muxed_conn.event_shutting_down.is_set(), - stream_1.muxed_conn.event_closed.is_set(), - ) assert (await stream_1.read(MAX_READ_LEN)) == DATA with pytest.raises(MplexStreamEOF): await stream_1.read(MAX_READ_LEN) diff --git a/tests/transport/test_tcp.py b/tests/transport/test_tcp.py index c8fe6f21..abd58840 100644 --- a/tests/transport/test_tcp.py +++ b/tests/transport/test_tcp.py @@ -3,7 +3,7 @@ import pytest import trio from libp2p.network.connection.raw_connection import RawConnection -from libp2p.tools.constants import LISTEN_MADDR, MAX_READ_LEN +from libp2p.tools.constants import LISTEN_MADDR from libp2p.transport.tcp.tcp import TCP From 837a2495528c458b54bfdfaf3ab1216c551eb1e6 Mon Sep 17 00:00:00 2001 From: mhchia Date: Sat, 7 Dec 2019 00:14:01 +0800 Subject: [PATCH 20/81] Fix `security` module --- libp2p/pubsub/exceptions.py | 9 ++ libp2p/tools/factories.py | 24 +++++ setup.py | 1 - tests/security/test_secio.py | 111 +++++--------------- tests/security/test_security_multistream.py | 47 ++++----- 5 files changed, 84 insertions(+), 108 deletions(-) create mode 100644 libp2p/pubsub/exceptions.py diff --git a/libp2p/pubsub/exceptions.py b/libp2p/pubsub/exceptions.py new file mode 100644 index 00000000..a47446de --- /dev/null +++ b/libp2p/pubsub/exceptions.py @@ -0,0 +1,9 @@ +from libp2p.exceptions import BaseLibp2pError + + +class PubsubRouterError(BaseLibp2pError): + ... + + +class NoPubsubAttached(PubsubRouterError): + ... diff --git a/libp2p/tools/factories.py b/libp2p/tools/factories.py index 568a2762..e1798898 100644 --- a/libp2p/tools/factories.py +++ b/libp2p/tools/factories.py @@ -9,6 +9,9 @@ from libp2p import generate_new_rsa_identity, generate_peer_id_from from libp2p.crypto.keys import KeyPair from libp2p.host.basic_host import BasicHost from libp2p.host.host_interface import IHost +from libp2p.io.abc import ReadWriteCloser +from libp2p.network.connection.raw_connection import RawConnection +from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.connection.swarm_connection import SwarmConn from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.swarm import Swarm @@ -51,6 +54,27 @@ def security_transport_factory( return {secio.ID: secio.Transport(key_pair)} +@asynccontextmanager +async def raw_conn_factory( + nursery: trio.Nursery +) -> AsyncIterator[Tuple[IRawConnection, IRawConnection]]: + conn_0 = None + conn_1 = None + + async def tcp_stream_handler(stream: ReadWriteCloser) -> None: + nonlocal conn_1 + conn_1 = RawConnection(stream, initiator=False) + await trio.sleep_forever() + + tcp_transport = TCP() + listener = tcp_transport.create_listener(tcp_stream_handler) + await listener.listen(LISTEN_MADDR, nursery) + listening_maddr = listener.multiaddrs[0] + conn_0 = await tcp_transport.dial(listening_maddr) + print("raw_conn_factory") + yield conn_0, conn_1 + + class SwarmFactory(factory.Factory): class Meta: model = Swarm diff --git a/setup.py b/setup.py index cbb8eaf8..edfd1aae 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,6 @@ extras_require = { "test": [ "factory-boy>=2.12.0,<3.0.0", "pytest>=4.6.3,<5.0.0", - "pytest-asyncio>=0.10.0,<1.0.0", "pytest-xdist>=1.30.0", "pytest-trio>=0.5.2", ], diff --git a/tests/security/test_secio.py b/tests/security/test_secio.py index c7808b46..d009a738 100644 --- a/tests/security/test_secio.py +++ b/tests/security/test_secio.py @@ -1,70 +1,15 @@ -import asyncio - import pytest +import trio from libp2p.crypto.secp256k1 import create_new_key_pair -from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.peer.id import ID from libp2p.security.secio.transport import NONCE_SIZE, create_secure_session +from libp2p.tools.constants import MAX_READ_LEN +from libp2p.tools.factories import raw_conn_factory -class InMemoryConnection(IRawConnection): - def __init__(self, peer, is_initiator=False): - self.peer = peer - self.recv_queue = asyncio.Queue() - self.send_queue = asyncio.Queue() - self.is_initiator = is_initiator - - self.current_msg = None - self.current_position = 0 - - self.closed = False - - async def write(self, data: bytes) -> int: - if self.closed: - raise Exception("InMemoryConnection is closed for writing") - - await self.send_queue.put(data) - return len(data) - - async def read(self, n: int = -1) -> bytes: - """ - NOTE: have to buffer the current message and juggle packets - off the recv queue to satisfy the semantics of this function. - """ - if self.closed: - raise Exception("InMemoryConnection is closed for reading") - - if not self.current_msg: - self.current_msg = await self.recv_queue.get() - self.current_position = 0 - - if n < 0: - msg = self.current_msg - self.current_msg = None - return msg - - next_msg = self.current_msg[self.current_position : self.current_position + n] - self.current_position += n - if self.current_position == len(self.current_msg): - self.current_msg = None - return next_msg - - async def close(self) -> None: - self.closed = True - - -async def create_pipe(local_conn, remote_conn): - try: - while True: - next_msg = await local_conn.send_queue.get() - await remote_conn.recv_queue.put(next_msg) - except asyncio.CancelledError: - return - - -@pytest.mark.asyncio -async def test_create_secure_session(): +@pytest.mark.trio +async def test_create_secure_session(nursery): local_nonce = b"\x01" * NONCE_SIZE local_key_pair = create_new_key_pair(b"a") local_peer = ID.from_pubkey(local_key_pair.public_key) @@ -73,30 +18,32 @@ async def test_create_secure_session(): remote_key_pair = create_new_key_pair(b"b") remote_peer = ID.from_pubkey(remote_key_pair.public_key) - local_conn = InMemoryConnection(local_peer, is_initiator=True) - remote_conn = InMemoryConnection(remote_peer) + async with raw_conn_factory(nursery) as conns: + local_conn, remote_conn = conns - local_pipe_task = asyncio.create_task(create_pipe(local_conn, remote_conn)) - remote_pipe_task = asyncio.create_task(create_pipe(remote_conn, local_conn)) + local_secure_conn, remote_secure_conn = None, None - local_session_builder = create_secure_session( - local_nonce, local_peer, local_key_pair.private_key, local_conn, remote_peer - ) - remote_session_builder = create_secure_session( - remote_nonce, remote_peer, remote_key_pair.private_key, remote_conn - ) - local_secure_conn, remote_secure_conn = await asyncio.gather( - local_session_builder, remote_session_builder - ) + async def local_create_secure_session(): + nonlocal local_secure_conn + local_secure_conn = await create_secure_session( + local_nonce, + local_peer, + local_key_pair.private_key, + local_conn, + remote_peer, + ) - msg = b"abc" - await local_secure_conn.write(msg) - received_msg = await remote_secure_conn.read() - assert received_msg == msg + async def remote_create_secure_session(): + nonlocal remote_secure_conn + remote_secure_conn = await create_secure_session( + remote_nonce, remote_peer, remote_key_pair.private_key, remote_conn + ) - await asyncio.gather(local_secure_conn.close(), remote_secure_conn.close()) + async with trio.open_nursery() as nursery_1: + nursery_1.start_soon(local_create_secure_session) + nursery_1.start_soon(remote_create_secure_session) - local_pipe_task.cancel() - remote_pipe_task.cancel() - await local_pipe_task - await remote_pipe_task + msg = b"abc" + await local_secure_conn.write(msg) + received_msg = await remote_secure_conn.read(MAX_READ_LEN) + assert received_msg == msg diff --git a/tests/security/test_security_multistream.py b/tests/security/test_security_multistream.py index c4eb3ecb..5c751f92 100644 --- a/tests/security/test_security_multistream.py +++ b/tests/security/test_security_multistream.py @@ -1,6 +1,6 @@ -import asyncio - +from async_service import background_trio_service import pytest +import trio from libp2p import new_node from libp2p.crypto.rsa import create_new_key_pair @@ -24,42 +24,39 @@ noninitiator_key_pair = create_new_key_pair() async def perform_simple_test( assertion_func, transports_for_initiator, transports_for_noninitiator ): - # Create libp2p nodes and connect them, then secure the connection, then check # the proper security was chosen # TODO: implement -- note we need to introduce the notion of communicating over a raw connection # for testing, we do NOT want to communicate over a stream so we can't just create two nodes # and use their conn because our mplex will internally relay messages to a stream - node1 = await new_node( - key_pair=initiator_key_pair, sec_opt=transports_for_initiator - ) - node2 = await new_node( + node1 = new_node(key_pair=initiator_key_pair, sec_opt=transports_for_initiator) + node2 = new_node( key_pair=noninitiator_key_pair, sec_opt=transports_for_noninitiator ) + swarm1 = node1.get_network() + swarm2 = node2.get_network() + async with background_trio_service(swarm1), background_trio_service(swarm2): + await swarm1.listen(LISTEN_MADDR) + await swarm2.listen(LISTEN_MADDR) - await node1.get_network().listen(LISTEN_MADDR) - await node2.get_network().listen(LISTEN_MADDR) + await connect(node1, node2) - await connect(node1, node2) + # Wait a very short period to allow conns to be stored (since the functions + # storing the conns are async, they may happen at slightly different times + # on each node) + await trio.sleep(0.1) - # Wait a very short period to allow conns to be stored (since the functions - # storing the conns are async, they may happen at slightly different times - # on each node) - await asyncio.sleep(0.1) + # Get conns + node1_conn = node1.get_network().connections[peer_id_for_node(node2)] + node2_conn = node2.get_network().connections[peer_id_for_node(node1)] - # Get conns - node1_conn = node1.get_network().connections[peer_id_for_node(node2)] - node2_conn = node2.get_network().connections[peer_id_for_node(node1)] - - # Perform assertion - assertion_func(node1_conn.muxed_conn.secured_conn) - assertion_func(node2_conn.muxed_conn.secured_conn) - - # Success, terminate pending tasks. + # Perform assertion + assertion_func(node1_conn.muxed_conn.secured_conn) + assertion_func(node2_conn.muxed_conn.secured_conn) -@pytest.mark.asyncio +@pytest.mark.trio async def test_single_insecure_security_transport_succeeds(): transports_for_initiator = {"foo": InsecureTransport(initiator_key_pair)} transports_for_noninitiator = {"foo": InsecureTransport(noninitiator_key_pair)} @@ -72,7 +69,7 @@ async def test_single_insecure_security_transport_succeeds(): ) -@pytest.mark.asyncio +@pytest.mark.trio async def test_default_insecure_security(): transports_for_initiator = None transports_for_noninitiator = None From d847e78a83cc2bcb23e45da5e0ff490cab5da5aa Mon Sep 17 00:00:00 2001 From: mhchia Date: Sat, 7 Dec 2019 00:19:10 +0800 Subject: [PATCH 21/81] Add dep `async-service` --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index edfd1aae..f7977131 100644 --- a/setup.py +++ b/setup.py @@ -70,6 +70,7 @@ setuptools.setup( "pynacl==1.3.0", "trio-asyncio>=0.10.0", "trio>=0.13.0", + "async-service>=0.1.0a2,<0.2.0", ], extras_require=extras_require, packages=setuptools.find_packages(exclude=["tests", "tests.*"]), From fb0519129d7a3b4a891c0b8ac99b38308feba34d Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 17 Dec 2019 15:50:55 +0800 Subject: [PATCH 22/81] Refine `Mplex.close` and `SwarmConn.close` Ensure `close` cleans up things and cancel the service finally. --- libp2p/network/connection/swarm_connection.py | 35 +++++++++++-------- libp2p/network/swarm.py | 18 +++++++--- libp2p/pubsub/pubsub.py | 4 +-- libp2p/stream_muxer/abc.py | 1 + libp2p/stream_muxer/mplex/mplex.py | 6 ++-- libp2p/stream_muxer/mplex/mplex_stream.py | 4 --- libp2p/tools/factories.py | 3 +- libp2p/tools/utils.py | 3 ++ libp2p/transport/listener_interface.py | 4 +++ libp2p/transport/tcp/tcp.py | 16 ++++++--- tests/network/test_swarm_conn.py | 13 ++++--- tests/stream_muxer/test_mplex_conn.py | 10 ++---- tests/transport/test_tcp.py | 5 +-- 13 files changed, 71 insertions(+), 51 deletions(-) diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index 1e310338..a4fc8be4 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -29,10 +29,19 @@ class SwarmConn(INetConn, Service): self.streams = set() self.event_closed = trio.Event() + @property + def is_closed(self) -> bool: + return self.event_closed.is_set() + async def close(self) -> None: if self.event_closed.is_set(): return self.event_closed.set() + await self._cleanup() + # Cancel service + await self.manager.stop() + + async def _cleanup(self) -> None: self.swarm.remove_conn(self) await self.muxed_conn.close() @@ -51,28 +60,23 @@ class SwarmConn(INetConn, Service): while self.manager.is_running: try: stream = await self.muxed_conn.accept_stream() - except MuxedConnUnavailable: - # If there is anything wrong in the MuxedConn, - # we should break the loop and close the connection. - break # Asynchronously handle the accepted stream, to avoid blocking the next stream. + except MuxedConnUnavailable: + break self.manager.run_task(self._handle_muxed_stream, stream) await self.close() - async def _call_stream_handler(self, net_stream: NetStream) -> None: - try: - await self.swarm.common_stream_handler(net_stream) - # TODO: More exact exceptions - except Exception: - # TODO: Emit logs. - # TODO: Clean up and remove the stream from SwarmConn if there is anything wrong. - self.remove_stream(net_stream) - async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None: net_stream = await self._add_stream(muxed_stream) if self.swarm.common_stream_handler is not None: - await self._call_stream_handler(net_stream) + try: + await self.swarm.common_stream_handler(net_stream) + # TODO: More exact exceptions + except Exception: + # TODO: Emit logs. + # TODO: Clean up and remove the stream from SwarmConn if there is anything wrong. + self.remove_stream(net_stream) async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: net_stream = NetStream(muxed_stream) @@ -84,7 +88,8 @@ class SwarmConn(INetConn, Service): await self.swarm.notify_disconnected(self) async def run(self) -> None: - await self._handle_new_streams() + self.manager.run_task(self._handle_new_streams) + await self.manager.wait_finished() async def new_stream(self) -> NetStream: muxed_stream = await self.muxed_conn.open_stream() diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 4bf86dde..cdb80a44 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -44,6 +44,7 @@ class Swarm(INetwork, Service): common_stream_handler: Optional[StreamHandlerFn] notifees: List[INotifee] + event_closed: trio.Event def __init__( self, @@ -62,6 +63,8 @@ class Swarm(INetwork, Service): # Create Notifee array self.notifees = [] + self.event_closed = trio.Event() + self.common_stream_handler = None async def run(self) -> None: @@ -227,10 +230,19 @@ class Swarm(INetwork, Service): return False async def close(self) -> None: - # TODO: Prevent from new listeners and conns being added. + if self.event_closed.is_set(): + return + self.event_closed.set() # Reference: https://github.com/libp2p/go-libp2p-swarm/blob/8be680aef8dea0a4497283f2f98470c2aeae6b65/swarm.go#L124-L134 # noqa: E501 + async with trio.open_nursery() as nursery: + for conn in self.connections.values(): + nursery.start_soon(conn.close) + async with trio.open_nursery() as nursery: + for listener in self.listeners.values(): + nursery.start_soon(listener.close) + + # Cancel tasks await self.manager.stop() - await self.manager.wait_finished() logger.debug("swarm successfully closed") async def close_peer(self, peer_id: ID) -> None: @@ -270,8 +282,6 @@ class Swarm(INetwork, Service): # Notifee - # TODO: Remeber the spawn notifying tasks and clean them up when closing. - def register_notifee(self, notifee: INotifee) -> None: """ :param notifee: object implementing Notifee interface diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 3370ea30..0c3b162c 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -58,11 +58,11 @@ class TopicValidator(NamedTuple): # TODO: Add interface for Pubsub -class BasePubsub(ABC): +class IPubsub(ABC): pass -class Pubsub(BasePubsub, Service): +class Pubsub(IPubsub, Service): host: IHost diff --git a/libp2p/stream_muxer/abc.py b/libp2p/stream_muxer/abc.py index 12a8f805..e34295cf 100644 --- a/libp2p/stream_muxer/abc.py +++ b/libp2p/stream_muxer/abc.py @@ -33,6 +33,7 @@ class IMuxedConn(ServiceAPI): async def close(self) -> None: """close connection.""" + @property @abstractmethod def is_closed(self) -> bool: """ diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index e23da00c..b7b3a3ae 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -91,7 +91,9 @@ class Mplex(IMuxedConn, Service): await self.secured_conn.close() # Blocked until `close` is finally set. await self.event_closed.wait() + await self.manager.stop() + @property def is_closed(self) -> bool: """ check connection is fully closed. @@ -213,10 +215,6 @@ class Mplex(IMuxedConn, Service): return channel_id, flag, message - @property - def _id(self) -> int: - return 0 if self.is_initiator else 1 - async def _handle_incoming_message(self) -> None: """ Read and handle a new incoming message. diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index eeefc422..011cd3ae 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -156,26 +156,22 @@ class MplexStream(IMuxedStream): if self.event_local_closed.is_set(): return - print(f"!@# stream.close: {self.muxed_conn._id}: step=0") flag = ( HeaderTags.CloseInitiator if self.is_initiator else HeaderTags.CloseReceiver ) # TODO: Raise when `muxed_conn.send_message` fails and `Mplex` isn't shutdown. await self.muxed_conn.send_message(flag, None, self.stream_id) - print(f"!@# stream.close: {self.muxed_conn._id}: step=1") _is_remote_closed: bool async with self.close_lock: self.event_local_closed.set() _is_remote_closed = self.event_remote_closed.is_set() - print(f"!@# stream.close: {self.muxed_conn._id}: step=2") if _is_remote_closed: # Both sides are closed, we can safely remove the buffer from the dict. async with self.muxed_conn.streams_lock: if self.stream_id in self.muxed_conn.streams: del self.muxed_conn.streams[self.stream_id] - print(f"!@# stream.close: {self.muxed_conn._id}: step=3") async def reset(self) -> None: """closes both ends of the stream tells this remote side to hang up.""" diff --git a/libp2p/tools/factories.py b/libp2p/tools/factories.py index e1798898..a9eb6a53 100644 --- a/libp2p/tools/factories.py +++ b/libp2p/tools/factories.py @@ -69,9 +69,8 @@ async def raw_conn_factory( tcp_transport = TCP() listener = tcp_transport.create_listener(tcp_stream_handler) await listener.listen(LISTEN_MADDR, nursery) - listening_maddr = listener.multiaddrs[0] + listening_maddr = listener.get_addrs()[0] conn_0 = await tcp_transport.dial(listening_maddr) - print("raw_conn_factory") yield conn_0, conn_1 diff --git a/libp2p/tools/utils.py b/libp2p/tools/utils.py index a66155c3..216fdd82 100644 --- a/libp2p/tools/utils.py +++ b/libp2p/tools/utils.py @@ -39,3 +39,6 @@ def create_echo_stream_handler( await stream.write(resp.encode()) return echo_stream_handler + + +# TODO: Service `external_api` diff --git a/libp2p/transport/listener_interface.py b/libp2p/transport/listener_interface.py index 6d737233..d170d1de 100644 --- a/libp2p/transport/listener_interface.py +++ b/libp2p/transport/listener_interface.py @@ -22,3 +22,7 @@ class IListener(ABC): :return: return list of addrs """ + + @abstractmethod + async def close(self) -> None: + ... diff --git a/libp2p/transport/tcp/tcp.py b/libp2p/transport/tcp/tcp.py index 04d8874f..8c46a4aa 100644 --- a/libp2p/transport/tcp/tcp.py +++ b/libp2p/transport/tcp/tcp.py @@ -16,10 +16,10 @@ logger = logging.getLogger("libp2p.transport.tcp") class TCPListener(IListener): - multiaddrs: List[Multiaddr] + listeners: List[trio.SocketListener] def __init__(self, handler_function: THandler) -> None: - self.multiaddrs = [] + self.listeners = [] self.handler = handler_function # TODO: Get rid of `nursery`? @@ -50,8 +50,7 @@ class TCPListener(IListener): int(maddr.value_for_protocol("tcp")), maddr.value_for_protocol("ip4"), ) - socket = listeners[0].socket - self.multiaddrs.append(_multiaddr_from_socket(socket)) + self.listeners.extend(listeners) def get_addrs(self) -> Tuple[Multiaddr, ...]: """ @@ -59,7 +58,14 @@ class TCPListener(IListener): :return: return list of addrs """ - return tuple(self.multiaddrs) + return tuple( + _multiaddr_from_socket(listener.socket) for listener in self.listeners + ) + + async def close(self) -> None: + async with trio.open_nursery() as nursery: + for listener in self.listeners: + nursery.start_soon(listener.aclose) class TCP(ITransport): diff --git a/tests/network/test_swarm_conn.py b/tests/network/test_swarm_conn.py index e1c1285c..1bfd7d86 100644 --- a/tests/network/test_swarm_conn.py +++ b/tests/network/test_swarm_conn.py @@ -1,20 +1,23 @@ import pytest import trio +from trio.testing import wait_all_tasks_blocked @pytest.mark.trio async def test_swarm_conn_close(swarm_conn_pair): conn_0, conn_1 = swarm_conn_pair - assert not conn_0.event_closed.is_set() - assert not conn_1.event_closed.is_set() + assert not conn_0.is_closed + assert not conn_1.is_closed await conn_0.close() - await trio.sleep(0.01) + await trio.sleep(0.1) + await wait_all_tasks_blocked() + await conn_0.manager.wait_finished() - assert conn_0.event_closed.is_set() - assert conn_1.event_closed.is_set() + assert conn_0.is_closed + assert conn_1.is_closed assert conn_0 not in conn_0.swarm.connections.values() assert conn_1 not in conn_1.swarm.connections.values() diff --git a/tests/stream_muxer/test_mplex_conn.py b/tests/stream_muxer/test_mplex_conn.py index 4cedc36d..4bff2d61 100644 --- a/tests/stream_muxer/test_mplex_conn.py +++ b/tests/stream_muxer/test_mplex_conn.py @@ -8,10 +8,6 @@ async def test_mplex_conn(mplex_conn_pair): assert len(conn_0.streams) == 0 assert len(conn_1.streams) == 0 - assert not conn_0.event_shutting_down.is_set() - assert not conn_1.event_shutting_down.is_set() - assert not conn_0.event_closed.is_set() - assert not conn_1.event_closed.is_set() # Test: Open a stream, and both side get 1 more stream. stream_0 = await conn_0.open_stream() @@ -29,10 +25,8 @@ async def test_mplex_conn(mplex_conn_pair): # Sleep for a while for both side to handle `close`. await trio.sleep(0.01) # Test: Both side is closed. - assert conn_0.event_shutting_down.is_set() - assert conn_0.event_closed.is_set() - assert conn_1.event_shutting_down.is_set() - assert conn_1.event_closed.is_set() + assert conn_0.is_closed + assert conn_1.is_closed # Test: All streams should have been closed. assert stream_0.event_remote_closed.is_set() assert stream_0.event_reset.is_set() diff --git a/tests/transport/test_tcp.py b/tests/transport/test_tcp.py index abd58840..247b5f91 100644 --- a/tests/transport/test_tcp.py +++ b/tests/transport/test_tcp.py @@ -38,8 +38,9 @@ async def test_tcp_dial(nursery): listener = transport.create_listener(handler) await listener.listen(LISTEN_MADDR, nursery) - assert len(listener.multiaddrs) == 1 - listen_addr = listener.multiaddrs[0] + addrs = listener.get_addrs() + assert len(addrs) == 1 + listen_addr = addrs[0] raw_conn = await transport.dial(listen_addr) data = b"123" From 47d10e186f06fa20cf4bee2330569aa645bf3975 Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 17 Dec 2019 18:17:28 +0800 Subject: [PATCH 23/81] Add `SubscriptionAPI` And `TrioSubscriptionAPI`, to make subscription io-agnostic. --- .../{pubsub_router_interface.py => abc.py} | 20 ++++- libp2p/pubsub/floodsub.py | 2 +- libp2p/pubsub/gossipsub.py | 2 +- libp2p/pubsub/pubsub.py | 25 +++--- libp2p/pubsub/subscription.py | 39 ++++++++++ libp2p/tools/factories.py | 2 +- libp2p/tools/pubsub/dummy_account_node.py | 2 +- .../floodsub_integration_test_settings.py | 2 +- tests/pubsub/test_floodsub.py | 7 +- tests/pubsub/test_gossipsub.py | 12 +-- tests/pubsub/test_pubsub.py | 4 +- tests/pubsub/test_subscription.py | 77 +++++++++++++++++++ 12 files changed, 158 insertions(+), 36 deletions(-) rename libp2p/pubsub/{pubsub_router_interface.py => abc.py} (86%) create mode 100644 libp2p/pubsub/subscription.py create mode 100644 tests/pubsub/test_subscription.py diff --git a/libp2p/pubsub/pubsub_router_interface.py b/libp2p/pubsub/abc.py similarity index 86% rename from libp2p/pubsub/pubsub_router_interface.py rename to libp2p/pubsub/abc.py index 99a9be75..19f9b2a6 100644 --- a/libp2p/pubsub/pubsub_router_interface.py +++ b/libp2p/pubsub/abc.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, AsyncContextManager, AsyncIterable, List from libp2p.peer.id import ID from libp2p.typing import TProtocol @@ -10,6 +10,11 @@ if TYPE_CHECKING: from .pubsub import Pubsub # noqa: F401 +# TODO: Add interface for Pubsub +class IPubsub(ABC): + pass + + class IPubsubRouter(ABC): @abstractmethod def get_protocols(self) -> List[TProtocol]: @@ -53,7 +58,6 @@ class IPubsubRouter(ABC): :param rpc: rpc message """ - # FIXME: Should be changed to type 'peer.ID' @abstractmethod async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None: """ @@ -80,3 +84,15 @@ class IPubsubRouter(ABC): :param topic: topic to leave """ + + +class ISubscriptionAPI( + AsyncContextManager["ISubscriptionAPI"], AsyncIterable[rpc_pb2.Message] +): + @abstractmethod + async def cancel(self) -> None: + ... + + @abstractmethod + async def get(self) -> rpc_pb2.Message: + ... diff --git a/libp2p/pubsub/floodsub.py b/libp2p/pubsub/floodsub.py index 9e323eb2..06300eec 100644 --- a/libp2p/pubsub/floodsub.py +++ b/libp2p/pubsub/floodsub.py @@ -8,9 +8,9 @@ from libp2p.peer.id import ID from libp2p.typing import TProtocol from libp2p.utils import encode_varint_prefixed +from .abc import IPubsubRouter from .pb import rpc_pb2 from .pubsub import Pubsub -from .pubsub_router_interface import IPubsubRouter PROTOCOL_ID = TProtocol("/floodsub/1.0.0") diff --git a/libp2p/pubsub/gossipsub.py b/libp2p/pubsub/gossipsub.py index df0f83f4..df886db5 100644 --- a/libp2p/pubsub/gossipsub.py +++ b/libp2p/pubsub/gossipsub.py @@ -12,11 +12,11 @@ from libp2p.pubsub import floodsub from libp2p.typing import TProtocol from libp2p.utils import encode_varint_prefixed +from .abc import IPubsubRouter from .exceptions import NoPubsubAttached from .mcache import MessageCache from .pb import rpc_pb2 from .pubsub import Pubsub -from .pubsub_router_interface import IPubsubRouter PROTOCOL_ID = TProtocol("/meshsub/1.0.0") diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 0c3b162c..71b82f48 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -1,4 +1,3 @@ -from abc import ABC import logging import math import time @@ -30,12 +29,14 @@ from libp2p.peer.id import ID from libp2p.typing import TProtocol from libp2p.utils import encode_varint_prefixed, read_varint_prefixed_bytes +from .abc import IPubsub, ISubscriptionAPI from .pb import rpc_pb2 from .pubsub_notifee import PubsubNotifee +from .subscription import TrioSubscriptionAPI from .validators import signature_validator if TYPE_CHECKING: - from .pubsub_router_interface import IPubsubRouter # noqa: F401 + from .abc import IPubsubRouter # noqa: F401 from typing import Any # noqa: F401 @@ -57,11 +58,6 @@ class TopicValidator(NamedTuple): is_async: bool -# TODO: Add interface for Pubsub -class IPubsub(ABC): - pass - - class Pubsub(IPubsub, Service): host: IHost @@ -75,7 +71,7 @@ class Pubsub(IPubsub, Service): # TODO: Implement `trio.abc.Channel`? subscribed_topics_send: Dict[str, "trio.MemorySendChannel[rpc_pb2.Message]"] - subscribed_topics_receive: Dict[str, "trio.MemoryReceiveChannel[rpc_pb2.Message]"] + subscribed_topics_receive: Dict[str, "TrioSubscriptionAPI"] peer_topics: Dict[str, List[ID]] peers: Dict[ID, INetStream] @@ -380,10 +376,7 @@ class Pubsub(IPubsub, Service): # for each topic await self.subscribed_topics_send[topic].send(publish_message) - # TODO: Change to return an `AsyncIterable` to be I/O-agnostic? - async def subscribe( - self, topic_id: str - ) -> "trio.MemoryReceiveChannel[rpc_pb2.Message]": + async def subscribe(self, topic_id: str) -> ISubscriptionAPI: """ Subscribe ourself to a topic. @@ -396,14 +389,14 @@ class Pubsub(IPubsub, Service): if topic_id in self.topic_ids: return self.subscribed_topics_receive[topic_id] - # Map topic_id to a blocking channel channels: Tuple[ "trio.MemorySendChannel[rpc_pb2.Message]", "trio.MemoryReceiveChannel[rpc_pb2.Message]", ] = trio.open_memory_channel(math.inf) send_channel, receive_channel = channels + subscription = TrioSubscriptionAPI(receive_channel) self.subscribed_topics_send[topic_id] = send_channel - self.subscribed_topics_receive[topic_id] = receive_channel + self.subscribed_topics_receive[topic_id] = subscription # Create subscribe message packet: rpc_pb2.RPC = rpc_pb2.RPC() @@ -417,8 +410,8 @@ class Pubsub(IPubsub, Service): # Tell router we are joining this topic await self.router.join(topic_id) - # Return the trio channel for messages on this topic - return receive_channel + # Return the subscription for messages on this topic + return subscription async def unsubscribe(self, topic_id: str) -> None: """ diff --git a/libp2p/pubsub/subscription.py b/libp2p/pubsub/subscription.py new file mode 100644 index 00000000..1d88d09b --- /dev/null +++ b/libp2p/pubsub/subscription.py @@ -0,0 +1,39 @@ +from types import TracebackType +from typing import AsyncIterator, Optional, Type + +import trio + +from .abc import ISubscriptionAPI +from .pb import rpc_pb2 + + +class BaseSubscriptionAPI(ISubscriptionAPI): + async def __aenter__(self) -> "BaseSubscriptionAPI": + await trio.hazmat.checkpoint() + return self + + async def __aexit__( + self, + exc_type: "Optional[Type[BaseException]]", + exc_value: "Optional[BaseException]", + traceback: "Optional[TracebackType]", + ) -> None: + await self.cancel() + + +class TrioSubscriptionAPI(BaseSubscriptionAPI): + receive_channel: "trio.MemoryReceiveChannel[rpc_pb2.Message]" + + def __init__( + self, receive_channel: "trio.MemoryReceiveChannel[rpc_pb2.Message]" + ) -> None: + self.receive_channel = receive_channel + + async def cancel(self) -> None: + await self.receive_channel.aclose() + + def __aiter__(self) -> AsyncIterator[rpc_pb2.Message]: + return self.receive_channel.__aiter__() + + async def get(self) -> rpc_pb2.Message: + return await self.receive_channel.receive() diff --git a/libp2p/tools/factories.py b/libp2p/tools/factories.py index a9eb6a53..52208da4 100644 --- a/libp2p/tools/factories.py +++ b/libp2p/tools/factories.py @@ -17,10 +17,10 @@ from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.swarm import Swarm from libp2p.peer.id import ID from libp2p.peer.peerstore import PeerStore +from libp2p.pubsub.abc import IPubsubRouter from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.gossipsub import GossipSub from libp2p.pubsub.pubsub import Pubsub -from libp2p.pubsub.pubsub_router_interface import IPubsubRouter from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport import libp2p.security.secio.transport as secio diff --git a/libp2p/tools/pubsub/dummy_account_node.py b/libp2p/tools/pubsub/dummy_account_node.py index 5a61ed69..9079ac20 100644 --- a/libp2p/tools/pubsub/dummy_account_node.py +++ b/libp2p/tools/pubsub/dummy_account_node.py @@ -61,7 +61,7 @@ class DummyAccountNode(Service): async def handle_incoming_msgs(self) -> None: """Handle all incoming messages on the CRYPTO_TOPIC from peers.""" while True: - incoming = await self.subscription.receive() + incoming = await self.subscription.get() msg_comps = incoming.data.decode("utf-8").split(",") if msg_comps[0] == "send": diff --git a/libp2p/tools/pubsub/floodsub_integration_test_settings.py b/libp2p/tools/pubsub/floodsub_integration_test_settings.py index 58a5b242..0d25586e 100644 --- a/libp2p/tools/pubsub/floodsub_integration_test_settings.py +++ b/libp2p/tools/pubsub/floodsub_integration_test_settings.py @@ -250,7 +250,7 @@ async def perform_test_from_obj(obj, pubsub_factory) -> None: # Look at each node in each topic for node_id in topic_map[topic]: # Get message from subscription queue - msg = await queues_map[node_id][topic].receive() + msg = await queues_map[node_id][topic].get() assert data == msg.data # Check the message origin assert node_map[origin_node_id].get_id().to_bytes() == msg.from_id diff --git a/tests/pubsub/test_floodsub.py b/tests/pubsub/test_floodsub.py index dbeb6833..148c001b 100644 --- a/tests/pubsub/test_floodsub.py +++ b/tests/pubsub/test_floodsub.py @@ -27,7 +27,7 @@ async def test_simple_two_nodes(): await pubsubs_fsub[0].publish(topic, data) - res_b = await sub_b.receive() + res_b = await sub_b.get() # Check that the msg received by node_b is the same # as the message sent by node_a @@ -75,12 +75,9 @@ async def test_lru_cache_two_nodes(monkeypatch): await trio.sleep(0.25) for index in expected_received_indices: - res_b = await sub_b.receive() + res_b = await sub_b.get() assert res_b.data == _make_testing_data(index) - with pytest.raises(trio.WouldBlock): - sub_b.receive_nowait() - @pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params) @pytest.mark.trio diff --git a/tests/pubsub/test_gossipsub.py b/tests/pubsub/test_gossipsub.py index b1ed3af1..e9d789a9 100644 --- a/tests/pubsub/test_gossipsub.py +++ b/tests/pubsub/test_gossipsub.py @@ -196,7 +196,7 @@ async def test_dense(): await trio.sleep(0.5) # Assert that all blocking queues receive the message for queue in queues: - msg = await queue.receive() + msg = await queue.get() assert msg.data == msg_content @@ -229,7 +229,7 @@ async def test_fanout(): await trio.sleep(0.5) # Assert that all blocking queues receive the message for sub in subs: - msg = await sub.receive() + msg = await sub.get() assert msg.data == msg_content # Subscribe message origin @@ -248,7 +248,7 @@ async def test_fanout(): await trio.sleep(0.5) # Assert that all blocking queues receive the message for sub in subs: - msg = await sub.receive() + msg = await sub.get() assert msg.data == msg_content @@ -287,7 +287,7 @@ async def test_fanout_maintenance(): await trio.sleep(0.5) # Assert that all blocking queues receive the message for queue in queues: - msg = await queue.receive() + msg = await queue.get() assert msg.data == msg_content for sub in pubsubs_gsub: @@ -319,7 +319,7 @@ async def test_fanout_maintenance(): await trio.sleep(0.5) # Assert that all blocking queues receive the message for queue in queues: - msg = await queue.receive() + msg = await queue.get() assert msg.data == msg_content @@ -346,5 +346,5 @@ async def test_gossip_propagation(): await trio.sleep(2) # should be able to read message - msg = await queue_1.receive() + msg = await queue_1.get() assert msg.data == msg_content diff --git a/tests/pubsub/test_pubsub.py b/tests/pubsub/test_pubsub.py index ea04788d..c4f00011 100644 --- a/tests/pubsub/test_pubsub.py +++ b/tests/pubsub/test_pubsub.py @@ -384,7 +384,7 @@ async def test_handle_talk(): len(pubsubs_fsub[0].topic_ids) == 1 and sub == pubsubs_fsub[0].subscribed_topics_receive[TESTING_TOPIC] ) - assert (await sub.receive()) == msg_0 + assert (await sub.get()) == msg_0 @pytest.mark.trio @@ -486,7 +486,7 @@ async def test_push_msg(monkeypatch): with trio.fail_after(0.1): await event.wait() # Test: Subscribers are notified when `push_msg` new messages. - assert (await sub.receive()) == msg_1 + assert (await sub.get()) == msg_1 with mock_router_publish() as event: # Test: add a topic validator and `push_msg` the message that diff --git a/tests/pubsub/test_subscription.py b/tests/pubsub/test_subscription.py new file mode 100644 index 00000000..c5a20eda --- /dev/null +++ b/tests/pubsub/test_subscription.py @@ -0,0 +1,77 @@ +import math + +import pytest +import trio + +from libp2p.pubsub.pb import rpc_pb2 +from libp2p.pubsub.subscription import TrioSubscriptionAPI + +GET_TIMEOUT = 0.001 + + +def make_trio_subscription(): + send_channel, receive_channel = trio.open_memory_channel(math.inf) + return send_channel, TrioSubscriptionAPI(receive_channel) + + +def make_pubsub_msg(): + return rpc_pb2.Message() + + +async def send_something(send_channel): + msg = make_pubsub_msg() + await send_channel.send(msg) + return msg + + +@pytest.mark.trio +async def test_trio_subscription_get(): + send_channel, sub = make_trio_subscription() + data_0 = await send_something(send_channel) + data_1 = await send_something(send_channel) + assert data_0 == await sub.get() + assert data_1 == await sub.get() + # No more message + with pytest.raises(trio.TooSlowError): + with trio.fail_after(GET_TIMEOUT): + await sub.get() + + +@pytest.mark.trio +async def test_trio_subscription_iter(): + send_channel, sub = make_trio_subscription() + received_data = [] + + async def iter_subscriptions(subscription): + async for data in sub: + received_data.append(data) + + async with trio.open_nursery() as nursery: + nursery.start_soon(iter_subscriptions, sub) + await send_something(send_channel) + await send_something(send_channel) + await send_channel.aclose() + + assert len(received_data) == 2 + + +@pytest.mark.trio +async def test_trio_subscription_cancel(): + send_channel, sub = make_trio_subscription() + await sub.cancel() + # Test: If the subscription is cancelled, `send_channel` should be broken. + with pytest.raises(trio.BrokenResourceError): + await send_something(send_channel) + # Test: No side effect when cancelled twice. + await sub.cancel() + + +@pytest.mark.trio +async def test_trio_subscription_async_context_manager(): + send_channel, sub = make_trio_subscription() + async with sub: + # Test: `sub` is not cancelled yet, so `send_something` works fine. + await send_something(send_channel) + # Test: `sub` is cancelled, `send_something` fails + with pytest.raises(trio.BrokenResourceError): + await send_something(send_channel) From 6fe5871d96af328d3b7f7136054039c97be5770b Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 24 Dec 2019 14:44:28 +0800 Subject: [PATCH 24/81] Use `async-exit-stack` over contextlib For `AsyncExitStack` --- libp2p/tools/factories.py | 3 +-- libp2p/tools/pubsub/dummy_account_node.py | 3 ++- setup.py | 1 + 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/libp2p/tools/factories.py b/libp2p/tools/factories.py index 74fbb5fa..144c8bc0 100644 --- a/libp2p/tools/factories.py +++ b/libp2p/tools/factories.py @@ -1,7 +1,6 @@ -from contextlib import AsyncExitStack from typing import Any, AsyncIterator, Dict, List, Sequence, Tuple, cast -# NOTE: import ``asynccontextmanager`` from ``contextlib`` when support for python 3.6 is dropped. +from async_exit_stack import AsyncExitStack from async_generator import asynccontextmanager from async_service import background_trio_service import factory diff --git a/libp2p/tools/pubsub/dummy_account_node.py b/libp2p/tools/pubsub/dummy_account_node.py index 9079ac20..f29a9fd5 100644 --- a/libp2p/tools/pubsub/dummy_account_node.py +++ b/libp2p/tools/pubsub/dummy_account_node.py @@ -1,6 +1,7 @@ -from contextlib import AsyncExitStack, asynccontextmanager from typing import AsyncIterator, Dict, Tuple +from async_exit_stack import AsyncExitStack +from async_generator import asynccontextmanager from async_service import Service, background_trio_service from libp2p.host.host_interface import IHost diff --git a/setup.py b/setup.py index aaebd303..435f72ed 100644 --- a/setup.py +++ b/setup.py @@ -76,6 +76,7 @@ install_requires = [ "async_generator==1.10", "trio>=0.13.0", "async-service>=0.1.0a2,<0.2.0", + "async-exit-stack==1.0.1", ] From 3372c32432822dcbbee424ae9c5b86b9abe9f093 Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 24 Dec 2019 18:03:18 +0800 Subject: [PATCH 25/81] Fix examples and modify `new_node` - Fix examples `chat.py` and `echo.py` - Use trio directly, instead of `trio-asyncio` - Remove redundant code - Change entry API `new_node` to `new_host_trio` --- examples/chat/chat.py | 98 ++++++++------------- examples/echo/echo.py | 95 ++++++++------------ libp2p/__init__.py | 84 +++++++++--------- libp2p/network/network_interface.py | 5 ++ libp2p/network/swarm.py | 5 +- libp2p/peer/peerinfo.py | 3 - tests/security/test_security_multistream.py | 22 +++-- 7 files changed, 129 insertions(+), 183 deletions(-) diff --git a/examples/chat/chat.py b/examples/chat/chat.py index fbfd89e3..80dcc862 100755 --- a/examples/chat/chat.py +++ b/examples/chat/chat.py @@ -1,13 +1,10 @@ import argparse -import asyncio import sys -import urllib.request import multiaddr import trio -import trio_asyncio -from libp2p import new_node +from libp2p import new_host_trio from libp2p.network.stream.net_stream_interface import INetStream from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.typing import TProtocol @@ -28,60 +25,48 @@ async def read_data(stream: INetStream) -> None: async def write_data(stream: INetStream) -> None: - loop = asyncio.get_event_loop() + async_f = trio.wrap_file(sys.stdin) while True: - line = await loop.run_in_executor(None, sys.stdin.readline) + line = await async_f.readline() await stream.write(line.encode()) -async def run(port: int, destination: str, localhost: bool) -> None: - if localhost: - ip = "127.0.0.1" - else: - ip = urllib.request.urlopen("https://v4.ident.me/").read().decode("utf8") - transport_opt = f"/ip4/{ip}/tcp/{port}" - host = new_node(transport_opt=[transport_opt]) +async def run(port: int, destination: str) -> None: + localhost_ip = "127.0.0.1" + listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") + async with new_host_trio( + listen_addrs=[listen_addr] + ) as host, trio.open_nursery() as nursery: + if not destination: # its the server - await trio_asyncio.run_asyncio( - host.get_network().listen, multiaddr.Multiaddr(transport_opt) - ) + async def stream_handler(stream: INetStream) -> None: + nursery.start_soon(read_data, stream) + nursery.start_soon(write_data, stream) - if not destination: # its the server + host.set_stream_handler(PROTOCOL_ID, stream_handler) - async def stream_handler(stream: INetStream) -> None: - asyncio.ensure_future(read_data(stream)) - asyncio.ensure_future(write_data(stream)) + print( + f"Run 'python ./examples/chat/chat.py " + f"-p {int(port) + 1} " + f"-d /ip4/{localhost_ip}/tcp/{port}/p2p/{host.get_id().pretty()}' " + "on another console." + ) + print("Waiting for incoming connection...") - host.set_stream_handler(PROTOCOL_ID, stream_handler) + else: # its the client + maddr = multiaddr.Multiaddr(destination) + info = info_from_p2p_addr(maddr) + # Associate the peer with local ip address + await host.connect(info) + # Start a stream with the destination. + # Multiaddress of the destination peer is fetched from the peerstore using 'peerId'. + stream = await host.new_stream(info.peer_id, [PROTOCOL_ID]) - localhost_opt = " --localhost" if localhost else "" + nursery.start_soon(read_data, stream) + nursery.start_soon(write_data, stream) + print("Connected to peer %s" % info.addrs[0]) - print( - f"Run 'python ./examples/chat/chat.py" - + localhost_opt - + f" -p {int(port) + 1} -d /ip4/{ip}/tcp/{port}/p2p/{host.get_id().pretty()}'" - + " on another console." - ) - print("Waiting for incoming connection...") - - else: # its the client - maddr = multiaddr.Multiaddr(destination) - info = info_from_p2p_addr(maddr) - # Associate the peer with local ip address - await trio_asyncio.run_asyncio(host.connect, info) - - # Start a stream with the destination. - # Multiaddress of the destination peer is fetched from the peerstore using 'peerId'. - stream = await trio_asyncio.run_asyncio( - host.new_stream, *(info.peer_id, [PROTOCOL_ID]) - ) - - asyncio.ensure_future(read_data(stream)) - asyncio.ensure_future(write_data(stream)) - print("Connected to peer %s" % info.addrs[0]) - - stopped_event = trio.Event() - await stopped_event.wait() + await trio.sleep_forever() def main() -> None: @@ -95,11 +80,6 @@ def main() -> None: "/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q" ) parser = argparse.ArgumentParser(description=description) - parser.add_argument( - "--debug", - action="store_true", - help="generate the same node ID on every execution", - ) parser.add_argument( "-p", "--port", default=8000, type=int, help="source port number" ) @@ -109,19 +89,15 @@ def main() -> None: type=str, help=f"destination multiaddr string, e.g. {example_maddr}", ) - parser.add_argument( - "-l", - "--localhost", - dest="localhost", - action="store_true", - help="flag indicating if localhost should be used or an external IP", - ) args = parser.parse_args() if not args.port: raise RuntimeError("was not able to determine a local port") - trio_asyncio.run(run, *(args.port, args.destination, args.localhost)) + try: + trio.run(run, *(args.port, args.destination)) + except KeyboardInterrupt: + pass if __name__ == "__main__": diff --git a/examples/echo/echo.py b/examples/echo/echo.py index 3f3ed33e..b39f8138 100644 --- a/examples/echo/echo.py +++ b/examples/echo/echo.py @@ -1,10 +1,9 @@ import argparse -import asyncio -import urllib.request import multiaddr +import trio -from libp2p import new_node +from libp2p import new_host_trio from libp2p.crypto.secp256k1 import create_new_key_pair from libp2p.network.stream.net_stream_interface import INetStream from libp2p.peer.peerinfo import info_from_p2p_addr @@ -20,12 +19,9 @@ async def _echo_stream_handler(stream: INetStream) -> None: await stream.close() -async def run(port: int, destination: str, localhost: bool, seed: int = None) -> None: - if localhost: - ip = "127.0.0.1" - else: - ip = urllib.request.urlopen("https://v4.ident.me/").read().decode("utf8") - transport_opt = f"/ip4/{ip}/tcp/{port}" +async def run(port: int, destination: str, seed: int = None) -> None: + localhost_ip = "127.0.0.1" + listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") if seed: import random @@ -38,47 +34,44 @@ async def run(port: int, destination: str, localhost: bool, seed: int = None) -> secret = secrets.token_bytes(32) - host = await new_node( - key_pair=create_new_key_pair(secret), transport_opt=[transport_opt] - ) + async with new_host_trio( + listen_addrs=[listen_addr], key_pair=create_new_key_pair(secret) + ) as host: - print(f"I am {host.get_id().to_string()}") + print(f"I am {host.get_id().to_string()}") - await host.get_network().listen(multiaddr.Multiaddr(transport_opt)) + if not destination: # its the server - if not destination: # its the server + host.set_stream_handler(PROTOCOL_ID, _echo_stream_handler) - host.set_stream_handler(PROTOCOL_ID, _echo_stream_handler) + print( + f"Run 'python ./examples/echo/echo.py " + f"-p {int(port) + 1} " + f"-d /ip4/{localhost_ip}/tcp/{port}/p2p/{host.get_id().pretty()}' " + "on another console." + ) + print("Waiting for incoming connections...") + await trio.sleep_forever() - localhost_opt = " --localhost" if localhost else "" + else: # its the client + maddr = multiaddr.Multiaddr(destination) + info = info_from_p2p_addr(maddr) + # Associate the peer with local ip address + await host.connect(info) - print( - f"Run 'python ./examples/echo/echo.py" - + localhost_opt - + f" -p {int(port) + 1} -d /ip4/{ip}/tcp/{port}/p2p/{host.get_id().pretty()}'" - + " on another console." - ) - print("Waiting for incoming connections...") + # Start a stream with the destination. + # Multiaddress of the destination peer is fetched from the peerstore using 'peerId'. + stream = await host.new_stream(info.peer_id, [PROTOCOL_ID]) - else: # its the client - maddr = multiaddr.Multiaddr(destination) - info = info_from_p2p_addr(maddr) - # Associate the peer with local ip address - await host.connect(info) + msg = b"hi, there!\n" - # Start a stream with the destination. - # Multiaddress of the destination peer is fetched from the peerstore using 'peerId'. - stream = await host.new_stream(info.peer_id, [PROTOCOL_ID]) + await stream.write(msg) + # Notify the other side about EOF + await stream.close() + response = await stream.read() - msg = b"hi, there!\n" - - await stream.write(msg) - # Notify the other side about EOF - await stream.close() - response = await stream.read() - - print(f"Sent: {msg}") - print(f"Got: {response}") + print(f"Sent: {msg}") + print(f"Got: {response}") def main() -> None: @@ -94,11 +87,6 @@ def main() -> None: "/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q" ) parser = argparse.ArgumentParser(description=description) - parser.add_argument( - "--debug", - action="store_true", - help="generate the same node ID on every execution", - ) parser.add_argument( "-p", "--port", default=8000, type=int, help="source port number" ) @@ -108,13 +96,6 @@ def main() -> None: type=str, help=f"destination multiaddr string, e.g. {example_maddr}", ) - parser.add_argument( - "-l", - "--localhost", - dest="localhost", - action="store_true", - help="flag indicating if localhost should be used or an external IP", - ) parser.add_argument( "-s", "--seed", @@ -126,16 +107,10 @@ def main() -> None: if not args.port: raise RuntimeError("was not able to determine a local port") - loop = asyncio.get_event_loop() try: - asyncio.ensure_future( - run(args.port, args.destination, args.localhost, args.seed) - ) - loop.run_forever() + trio.run(run, args.port, args.destination, args.seed) except KeyboardInterrupt: pass - finally: - loop.close() if __name__ == "__main__": diff --git a/libp2p/__init__.py b/libp2p/__init__.py index c0cd6cea..658fceb3 100644 --- a/libp2p/__init__.py +++ b/libp2p/__init__.py @@ -1,11 +1,14 @@ -from typing import Sequence +from typing import AsyncIterator, Sequence + +from async_generator import asynccontextmanager +from async_service import background_trio_service from libp2p.crypto.keys import KeyPair from libp2p.crypto.rsa import create_new_key_pair from libp2p.host.basic_host import BasicHost from libp2p.host.host_interface import IHost from libp2p.host.routed_host import RoutedHost -from libp2p.network.network_interface import INetwork +from libp2p.network.network_interface import INetwork, INetworkService from libp2p.network.swarm import Swarm from libp2p.peer.id import ID from libp2p.peer.peerstore import PeerStore @@ -30,28 +33,27 @@ def generate_peer_id_from(key_pair: KeyPair) -> ID: def initialize_default_swarm( - key_pair: KeyPair, - id_opt: ID = None, - transport_opt: Sequence[str] = None, + key_pair: KeyPair = None, muxer_opt: TMuxerOptions = None, sec_opt: TSecurityOptions = None, peerstore_opt: IPeerStore = None, -) -> Swarm: +) -> INetworkService: """ initialize swarm when no swarm is passed in. - :param id_opt: optional id for host - :param transport_opt: optional choice of transport upgrade + :param key_pair: optional choice of the ``KeyPair`` :param muxer_opt: optional choice of stream muxer :param sec_opt: optional choice of security upgrade :param peerstore_opt: optional peerstore :return: return a default swarm instance """ - if not id_opt: - id_opt = generate_peer_id_from(key_pair) + if not key_pair: + key_pair = generate_new_rsa_identity() - # TODO: Parse `transport_opt` to determine transport + id_opt = generate_peer_id_from(key_pair) + + # TODO: Parse `listen_addrs` to determine transport transport = TCP() muxer_transports_by_protocol = muxer_opt or {MPLEX_PROTOCOL_ID: Mplex} @@ -67,48 +69,17 @@ def initialize_default_swarm( # Store our key pair in peerstore peerstore.add_key_pair(id_opt, key_pair) - # TODO: Initialize discovery if not presented return Swarm(id_opt, peerstore, upgrader, transport) -def new_node( - key_pair: KeyPair = None, - swarm_opt: INetwork = None, - transport_opt: Sequence[str] = None, - muxer_opt: TMuxerOptions = None, - sec_opt: TSecurityOptions = None, - peerstore_opt: IPeerStore = None, - disc_opt: IPeerRouting = None, -) -> BasicHost: +def _new_host(swarm_opt: INetwork, disc_opt: IPeerRouting = None) -> IHost: """ - create new libp2p node. + create new libp2p host. - :param key_pair: key pair for deriving an identity :param swarm_opt: optional swarm - :param id_opt: optional id for host - :param transport_opt: optional choice of transport upgrade - :param muxer_opt: optional choice of stream muxer - :param sec_opt: optional choice of security upgrade - :param peerstore_opt: optional peerstore :param disc_opt: optional discovery :return: return a host instance """ - - if not key_pair: - key_pair = generate_new_rsa_identity() - - id_opt = generate_peer_id_from(key_pair) - - if not swarm_opt: - swarm_opt = initialize_default_swarm( - key_pair=key_pair, - id_opt=id_opt, - transport_opt=transport_opt, - muxer_opt=muxer_opt, - sec_opt=sec_opt, - peerstore_opt=peerstore_opt, - ) - # TODO enable support for other host type # TODO routing unimplemented host: IHost # If not explicitly typed, MyPy raises error @@ -118,3 +89,28 @@ def new_node( host = BasicHost(swarm_opt) return host + + +@asynccontextmanager +async def new_host_trio( + listen_addrs: Sequence[str], + key_pair: KeyPair = None, + swarm_opt: INetwork = None, + muxer_opt: TMuxerOptions = None, + sec_opt: TSecurityOptions = None, + peerstore_opt: IPeerStore = None, + disc_opt: IPeerRouting = None, +) -> AsyncIterator[IHost]: + swarm = initialize_default_swarm( + key_pair=key_pair, + muxer_opt=muxer_opt, + sec_opt=sec_opt, + peerstore_opt=peerstore_opt, + ) + async with background_trio_service(swarm): + await swarm.listen(*listen_addrs) + host = _new_host(swarm_opt=swarm, disc_opt=disc_opt) + yield host + + +# TODO: Support asyncio diff --git a/libp2p/network/network_interface.py b/libp2p/network/network_interface.py index 9e942831..c759a411 100644 --- a/libp2p/network/network_interface.py +++ b/libp2p/network/network_interface.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Dict, Sequence +from async_service import ServiceAPI from multiaddr import Multiaddr from libp2p.network.connection.net_connection_interface import INetConn @@ -70,3 +71,7 @@ class INetwork(ABC): @abstractmethod async def close_peer(self, peer_id: ID) -> None: pass + + +class INetworkService(INetwork, ServiceAPI): + ... diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 873f2399..9d19717b 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -1,7 +1,6 @@ import logging from typing import Dict, List, Optional -from async_service import Service from multiaddr import Multiaddr import trio @@ -25,14 +24,14 @@ from ..exceptions import MultiError from .connection.raw_connection import RawConnection from .connection.swarm_connection import SwarmConn from .exceptions import SwarmException -from .network_interface import INetwork +from .network_interface import INetworkService from .notifee_interface import INotifee from .stream.net_stream_interface import INetStream logger = logging.getLogger("libp2p.network.swarm") -class Swarm(INetwork, Service): +class Swarm(INetworkService): self_id: ID peerstore: IPeerStore diff --git a/libp2p/peer/peerinfo.py b/libp2p/peer/peerinfo.py index 4015ef97..889b6f61 100644 --- a/libp2p/peer/peerinfo.py +++ b/libp2p/peer/peerinfo.py @@ -25,9 +25,6 @@ def info_from_p2p_addr(addr: multiaddr.Multiaddr) -> PeerInfo: if not addr: raise InvalidAddrError("`addr` should not be `None`") - if not isinstance(addr, multiaddr.Multiaddr): - raise InvalidAddrError(f"`addr`={addr} should be of type `Multiaddr`") - parts = addr.split() if not parts: raise InvalidAddrError( diff --git a/tests/security/test_security_multistream.py b/tests/security/test_security_multistream.py index 5c751f92..8d466106 100644 --- a/tests/security/test_security_multistream.py +++ b/tests/security/test_security_multistream.py @@ -1,8 +1,7 @@ -from async_service import background_trio_service import pytest import trio -from libp2p import new_node +from libp2p import new_host_trio from libp2p.crypto.rsa import create_new_key_pair from libp2p.security.insecure.transport import InsecureSession, InsecureTransport from libp2p.tools.constants import LISTEN_MADDR @@ -30,16 +29,15 @@ async def perform_simple_test( # for testing, we do NOT want to communicate over a stream so we can't just create two nodes # and use their conn because our mplex will internally relay messages to a stream - node1 = new_node(key_pair=initiator_key_pair, sec_opt=transports_for_initiator) - node2 = new_node( - key_pair=noninitiator_key_pair, sec_opt=transports_for_noninitiator - ) - swarm1 = node1.get_network() - swarm2 = node2.get_network() - async with background_trio_service(swarm1), background_trio_service(swarm2): - await swarm1.listen(LISTEN_MADDR) - await swarm2.listen(LISTEN_MADDR) - + async with new_host_trio( + listen_addrs=[LISTEN_MADDR], + key_pair=initiator_key_pair, + sec_opt=transports_for_initiator, + ) as node1, new_host_trio( + listen_addrs=[LISTEN_MADDR], + key_pair=noninitiator_key_pair, + sec_opt=transports_for_noninitiator, + ) as node2: await connect(node1, node2) # Wait a very short period to allow conns to be stored (since the functions From 2287dc95befbc4fe23994c6e80e1516d4ecf8134 Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 24 Dec 2019 18:08:33 +0800 Subject: [PATCH 26/81] Fix test for `info_from_p2p_addr` It is because I removed some checks in the function. This checks should be useless thanks to mypy --- tests/peer/test_peerinfo.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/peer/test_peerinfo.py b/tests/peer/test_peerinfo.py index 29c46887..deb760d4 100644 --- a/tests/peer/test_peerinfo.py +++ b/tests/peer/test_peerinfo.py @@ -25,8 +25,6 @@ def test_init_(): @pytest.mark.parametrize( "addr", ( - pytest.param(None), - pytest.param(random.randint(0, 255), id="random integer"), pytest.param(multiaddr.Multiaddr("/"), id="empty multiaddr"), pytest.param( multiaddr.Multiaddr("/ip4/127.0.0.1"), From 573c049d0f9290b0ce0908b2ed0f4500d07f6982 Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 24 Dec 2019 18:31:39 +0800 Subject: [PATCH 27/81] Catch expections in `PubsubNotifee` Also, add lock to avoid resource race condition --- libp2p/pubsub/pubsub_notifee.py | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/libp2p/pubsub/pubsub_notifee.py b/libp2p/pubsub/pubsub_notifee.py index 7394736e..2d66da82 100644 --- a/libp2p/pubsub/pubsub_notifee.py +++ b/libp2p/pubsub/pubsub_notifee.py @@ -16,6 +16,7 @@ class PubsubNotifee(INotifee): initiator_peers_queue: "trio.MemorySendChannel[ID]" dead_peers_queue: "trio.MemorySendChannel[ID]" + dead_peers_queue_lock: trio.Lock def __init__( self, @@ -29,7 +30,9 @@ class PubsubNotifee(INotifee): can process dead peers after we disconnect from each other """ self.initiator_peers_queue = initiator_peers_queue + self.initiator_peers_queue_lock: trio.Lock() self.dead_peers_queue = dead_peers_queue + self.dead_peers_queue_lock: trio.Lock() async def opened_stream(self, network: INetwork, stream: INetStream) -> None: pass @@ -46,12 +49,17 @@ class PubsubNotifee(INotifee): :param network: network the connection was opened on :param conn: connection that was opened """ - try: - await self.initiator_peers_queue.send(conn.muxed_conn.peer_id) - except trio.BrokenResourceError: - # Raised when the receive channel is closed. - # TODO: Do something with loggers? - ... + async with self.initiator_peers_queue_lock: + try: + await self.initiator_peers_queue.send(conn.muxed_conn.peer_id) + except ( + trio.BrokenResourceError, + trio.ClosedResourceError, + trio.BusyResourceError, + ): + # Raised when the receive channel is closed. + # TODO: Do something with loggers? + ... async def disconnected(self, network: INetwork, conn: INetConn) -> None: """ @@ -61,7 +69,17 @@ class PubsubNotifee(INotifee): :param network: network the connection was opened on :param conn: connection that was opened """ - await self.dead_peers_queue.send(conn.muxed_conn.peer_id) + async with self.dead_peers_queue_lock: + try: + await self.dead_peers_queue.send(conn.muxed_conn.peer_id) + except ( + trio.BrokenResourceError, + trio.ClosedResourceError, + trio.BusyResourceError, + ): + # Raised when the receive channel is closed. + # TODO: Do something with loggers? + ... async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None: pass From 53dbb0aff19aef28d311f9dd275d725e7502e6b1 Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 24 Dec 2019 18:37:59 +0800 Subject: [PATCH 28/81] Fix pubsub_notifee.py For wrong syntax and import --- libp2p/pubsub/pubsub_notifee.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libp2p/pubsub/pubsub_notifee.py b/libp2p/pubsub/pubsub_notifee.py index 2d66da82..be481901 100644 --- a/libp2p/pubsub/pubsub_notifee.py +++ b/libp2p/pubsub/pubsub_notifee.py @@ -1,6 +1,7 @@ from typing import TYPE_CHECKING from multiaddr import Multiaddr +import trio from libp2p.network.connection.net_connection_interface import INetConn from libp2p.network.network_interface import INetwork @@ -8,7 +9,6 @@ from libp2p.network.notifee_interface import INotifee from libp2p.network.stream.net_stream_interface import INetStream if TYPE_CHECKING: - import trio # noqa: F401 from libp2p.peer.id import ID # noqa: F401 @@ -30,9 +30,9 @@ class PubsubNotifee(INotifee): can process dead peers after we disconnect from each other """ self.initiator_peers_queue = initiator_peers_queue - self.initiator_peers_queue_lock: trio.Lock() + self.initiator_peers_queue_lock = trio.Lock() self.dead_peers_queue = dead_peers_queue - self.dead_peers_queue_lock: trio.Lock() + self.dead_peers_queue_lock = trio.Lock() async def opened_stream(self, network: INetwork, stream: INetStream) -> None: pass From 6ae3f5dc1b931de5b561bb5e7e6559c1ce99f3ec Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 24 Dec 2019 21:28:37 +0800 Subject: [PATCH 29/81] Add checkpoints in tests --- tests/pubsub/test_pubsub.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/pubsub/test_pubsub.py b/tests/pubsub/test_pubsub.py index 293b2c02..c05aecbb 100644 --- a/tests/pubsub/test_pubsub.py +++ b/tests/pubsub/test_pubsub.py @@ -154,6 +154,7 @@ async def test_get_msg_validators(): async def async_validator(peer_id, msg): nonlocal times_async_validator_called times_async_validator_called += 1 + await trio.hazmat.checkpoint() topic_1 = "TEST_VALIDATOR_1" topic_2 = "TEST_VALIDATOR_2" @@ -199,9 +200,11 @@ async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed): return False async def passed_async_validator(peer_id, msg): + await trio.hazmat.checkpoint() return True async def failed_async_validator(peer_id, msg): + await trio.hazmat.checkpoint() return False topic_1 = "TEST_SYNC_VALIDATOR" From fb6076c061eaaa7e3954332ee710027ecf9eaaf3 Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 24 Dec 2019 21:50:42 +0800 Subject: [PATCH 30/81] Upgrade to 0.1.0a4 Probably it can solve the dag issue: https://github.com/ethereum/async-service/issues/12 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 435f72ed..382bb407 100644 --- a/setup.py +++ b/setup.py @@ -75,7 +75,7 @@ install_requires = [ "dataclasses>=0.7, <1;python_version<'3.7'", "async_generator==1.10", "trio>=0.13.0", - "async-service>=0.1.0a2,<0.2.0", + "async-service>=0.1.0a4,<0.2.0", "async-exit-stack==1.0.1", ] From 3c98b1973ddb8269d95a7de036cb9bd4dd0dcb90 Mon Sep 17 00:00:00 2001 From: mhchia Date: Thu, 26 Dec 2019 20:43:38 +0800 Subject: [PATCH 31/81] Remove useless conftest for pubsub --- tests/pubsub/conftest.py | 58 ---------------------------------------- 1 file changed, 58 deletions(-) delete mode 100644 tests/pubsub/conftest.py diff --git a/tests/pubsub/conftest.py b/tests/pubsub/conftest.py deleted file mode 100644 index 520fdf4b..00000000 --- a/tests/pubsub/conftest.py +++ /dev/null @@ -1,58 +0,0 @@ -import pytest - -from libp2p.tools.constants import GOSSIPSUB_PARAMS -from libp2p.tools.factories import FloodsubFactory, GossipsubFactory, PubsubFactory - - -@pytest.fixture -def is_strict_signing(): - return False - - -def _make_pubsubs(hosts, pubsub_routers, cache_size, is_strict_signing): - if len(pubsub_routers) != len(hosts): - raise ValueError( - f"lenght of pubsub_routers={pubsub_routers} should be equaled to the " - f"length of hosts={len(hosts)}" - ) - return tuple( - PubsubFactory( - host=host, - router=router, - cache_size=cache_size, - strict_signing=is_strict_signing, - ) - for host, router in zip(hosts, pubsub_routers) - ) - - -@pytest.fixture -def pubsub_cache_size(): - return None # default - - -@pytest.fixture -def gossipsub_params(): - return GOSSIPSUB_PARAMS - - -@pytest.fixture -def pubsubs_fsub(num_hosts, hosts, pubsub_cache_size, is_strict_signing): - floodsubs = FloodsubFactory.create_batch(num_hosts) - _pubsubs_fsub = _make_pubsubs( - hosts, floodsubs, pubsub_cache_size, is_strict_signing - ) - yield _pubsubs_fsub - # TODO: Clean up - - -@pytest.fixture -def pubsubs_gsub( - num_hosts, hosts, pubsub_cache_size, gossipsub_params, is_strict_signing -): - gossipsubs = GossipsubFactory.create_batch(num_hosts, **gossipsub_params._asdict()) - _pubsubs_gsub = _make_pubsubs( - hosts, gossipsubs, pubsub_cache_size, is_strict_signing - ) - yield _pubsubs_gsub - # TODO: Clean up From 68c84b273dd7b992f2a45c918d88e218310ef7a2 Mon Sep 17 00:00:00 2001 From: mhchia Date: Thu, 26 Dec 2019 20:44:10 +0800 Subject: [PATCH 32/81] Use `cls` over the name of the factory --- libp2p/tools/factories.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libp2p/tools/factories.py b/libp2p/tools/factories.py index 144c8bc0..ae99f345 100644 --- a/libp2p/tools/factories.py +++ b/libp2p/tools/factories.py @@ -238,7 +238,7 @@ class PubsubFactory(factory.Factory): async def create_and_start( cls, host: IHost, router: IPubsubRouter, cache_size: int, strict_signing: bool ) -> AsyncIterator[Pubsub]: - pubsub = PubsubFactory( + pubsub = cls( host=host, router=router, cache_size=cache_size, From 94f0fcb6ad9b50aa093ed7e2b0d56f3c75ac8125 Mon Sep 17 00:00:00 2001 From: mhchia Date: Thu, 26 Dec 2019 20:44:32 +0800 Subject: [PATCH 33/81] Iterate `dead_peer_receive_channel` with async for --- libp2p/pubsub/pubsub.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index e735d872..e0026de8 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -332,8 +332,7 @@ class Pubsub(IPubsub, Service): open a stream to the peer using a supported pubsub protocol pubsub protocols we support.""" async with self.peer_receive_channel: - while self.manager.is_running: - peer_id: ID = await self.peer_receive_channel.receive() + async for peer_id in self.peer_receive_channel: # Add Peer self.manager.run_task(self._handle_new_peer, peer_id) @@ -342,8 +341,7 @@ class Pubsub(IPubsub, Service): between that peer and remove peer info from pubsub and pubsub router.""" async with self.dead_peer_receive_channel: - while self.manager.is_running: - peer_id: ID = await self.dead_peer_receive_channel.receive() + async for peer_id in self.dead_peer_receive_channel: # Remove Peer self._handle_dead_peer(peer_id) From 000e777ac73da4210bf4fb925163e173b03cae17 Mon Sep 17 00:00:00 2001 From: mhchia Date: Thu, 26 Dec 2019 20:44:58 +0800 Subject: [PATCH 34/81] Try older async-service --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 382bb407..5f3a873d 100644 --- a/setup.py +++ b/setup.py @@ -75,7 +75,7 @@ install_requires = [ "dataclasses>=0.7, <1;python_version<'3.7'", "async_generator==1.10", "trio>=0.13.0", - "async-service>=0.1.0a4,<0.2.0", + "async-service==0.1.0a2", "async-exit-stack==1.0.1", ] From fe4354d377ed2dfa19d1746b024d379d108ecd5f Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 7 Jan 2020 14:14:34 +0800 Subject: [PATCH 35/81] Fix `tests_interop` - Remove pexpect - Use new version of `p2pclient`, which makes use of anyio - Clean up tests --- libp2p/tools/constants.py | 2 +- libp2p/tools/interop/constants.py | 1 - libp2p/tools/interop/daemon.py | 147 ++++-------------- libp2p/tools/interop/process.py | 66 ++++++++ libp2p/tools/interop/utils.py | 4 +- tests_interop/conftest.py | 216 +++++++++----------------- tests_interop/test_bindings.py | 42 +++--- tests_interop/test_echo.py | 148 ++++++++++-------- tests_interop/test_net_stream.py | 23 ++- tests_interop/test_pubsub.py | 241 +++++++++++++++--------------- tox.ini | 2 +- 11 files changed, 415 insertions(+), 477 deletions(-) create mode 100644 libp2p/tools/interop/process.py diff --git a/libp2p/tools/constants.py b/libp2p/tools/constants.py index 8c22d151..b1ad2652 100644 --- a/libp2p/tools/constants.py +++ b/libp2p/tools/constants.py @@ -7,7 +7,7 @@ from libp2p.pubsub import floodsub, gossipsub # Just a arbitrary large number. # It is used when calling `MplexStream.read(MAX_READ_LEN)`, # to avoid `MplexStream.read()`, which blocking reads until EOF. -MAX_READ_LEN = 2 ** 32 - 1 +MAX_READ_LEN = 65535 LISTEN_MADDR = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0") diff --git a/libp2p/tools/interop/constants.py b/libp2p/tools/interop/constants.py index 331e2843..8f039f1c 100644 --- a/libp2p/tools/interop/constants.py +++ b/libp2p/tools/interop/constants.py @@ -1,2 +1 @@ LOCALHOST_IP = "127.0.0.1" -PEXPECT_NEW_LINE = "\r\n" diff --git a/libp2p/tools/interop/daemon.py b/libp2p/tools/interop/daemon.py index 43cfd6db..f6e363e0 100644 --- a/libp2p/tools/interop/daemon.py +++ b/libp2p/tools/interop/daemon.py @@ -1,52 +1,22 @@ -import asyncio -import time -from typing import Any, Awaitable, Callable, List +from typing import AsyncIterator +from async_generator import asynccontextmanager import multiaddr from multiaddr import Multiaddr from p2pclient import Client -import pytest +import trio from libp2p.peer.id import ID from libp2p.peer.peerinfo import PeerInfo, info_from_p2p_addr from .constants import LOCALHOST_IP from .envs import GO_BIN_PATH +from .process import BaseInteractiveProcess P2PD_PATH = GO_BIN_PATH / "p2pd" -TIMEOUT_DURATION = 30 - - -async def try_until_success( - coro_func: Callable[[], Awaitable[Any]], timeout: int = TIMEOUT_DURATION -) -> None: - """ - Keep running ``coro_func`` until either it succeed or time is up. - - All arguments of ``coro_func`` should be filled, i.e. it should be - called without arguments. - """ - t_start = time.monotonic() - while True: - result = await coro_func() - if result: - break - if (time.monotonic() - t_start) >= timeout: - # timeout - pytest.fail(f"{coro_func} is still failing after `{timeout}` seconds") - await asyncio.sleep(0.01) - - -class P2PDProcess: - proc: asyncio.subprocess.Process - cmd: str = str(P2PD_PATH) - args: List[Any] - is_proc_running: bool - - _tasks: List["asyncio.Future[Any]"] - +class P2PDProcess(BaseInteractiveProcess): def __init__( self, control_maddr: Multiaddr, @@ -75,74 +45,21 @@ class P2PDProcess: # - gossipsubHeartbeatInterval: GossipSubHeartbeatInitialDelay = 100 * time.Millisecond # noqa: E501 # - gossipsubHeartbeatInitialDelay: GossipSubHeartbeatInterval = 1 * time.Second # Referece: https://github.com/libp2p/go-libp2p-daemon/blob/b95e77dbfcd186ccf817f51e95f73f9fd5982600/p2pd/main.go#L348-L353 # noqa: E501 + self.proc = None + self.cmd = str(P2PD_PATH) self.args = args - self.is_proc_running = False - - self._tasks = [] - - async def wait_until_ready(self) -> None: - lines_head_pattern = (b"Control socket:", b"Peer ID:", b"Peer Addrs:") - lines_head_occurred = {line: False for line in lines_head_pattern} - - async def read_from_daemon_and_check() -> bool: - line = await self.proc.stdout.readline() - for head_pattern in lines_head_occurred: - if line.startswith(head_pattern): - lines_head_occurred[head_pattern] = True - return all([value for value in lines_head_occurred.values()]) - - await try_until_success(read_from_daemon_and_check) - # Sleep a little bit to ensure the listener is up after logs are emitted. - await asyncio.sleep(0.01) - - async def start_printing_logs(self) -> None: - async def _print_from_stream( - src_name: str, reader: asyncio.StreamReader - ) -> None: - while True: - line = await reader.readline() - if line != b"": - print(f"{src_name}\t: {line.rstrip().decode()}") - await asyncio.sleep(0.01) - - self._tasks.append( - asyncio.ensure_future(_print_from_stream("out", self.proc.stdout)) - ) - self._tasks.append( - asyncio.ensure_future(_print_from_stream("err", self.proc.stderr)) - ) - await asyncio.sleep(0) - - async def start(self) -> None: - if self.is_proc_running: - return - self.proc = await asyncio.subprocess.create_subprocess_exec( - self.cmd, - *self.args, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - bufsize=0, - ) - self.is_proc_running = True - await self.wait_until_ready() - await self.start_printing_logs() - - async def close(self) -> None: - if self.is_proc_running: - self.proc.terminate() - await self.proc.wait() - self.is_proc_running = False - for task in self._tasks: - task.cancel() + self.patterns = (b"Control socket:", b"Peer ID:", b"Peer Addrs:") + self.bytes_read = bytearray() + self.event_ready = trio.Event() class Daemon: - p2pd_proc: P2PDProcess + p2pd_proc: BaseInteractiveProcess control: Client peer_info: PeerInfo def __init__( - self, p2pd_proc: P2PDProcess, control: Client, peer_info: PeerInfo + self, p2pd_proc: BaseInteractiveProcess, control: Client, peer_info: PeerInfo ) -> None: self.p2pd_proc = p2pd_proc self.control = control @@ -164,6 +81,7 @@ class Daemon: await self.control.close() +@asynccontextmanager async def make_p2pd( daemon_control_port: int, client_callback_port: int, @@ -172,7 +90,7 @@ async def make_p2pd( is_gossipsub: bool = True, is_pubsub_signing: bool = False, is_pubsub_signing_strict: bool = False, -) -> Daemon: +) -> AsyncIterator[Daemon]: control_maddr = Multiaddr(f"/ip4/{LOCALHOST_IP}/tcp/{daemon_control_port}") p2pd_proc = P2PDProcess( control_maddr, @@ -185,21 +103,22 @@ async def make_p2pd( await p2pd_proc.start() client_callback_maddr = Multiaddr(f"/ip4/{LOCALHOST_IP}/tcp/{client_callback_port}") p2pc = Client(control_maddr, client_callback_maddr) - await p2pc.listen() - peer_id, maddrs = await p2pc.identify() - listen_maddr: Multiaddr = None - for maddr in maddrs: - try: - ip = maddr.value_for_protocol(multiaddr.protocols.P_IP4) - # NOTE: Check if this `maddr` uses `tcp`. - maddr.value_for_protocol(multiaddr.protocols.P_TCP) - except multiaddr.exceptions.ProtocolLookupError: - continue - if ip == LOCALHOST_IP: - listen_maddr = maddr - break - assert listen_maddr is not None, "no loopback maddr is found" - peer_info = info_from_p2p_addr( - listen_maddr.encapsulate(Multiaddr(f"/p2p/{peer_id.to_string()}")) - ) - return Daemon(p2pd_proc, p2pc, peer_info) + + async with p2pc.listen(): + peer_id, maddrs = await p2pc.identify() + listen_maddr: Multiaddr = None + for maddr in maddrs: + try: + ip = maddr.value_for_protocol(multiaddr.protocols.P_IP4) + # NOTE: Check if this `maddr` uses `tcp`. + maddr.value_for_protocol(multiaddr.protocols.P_TCP) + except multiaddr.exceptions.ProtocolLookupError: + continue + if ip == LOCALHOST_IP: + listen_maddr = maddr + break + assert listen_maddr is not None, "no loopback maddr is found" + peer_info = info_from_p2p_addr( + listen_maddr.encapsulate(Multiaddr(f"/p2p/{peer_id.to_string()}")) + ) + yield Daemon(p2pd_proc, p2pc, peer_info) diff --git a/libp2p/tools/interop/process.py b/libp2p/tools/interop/process.py new file mode 100644 index 00000000..0c17e51b --- /dev/null +++ b/libp2p/tools/interop/process.py @@ -0,0 +1,66 @@ +from abc import ABC, abstractmethod +import subprocess +from typing import Iterable, List + +import trio + +TIMEOUT_DURATION = 30 + + +class AbstractInterativeProcess(ABC): + @abstractmethod + async def start(self) -> None: + ... + + @abstractmethod + async def close(self) -> None: + ... + + +class BaseInteractiveProcess(AbstractInterativeProcess): + proc: trio.Process = None + cmd: str + args: List[str] + bytes_read: bytearray + patterns: Iterable[bytes] = None + event_ready: trio.Event + + async def wait_until_ready(self) -> None: + patterns_occurred = {pat: False for pat in self.patterns} + + async def read_from_daemon_and_check() -> None: + async for data in self.proc.stdout: + # TODO: It takes O(n^2), which is quite bad. + # But it should succeed in a few seconds. + self.bytes_read.extend(data) + for pat, occurred in patterns_occurred.items(): + if occurred: + continue + if pat in self.bytes_read: + patterns_occurred[pat] = True + if all([value for value in patterns_occurred.values()]): + return + + with trio.fail_after(TIMEOUT_DURATION): + await read_from_daemon_and_check() + self.event_ready.set() + # Sleep a little bit to ensure the listener is up after logs are emitted. + await trio.sleep(0.01) + + async def start(self) -> None: + if self.proc is not None: + return + # NOTE: Ignore type checks here since mypy complains about bufsize=0 + self.proc = await trio.open_process( # type: ignore + [self.cmd] + self.args, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, # Redirect stderr to stdout, which makes parsing easier + bufsize=0, + ) + await self.wait_until_ready() + + async def close(self) -> None: + if self.proc is None: + return + self.proc.terminate() + await self.proc.wait() diff --git a/libp2p/tools/interop/utils.py b/libp2p/tools/interop/utils.py index c9174179..ce05c8fb 100644 --- a/libp2p/tools/interop/utils.py +++ b/libp2p/tools/interop/utils.py @@ -1,7 +1,7 @@ -import asyncio from typing import Union from multiaddr import Multiaddr +import trio from libp2p.host.host_interface import IHost from libp2p.peer.id import ID @@ -50,7 +50,7 @@ async def connect(a: TDaemonOrHost, b: TDaemonOrHost) -> None: else: # isinstance(b, IHost) await a.connect(b_peer_info) # Allow additional sleep for both side to establish the connection. - await asyncio.sleep(0.1) + await trio.sleep(0.1) a_peer_info = _get_peer_info(a) diff --git a/tests_interop/conftest.py b/tests_interop/conftest.py index 08df614c..bad0054e 100644 --- a/tests_interop/conftest.py +++ b/tests_interop/conftest.py @@ -1,20 +1,13 @@ -import asyncio -import sys -from typing import Union - +import anyio +from async_exit_stack import AsyncExitStack from p2pclient.datastructures import StreamInfo -import pexpect +from p2pclient.utils import get_unused_tcp_port import pytest +import trio from libp2p.io.abc import ReadWriteCloser -from libp2p.tools.constants import GOSSIPSUB_PARAMS, LISTEN_MADDR -from libp2p.tools.factories import ( - FloodsubFactory, - GossipsubFactory, - HostFactory, - PubsubFactory, -) -from libp2p.tools.interop.daemon import Daemon, make_p2pd +from libp2p.tools.factories import HostFactory, PubsubFactory +from libp2p.tools.interop.daemon import make_p2pd from libp2p.tools.interop.utils import connect @@ -23,48 +16,6 @@ def is_host_secure(): return False -@pytest.fixture -def num_hosts(): - return 3 - - -@pytest.fixture -async def hosts(num_hosts, is_host_secure): - _hosts = HostFactory.create_batch(num_hosts, is_secure=is_host_secure) - await asyncio.gather( - *[_host.get_network().listen(LISTEN_MADDR) for _host in _hosts] - ) - try: - yield _hosts - finally: - # TODO: It's possible that `close` raises exceptions currently, - # due to the connection reset things. Though we don't care much about that when - # cleaning up the tasks, it is probably better to handle the exceptions properly. - await asyncio.gather( - *[_host.close() for _host in _hosts], return_exceptions=True - ) - - -@pytest.fixture -def proc_factory(): - procs = [] - - def call_proc(cmd, args, logfile=None, encoding=None): - if logfile is None: - logfile = sys.stdout - if encoding is None: - encoding = "utf-8" - proc = pexpect.spawn(cmd, args, logfile=logfile, encoding=encoding) - procs.append(proc) - return proc - - try: - yield call_proc - finally: - for proc in procs: - proc.close() - - @pytest.fixture def num_p2pds(): return 1 @@ -87,79 +38,60 @@ def is_pubsub_signing_strict(): @pytest.fixture async def p2pds( - num_p2pds, - is_host_secure, - is_gossipsub, - unused_tcp_port_factory, - is_pubsub_signing, - is_pubsub_signing_strict, + num_p2pds, is_host_secure, is_gossipsub, is_pubsub_signing, is_pubsub_signing_strict ): - p2pds: Union[Daemon, Exception] = await asyncio.gather( - *[ - make_p2pd( - unused_tcp_port_factory(), - unused_tcp_port_factory(), - is_host_secure, - is_gossipsub=is_gossipsub, - is_pubsub_signing=is_pubsub_signing, - is_pubsub_signing_strict=is_pubsub_signing_strict, + async with AsyncExitStack() as stack: + p2pds = [ + await stack.enter_async_context( + make_p2pd( + get_unused_tcp_port(), + get_unused_tcp_port(), + is_host_secure, + is_gossipsub=is_gossipsub, + is_pubsub_signing=is_pubsub_signing, + is_pubsub_signing_strict=is_pubsub_signing_strict, + ) ) for _ in range(num_p2pds) - ], - return_exceptions=True, - ) - p2pds_succeeded = tuple(p2pd for p2pd in p2pds if isinstance(p2pd, Daemon)) - if len(p2pds_succeeded) != len(p2pds): - # Not all succeeded. Close the succeeded ones and print the failed ones(exceptions). - await asyncio.gather(*[p2pd.close() for p2pd in p2pds_succeeded]) - exceptions = tuple(p2pd for p2pd in p2pds if isinstance(p2pd, Exception)) - raise Exception(f"not all p2pds succeed: first exception={exceptions[0]}") - try: - yield p2pds - finally: - await asyncio.gather(*[p2pd.close() for p2pd in p2pds]) + ] + try: + yield p2pds + finally: + for p2pd in p2pds: + await p2pd.close() @pytest.fixture -def pubsubs(num_hosts, hosts, is_gossipsub, is_pubsub_signing_strict): +async def pubsubs(num_hosts, is_host_secure, is_gossipsub, is_pubsub_signing_strict): if is_gossipsub: - routers = GossipsubFactory.create_batch(num_hosts, **GOSSIPSUB_PARAMS._asdict()) + yield PubsubFactory.create_batch_with_gossipsub( + num_hosts, is_secure=is_host_secure, strict_signing=is_pubsub_signing_strict + ) else: - routers = FloodsubFactory.create_batch(num_hosts) - _pubsubs = tuple( - PubsubFactory(host=host, router=router, strict_signing=is_pubsub_signing_strict) - for host, router in zip(hosts, routers) - ) - yield _pubsubs - # TODO: Clean up + yield PubsubFactory.create_batch_with_floodsub( + num_hosts, is_host_secure, strict_signing=is_pubsub_signing_strict + ) class DaemonStream(ReadWriteCloser): stream_info: StreamInfo - reader: asyncio.StreamReader - writer: asyncio.StreamWriter + stream: anyio.abc.SocketStream - def __init__( - self, - stream_info: StreamInfo, - reader: asyncio.StreamReader, - writer: asyncio.StreamWriter, - ) -> None: + def __init__(self, stream_info: StreamInfo, stream: anyio.abc.SocketStream) -> None: self.stream_info = stream_info - self.reader = reader - self.writer = writer + self.stream = stream async def close(self) -> None: - self.writer.close() - if sys.version_info < (3, 7): - return - await self.writer.wait_closed() + await self.stream.close() async def read(self, n: int = -1) -> bytes: - return await self.reader.read(n) + if n == -1: + return await self.stream.receive_some() + else: + return await self.stream.receive_some(n) - async def write(self, data: bytes) -> int: - return self.writer.write(data) + async def write(self, data: bytes) -> None: + return await self.stream.send_all(data) @pytest.fixture @@ -168,40 +100,38 @@ async def is_to_fail_daemon_stream(): @pytest.fixture -async def py_to_daemon_stream_pair(hosts, p2pds, is_to_fail_daemon_stream): - assert len(hosts) >= 1 - assert len(p2pds) >= 1 - host = hosts[0] - p2pd = p2pds[0] - protocol_id = "/protocol/id/123" - stream_py = None - stream_daemon = None - event_stream_handled = asyncio.Event() - await connect(host, p2pd) +async def py_to_daemon_stream_pair(p2pds, is_host_secure, is_to_fail_daemon_stream): + async with HostFactory.create_batch_and_listen(is_host_secure, 1) as hosts: + assert len(p2pds) >= 1 + host = hosts[0] + p2pd = p2pds[0] + protocol_id = "/protocol/id/123" + stream_py = None + stream_daemon = None + event_stream_handled = trio.Event() + await connect(host, p2pd) - async def daemon_stream_handler(stream_info, reader, writer): - nonlocal stream_daemon - stream_daemon = DaemonStream(stream_info, reader, writer) - event_stream_handled.set() + async def daemon_stream_handler(stream_info, stream): + nonlocal stream_daemon + stream_daemon = DaemonStream(stream_info, stream) + event_stream_handled.set() + await trio.hazmat.checkpoint() - await p2pd.control.stream_handler(protocol_id, daemon_stream_handler) - # Sleep for a while to wait for the handler being registered. - await asyncio.sleep(0.01) + await p2pd.control.stream_handler(protocol_id, daemon_stream_handler) + # Sleep for a while to wait for the handler being registered. + await trio.sleep(0.01) - if is_to_fail_daemon_stream: - # FIXME: This is a workaround to make daemon reset the stream. - # We intentionally close the listener on the python side, it makes the connection from - # daemon to us fail, and therefore the daemon resets the opened stream on their side. - # Reference: https://github.com/libp2p/go-libp2p-daemon/blob/b95e77dbfcd186ccf817f51e95f73f9fd5982600/stream.go#L47-L50 # noqa: E501 - # We need it because we want to test against `stream_py` after the remote side(daemon) - # is reset. This should be removed after the API `stream.reset` is exposed in daemon - # some day. - listener = p2pds[0].control.control.listener - listener.close() - if sys.version_info[0:2] > (3, 6): - await listener.wait_closed() - stream_py = await host.new_stream(p2pd.peer_id, [protocol_id]) - if not is_to_fail_daemon_stream: - await event_stream_handled.wait() - # NOTE: If `is_to_fail_daemon_stream == True`, then `stream_daemon == None`. - yield stream_py, stream_daemon + if is_to_fail_daemon_stream: + # FIXME: This is a workaround to make daemon reset the stream. + # We intentionally close the listener on the python side, it makes the connection from + # daemon to us fail, and therefore the daemon resets the opened stream on their side. + # Reference: https://github.com/libp2p/go-libp2p-daemon/blob/b95e77dbfcd186ccf817f51e95f73f9fd5982600/stream.go#L47-L50 # noqa: E501 + # We need it because we want to test against `stream_py` after the remote side(daemon) + # is reset. This should be removed after the API `stream.reset` is exposed in daemon + # some day. + await p2pds[0].control.control.close() + stream_py = await host.new_stream(p2pd.peer_id, [protocol_id]) + if not is_to_fail_daemon_stream: + await event_stream_handled.wait() + # NOTE: If `is_to_fail_daemon_stream == True`, then `stream_daemon == None`. + yield stream_py, stream_daemon diff --git a/tests_interop/test_bindings.py b/tests_interop/test_bindings.py index 9e70aa21..87cdbb1b 100644 --- a/tests_interop/test_bindings.py +++ b/tests_interop/test_bindings.py @@ -1,26 +1,26 @@ -import asyncio - import pytest +import trio +from libp2p.tools.factories import HostFactory from libp2p.tools.interop.utils import connect -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_connect(hosts, p2pds): - p2pd = p2pds[0] - host = hosts[0] - assert len(await p2pd.control.list_peers()) == 0 - # Test: connect from Py - await connect(host, p2pd) - assert len(await p2pd.control.list_peers()) == 1 - # Test: `disconnect` from Py - await host.disconnect(p2pd.peer_id) - assert len(await p2pd.control.list_peers()) == 0 - # Test: connect from Go - await connect(p2pd, host) - assert len(host.get_network().connections) == 1 - # Test: `disconnect` from Go - await p2pd.control.disconnect(host.get_id()) - await asyncio.sleep(0.01) - assert len(host.get_network().connections) == 0 +@pytest.mark.trio +async def test_connect(is_host_secure, p2pds): + async with HostFactory.create_batch_and_listen(is_host_secure, 1) as hosts: + p2pd = p2pds[0] + host = hosts[0] + assert len(await p2pd.control.list_peers()) == 0 + # Test: connect from Py + await connect(host, p2pd) + assert len(await p2pd.control.list_peers()) == 1 + # Test: `disconnect` from Py + await host.disconnect(p2pd.peer_id) + assert len(await p2pd.control.list_peers()) == 0 + # Test: connect from Go + await connect(p2pd, host) + assert len(host.get_network().connections) == 1 + # Test: `disconnect` from Go + await p2pd.control.disconnect(host.get_id()) + await trio.sleep(0.01) + assert len(host.get_network().connections) == 0 diff --git a/tests_interop/test_echo.py b/tests_interop/test_echo.py index 6ac867ba..d3cb72fd 100644 --- a/tests_interop/test_echo.py +++ b/tests_interop/test_echo.py @@ -1,82 +1,104 @@ -import asyncio +import random +import re from multiaddr import Multiaddr import pytest +import trio -from libp2p.peer.peerinfo import info_from_p2p_addr -from libp2p.tools.interop.constants import PEXPECT_NEW_LINE +from libp2p.peer.peerinfo import PeerInfo, info_from_p2p_addr +from libp2p.tools.factories import HostFactory from libp2p.tools.interop.envs import GO_BIN_PATH +from libp2p.tools.interop.process import BaseInteractiveProcess from libp2p.typing import TProtocol ECHO_PATH = GO_BIN_PATH / "echo" ECHO_PROTOCOL_ID = TProtocol("/echo/1.0.0") -async def make_echo_proc( - proc_factory, port: int, is_secure: bool, destination: Multiaddr = None -): - args = [f"-l={port}"] - if not is_secure: - args.append("-insecure") - if destination is not None: - args.append(f"-d={str(destination)}") - echo_proc = proc_factory(str(ECHO_PATH), args) - await echo_proc.expect(r"I am ([\w\./]+)" + PEXPECT_NEW_LINE, async_=True) - maddr_str_ipfs = echo_proc.match.group(1) - maddr_str = maddr_str_ipfs.replace("ipfs", "p2p") - maddr = Multiaddr(maddr_str) - go_pinfo = info_from_p2p_addr(maddr) - if destination is None: - await echo_proc.expect("listening for connections", async_=True) - return echo_proc, go_pinfo +# FIXME: Change to a reasonable implementation +def unused_tcp_port_factory(): + return random.randint(1024, 65535) -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_insecure_conn_py_to_go( - hosts, proc_factory, is_host_secure, unused_tcp_port -): - go_proc, go_pinfo = await make_echo_proc( - proc_factory, unused_tcp_port, is_host_secure - ) +class EchoProcess(BaseInteractiveProcess): + port: int + _peer_info: PeerInfo - host = hosts[0] - await host.connect(go_pinfo) - await go_proc.expect("swarm listener accepted connection", async_=True) - s = await host.new_stream(go_pinfo.peer_id, [ECHO_PROTOCOL_ID]) + def __init__( + self, port: int, is_secure: bool, destination: Multiaddr = None + ) -> None: + args = [f"-l={port}"] + if not is_secure: + args.append("-insecure") + if destination is not None: + args.append(f"-d={str(destination)}") - await go_proc.expect("Got a new stream!", async_=True) - data = "data321123\n" - await s.write(data.encode()) - await go_proc.expect(f"read: {data[:-1]}", async_=True) - echoed_resp = await s.read(len(data)) - assert echoed_resp.decode() == data - await s.close() + patterns = [b"I am"] + if destination is None: + patterns.append(b"listening for connections") + + self.args = args + self.cmd = str(ECHO_PATH) + self.patterns = patterns + self.bytes_read = bytearray() + self.event_ready = trio.Event() + + self.port = port + self._peer_info = None + self.regex_pat = re.compile(br"I am ([\w\./]+)") + + @property + def peer_info(self) -> None: + if self._peer_info is not None: + return self._peer_info + if not self.event_ready.is_set(): + raise Exception("process is not ready yet. failed to parse the peer info") + # Example: + # b"I am /ip4/127.0.0.1/tcp/56171/ipfs/QmU41TRPs34WWqa1brJEojBLYZKrrBcJq9nyNfVvSrbZUJ\n" + m = re.search(br"I am ([\w\./]+)", self.bytes_read) + if m is None: + raise Exception("failed to find the pattern for the listening multiaddr") + maddr_bytes_str_ipfs = m.group(1) + maddr_str = maddr_bytes_str_ipfs.decode().replace("ipfs", "p2p") + maddr = Multiaddr(maddr_str) + self._peer_info = info_from_p2p_addr(maddr) + return self._peer_info -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_insecure_conn_go_to_py( - hosts, proc_factory, is_host_secure, unused_tcp_port -): - host = hosts[0] - expected_data = "Hello, world!\n" - reply_data = "Replyooo!\n" - event_handler_finished = asyncio.Event() +@pytest.mark.trio +async def test_insecure_conn_py_to_go(is_host_secure): + async with HostFactory.create_batch_and_listen(is_host_secure, 1) as hosts: + go_proc = EchoProcess(unused_tcp_port_factory(), is_host_secure) + await go_proc.start() - async def _handle_echo(stream): - read_data = await stream.read(len(expected_data)) - assert read_data == expected_data.encode() - event_handler_finished.set() - await stream.write(reply_data.encode()) - await stream.close() + host = hosts[0] + peer_info = go_proc.peer_info + await host.connect(peer_info) + s = await host.new_stream(peer_info.peer_id, [ECHO_PROTOCOL_ID]) + data = "data321123\n" + await s.write(data.encode()) + echoed_resp = await s.read(len(data)) + assert echoed_resp.decode() == data + await s.close() - host.set_stream_handler(ECHO_PROTOCOL_ID, _handle_echo) - py_maddr = host.get_addrs()[0] - go_proc, _ = await make_echo_proc( - proc_factory, unused_tcp_port, is_host_secure, py_maddr - ) - await go_proc.expect("connect with peer", async_=True) - await go_proc.expect("opened stream", async_=True) - await event_handler_finished.wait() - await go_proc.expect(f"read reply: .*{reply_data.rstrip()}.*", async_=True) + +@pytest.mark.trio +async def test_insecure_conn_go_to_py(is_host_secure): + async with HostFactory.create_batch_and_listen(is_host_secure, 1) as hosts: + host = hosts[0] + expected_data = "Hello, world!\n" + reply_data = "Replyooo!\n" + event_handler_finished = trio.Event() + + async def _handle_echo(stream): + read_data = await stream.read(len(expected_data)) + assert read_data == expected_data.encode() + event_handler_finished.set() + await stream.write(reply_data.encode()) + await stream.close() + + host.set_stream_handler(ECHO_PROTOCOL_ID, _handle_echo) + py_maddr = host.get_addrs()[0] + go_proc = EchoProcess(unused_tcp_port_factory(), is_host_secure, py_maddr) + await go_proc.start() + await event_handler_finished.wait() diff --git a/tests_interop/test_net_stream.py b/tests_interop/test_net_stream.py index 2c897d2b..59812a9b 100644 --- a/tests_interop/test_net_stream.py +++ b/tests_interop/test_net_stream.py @@ -1,6 +1,5 @@ -import asyncio - import pytest +import trio from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset from libp2p.tools.constants import MAX_READ_LEN @@ -8,7 +7,7 @@ from libp2p.tools.constants import MAX_READ_LEN DATA = b"data" -@pytest.mark.asyncio +@pytest.mark.trio async def test_net_stream_read_write(py_to_daemon_stream_pair, p2pds): stream_py, stream_daemon = py_to_daemon_stream_pair assert ( @@ -19,19 +18,19 @@ async def test_net_stream_read_write(py_to_daemon_stream_pair, p2pds): assert (await stream_daemon.read(MAX_READ_LEN)) == DATA -@pytest.mark.asyncio +@pytest.mark.trio async def test_net_stream_read_after_remote_closed(py_to_daemon_stream_pair, p2pds): stream_py, stream_daemon = py_to_daemon_stream_pair await stream_daemon.write(DATA) await stream_daemon.close() - await asyncio.sleep(0.01) + await trio.sleep(0.01) assert (await stream_py.read(MAX_READ_LEN)) == DATA # EOF with pytest.raises(StreamEOF): await stream_py.read(MAX_READ_LEN) -@pytest.mark.asyncio +@pytest.mark.trio async def test_net_stream_read_after_local_reset(py_to_daemon_stream_pair, p2pds): stream_py, _ = py_to_daemon_stream_pair await stream_py.reset() @@ -40,15 +39,15 @@ async def test_net_stream_read_after_local_reset(py_to_daemon_stream_pair, p2pds @pytest.mark.parametrize("is_to_fail_daemon_stream", (True,)) -@pytest.mark.asyncio +@pytest.mark.trio async def test_net_stream_read_after_remote_reset(py_to_daemon_stream_pair, p2pds): stream_py, _ = py_to_daemon_stream_pair - await asyncio.sleep(0.01) + await trio.sleep(0.01) with pytest.raises(StreamReset): await stream_py.read(MAX_READ_LEN) -@pytest.mark.asyncio +@pytest.mark.trio async def test_net_stream_write_after_local_closed(py_to_daemon_stream_pair, p2pds): stream_py, _ = py_to_daemon_stream_pair await stream_py.write(DATA) @@ -57,7 +56,7 @@ async def test_net_stream_write_after_local_closed(py_to_daemon_stream_pair, p2p await stream_py.write(DATA) -@pytest.mark.asyncio +@pytest.mark.trio async def test_net_stream_write_after_local_reset(py_to_daemon_stream_pair, p2pds): stream_py, stream_daemon = py_to_daemon_stream_pair await stream_py.reset() @@ -66,9 +65,9 @@ async def test_net_stream_write_after_local_reset(py_to_daemon_stream_pair, p2pd @pytest.mark.parametrize("is_to_fail_daemon_stream", (True,)) -@pytest.mark.asyncio +@pytest.mark.trio async def test_net_stream_write_after_remote_reset(py_to_daemon_stream_pair, p2pds): stream_py, _ = py_to_daemon_stream_pair - await asyncio.sleep(0.01) + await trio.sleep(0.01) with pytest.raises(StreamClosed): await stream_py.write(DATA) diff --git a/tests_interop/test_pubsub.py b/tests_interop/test_pubsub.py index db42c7cd..f15d89e3 100644 --- a/tests_interop/test_pubsub.py +++ b/tests_interop/test_pubsub.py @@ -1,11 +1,15 @@ -import asyncio import functools +import math from p2pclient.pb import p2pd_pb2 import pytest +import trio +from libp2p.io.trio import TrioTCPStream from libp2p.peer.id import ID from libp2p.pubsub.pb import rpc_pb2 +from libp2p.pubsub.subscription import TrioSubscriptionAPI +from libp2p.tools.factories import PubsubFactory from libp2p.tools.interop.utils import connect from libp2p.utils import read_varint_prefixed_bytes @@ -13,26 +17,15 @@ TOPIC_0 = "ABALA" TOPIC_1 = "YOOOO" -async def p2pd_subscribe(p2pd, topic) -> "asyncio.Queue[rpc_pb2.Message]": - reader, writer = await p2pd.control.pubsub_subscribe(topic) +async def p2pd_subscribe(p2pd, topic, nursery): + stream = TrioTCPStream(await p2pd.control.pubsub_subscribe(topic)) + send_channel, receive_channel = trio.open_memory_channel(math.inf) - queue = asyncio.Queue() + sub = TrioSubscriptionAPI(receive_channel) async def _read_pubsub_msg() -> None: - writer_closed_task = asyncio.ensure_future(writer.wait_closed()) - while True: - done, pending = await asyncio.wait( - [read_varint_prefixed_bytes(reader), writer_closed_task], - return_when=asyncio.FIRST_COMPLETED, - ) - done_tasks = tuple(done) - if writer.is_closing(): - return - read_task = done_tasks[0] - # Sanity check - assert read_task._coro.__name__ == "read_varint_prefixed_bytes" - msg_bytes = read_task.result() + msg_bytes = await read_varint_prefixed_bytes(stream) ps_msg = p2pd_pb2.PSMessage() ps_msg.ParseFromString(msg_bytes) # Fill in the message used in py-libp2p @@ -44,11 +37,10 @@ async def p2pd_subscribe(p2pd, topic) -> "asyncio.Queue[rpc_pb2.Message]": signature=ps_msg.signature, key=ps_msg.key, ) - queue.put_nowait(msg) + await send_channel.send(msg) - asyncio.ensure_future(_read_pubsub_msg()) - await asyncio.sleep(0) - return queue + nursery.start_soon(_read_pubsub_msg) + return sub def validate_pubsub_msg(msg: rpc_pb2.Message, data: bytes, from_peer_id: ID) -> None: @@ -59,108 +51,119 @@ def validate_pubsub_msg(msg: rpc_pb2.Message, data: bytes, from_peer_id: ID) -> "is_pubsub_signing, is_pubsub_signing_strict", ((True, True), (False, False)) ) @pytest.mark.parametrize("is_gossipsub", (True, False)) -@pytest.mark.parametrize("num_hosts, num_p2pds", ((1, 2),)) -@pytest.mark.asyncio -async def test_pubsub(pubsubs, p2pds): - # - # Test: Recognize pubsub peers on connection. - # - py_pubsub = pubsubs[0] - # go0 <-> py <-> go1 - await connect(p2pds[0], py_pubsub.host) - await connect(py_pubsub.host, p2pds[1]) - py_peer_id = py_pubsub.host.get_id() - # Check pubsub peers - pubsub_peers_0 = await p2pds[0].control.pubsub_list_peers("") - assert len(pubsub_peers_0) == 1 and pubsub_peers_0[0] == py_peer_id - pubsub_peers_1 = await p2pds[1].control.pubsub_list_peers("") - assert len(pubsub_peers_1) == 1 and pubsub_peers_1[0] == py_peer_id - assert ( - len(py_pubsub.peers) == 2 - and p2pds[0].peer_id in py_pubsub.peers - and p2pds[1].peer_id in py_pubsub.peers - ) +@pytest.mark.parametrize("num_p2pds", (2,)) +@pytest.mark.trio +async def test_pubsub( + p2pds, is_gossipsub, is_host_secure, is_pubsub_signing_strict, nursery +): + pubsub_factory = None + if is_gossipsub: + pubsub_factory = PubsubFactory.create_batch_with_gossipsub + else: + pubsub_factory = PubsubFactory.create_batch_with_floodsub - # - # Test: `subscribe`. - # - # (name, topics) - # (go_0, [0, 1]) <-> (py, [0, 1]) <-> (go_1, [1]) - sub_py_topic_0 = await py_pubsub.subscribe(TOPIC_0) - sub_py_topic_1 = await py_pubsub.subscribe(TOPIC_1) - sub_go_0_topic_0 = await p2pd_subscribe(p2pds[0], TOPIC_0) - sub_go_0_topic_1 = await p2pd_subscribe(p2pds[0], TOPIC_1) - sub_go_1_topic_1 = await p2pd_subscribe(p2pds[1], TOPIC_1) - # Check topic peers - await asyncio.sleep(0.1) - # go_0 - go_0_topic_0_peers = await p2pds[0].control.pubsub_list_peers(TOPIC_0) - assert len(go_0_topic_0_peers) == 1 and py_peer_id == go_0_topic_0_peers[0] - go_0_topic_1_peers = await p2pds[0].control.pubsub_list_peers(TOPIC_1) - assert len(go_0_topic_1_peers) == 1 and py_peer_id == go_0_topic_1_peers[0] - # py - py_topic_0_peers = list(py_pubsub.peer_topics[TOPIC_0]) - assert len(py_topic_0_peers) == 1 and p2pds[0].peer_id == py_topic_0_peers[0] - # go_1 - go_1_topic_1_peers = await p2pds[1].control.pubsub_list_peers(TOPIC_1) - assert len(go_1_topic_1_peers) == 1 and py_peer_id == go_1_topic_1_peers[0] + async with pubsub_factory( + 1, is_secure=is_host_secure, strict_signing=is_pubsub_signing_strict + ) as pubsubs: + # + # Test: Recognize pubsub peers on connection. + # + py_pubsub = pubsubs[0] + # go0 <-> py <-> go1 + await connect(p2pds[0], py_pubsub.host) + await connect(py_pubsub.host, p2pds[1]) + py_peer_id = py_pubsub.host.get_id() + # Check pubsub peers + pubsub_peers_0 = await p2pds[0].control.pubsub_list_peers("") + assert len(pubsub_peers_0) == 1 and pubsub_peers_0[0] == py_peer_id + pubsub_peers_1 = await p2pds[1].control.pubsub_list_peers("") + assert len(pubsub_peers_1) == 1 and pubsub_peers_1[0] == py_peer_id + assert ( + len(py_pubsub.peers) == 2 + and p2pds[0].peer_id in py_pubsub.peers + and p2pds[1].peer_id in py_pubsub.peers + ) - # - # Test: `publish` - # - # 1. py publishes - # - 1.1. py publishes data_11 to topic_0, py and go_0 receives. - # - 1.2. py publishes data_12 to topic_1, all receive. - # 2. go publishes - # - 2.1. go_0 publishes data_21 to topic_0, py and go_0 receive. - # - 2.2. go_1 publishes data_22 to topic_1, all receive. + # + # Test: `subscribe`. + # + # (name, topics) + # (go_0, [0, 1]) <-> (py, [0, 1]) <-> (go_1, [1]) + sub_py_topic_0 = await py_pubsub.subscribe(TOPIC_0) + sub_py_topic_1 = await py_pubsub.subscribe(TOPIC_1) + sub_go_0_topic_0 = await p2pd_subscribe(p2pds[0], TOPIC_0, nursery) + sub_go_0_topic_1 = await p2pd_subscribe(p2pds[0], TOPIC_1, nursery) + sub_go_1_topic_1 = await p2pd_subscribe(p2pds[1], TOPIC_1, nursery) + # Check topic peers + await trio.sleep(0.1) + # go_0 + go_0_topic_0_peers = await p2pds[0].control.pubsub_list_peers(TOPIC_0) + assert len(go_0_topic_0_peers) == 1 and py_peer_id == go_0_topic_0_peers[0] + go_0_topic_1_peers = await p2pds[0].control.pubsub_list_peers(TOPIC_1) + assert len(go_0_topic_1_peers) == 1 and py_peer_id == go_0_topic_1_peers[0] + # py + py_topic_0_peers = list(py_pubsub.peer_topics[TOPIC_0]) + assert len(py_topic_0_peers) == 1 and p2pds[0].peer_id == py_topic_0_peers[0] + # go_1 + go_1_topic_1_peers = await p2pds[1].control.pubsub_list_peers(TOPIC_1) + assert len(go_1_topic_1_peers) == 1 and py_peer_id == go_1_topic_1_peers[0] - # 1.1. py publishes data_11 to topic_0, py and go_0 receives. - data_11 = b"data_11" - await py_pubsub.publish(TOPIC_0, data_11) - validate_11 = functools.partial( - validate_pubsub_msg, data=data_11, from_peer_id=py_peer_id - ) - validate_11(await sub_py_topic_0.get()) - validate_11(await sub_go_0_topic_0.get()) + # + # Test: `publish` + # + # 1. py publishes + # - 1.1. py publishes data_11 to topic_0, py and go_0 receives. + # - 1.2. py publishes data_12 to topic_1, all receive. + # 2. go publishes + # - 2.1. go_0 publishes data_21 to topic_0, py and go_0 receive. + # - 2.2. go_1 publishes data_22 to topic_1, all receive. - # 1.2. py publishes data_12 to topic_1, all receive. - data_12 = b"data_12" - validate_12 = functools.partial( - validate_pubsub_msg, data=data_12, from_peer_id=py_peer_id - ) - await py_pubsub.publish(TOPIC_1, data_12) - validate_12(await sub_py_topic_1.get()) - validate_12(await sub_go_0_topic_1.get()) - validate_12(await sub_go_1_topic_1.get()) + # 1.1. py publishes data_11 to topic_0, py and go_0 receives. + data_11 = b"data_11" + await py_pubsub.publish(TOPIC_0, data_11) + validate_11 = functools.partial( + validate_pubsub_msg, data=data_11, from_peer_id=py_peer_id + ) + validate_11(await sub_py_topic_0.get()) + validate_11(await sub_go_0_topic_0.get()) - # 2.1. go_0 publishes data_21 to topic_0, py and go_0 receive. - data_21 = b"data_21" - validate_21 = functools.partial( - validate_pubsub_msg, data=data_21, from_peer_id=p2pds[0].peer_id - ) - await p2pds[0].control.pubsub_publish(TOPIC_0, data_21) - validate_21(await sub_py_topic_0.get()) - validate_21(await sub_go_0_topic_0.get()) + # 1.2. py publishes data_12 to topic_1, all receive. + data_12 = b"data_12" + validate_12 = functools.partial( + validate_pubsub_msg, data=data_12, from_peer_id=py_peer_id + ) + await py_pubsub.publish(TOPIC_1, data_12) + validate_12(await sub_py_topic_1.get()) + validate_12(await sub_go_0_topic_1.get()) + validate_12(await sub_go_1_topic_1.get()) - # 2.2. go_1 publishes data_22 to topic_1, all receive. - data_22 = b"data_22" - validate_22 = functools.partial( - validate_pubsub_msg, data=data_22, from_peer_id=p2pds[1].peer_id - ) - await p2pds[1].control.pubsub_publish(TOPIC_1, data_22) - validate_22(await sub_py_topic_1.get()) - validate_22(await sub_go_0_topic_1.get()) - validate_22(await sub_go_1_topic_1.get()) + # 2.1. go_0 publishes data_21 to topic_0, py and go_0 receive. + data_21 = b"data_21" + validate_21 = functools.partial( + validate_pubsub_msg, data=data_21, from_peer_id=p2pds[0].peer_id + ) + await p2pds[0].control.pubsub_publish(TOPIC_0, data_21) + validate_21(await sub_py_topic_0.get()) + validate_21(await sub_go_0_topic_0.get()) - # - # Test: `unsubscribe` and re`subscribe` - # - await py_pubsub.unsubscribe(TOPIC_0) - await asyncio.sleep(0.1) - assert py_peer_id not in (await p2pds[0].control.pubsub_list_peers(TOPIC_0)) - assert py_peer_id not in (await p2pds[1].control.pubsub_list_peers(TOPIC_0)) - await py_pubsub.subscribe(TOPIC_0) - await asyncio.sleep(0.1) - assert py_peer_id in (await p2pds[0].control.pubsub_list_peers(TOPIC_0)) - assert py_peer_id in (await p2pds[1].control.pubsub_list_peers(TOPIC_0)) + # 2.2. go_1 publishes data_22 to topic_1, all receive. + data_22 = b"data_22" + validate_22 = functools.partial( + validate_pubsub_msg, data=data_22, from_peer_id=p2pds[1].peer_id + ) + await p2pds[1].control.pubsub_publish(TOPIC_1, data_22) + validate_22(await sub_py_topic_1.get()) + validate_22(await sub_go_0_topic_1.get()) + validate_22(await sub_go_1_topic_1.get()) + + # + # Test: `unsubscribe` and re`subscribe` + # + await py_pubsub.unsubscribe(TOPIC_0) + await trio.sleep(0.1) + assert py_peer_id not in (await p2pds[0].control.pubsub_list_peers(TOPIC_0)) + assert py_peer_id not in (await p2pds[1].control.pubsub_list_peers(TOPIC_0)) + await py_pubsub.subscribe(TOPIC_0) + await trio.sleep(0.1) + assert py_peer_id in (await p2pds[0].control.pubsub_list_peers(TOPIC_0)) + assert py_peer_id in (await p2pds[1].control.pubsub_list_peers(TOPIC_0)) diff --git a/tox.ini b/tox.ini index 21f46435..00af177f 100644 --- a/tox.ini +++ b/tox.ini @@ -12,7 +12,7 @@ envlist = combine_as_imports=False force_sort_within_sections=True include_trailing_comma=True -known_third_party=hypothesis,pytest,p2pclient,pexpect,factory +known_third_party=anyio,factory,p2pclient,pytest known_first_party=libp2p line_length=88 multi_line_output=3 From 52f85586b8929f635c21d0b856a3f4bfb455197f Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 7 Jan 2020 15:41:44 +0800 Subject: [PATCH 36/81] Fix docs --- docs/libp2p.pubsub.rst | 22 +++++++++++++++++++--- libp2p/kademlia/__init__.py | 2 -- 2 files changed, 19 insertions(+), 5 deletions(-) delete mode 100644 libp2p/kademlia/__init__.py diff --git a/docs/libp2p.pubsub.rst b/docs/libp2p.pubsub.rst index d217772b..38dce6ff 100644 --- a/docs/libp2p.pubsub.rst +++ b/docs/libp2p.pubsub.rst @@ -11,6 +11,22 @@ Subpackages Submodules ---------- +libp2p.pubsub.abc module +------------------------ + +.. automodule:: libp2p.pubsub.abc + :members: + :undoc-members: + :show-inheritance: + +libp2p.pubsub.exceptions module +------------------------------- + +.. automodule:: libp2p.pubsub.exceptions + :members: + :undoc-members: + :show-inheritance: + libp2p.pubsub.floodsub module ----------------------------- @@ -51,10 +67,10 @@ libp2p.pubsub.pubsub\_notifee module :undoc-members: :show-inheritance: -libp2p.pubsub.pubsub\_router\_interface module ----------------------------------------------- +libp2p.pubsub.subscription module +--------------------------------- -.. automodule:: libp2p.pubsub.pubsub_router_interface +.. automodule:: libp2p.pubsub.subscription :members: :undoc-members: :show-inheritance: diff --git a/libp2p/kademlia/__init__.py b/libp2p/kademlia/__init__.py deleted file mode 100644 index e85d6e0a..00000000 --- a/libp2p/kademlia/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""Kademlia is a Python implementation of the Kademlia protocol.""" -__version__ = "2.0" From 4db043a26af3d9d1cfa794370bec70d851ed7cb6 Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 7 Jan 2020 16:23:00 +0800 Subject: [PATCH 37/81] Remove pexpect from tox --- tox.ini | 1 - 1 file changed, 1 deletion(-) diff --git a/tox.ini b/tox.ini index 00af177f..9c703d8e 100644 --- a/tox.ini +++ b/tox.ini @@ -58,7 +58,6 @@ commands = [testenv:py37-interop] deps = p2pclient - pexpect passenv = CI TRAVIS TRAVIS_* GOPATH extras = test commands = From 45eeb4fba3df83139f9300ec3daa79b499d80861 Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 7 Jan 2020 16:45:06 +0800 Subject: [PATCH 38/81] Change `notify_xxx` to sync functions Since we already have `Swarm.run_task`, we can just change notify functions to sync. --- libp2p/network/connection/swarm_connection.py | 14 ++++---- libp2p/network/swarm.py | 33 ++++++++----------- 2 files changed, 21 insertions(+), 26 deletions(-) diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index a4fc8be4..90cb8231 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -54,7 +54,7 @@ class SwarmConn(INetConn, Service): # before we cancel the stream handler tasks. await trio.sleep(0.1) - await self._notify_disconnected() + self._notify_disconnected() async def _handle_new_streams(self) -> None: while self.manager.is_running: @@ -68,7 +68,7 @@ class SwarmConn(INetConn, Service): await self.close() async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None: - net_stream = await self._add_stream(muxed_stream) + net_stream = self._add_stream(muxed_stream) if self.swarm.common_stream_handler is not None: try: await self.swarm.common_stream_handler(net_stream) @@ -78,14 +78,14 @@ class SwarmConn(INetConn, Service): # TODO: Clean up and remove the stream from SwarmConn if there is anything wrong. self.remove_stream(net_stream) - async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: + def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: net_stream = NetStream(muxed_stream) self.streams.add(net_stream) - await self.swarm.notify_opened_stream(net_stream) + self.swarm.notify_opened_stream(net_stream) return net_stream - async def _notify_disconnected(self) -> None: - await self.swarm.notify_disconnected(self) + def _notify_disconnected(self) -> None: + self.swarm.notify_disconnected(self) async def run(self) -> None: self.manager.run_task(self._handle_new_streams) @@ -93,7 +93,7 @@ class SwarmConn(INetConn, Service): async def new_stream(self) -> NetStream: muxed_stream = await self.muxed_conn.open_stream() - return await self._add_stream(muxed_stream) + return self._add_stream(muxed_stream) async def get_streams(self) -> Tuple[NetStream, ...]: return tuple(self.streams) diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 9d19717b..0904774d 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -250,7 +250,7 @@ class Swarm(INetworkService): await listener.listen(maddr, self.manager._task_nursery) # type: ignore # Call notifiers since event occurred - await self.notify_listen(maddr) + self.notify_listen(maddr) return True except IOError: @@ -297,7 +297,7 @@ class Swarm(INetworkService): # Store muxed_conn with peer id self.connections[muxed_conn.peer_id] = swarm_conn # Call notifiers since event occurred - self.manager.run_task(self.notify_connected, swarm_conn) + self.notify_connected(swarm_conn) await manager.wait_started() return swarm_conn @@ -320,27 +320,22 @@ class Swarm(INetworkService): """ self.notifees.append(notifee) - # TODO: Use `run_task`. - async def notify_opened_stream(self, stream: INetStream) -> None: - async with trio.open_nursery() as nursery: - for notifee in self.notifees: - nursery.start_soon(notifee.opened_stream, self, stream) + def notify_opened_stream(self, stream: INetStream) -> None: + for notifee in self.notifees: + self.manager.run_task(notifee.opened_stream, self, stream) # TODO: `notify_closed_stream` - async def notify_connected(self, conn: INetConn) -> None: - async with trio.open_nursery() as nursery: - for notifee in self.notifees: - nursery.start_soon(notifee.connected, self, conn) + def notify_connected(self, conn: INetConn) -> None: + for notifee in self.notifees: + self.manager.run_task(notifee.connected, self, conn) - async def notify_disconnected(self, conn: INetConn) -> None: - async with trio.open_nursery() as nursery: - for notifee in self.notifees: - nursery.start_soon(notifee.disconnected, self, conn) + def notify_disconnected(self, conn: INetConn) -> None: + for notifee in self.notifees: + self.manager.run_task(notifee.disconnected, self, conn) - async def notify_listen(self, multiaddr: Multiaddr) -> None: - async with trio.open_nursery() as nursery: - for notifee in self.notifees: - nursery.start_soon(notifee.listen, self, multiaddr) + def notify_listen(self, multiaddr: Multiaddr) -> None: + for notifee in self.notifees: + self.manager.run_task(notifee.listen, self, multiaddr) # TODO: `notify_listen_close` From eab59482c06c822ac450a28835396cdb4743c8d6 Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 7 Jan 2020 16:45:59 +0800 Subject: [PATCH 39/81] Use the real `get_unused_tcp_port` To get rid of the fake one --- tests_interop/test_echo.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tests_interop/test_echo.py b/tests_interop/test_echo.py index d3cb72fd..85810a7f 100644 --- a/tests_interop/test_echo.py +++ b/tests_interop/test_echo.py @@ -1,7 +1,7 @@ -import random import re from multiaddr import Multiaddr +from p2pclient.utils import get_unused_tcp_port import pytest import trio @@ -15,11 +15,6 @@ ECHO_PATH = GO_BIN_PATH / "echo" ECHO_PROTOCOL_ID = TProtocol("/echo/1.0.0") -# FIXME: Change to a reasonable implementation -def unused_tcp_port_factory(): - return random.randint(1024, 65535) - - class EchoProcess(BaseInteractiveProcess): port: int _peer_info: PeerInfo @@ -68,7 +63,7 @@ class EchoProcess(BaseInteractiveProcess): @pytest.mark.trio async def test_insecure_conn_py_to_go(is_host_secure): async with HostFactory.create_batch_and_listen(is_host_secure, 1) as hosts: - go_proc = EchoProcess(unused_tcp_port_factory(), is_host_secure) + go_proc = EchoProcess(get_unused_tcp_port(), is_host_secure) await go_proc.start() host = hosts[0] @@ -99,6 +94,6 @@ async def test_insecure_conn_go_to_py(is_host_secure): host.set_stream_handler(ECHO_PROTOCOL_ID, _handle_echo) py_maddr = host.get_addrs()[0] - go_proc = EchoProcess(unused_tcp_port_factory(), is_host_secure, py_maddr) + go_proc = EchoProcess(get_unused_tcp_port(), is_host_secure, py_maddr) await go_proc.start() await event_handler_finished.wait() From eef241e70e4d26ec44a10db315ceca0421bb2013 Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 7 Jan 2020 21:50:03 +0800 Subject: [PATCH 40/81] Make `Mplex` and `SwarmConn` not `Service` After second thoughts, they seem not a good candidate of `Service`. The shutdown logic becomes simpler by making them not `Service`. --- .../connection/net_connection_interface.py | 5 ++- libp2p/network/connection/swarm_connection.py | 19 ++++++------ libp2p/network/swarm.py | 31 +++---------------- libp2p/stream_muxer/abc.py | 13 +++++--- libp2p/stream_muxer/mplex/mplex.py | 15 +++++---- libp2p/stream_muxer/mplex/mplex_stream.py | 4 +-- tests/network/test_swarm_conn.py | 17 +++++----- 7 files changed, 43 insertions(+), 61 deletions(-) diff --git a/libp2p/network/connection/net_connection_interface.py b/libp2p/network/connection/net_connection_interface.py index e308ad65..f1bcac24 100644 --- a/libp2p/network/connection/net_connection_interface.py +++ b/libp2p/network/connection/net_connection_interface.py @@ -1,6 +1,8 @@ from abc import abstractmethod from typing import Tuple +import trio + from libp2p.io.abc import Closer from libp2p.network.stream.net_stream_interface import INetStream from libp2p.stream_muxer.abc import IMuxedConn @@ -8,11 +10,12 @@ from libp2p.stream_muxer.abc import IMuxedConn class INetConn(Closer): muxed_conn: IMuxedConn + event_started: trio.Event @abstractmethod async def new_stream(self) -> INetStream: ... @abstractmethod - async def get_streams(self) -> Tuple[INetStream, ...]: + def get_streams(self) -> Tuple[INetStream, ...]: ... diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index 90cb8231..7cf3c0ff 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -1,6 +1,5 @@ from typing import TYPE_CHECKING, Set, Tuple -from async_service import Service import trio from libp2p.network.connection.net_connection_interface import INetConn @@ -17,10 +16,11 @@ Reference: https://github.com/libp2p/go-libp2p-swarm/blob/04c86bbdafd390651cb2ee """ -class SwarmConn(INetConn, Service): +class SwarmConn(INetConn): muxed_conn: IMuxedConn swarm: "Swarm" streams: Set[NetStream] + event_started: trio.Event event_closed: trio.Event def __init__(self, muxed_conn: IMuxedConn, swarm: "Swarm") -> None: @@ -28,6 +28,7 @@ class SwarmConn(INetConn, Service): self.swarm = swarm self.streams = set() self.event_closed = trio.Event() + self.event_started = trio.Event() @property def is_closed(self) -> bool: @@ -38,8 +39,6 @@ class SwarmConn(INetConn, Service): return self.event_closed.set() await self._cleanup() - # Cancel service - await self.manager.stop() async def _cleanup(self) -> None: self.swarm.remove_conn(self) @@ -57,13 +56,14 @@ class SwarmConn(INetConn, Service): self._notify_disconnected() async def _handle_new_streams(self) -> None: - while self.manager.is_running: + self.event_started.set() + while True: try: stream = await self.muxed_conn.accept_stream() # Asynchronously handle the accepted stream, to avoid blocking the next stream. except MuxedConnUnavailable: break - self.manager.run_task(self._handle_muxed_stream, stream) + self.swarm.manager.run_task(self._handle_muxed_stream, stream) await self.close() @@ -87,15 +87,14 @@ class SwarmConn(INetConn, Service): def _notify_disconnected(self) -> None: self.swarm.notify_disconnected(self) - async def run(self) -> None: - self.manager.run_task(self._handle_new_streams) - await self.manager.wait_finished() + async def start(self) -> None: + await self._handle_new_streams() async def new_stream(self) -> NetStream: muxed_stream = await self.muxed_conn.open_stream() return self._add_stream(muxed_stream) - async def get_streams(self) -> Tuple[NetStream, ...]: + def get_streams(self) -> Tuple[NetStream, ...]: return tuple(self.streams) def remove_stream(self, stream: NetStream) -> None: diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 0904774d..45d85b19 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -2,7 +2,6 @@ import logging from typing import Dict, List, Optional from multiaddr import Multiaddr -import trio from libp2p.io.abc import ReadWriteCloser from libp2p.network.connection.net_connection_interface import INetConn @@ -44,7 +43,6 @@ class Swarm(INetworkService): common_stream_handler: Optional[StreamHandlerFn] notifees: List[INotifee] - event_closed: trio.Event def __init__( self, @@ -63,8 +61,6 @@ class Swarm(INetworkService): # Create Notifee array self.notifees = [] - self.event_closed = trio.Event() - self.common_stream_handler = None async def run(self) -> None: @@ -158,13 +154,11 @@ class Swarm(INetworkService): try: muxed_conn = await self.upgrader.upgrade_connection(secured_conn, peer_id) - self.manager.run_child_service(muxed_conn) except MuxerUpgradeFailure as error: error_msg = "fail to upgrade mux for peer %s" logger.debug(error_msg, peer_id) await secured_conn.close() raise SwarmException(error_msg % peer_id) from error - logger.debug("upgraded mux for peer %s", peer_id) swarm_conn = await self.add_conn(muxed_conn) @@ -226,7 +220,6 @@ class Swarm(INetworkService): muxed_conn = await self.upgrader.upgrade_connection( secured_conn, peer_id ) - self.manager.run_child_service(muxed_conn) except MuxerUpgradeFailure as error: error_msg = "fail to upgrade mux for peer %s" logger.debug(error_msg, peer_id) @@ -235,8 +228,8 @@ class Swarm(INetworkService): logger.debug("upgraded mux for peer %s", peer_id) await self.add_conn(muxed_conn) - logger.debug("successfully opened connection to peer %s", peer_id) + # NOTE: This is a intentional barrier to prevent from the handler exiting and # closing the connection. await self.manager.wait_finished() @@ -261,26 +254,12 @@ class Swarm(INetworkService): return False async def close(self) -> None: - if self.event_closed.is_set(): - return - self.event_closed.set() - # Reference: https://github.com/libp2p/go-libp2p-swarm/blob/8be680aef8dea0a4497283f2f98470c2aeae6b65/swarm.go#L124-L134 # noqa: E501 - async with trio.open_nursery() as nursery: - for conn in self.connections.values(): - nursery.start_soon(conn.close) - async with trio.open_nursery() as nursery: - for listener in self.listeners.values(): - nursery.start_soon(listener.close) - - # Cancel tasks await self.manager.stop() logger.debug("swarm successfully closed") async def close_peer(self, peer_id: ID) -> None: if peer_id not in self.connections: return - # TODO: Should be changed to close multisple connections, - # if we have several connections per peer in the future. connection = self.connections[peer_id] # NOTE: `connection.close` will delete `peer_id` from `self.connections` # and `notify_disconnected` for us. @@ -293,12 +272,14 @@ class Swarm(INetworkService): and start to monitor the connection for its new streams and disconnection.""" swarm_conn = SwarmConn(muxed_conn, self) - manager = self.manager.run_child_service(swarm_conn) + self.manager.run_task(muxed_conn.start) + await muxed_conn.event_started.wait() + self.manager.run_task(swarm_conn.start) + await swarm_conn.event_started.wait() # Store muxed_conn with peer id self.connections[muxed_conn.peer_id] = swarm_conn # Call notifiers since event occurred self.notify_connected(swarm_conn) - await manager.wait_started() return swarm_conn def remove_conn(self, swarm_conn: SwarmConn) -> None: @@ -307,8 +288,6 @@ class Swarm(INetworkService): peer_id = swarm_conn.muxed_conn.peer_id if peer_id not in self.connections: return - # TODO: Should be changed to remove the exact connection, - # if we have several connections per peer in the future. del self.connections[peer_id] # Notifee diff --git a/libp2p/stream_muxer/abc.py b/libp2p/stream_muxer/abc.py index e34295cf..82140ff4 100644 --- a/libp2p/stream_muxer/abc.py +++ b/libp2p/stream_muxer/abc.py @@ -1,18 +1,19 @@ -from abc import abstractmethod +from abc import ABC, abstractmethod -from async_service import ServiceAPI +import trio from libp2p.io.abc import ReadWriteCloser from libp2p.peer.id import ID from libp2p.security.secure_conn_interface import ISecureConn -class IMuxedConn(ServiceAPI): +class IMuxedConn(ABC): """ reference: https://github.com/libp2p/go-stream-muxer/blob/master/muxer.go """ peer_id: ID + event_started: trio.Event @abstractmethod def __init__(self, conn: ISecureConn, peer_id: ID) -> None: @@ -27,7 +28,11 @@ class IMuxedConn(ServiceAPI): @property @abstractmethod def is_initiator(self) -> bool: - pass + """if this connection is the initiator.""" + + @abstractmethod + async def start(self) -> None: + """start the multiplexer.""" @abstractmethod async def close(self) -> None: diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 6523b488..486fd3f5 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -2,7 +2,6 @@ import logging import math from typing import Dict, Optional, Tuple -from async_service import Service import trio from libp2p.exceptions import ParseError @@ -29,7 +28,7 @@ MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0") logger = logging.getLogger("libp2p.stream_muxer.mplex.mplex") -class Mplex(IMuxedConn, Service): +class Mplex(IMuxedConn): """ reference: https://github.com/libp2p/go-mplex/blob/master/multiplex.go """ @@ -45,6 +44,7 @@ class Mplex(IMuxedConn, Service): event_shutting_down: trio.Event event_closed: trio.Event + event_started: trio.Event def __init__(self, secured_conn: ISecureConn, peer_id: ID) -> None: """ @@ -73,10 +73,10 @@ class Mplex(IMuxedConn, Service): self.new_stream_send_channel, self.new_stream_receive_channel = channels self.event_shutting_down = trio.Event() self.event_closed = trio.Event() + self.event_started = trio.Event() - async def run(self) -> None: - self.manager.run_task(self.handle_incoming) - await self.manager.wait_finished() + async def start(self) -> None: + await self.handle_incoming() @property def is_initiator(self) -> bool: @@ -91,7 +91,6 @@ class Mplex(IMuxedConn, Service): await self.secured_conn.close() # Blocked until `close` is finally set. await self.event_closed.wait() - await self.manager.stop() @property def is_closed(self) -> bool: @@ -178,8 +177,8 @@ class Mplex(IMuxedConn, Service): async def handle_incoming(self) -> None: """Read a message off of the secured connection and add it to the corresponding message buffer.""" - - while self.manager.is_running: + self.event_started.set() + while True: try: await self._handle_incoming_message() except MplexUnavailable as e: diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index fc3d2747..ae6f7ea0 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -188,9 +188,7 @@ class MplexStream(IMuxedStream): if self.is_initiator else HeaderTags.ResetReceiver ) - self.muxed_conn.manager.run_task( - self.muxed_conn.send_message, flag, None, self.stream_id - ) + await self.muxed_conn.send_message(flag, None, self.stream_id) self.event_local_closed.set() self.event_remote_closed.set() diff --git a/tests/network/test_swarm_conn.py b/tests/network/test_swarm_conn.py index 1bfd7d86..dc692f44 100644 --- a/tests/network/test_swarm_conn.py +++ b/tests/network/test_swarm_conn.py @@ -14,7 +14,6 @@ async def test_swarm_conn_close(swarm_conn_pair): await trio.sleep(0.1) await wait_all_tasks_blocked() - await conn_0.manager.wait_finished() assert conn_0.is_closed assert conn_1.is_closed @@ -26,22 +25,22 @@ async def test_swarm_conn_close(swarm_conn_pair): async def test_swarm_conn_streams(swarm_conn_pair): conn_0, conn_1 = swarm_conn_pair - assert len(await conn_0.get_streams()) == 0 - assert len(await conn_1.get_streams()) == 0 + assert len(conn_0.get_streams()) == 0 + assert len(conn_1.get_streams()) == 0 stream_0_0 = await conn_0.new_stream() await trio.sleep(0.01) - assert len(await conn_0.get_streams()) == 1 - assert len(await conn_1.get_streams()) == 1 + assert len(conn_0.get_streams()) == 1 + assert len(conn_1.get_streams()) == 1 stream_0_1 = await conn_0.new_stream() await trio.sleep(0.01) - assert len(await conn_0.get_streams()) == 2 - assert len(await conn_1.get_streams()) == 2 + assert len(conn_0.get_streams()) == 2 + assert len(conn_1.get_streams()) == 2 conn_0.remove_stream(stream_0_0) - assert len(await conn_0.get_streams()) == 1 + assert len(conn_0.get_streams()) == 1 conn_0.remove_stream(stream_0_1) - assert len(await conn_0.get_streams()) == 0 + assert len(conn_0.get_streams()) == 0 # Nothing happen if `stream_0_1` is not present or already removed. conn_0.remove_stream(stream_0_1) From 54871024cc00e3936b1821bf5784adae503b5519 Mon Sep 17 00:00:00 2001 From: mhchia Date: Thu, 16 Jan 2020 18:54:19 +0800 Subject: [PATCH 41/81] Pin the version of async-service to a4 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 5f3a873d..031d9112 100644 --- a/setup.py +++ b/setup.py @@ -75,7 +75,7 @@ install_requires = [ "dataclasses>=0.7, <1;python_version<'3.7'", "async_generator==1.10", "trio>=0.13.0", - "async-service==0.1.0a2", + "async-service>=0.1.0a4", "async-exit-stack==1.0.1", ] From 6c7aa301910e7284366c2adbdc5e9332ea93c400 Mon Sep 17 00:00:00 2001 From: mhchia Date: Sat, 18 Jan 2020 00:17:30 +0800 Subject: [PATCH 42/81] Add events in Pubsub To ensure `handle_peer_queue` and `handle_dead_peer_queue` are indeed run before the tests finish. Previously, we get errors when performing `iter_dag` after cancellation. This is because `handle_peer_queue` or `handle_dead_peer_queue` is not actually run before the Service is cancelled. --- libp2p/pubsub/pubsub.py | 8 ++++++++ libp2p/tools/factories.py | 2 ++ tests/pubsub/test_gossipsub.py | 1 + tests/pubsub/test_pubsub.py | 6 +++++- 4 files changed, 16 insertions(+), 1 deletion(-) diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index e0026de8..3df0ac29 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -86,6 +86,9 @@ class Pubsub(IPubsub, Service): strict_signing: bool sign_key: PrivateKey + event_handle_peer_queue_started: trio.Event + event_handle_dead_peer_queue_started: trio.Event + def __init__( self, host: IHost, @@ -159,6 +162,9 @@ class Pubsub(IPubsub, Service): self.counter = int(time.time()) + self.event_handle_peer_queue_started = trio.Event() + self.event_handle_dead_peer_queue_started = trio.Event() + async def run(self) -> None: self.manager.run_daemon_task(self.handle_peer_queue) self.manager.run_daemon_task(self.handle_dead_peer_queue) @@ -331,12 +337,14 @@ class Pubsub(IPubsub, Service): """Continuously read from peer queue and each time a new peer is found, open a stream to the peer using a supported pubsub protocol pubsub protocols we support.""" + self.event_handle_peer_queue_started.set() async with self.peer_receive_channel: async for peer_id in self.peer_receive_channel: # Add Peer self.manager.run_task(self._handle_new_peer, peer_id) async def handle_dead_peer_queue(self) -> None: + self.event_handle_dead_peer_queue_started.set() """Continuously read from dead peer channel and close the stream between that peer and remove peer info from pubsub and pubsub router.""" diff --git a/libp2p/tools/factories.py b/libp2p/tools/factories.py index ae99f345..a20b59fa 100644 --- a/libp2p/tools/factories.py +++ b/libp2p/tools/factories.py @@ -245,6 +245,8 @@ class PubsubFactory(factory.Factory): strict_signing=strict_signing, ) async with background_trio_service(pubsub): + await pubsub.event_handle_peer_queue_started.wait() + await pubsub.event_handle_dead_peer_queue_started.wait() yield pubsub @classmethod diff --git a/tests/pubsub/test_gossipsub.py b/tests/pubsub/test_gossipsub.py index 4630c85f..a423fbd6 100644 --- a/tests/pubsub/test_gossipsub.py +++ b/tests/pubsub/test_gossipsub.py @@ -106,6 +106,7 @@ async def test_handle_graft(monkeypatch): async def emit_prune(topic, sender_peer_id): event_emit_prune.set() + await trio.hazmat.checkpoint() monkeypatch.setattr(gossipsubs[index_bob], "emit_prune", emit_prune) diff --git a/tests/pubsub/test_pubsub.py b/tests/pubsub/test_pubsub.py index c05aecbb..4bea6dd2 100644 --- a/tests/pubsub/test_pubsub.py +++ b/tests/pubsub/test_pubsub.py @@ -103,6 +103,7 @@ async def test_set_and_remove_topic_validator(): async def async_validator(peer_id, msg): nonlocal is_async_validator_called is_async_validator_called = True + await trio.hazmat.checkpoint() topic = "TEST_VALIDATOR" @@ -237,6 +238,7 @@ async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed): @pytest.mark.trio async def test_continuously_read_stream(monkeypatch, nursery, is_host_secure): async def wait_for_event_occurring(event): + await trio.hazmat.checkpoint() with trio.fail_after(0.1): await event.wait() @@ -418,6 +420,7 @@ async def test_publish(monkeypatch): async def push_msg(msg_forwarder, msg): msg_forwarders.append(msg_forwarder) msgs.append(msg) + await trio.hazmat.checkpoint() async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: with monkeypatch.context() as m: @@ -454,7 +457,7 @@ async def test_push_msg(monkeypatch): async def router_publish(*args, **kwargs): event.set() - await trio.sleep(0) + await trio.hazmat.checkpoint() with monkeypatch.context() as m: m.setattr(pubsubs_fsub[0].router, "publish", router_publish) @@ -555,6 +558,7 @@ async def test_strict_signing_failed_validation(monkeypatch): # Use router publish to check if `push_msg` succeed. async def router_publish(*args, **kwargs): + await trio.hazmat.checkpoint() # The event will only be set if `push_msg` succeed. event.set() From f0c4254bbd174172e947af93dcc16db62c39c3d4 Mon Sep 17 00:00:00 2001 From: mhchia Date: Sat, 18 Jan 2020 00:31:39 +0800 Subject: [PATCH 43/81] Use `Service` instead of `ServiceAPI` To fix error with async-service==0.1.0a5 --- libp2p/network/network_interface.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libp2p/network/network_interface.py b/libp2p/network/network_interface.py index c759a411..156a71b9 100644 --- a/libp2p/network/network_interface.py +++ b/libp2p/network/network_interface.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Dict, Sequence -from async_service import ServiceAPI +from async_service import Service from multiaddr import Multiaddr from libp2p.network.connection.net_connection_interface import INetConn @@ -73,5 +73,5 @@ class INetwork(ABC): pass -class INetworkService(INetwork, ServiceAPI): +class INetworkService(INetwork, Service): ... From 6e01a7da31470408486037dba64d9fe32d9b9a7f Mon Sep 17 00:00:00 2001 From: mhchia Date: Sun, 26 Jan 2020 16:44:42 +0800 Subject: [PATCH 44/81] PR feedback: async with host.run() --- examples/chat/chat.py | 7 +-- examples/echo/echo.py | 7 +-- libp2p/__init__.py | 64 ++++++++------------- libp2p/host/basic_host.py | 26 +++++++-- libp2p/host/host_interface.py | 16 +++++- libp2p/host/routed_host.py | 4 +- tests/host/test_basic_host.py | 4 +- tests/security/test_security_multistream.py | 18 +++--- 8 files changed, 75 insertions(+), 71 deletions(-) diff --git a/examples/chat/chat.py b/examples/chat/chat.py index 80dcc862..41aad927 100755 --- a/examples/chat/chat.py +++ b/examples/chat/chat.py @@ -4,7 +4,7 @@ import sys import multiaddr import trio -from libp2p import new_host_trio +from libp2p import new_host from libp2p.network.stream.net_stream_interface import INetStream from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.typing import TProtocol @@ -34,9 +34,8 @@ async def write_data(stream: INetStream) -> None: async def run(port: int, destination: str) -> None: localhost_ip = "127.0.0.1" listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") - async with new_host_trio( - listen_addrs=[listen_addr] - ) as host, trio.open_nursery() as nursery: + host = new_host() + async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery: if not destination: # its the server async def stream_handler(stream: INetStream) -> None: diff --git a/examples/echo/echo.py b/examples/echo/echo.py index b39f8138..5ea8ab4a 100644 --- a/examples/echo/echo.py +++ b/examples/echo/echo.py @@ -3,7 +3,7 @@ import argparse import multiaddr import trio -from libp2p import new_host_trio +from libp2p import new_host from libp2p.crypto.secp256k1 import create_new_key_pair from libp2p.network.stream.net_stream_interface import INetStream from libp2p.peer.peerinfo import info_from_p2p_addr @@ -34,9 +34,8 @@ async def run(port: int, destination: str, seed: int = None) -> None: secret = secrets.token_bytes(32) - async with new_host_trio( - listen_addrs=[listen_addr], key_pair=create_new_key_pair(secret) - ) as host: + host = new_host(key_pair=create_new_key_pair(secret)) + async with host.run(listen_addrs=[listen_addr]): print(f"I am {host.get_id().to_string()}") diff --git a/libp2p/__init__.py b/libp2p/__init__.py index 658fceb3..4d91b9da 100644 --- a/libp2p/__init__.py +++ b/libp2p/__init__.py @@ -1,14 +1,9 @@ -from typing import AsyncIterator, Sequence - -from async_generator import asynccontextmanager -from async_service import background_trio_service - from libp2p.crypto.keys import KeyPair from libp2p.crypto.rsa import create_new_key_pair from libp2p.host.basic_host import BasicHost from libp2p.host.host_interface import IHost from libp2p.host.routed_host import RoutedHost -from libp2p.network.network_interface import INetwork, INetworkService +from libp2p.network.network_interface import INetworkService from libp2p.network.swarm import Swarm from libp2p.peer.id import ID from libp2p.peer.peerstore import PeerStore @@ -32,14 +27,14 @@ def generate_peer_id_from(key_pair: KeyPair) -> ID: return ID.from_pubkey(public_key) -def initialize_default_swarm( +def new_swarm( key_pair: KeyPair = None, muxer_opt: TMuxerOptions = None, sec_opt: TSecurityOptions = None, peerstore_opt: IPeerStore = None, ) -> INetworkService: """ - initialize swarm when no swarm is passed in. + Create a swarm instance based on the parameters. :param key_pair: optional choice of the ``KeyPair`` :param muxer_opt: optional choice of stream muxer @@ -48,7 +43,7 @@ def initialize_default_swarm( :return: return a default swarm instance """ - if not key_pair: + if key_pair is None: key_pair = generate_new_rsa_identity() id_opt = generate_peer_id_from(key_pair) @@ -72,45 +67,32 @@ def initialize_default_swarm( return Swarm(id_opt, peerstore, upgrader, transport) -def _new_host(swarm_opt: INetwork, disc_opt: IPeerRouting = None) -> IHost: - """ - create new libp2p host. - - :param swarm_opt: optional swarm - :param disc_opt: optional discovery - :return: return a host instance - """ - # TODO enable support for other host type - # TODO routing unimplemented - host: IHost # If not explicitly typed, MyPy raises error - if disc_opt: - host = RoutedHost(swarm_opt, disc_opt) - else: - host = BasicHost(swarm_opt) - - return host - - -@asynccontextmanager -async def new_host_trio( - listen_addrs: Sequence[str], +def new_host( key_pair: KeyPair = None, - swarm_opt: INetwork = None, muxer_opt: TMuxerOptions = None, sec_opt: TSecurityOptions = None, peerstore_opt: IPeerStore = None, disc_opt: IPeerRouting = None, -) -> AsyncIterator[IHost]: - swarm = initialize_default_swarm( +) -> IHost: + """ + Create a new libp2p host based on the given parameters. + + :param key_pair: optional choice of the ``KeyPair`` + :param muxer_opt: optional choice of stream muxer + :param sec_opt: optional choice of security upgrade + :param peerstore_opt: optional peerstore + :param disc_opt: optional discovery + :return: return a host instance + """ + swarm = new_swarm( key_pair=key_pair, muxer_opt=muxer_opt, sec_opt=sec_opt, peerstore_opt=peerstore_opt, ) - async with background_trio_service(swarm): - await swarm.listen(*listen_addrs) - host = _new_host(swarm_opt=swarm, disc_opt=disc_opt) - yield host - - -# TODO: Support asyncio + host: IHost + if disc_opt: + host = RoutedHost(swarm, disc_opt) + else: + host = BasicHost(swarm) + return host diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index 253394e5..6386fb83 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -1,12 +1,14 @@ import logging -from typing import TYPE_CHECKING, List, Sequence +from typing import TYPE_CHECKING, AsyncIterator, List, Sequence +from async_generator import asynccontextmanager +from async_service import background_trio_service import multiaddr from libp2p.crypto.keys import PrivateKey, PublicKey from libp2p.host.defaults import get_default_protocols from libp2p.host.exceptions import StreamFailure -from libp2p.network.network_interface import INetwork +from libp2p.network.network_interface import INetworkService from libp2p.network.stream.net_stream_interface import INetStream from libp2p.peer.id import ID from libp2p.peer.peerinfo import PeerInfo @@ -39,7 +41,7 @@ class BasicHost(IHost): right after a stream is initialized. """ - _network: INetwork + _network: INetworkService peerstore: IPeerStore multiselect: Multiselect @@ -47,7 +49,7 @@ class BasicHost(IHost): def __init__( self, - network: INetwork, + network: INetworkService, default_protocols: "OrderedDict[TProtocol, StreamHandlerFn]" = None, ) -> None: self._network = network @@ -70,7 +72,7 @@ class BasicHost(IHost): def get_private_key(self) -> PrivateKey: return self.peerstore.privkey(self.get_id()) - def get_network(self) -> INetwork: + def get_network(self) -> INetworkService: """ :return: network instance of host """ @@ -101,6 +103,20 @@ class BasicHost(IHost): addrs.append(addr.encapsulate(p2p_part)) return addrs + @asynccontextmanager + async def run( + self, listen_addrs: Sequence[multiaddr.Multiaddr] + ) -> AsyncIterator[None]: + """ + run the host instance and listen to ``listen_addrs``. + + :param listen_addrs: a sequence of multiaddrs that we want to listen to + """ + network = self.get_network() + async with background_trio_service(network): + await network.listen(*listen_addrs) + yield + def set_stream_handler( self, protocol_id: TProtocol, stream_handler: StreamHandlerFn ) -> None: diff --git a/libp2p/host/host_interface.py b/libp2p/host/host_interface.py index 43f4ac40..59146e7f 100644 --- a/libp2p/host/host_interface.py +++ b/libp2p/host/host_interface.py @@ -1,10 +1,10 @@ from abc import ABC, abstractmethod -from typing import Any, List, Sequence +from typing import Any, AsyncContextManager, List, Sequence import multiaddr from libp2p.crypto.keys import PrivateKey, PublicKey -from libp2p.network.network_interface import INetwork +from libp2p.network.network_interface import INetworkService from libp2p.network.stream.net_stream_interface import INetStream from libp2p.peer.id import ID from libp2p.peer.peerinfo import PeerInfo @@ -31,7 +31,7 @@ class IHost(ABC): """ @abstractmethod - def get_network(self) -> INetwork: + def get_network(self) -> INetworkService: """ :return: network instance of host """ @@ -49,6 +49,16 @@ class IHost(ABC): :return: all the multiaddr addresses this host is listening to """ + @abstractmethod + def run( + self, listen_addrs: Sequence[multiaddr.Multiaddr] + ) -> AsyncContextManager[None]: + """ + run the host instance and listen to ``listen_addrs``. + + :param listen_addrs: a sequence of multiaddrs that we want to listen to + """ + @abstractmethod def set_stream_handler( self, protocol_id: TProtocol, stream_handler: StreamHandlerFn diff --git a/libp2p/host/routed_host.py b/libp2p/host/routed_host.py index 78b6fa54..91264c71 100644 --- a/libp2p/host/routed_host.py +++ b/libp2p/host/routed_host.py @@ -1,6 +1,6 @@ from libp2p.host.basic_host import BasicHost from libp2p.host.exceptions import ConnectionFailure -from libp2p.network.network_interface import INetwork +from libp2p.network.network_interface import INetworkService from libp2p.peer.peerinfo import PeerInfo from libp2p.routing.interfaces import IPeerRouting @@ -10,7 +10,7 @@ from libp2p.routing.interfaces import IPeerRouting class RoutedHost(BasicHost): _router: IPeerRouting - def __init__(self, network: INetwork, router: IPeerRouting): + def __init__(self, network: INetworkService, router: IPeerRouting): super().__init__(network) self._router = router diff --git a/tests/host/test_basic_host.py b/tests/host/test_basic_host.py index 1eec04a8..55605ed5 100644 --- a/tests/host/test_basic_host.py +++ b/tests/host/test_basic_host.py @@ -1,4 +1,4 @@ -from libp2p import initialize_default_swarm +from libp2p import new_swarm from libp2p.crypto.rsa import create_new_key_pair from libp2p.host.basic_host import BasicHost from libp2p.host.defaults import get_default_protocols @@ -6,7 +6,7 @@ from libp2p.host.defaults import get_default_protocols def test_default_protocols(): key_pair = create_new_key_pair() - swarm = initialize_default_swarm(key_pair) + swarm = new_swarm(key_pair) host = BasicHost(swarm) mux = host.get_mux() diff --git a/tests/security/test_security_multistream.py b/tests/security/test_security_multistream.py index 8d466106..cd968ac2 100644 --- a/tests/security/test_security_multistream.py +++ b/tests/security/test_security_multistream.py @@ -1,7 +1,7 @@ import pytest import trio -from libp2p import new_host_trio +from libp2p import new_host from libp2p.crypto.rsa import create_new_key_pair from libp2p.security.insecure.transport import InsecureSession, InsecureTransport from libp2p.tools.constants import LISTEN_MADDR @@ -29,15 +29,13 @@ async def perform_simple_test( # for testing, we do NOT want to communicate over a stream so we can't just create two nodes # and use their conn because our mplex will internally relay messages to a stream - async with new_host_trio( - listen_addrs=[LISTEN_MADDR], - key_pair=initiator_key_pair, - sec_opt=transports_for_initiator, - ) as node1, new_host_trio( - listen_addrs=[LISTEN_MADDR], - key_pair=noninitiator_key_pair, - sec_opt=transports_for_noninitiator, - ) as node2: + node1 = new_host(key_pair=initiator_key_pair, sec_opt=transports_for_initiator) + node2 = new_host( + key_pair=noninitiator_key_pair, sec_opt=transports_for_noninitiator + ) + async with node1.run(listen_addrs=[LISTEN_MADDR]), node2.run( + listen_addrs=[LISTEN_MADDR] + ): await connect(node1, node2) # Wait a very short period to allow conns to be stored (since the functions From 5b4b65faa8f81f74eabbb01f49ad197ae5023945 Mon Sep 17 00:00:00 2001 From: mhchia Date: Sun, 26 Jan 2020 23:03:38 +0800 Subject: [PATCH 45/81] Change default value of `read()` From `n = -1` to `n = None`, to comply with trio API --- libp2p/io/abc.py | 2 +- libp2p/io/msgio.py | 2 +- libp2p/io/trio.py | 15 +++------------ libp2p/network/connection/raw_connection.py | 2 +- libp2p/network/stream/net_stream.py | 2 +- libp2p/security/insecure/transport.py | 2 +- libp2p/security/secio/transport.py | 4 ++-- libp2p/stream_muxer/mplex/mplex_stream.py | 11 ++++++----- tests_interop/conftest.py | 7 ++----- 9 files changed, 18 insertions(+), 29 deletions(-) diff --git a/libp2p/io/abc.py b/libp2p/io/abc.py index eea7b72f..8f2c7582 100644 --- a/libp2p/io/abc.py +++ b/libp2p/io/abc.py @@ -8,7 +8,7 @@ class Closer(ABC): class Reader(ABC): @abstractmethod - async def read(self, n: int = -1) -> bytes: + async def read(self, n: int = None) -> bytes: ... diff --git a/libp2p/io/msgio.py b/libp2p/io/msgio.py index f60b0ff9..32c0d09e 100644 --- a/libp2p/io/msgio.py +++ b/libp2p/io/msgio.py @@ -54,7 +54,7 @@ class MsgIOReader(ReadCloser): self.read_closer = read_closer self.next_length = None - async def read(self, n: int = -1) -> bytes: + async def read(self, n: int = None) -> bytes: return await self.read_msg() async def read_msg(self) -> bytes: diff --git a/libp2p/io/trio.py b/libp2p/io/trio.py index 840c3bc8..b8571c89 100644 --- a/libp2p/io/trio.py +++ b/libp2p/io/trio.py @@ -26,22 +26,13 @@ class TrioTCPStream(ReadWriteCloser): await self.stream.send_all(data) except (trio.ClosedResourceError, trio.BrokenResourceError) as error: raise IOException from error - except trio.BusyResourceError as error: - # This should never happen, since we already access streams with read/write locks. - raise Exception( - "this should never happen " - "since we already access streams with read/write locks." - ) from error - async def read(self, n: int = -1) -> bytes: + async def read(self, n: int = None) -> bytes: async with self.read_lock: - if n == 0: - # Checkpoint - await trio.hazmat.checkpoint() + if n is not None and n == 0: return b"" - max_bytes = n if n != -1 else None try: - return await self.stream.receive_some(max_bytes) + return await self.stream.receive_some(n) except (trio.ClosedResourceError, trio.BrokenResourceError) as error: raise IOException from error except trio.BusyResourceError as error: diff --git a/libp2p/network/connection/raw_connection.py b/libp2p/network/connection/raw_connection.py index 25b1049c..69ef56ae 100644 --- a/libp2p/network/connection/raw_connection.py +++ b/libp2p/network/connection/raw_connection.py @@ -20,7 +20,7 @@ class RawConnection(IRawConnection): except IOException as error: raise RawConnError(error) - async def read(self, n: int = -1) -> bytes: + async def read(self, n: int = None) -> bytes: """ Read up to ``n`` bytes from the underlying stream. This call is delegated directly to the underlying ``self.reader``. diff --git a/libp2p/network/stream/net_stream.py b/libp2p/network/stream/net_stream.py index 7ab609d0..b2bac06a 100644 --- a/libp2p/network/stream/net_stream.py +++ b/libp2p/network/stream/net_stream.py @@ -37,7 +37,7 @@ class NetStream(INetStream): """ self.protocol_id = protocol_id - async def read(self, n: int = -1) -> bytes: + async def read(self, n: int = None) -> bytes: """ reads from stream. diff --git a/libp2p/security/insecure/transport.py b/libp2p/security/insecure/transport.py index 4199c612..abce868c 100644 --- a/libp2p/security/insecure/transport.py +++ b/libp2p/security/insecure/transport.py @@ -39,7 +39,7 @@ class InsecureSession(BaseSession): await self.conn.write(data) return len(data) - async def read(self, n: int = -1) -> bytes: + async def read(self, n: int = None) -> bytes: return await self.conn.read(n) async def close(self) -> None: diff --git a/libp2p/security/secio/transport.py b/libp2p/security/secio/transport.py index 08ab0e29..23359215 100644 --- a/libp2p/security/secio/transport.py +++ b/libp2p/security/secio/transport.py @@ -94,7 +94,7 @@ class SecureSession(BaseSession): data = self.buf.getbuffer()[self.low_watermark : self.high_watermark] - if n < 0: + if n is None: n = len(data) result = data[:n].tobytes() self.low_watermark += len(result) @@ -111,7 +111,7 @@ class SecureSession(BaseSession): self.low_watermark = 0 self.high_watermark = len(msg) - async def read(self, n: int = -1) -> bytes: + async def read(self, n: int = None) -> bytes: if n == 0: return bytes() diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index ae6f7ea0..933de204 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -81,22 +81,23 @@ class MplexStream(IMuxedStream): break return buf - async def read(self, n: int = -1) -> bytes: + async def read(self, n: int = None) -> bytes: """ Read up to n bytes. Read possibly returns fewer than `n` bytes, if - there are not enough bytes in the Mplex buffer. If `n == -1`, read + there are not enough bytes in the Mplex buffer. If `n is None`, read until EOF. :param n: number of bytes to read :return: bytes actually read """ - if n < 0 and n != -1: + if n is not None and n < 0: raise ValueError( - f"the number of bytes to read `n` must be positive or -1 to indicate read until EOF" + f"the number of bytes to read `n` must be non-negative or " + "`None` to indicate read until EOF" ) if self.event_reset.is_set(): raise MplexStreamReset - if n == -1: + if n is None: return await self._read_until_eof() if len(self._buf) == 0: data: bytes diff --git a/tests_interop/conftest.py b/tests_interop/conftest.py index bad0054e..067db83f 100644 --- a/tests_interop/conftest.py +++ b/tests_interop/conftest.py @@ -84,11 +84,8 @@ class DaemonStream(ReadWriteCloser): async def close(self) -> None: await self.stream.close() - async def read(self, n: int = -1) -> bytes: - if n == -1: - return await self.stream.receive_some() - else: - return await self.stream.receive_some(n) + async def read(self, n: int = None) -> bytes: + return await self.stream.receive_some(n) async def write(self, data: bytes) -> None: return await self.stream.send_all(data) From b85bab1a09bf9943b2ea9040387ce59b7b6c8528 Mon Sep 17 00:00:00 2001 From: mhchia Date: Sun, 26 Jan 2020 23:09:56 +0800 Subject: [PATCH 46/81] Don't catch `trio.BusyResourceError` --- libp2p/io/trio.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/libp2p/io/trio.py b/libp2p/io/trio.py index b8571c89..465e4eaa 100644 --- a/libp2p/io/trio.py +++ b/libp2p/io/trio.py @@ -35,12 +35,6 @@ class TrioTCPStream(ReadWriteCloser): return await self.stream.receive_some(n) except (trio.ClosedResourceError, trio.BrokenResourceError) as error: raise IOException from error - except trio.BusyResourceError as error: - # This should never happen, since we already access streams with read/write locks. - raise Exception( - "this should never happen " - "since we already access streams with read/write locks." - ) from error async def close(self) -> None: await self.stream.aclose() From ddfbf9ffc89f3a8819344959808d2f2f9b953a64 Mon Sep 17 00:00:00 2001 From: mhchia Date: Sun, 26 Jan 2020 23:54:29 +0800 Subject: [PATCH 47/81] Use `raise from` to reserve stacktrace --- libp2p/network/connection/raw_connection.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libp2p/network/connection/raw_connection.py b/libp2p/network/connection/raw_connection.py index 69ef56ae..2d8409f7 100644 --- a/libp2p/network/connection/raw_connection.py +++ b/libp2p/network/connection/raw_connection.py @@ -18,7 +18,7 @@ class RawConnection(IRawConnection): try: await self.stream.write(data) except IOException as error: - raise RawConnError(error) + raise RawConnError from error async def read(self, n: int = None) -> bytes: """ @@ -30,7 +30,7 @@ class RawConnection(IRawConnection): try: return await self.stream.read(n) except IOException as error: - raise RawConnError(error) + raise RawConnError from error async def close(self) -> None: await self.stream.close() From 42bc4d5d0609010c9e2fe346f0d3482c18cd2c42 Mon Sep 17 00:00:00 2001 From: mhchia Date: Sun, 26 Jan 2020 23:55:31 +0800 Subject: [PATCH 48/81] `INetworkService` implement `ServiceAPI` --- libp2p/network/network_interface.py | 4 ++-- libp2p/network/swarm.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/libp2p/network/network_interface.py b/libp2p/network/network_interface.py index 156a71b9..c759a411 100644 --- a/libp2p/network/network_interface.py +++ b/libp2p/network/network_interface.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Dict, Sequence -from async_service import Service +from async_service import ServiceAPI from multiaddr import Multiaddr from libp2p.network.connection.net_connection_interface import INetConn @@ -73,5 +73,5 @@ class INetwork(ABC): pass -class INetworkService(INetwork, Service): +class INetworkService(INetwork, ServiceAPI): ... diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 45d85b19..e03d65b7 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -1,6 +1,7 @@ import logging from typing import Dict, List, Optional +from async_service import Service from multiaddr import Multiaddr from libp2p.io.abc import ReadWriteCloser @@ -30,7 +31,7 @@ from .stream.net_stream_interface import INetStream logger = logging.getLogger("libp2p.network.swarm") -class Swarm(INetworkService): +class Swarm(Service, INetworkService): self_id: ID peerstore: IPeerStore From e3a1dd62e40dc5807f671c9698f6532a7aaa2a98 Mon Sep 17 00:00:00 2001 From: mhchia Date: Sun, 26 Jan 2020 23:56:19 +0800 Subject: [PATCH 49/81] Use new type hinting for trio channel --- libp2p/pubsub/pubsub.py | 14 +++----------- libp2p/stream_muxer/mplex/mplex.py | 9 ++------- 2 files changed, 5 insertions(+), 18 deletions(-) diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 3df0ac29..379361d8 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -71,7 +71,6 @@ class Pubsub(IPubsub, Service): seen_messages: LRU - # TODO: Implement `trio.abc.Channel`? subscribed_topics_send: Dict[str, "trio.MemorySendChannel[rpc_pb2.Message]"] subscribed_topics_receive: Dict[str, "TrioSubscriptionAPI"] @@ -112,12 +111,8 @@ class Pubsub(IPubsub, Service): # Attach this new Pubsub object to the router self.router.attach(self) - peer_channels: Tuple[ - "trio.MemorySendChannel[ID]", "trio.MemoryReceiveChannel[ID]" - ] = trio.open_memory_channel(0) - dead_peer_channels: Tuple[ - "trio.MemorySendChannel[ID]", "trio.MemoryReceiveChannel[ID]" - ] = trio.open_memory_channel(0) + peer_channels = trio.open_memory_channel[ID](0) + dead_peer_channels = trio.open_memory_channel[ID](0) # Only keep the receive channels in `Pubsub`. # Therefore, we can only close from the receive side. self.peer_receive_channel = peer_channels[1] @@ -404,10 +399,7 @@ class Pubsub(IPubsub, Service): if topic_id in self.topic_ids: return self.subscribed_topics_receive[topic_id] - channels: Tuple[ - "trio.MemorySendChannel[rpc_pb2.Message]", - "trio.MemoryReceiveChannel[rpc_pb2.Message]", - ] = trio.open_memory_channel(math.inf) + channels = trio.open_memory_channel[rpc_pb2.Message](math.inf) send_channel, receive_channel = channels subscription = TrioSubscriptionAPI(receive_channel) self.subscribed_topics_send[topic_id] = send_channel diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 486fd3f5..6f5f3fd8 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -66,10 +66,7 @@ class Mplex(IMuxedConn): self.streams = {} self.streams_lock = trio.Lock() self.streams_msg_channels = {} - channels: Tuple[ - "trio.MemorySendChannel[IMuxedStream]", - "trio.MemoryReceiveChannel[IMuxedStream]", - ] = trio.open_memory_channel(math.inf) + channels = trio.open_memory_channel[IMuxedStream](math.inf) self.new_stream_send_channel, self.new_stream_receive_channel = channels self.event_shutting_down = trio.Event() self.event_closed = trio.Event() @@ -114,9 +111,7 @@ class Mplex(IMuxedConn): async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream: # Use an unbounded buffer, to avoid `handle_incoming` being blocked when doing # `send_channel.send`. - channels: Tuple[ - "trio.MemorySendChannel[bytes]", "trio.MemoryReceiveChannel[bytes]" - ] = trio.open_memory_channel(math.inf) + channels = trio.open_memory_channel[bytes](math.inf) stream = MplexStream(name, stream_id, self, channels[1]) async with self.streams_lock: self.streams[stream_id] = stream From 92ea35e147b825a4880785a87b5ee7d34567996b Mon Sep 17 00:00:00 2001 From: mhchia Date: Mon, 27 Jan 2020 00:10:33 +0800 Subject: [PATCH 50/81] Fix `IPubsub` and add `IPubsub.wait_until_ready` --- libp2p/pubsub/abc.py | 66 +++++++++++++++++++++++++++++++++------ libp2p/pubsub/pubsub.py | 30 ++++++------------ libp2p/pubsub/typing.py | 9 ++++++ libp2p/tools/factories.py | 3 +- 4 files changed, 76 insertions(+), 32 deletions(-) create mode 100644 libp2p/pubsub/typing.py diff --git a/libp2p/pubsub/abc.py b/libp2p/pubsub/abc.py index 19f9b2a6..e4b75840 100644 --- a/libp2p/pubsub/abc.py +++ b/libp2p/pubsub/abc.py @@ -1,18 +1,35 @@ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, AsyncContextManager, AsyncIterable, List +from typing import ( + TYPE_CHECKING, + AsyncContextManager, + AsyncIterable, + KeysView, + List, + Tuple, +) + +from async_service import ServiceAPI from libp2p.peer.id import ID from libp2p.typing import TProtocol from .pb import rpc_pb2 +from .typing import ValidatorFn if TYPE_CHECKING: from .pubsub import Pubsub # noqa: F401 -# TODO: Add interface for Pubsub -class IPubsub(ABC): - pass +class ISubscriptionAPI( + AsyncContextManager["ISubscriptionAPI"], AsyncIterable[rpc_pb2.Message] +): + @abstractmethod + async def cancel(self) -> None: + ... + + @abstractmethod + async def get(self) -> rpc_pb2.Message: + ... class IPubsubRouter(ABC): @@ -86,13 +103,44 @@ class IPubsubRouter(ABC): """ -class ISubscriptionAPI( - AsyncContextManager["ISubscriptionAPI"], AsyncIterable[rpc_pb2.Message] -): +class IPubsub(ServiceAPI): + @property @abstractmethod - async def cancel(self) -> None: + def my_id(self) -> ID: + ... + + @property + @abstractmethod + def protocols(self) -> Tuple[TProtocol, ...]: + ... + + @property + @abstractmethod + def topic_ids(self) -> KeysView[str]: ... @abstractmethod - async def get(self) -> rpc_pb2.Message: + def set_topic_validator( + self, topic: str, validator: ValidatorFn, is_async_validator: bool + ) -> None: + ... + + @abstractmethod + def remove_topic_validator(self, topic: str) -> None: + ... + + @abstractmethod + async def wait_until_ready(self) -> None: + ... + + @abstractmethod + async def subscribe(self, topic_id: str) -> ISubscriptionAPI: + ... + + @abstractmethod + async def unsubscribe(self, topic_id: str) -> None: + ... + + @abstractmethod + async def publish(self, topic_id: str, data: bytes) -> None: ... diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 379361d8..94c33b41 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -1,19 +1,7 @@ import logging import math import time -from typing import ( - TYPE_CHECKING, - Awaitable, - Callable, - Dict, - KeysView, - List, - NamedTuple, - Set, - Tuple, - Union, - cast, -) +from typing import TYPE_CHECKING, Dict, KeysView, List, NamedTuple, Set, Tuple, cast from async_service import Service import base58 @@ -35,6 +23,7 @@ from .abc import IPubsub, ISubscriptionAPI from .pb import rpc_pb2 from .pubsub_notifee import PubsubNotifee from .subscription import TrioSubscriptionAPI +from .typing import AsyncValidatorFn, SyncValidatorFn, ValidatorFn from .validators import PUBSUB_SIGNING_PREFIX, signature_validator if TYPE_CHECKING: @@ -50,17 +39,12 @@ def get_msg_id(msg: rpc_pb2.Message) -> Tuple[bytes, bytes]: return (msg.seqno, msg.from_id) -SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool] -AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]] -ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn] - - class TopicValidator(NamedTuple): validator: ValidatorFn is_async: bool -class Pubsub(IPubsub, Service): +class Pubsub(Service, IPubsub): host: IHost @@ -290,6 +274,10 @@ class Pubsub(IPubsub, Service): await stream.reset() self._handle_dead_peer(peer_id) + async def wait_until_ready(self) -> None: + await self.event_handle_peer_queue_started.wait() + await self.event_handle_dead_peer_queue_started.wait() + async def _handle_new_peer(self, peer_id: ID) -> None: try: stream: INetStream = await self.host.new_stream(peer_id, self.protocols) @@ -332,18 +320,18 @@ class Pubsub(IPubsub, Service): """Continuously read from peer queue and each time a new peer is found, open a stream to the peer using a supported pubsub protocol pubsub protocols we support.""" - self.event_handle_peer_queue_started.set() async with self.peer_receive_channel: + self.event_handle_peer_queue_started.set() async for peer_id in self.peer_receive_channel: # Add Peer self.manager.run_task(self._handle_new_peer, peer_id) async def handle_dead_peer_queue(self) -> None: - self.event_handle_dead_peer_queue_started.set() """Continuously read from dead peer channel and close the stream between that peer and remove peer info from pubsub and pubsub router.""" async with self.dead_peer_receive_channel: + self.event_handle_dead_peer_queue_started.set() async for peer_id in self.dead_peer_receive_channel: # Remove Peer self._handle_dead_peer(peer_id) diff --git a/libp2p/pubsub/typing.py b/libp2p/pubsub/typing.py new file mode 100644 index 00000000..c352d529 --- /dev/null +++ b/libp2p/pubsub/typing.py @@ -0,0 +1,9 @@ +from typing import Awaitable, Callable, Union + +from libp2p.peer.id import ID + +from .pb import rpc_pb2 + +SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool] +AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]] +ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn] diff --git a/libp2p/tools/factories.py b/libp2p/tools/factories.py index a20b59fa..84e6af39 100644 --- a/libp2p/tools/factories.py +++ b/libp2p/tools/factories.py @@ -245,8 +245,7 @@ class PubsubFactory(factory.Factory): strict_signing=strict_signing, ) async with background_trio_service(pubsub): - await pubsub.event_handle_peer_queue_started.wait() - await pubsub.event_handle_dead_peer_queue_started.wait() + await pubsub.wait_until_ready() yield pubsub @classmethod From c3ba67ea876e2f5dec32c8285cb10e050e3e109c Mon Sep 17 00:00:00 2001 From: mhchia Date: Mon, 27 Jan 2020 14:30:44 +0800 Subject: [PATCH 51/81] Remove locks in `PubsubNotifee` - Change `open_memory_channel(0)` to `open_memory_channel(math.inf)`, to avoid `peer_queue.send` and `dead_peer_queue.send` blocking. This allows us to remove the locks. - Only catch `trio.BrokenResourceError`, which is caused by Pubsub when it's closing. --- libp2p/pubsub/pubsub.py | 10 ++++---- libp2p/pubsub/pubsub_notifee.py | 43 +++++++++++---------------------- 2 files changed, 19 insertions(+), 34 deletions(-) diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 94c33b41..03140742 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -95,15 +95,15 @@ class Pubsub(Service, IPubsub): # Attach this new Pubsub object to the router self.router.attach(self) - peer_channels = trio.open_memory_channel[ID](0) - dead_peer_channels = trio.open_memory_channel[ID](0) + peer_send, peer_receive = trio.open_memory_channel[ID](math.inf) + dead_peer_send, dead_peer_receive = trio.open_memory_channel[ID](math.inf) # Only keep the receive channels in `Pubsub`. # Therefore, we can only close from the receive side. - self.peer_receive_channel = peer_channels[1] - self.dead_peer_receive_channel = dead_peer_channels[1] + self.peer_receive_channel = peer_receive + self.dead_peer_receive_channel = dead_peer_receive # Register a notifee self.host.get_network().register_notifee( - PubsubNotifee(peer_channels[0], dead_peer_channels[0]) + PubsubNotifee(peer_send, dead_peer_send) ) # Register stream handlers for each pubsub router protocol to handle diff --git a/libp2p/pubsub/pubsub_notifee.py b/libp2p/pubsub/pubsub_notifee.py index be481901..08b4a5b2 100644 --- a/libp2p/pubsub/pubsub_notifee.py +++ b/libp2p/pubsub/pubsub_notifee.py @@ -16,7 +16,6 @@ class PubsubNotifee(INotifee): initiator_peers_queue: "trio.MemorySendChannel[ID]" dead_peers_queue: "trio.MemorySendChannel[ID]" - dead_peers_queue_lock: trio.Lock def __init__( self, @@ -30,15 +29,13 @@ class PubsubNotifee(INotifee): can process dead peers after we disconnect from each other """ self.initiator_peers_queue = initiator_peers_queue - self.initiator_peers_queue_lock = trio.Lock() self.dead_peers_queue = dead_peers_queue - self.dead_peers_queue_lock = trio.Lock() async def opened_stream(self, network: INetwork, stream: INetStream) -> None: - pass + ... async def closed_stream(self, network: INetwork, stream: INetStream) -> None: - pass + ... async def connected(self, network: INetwork, conn: INetConn) -> None: """ @@ -49,17 +46,11 @@ class PubsubNotifee(INotifee): :param network: network the connection was opened on :param conn: connection that was opened """ - async with self.initiator_peers_queue_lock: - try: - await self.initiator_peers_queue.send(conn.muxed_conn.peer_id) - except ( - trio.BrokenResourceError, - trio.ClosedResourceError, - trio.BusyResourceError, - ): - # Raised when the receive channel is closed. - # TODO: Do something with loggers? - ... + try: + await self.initiator_peers_queue.send(conn.muxed_conn.peer_id) + except trio.BrokenResourceError: + # The receive channel is closed by Pubsub. We should do nothing here. + ... async def disconnected(self, network: INetwork, conn: INetConn) -> None: """ @@ -69,20 +60,14 @@ class PubsubNotifee(INotifee): :param network: network the connection was opened on :param conn: connection that was opened """ - async with self.dead_peers_queue_lock: - try: - await self.dead_peers_queue.send(conn.muxed_conn.peer_id) - except ( - trio.BrokenResourceError, - trio.ClosedResourceError, - trio.BusyResourceError, - ): - # Raised when the receive channel is closed. - # TODO: Do something with loggers? - ... + try: + await self.dead_peers_queue.send(conn.muxed_conn.peer_id) + except trio.BrokenResourceError: + # The receive channel is closed by Pubsub. We should do nothing here. + ... async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None: - pass + ... async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None: - pass + ... From 095a848f3032e753d05ca2e058017c528302a6e2 Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 28 Jan 2020 00:29:05 +0800 Subject: [PATCH 52/81] Add clean-up logics into TrioSubscriptionAPI Register an `unsubscribe_fn` when initializing the TrioSubscriptionAPI. `unsubscribe_fn` is called when subscription is unsubscribed. --- libp2p/pubsub/abc.py | 2 +- libp2p/pubsub/pubsub.py | 12 ++++++++--- libp2p/pubsub/subscription.py | 15 ++++++++++---- libp2p/pubsub/typing.py | 2 ++ tests/pubsub/test_pubsub.py | 34 ++++++++++++++++++++++++++++++- tests/pubsub/test_subscription.py | 21 ++++++++++++------- 6 files changed, 70 insertions(+), 16 deletions(-) diff --git a/libp2p/pubsub/abc.py b/libp2p/pubsub/abc.py index e4b75840..da37b6a1 100644 --- a/libp2p/pubsub/abc.py +++ b/libp2p/pubsub/abc.py @@ -24,7 +24,7 @@ class ISubscriptionAPI( AsyncContextManager["ISubscriptionAPI"], AsyncIterable[rpc_pb2.Message] ): @abstractmethod - async def cancel(self) -> None: + async def unsubscribe(self) -> None: ... @abstractmethod diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 03140742..26c4b4ff 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -1,3 +1,4 @@ +import functools import logging import math import time @@ -387,9 +388,14 @@ class Pubsub(Service, IPubsub): if topic_id in self.topic_ids: return self.subscribed_topics_receive[topic_id] - channels = trio.open_memory_channel[rpc_pb2.Message](math.inf) - send_channel, receive_channel = channels - subscription = TrioSubscriptionAPI(receive_channel) + send_channel, receive_channel = trio.open_memory_channel[rpc_pb2.Message]( + math.inf + ) + + subscription = TrioSubscriptionAPI( + receive_channel, + unsubscribe_fn=functools.partial(self.unsubscribe, topic_id), + ) self.subscribed_topics_send[topic_id] = send_channel self.subscribed_topics_receive[topic_id] = subscription diff --git a/libp2p/pubsub/subscription.py b/libp2p/pubsub/subscription.py index 1d88d09b..e3c926cc 100644 --- a/libp2p/pubsub/subscription.py +++ b/libp2p/pubsub/subscription.py @@ -5,6 +5,7 @@ import trio from .abc import ISubscriptionAPI from .pb import rpc_pb2 +from .typing import UnsubscribeFn class BaseSubscriptionAPI(ISubscriptionAPI): @@ -18,19 +19,25 @@ class BaseSubscriptionAPI(ISubscriptionAPI): exc_value: "Optional[BaseException]", traceback: "Optional[TracebackType]", ) -> None: - await self.cancel() + await self.unsubscribe() class TrioSubscriptionAPI(BaseSubscriptionAPI): receive_channel: "trio.MemoryReceiveChannel[rpc_pb2.Message]" + unsubscribe_fn: UnsubscribeFn def __init__( - self, receive_channel: "trio.MemoryReceiveChannel[rpc_pb2.Message]" + self, + receive_channel: "trio.MemoryReceiveChannel[rpc_pb2.Message]", + unsubscribe_fn: UnsubscribeFn, ) -> None: self.receive_channel = receive_channel + # Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427 + self.unsubscribe_fn = unsubscribe_fn # type: ignore - async def cancel(self) -> None: - await self.receive_channel.aclose() + async def unsubscribe(self) -> None: + # Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427 + await self.unsubscribe_fn() # type: ignore def __aiter__(self) -> AsyncIterator[rpc_pb2.Message]: return self.receive_channel.__aiter__() diff --git a/libp2p/pubsub/typing.py b/libp2p/pubsub/typing.py index c352d529..33297a9f 100644 --- a/libp2p/pubsub/typing.py +++ b/libp2p/pubsub/typing.py @@ -7,3 +7,5 @@ from .pb import rpc_pb2 SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool] AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]] ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn] + +UnsubscribeFn = Callable[[], Awaitable[None]] diff --git a/tests/pubsub/test_pubsub.py b/tests/pubsub/test_pubsub.py index 4bea6dd2..1e9d670a 100644 --- a/tests/pubsub/test_pubsub.py +++ b/tests/pubsub/test_pubsub.py @@ -413,7 +413,39 @@ async def test_message_all_peers(monkeypatch, is_host_secure): @pytest.mark.trio -async def test_publish(monkeypatch): +async def test_subscribe_and_publish(): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: + pubsub = pubsubs_fsub[0] + + list_data = [b"d0", b"d1"] + event_receive_data_started = trio.Event() + + async def publish_data(topic): + await event_receive_data_started.wait() + for data in list_data: + await pubsub.publish(topic, data) + + async def receive_data(topic): + i = 0 + event_receive_data_started.set() + assert topic not in pubsub.topic_ids + subscription = await pubsub.subscribe(topic) + async with subscription: + assert topic in pubsub.topic_ids + async for msg in subscription: + assert msg.data == list_data[i] + i += 1 + if i == len(list_data): + break + assert topic not in pubsub.topic_ids + + async with trio.open_nursery() as nursery: + nursery.start_soon(receive_data, TESTING_TOPIC) + nursery.start_soon(publish_data, TESTING_TOPIC) + + +@pytest.mark.trio +async def test_publish_push_msg_is_called(monkeypatch): msg_forwarders = [] msgs = [] diff --git a/tests/pubsub/test_subscription.py b/tests/pubsub/test_subscription.py index c5a20eda..a0a6c10c 100644 --- a/tests/pubsub/test_subscription.py +++ b/tests/pubsub/test_subscription.py @@ -11,7 +11,14 @@ GET_TIMEOUT = 0.001 def make_trio_subscription(): send_channel, receive_channel = trio.open_memory_channel(math.inf) - return send_channel, TrioSubscriptionAPI(receive_channel) + + async def unsubscribe_fn(): + await send_channel.aclose() + + return ( + send_channel, + TrioSubscriptionAPI(receive_channel, unsubscribe_fn=unsubscribe_fn), + ) def make_pubsub_msg(): @@ -56,14 +63,14 @@ async def test_trio_subscription_iter(): @pytest.mark.trio -async def test_trio_subscription_cancel(): +async def test_trio_subscription_unsubscribe(): send_channel, sub = make_trio_subscription() - await sub.cancel() - # Test: If the subscription is cancelled, `send_channel` should be broken. - with pytest.raises(trio.BrokenResourceError): + await sub.unsubscribe() + # Test: If the subscription is unsubscribed, `send_channel` should be closed. + with pytest.raises(trio.ClosedResourceError): await send_something(send_channel) # Test: No side effect when cancelled twice. - await sub.cancel() + await sub.unsubscribe() @pytest.mark.trio @@ -73,5 +80,5 @@ async def test_trio_subscription_async_context_manager(): # Test: `sub` is not cancelled yet, so `send_something` works fine. await send_something(send_channel) # Test: `sub` is cancelled, `send_something` fails - with pytest.raises(trio.BrokenResourceError): + with pytest.raises(trio.ClosedResourceError): await send_something(send_channel) From e57d01f360de9fb101adffa686cf5aff414fc8f8 Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 28 Jan 2020 15:48:41 +0800 Subject: [PATCH 53/81] PR feedback - Use f-string - Fix wrongly indented comments - Add dep `trio-typing` --- examples/chat/chat.py | 2 +- libp2p/network/connection/swarm_connection.py | 2 +- setup.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/chat/chat.py b/examples/chat/chat.py index 41aad927..81f3891b 100755 --- a/examples/chat/chat.py +++ b/examples/chat/chat.py @@ -63,7 +63,7 @@ async def run(port: int, destination: str) -> None: nursery.start_soon(read_data, stream) nursery.start_soon(write_data, stream) - print("Connected to peer %s" % info.addrs[0]) + print(f"Connected to peer {info.addrs[0]}") await trio.sleep_forever() diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index 7cf3c0ff..b84db8e4 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -60,9 +60,9 @@ class SwarmConn(INetConn): while True: try: stream = await self.muxed_conn.accept_stream() - # Asynchronously handle the accepted stream, to avoid blocking the next stream. except MuxedConnUnavailable: break + # Asynchronously handle the accepted stream, to avoid blocking the next stream. self.swarm.manager.run_task(self._handle_muxed_stream, stream) await self.close() diff --git a/setup.py b/setup.py index fe3b5aeb..2d36c70d 100644 --- a/setup.py +++ b/setup.py @@ -77,6 +77,7 @@ install_requires = [ "trio>=0.13.0", "async-service>=0.1.0a4", "async-exit-stack==1.0.1", + "trio-typing>=0.3.0,<1.0.0", ] From 1588be2be956e8c31219dfbf7adaa9d831b2315c Mon Sep 17 00:00:00 2001 From: mhchia Date: Fri, 31 Jan 2020 17:42:47 +0800 Subject: [PATCH 54/81] Change the channel size of peer queue Back to `0`, to avoid unlimited buffer size. --- libp2p/pubsub/pubsub.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 26c4b4ff..41cd9655 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -96,8 +96,8 @@ class Pubsub(Service, IPubsub): # Attach this new Pubsub object to the router self.router.attach(self) - peer_send, peer_receive = trio.open_memory_channel[ID](math.inf) - dead_peer_send, dead_peer_receive = trio.open_memory_channel[ID](math.inf) + peer_send, peer_receive = trio.open_memory_channel[ID](0) + dead_peer_send, dead_peer_receive = trio.open_memory_channel[ID](0) # Only keep the receive channels in `Pubsub`. # Therefore, we can only close from the receive side. self.peer_receive_channel = peer_receive From 05d5d045ea3460a02979a7410dbddb069e54cb9f Mon Sep 17 00:00:00 2001 From: mhchia Date: Sun, 2 Feb 2020 18:17:22 +0800 Subject: [PATCH 55/81] Fix pubsub interop: missing unsubscribe_fn --- tests_interop/test_pubsub.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests_interop/test_pubsub.py b/tests_interop/test_pubsub.py index f15d89e3..793f2446 100644 --- a/tests_interop/test_pubsub.py +++ b/tests_interop/test_pubsub.py @@ -21,7 +21,7 @@ async def p2pd_subscribe(p2pd, topic, nursery): stream = TrioTCPStream(await p2pd.control.pubsub_subscribe(topic)) send_channel, receive_channel = trio.open_memory_channel(math.inf) - sub = TrioSubscriptionAPI(receive_channel) + sub = TrioSubscriptionAPI(receive_channel, unsubscribe_fn=stream.close) async def _read_pubsub_msg() -> None: while True: From 22963a309968c9f5c318c65c8897d7cbfb19a207 Mon Sep 17 00:00:00 2001 From: mhchia Date: Sun, 2 Feb 2020 18:18:01 +0800 Subject: [PATCH 56/81] Fix trio-typing>=0.3,<0.4 To be consistent with trinity --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 2d36c70d..4ec5c515 100644 --- a/setup.py +++ b/setup.py @@ -77,7 +77,7 @@ install_requires = [ "trio>=0.13.0", "async-service>=0.1.0a4", "async-exit-stack==1.0.1", - "trio-typing>=0.3.0,<1.0.0", + "trio-typing>=0.3.0,<0.4.0", ] From 113696dce212fc5282012528f02a0429e1ef779c Mon Sep 17 00:00:00 2001 From: mhchia Date: Mon, 3 Feb 2020 16:04:32 +0800 Subject: [PATCH 57/81] TravisCI: use python `3.7` instead of `3.7-dev` --- .travis.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.travis.yml b/.travis.yml index 6b6b3ccc..1c7856d6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,16 +5,16 @@ matrix: - python: 3.6-dev dist: xenial env: TOXENV=py36-test - - python: 3.7-dev + - python: 3.7 dist: xenial env: TOXENV=py37-test - - python: 3.7-dev + - python: 3.7 dist: xenial env: TOXENV=lint - - python: 3.7-dev + - python: 3.7 dist: xenial env: TOXENV=docs - - python: 3.7-dev + - python: 3.7 dist: xenial env: TOXENV=py37-interop sudo: true From 5da102d1c9b456b6afa8fa8b7f0713c3c07dbe7b Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 4 Feb 2020 15:09:42 +0800 Subject: [PATCH 58/81] Ping protocol: move `with` statement out of `try` --- libp2p/host/ping.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/libp2p/host/ping.py b/libp2p/host/ping.py index 9e23f1cc..01102451 100644 --- a/libp2p/host/ping.py +++ b/libp2p/host/ping.py @@ -17,23 +17,23 @@ logger = logging.getLogger("libp2p.host.ping") async def _handle_ping(stream: INetStream, peer_id: PeerID) -> bool: """Return a boolean indicating if we expect more pings from the peer at ``peer_id``.""" - with trio.fail_after(RESP_TIMEOUT): - try: + try: + with trio.fail_after(RESP_TIMEOUT): payload = await stream.read(PING_LENGTH) - except trio.TooSlowError as error: - logger.debug("Timed out waiting for ping from %s: %s", peer_id, error) - raise - except StreamEOF: - logger.debug("Other side closed while waiting for ping from %s", peer_id) - return False - except StreamReset as error: - logger.debug( - "Other side reset while waiting for ping from %s: %s", peer_id, error - ) - raise - except Exception as error: - logger.debug("Error while waiting to read ping for %s: %s", peer_id, error) - raise + except trio.TooSlowError as error: + logger.debug("Timed out waiting for ping from %s: %s", peer_id, error) + raise + except StreamEOF: + logger.debug("Other side closed while waiting for ping from %s", peer_id) + return False + except StreamReset as error: + logger.debug( + "Other side reset while waiting for ping from %s: %s", peer_id, error + ) + raise + except Exception as error: + logger.debug("Error while waiting to read ping for %s: %s", peer_id, error) + raise logger.debug("Received ping from %s with data: 0x%s", peer_id, payload.hex()) From d483982acb46befdbec901d7de7f9bab76ab0004 Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 4 Feb 2020 15:10:49 +0800 Subject: [PATCH 59/81] SwarmConn: don't catch exceptions in handler --- libp2p/network/connection/swarm_connection.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index b84db8e4..a5e22e70 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -20,7 +20,6 @@ class SwarmConn(INetConn): muxed_conn: IMuxedConn swarm: "Swarm" streams: Set[NetStream] - event_started: trio.Event event_closed: trio.Event def __init__(self, muxed_conn: IMuxedConn, swarm: "Swarm") -> None: @@ -72,10 +71,7 @@ class SwarmConn(INetConn): if self.swarm.common_stream_handler is not None: try: await self.swarm.common_stream_handler(net_stream) - # TODO: More exact exceptions - except Exception: - # TODO: Emit logs. - # TODO: Clean up and remove the stream from SwarmConn if there is anything wrong. + finally: self.remove_stream(net_stream) def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: From 3fc60cb312a74f2340e0c712bc799bda6df82c44 Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 4 Feb 2020 17:04:28 +0800 Subject: [PATCH 60/81] SwarmConn: iterate `streams.copy` in `_cleanup` To avoid `RuntimeError` if `streams` is changed. --- libp2p/network/connection/swarm_connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index a5e22e70..0e930f50 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -46,7 +46,7 @@ class SwarmConn(INetConn): # This is just for cleaning up state. The connection has already been closed. # We *could* optimize this but it really isn't worth it. - for stream in self.streams: + for stream in self.streams.copy(): await stream.reset() # Force context switch for stream handlers to process the stream reset event we just emit # before we cancel the stream handler tasks. From 3a91f114abc2f301a742b7acc14655f289cbc9da Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 4 Feb 2020 17:05:53 +0800 Subject: [PATCH 61/81] Swarm: add `default_stream_handler` Advantage: - To avoid `None` checks - If users forget to register a stream handler for `Swarm`, with the default stream handler, opened streams aren't removed until `Swarm` finishes. --- libp2p/network/connection/swarm_connection.py | 11 ++++++----- libp2p/network/swarm.py | 17 +++++++++++++---- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index 0e930f50..cc13dcca 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -68,11 +68,12 @@ class SwarmConn(INetConn): async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None: net_stream = self._add_stream(muxed_stream) - if self.swarm.common_stream_handler is not None: - try: - await self.swarm.common_stream_handler(net_stream) - finally: - self.remove_stream(net_stream) + try: + # Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427 + await self.swarm.common_stream_handler(net_stream) # type: ignore + finally: + # As long as `common_stream_handler`, remove the stream. + self.remove_stream(net_stream) def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: net_stream = NetStream(muxed_stream) diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index e03d65b7..9a0279d9 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, List, Optional +from typing import Dict, List from async_service import Service from multiaddr import Multiaddr @@ -31,6 +31,13 @@ from .stream.net_stream_interface import INetStream logger = logging.getLogger("libp2p.network.swarm") +def create_default_stream_handler(network: INetworkService) -> StreamHandlerFn: + async def stream_handler(stream: INetStream) -> None: + await network.get_manager().wait_finished() + + return stream_handler + + class Swarm(Service, INetworkService): self_id: ID @@ -41,7 +48,7 @@ class Swarm(Service, INetworkService): # whereas in Go one `peer_id` may point to multiple connections. connections: Dict[ID, INetConn] listeners: Dict[str, IListener] - common_stream_handler: Optional[StreamHandlerFn] + common_stream_handler: StreamHandlerFn notifees: List[INotifee] @@ -62,7 +69,8 @@ class Swarm(Service, INetworkService): # Create Notifee array self.notifees = [] - self.common_stream_handler = None + # Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427 + self.common_stream_handler = create_default_stream_handler(self) # type: ignore async def run(self) -> None: await self.manager.wait_finished() @@ -71,7 +79,8 @@ class Swarm(Service, INetworkService): return self.self_id def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None: - self.common_stream_handler = stream_handler + # Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427 + self.common_stream_handler = stream_handler # type: ignore async def dial_peer(self, peer_id: ID) -> INetConn: """ From 7ae9de90023960cff99c558dc86a1902d1ca6361 Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 4 Feb 2020 17:09:26 +0800 Subject: [PATCH 62/81] Fix handler in `net_stream_pair_factory` Change it to async function. It wasn't discovered since we caught all exceptions raised in stream handlers. --- libp2p/tools/factories.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/libp2p/tools/factories.py b/libp2p/tools/factories.py index 84e6af39..67e26519 100644 --- a/libp2p/tools/factories.py +++ b/libp2p/tools/factories.py @@ -404,13 +404,18 @@ async def net_stream_pair_factory( stream_1: INetStream - # Just a proxy, we only care about the stream - def handler(stream: INetStream) -> None: + # Just a proxy, we only care about the stream. + # Add a barrier to avoid stream being removed. + event_handler_finished = trio.Event() + + async def handler(stream: INetStream) -> None: nonlocal stream_1 stream_1 = stream + await event_handler_finished.wait() async with host_pair_factory(is_secure) as hosts: hosts[1].set_stream_handler(protocol_id, handler) stream_0 = await hosts[0].new_stream(hosts[1].get_id(), [protocol_id]) yield stream_0, stream_1 + event_handler_finished.set() From 66975ae3f2a7b90338fd5c4ce661d4874eb192ec Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 4 Feb 2020 17:43:39 +0800 Subject: [PATCH 63/81] Pubsub: change `run_task` to `run_daemon_task` --- libp2p/pubsub/gossipsub.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libp2p/pubsub/gossipsub.py b/libp2p/pubsub/gossipsub.py index 25c53d28..4d25c254 100644 --- a/libp2p/pubsub/gossipsub.py +++ b/libp2p/pubsub/gossipsub.py @@ -88,7 +88,7 @@ class GossipSub(IPubsubRouter, Service): async def run(self) -> None: if self.pubsub is None: raise NoPubsubAttached - self.manager.run_task(self.heartbeat) + self.manager.run_daemon_task(self.heartbeat) await self.manager.wait_finished() # Interface functions From 857bb34f4e5a1ded58281a9bc0212927a4b241b5 Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 4 Feb 2020 17:44:10 +0800 Subject: [PATCH 64/81] Add checkpoints in `PubsubNotifee` Since some of the methods in `PubsubNotifee` are doing nothing, add checkpoints to yield control. --- libp2p/pubsub/pubsub_notifee.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/libp2p/pubsub/pubsub_notifee.py b/libp2p/pubsub/pubsub_notifee.py index 08b4a5b2..b32c1450 100644 --- a/libp2p/pubsub/pubsub_notifee.py +++ b/libp2p/pubsub/pubsub_notifee.py @@ -32,10 +32,10 @@ class PubsubNotifee(INotifee): self.dead_peers_queue = dead_peers_queue async def opened_stream(self, network: INetwork, stream: INetStream) -> None: - ... + await trio.hazmat.checkpoint() async def closed_stream(self, network: INetwork, stream: INetStream) -> None: - ... + await trio.hazmat.checkpoint() async def connected(self, network: INetwork, conn: INetConn) -> None: """ @@ -67,7 +67,7 @@ class PubsubNotifee(INotifee): ... async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None: - ... + await trio.hazmat.checkpoint() async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None: - ... + await trio.hazmat.checkpoint() From 89338914d335e2ccc7f5280746046cee5070f309 Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 4 Feb 2020 17:45:56 +0800 Subject: [PATCH 65/81] Add comment for `serve_tcp` --- libp2p/transport/tcp/tcp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/libp2p/transport/tcp/tcp.py b/libp2p/transport/tcp/tcp.py index c336130b..1004e288 100644 --- a/libp2p/transport/tcp/tcp.py +++ b/libp2p/transport/tcp/tcp.py @@ -38,6 +38,7 @@ class TCPListener(IListener): host: str, task_status: TaskStatus[Sequence[trio.SocketListener]] = None, ) -> None: + """Just a proxy function to add logging here.""" logger.debug("serve_tcp %s %s", host, port) await trio.serve_tcp(handler, port, host=host, task_status=task_status) From b007bb4d07d1dbfdc54287ef638c3956a7c1cafb Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 4 Feb 2020 17:46:30 +0800 Subject: [PATCH 66/81] Use the latest async-service --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 4ec5c515..e808b959 100644 --- a/setup.py +++ b/setup.py @@ -75,7 +75,7 @@ install_requires = [ "dataclasses>=0.7, <1;python_version<'3.7'", "async_generator==1.10", "trio>=0.13.0", - "async-service>=0.1.0a4", + "async-service>=0.1.0a6", "async-exit-stack==1.0.1", "trio-typing>=0.3.0,<0.4.0", ] From a7ba59bf9f758689b796f912d015d7cf0ff0e457 Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 4 Feb 2020 20:45:58 +0800 Subject: [PATCH 67/81] Add a nursery in `Swarm` To avoid using the one in `Service` --- libp2p/network/swarm.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 9a0279d9..45180a98 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -1,8 +1,9 @@ import logging -from typing import Dict, List +from typing import Dict, List, Optional from async_service import Service from multiaddr import Multiaddr +import trio from libp2p.io.abc import ReadWriteCloser from libp2p.network.connection.net_connection_interface import INetConn @@ -49,6 +50,8 @@ class Swarm(Service, INetworkService): connections: Dict[ID, INetConn] listeners: Dict[str, IListener] common_stream_handler: StreamHandlerFn + listener_nursery: Optional[trio.Nursery] + event_listener_nursery_created: trio.Event notifees: List[INotifee] @@ -72,8 +75,21 @@ class Swarm(Service, INetworkService): # Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427 self.common_stream_handler = create_default_stream_handler(self) # type: ignore + self.listener_nursery = None + self.event_listener_nursery_created = trio.Event() + async def run(self) -> None: - await self.manager.wait_finished() + async with trio.open_nursery() as nursery: + # Create a nursery for listener tasks. + self.listener_nursery = nursery + self.event_listener_nursery_created.set() + try: + await self.manager.wait_finished() + finally: + # The service ended. Cancel listener tasks. + nursery.cancel_scope.cancel() + # Indicate that the nursery has been cancelled. + self.listener_nursery = None def get_peer_id(self) -> ID: return self.self_id @@ -207,6 +223,9 @@ class Swarm(Service, INetworkService): - Call listener listen with the multiaddr - Map multiaddr to listener """ + # We need to wait until `self.listener_nursery` is created. + await self.event_listener_nursery_created.wait() + for maddr in multiaddrs: if str(maddr) in self.listeners: return True @@ -250,7 +269,9 @@ class Swarm(Service, INetworkService): self.listeners[str(maddr)] = listener # TODO: `listener.listen` is not bounded with nursery. If we want to be # I/O agnostic, we should change the API. - await listener.listen(maddr, self.manager._task_nursery) # type: ignore + if self.listener_nursery is None: + raise SwarmException("swarm instance hasn't been run") + await listener.listen(maddr, self.listener_nursery) # Call notifiers since event occurred self.notify_listen(maddr) From 0548d285682d23e7ff2e003ab11ffd1b946d391e Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 4 Feb 2020 20:46:40 +0800 Subject: [PATCH 68/81] Fix: `StreamReset` in the stream handlers Since we don't catch `Exception` in the stream handlers, catch them in the stream handlers in the tests. --- libp2p/tools/utils.py | 11 +++++++++-- tests/libp2p/test_libp2p.py | 16 +++++++++++++--- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/libp2p/tools/utils.py b/libp2p/tools/utils.py index a66155c3..5a262b3b 100644 --- a/libp2p/tools/utils.py +++ b/libp2p/tools/utils.py @@ -1,6 +1,7 @@ from typing import Awaitable, Callable from libp2p.host.host_interface import IHost +from libp2p.network.stream.exceptions import StreamError from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.swarm import Swarm from libp2p.peer.peerinfo import info_from_p2p_addr @@ -33,9 +34,15 @@ def create_echo_stream_handler( ) -> Callable[[INetStream], Awaitable[None]]: async def echo_stream_handler(stream: INetStream) -> None: while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() + try: + read_string = (await stream.read(MAX_READ_LEN)).decode() + except StreamError: + break resp = ack_prefix + read_string - await stream.write(resp.encode()) + try: + await stream.write(resp.encode()) + except StreamError: + break return echo_stream_handler diff --git a/tests/libp2p/test_libp2p.py b/tests/libp2p/test_libp2p.py index 91fea586..99a60bd5 100644 --- a/tests/libp2p/test_libp2p.py +++ b/tests/libp2p/test_libp2p.py @@ -1,6 +1,7 @@ import multiaddr import pytest +from libp2p.network.stream.exceptions import StreamError from libp2p.tools.constants import MAX_READ_LEN from libp2p.tools.factories import HostFactory from libp2p.tools.utils import connect, create_echo_stream_handler @@ -42,13 +43,22 @@ async def test_double_response(is_host_secure): async def double_response_stream_handler(stream): while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() + try: + read_string = (await stream.read(MAX_READ_LEN)).decode() + except StreamError: + break response = ACK_STR_0 + read_string - await stream.write(response.encode()) + try: + await stream.write(response.encode()) + except StreamError: + break response = ACK_STR_1 + read_string - await stream.write(response.encode()) + try: + await stream.write(response.encode()) + except StreamError: + break hosts[1].set_stream_handler(PROTOCOL_ID_0, double_response_stream_handler) From f884bfa39ec6ccbf9c109996d2a2aa44110a898b Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 4 Feb 2020 21:57:11 +0800 Subject: [PATCH 69/81] SwarmConn: don't access `Swarm.manager` Open a local nursery instead. --- libp2p/network/connection/swarm_connection.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index cc13dcca..d21da482 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -56,13 +56,14 @@ class SwarmConn(INetConn): async def _handle_new_streams(self) -> None: self.event_started.set() - while True: - try: - stream = await self.muxed_conn.accept_stream() - except MuxedConnUnavailable: - break - # Asynchronously handle the accepted stream, to avoid blocking the next stream. - self.swarm.manager.run_task(self._handle_muxed_stream, stream) + async with trio.open_nursery() as nursery: + while True: + try: + stream = await self.muxed_conn.accept_stream() + except MuxedConnUnavailable: + break + # Asynchronously handle the accepted stream, to avoid blocking the next stream. + nursery.start_soon(self._handle_muxed_stream, stream) await self.close() From c0ab6095599c7a299070b623af645b44ce50eb02 Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 4 Feb 2020 21:57:59 +0800 Subject: [PATCH 70/81] Mplex: catch `RawConnError` when writing Also, do nothing in `MplexStream.reset` if `MuxedConnUnavailable` is raised when sending the message. --- libp2p/stream_muxer/mplex/mplex.py | 8 +++++++- libp2p/stream_muxer/mplex/mplex_stream.py | 7 ++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 6f5f3fd8..70f26b3b 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -166,7 +166,13 @@ class Mplex(IMuxedConn): :param _bytes: byte array to write :return: length written """ - await self.secured_conn.write(_bytes) + try: + await self.secured_conn.write(_bytes) + except RawConnError as e: + raise MplexUnavailable( + "failed to write message to the underlying connection" + ) from e + return len(_bytes) async def handle_incoming(self) -> None: diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 933de204..79675749 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING import trio from libp2p.stream_muxer.abc import IMuxedStream +from libp2p.stream_muxer.exceptions import MuxedConnUnavailable from .constants import HeaderTags from .datastructures import StreamID @@ -189,7 +190,11 @@ class MplexStream(IMuxedStream): if self.is_initiator else HeaderTags.ResetReceiver ) - await self.muxed_conn.send_message(flag, None, self.stream_id) + # Try to send reset message to the other side. Ignore if there is anything wrong. + try: + await self.muxed_conn.send_message(flag, None, self.stream_id) + except MuxedConnUnavailable: + pass self.event_local_closed.set() self.event_remote_closed.set() From 13930ae718bd29042adb823b9922d218813c2126 Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 4 Feb 2020 22:51:21 +0800 Subject: [PATCH 71/81] SwarmConn: perform `close` right away In `_handle_new_streams`, when the underlying muxed conn is unavailable, close `SwarmConn` itself right away, to reset all the streams. Therefore, the stream processed by `_handle_muxed_stream` are conscious of the fact that they are reset. It allows a more graceful clean up. --- libp2p/network/connection/swarm_connection.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index d21da482..f6e60784 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -61,12 +61,11 @@ class SwarmConn(INetConn): try: stream = await self.muxed_conn.accept_stream() except MuxedConnUnavailable: + await self.close() break # Asynchronously handle the accepted stream, to avoid blocking the next stream. nursery.start_soon(self._handle_muxed_stream, stream) - await self.close() - async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None: net_stream = self._add_stream(muxed_stream) try: From 12cb0d9ac48db078d15bd1651139c124ff55494d Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 4 Feb 2020 22:56:13 +0800 Subject: [PATCH 72/81] Swarm: change `notify_xxx` back to async func --- libp2p/network/connection/swarm_connection.py | 14 +++---- libp2p/network/swarm.py | 38 +++++++++++-------- 2 files changed, 29 insertions(+), 23 deletions(-) diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index f6e60784..baa9df50 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -52,7 +52,7 @@ class SwarmConn(INetConn): # before we cancel the stream handler tasks. await trio.sleep(0.1) - self._notify_disconnected() + await self._notify_disconnected() async def _handle_new_streams(self) -> None: self.event_started.set() @@ -67,7 +67,7 @@ class SwarmConn(INetConn): nursery.start_soon(self._handle_muxed_stream, stream) async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None: - net_stream = self._add_stream(muxed_stream) + net_stream = await self._add_stream(muxed_stream) try: # Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427 await self.swarm.common_stream_handler(net_stream) # type: ignore @@ -75,21 +75,21 @@ class SwarmConn(INetConn): # As long as `common_stream_handler`, remove the stream. self.remove_stream(net_stream) - def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: + async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: net_stream = NetStream(muxed_stream) self.streams.add(net_stream) - self.swarm.notify_opened_stream(net_stream) + await self.swarm.notify_opened_stream(net_stream) return net_stream - def _notify_disconnected(self) -> None: - self.swarm.notify_disconnected(self) + async def _notify_disconnected(self) -> None: + await self.swarm.notify_disconnected(self) async def start(self) -> None: await self._handle_new_streams() async def new_stream(self) -> NetStream: muxed_stream = await self.muxed_conn.open_stream() - return self._add_stream(muxed_stream) + return await self._add_stream(muxed_stream) def get_streams(self) -> Tuple[NetStream, ...]: return tuple(self.streams) diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 45180a98..57ce1358 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -274,7 +274,7 @@ class Swarm(Service, INetworkService): await listener.listen(maddr, self.listener_nursery) # Call notifiers since event occurred - self.notify_listen(maddr) + await self.notify_listen(maddr) return True except IOError: @@ -310,7 +310,7 @@ class Swarm(Service, INetworkService): # Store muxed_conn with peer id self.connections[muxed_conn.peer_id] = swarm_conn # Call notifiers since event occurred - self.notify_connected(swarm_conn) + await self.notify_connected(swarm_conn) return swarm_conn def remove_conn(self, swarm_conn: SwarmConn) -> None: @@ -330,22 +330,28 @@ class Swarm(Service, INetworkService): """ self.notifees.append(notifee) - def notify_opened_stream(self, stream: INetStream) -> None: - for notifee in self.notifees: - self.manager.run_task(notifee.opened_stream, self, stream) + async def notify_opened_stream(self, stream: INetStream) -> None: + async with trio.open_nursery() as nursery: + for notifee in self.notifees: + nursery.start_soon(notifee.opened_stream, self, stream) - # TODO: `notify_closed_stream` + async def notify_connected(self, conn: INetConn) -> None: + async with trio.open_nursery() as nursery: + for notifee in self.notifees: + nursery.start_soon(notifee.connected, self, conn) - def notify_connected(self, conn: INetConn) -> None: - for notifee in self.notifees: - self.manager.run_task(notifee.connected, self, conn) + async def notify_disconnected(self, conn: INetConn) -> None: + async with trio.open_nursery() as nursery: + for notifee in self.notifees: + nursery.start_soon(notifee.disconnected, self, conn) - def notify_disconnected(self, conn: INetConn) -> None: - for notifee in self.notifees: - self.manager.run_task(notifee.disconnected, self, conn) + async def notify_listen(self, multiaddr: Multiaddr) -> None: + async with trio.open_nursery() as nursery: + for notifee in self.notifees: + nursery.start_soon(notifee.listen, self, multiaddr) - def notify_listen(self, multiaddr: Multiaddr) -> None: - for notifee in self.notifees: - self.manager.run_task(notifee.listen, self, multiaddr) + async def notify_closed_stream(self, stream: INetStream) -> None: + raise NotImplementedError - # TODO: `notify_listen_close` + async def notify_listen_close(self, multiaddr: Multiaddr) -> None: + raise NotImplementedError From 996b5cf15d7e719c4b247aefe3757d7cf54d3efc Mon Sep 17 00:00:00 2001 From: mhchia Date: Wed, 5 Feb 2020 17:05:30 +0800 Subject: [PATCH 73/81] Mplex: catch exceptions from `channel.send` --- libp2p/stream_muxer/mplex/mplex.py | 13 ++++++++----- tests/stream_muxer/test_mplex_conn.py | 2 -- tests/stream_muxer/test_mplex_stream.py | 2 -- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 70f26b3b..5e265b96 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -111,11 +111,11 @@ class Mplex(IMuxedConn): async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream: # Use an unbounded buffer, to avoid `handle_incoming` being blocked when doing # `send_channel.send`. - channels = trio.open_memory_channel[bytes](math.inf) - stream = MplexStream(name, stream_id, self, channels[1]) + send_channel, receive_channel = trio.open_memory_channel[bytes](math.inf) + stream = MplexStream(name, stream_id, self, receive_channel) async with self.streams_lock: self.streams[stream_id] = stream - self.streams_msg_channels[stream_id] = channels[0] + self.streams_msg_channels[stream_id] = send_channel return stream async def open_stream(self) -> IMuxedStream: @@ -269,7 +269,10 @@ class Mplex(IMuxedConn): if stream.event_remote_closed.is_set(): # TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501 return - await send_channel.send(message) + try: + await send_channel.send(message) + except (trio.BrokenResourceError, trio.ClosedResourceError): + raise MplexUnavailable async def _handle_close(self, stream_id: StreamID) -> None: async with self.streams_lock: @@ -325,7 +328,7 @@ class Mplex(IMuxedConn): stream.event_local_closed.set() send_channel = self.streams_msg_channels[stream_id] await send_channel.aclose() - self.streams = None self.event_closed.set() + # FIXME: It's enough to just close one side. await self.new_stream_send_channel.aclose() await self.new_stream_receive_channel.aclose() diff --git a/tests/stream_muxer/test_mplex_conn.py b/tests/stream_muxer/test_mplex_conn.py index 4bff2d61..df1097dd 100644 --- a/tests/stream_muxer/test_mplex_conn.py +++ b/tests/stream_muxer/test_mplex_conn.py @@ -31,12 +31,10 @@ async def test_mplex_conn(mplex_conn_pair): assert stream_0.event_remote_closed.is_set() assert stream_0.event_reset.is_set() assert stream_0.event_local_closed.is_set() - assert conn_0.streams is None # Test: All streams on the other side are also closed. assert stream_1.event_remote_closed.is_set() assert stream_1.event_reset.is_set() assert stream_1.event_local_closed.is_set() - assert conn_1.streams is None # Test: No effect to close more than once between two side. await conn_0.close() diff --git a/tests/stream_muxer/test_mplex_stream.py b/tests/stream_muxer/test_mplex_stream.py index 55ee97bd..181daa09 100644 --- a/tests/stream_muxer/test_mplex_stream.py +++ b/tests/stream_muxer/test_mplex_stream.py @@ -69,9 +69,7 @@ async def test_mplex_stream_read_after_remote_closed(mplex_stream_pair): await stream_0.close() assert stream_0.event_local_closed.is_set() await trio.sleep(0.01) - # await trio.sleep(100000) await wait_all_tasks_blocked() - # await trio.sleep_forever() assert stream_1.event_remote_closed.is_set() assert (await stream_1.read(MAX_READ_LEN)) == DATA with pytest.raises(MplexStreamEOF): From 1e7d5c73ee22c0290f5f4c276b0ba6c8e958bd9b Mon Sep 17 00:00:00 2001 From: mhchia Date: Wed, 5 Feb 2020 17:25:39 +0800 Subject: [PATCH 74/81] test_mplex_stream: refactor --- tests/stream_muxer/test_mplex_stream.py | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/tests/stream_muxer/test_mplex_stream.py b/tests/stream_muxer/test_mplex_stream.py index 181daa09..eeb76538 100644 --- a/tests/stream_muxer/test_mplex_stream.py +++ b/tests/stream_muxer/test_mplex_stream.py @@ -7,26 +7,16 @@ from libp2p.stream_muxer.mplex.exceptions import ( MplexStreamEOF, MplexStreamReset, ) -from libp2p.tools.constants import LISTEN_MADDR, MAX_READ_LEN -from libp2p.tools.factories import SwarmFactory -from libp2p.tools.utils import connect_swarm +from libp2p.tools.constants import MAX_READ_LEN DATA = b"data_123" @pytest.mark.trio -async def test_mplex_stream_read_write(): - async with SwarmFactory.create_batch_and_listen(False, 2) as swarms: - await swarms[0].listen(LISTEN_MADDR) - await swarms[1].listen(LISTEN_MADDR) - await connect_swarm(swarms[0], swarms[1]) - conn_0 = swarms[0].connections[swarms[1].get_peer_id()] - conn_1 = swarms[1].connections[swarms[0].get_peer_id()] - stream_0 = await conn_0.muxed_conn.open_stream() - await trio.sleep(1) - stream_1 = tuple(conn_1.muxed_conn.streams.values())[0] - await stream_0.write(DATA) - assert (await stream_1.read(MAX_READ_LEN)) == DATA +async def test_mplex_stream_read_write(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + await stream_0.write(DATA) + assert (await stream_1.read(MAX_READ_LEN)) == DATA @pytest.mark.trio From 64c9c48dac6870877bf5dbafc80efba998316c3c Mon Sep 17 00:00:00 2001 From: mhchia Date: Wed, 5 Feb 2020 19:48:02 +0800 Subject: [PATCH 75/81] Mplex: change new stream channel size To `0`, i.e. no unbuffered, to avoid growing buffer size. --- libp2p/stream_muxer/mplex/mplex.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 5e265b96..b3b45c05 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -66,7 +66,7 @@ class Mplex(IMuxedConn): self.streams = {} self.streams_lock = trio.Lock() self.streams_msg_channels = {} - channels = trio.open_memory_channel[IMuxedStream](math.inf) + channels = trio.open_memory_channel[IMuxedStream](0) self.new_stream_send_channel, self.new_stream_receive_channel = channels self.event_shutting_down = trio.Event() self.event_closed = trio.Event() From 1fff6ad6b4309fe29cf30e60041ffd305de93013 Mon Sep 17 00:00:00 2001 From: mhchia Date: Wed, 5 Feb 2020 20:31:18 +0800 Subject: [PATCH 76/81] Mplex: change message channel size to 8 To avoid infinity sized channel, and to conform to the go implementation. --- libp2p/stream_muxer/mplex/mplex.py | 19 +++++++++++++------ tests/stream_muxer/test_mplex_stream.py | 23 +++++++++++++++++++++++ 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index b3b45c05..5b1df77d 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -1,5 +1,4 @@ import logging -import math from typing import Dict, Optional, Tuple import trio @@ -24,6 +23,8 @@ from .exceptions import MplexUnavailable from .mplex_stream import MplexStream MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0") +# Ref: https://github.com/libp2p/go-mplex/blob/master/multiplex.go#L115 +MPLEX_MESSAGE_CHANNEL_SIZE = 8 logger = logging.getLogger("libp2p.stream_muxer.mplex.mplex") @@ -109,9 +110,9 @@ class Mplex(IMuxedConn): return next_id async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream: - # Use an unbounded buffer, to avoid `handle_incoming` being blocked when doing - # `send_channel.send`. - send_channel, receive_channel = trio.open_memory_channel[bytes](math.inf) + send_channel, receive_channel = trio.open_memory_channel[bytes]( + MPLEX_MESSAGE_CHANNEL_SIZE + ) stream = MplexStream(name, stream_id, self, receive_channel) async with self.streams_lock: self.streams[stream_id] = stream @@ -145,7 +146,7 @@ class Mplex(IMuxedConn): """ sends a message over the connection. - :param header: header to use + :param flag: header to use :param data: data to send in the message :param stream_id: stream the message is in """ @@ -270,9 +271,15 @@ class Mplex(IMuxedConn): # TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501 return try: - await send_channel.send(message) + send_channel.send_nowait(message) except (trio.BrokenResourceError, trio.ClosedResourceError): raise MplexUnavailable + except trio.WouldBlock: + # `send_channel` is full, reset this stream. + logger.warning( + "message channel of stream %s is full: stream is reset", stream_id + ) + await stream.reset() async def _handle_close(self, stream_id: StreamID) -> None: async with self.streams_lock: diff --git a/tests/stream_muxer/test_mplex_stream.py b/tests/stream_muxer/test_mplex_stream.py index eeb76538..3bc8bc1c 100644 --- a/tests/stream_muxer/test_mplex_stream.py +++ b/tests/stream_muxer/test_mplex_stream.py @@ -7,6 +7,7 @@ from libp2p.stream_muxer.mplex.exceptions import ( MplexStreamEOF, MplexStreamReset, ) +from libp2p.stream_muxer.mplex.mplex import MPLEX_MESSAGE_CHANNEL_SIZE from libp2p.tools.constants import MAX_READ_LEN DATA = b"data_123" @@ -19,6 +20,28 @@ async def test_mplex_stream_read_write(mplex_stream_pair): assert (await stream_1.read(MAX_READ_LEN)) == DATA +@pytest.mark.trio +async def test_mplex_stream_full_buffer(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + # Test: The message channel is of size `MPLEX_MESSAGE_CHANNEL_SIZE`. + # It should be fine to read even there are already `MPLEX_MESSAGE_CHANNEL_SIZE` + # messages arriving. + for _ in range(MPLEX_MESSAGE_CHANNEL_SIZE): + await stream_0.write(DATA) + await wait_all_tasks_blocked() + # Sanity check + assert MAX_READ_LEN >= MPLEX_MESSAGE_CHANNEL_SIZE * len(DATA) + assert (await stream_1.read(MAX_READ_LEN)) == MPLEX_MESSAGE_CHANNEL_SIZE * DATA + + # Test: Read after `MPLEX_MESSAGE_CHANNEL_SIZE + 1` messages has arrived, which + # exceeds the channel size. The stream should have been reset. + for _ in range(MPLEX_MESSAGE_CHANNEL_SIZE + 1): + await stream_0.write(DATA) + await wait_all_tasks_blocked() + with pytest.raises(MplexStreamReset): + await stream_1.read(MAX_READ_LEN) + + @pytest.mark.trio async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair): read_bytes = bytearray() From ba0fb8a833ffb736779af1d7a40a885128fd7a93 Mon Sep 17 00:00:00 2001 From: mhchia Date: Wed, 5 Feb 2020 20:36:42 +0800 Subject: [PATCH 77/81] Fix: use `pass` over `...`(Ellipse) Use `...`(Ellipse) only in abstract methods. --- libp2p/network/network_interface.py | 2 +- libp2p/pubsub/exceptions.py | 4 ++-- libp2p/pubsub/pubsub_notifee.py | 4 ++-- tests/transport/test_tcp.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/libp2p/network/network_interface.py b/libp2p/network/network_interface.py index c759a411..70fb7295 100644 --- a/libp2p/network/network_interface.py +++ b/libp2p/network/network_interface.py @@ -74,4 +74,4 @@ class INetwork(ABC): class INetworkService(INetwork, ServiceAPI): - ... + pass diff --git a/libp2p/pubsub/exceptions.py b/libp2p/pubsub/exceptions.py index a47446de..55afde04 100644 --- a/libp2p/pubsub/exceptions.py +++ b/libp2p/pubsub/exceptions.py @@ -2,8 +2,8 @@ from libp2p.exceptions import BaseLibp2pError class PubsubRouterError(BaseLibp2pError): - ... + pass class NoPubsubAttached(PubsubRouterError): - ... + pass diff --git a/libp2p/pubsub/pubsub_notifee.py b/libp2p/pubsub/pubsub_notifee.py index b32c1450..cf728843 100644 --- a/libp2p/pubsub/pubsub_notifee.py +++ b/libp2p/pubsub/pubsub_notifee.py @@ -50,7 +50,7 @@ class PubsubNotifee(INotifee): await self.initiator_peers_queue.send(conn.muxed_conn.peer_id) except trio.BrokenResourceError: # The receive channel is closed by Pubsub. We should do nothing here. - ... + pass async def disconnected(self, network: INetwork, conn: INetConn) -> None: """ @@ -64,7 +64,7 @@ class PubsubNotifee(INotifee): await self.dead_peers_queue.send(conn.muxed_conn.peer_id) except trio.BrokenResourceError: # The receive channel is closed by Pubsub. We should do nothing here. - ... + pass async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None: await trio.hazmat.checkpoint() diff --git a/tests/transport/test_tcp.py b/tests/transport/test_tcp.py index aca499ab..130b3cc4 100644 --- a/tests/transport/test_tcp.py +++ b/tests/transport/test_tcp.py @@ -13,7 +13,7 @@ async def test_tcp_listener(nursery): transport = TCP() async def handler(tcp_stream): - ... + pass listener = transport.create_listener(handler) assert len(listener.get_addrs()) == 0 From 7f8c0f11f690fe02bd8957fdbd7f7e8af8dd6490 Mon Sep 17 00:00:00 2001 From: mhchia Date: Wed, 5 Feb 2020 21:30:26 +0800 Subject: [PATCH 78/81] Pubsub: change channel size To `32` to conform to the go implementation. --- libp2p/pubsub/pubsub.py | 14 +++++++++++--- tests/pubsub/test_pubsub.py | 34 +++++++++++++++++++++++++++++++++- 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 41cd9655..af5c41d1 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -1,6 +1,5 @@ import functools import logging -import math import time from typing import TYPE_CHECKING, Dict, KeysView, List, NamedTuple, Set, Tuple, cast @@ -32,6 +31,9 @@ if TYPE_CHECKING: from typing import Any # noqa: F401 +# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/40e1c94708658b155f30cf99e4574f384756d83c/topic.go#L97 # noqa: E501 +SUBSCRIPTION_CHANNEL_SIZE = 32 + logger = logging.getLogger("libp2p.pubsub") @@ -373,7 +375,13 @@ class Pubsub(Service, IPubsub): # we are subscribed to a topic this message was sent for, # so add message to the subscription output queue # for each topic - await self.subscribed_topics_send[topic].send(publish_message) + try: + self.subscribed_topics_send[topic].send_nowait(publish_message) + except trio.WouldBlock: + # Channel is full, ignore this message. + logger.warning( + "fail to deliver message to subscription for topic %s", topic + ) async def subscribe(self, topic_id: str) -> ISubscriptionAPI: """ @@ -389,7 +397,7 @@ class Pubsub(Service, IPubsub): return self.subscribed_topics_receive[topic_id] send_channel, receive_channel = trio.open_memory_channel[rpc_pb2.Message]( - math.inf + SUBSCRIPTION_CHANNEL_SIZE ) subscription = TrioSubscriptionAPI( diff --git a/tests/pubsub/test_pubsub.py b/tests/pubsub/test_pubsub.py index 1e9d670a..6a22008c 100644 --- a/tests/pubsub/test_pubsub.py +++ b/tests/pubsub/test_pubsub.py @@ -6,7 +6,7 @@ import trio from libp2p.exceptions import ValidationError from libp2p.pubsub.pb import rpc_pb2 -from libp2p.pubsub.pubsub import PUBSUB_SIGNING_PREFIX +from libp2p.pubsub.pubsub import PUBSUB_SIGNING_PREFIX, SUBSCRIPTION_CHANNEL_SIZE from libp2p.tools.constants import MAX_READ_LEN from libp2p.tools.factories import IDFactory, PubsubFactory, net_stream_pair_factory from libp2p.tools.pubsub.utils import make_pubsub_msg @@ -444,6 +444,38 @@ async def test_subscribe_and_publish(): nursery.start_soon(publish_data, TESTING_TOPIC) +@pytest.mark.trio +async def test_subscribe_and_publish_full_channel(): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: + pubsub = pubsubs_fsub[0] + + extra_data_0 = b"extra_data_0" + extra_data_1 = b"extra_data_1" + + # Test: Subscription channel is of size `SUBSCRIPTION_CHANNEL_SIZE`. + # When the channel is full, new received messages are dropped. + # After the channel has empty slot, the channel can receive new messages. + + # Assume `SUBSCRIPTION_CHANNEL_SIZE` is smaller than `2**(4*8)`. + list_data = [i.to_bytes(4, "big") for i in range(SUBSCRIPTION_CHANNEL_SIZE)] + # Expect `extra_data_0` is dropped and `extra_data_1` is appended. + expected_list_data = list_data + [extra_data_1] + + subscription = await pubsub.subscribe(TESTING_TOPIC) + for data in list_data: + await pubsub.publish(TESTING_TOPIC, data) + + # Publish `extra_data_0` which should be dropped since the channel is already full. + await pubsub.publish(TESTING_TOPIC, extra_data_0) + # Consume a message and there is an empty slot in the channel. + assert (await subscription.get()).data == expected_list_data.pop(0) + # Publish `extra_data_1` which should be appended to the channel. + await pubsub.publish(TESTING_TOPIC, extra_data_1) + + for expected_data in expected_list_data: + assert (await subscription.get()).data == expected_data + + @pytest.mark.trio async def test_publish_push_msg_is_called(monkeypatch): msg_forwarders = [] From b7c2ec2187976e4f961c69884d09d0d049cf58e0 Mon Sep 17 00:00:00 2001 From: mhchia Date: Wed, 5 Feb 2020 21:31:04 +0800 Subject: [PATCH 79/81] Mplex: change the reference url To the commit hash, to make it more correct. --- libp2p/stream_muxer/mplex/mplex.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 5b1df77d..ddeb41f4 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -23,7 +23,7 @@ from .exceptions import MplexUnavailable from .mplex_stream import MplexStream MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0") -# Ref: https://github.com/libp2p/go-mplex/blob/master/multiplex.go#L115 +# Ref: https://github.com/libp2p/go-mplex/blob/414db61813d9ad3e6f4a7db5c1b1612de343ace9/multiplex.go#L115 # noqa: E501 MPLEX_MESSAGE_CHANNEL_SIZE = 8 logger = logging.getLogger("libp2p.stream_muxer.mplex.mplex") From 5b03a7ad9f801fe7fa26c70515dadbe2e9bf799e Mon Sep 17 00:00:00 2001 From: mhchia Date: Wed, 5 Feb 2020 21:41:28 +0800 Subject: [PATCH 80/81] Mplex: only close the send of new stream channel --- libp2p/stream_muxer/mplex/mplex.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index ddeb41f4..4f62e152 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -137,7 +137,7 @@ class Mplex(IMuxedConn): """accepts a muxed stream opened by the other end.""" try: return await self.new_stream_receive_channel.receive() - except (trio.ClosedResourceError, trio.EndOfChannel): + except trio.EndOfChannel: raise MplexUnavailable async def send_message( @@ -254,7 +254,7 @@ class Mplex(IMuxedConn): mplex_stream = await self._initialize_stream(stream_id, message.decode()) try: await self.new_stream_send_channel.send(mplex_stream) - except (trio.BrokenResourceError, trio.ClosedResourceError): + except trio.ClosedResourceError: raise MplexUnavailable async def _handle_message(self, stream_id: StreamID, message: bytes) -> None: @@ -336,6 +336,4 @@ class Mplex(IMuxedConn): send_channel = self.streams_msg_channels[stream_id] await send_channel.aclose() self.event_closed.set() - # FIXME: It's enough to just close one side. await self.new_stream_send_channel.aclose() - await self.new_stream_receive_channel.aclose() From ddbedc6c154cff5c513ff149c0948c2394b0e6d8 Mon Sep 17 00:00:00 2001 From: mhchia Date: Wed, 5 Feb 2020 21:44:33 +0800 Subject: [PATCH 81/81] Pubsub: `handle_talk` - Change from async function to sync - Change the name to `notify_subscriptions`, which is clearer. --- libp2p/pubsub/pubsub.py | 5 ++--- tests/pubsub/test_pubsub.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index af5c41d1..ecfa544d 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -361,8 +361,7 @@ class Pubsub(Service, IPubsub): if origin_id in self.peer_topics[sub_message.topicid]: self.peer_topics[sub_message.topicid].discard(origin_id) - # FIXME(mhchia): Change the function name? - async def handle_talk(self, publish_message: rpc_pb2.Message) -> None: + def notify_subscriptions(self, publish_message: rpc_pb2.Message) -> None: """ Put incoming message from a peer onto my blocking queue. @@ -576,7 +575,7 @@ class Pubsub(Service, IPubsub): return self._mark_msg_seen(msg) - await self.handle_talk(msg) + self.notify_subscriptions(msg) await self.router.publish(msg_forwarder, msg) def _next_seqno(self) -> bytes: diff --git a/tests/pubsub/test_pubsub.py b/tests/pubsub/test_pubsub.py index 6a22008c..d6c29310 100644 --- a/tests/pubsub/test_pubsub.py +++ b/tests/pubsub/test_pubsub.py @@ -378,14 +378,14 @@ async def test_handle_talk(): data=b"1234", seqno=b"\x00" * 8, ) - await pubsubs_fsub[0].handle_talk(msg_0) + pubsubs_fsub[0].notify_subscriptions(msg_0) msg_1 = make_pubsub_msg( origin_id=pubsubs_fsub[0].my_id, topic_ids=["NOT_SUBSCRIBED"], data=b"1234", seqno=b"\x11" * 8, ) - await pubsubs_fsub[0].handle_talk(msg_1) + pubsubs_fsub[0].notify_subscriptions(msg_1) assert ( len(pubsubs_fsub[0].topic_ids) == 1 and sub == pubsubs_fsub[0].subscribed_topics_receive[TESTING_TOPIC]