diff --git a/libp2p/io/trio.py b/libp2p/io/trio.py new file mode 100644 index 00000000..6e88f01d --- /dev/null +++ b/libp2p/io/trio.py @@ -0,0 +1,32 @@ +import trio +from trio import SocketStream +from libp2p.io.abc import ReadWriteCloser +from libp2p.io.exceptions import IOException +import logging + + +logger = logging.getLogger("libp2p.io.trio") + + +class TrioReadWriteCloser(ReadWriteCloser): + stream: SocketStream + + def __init__(self, stream: SocketStream) -> None: + self.stream = stream + + async def write(self, data: bytes) -> None: + """Raise `RawConnError` if the underlying connection breaks.""" + try: + await self.stream.send_all(data) + except (trio.ClosedResourceError, trio.BrokenResourceError) as error: + raise IOException(error) + + async def read(self, n: int = -1) -> bytes: + max_bytes = n if n != -1 else None + try: + return await self.stream.receive_some(max_bytes) + except (trio.ClosedResourceError, trio.BrokenResourceError) as error: + raise IOException(error) + + async def close(self) -> None: + await self.stream.aclose() diff --git a/libp2p/network/connection/raw_connection.py b/libp2p/network/connection/raw_connection.py index 08d22055..50b28984 100644 --- a/libp2p/network/connection/raw_connection.py +++ b/libp2p/network/connection/raw_connection.py @@ -1,42 +1,25 @@ -import asyncio +import trio +from libp2p.io.exceptions import IOException from .exceptions import RawConnError from .raw_connection_interface import IRawConnection +from libp2p.io.abc import ReadWriteCloser class RawConnection(IRawConnection): - reader: asyncio.StreamReader - writer: asyncio.StreamWriter + read_write_closer: ReadWriteCloser is_initiator: bool - _drain_lock: asyncio.Lock - - def __init__( - self, - reader: asyncio.StreamReader, - writer: asyncio.StreamWriter, - initiator: bool, - ) -> None: - self.reader = reader - self.writer = writer + def __init__(self, read_write_closer: ReadWriteCloser, initiator: bool) -> None: + self.read_write_closer = read_write_closer self.is_initiator = initiator - self._drain_lock = asyncio.Lock() - async def write(self, data: bytes) -> None: """Raise `RawConnError` if the underlying connection breaks.""" try: - self.writer.write(data) - except ConnectionResetError as error: + await self.read_write_closer.write(data) + except IOException as error: raise RawConnError(error) - # Reference: https://github.com/ethereum/lahja/blob/93610b2eb46969ff1797e0748c7ac2595e130aef/lahja/asyncio/endpoint.py#L99-L102 # noqa: E501 - # Use a lock to serialize drain() calls. Circumvents this bug: - # https://bugs.python.org/issue29930 - async with self._drain_lock: - try: - await self.writer.drain() - except ConnectionResetError as error: - raise RawConnError(error) async def read(self, n: int = -1) -> bytes: """ @@ -46,10 +29,9 @@ class RawConnection(IRawConnection): Raise `RawConnError` if the underlying connection breaks """ try: - return await self.reader.read(n) - except ConnectionResetError as error: + return await self.read_write_closer.read(n) + except IOException as error: raise RawConnError(error) async def close(self) -> None: - self.writer.close() - await self.writer.wait_closed() + await self.read_write_closer.close() diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 7bb40cee..9b89fa56 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -4,6 +4,7 @@ from typing import Dict, List, Optional from multiaddr import Multiaddr +from libp2p.io.abc import ReadWriteCloser from libp2p.network.connection.net_connection_interface import INetConn from libp2p.peer.id import ID from libp2p.peer.peerstore import PeerStoreError @@ -149,7 +150,7 @@ class Swarm(INetwork): logger.debug("successfully opened a stream to peer %s", peer_id) return net_stream - async def listen(self, *multiaddrs: Multiaddr) -> bool: + async def listen(self, *multiaddrs: Multiaddr, nursery) -> bool: """ :param multiaddrs: one or many multiaddrs to start listening on :return: true if at least one success @@ -167,15 +168,8 @@ class Swarm(INetwork): if str(maddr) in self.listeners: return True - async def conn_handler( - reader: asyncio.StreamReader, writer: asyncio.StreamWriter - ) -> None: - connection_info = writer.get_extra_info("peername") - # TODO make a proper multiaddr - peer_addr = f"/ip4/{connection_info[0]}/tcp/{connection_info[1]}" - logger.debug("inbound connection at %s", peer_addr) - # logger.debug("inbound connection request", peer_id) - raw_conn = RawConnection(reader, writer, False) + async def conn_handler(read_write_closer: ReadWriteCloser) -> None: + raw_conn = RawConnection(read_write_closer, False) # Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure # the conn and then mux the conn @@ -185,14 +179,10 @@ class Swarm(INetwork): raw_conn, ID(b""), False ) except SecurityUpgradeFailure as error: - error_msg = "fail to upgrade security for peer at %s" - logger.debug(error_msg, peer_addr) await raw_conn.close() - raise SwarmException(error_msg % peer_addr) from error + raise SwarmException() from error peer_id = secured_conn.get_remote_peer() - logger.debug("upgraded security for peer at %s", peer_addr) - logger.debug("identified peer at %s as %s", peer_addr, peer_id) try: muxed_conn = await self.upgrader.upgrade_connection( @@ -213,7 +203,7 @@ class Swarm(INetwork): # Success listener = self.transport.create_listener(conn_handler) self.listeners[str(maddr)] = listener - await listener.listen(maddr) + await listener.listen(maddr, nursery) # Call notifiers since event occurred self.notify_listen(maddr) diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 1a43c7cb..1f81c2f2 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -123,27 +123,10 @@ class Mplex(IMuxedConn): await self.send_message(HeaderTags.NewStream, name.encode(), stream_id) return stream - async def _wait_until_shutting_down_or_closed(self, coro: Awaitable[Any]) -> Any: - task_coro = asyncio.ensure_future(coro) - task_wait_closed = asyncio.ensure_future(self.event_closed.wait()) - task_wait_shutting_down = asyncio.ensure_future(self.event_shutting_down.wait()) - done, pending = await asyncio.wait( - [task_coro, task_wait_closed, task_wait_shutting_down], - return_when=asyncio.FIRST_COMPLETED, - ) - for fut in pending: - fut.cancel() - if task_wait_closed in done: - raise MplexUnavailable("Mplex is closed") - if task_wait_shutting_down in done: - raise MplexUnavailable("Mplex is shutting down") - return task_coro.result() async def accept_stream(self) -> IMuxedStream: """accepts a muxed stream opened by the other end.""" - return await self._wait_until_shutting_down_or_closed( - self.new_stream_queue.get() - ) + return await self.new_stream_queue.get() async def send_message( self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID @@ -163,9 +146,7 @@ class Mplex(IMuxedConn): _bytes = header + encode_varint_prefixed(data) - return await self._wait_until_shutting_down_or_closed( - self.write_to_stream(_bytes) - ) + return await self.write_to_stream(_bytes) async def write_to_stream(self, _bytes: bytes) -> int: """ @@ -226,9 +207,7 @@ class Mplex(IMuxedConn): :raise MplexUnavailable: `Mplex` encounters fatal error or is shutting down. """ - channel_id, flag, message = await self._wait_until_shutting_down_or_closed( - self.read_message() - ) + channel_id, flag, message = await self.read_message() stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1)) if flag == HeaderTags.NewStream.value: @@ -258,9 +237,7 @@ class Mplex(IMuxedConn): f"received NewStream message for existing stream: {stream_id}" ) mplex_stream = await self._initialize_stream(stream_id, message.decode()) - await self._wait_until_shutting_down_or_closed( - self.new_stream_queue.put(mplex_stream) - ) + await self.new_stream_queue.put(mplex_stream) async def _handle_message(self, stream_id: StreamID, message: bytes) -> None: async with self.streams_lock: @@ -274,9 +251,7 @@ class Mplex(IMuxedConn): if stream.event_remote_closed.is_set(): # TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501 return - await self._wait_until_shutting_down_or_closed( - stream.incoming_data.put(message) - ) + await stream.incoming_data.put(message) async def _handle_close(self, stream_id: StreamID) -> None: async with self.streams_lock: diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index f080d3cf..7659d32f 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -1,3 +1,4 @@ +import trio import asyncio from typing import TYPE_CHECKING @@ -22,14 +23,14 @@ class MplexStream(IMuxedStream): read_deadline: int write_deadline: int - close_lock: asyncio.Lock + close_lock: trio.Lock # NOTE: `dataIn` is size of 8 in Go implementation. incoming_data: "asyncio.Queue[bytes]" - event_local_closed: asyncio.Event - event_remote_closed: asyncio.Event - event_reset: asyncio.Event + event_local_closed: trio.Event + event_remote_closed: trio.Event + event_reset: trio.Event _buf: bytearray @@ -45,10 +46,10 @@ class MplexStream(IMuxedStream): self.muxed_conn = muxed_conn self.read_deadline = None self.write_deadline = None - self.event_local_closed = asyncio.Event() - self.event_remote_closed = asyncio.Event() - self.event_reset = asyncio.Event() - self.close_lock = asyncio.Lock() + self.event_local_closed = trio.Event() + self.event_remote_closed = trio.Event() + self.event_reset = trio.Event() + self.close_lock = trio.Lock() self.incoming_data = asyncio.Queue() self._buf = bytearray() @@ -199,10 +200,11 @@ class MplexStream(IMuxedStream): if self.is_initiator else HeaderTags.ResetReceiver ) - asyncio.ensure_future( - self.muxed_conn.send_message(flag, None, self.stream_id) - ) - await asyncio.sleep(0) + async with trio.open_nursery() as nursery: + nursery.start_soon( + self.muxed_conn.send_message, flag, None, self.stream_id + ) + await trio.sleep(0) self.event_local_closed.set() self.event_remote_closed.set() diff --git a/libp2p/tools/utils.py b/libp2p/tools/utils.py index 84e3edf9..89f864eb 100644 --- a/libp2p/tools/utils.py +++ b/libp2p/tools/utils.py @@ -1,3 +1,4 @@ +import trio from typing import List, Sequence, Tuple import multiaddr @@ -37,12 +38,12 @@ async def connect(node1: IHost, node2: IHost) -> None: async def set_up_nodes_by_transport_opt( - transport_opt_list: Sequence[Sequence[str]] + transport_opt_list: Sequence[Sequence[str]], nursery: trio.Nursery ) -> Tuple[BasicHost, ...]: nodes_list = [] for transport_opt in transport_opt_list: - node = await new_node(transport_opt=transport_opt) - await node.get_network().listen(multiaddr.Multiaddr(transport_opt[0])) + node = new_node(transport_opt=transport_opt) + await node.get_network().listen(multiaddr.Multiaddr(transport_opt[0]), nursery=nursery) nodes_list.append(node) return tuple(nodes_list) diff --git a/libp2p/transport/tcp/tcp.py b/libp2p/transport/tcp/tcp.py index 7470510d..8838b504 100644 --- a/libp2p/transport/tcp/tcp.py +++ b/libp2p/transport/tcp/tcp.py @@ -1,4 +1,5 @@ import asyncio +import trio from socket import socket from typing import List @@ -10,6 +11,10 @@ from libp2p.transport.exceptions import OpenConnectionError from libp2p.transport.listener_interface import IListener from libp2p.transport.transport_interface import ITransport from libp2p.transport.typing import THandler +from libp2p.io.trio import TrioReadWriteCloser +import logging + +logger = logging.getLogger("libp2p.transport.tcp") class TCPListener(IListener): @@ -21,20 +26,38 @@ class TCPListener(IListener): self.server = None self.handler = handler_function - async def listen(self, maddr: Multiaddr) -> bool: + async def listen(self, maddr: Multiaddr, nursery) -> bool: """ put listener in listening mode and wait for incoming connections. :param maddr: maddr of peer :return: return True if successful """ - self.server = await asyncio.start_server( - self.handler, - maddr.value_for_protocol("ip4"), - maddr.value_for_protocol("tcp"), + + async def serve_tcp(handler, port, host, task_status=None): + logger.debug("serve_tcp %s %s", host, port) + await trio.serve_tcp(handler, port, host=host, task_status=task_status) + + async def handler(stream): + read_write_closer = TrioReadWriteCloser(stream) + await self.handler(read_write_closer) + + listeners = await nursery.start( + serve_tcp, + *( + handler, + int(maddr.value_for_protocol("tcp")), + maddr.value_for_protocol("ip4"), + ), ) - socket = self.server.sockets[0] + # self.server = await asyncio.start_server( + # self.handler, + # maddr.value_for_protocol("ip4"), + # maddr.value_for_protocol("tcp"), + # ) + socket = listeners[0].socket self.multiaddrs.append(_multiaddr_from_socket(socket)) + logger.debug("Multiaddrs %s", self.multiaddrs) return True @@ -69,12 +92,10 @@ class TCP(ITransport): self.host = maddr.value_for_protocol("ip4") self.port = int(maddr.value_for_protocol("tcp")) - try: - reader, writer = await asyncio.open_connection(self.host, self.port) - except (ConnectionAbortedError, ConnectionRefusedError) as error: - raise OpenConnectionError(error) + stream = await trio.open_tcp_stream(self.host, self.port) + read_write_closer = TrioReadWriteCloser(stream) - return RawConnection(reader, writer, True) + return RawConnection(read_write_closer, True) def create_listener(self, handler_function: THandler) -> TCPListener: """ diff --git a/libp2p/transport/typing.py b/libp2p/transport/typing.py index f9b31dcb..0f42335e 100644 --- a/libp2p/transport/typing.py +++ b/libp2p/transport/typing.py @@ -1,11 +1,12 @@ -from asyncio import StreamReader, StreamWriter + from typing import Awaitable, Callable, Mapping, Type from libp2p.security.secure_transport_interface import ISecureTransport from libp2p.stream_muxer.abc import IMuxedConn from libp2p.typing import TProtocol +from libp2p.io.abc import ReadWriteCloser -THandler = Callable[[StreamReader, StreamWriter], Awaitable[None]] +THandler = Callable[[ReadWriteCloser], Awaitable[None]] TSecurityOptions = Mapping[TProtocol, ISecureTransport] TMuxerClass = Type[IMuxedConn] TMuxerOptions = Mapping[TProtocol, TMuxerClass] diff --git a/tests/libp2p/test_libp2p.py b/tests/libp2p/test_libp2p.py index 330250c4..9f78dce5 100644 --- a/tests/libp2p/test_libp2p.py +++ b/tests/libp2p/test_libp2p.py @@ -1,3 +1,4 @@ +import trio import multiaddr import pytest @@ -6,10 +7,10 @@ from libp2p.tools.constants import MAX_READ_LEN from libp2p.tools.utils import set_up_nodes_by_transport_opt -@pytest.mark.asyncio -async def test_simple_messages(): +@pytest.mark.trio +async def test_simple_messages(nursery): transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] - (node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list) + (node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list, nursery) async def stream_handler(stream): while True: @@ -23,6 +24,7 @@ async def test_simple_messages(): # Associate the peer with local ip address (see default parameters of Libp2p()) node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) + stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) messages = ["hello" + str(x) for x in range(10)]