From eb494e8682334f77d1cffaf5a2369281b8de0a85 Mon Sep 17 00:00:00 2001 From: mhchia Date: Sun, 1 Dec 2019 19:17:32 +0800 Subject: [PATCH] Fix ping protocol --- libp2p/host/ping.py | 35 ++++++++++++----------- libp2p/stream_muxer/mplex/mplex.py | 1 - libp2p/stream_muxer/mplex/mplex_stream.py | 3 +- tests/host/test_ping.py | 17 +++++------ 4 files changed, 29 insertions(+), 27 deletions(-) diff --git a/libp2p/host/ping.py b/libp2p/host/ping.py index 3144ef4d..589fc917 100644 --- a/libp2p/host/ping.py +++ b/libp2p/host/ping.py @@ -1,4 +1,4 @@ -import asyncio +import trio import logging from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset @@ -16,22 +16,23 @@ logger = logging.getLogger("libp2p.host.ping") async def _handle_ping(stream: INetStream, peer_id: PeerID) -> bool: """Return a boolean indicating if we expect more pings from the peer at ``peer_id``.""" - try: - payload = await asyncio.wait_for(stream.read(PING_LENGTH), RESP_TIMEOUT) - except asyncio.TimeoutError as error: - logger.debug("Timed out waiting for ping from %s: %s", peer_id, error) - raise - except StreamEOF: - logger.debug("Other side closed while waiting for ping from %s", peer_id) - return False - except StreamReset as error: - logger.debug( - "Other side reset while waiting for ping from %s: %s", peer_id, error - ) - raise - except Exception as error: - logger.debug("Error while waiting to read ping for %s: %s", peer_id, error) - raise + with trio.fail_after(RESP_TIMEOUT): + try: + payload = await stream.read(PING_LENGTH) + except trio.TooSlowError as error: + logger.debug("Timed out waiting for ping from %s: %s", peer_id, error) + raise + except StreamEOF: + logger.debug("Other side closed while waiting for ping from %s", peer_id) + return False + except StreamReset as error: + logger.debug( + "Other side reset while waiting for ping from %s: %s", peer_id, error + ) + raise + except Exception as error: + logger.debug("Error while waiting to read ping for %s: %s", peer_id, error) + raise logger.debug("Received ping from %s with data: 0x%s", peer_id, payload.hex()) diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index f93acea7..6d8a64e6 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -1,4 +1,3 @@ -import asyncio import logging import math from typing import Any # noqa: F401 diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 6ecc4077..eeefc422 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -104,6 +104,7 @@ class MplexStream(IMuxedStream): # and then return. try: data = self.incoming_data_channel.receive_nowait() + self._buf.extend(data) except trio.EndOfChannel: raise MplexStreamEOF except trio.WouldBlock: @@ -111,6 +112,7 @@ class MplexStream(IMuxedStream): # catch all kinds of errors here. try: data = await self.incoming_data_channel.receive() + self._buf.extend(data) except trio.EndOfChannel: if self.event_reset.is_set(): raise MplexStreamReset @@ -125,7 +127,6 @@ class MplexStream(IMuxedStream): "`incoming_data_channel` is closed but stream is not reset. " "This should never happen." ) from error - self._buf.extend(data) self._buf.extend(self._read_return_when_blocked()) payload = self._buf[:n] self._buf = self._buf[len(payload) :] diff --git a/tests/host/test_ping.py b/tests/host/test_ping.py index 1bd02f0f..29135141 100644 --- a/tests/host/test_ping.py +++ b/tests/host/test_ping.py @@ -1,4 +1,4 @@ -import asyncio +import trio import secrets import pytest @@ -7,12 +7,13 @@ from libp2p.host.ping import ID, PING_LENGTH from libp2p.tools.factories import host_pair_factory -@pytest.mark.asyncio -async def test_ping_once(): - async with host_pair_factory() as (host_a, host_b): +@pytest.mark.trio +async def test_ping_once(is_host_secure): + async with host_pair_factory(is_host_secure) as (host_a, host_b): stream = await host_b.new_stream(host_a.get_id(), (ID,)) some_ping = secrets.token_bytes(PING_LENGTH) await stream.write(some_ping) + await trio.sleep(0.01) some_pong = await stream.read(PING_LENGTH) assert some_ping == some_pong await stream.close() @@ -21,9 +22,9 @@ async def test_ping_once(): SOME_PING_COUNT = 3 -@pytest.mark.asyncio -async def test_ping_several(): - async with host_pair_factory() as (host_a, host_b): +@pytest.mark.trio +async def test_ping_several(is_host_secure): + async with host_pair_factory(is_host_secure) as (host_a, host_b): stream = await host_b.new_stream(host_a.get_id(), (ID,)) for _ in range(SOME_PING_COUNT): some_ping = secrets.token_bytes(PING_LENGTH) @@ -33,5 +34,5 @@ async def test_ping_several(): # NOTE: simulate some time to sleep to mirror a real # world usage where a peer sends pings on some periodic interval # NOTE: this interval can be `0` for this test. - await asyncio.sleep(0) + await trio.sleep(0) await stream.close()