diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index d12d6721..7469d33c 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -168,7 +168,11 @@ class BasicHost(IHost): protocol, handler = await self.multiselect.negotiate( MultiselectCommunicator(net_stream) ) - except MultiselectError: + except MultiselectError as error: + peer_id = net_stream.muxed_conn.peer_id + logger.debug( + "failed to accept a stream from peer %s, error=%s", peer_id, error + ) await net_stream.reset() return net_stream.set_protocol(protocol) diff --git a/libp2p/identity/identify/protocol.py b/libp2p/identity/identify/protocol.py index 390c0de2..87d946c0 100644 --- a/libp2p/identity/identify/protocol.py +++ b/libp2p/identity/identify/protocol.py @@ -19,24 +19,28 @@ def _multiaddr_to_bytes(maddr: Multiaddr) -> bytes: return maddr.to_bytes() +def _mk_identify_protobuf(host: IHost) -> Identify: + public_key = host.get_public_key() + laddrs = host.get_addrs() + protocols = host.get_mux().get_protocols() + + return Identify( + protocol_version=PROTOCOL_VERSION, + agent_version=AGENT_VERSION, + public_key=public_key.serialize(), + listen_addrs=map(_multiaddr_to_bytes, laddrs), + # TODO send observed address from ``stream`` + observed_addr=b"", + protocols=protocols, + ) + + def identify_handler_for(host: IHost) -> StreamHandlerFn: async def handle_identify(stream: INetStream) -> None: peer_id = stream.muxed_conn.peer_id logger.debug("received a request for %s from %s", ID, peer_id) - public_key = host.get_public_key() - laddrs = host.get_addrs() - protocols = host.get_mux().get_protocols() - - protobuf = Identify( - protocol_version=PROTOCOL_VERSION, - agent_version=AGENT_VERSION, - public_key=public_key.serialize(), - listen_addrs=map(_multiaddr_to_bytes, laddrs), - # TODO send observed address from ``stream`` - observed_addr=b"", - protocols=protocols, - ) + protobuf = _mk_identify_protobuf(host) response = protobuf.SerializeToString() await stream.write(response) diff --git a/libp2p/security/secio/transport.py b/libp2p/security/secio/transport.py index 6a4c8673..cf807bf4 100644 --- a/libp2p/security/secio/transport.py +++ b/libp2p/security/secio/transport.py @@ -111,6 +111,9 @@ class SecureSession(BaseSession): self.high_watermark = len(msg) async def read(self, n: int = -1) -> bytes: + if n == 0: + return bytes() + data_from_buffer = self._drain(n) if len(data_from_buffer) > 0: return data_from_buffer diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index c8d7a209..2228360c 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -1,4 +1,5 @@ import asyncio +import logging from typing import Any # noqa: F401 from typing import Awaitable, Dict, List, Optional, Tuple @@ -23,6 +24,8 @@ from .mplex_stream import MplexStream MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0") +logger = logging.getLogger("libp2p.stream_muxer.mplex.mplex") + class Mplex(IMuxedConn): """ @@ -181,7 +184,8 @@ class Mplex(IMuxedConn): while True: try: await self._handle_incoming_message() - except MplexUnavailable: + except MplexUnavailable as e: + logger.debug("mplex unavailable while waiting for incoming: %s", e) break # Force context switch await asyncio.sleep(0) diff --git a/tests/factories.py b/tests/factories.py index d59f2278..b1ac527f 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -1,4 +1,5 @@ import asyncio +from contextlib import asynccontextmanager from typing import Dict, Tuple import factory @@ -163,6 +164,14 @@ async def host_pair_factory(is_secure) -> Tuple[BasicHost, BasicHost]: return hosts[0], hosts[1] +@asynccontextmanager +async def pair_of_connected_hosts(is_secure=True): + a, b = await host_pair_factory(is_secure) + yield a, b + close_tasks = (a.close(), b.close()) + await asyncio.gather(*close_tasks) + + async def swarm_conn_pair_factory( is_secure: bool, muxer_opt: TMuxerOptions = None ) -> Tuple[SwarmConn, Swarm, SwarmConn, Swarm]: diff --git a/tests/host/test_ping.py b/tests/host/test_ping.py index a5296f13..37975a86 100644 --- a/tests/host/test_ping.py +++ b/tests/host/test_ping.py @@ -4,25 +4,18 @@ import secrets import pytest from libp2p.host.ping import ID, PING_LENGTH -from libp2p.peer.peerinfo import info_from_p2p_addr -from tests.utils import set_up_nodes_by_transport_opt +from tests.factories import pair_of_connected_hosts @pytest.mark.asyncio async def test_ping_once(): - transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] - (host_a, host_b) = await set_up_nodes_by_transport_opt(transport_opt_list) - - addr = host_a.get_addrs()[0] - info = info_from_p2p_addr(addr) - await host_b.connect(info) - - stream = await host_b.new_stream(host_a.get_id(), (ID,)) - some_ping = secrets.token_bytes(PING_LENGTH) - await stream.write(some_ping) - some_pong = await stream.read(PING_LENGTH) - assert some_ping == some_pong - await stream.close() + async with pair_of_connected_hosts() as (host_a, host_b): + stream = await host_b.new_stream(host_a.get_id(), (ID,)) + some_ping = secrets.token_bytes(PING_LENGTH) + await stream.write(some_ping) + some_pong = await stream.read(PING_LENGTH) + assert some_ping == some_pong + await stream.close() SOME_PING_COUNT = 3 @@ -30,21 +23,15 @@ SOME_PING_COUNT = 3 @pytest.mark.asyncio async def test_ping_several(): - transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] - (host_a, host_b) = await set_up_nodes_by_transport_opt(transport_opt_list) - - addr = host_a.get_addrs()[0] - info = info_from_p2p_addr(addr) - await host_b.connect(info) - - stream = await host_b.new_stream(host_a.get_id(), (ID,)) - for _ in range(SOME_PING_COUNT): - some_ping = secrets.token_bytes(PING_LENGTH) - await stream.write(some_ping) - some_pong = await stream.read(PING_LENGTH) - assert some_ping == some_pong - # NOTE: simulate some time to sleep to mirror a real - # world usage where a peer sends pings on some periodic interval - # NOTE: this interval can be `0` for this test. - await asyncio.sleep(0) - await stream.close() + async with pair_of_connected_hosts() as (host_a, host_b): + stream = await host_b.new_stream(host_a.get_id(), (ID,)) + for _ in range(SOME_PING_COUNT): + some_ping = secrets.token_bytes(PING_LENGTH) + await stream.write(some_ping) + some_pong = await stream.read(PING_LENGTH) + assert some_ping == some_pong + # NOTE: simulate some time to sleep to mirror a real + # world usage where a peer sends pings on some periodic interval + # NOTE: this interval can be `0` for this test. + await asyncio.sleep(0) + await stream.close() diff --git a/tests/identity/identify/test_protocol.py b/tests/identity/identify/test_protocol.py new file mode 100644 index 00000000..e36e7ca5 --- /dev/null +++ b/tests/identity/identify/test_protocol.py @@ -0,0 +1,17 @@ +import pytest + +from libp2p.identity.identify.pb.identify_pb2 import Identify +from libp2p.identity.identify.protocol import ID, _mk_identify_protobuf +from tests.factories import pair_of_connected_hosts + + +@pytest.mark.asyncio +async def test_identify_protocol(): + async with pair_of_connected_hosts() as (host_a, host_b): + stream = await host_b.new_stream(host_a.get_id(), (ID,)) + response = await stream.read() + await stream.close() + + identify_response = Identify() + identify_response.ParseFromString(response) + assert identify_response == _mk_identify_protobuf(host_a)