diff --git a/examples/chat/chat.py b/examples/chat/chat.py index 8f10cd30..fbfd89e3 100755 --- a/examples/chat/chat.py +++ b/examples/chat/chat.py @@ -1,11 +1,11 @@ import argparse import asyncio -import trio_asyncio -import trio import sys import urllib.request import multiaddr +import trio +import trio_asyncio from libp2p import new_node from libp2p.network.stream.net_stream_interface import INetStream @@ -42,7 +42,9 @@ async def run(port: int, destination: str, localhost: bool) -> None: transport_opt = f"/ip4/{ip}/tcp/{port}" host = new_node(transport_opt=[transport_opt]) - await trio_asyncio.run_asyncio(host.get_network().listen,multiaddr.Multiaddr(transport_opt) ) + await trio_asyncio.run_asyncio( + host.get_network().listen, multiaddr.Multiaddr(transport_opt) + ) if not destination: # its the server @@ -70,7 +72,9 @@ async def run(port: int, destination: str, localhost: bool) -> None: # Start a stream with the destination. # Multiaddress of the destination peer is fetched from the peerstore using 'peerId'. - stream = await trio_asyncio.run_asyncio(host.new_stream, *(info.peer_id, [PROTOCOL_ID])) + stream = await trio_asyncio.run_asyncio( + host.new_stream, *(info.peer_id, [PROTOCOL_ID]) + ) asyncio.ensure_future(read_data(stream)) asyncio.ensure_future(write_data(stream)) @@ -119,5 +123,6 @@ def main() -> None: trio_asyncio.run(run, *(args.port, args.destination, args.localhost)) + if __name__ == "__main__": main() diff --git a/libp2p/io/trio.py b/libp2p/io/trio.py index 6e88f01d..7d0584e4 100644 --- a/libp2p/io/trio.py +++ b/libp2p/io/trio.py @@ -1,9 +1,10 @@ -import trio -from trio import SocketStream -from libp2p.io.abc import ReadWriteCloser -from libp2p.io.exceptions import IOException import logging +import trio +from trio import SocketStream + +from libp2p.io.abc import ReadWriteCloser +from libp2p.io.exceptions import IOException logger = logging.getLogger("libp2p.io.trio") diff --git a/libp2p/network/connection/raw_connection.py b/libp2p/network/connection/raw_connection.py index 50b28984..2bdb3b10 100644 --- a/libp2p/network/connection/raw_connection.py +++ b/libp2p/network/connection/raw_connection.py @@ -1,9 +1,10 @@ import trio +from libp2p.io.abc import ReadWriteCloser from libp2p.io.exceptions import IOException + from .exceptions import RawConnError from .raw_connection_interface import IRawConnection -from libp2p.io.abc import ReadWriteCloser class RawConnection(IRawConnection): diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 9b89fa56..e54ad6f7 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -3,6 +3,7 @@ import logging from typing import Dict, List, Optional from multiaddr import Multiaddr +import trio from libp2p.io.abc import ReadWriteCloser from libp2p.network.connection.net_connection_interface import INetConn @@ -69,7 +70,7 @@ class Swarm(INetwork): def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None: self.common_stream_handler = stream_handler - async def dial_peer(self, peer_id: ID) -> INetConn: + async def dial_peer(self, peer_id: ID, nursery) -> INetConn: """ dial_peer try to create a connection to peer_id. @@ -121,6 +122,7 @@ class Swarm(INetwork): try: muxed_conn = await self.upgrader.upgrade_connection(secured_conn, peer_id) + muxed_conn.run(nursery) except MuxerUpgradeFailure as error: error_msg = "fail to upgrade mux for peer %s" logger.debug(error_msg, peer_id) @@ -135,7 +137,7 @@ class Swarm(INetwork): return swarm_conn - async def new_stream(self, peer_id: ID) -> INetStream: + async def new_stream(self, peer_id: ID, nursery) -> INetStream: """ :param peer_id: peer_id of destination :param protocol_id: protocol id @@ -144,7 +146,7 @@ class Swarm(INetwork): """ logger.debug("attempting to open a stream to peer %s", peer_id) - swarm_conn = await self.dial_peer(peer_id) + swarm_conn = await self.dial_peer(peer_id, nursery) net_stream = await swarm_conn.new_stream() logger.debug("successfully opened a stream to peer %s", peer_id) @@ -183,11 +185,11 @@ class Swarm(INetwork): raise SwarmException() from error peer_id = secured_conn.get_remote_peer() - try: muxed_conn = await self.upgrader.upgrade_connection( secured_conn, peer_id ) + muxed_conn.run(nursery) except MuxerUpgradeFailure as error: error_msg = "fail to upgrade mux for peer %s" logger.debug(error_msg, peer_id) @@ -198,6 +200,8 @@ class Swarm(INetwork): await self.add_conn(muxed_conn) logger.debug("successfully opened connection to peer %s", peer_id) + event = trio.Event() + await event.wait() try: # Success diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 1f81c2f2..b6df526b 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -3,6 +3,8 @@ import logging from typing import Any # noqa: F401 from typing import Awaitable, Dict, List, Optional, Tuple +import trio + from libp2p.exceptions import ParseError from libp2p.io.exceptions import IncompleteReadError from libp2p.network.connection.exceptions import RawConnError @@ -41,8 +43,6 @@ class Mplex(IMuxedConn): event_shutting_down: asyncio.Event event_closed: asyncio.Event - _tasks: List["asyncio.Future[Any]"] - def __init__(self, secured_conn: ISecureConn, peer_id: ID) -> None: """ create a new muxed connection. @@ -66,10 +66,8 @@ class Mplex(IMuxedConn): self.event_shutting_down = asyncio.Event() self.event_closed = asyncio.Event() - self._tasks = [] - - # Kick off reading - self._tasks.append(asyncio.ensure_future(self.handle_incoming())) + def run(self, nursery): + nursery.start_soon(self.handle_incoming) @property def is_initiator(self) -> bool: @@ -123,7 +121,6 @@ class Mplex(IMuxedConn): await self.send_message(HeaderTags.NewStream, name.encode(), stream_id) return stream - async def accept_stream(self) -> IMuxedStream: """accepts a muxed stream opened by the other end.""" return await self.new_stream_queue.get() @@ -169,7 +166,7 @@ class Mplex(IMuxedConn): logger.debug("mplex unavailable while waiting for incoming: %s", e) break # Force context switch - await asyncio.sleep(0) + await trio.sleep(0) # If we enter here, it means this connection is shutting down. # We should clean things up. await self._cleanup() @@ -184,9 +181,7 @@ class Mplex(IMuxedConn): # FIXME: No timeout is used in Go implementation. try: header = await decode_uvarint_from_stream(self.secured_conn) - message = await asyncio.wait_for( - read_varint_prefixed_bytes(self.secured_conn), timeout=5 - ) + message = await read_varint_prefixed_bytes(self.secured_conn) except (ParseError, RawConnError, IncompleteReadError) as error: raise MplexUnavailable( "failed to read messages correctly from the underlying connection" diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 7659d32f..77458cbd 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -1,8 +1,10 @@ -import trio import asyncio from typing import TYPE_CHECKING +import trio + from libp2p.stream_muxer.abc import IMuxedStream +from libp2p.utils import IQueue, TrioQueue from .constants import HeaderTags from .datastructures import StreamID @@ -26,7 +28,7 @@ class MplexStream(IMuxedStream): close_lock: trio.Lock # NOTE: `dataIn` is size of 8 in Go implementation. - incoming_data: "asyncio.Queue[bytes]" + incoming_data: IQueue[bytes] event_local_closed: trio.Event event_remote_closed: trio.Event @@ -50,69 +52,13 @@ class MplexStream(IMuxedStream): self.event_remote_closed = trio.Event() self.event_reset = trio.Event() self.close_lock = trio.Lock() - self.incoming_data = asyncio.Queue() + self.incoming_data = TrioQueue() self._buf = bytearray() @property def is_initiator(self) -> bool: return self.stream_id.is_initiator - async def _wait_for_data(self) -> None: - task_event_reset = asyncio.ensure_future(self.event_reset.wait()) - task_incoming_data_get = asyncio.ensure_future(self.incoming_data.get()) - task_event_remote_closed = asyncio.ensure_future( - self.event_remote_closed.wait() - ) - done, pending = await asyncio.wait( # type: ignore - [ # type: ignore - task_event_reset, - task_incoming_data_get, - task_event_remote_closed, - ], - return_when=asyncio.FIRST_COMPLETED, - ) - for fut in pending: - fut.cancel() - - if task_event_reset in done: - if self.event_reset.is_set(): - raise MplexStreamReset - else: - # However, it is abnormal that `Event.wait` is unblocked without any of the flag - # is set. The task is probably cancelled. - raise Exception( - "Should not enter here. " - f"It is probably because {task_event_remote_closed} is cancelled." - ) - - if task_incoming_data_get in done: - data = task_incoming_data_get.result() - self._buf.extend(data) - return - - if task_event_remote_closed in done: - if self.event_remote_closed.is_set(): - raise MplexStreamEOF - else: - # However, it is abnormal that `Event.wait` is unblocked without any of the flag - # is set. The task is probably cancelled. - raise Exception( - "Should not enter here. " - f"It is probably because {task_event_remote_closed} is cancelled." - ) - - # TODO: Handle timeout when deadline is used. - - async def _read_until_eof(self) -> bytes: - while True: - try: - await self._wait_for_data() - except MplexStreamEOF: - break - payload = self._buf - self._buf = self._buf[len(payload) :] - return bytes(payload) - async def read(self, n: int = -1) -> bytes: """ Read up to n bytes. Read possibly returns fewer than `n` bytes, if @@ -128,20 +74,7 @@ class MplexStream(IMuxedStream): ) if self.event_reset.is_set(): raise MplexStreamReset - if n == -1: - return await self._read_until_eof() - if len(self._buf) == 0 and self.incoming_data.empty(): - await self._wait_for_data() - # Now we are sure we have something to read. - # Try to put enough incoming data into `self._buf`. - while len(self._buf) < n: - try: - self._buf.extend(self.incoming_data.get_nowait()) - except asyncio.QueueEmpty: - break - payload = self._buf[:n] - self._buf = self._buf[len(payload) :] - return bytes(payload) + return await self.incoming_data.get() async def write(self, data: bytes) -> int: """ diff --git a/libp2p/tools/utils.py b/libp2p/tools/utils.py index 89f864eb..d1c266a0 100644 --- a/libp2p/tools/utils.py +++ b/libp2p/tools/utils.py @@ -17,7 +17,7 @@ from libp2p.typing import StreamHandlerFn, TProtocol from .constants import MAX_READ_LEN -async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm) -> None: +async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm, nursery: trio.Nursery) -> None: peer_id = swarm_1.get_peer_id() addrs = tuple( addr @@ -25,7 +25,7 @@ async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm) -> None: for addr in transport.get_addrs() ) swarm_0.peerstore.add_addrs(peer_id, addrs, 10000) - await swarm_0.dial_peer(peer_id) + await swarm_0.dial_peer(peer_id, nursery) assert swarm_0.get_peer_id() in swarm_1.connections assert swarm_1.get_peer_id() in swarm_0.connections @@ -43,7 +43,9 @@ async def set_up_nodes_by_transport_opt( nodes_list = [] for transport_opt in transport_opt_list: node = new_node(transport_opt=transport_opt) - await node.get_network().listen(multiaddr.Multiaddr(transport_opt[0]), nursery=nursery) + await node.get_network().listen( + multiaddr.Multiaddr(transport_opt[0]), nursery=nursery + ) nodes_list.append(node) return tuple(nodes_list) diff --git a/libp2p/transport/tcp/tcp.py b/libp2p/transport/tcp/tcp.py index 8838b504..0bc1a962 100644 --- a/libp2p/transport/tcp/tcp.py +++ b/libp2p/transport/tcp/tcp.py @@ -1,18 +1,18 @@ import asyncio -import trio +import logging from socket import socket from typing import List from multiaddr import Multiaddr +import trio +from libp2p.io.trio import TrioReadWriteCloser from libp2p.network.connection.raw_connection import RawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.transport.exceptions import OpenConnectionError from libp2p.transport.listener_interface import IListener from libp2p.transport.transport_interface import ITransport from libp2p.transport.typing import THandler -from libp2p.io.trio import TrioReadWriteCloser -import logging logger = logging.getLogger("libp2p.transport.tcp") @@ -44,11 +44,9 @@ class TCPListener(IListener): listeners = await nursery.start( serve_tcp, - *( - handler, - int(maddr.value_for_protocol("tcp")), - maddr.value_for_protocol("ip4"), - ), + handler, + int(maddr.value_for_protocol("tcp")), + maddr.value_for_protocol("ip4"), ) # self.server = await asyncio.start_server( # self.handler, @@ -57,7 +55,6 @@ class TCPListener(IListener): # ) socket = listeners[0].socket self.multiaddrs.append(_multiaddr_from_socket(socket)) - logger.debug("Multiaddrs %s", self.multiaddrs) return True diff --git a/libp2p/transport/typing.py b/libp2p/transport/typing.py index 0f42335e..d68a8aa4 100644 --- a/libp2p/transport/typing.py +++ b/libp2p/transport/typing.py @@ -1,10 +1,9 @@ - from typing import Awaitable, Callable, Mapping, Type +from libp2p.io.abc import ReadWriteCloser from libp2p.security.secure_transport_interface import ISecureTransport from libp2p.stream_muxer.abc import IMuxedConn from libp2p.typing import TProtocol -from libp2p.io.abc import ReadWriteCloser THandler = Callable[[ReadWriteCloser], Awaitable[None]] TSecurityOptions = Mapping[TProtocol, ISecureTransport] diff --git a/libp2p/utils.py b/libp2p/utils.py index bce9e580..27880c48 100644 --- a/libp2p/utils.py +++ b/libp2p/utils.py @@ -1,14 +1,14 @@ import itertools import math +from typing import Generic, TypeVar + +import trio from libp2p.exceptions import ParseError from libp2p.io.abc import Reader from .io.utils import read_exactly -from typing import Generic, TypeVar -import trio - # Unsigned LEB128(varint codec) # Reference: https://github.com/ethereum/py-wasm/blob/master/wasm/parsers/leb128.py diff --git a/tests/libp2p/test_libp2p.py b/tests/libp2p/test_libp2p.py index 9f78dce5..628e008a 100644 --- a/tests/libp2p/test_libp2p.py +++ b/tests/libp2p/test_libp2p.py @@ -1,6 +1,6 @@ -import trio import multiaddr import pytest +import trio from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.tools.constants import MAX_READ_LEN @@ -24,11 +24,11 @@ async def test_simple_messages(nursery): # Associate the peer with local ip address (see default parameters of Libp2p()) node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) - stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) messages = ["hello" + str(x) for x in range(10)] for message in messages: + await stream.write(message.encode()) response = (await stream.read(MAX_READ_LEN)).decode() diff --git a/tests/stream_muxer/test_mplex_stream.py b/tests/stream_muxer/test_mplex_stream.py index f3458d8f..27c7f45e 100644 --- a/tests/stream_muxer/test_mplex_stream.py +++ b/tests/stream_muxer/test_mplex_stream.py @@ -1,20 +1,31 @@ import asyncio import pytest +import trio from libp2p.stream_muxer.mplex.exceptions import ( MplexStreamClosed, MplexStreamEOF, MplexStreamReset, ) -from libp2p.tools.constants import MAX_READ_LEN +from libp2p.tools.constants import MAX_READ_LEN, LISTEN_MADDR +from libp2p.tools.factories import SwarmFactory +from libp2p.tools.utils import connect_swarm DATA = b"data_123" -@pytest.mark.asyncio -async def test_mplex_stream_read_write(mplex_stream_pair): - stream_0, stream_1 = mplex_stream_pair +@pytest.mark.trio +async def test_mplex_stream_read_write(nursery): + swarm0, swarm1 = SwarmFactory(), SwarmFactory() + await swarm0.listen(LISTEN_MADDR, nursery=nursery) + await swarm1.listen(LISTEN_MADDR, nursery=nursery) + await connect_swarm(swarm0, swarm1, nursery) + conn_0 = swarm0.connections[swarm1.get_peer_id()] + conn_1 = swarm1.connections[swarm0.get_peer_id()] + stream_0 = await conn_0.muxed_conn.open_stream() + await trio.sleep(1) + stream_1 = tuple(conn_1.muxed_conn.streams.values())[0] await stream_0.write(DATA) assert (await stream_1.read(MAX_READ_LEN)) == DATA diff --git a/tests/test_utils.py b/tests/test_utils.py index a6284371..7bd807d5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,6 @@ -import trio import pytest +import trio + from libp2p.utils import TrioQueue @@ -16,4 +17,3 @@ async def test_trio_queue(): result = await nursery.start(queue_get) assert result == 123 -