reorg test structure to match tox and CI jobs, drop bumpversion for bump-my-version and move config to pyproject.toml, fix docs building

This commit is contained in:
pacrob
2024-04-08 09:06:04 -06:00
committed by Paul Robinson
parent 1206fbef3d
commit eea065fb57
62 changed files with 319 additions and 160 deletions

View File

@ -0,0 +1,27 @@
from libp2p.crypto.ed25519 import (
create_new_key_pair,
)
from libp2p.crypto.serialization import (
deserialize_private_key,
deserialize_public_key,
)
def test_public_key_serialize_deserialize_round_trip():
key_pair = create_new_key_pair()
public_key = key_pair.public_key
public_key_bytes = public_key.serialize()
another_public_key = deserialize_public_key(public_key_bytes)
assert public_key == another_public_key
def test_private_key_serialize_deserialize_round_trip():
key_pair = create_new_key_pair()
private_key = key_pair.private_key
private_key_bytes = private_key.serialize()
another_private_key = deserialize_private_key(private_key_bytes)
assert private_key == another_private_key

View File

@ -0,0 +1,27 @@
from libp2p.crypto.secp256k1 import (
create_new_key_pair,
)
from libp2p.crypto.serialization import (
deserialize_private_key,
deserialize_public_key,
)
def test_public_key_serialize_deserialize_round_trip():
key_pair = create_new_key_pair()
public_key = key_pair.public_key
public_key_bytes = public_key.serialize()
another_public_key = deserialize_public_key(public_key_bytes)
assert public_key == another_public_key
def test_private_key_serialize_deserialize_round_trip():
key_pair = create_new_key_pair()
private_key = key_pair.private_key
private_key_bytes = private_key.serialize()
another_private_key = deserialize_private_key(private_key_bytes)
assert private_key == another_private_key

View File

@ -0,0 +1,112 @@
import pytest
import trio
from libp2p.host.exceptions import (
StreamFailure,
)
from libp2p.peer.peerinfo import (
info_from_p2p_addr,
)
from libp2p.tools.factories import (
HostFactory,
)
from libp2p.tools.utils import (
MAX_READ_LEN,
)
PROTOCOL_ID = "/chat/1.0.0"
async def hello_world(host_a, host_b):
hello_world_from_host_a = b"hello world from host a"
hello_world_from_host_b = b"hello world from host b"
async def stream_handler(stream):
read = await stream.read(len(hello_world_from_host_b))
assert read == hello_world_from_host_b
await stream.write(hello_world_from_host_a)
await stream.close()
host_a.set_stream_handler(PROTOCOL_ID, stream_handler)
# Start a stream with the destination.
# Multiaddress of the destination peer is fetched from the peerstore using 'peerId'.
stream = await host_b.new_stream(host_a.get_id(), [PROTOCOL_ID])
await stream.write(hello_world_from_host_b)
read = await stream.read(MAX_READ_LEN)
assert read == hello_world_from_host_a
await stream.close()
async def connect_write(host_a, host_b):
messages = ["data %d" % i for i in range(5)]
received = []
async def stream_handler(stream):
for message in messages:
received.append((await stream.read(len(message))).decode())
host_a.set_stream_handler(PROTOCOL_ID, stream_handler)
# Start a stream with the destination.
# Multiaddress of the destination peer is fetched from the peerstore using 'peerId'.
stream = await host_b.new_stream(host_a.get_id(), [PROTOCOL_ID])
for message in messages:
await stream.write(message.encode())
# Reader needs time due to async reads
await trio.sleep(2)
await stream.close()
assert received == messages
async def connect_read(host_a, host_b):
messages = [b"data %d" % i for i in range(5)]
async def stream_handler(stream):
for message in messages:
await stream.write(message)
await stream.close()
host_a.set_stream_handler(PROTOCOL_ID, stream_handler)
# Start a stream with the destination.
# Multiaddress of the destination peer is fetched from the peerstore using 'peerId'.
stream = await host_b.new_stream(host_a.get_id(), [PROTOCOL_ID])
received = []
for message in messages:
received.append(await stream.read(len(message)))
await stream.close()
assert received == messages
async def no_common_protocol(host_a, host_b):
messages = [b"data %d" % i for i in range(5)]
async def stream_handler(stream):
for message in messages:
await stream.write(message)
await stream.close()
host_a.set_stream_handler(PROTOCOL_ID, stream_handler)
# try to creates a new new with a procotol not known by the other host
with pytest.raises(StreamFailure):
await host_b.new_stream(host_a.get_id(), ["/fakeproto/0.0.1"])
@pytest.mark.parametrize(
"test", [(hello_world), (connect_write), (connect_read), (no_common_protocol)]
)
@pytest.mark.trio
async def test_chat(test, security_protocol):
print("!@# ", security_protocol)
async with HostFactory.create_batch_and_listen(
2, security_protocol=security_protocol
) as hosts:
addr = hosts[0].get_addrs()[0]
info = info_from_p2p_addr(addr)
await hosts[1].connect(info)
await test(hosts[0], hosts[1])

View File

@ -0,0 +1,24 @@
from libp2p import (
new_swarm,
)
from libp2p.crypto.rsa import (
create_new_key_pair,
)
from libp2p.host.basic_host import (
BasicHost,
)
from libp2p.host.defaults import (
get_default_protocols,
)
def test_default_protocols():
key_pair = create_new_key_pair()
swarm = new_swarm(key_pair)
host = BasicHost(swarm)
mux = host.get_mux()
handlers = mux.handlers
# NOTE: comparing keys for equality as handlers may be closures that do not compare
# in the way this test is concerned with
assert handlers.keys() == get_default_protocols(host).keys()

View File

@ -0,0 +1,49 @@
import secrets
import pytest
import trio
from libp2p.host.ping import (
ID,
PING_LENGTH,
)
from libp2p.tools.factories import (
host_pair_factory,
)
@pytest.mark.trio
async def test_ping_once(security_protocol):
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
stream = await host_b.new_stream(host_a.get_id(), (ID,))
some_ping = secrets.token_bytes(PING_LENGTH)
await stream.write(some_ping)
await trio.sleep(0.01)
some_pong = await stream.read(PING_LENGTH)
assert some_ping == some_pong
await stream.close()
SOME_PING_COUNT = 3
@pytest.mark.trio
async def test_ping_several(security_protocol):
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
stream = await host_b.new_stream(host_a.get_id(), (ID,))
for _ in range(SOME_PING_COUNT):
some_ping = secrets.token_bytes(PING_LENGTH)
await stream.write(some_ping)
some_pong = await stream.read(PING_LENGTH)
assert some_ping == some_pong
# NOTE: simulate some time to sleep to mirror a real
# world usage where a peer sends pings on some periodic interval
# NOTE: this interval can be `0` for this test.
await trio.sleep(0)
await stream.close()

View File

@ -0,0 +1,32 @@
import pytest
from libp2p.host.exceptions import (
ConnectionFailure,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
from libp2p.tools.factories import (
HostFactory,
RoutedHostFactory,
)
@pytest.mark.trio
async def test_host_routing_success():
async with RoutedHostFactory.create_batch_and_listen(2) as hosts:
# forces to use routing as no addrs are provided
await hosts[0].connect(PeerInfo(hosts[1].get_id(), []))
await hosts[1].connect(PeerInfo(hosts[0].get_id(), []))
@pytest.mark.trio
async def test_host_routing_fail():
async with RoutedHostFactory.create_batch_and_listen(
2
) as routed_hosts, HostFactory.create_batch_and_listen(1) as basic_hosts:
# routing fails because host_c does not use routing
with pytest.raises(ConnectionFailure):
await routed_hosts[0].connect(PeerInfo(basic_hosts[0].get_id(), []))
with pytest.raises(ConnectionFailure):
await routed_hosts[1].connect(PeerInfo(basic_hosts[0].get_id(), []))

View File

@ -0,0 +1,27 @@
import pytest
from libp2p.identity.identify.pb.identify_pb2 import (
Identify,
)
from libp2p.identity.identify.protocol import (
ID,
_mk_identify_protobuf,
)
from libp2p.tools.factories import (
host_pair_factory,
)
@pytest.mark.trio
async def test_identify_protocol(security_protocol):
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read()
await stream.close()
identify_response = Identify()
identify_response.ParseFromString(response)
assert identify_response == _mk_identify_protobuf(host_a)

View File

@ -0,0 +1,29 @@
import pytest
from libp2p.tools.factories import (
net_stream_pair_factory,
swarm_conn_pair_factory,
swarm_pair_factory,
)
@pytest.fixture
async def net_stream_pair(security_protocol):
async with net_stream_pair_factory(
security_protocol=security_protocol
) as net_stream_pair:
yield net_stream_pair
@pytest.fixture
async def swarm_pair(security_protocol):
async with swarm_pair_factory(security_protocol=security_protocol) as swarms:
yield swarms
@pytest.fixture
async def swarm_conn_pair(security_protocol):
async with swarm_conn_pair_factory(
security_protocol=security_protocol
) as swarm_conn_pair:
yield swarm_conn_pair

View File

@ -0,0 +1,120 @@
import pytest
import trio
from libp2p.network.stream.exceptions import (
StreamClosed,
StreamEOF,
StreamReset,
)
from libp2p.tools.constants import (
MAX_READ_LEN,
)
DATA = b"data_123"
@pytest.mark.trio
async def test_net_stream_read_write(net_stream_pair):
stream_0, stream_1 = net_stream_pair
assert (
stream_0.protocol_id is not None
and stream_0.protocol_id == stream_1.protocol_id
)
await stream_0.write(DATA)
assert (await stream_1.read(MAX_READ_LEN)) == DATA
@pytest.mark.trio
async def test_net_stream_read_until_eof(net_stream_pair):
read_bytes = bytearray()
stream_0, stream_1 = net_stream_pair
async def read_until_eof():
read_bytes.extend(await stream_1.read())
async with trio.open_nursery() as nursery:
nursery.start_soon(read_until_eof)
expected_data = bytearray()
# Test: `read` doesn't return before `close` is called.
await stream_0.write(DATA)
expected_data.extend(DATA)
await trio.sleep(0.01)
assert len(read_bytes) == 0
# Test: `read` doesn't return before `close` is called.
await stream_0.write(DATA)
expected_data.extend(DATA)
await trio.sleep(0.01)
assert len(read_bytes) == 0
# Test: Close the stream, `read` returns, and receive previous sent data.
await stream_0.close()
await trio.sleep(0.01)
assert read_bytes == expected_data
@pytest.mark.trio
async def test_net_stream_read_after_remote_closed(net_stream_pair):
stream_0, stream_1 = net_stream_pair
await stream_0.write(DATA)
await stream_0.close()
await trio.sleep(0.01)
assert (await stream_1.read(MAX_READ_LEN)) == DATA
with pytest.raises(StreamEOF):
await stream_1.read(MAX_READ_LEN)
@pytest.mark.trio
async def test_net_stream_read_after_local_reset(net_stream_pair):
stream_0, stream_1 = net_stream_pair
await stream_0.reset()
with pytest.raises(StreamReset):
await stream_0.read(MAX_READ_LEN)
@pytest.mark.trio
async def test_net_stream_read_after_remote_reset(net_stream_pair):
stream_0, stream_1 = net_stream_pair
await stream_0.write(DATA)
await stream_0.reset()
# Sleep to let `stream_1` receive the message.
await trio.sleep(0.01)
with pytest.raises(StreamReset):
await stream_1.read(MAX_READ_LEN)
@pytest.mark.trio
async def test_net_stream_read_after_remote_closed_and_reset(net_stream_pair):
stream_0, stream_1 = net_stream_pair
await stream_0.write(DATA)
await stream_0.close()
await stream_0.reset()
# Sleep to let `stream_1` receive the message.
await trio.sleep(0.01)
assert (await stream_1.read(MAX_READ_LEN)) == DATA
@pytest.mark.trio
async def test_net_stream_write_after_local_closed(net_stream_pair):
stream_0, stream_1 = net_stream_pair
await stream_0.write(DATA)
await stream_0.close()
with pytest.raises(StreamClosed):
await stream_0.write(DATA)
@pytest.mark.trio
async def test_net_stream_write_after_local_reset(net_stream_pair):
stream_0, stream_1 = net_stream_pair
await stream_0.reset()
with pytest.raises(StreamClosed):
await stream_0.write(DATA)
@pytest.mark.trio
async def test_net_stream_write_after_remote_reset(net_stream_pair):
stream_0, stream_1 = net_stream_pair
await stream_1.reset()
await trio.sleep(0.01)
with pytest.raises(StreamClosed):
await stream_0.write(DATA)

View File

@ -0,0 +1,126 @@
"""
Test Notify and Notifee by ensuring that the proper events get called, and that
the stream passed into opened_stream is correct.
Note: Listen event does not get hit because MyNotifee is passed
into network after network has already started listening
TODO: Add tests for closed_stream, listen_close when those
features are implemented in swarm
"""
import enum
from async_service import (
background_trio_service,
)
import pytest
import trio
from libp2p.network.notifee_interface import (
INotifee,
)
from libp2p.tools.constants import (
LISTEN_MADDR,
)
from libp2p.tools.factories import (
SwarmFactory,
)
from libp2p.tools.utils import (
connect_swarm,
)
class Event(enum.Enum):
OpenedStream = 0
ClosedStream = 1 # Not implemented
Connected = 2
Disconnected = 3
Listen = 4
ListenClose = 5 # Not implemented
class MyNotifee(INotifee):
def __init__(self, events):
self.events = events
async def opened_stream(self, network, stream):
self.events.append(Event.OpenedStream)
async def closed_stream(self, network, stream):
# TODO: It is not implemented yet.
pass
async def connected(self, network, conn):
self.events.append(Event.Connected)
async def disconnected(self, network, conn):
self.events.append(Event.Disconnected)
async def listen(self, network, _multiaddr):
self.events.append(Event.Listen)
async def listen_close(self, network, _multiaddr):
# TODO: It is not implemented yet.
pass
@pytest.mark.trio
async def test_notify(security_protocol):
swarms = [SwarmFactory(security_protocol=security_protocol) for _ in range(2)]
events_0_0 = []
events_1_0 = []
events_0_without_listen = []
# Run swarms.
async with background_trio_service(swarms[0]), background_trio_service(swarms[1]):
# Register events before listening, to allow `MyNotifee` is notified with the
# event `listen`.
swarms[0].register_notifee(MyNotifee(events_0_0))
swarms[1].register_notifee(MyNotifee(events_1_0))
# Listen
async with trio.open_nursery() as nursery:
nursery.start_soon(swarms[0].listen, LISTEN_MADDR)
nursery.start_soon(swarms[1].listen, LISTEN_MADDR)
swarms[0].register_notifee(MyNotifee(events_0_without_listen))
# Connected
await connect_swarm(swarms[0], swarms[1])
# OpenedStream: first
await swarms[0].new_stream(swarms[1].get_peer_id())
# OpenedStream: second
await swarms[0].new_stream(swarms[1].get_peer_id())
# OpenedStream: third, but different direction.
await swarms[1].new_stream(swarms[0].get_peer_id())
await trio.sleep(0.01)
# TODO: Check `ClosedStream` and `ListenClose` events after they are ready.
# Disconnected
await swarms[0].close_peer(swarms[1].get_peer_id())
await trio.sleep(0.01)
# Connected again, but different direction.
await connect_swarm(swarms[1], swarms[0])
await trio.sleep(0.01)
# Disconnected again, but different direction.
await swarms[1].close_peer(swarms[0].get_peer_id())
await trio.sleep(0.01)
expected_events_without_listen = [
Event.Connected,
Event.OpenedStream,
Event.OpenedStream,
Event.OpenedStream,
Event.Disconnected,
Event.Connected,
Event.Disconnected,
]
expected_events = [Event.Listen] + expected_events_without_listen
assert events_0_0 == expected_events
assert events_1_0 == expected_events
assert events_0_without_listen == expected_events_without_listen

View File

@ -0,0 +1,158 @@
from multiaddr import (
Multiaddr,
)
import pytest
import trio
from trio.testing import (
wait_all_tasks_blocked,
)
from libp2p.network.exceptions import (
SwarmException,
)
from libp2p.tools.factories import (
SwarmFactory,
)
from libp2p.tools.utils import (
connect_swarm,
)
@pytest.mark.trio
async def test_swarm_dial_peer(security_protocol):
async with SwarmFactory.create_batch_and_listen(
3, security_protocol=security_protocol
) as swarms:
# Test: No addr found.
with pytest.raises(SwarmException):
await swarms[0].dial_peer(swarms[1].get_peer_id())
# Test: len(addr) in the peerstore is 0.
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), [], 10000)
with pytest.raises(SwarmException):
await swarms[0].dial_peer(swarms[1].get_peer_id())
# Test: Succeed if addrs of the peer_id are present in the peerstore.
addrs = tuple(
addr
for transport in swarms[1].listeners.values()
for addr in transport.get_addrs()
)
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
await swarms[0].dial_peer(swarms[1].get_peer_id())
assert swarms[0].get_peer_id() in swarms[1].connections
assert swarms[1].get_peer_id() in swarms[0].connections
# Test: Reuse connections when we already have ones with a peer.
conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()]
conn = await swarms[0].dial_peer(swarms[1].get_peer_id())
assert conn is conn_to_1
@pytest.mark.trio
async def test_swarm_close_peer(security_protocol):
async with SwarmFactory.create_batch_and_listen(
3, security_protocol=security_protocol
) as swarms:
# 0 <> 1 <> 2
await connect_swarm(swarms[0], swarms[1])
await connect_swarm(swarms[1], swarms[2])
# peer 1 closes peer 0
await swarms[1].close_peer(swarms[0].get_peer_id())
await trio.sleep(0.01)
await wait_all_tasks_blocked()
# 0 1 <> 2
assert len(swarms[0].connections) == 0
assert (
len(swarms[1].connections) == 1
and swarms[2].get_peer_id() in swarms[1].connections
)
# peer 1 is closed by peer 2
await swarms[2].close_peer(swarms[1].get_peer_id())
await trio.sleep(0.01)
# 0 1 2
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
await connect_swarm(swarms[0], swarms[1])
# 0 <> 1 2
assert (
len(swarms[0].connections) == 1
and swarms[1].get_peer_id() in swarms[0].connections
)
assert (
len(swarms[1].connections) == 1
and swarms[0].get_peer_id() in swarms[1].connections
)
# peer 0 closes peer 1
await swarms[0].close_peer(swarms[1].get_peer_id())
await trio.sleep(0.01)
# 0 1 2
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
@pytest.mark.trio
async def test_swarm_remove_conn(swarm_pair):
swarm_0, swarm_1 = swarm_pair
conn_0 = swarm_0.connections[swarm_1.get_peer_id()]
swarm_0.remove_conn(conn_0)
assert swarm_1.get_peer_id() not in swarm_0.connections
# Test: Remove twice. There should not be errors.
swarm_0.remove_conn(conn_0)
assert swarm_1.get_peer_id() not in swarm_0.connections
@pytest.mark.trio
async def test_swarm_multiaddr(security_protocol):
async with SwarmFactory.create_batch_and_listen(
3, security_protocol=security_protocol
) as swarms:
def clear():
swarms[0].peerstore.clear_addrs(swarms[1].get_peer_id())
clear()
# No addresses
with pytest.raises(SwarmException):
await swarms[0].dial_peer(swarms[1].get_peer_id())
clear()
# Wrong addresses
swarms[0].peerstore.add_addrs(
swarms[1].get_peer_id(), [Multiaddr("/ip4/0.0.0.0/tcp/9999")], 10000
)
with pytest.raises(SwarmException):
await swarms[0].dial_peer(swarms[1].get_peer_id())
clear()
# Multiple wrong addresses
swarms[0].peerstore.add_addrs(
swarms[1].get_peer_id(),
[Multiaddr("/ip4/0.0.0.0/tcp/9999"), Multiaddr("/ip4/0.0.0.0/tcp/9998")],
10000,
)
with pytest.raises(SwarmException):
await swarms[0].dial_peer(swarms[1].get_peer_id())
# Test one address
addrs = tuple(
addr
for transport in swarms[1].listeners.values()
for addr in transport.get_addrs()
)
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs[:1], 10000)
await swarms[0].dial_peer(swarms[1].get_peer_id())
# Test multiple addresses
addrs = tuple(
addr
for transport in swarms[1].listeners.values()
for addr in transport.get_addrs()
)
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs + addrs, 10000)
await swarms[0].dial_peer(swarms[1].get_peer_id())

View File

@ -0,0 +1,48 @@
import pytest
import trio
from trio.testing import (
wait_all_tasks_blocked,
)
@pytest.mark.trio
async def test_swarm_conn_close(swarm_conn_pair):
conn_0, conn_1 = swarm_conn_pair
assert not conn_0.is_closed
assert not conn_1.is_closed
await conn_0.close()
await trio.sleep(0.1)
await wait_all_tasks_blocked()
assert conn_0.is_closed
assert conn_1.is_closed
assert conn_0 not in conn_0.swarm.connections.values()
assert conn_1 not in conn_1.swarm.connections.values()
@pytest.mark.trio
async def test_swarm_conn_streams(swarm_conn_pair):
conn_0, conn_1 = swarm_conn_pair
assert len(conn_0.get_streams()) == 0
assert len(conn_1.get_streams()) == 0
stream_0_0 = await conn_0.new_stream()
await trio.sleep(0.01)
assert len(conn_0.get_streams()) == 1
assert len(conn_1.get_streams()) == 1
stream_0_1 = await conn_0.new_stream()
await trio.sleep(0.01)
assert len(conn_0.get_streams()) == 2
assert len(conn_1.get_streams()) == 2
conn_0.remove_stream(stream_0_0)
assert len(conn_0.get_streams()) == 1
conn_0.remove_stream(stream_0_1)
assert len(conn_0.get_streams()) == 0
# Nothing happen if `stream_0_1` is not present or already removed.
conn_0.remove_stream(stream_0_1)

View File

@ -0,0 +1,61 @@
import pytest
from libp2p.peer.peerstore import (
PeerStore,
PeerStoreError,
)
# Testing methods from IAddrBook base class.
def test_addrs_empty():
with pytest.raises(PeerStoreError):
store = PeerStore()
val = store.addrs("peer")
assert not val
def test_add_addr_single():
store = PeerStore()
store.add_addr("peer1", "/foo", 10)
store.add_addr("peer1", "/bar", 10)
store.add_addr("peer2", "/baz", 10)
assert store.addrs("peer1") == ["/foo", "/bar"]
assert store.addrs("peer2") == ["/baz"]
def test_add_addrs_multiple():
store = PeerStore()
store.add_addrs("peer1", ["/foo1", "/bar1"], 10)
store.add_addrs("peer2", ["/foo2"], 10)
assert store.addrs("peer1") == ["/foo1", "/bar1"]
assert store.addrs("peer2") == ["/foo2"]
def test_clear_addrs():
store = PeerStore()
store.add_addrs("peer1", ["/foo1", "/bar1"], 10)
store.add_addrs("peer2", ["/foo2"], 10)
store.clear_addrs("peer1")
assert store.addrs("peer1") == []
assert store.addrs("peer2") == ["/foo2"]
store.add_addrs("peer1", ["/foo1", "/bar1"], 10)
assert store.addrs("peer1") == ["/foo1", "/bar1"]
def test_peers_with_addrs():
store = PeerStore()
store.add_addrs("peer1", [], 10)
store.add_addrs("peer2", ["/foo"], 10)
store.add_addrs("peer3", ["/bar"], 10)
assert set(store.peers_with_addrs()) == {"peer2", "peer3"}
store.clear_addrs("peer2")
assert set(store.peers_with_addrs()) == {"peer3"}

View File

@ -0,0 +1,47 @@
import base64
import Crypto.PublicKey.RSA as RSA
from libp2p.crypto.pb import crypto_pb2 as pb
from libp2p.crypto.rsa import (
RSAPrivateKey,
)
from libp2p.peer.id import (
ID,
)
# ``PRIVATE_KEY_PROTOBUF_SERIALIZATION`` is a protobuf holding an RSA private key.
PRIVATE_KEY_PROTOBUF_SERIALIZATION = """
CAAS4AQwggJcAgEAAoGBAL7w+Wc4VhZhCdM/+Hccg5Nrf4q9NXWwJylbSrXz/unFS24wyk6pEk0zi3W
7li+vSNVO+NtJQw9qGNAMtQKjVTP+3Vt/jfQRnQM3s6awojtjueEWuLYVt62z7mofOhCtj+VwIdZNBo
/EkLZ0ETfcvN5LVtLYa8JkXybnOPsLvK+PAgMBAAECgYBdk09HDM7zzL657uHfzfOVrdslrTCj6p5mo
DzvCxLkkjIzYGnlPuqfNyGjozkpSWgSUc+X+EGLLl3WqEOVdWJtbM61fewEHlRTM5JzScvwrJ39t7o6
CCAjKA0cBWBd6UWgbN/t53RoWvh9HrA2AW5YrT0ZiAgKe9y7EMUaENVJ8QJBAPhpdmb4ZL4Fkm4OKia
NEcjzn6mGTlZtef7K/0oRC9+2JkQnCuf6HBpaRhJoCJYg7DW8ZY+AV6xClKrgjBOfERMCQQDExhnzu2
dsQ9k8QChBlpHO0TRbZBiQfC70oU31kM1AeLseZRmrxv9Yxzdl8D693NNWS2JbKOXl0kMHHcuGQLMVA
kBZ7WvkmPV3aPL6jnwp2pXepntdVnaTiSxJ1dkXShZ/VSSDNZMYKY306EtHrIu3NZHtXhdyHKcggDXr
qkBrdgErAkAlpGPojUwemOggr4FD8sLX1ot2hDJyyV7OK2FXfajWEYJyMRL1Gm9Uk1+Un53RAkJneqp
JGAzKpyttXBTIDO51AkEA98KTiROMnnU8Y6Mgcvr68/SMIsvCYMt9/mtwSBGgl80VaTQ5Hpaktl6Xbh
VUt5Wv0tRxlXZiViCGCD1EtrrwTw==
""".replace(
"\n", ""
)
EXPECTED_PEER_ID = "QmRK3JgmVEGiewxWbhpXLJyjWuGuLeSTMTndA1coMHEy5o"
# NOTE: this test checks that we can recreate the expected peer id given a private key
# serialization, taken from the Go implementation of libp2p.
def test_peer_id_interop():
private_key_protobuf_bytes = base64.b64decode(PRIVATE_KEY_PROTOBUF_SERIALIZATION)
private_key_protobuf = pb.PrivateKey()
private_key_protobuf.ParseFromString(private_key_protobuf_bytes)
private_key_data = private_key_protobuf.data
private_key_impl = RSA.import_key(private_key_data)
private_key = RSAPrivateKey(private_key_impl)
public_key = private_key.get_public_key()
peer_id = ID.from_pubkey(public_key)
assert peer_id == EXPECTED_PEER_ID

View File

@ -0,0 +1,110 @@
import random
import base58
import multihash
from libp2p.crypto.rsa import (
create_new_key_pair,
)
import libp2p.peer.id as PeerID
from libp2p.peer.id import (
ID,
)
ALPHABETS = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
# ensure we are not in "debug" mode for the following tests
PeerID.FRIENDLY_IDS = False
def test_eq_impl_for_bytes():
random_id_string = ""
for _ in range(10):
random_id_string += random.choice(ALPHABETS)
peer_id = ID(random_id_string.encode())
assert peer_id == random_id_string.encode()
def test_pretty():
random_id_string = ""
for _ in range(10):
random_id_string += random.choice(ALPHABETS)
peer_id = ID(random_id_string.encode())
actual = peer_id.pretty()
expected = base58.b58encode(random_id_string).decode()
assert actual == expected
def test_str_less_than_10():
random_id_string = ""
for _ in range(5):
random_id_string += random.choice(ALPHABETS)
peer_id = base58.b58encode(random_id_string).decode()
expected = peer_id
actual = ID(random_id_string.encode()).__str__()
assert actual == expected
def test_str_more_than_10():
random_id_string = ""
for _ in range(10):
random_id_string += random.choice(ALPHABETS)
peer_id = base58.b58encode(random_id_string).decode()
expected = peer_id
actual = ID(random_id_string.encode()).__str__()
assert actual == expected
def test_eq_true():
random_id_string = ""
for _ in range(10):
random_id_string += random.choice(ALPHABETS)
peer_id = ID(random_id_string.encode())
assert peer_id == base58.b58encode(random_id_string).decode()
assert peer_id == random_id_string.encode()
assert peer_id == ID(random_id_string.encode())
def test_eq_false():
peer_id = ID("efgh")
other = ID("abcd")
assert peer_id != other
def test_id_to_base58():
random_id_string = ""
for _ in range(10):
random_id_string += random.choice(ALPHABETS)
expected = base58.b58encode(random_id_string).decode()
actual = ID(random_id_string.encode()).to_base58()
assert actual == expected
def test_id_from_base58():
random_id_string = ""
for _ in range(10):
random_id_string += random.choice(ALPHABETS)
expected = ID(base58.b58decode(random_id_string))
actual = ID.from_base58(random_id_string.encode())
assert actual == expected
def test_id_from_public_key():
key_pair = create_new_key_pair()
public_key = key_pair.public_key
key_bin = public_key.serialize()
algo = multihash.Func.sha2_256
mh_digest = multihash.digest(key_bin, algo)
expected = ID(mh_digest.encode())
actual = ID.from_pubkey(public_key)
assert actual == expected

View File

@ -0,0 +1,54 @@
import random
import multiaddr
import pytest
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerinfo import (
InvalidAddrError,
PeerInfo,
info_from_p2p_addr,
)
ALPHABETS = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
VALID_MULTI_ADDR_STR = "/ip4/127.0.0.1/tcp/8000/p2p/3YgLAeMKSAPcGqZkAt8mREqhQXmJT8SN8VCMN4T6ih4GNX9wvK8mWJnWZ1qA2mLdCQ" # noqa: E501
def test_init_():
random_addrs = [random.randint(0, 255) for r in range(4)]
random_id_string = ""
for _ in range(10):
random_id_string += random.SystemRandom().choice(ALPHABETS)
peer_id = ID(random_id_string.encode())
peer_info = PeerInfo(peer_id, random_addrs)
assert peer_info.peer_id == peer_id
assert peer_info.addrs == random_addrs
@pytest.mark.parametrize(
"addr",
(
pytest.param(multiaddr.Multiaddr("/"), id="empty multiaddr"),
pytest.param(
multiaddr.Multiaddr("/ip4/127.0.0.1"),
id="multiaddr without peer_id(p2p protocol)",
),
),
)
def test_info_from_p2p_addr_invalid(addr):
with pytest.raises(InvalidAddrError):
info_from_p2p_addr(addr)
def test_info_from_p2p_addr_valid():
m_addr = multiaddr.Multiaddr(VALID_MULTI_ADDR_STR)
info = info_from_p2p_addr(m_addr)
assert (
info.peer_id.pretty()
== "3YgLAeMKSAPcGqZkAt8mREqhQXmJT8SN8VCMN4T6ih4GNX9wvK8mWJnWZ1qA2mLdCQ"
)
assert len(info.addrs) == 1
assert str(info.addrs[0]) == "/ip4/127.0.0.1/tcp/8000"

View File

@ -0,0 +1,46 @@
import pytest
from libp2p.peer.peerstore import (
PeerStore,
PeerStoreError,
)
# Testing methods from IPeerMetadata base class.
def test_get_empty():
with pytest.raises(PeerStoreError):
store = PeerStore()
val = store.get("peer", "key")
assert not val
def test_put_get_simple():
store = PeerStore()
store.put("peer", "key", "val")
assert store.get("peer", "key") == "val"
def test_put_get_update():
store = PeerStore()
store.put("peer", "key1", "val1")
store.put("peer", "key2", "val2")
store.put("peer", "key2", "new val2")
assert store.get("peer", "key1") == "val1"
assert store.get("peer", "key2") == "new val2"
def test_put_get_two_peers():
store = PeerStore()
store.put("peer1", "key1", "val1")
store.put("peer2", "key1", "val1 prime")
assert store.get("peer1", "key1") == "val1"
assert store.get("peer2", "key1") == "val1 prime"
# Try update
store.put("peer2", "key1", "new val1")
assert store.get("peer1", "key1") == "val1"
assert store.get("peer2", "key1") == "new val1"

View File

@ -0,0 +1,62 @@
import pytest
from libp2p.peer.peerstore import (
PeerStore,
PeerStoreError,
)
# Testing methods from IPeerStore base class.
def test_peer_info_empty():
store = PeerStore()
with pytest.raises(PeerStoreError):
store.peer_info("peer")
def test_peer_info_basic():
store = PeerStore()
store.add_addr("peer", "/foo", 10)
info = store.peer_info("peer")
assert info.peer_id == "peer"
assert info.addrs == ["/foo"]
def test_add_get_protocols_basic():
store = PeerStore()
store.add_protocols("peer1", ["p1", "p2"])
store.add_protocols("peer2", ["p3"])
assert set(store.get_protocols("peer1")) == {"p1", "p2"}
assert set(store.get_protocols("peer2")) == {"p3"}
def test_add_get_protocols_extend():
store = PeerStore()
store.add_protocols("peer1", ["p1", "p2"])
store.add_protocols("peer1", ["p3"])
assert set(store.get_protocols("peer1")) == {"p1", "p2", "p3"}
def test_set_protocols():
store = PeerStore()
store.add_protocols("peer1", ["p1", "p2"])
store.add_protocols("peer2", ["p3"])
store.set_protocols("peer1", ["p4"])
store.set_protocols("peer2", [])
assert set(store.get_protocols("peer1")) == {"p4"}
assert set(store.get_protocols("peer2")) == set()
# Test with methods from other Peer interfaces.
def test_peers():
store = PeerStore()
store.add_protocols("peer1", [])
store.put("peer2", "key", "val")
store.add_addr("peer3", "/foo", 10)
assert set(store.peer_ids()) == {"peer1", "peer2", "peer3"}

View File

@ -0,0 +1,102 @@
import pytest
from libp2p.host.exceptions import (
StreamFailure,
)
from libp2p.tools.factories import (
HostFactory,
)
from libp2p.tools.utils import (
create_echo_stream_handler,
)
PROTOCOL_ECHO = "/echo/1.0.0"
PROTOCOL_POTATO = "/potato/1.0.0"
PROTOCOL_FOO = "/foo/1.0.0"
PROTOCOL_ROCK = "/rock/1.0.0"
ACK_PREFIX = "ack:"
async def perform_simple_test(
expected_selected_protocol,
protocols_for_client,
protocols_with_handlers,
security_protocol,
):
async with HostFactory.create_batch_and_listen(
2, security_protocol=security_protocol
) as hosts:
for protocol in protocols_with_handlers:
hosts[1].set_stream_handler(
protocol, create_echo_stream_handler(ACK_PREFIX)
)
# Associate the peer with local ip address (see default parameters of Libp2p())
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
stream = await hosts[0].new_stream(hosts[1].get_id(), protocols_for_client)
messages = ["hello" + str(x) for x in range(10)]
for message in messages:
expected_resp = "ack:" + message
await stream.write(message.encode())
response = (await stream.read(len(expected_resp))).decode()
assert response == expected_resp
assert expected_selected_protocol == stream.get_protocol()
@pytest.mark.trio
async def test_single_protocol_succeeds(security_protocol):
expected_selected_protocol = PROTOCOL_ECHO
await perform_simple_test(
expected_selected_protocol,
[expected_selected_protocol],
[expected_selected_protocol],
security_protocol,
)
@pytest.mark.trio
async def test_single_protocol_fails(security_protocol):
with pytest.raises(StreamFailure):
await perform_simple_test(
"", [PROTOCOL_ECHO], [PROTOCOL_POTATO], security_protocol
)
# Cleanup not reached on error
@pytest.mark.trio
async def test_multiple_protocol_first_is_valid_succeeds(security_protocol):
expected_selected_protocol = PROTOCOL_ECHO
protocols_for_client = [PROTOCOL_ECHO, PROTOCOL_POTATO]
protocols_for_listener = [PROTOCOL_FOO, PROTOCOL_ECHO]
await perform_simple_test(
expected_selected_protocol,
protocols_for_client,
protocols_for_listener,
security_protocol,
)
@pytest.mark.trio
async def test_multiple_protocol_second_is_valid_succeeds(security_protocol):
expected_selected_protocol = PROTOCOL_FOO
protocols_for_client = [PROTOCOL_ROCK, PROTOCOL_FOO]
protocols_for_listener = [PROTOCOL_FOO, PROTOCOL_ECHO]
await perform_simple_test(
expected_selected_protocol,
protocols_for_client,
protocols_for_listener,
security_protocol,
)
@pytest.mark.trio
async def test_multiple_protocol_fails(security_protocol):
protocols_for_client = [PROTOCOL_ROCK, PROTOCOL_FOO, "/bar/1.0.0"]
protocols_for_listener = ["/aspyn/1.0.0", "/rob/1.0.0", "/zx/1.0.0", "/alex/1.0.0"]
with pytest.raises(StreamFailure):
await perform_simple_test(
"", protocols_for_client, protocols_for_listener, security_protocol
)

View File

@ -0,0 +1,197 @@
import pytest
import trio
from libp2p.tools.pubsub.dummy_account_node import (
DummyAccountNode,
)
from libp2p.tools.utils import (
connect,
)
async def perform_test(num_nodes, adjacency_map, action_func, assertion_func):
"""
Helper function to allow for easy construction of custom tests for dummy
account nodes in various network topologies.
:param num_nodes: number of nodes in the test
:param adjacency_map: adjacency map defining each node and its list of neighbors
:param action_func: function to execute that includes actions by the nodes,
such as send crypto and set crypto
:param assertion_func: assertions for testing the results of the actions are correct
"""
async with DummyAccountNode.create(num_nodes) as dummy_nodes:
# Create connections between nodes according to `adjacency_map`
async with trio.open_nursery() as nursery:
for source_num in adjacency_map:
target_nums = adjacency_map[source_num]
for target_num in target_nums:
nursery.start_soon(
connect,
dummy_nodes[source_num].host,
dummy_nodes[target_num].host,
)
# Allow time for network creation to take place
await trio.sleep(0.25)
# Perform action function
await action_func(dummy_nodes)
# Allow time for action function to be performed (i.e. messages to propogate)
await trio.sleep(1)
# Perform assertion function
for dummy_node in dummy_nodes:
assertion_func(dummy_node)
# Success, terminate pending tasks.
@pytest.mark.trio
async def test_simple_two_nodes():
num_nodes = 2
adj_map = {0: [1]}
async def action_func(dummy_nodes):
await dummy_nodes[0].publish_set_crypto("aspyn", 10)
def assertion_func(dummy_node):
assert dummy_node.get_balance("aspyn") == 10
await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.trio
async def test_simple_three_nodes_line_topography():
num_nodes = 3
adj_map = {0: [1], 1: [2]}
async def action_func(dummy_nodes):
await dummy_nodes[0].publish_set_crypto("aspyn", 10)
def assertion_func(dummy_node):
assert dummy_node.get_balance("aspyn") == 10
await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.trio
async def test_simple_three_nodes_triangle_topography():
num_nodes = 3
adj_map = {0: [1, 2], 1: [2]}
async def action_func(dummy_nodes):
await dummy_nodes[0].publish_set_crypto("aspyn", 20)
def assertion_func(dummy_node):
assert dummy_node.get_balance("aspyn") == 20
await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.trio
async def test_simple_seven_nodes_tree_topography():
num_nodes = 7
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
async def action_func(dummy_nodes):
await dummy_nodes[0].publish_set_crypto("aspyn", 20)
def assertion_func(dummy_node):
assert dummy_node.get_balance("aspyn") == 20
await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.trio
async def test_set_then_send_from_root_seven_nodes_tree_topography():
num_nodes = 7
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
async def action_func(dummy_nodes):
await dummy_nodes[0].publish_set_crypto("aspyn", 20)
await trio.sleep(0.25)
await dummy_nodes[0].publish_send_crypto("aspyn", "alex", 5)
def assertion_func(dummy_node):
assert dummy_node.get_balance("aspyn") == 15
assert dummy_node.get_balance("alex") == 5
await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.trio
async def test_set_then_send_from_different_leafs_seven_nodes_tree_topography():
num_nodes = 7
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
async def action_func(dummy_nodes):
await dummy_nodes[6].publish_set_crypto("aspyn", 20)
await trio.sleep(0.25)
await dummy_nodes[4].publish_send_crypto("aspyn", "alex", 5)
def assertion_func(dummy_node):
assert dummy_node.get_balance("aspyn") == 15
assert dummy_node.get_balance("alex") == 5
await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.trio
async def test_simple_five_nodes_ring_topography():
num_nodes = 5
adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]}
async def action_func(dummy_nodes):
await dummy_nodes[0].publish_set_crypto("aspyn", 20)
def assertion_func(dummy_node):
assert dummy_node.get_balance("aspyn") == 20
await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.trio
async def test_set_then_send_from_diff_nodes_five_nodes_ring_topography():
num_nodes = 5
adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]}
async def action_func(dummy_nodes):
await dummy_nodes[0].publish_set_crypto("alex", 20)
await trio.sleep(0.25)
await dummy_nodes[3].publish_send_crypto("alex", "rob", 12)
def assertion_func(dummy_node):
assert dummy_node.get_balance("alex") == 8
assert dummy_node.get_balance("rob") == 12
await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.trio
@pytest.mark.slow
async def test_set_then_send_from_five_diff_nodes_five_nodes_ring_topography():
num_nodes = 5
adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]}
async def action_func(dummy_nodes):
await dummy_nodes[0].publish_set_crypto("alex", 20)
await trio.sleep(1)
await dummy_nodes[1].publish_send_crypto("alex", "rob", 3)
await trio.sleep(1)
await dummy_nodes[2].publish_send_crypto("rob", "aspyn", 2)
await trio.sleep(1)
await dummy_nodes[3].publish_send_crypto("aspyn", "zx", 1)
await trio.sleep(1)
await dummy_nodes[4].publish_send_crypto("zx", "raul", 1)
def assertion_func(dummy_node):
assert dummy_node.get_balance("alex") == 17
assert dummy_node.get_balance("rob") == 1
assert dummy_node.get_balance("aspyn") == 1
assert dummy_node.get_balance("zx") == 0
assert dummy_node.get_balance("raul") == 1
await perform_test(num_nodes, adj_map, action_func, assertion_func)

View File

@ -0,0 +1,95 @@
import functools
import pytest
import trio
from libp2p.peer.id import (
ID,
)
from libp2p.tools.factories import (
PubsubFactory,
)
from libp2p.tools.pubsub.floodsub_integration_test_settings import (
floodsub_protocol_pytest_params,
perform_test_from_obj,
)
from libp2p.tools.utils import (
connect,
)
@pytest.mark.trio
async def test_simple_two_nodes():
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
topic = "my_topic"
data = b"some data"
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
await trio.sleep(0.25)
sub_b = await pubsubs_fsub[1].subscribe(topic)
# Sleep to let a know of b's subscription
await trio.sleep(0.25)
await pubsubs_fsub[0].publish(topic, data)
res_b = await sub_b.get()
# Check that the msg received by node_b is the same
# as the message sent by node_a
assert ID(res_b.from_id) == pubsubs_fsub[0].host.get_id()
assert res_b.data == data
assert res_b.topicIDs == [topic]
@pytest.mark.trio
async def test_lru_cache_two_nodes():
# two nodes with cache_size of 4
# Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`.
def get_msg_id(msg):
# Originally it is `(msg.seqno, msg.from_id)`
return (msg.data, msg.from_id)
async with PubsubFactory.create_batch_with_floodsub(
2, cache_size=4, msg_id_constructor=get_msg_id
) as pubsubs_fsub:
# `node_a` send the following messages to node_b
message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1]
# `node_b` should only receive the following
expected_received_indices = [1, 2, 3, 4, 5, 1]
topic = "my_topic"
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
await trio.sleep(0.25)
sub_b = await pubsubs_fsub[1].subscribe(topic)
await trio.sleep(0.25)
def _make_testing_data(i: int) -> bytes:
num_int_bytes = 4
if i >= 2 ** (num_int_bytes * 8):
raise ValueError("integer is too large to be serialized")
return b"data" + i.to_bytes(num_int_bytes, "big")
for index in message_indices:
await pubsubs_fsub[0].publish(topic, _make_testing_data(index))
await trio.sleep(0.25)
for index in expected_received_indices:
res_b = await sub_b.get()
assert res_b.data == _make_testing_data(index)
@pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params)
@pytest.mark.trio
@pytest.mark.slow
async def test_gossipsub_run_with_floodsub_tests(test_case_obj, security_protocol):
await perform_test_from_obj(
test_case_obj,
functools.partial(
PubsubFactory.create_batch_with_floodsub,
security_protocol=security_protocol,
),
)

View File

@ -0,0 +1,488 @@
import random
import pytest
import trio
from libp2p.pubsub.gossipsub import (
PROTOCOL_ID,
)
from libp2p.tools.factories import (
IDFactory,
PubsubFactory,
)
from libp2p.tools.pubsub.utils import (
dense_connect,
one_to_all_connect,
)
from libp2p.tools.utils import (
connect,
)
@pytest.mark.trio
async def test_join():
async with PubsubFactory.create_batch_with_gossipsub(
4, degree=4, degree_low=3, degree_high=5
) as pubsubs_gsub:
gossipsubs = [pubsub.router for pubsub in pubsubs_gsub]
hosts = [pubsub.host for pubsub in pubsubs_gsub]
hosts_indices = list(range(len(pubsubs_gsub)))
topic = "test_join"
central_node_index = 0
# Remove index of central host from the indices
hosts_indices.remove(central_node_index)
num_subscribed_peer = 2
subscribed_peer_indices = random.sample(hosts_indices, num_subscribed_peer)
# All pubsub except the one of central node subscribe to topic
for i in subscribed_peer_indices:
await pubsubs_gsub[i].subscribe(topic)
# Connect central host to all other hosts
await one_to_all_connect(hosts, central_node_index)
# Wait 2 seconds for heartbeat to allow mesh to connect
await trio.sleep(2)
# Central node publish to the topic so that this topic
# is added to central node's fanout
# publish from the randomly chosen host
await pubsubs_gsub[central_node_index].publish(topic, b"data")
# Check that the gossipsub of central node has fanout for the topic
assert topic in gossipsubs[central_node_index].fanout
# Check that the gossipsub of central node does not have a mesh for the topic
assert topic not in gossipsubs[central_node_index].mesh
# Central node subscribes the topic
await pubsubs_gsub[central_node_index].subscribe(topic)
await trio.sleep(2)
# Check that the gossipsub of central node no longer has fanout for the topic
assert topic not in gossipsubs[central_node_index].fanout
for i in hosts_indices:
if i in subscribed_peer_indices:
assert hosts[i].get_id() in gossipsubs[central_node_index].mesh[topic]
assert hosts[central_node_index].get_id() in gossipsubs[i].mesh[topic]
else:
assert (
hosts[i].get_id() not in gossipsubs[central_node_index].mesh[topic]
)
assert topic not in gossipsubs[i].mesh
@pytest.mark.trio
async def test_leave():
async with PubsubFactory.create_batch_with_gossipsub(1) as pubsubs_gsub:
gossipsub = pubsubs_gsub[0].router
topic = "test_leave"
assert topic not in gossipsub.mesh
await gossipsub.join(topic)
assert topic in gossipsub.mesh
await gossipsub.leave(topic)
assert topic not in gossipsub.mesh
# Test re-leave
await gossipsub.leave(topic)
@pytest.mark.trio
async def test_handle_graft(monkeypatch):
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub)
index_alice = 0
id_alice = pubsubs_gsub[index_alice].my_id
index_bob = 1
id_bob = pubsubs_gsub[index_bob].my_id
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
# Wait 2 seconds for heartbeat to allow mesh to connect
await trio.sleep(2)
topic = "test_handle_graft"
# Only lice subscribe to the topic
await gossipsubs[index_alice].join(topic)
# Monkey patch bob's `emit_prune` function so we can
# check if it is called in `handle_graft`
event_emit_prune = trio.Event()
async def emit_prune(topic, sender_peer_id):
event_emit_prune.set()
await trio.lowlevel.checkpoint()
monkeypatch.setattr(gossipsubs[index_bob], "emit_prune", emit_prune)
# Check that alice is bob's peer but not his mesh peer
assert gossipsubs[index_bob].peer_protocol[id_alice] == PROTOCOL_ID
assert topic not in gossipsubs[index_bob].mesh
await gossipsubs[index_alice].emit_graft(topic, id_bob)
# Check that `emit_prune` is called
await event_emit_prune.wait()
# Check that bob is alice's peer but not her mesh peer
assert topic in gossipsubs[index_alice].mesh
assert id_bob not in gossipsubs[index_alice].mesh[topic]
assert gossipsubs[index_alice].peer_protocol[id_bob] == PROTOCOL_ID
await gossipsubs[index_bob].emit_graft(topic, id_alice)
await trio.sleep(1)
# Check that bob is now alice's mesh peer
assert id_bob in gossipsubs[index_alice].mesh[topic]
@pytest.mark.trio
async def test_handle_prune():
async with PubsubFactory.create_batch_with_gossipsub(
2, heartbeat_interval=3
) as pubsubs_gsub:
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub)
index_alice = 0
id_alice = pubsubs_gsub[index_alice].my_id
index_bob = 1
id_bob = pubsubs_gsub[index_bob].my_id
topic = "test_handle_prune"
for pubsub in pubsubs_gsub:
await pubsub.subscribe(topic)
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
# Wait for heartbeat to allow mesh to connect
await trio.sleep(1)
# Check that they are each other's mesh peer
assert id_alice in gossipsubs[index_bob].mesh[topic]
assert id_bob in gossipsubs[index_alice].mesh[topic]
# alice emit prune message to bob, alice should be removed
# from bob's mesh peer
await gossipsubs[index_alice].emit_prune(topic, id_bob)
# `emit_prune` does not remove bob from alice's mesh peers
assert id_bob in gossipsubs[index_alice].mesh[topic]
# NOTE: We increase `heartbeat_interval` to 3 seconds so that bob will not
# add alice back to his mesh after heartbeat.
# Wait for bob to `handle_prune`
await trio.sleep(0.1)
# Check that alice is no longer bob's mesh peer
assert id_alice not in gossipsubs[index_bob].mesh[topic]
@pytest.mark.trio
async def test_dense():
async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub:
hosts = [pubsub.host for pubsub in pubsubs_gsub]
num_msgs = 5
# All pubsub subscribe to foobar
queues = [await pubsub.subscribe("foobar") for pubsub in pubsubs_gsub]
# Densely connect libp2p hosts in a random way
await dense_connect(hosts)
# Wait 2 seconds for heartbeat to allow mesh to connect
await trio.sleep(2)
for i in range(num_msgs):
msg_content = b"foo " + i.to_bytes(1, "big")
# randomly pick a message origin
origin_idx = random.randint(0, len(hosts) - 1)
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish("foobar", msg_content)
await trio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.get()
assert msg.data == msg_content
@pytest.mark.trio
async def test_fanout():
async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub:
hosts = [pubsub.host for pubsub in pubsubs_gsub]
num_msgs = 5
# All pubsub subscribe to foobar except for `pubsubs_gsub[0]`
subs = [await pubsub.subscribe("foobar") for pubsub in pubsubs_gsub[1:]]
# Sparsely connect libp2p hosts in random way
await dense_connect(hosts)
# Wait 2 seconds for heartbeat to allow mesh to connect
await trio.sleep(2)
topic = "foobar"
# Send messages with origin not subscribed
for i in range(num_msgs):
msg_content = b"foo " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await trio.sleep(0.5)
# Assert that all blocking queues receive the message
for sub in subs:
msg = await sub.get()
assert msg.data == msg_content
# Subscribe message origin
subs.insert(0, await pubsubs_gsub[0].subscribe(topic))
# Send messages again
for i in range(num_msgs):
msg_content = b"bar " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await trio.sleep(0.5)
# Assert that all blocking queues receive the message
for sub in subs:
msg = await sub.get()
assert msg.data == msg_content
@pytest.mark.trio
@pytest.mark.slow
async def test_fanout_maintenance():
async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub:
hosts = [pubsub.host for pubsub in pubsubs_gsub]
num_msgs = 5
# All pubsub subscribe to foobar
queues = []
topic = "foobar"
for i in range(1, len(pubsubs_gsub)):
q = await pubsubs_gsub[i].subscribe(topic)
# Add each blocking queue to an array of blocking queues
queues.append(q)
# Sparsely connect libp2p hosts in random way
await dense_connect(hosts)
# Wait 2 seconds for heartbeat to allow mesh to connect
await trio.sleep(2)
# Send messages with origin not subscribed
for i in range(num_msgs):
msg_content = b"foo " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await trio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.get()
assert msg.data == msg_content
for sub in pubsubs_gsub:
await sub.unsubscribe(topic)
queues = []
await trio.sleep(2)
# Resub and repeat
for i in range(1, len(pubsubs_gsub)):
q = await pubsubs_gsub[i].subscribe(topic)
# Add each blocking queue to an array of blocking queues
queues.append(q)
await trio.sleep(2)
# Check messages can still be sent
for i in range(num_msgs):
msg_content = b"bar " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await trio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.get()
assert msg.data == msg_content
@pytest.mark.trio
async def test_gossip_propagation():
async with PubsubFactory.create_batch_with_gossipsub(
2, degree=1, degree_low=0, degree_high=2, gossip_window=50, gossip_history=100
) as pubsubs_gsub:
topic = "foo"
queue_0 = await pubsubs_gsub[0].subscribe(topic)
# node 0 publish to topic
msg_content = b"foo_msg"
# publish from the randomly chosen host
await pubsubs_gsub[0].publish(topic, msg_content)
await trio.sleep(0.5)
# Assert that the blocking queues receive the message
msg = await queue_0.get()
assert msg.data == msg_content
@pytest.mark.parametrize("initial_mesh_peer_count", (7, 10, 13))
@pytest.mark.trio
async def test_mesh_heartbeat(initial_mesh_peer_count, monkeypatch):
async with PubsubFactory.create_batch_with_gossipsub(
1, heartbeat_initial_delay=100
) as pubsubs_gsub:
# It's difficult to set up the initial peer subscription condition.
# Ideally I would like to have initial mesh peer count that's below
# ``GossipSubDegree`` so I can test if `mesh_heartbeat` return correct peers to
# GRAFT. The problem is that I can not set it up so that we have peers subscribe
# to the topic but not being part of our mesh peers (as these peers are the
# peers to GRAFT). So I monkeypatch the peer subscriptions and our mesh peers.
total_peer_count = 14
topic = "TEST_MESH_HEARTBEAT"
fake_peer_ids = [IDFactory() for _ in range(total_peer_count)]
peer_protocol = {peer_id: PROTOCOL_ID for peer_id in fake_peer_ids}
monkeypatch.setattr(pubsubs_gsub[0].router, "peer_protocol", peer_protocol)
peer_topics = {topic: set(fake_peer_ids)}
# Monkeypatch the peer subscriptions
monkeypatch.setattr(pubsubs_gsub[0], "peer_topics", peer_topics)
mesh_peer_indices = random.sample(
range(total_peer_count), initial_mesh_peer_count
)
mesh_peers = [fake_peer_ids[i] for i in mesh_peer_indices]
router_mesh = {topic: set(mesh_peers)}
# Monkeypatch our mesh peers
monkeypatch.setattr(pubsubs_gsub[0].router, "mesh", router_mesh)
peers_to_graft, peers_to_prune = pubsubs_gsub[0].router.mesh_heartbeat()
if initial_mesh_peer_count > pubsubs_gsub[0].router.degree:
# If number of initial mesh peers is more than `GossipSubDegree`,
# we should PRUNE mesh peers
assert len(peers_to_graft) == 0
assert (
len(peers_to_prune)
== initial_mesh_peer_count - pubsubs_gsub[0].router.degree
)
for peer in peers_to_prune:
assert peer in mesh_peers
elif initial_mesh_peer_count < pubsubs_gsub[0].router.degree:
# If number of initial mesh peers is less than `GossipSubDegree`,
# we should GRAFT more peers
assert len(peers_to_prune) == 0
assert (
len(peers_to_graft)
== pubsubs_gsub[0].router.degree - initial_mesh_peer_count
)
for peer in peers_to_graft:
assert peer not in mesh_peers
else:
assert len(peers_to_prune) == 0 and len(peers_to_graft) == 0
@pytest.mark.parametrize("initial_peer_count", (1, 4, 7))
@pytest.mark.trio
async def test_gossip_heartbeat(initial_peer_count, monkeypatch):
async with PubsubFactory.create_batch_with_gossipsub(
1, heartbeat_initial_delay=100
) as pubsubs_gsub:
# The problem is that I can not set it up so that we have peers subscribe to the
# topic but not being part of our mesh peers (as these peers are the peers to
# GRAFT). So I monkeypatch the peer subscriptions and our mesh peers.
total_peer_count = 28
topic_mesh = "TEST_GOSSIP_HEARTBEAT_1"
topic_fanout = "TEST_GOSSIP_HEARTBEAT_2"
fake_peer_ids = [IDFactory() for _ in range(total_peer_count)]
peer_protocol = {peer_id: PROTOCOL_ID for peer_id in fake_peer_ids}
monkeypatch.setattr(pubsubs_gsub[0].router, "peer_protocol", peer_protocol)
topic_mesh_peer_count = 14
# Split into mesh peers and fanout peers
peer_topics = {
topic_mesh: set(fake_peer_ids[:topic_mesh_peer_count]),
topic_fanout: set(fake_peer_ids[topic_mesh_peer_count:]),
}
# Monkeypatch the peer subscriptions
monkeypatch.setattr(pubsubs_gsub[0], "peer_topics", peer_topics)
mesh_peer_indices = random.sample(
range(topic_mesh_peer_count), initial_peer_count
)
mesh_peers = [fake_peer_ids[i] for i in mesh_peer_indices]
router_mesh = {topic_mesh: set(mesh_peers)}
# Monkeypatch our mesh peers
monkeypatch.setattr(pubsubs_gsub[0].router, "mesh", router_mesh)
fanout_peer_indices = random.sample(
range(topic_mesh_peer_count, total_peer_count), initial_peer_count
)
fanout_peers = [fake_peer_ids[i] for i in fanout_peer_indices]
router_fanout = {topic_fanout: set(fanout_peers)}
# Monkeypatch our fanout peers
monkeypatch.setattr(pubsubs_gsub[0].router, "fanout", router_fanout)
def window(topic):
if topic == topic_mesh:
return [topic_mesh]
elif topic == topic_fanout:
return [topic_fanout]
else:
return []
# Monkeypatch the memory cache messages
monkeypatch.setattr(pubsubs_gsub[0].router.mcache, "window", window)
peers_to_gossip = pubsubs_gsub[0].router.gossip_heartbeat()
# If our mesh peer count is less than `GossipSubDegree`, we should gossip to up
# to `GossipSubDegree` peers (exclude mesh peers).
if topic_mesh_peer_count - initial_peer_count < pubsubs_gsub[0].router.degree:
# The same goes for fanout so it's two times the number of peers to gossip.
assert len(peers_to_gossip) == 2 * (
topic_mesh_peer_count - initial_peer_count
)
elif (
topic_mesh_peer_count - initial_peer_count >= pubsubs_gsub[0].router.degree
):
assert len(peers_to_gossip) == 2 * (pubsubs_gsub[0].router.degree)
for peer in peers_to_gossip:
if peer in peer_topics[topic_mesh]:
# Check that the peer to gossip to is not in our mesh peers
assert peer not in mesh_peers
assert topic_mesh in peers_to_gossip[peer]
elif peer in peer_topics[topic_fanout]:
# Check that the peer to gossip to is not in our fanout peers
assert peer not in fanout_peers
assert topic_fanout in peers_to_gossip[peer]

View File

@ -0,0 +1,31 @@
import functools
import pytest
from libp2p.tools.constants import (
FLOODSUB_PROTOCOL_ID,
)
from libp2p.tools.factories import (
PubsubFactory,
)
from libp2p.tools.pubsub.floodsub_integration_test_settings import (
floodsub_protocol_pytest_params,
perform_test_from_obj,
)
@pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params)
@pytest.mark.trio
@pytest.mark.slow
async def test_gossipsub_run_with_floodsub_tests(test_case_obj):
await perform_test_from_obj(
test_case_obj,
functools.partial(
PubsubFactory.create_batch_with_gossipsub,
protocols=[FLOODSUB_PROTOCOL_ID],
degree=3,
degree_low=2,
degree_high=4,
time_to_live=30,
),
)

View File

@ -0,0 +1,130 @@
from libp2p.pubsub.mcache import (
MessageCache,
)
class Msg:
__slots__ = ["topicIDs", "seqno", "from_id"]
def __init__(self, topicIDs, seqno, from_id):
self.topicIDs = topicIDs
self.seqno = seqno
self.from_id = from_id
def test_mcache():
# Ported from:
# https://github.com/libp2p/go-libp2p-pubsub/blob/51b7501433411b5096cac2b4994a36a68515fc03/mcache_test.go
mcache = MessageCache(3, 5)
msgs = []
for i in range(60):
msgs.append(Msg(["test"], i, "test"))
for i in range(10):
mcache.put(msgs[i])
for i in range(10):
msg = msgs[i]
mid = (msg.seqno, msg.from_id)
get_msg = mcache.get(mid)
# successful read
assert get_msg == msg
gids = mcache.window("test")
assert len(gids) == 10
for i in range(10):
msg = msgs[i]
mid = (msg.seqno, msg.from_id)
assert mid == gids[i]
mcache.shift()
for i in range(10, 20):
mcache.put(msgs[i])
for i in range(20):
msg = msgs[i]
mid = (msg.seqno, msg.from_id)
get_msg = mcache.get(mid)
assert get_msg == msg
gids = mcache.window("test")
assert len(gids) == 20
for i in range(10):
msg = msgs[i]
mid = (msg.seqno, msg.from_id)
assert mid == gids[10 + i]
for i in range(10, 20):
msg = msgs[i]
mid = (msg.seqno, msg.from_id)
assert mid == gids[i - 10]
mcache.shift()
for i in range(20, 30):
mcache.put(msgs[i])
mcache.shift()
for i in range(30, 40):
mcache.put(msgs[i])
mcache.shift()
for i in range(40, 50):
mcache.put(msgs[i])
mcache.shift()
for i in range(50, 60):
mcache.put(msgs[i])
assert len(mcache.msgs) == 50
for i in range(10):
msg = msgs[i]
mid = (msg.seqno, msg.from_id)
get_msg = mcache.get(mid)
# Should be evicted from cache
assert not get_msg
for i in range(10, 60):
msg = msgs[i]
mid = (msg.seqno, msg.from_id)
get_msg = mcache.get(mid)
assert get_msg == msg
gids = mcache.window("test")
assert len(gids) == 30
for i in range(10):
msg = msgs[50 + i]
mid = (msg.seqno, msg.from_id)
assert mid == gids[i]
for i in range(10, 20):
msg = msgs[30 + i]
mid = (msg.seqno, msg.from_id)
assert mid == gids[i]
for i in range(20, 30):
msg = msgs[10 + i]
mid = (msg.seqno, msg.from_id)
assert mid == gids[i]

View File

@ -0,0 +1,684 @@
from contextlib import (
contextmanager,
)
from typing import (
NamedTuple,
)
import pytest
import trio
from libp2p.exceptions import (
ValidationError,
)
from libp2p.pubsub.pb import (
rpc_pb2,
)
from libp2p.pubsub.pubsub import (
PUBSUB_SIGNING_PREFIX,
SUBSCRIPTION_CHANNEL_SIZE,
)
from libp2p.tools.constants import (
MAX_READ_LEN,
)
from libp2p.tools.factories import (
IDFactory,
PubsubFactory,
net_stream_pair_factory,
)
from libp2p.tools.pubsub.utils import (
make_pubsub_msg,
)
from libp2p.tools.utils import (
connect,
)
from libp2p.utils import (
encode_varint_prefixed,
)
TESTING_TOPIC = "TEST_SUBSCRIBE"
TESTING_DATA = b"data"
@pytest.mark.trio
async def test_subscribe_and_unsubscribe():
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids
@pytest.mark.trio
async def test_re_subscribe():
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids
@pytest.mark.trio
async def test_re_unsubscribe():
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
# Unsubscribe from topic we didn't even subscribe to
assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].topic_ids
await pubsubs_fsub[0].unsubscribe("NOT_MY_TOPIC")
assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].topic_ids
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids
@pytest.mark.trio
async def test_peers_subscribe():
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
# Yield to let 0 notify 1
await trio.sleep(1)
assert pubsubs_fsub[0].my_id in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
# Yield to let 0 notify 1
await trio.sleep(1)
assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
@pytest.mark.trio
async def test_get_hello_packet():
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
def _get_hello_packet_topic_ids():
packet = pubsubs_fsub[0].get_hello_packet()
return tuple(sub.topicid for sub in packet.subscriptions)
# Test: No subscription, so there should not be any topic ids in the
# hello packet.
assert len(_get_hello_packet_topic_ids()) == 0
# Test: After subscriptions, topic ids should be in the hello packet.
topic_ids = ["t", "o", "p", "i", "c"]
for topic in topic_ids:
await pubsubs_fsub[0].subscribe(topic)
topic_ids_in_hello = _get_hello_packet_topic_ids()
for topic in topic_ids:
assert topic in topic_ids_in_hello
@pytest.mark.trio
async def test_set_and_remove_topic_validator():
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
is_sync_validator_called = False
def sync_validator(peer_id, msg):
nonlocal is_sync_validator_called
is_sync_validator_called = True
is_async_validator_called = False
async def async_validator(peer_id, msg):
nonlocal is_async_validator_called
is_async_validator_called = True
await trio.lowlevel.checkpoint()
topic = "TEST_VALIDATOR"
assert topic not in pubsubs_fsub[0].topic_validators
# Register sync validator
pubsubs_fsub[0].set_topic_validator(topic, sync_validator, False)
assert topic in pubsubs_fsub[0].topic_validators
topic_validator = pubsubs_fsub[0].topic_validators[topic]
assert not topic_validator.is_async
# Validate with sync validator
topic_validator.validator(peer_id=IDFactory(), msg="msg")
assert is_sync_validator_called
assert not is_async_validator_called
# Register with async validator
pubsubs_fsub[0].set_topic_validator(topic, async_validator, True)
is_sync_validator_called = False
assert topic in pubsubs_fsub[0].topic_validators
topic_validator = pubsubs_fsub[0].topic_validators[topic]
assert topic_validator.is_async
# Validate with async validator
await topic_validator.validator(peer_id=IDFactory(), msg="msg")
assert is_async_validator_called
assert not is_sync_validator_called
# Remove validator
pubsubs_fsub[0].remove_topic_validator(topic)
assert topic not in pubsubs_fsub[0].topic_validators
@pytest.mark.trio
async def test_get_msg_validators():
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
times_sync_validator_called = 0
def sync_validator(peer_id, msg):
nonlocal times_sync_validator_called
times_sync_validator_called += 1
times_async_validator_called = 0
async def async_validator(peer_id, msg):
nonlocal times_async_validator_called
times_async_validator_called += 1
await trio.lowlevel.checkpoint()
topic_1 = "TEST_VALIDATOR_1"
topic_2 = "TEST_VALIDATOR_2"
topic_3 = "TEST_VALIDATOR_3"
# Register sync validator for topic 1 and 2
pubsubs_fsub[0].set_topic_validator(topic_1, sync_validator, False)
pubsubs_fsub[0].set_topic_validator(topic_2, sync_validator, False)
# Register async validator for topic 3
pubsubs_fsub[0].set_topic_validator(topic_3, async_validator, True)
msg = make_pubsub_msg(
origin_id=pubsubs_fsub[0].my_id,
topic_ids=[topic_1, topic_2, topic_3],
data=b"1234",
seqno=b"\x00" * 8,
)
topic_validators = pubsubs_fsub[0].get_msg_validators(msg)
for topic_validator in topic_validators:
if topic_validator.is_async:
await topic_validator.validator(peer_id=IDFactory(), msg="msg")
else:
topic_validator.validator(peer_id=IDFactory(), msg="msg")
assert times_sync_validator_called == 2
assert times_async_validator_called == 1
@pytest.mark.parametrize(
"is_topic_1_val_passed, is_topic_2_val_passed",
((False, True), (True, False), (True, True)),
)
@pytest.mark.trio
async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed):
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
def passed_sync_validator(peer_id, msg):
return True
def failed_sync_validator(peer_id, msg):
return False
async def passed_async_validator(peer_id, msg):
await trio.lowlevel.checkpoint()
return True
async def failed_async_validator(peer_id, msg):
await trio.lowlevel.checkpoint()
return False
topic_1 = "TEST_SYNC_VALIDATOR"
topic_2 = "TEST_ASYNC_VALIDATOR"
if is_topic_1_val_passed:
pubsubs_fsub[0].set_topic_validator(topic_1, passed_sync_validator, False)
else:
pubsubs_fsub[0].set_topic_validator(topic_1, failed_sync_validator, False)
if is_topic_2_val_passed:
pubsubs_fsub[0].set_topic_validator(topic_2, passed_async_validator, True)
else:
pubsubs_fsub[0].set_topic_validator(topic_2, failed_async_validator, True)
msg = make_pubsub_msg(
origin_id=pubsubs_fsub[0].my_id,
topic_ids=[topic_1, topic_2],
data=b"1234",
seqno=b"\x00" * 8,
)
if is_topic_1_val_passed and is_topic_2_val_passed:
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
else:
with pytest.raises(ValidationError):
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
@pytest.mark.trio
async def test_continuously_read_stream(monkeypatch, nursery, security_protocol):
async def wait_for_event_occurring(event):
await trio.lowlevel.checkpoint()
with trio.fail_after(0.1):
await event.wait()
class Events(NamedTuple):
push_msg: trio.Event
handle_subscription: trio.Event
handle_rpc: trio.Event
@contextmanager
def mock_methods():
event_push_msg = trio.Event()
event_handle_subscription = trio.Event()
event_handle_rpc = trio.Event()
async def mock_push_msg(msg_forwarder, msg):
event_push_msg.set()
await trio.lowlevel.checkpoint()
def mock_handle_subscription(origin_id, sub_message):
event_handle_subscription.set()
async def mock_handle_rpc(rpc, sender_peer_id):
event_handle_rpc.set()
await trio.lowlevel.checkpoint()
with monkeypatch.context() as m:
m.setattr(pubsubs_fsub[0], "push_msg", mock_push_msg)
m.setattr(pubsubs_fsub[0], "handle_subscription", mock_handle_subscription)
m.setattr(pubsubs_fsub[0].router, "handle_rpc", mock_handle_rpc)
yield Events(event_push_msg, event_handle_subscription, event_handle_rpc)
async with PubsubFactory.create_batch_with_floodsub(
1, security_protocol=security_protocol
) as pubsubs_fsub, net_stream_pair_factory(
security_protocol=security_protocol
) as stream_pair:
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
# Kick off the task `continuously_read_stream`
nursery.start_soon(pubsubs_fsub[0].continuously_read_stream, stream_pair[0])
# Test: `push_msg` is called when publishing to a subscribed topic.
publish_subscribed_topic = rpc_pb2.RPC(
publish=[rpc_pb2.Message(topicIDs=[TESTING_TOPIC])]
)
with mock_methods() as events:
await stream_pair[1].write(
encode_varint_prefixed(publish_subscribed_topic.SerializeToString())
)
await wait_for_event_occurring(events.push_msg)
# Make sure the other events are not emitted.
with pytest.raises(trio.TooSlowError):
await wait_for_event_occurring(events.handle_subscription)
with pytest.raises(trio.TooSlowError):
await wait_for_event_occurring(events.handle_rpc)
# Test: `push_msg` is not called when publishing to a topic-not-subscribed.
publish_not_subscribed_topic = rpc_pb2.RPC(
publish=[rpc_pb2.Message(topicIDs=["NOT_SUBSCRIBED"])]
)
with mock_methods() as events:
await stream_pair[1].write(
encode_varint_prefixed(publish_not_subscribed_topic.SerializeToString())
)
with pytest.raises(trio.TooSlowError):
await wait_for_event_occurring(events.push_msg)
# Test: `handle_subscription` is called when a subscription message is received.
subscription_msg = rpc_pb2.RPC(subscriptions=[rpc_pb2.RPC.SubOpts()])
with mock_methods() as events:
await stream_pair[1].write(
encode_varint_prefixed(subscription_msg.SerializeToString())
)
await wait_for_event_occurring(events.handle_subscription)
# Make sure the other events are not emitted.
with pytest.raises(trio.TooSlowError):
await wait_for_event_occurring(events.push_msg)
with pytest.raises(trio.TooSlowError):
await wait_for_event_occurring(events.handle_rpc)
# Test: `handle_rpc` is called when a control message is received.
control_msg = rpc_pb2.RPC(control=rpc_pb2.ControlMessage())
with mock_methods() as events:
await stream_pair[1].write(
encode_varint_prefixed(control_msg.SerializeToString())
)
await wait_for_event_occurring(events.handle_rpc)
# Make sure the other events are not emitted.
with pytest.raises(trio.TooSlowError):
await wait_for_event_occurring(events.push_msg)
with pytest.raises(trio.TooSlowError):
await wait_for_event_occurring(events.handle_subscription)
# TODO: Add the following tests after they are aligned with Go.
# (Issue #191: https://github.com/libp2p/py-libp2p/issues/191)
# - `test_stream_handler`
# - `test_handle_peer_queue`
@pytest.mark.trio
async def test_handle_subscription():
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
assert len(pubsubs_fsub[0].peer_topics) == 0
sub_msg_0 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=TESTING_TOPIC)
peer_ids = [IDFactory() for _ in range(2)]
# Test: One peer is subscribed
pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_0)
assert (
len(pubsubs_fsub[0].peer_topics) == 1
and TESTING_TOPIC in pubsubs_fsub[0].peer_topics
)
assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 1
assert peer_ids[0] in pubsubs_fsub[0].peer_topics[TESTING_TOPIC]
# Test: Another peer is subscribed
pubsubs_fsub[0].handle_subscription(peer_ids[1], sub_msg_0)
assert len(pubsubs_fsub[0].peer_topics) == 1
assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 2
assert peer_ids[1] in pubsubs_fsub[0].peer_topics[TESTING_TOPIC]
# Test: Subscribe to another topic
another_topic = "ANOTHER_TOPIC"
sub_msg_1 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=another_topic)
pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_1)
assert len(pubsubs_fsub[0].peer_topics) == 2
assert another_topic in pubsubs_fsub[0].peer_topics
assert peer_ids[0] in pubsubs_fsub[0].peer_topics[another_topic]
# Test: unsubscribe
unsub_msg = rpc_pb2.RPC.SubOpts(subscribe=False, topicid=TESTING_TOPIC)
pubsubs_fsub[0].handle_subscription(peer_ids[0], unsub_msg)
assert peer_ids[0] not in pubsubs_fsub[0].peer_topics[TESTING_TOPIC]
@pytest.mark.trio
async def test_handle_talk():
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
msg_0 = make_pubsub_msg(
origin_id=pubsubs_fsub[0].my_id,
topic_ids=[TESTING_TOPIC],
data=b"1234",
seqno=b"\x00" * 8,
)
pubsubs_fsub[0].notify_subscriptions(msg_0)
msg_1 = make_pubsub_msg(
origin_id=pubsubs_fsub[0].my_id,
topic_ids=["NOT_SUBSCRIBED"],
data=b"1234",
seqno=b"\x11" * 8,
)
pubsubs_fsub[0].notify_subscriptions(msg_1)
assert (
len(pubsubs_fsub[0].topic_ids) == 1
and sub == pubsubs_fsub[0].subscribed_topics_receive[TESTING_TOPIC]
)
assert (await sub.get()) == msg_0
@pytest.mark.trio
async def test_message_all_peers(monkeypatch, security_protocol):
async with PubsubFactory.create_batch_with_floodsub(
1, security_protocol=security_protocol
) as pubsubs_fsub, net_stream_pair_factory(
security_protocol=security_protocol
) as stream_pair:
peer_id = IDFactory()
mock_peers = {peer_id: stream_pair[0]}
with monkeypatch.context() as m:
m.setattr(pubsubs_fsub[0], "peers", mock_peers)
empty_rpc = rpc_pb2.RPC()
empty_rpc_bytes = empty_rpc.SerializeToString()
empty_rpc_bytes_len_prefixed = encode_varint_prefixed(empty_rpc_bytes)
await pubsubs_fsub[0].message_all_peers(empty_rpc_bytes)
assert (
await stream_pair[1].read(MAX_READ_LEN)
) == empty_rpc_bytes_len_prefixed
@pytest.mark.trio
async def test_subscribe_and_publish():
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
pubsub = pubsubs_fsub[0]
list_data = [b"d0", b"d1"]
event_receive_data_started = trio.Event()
async def publish_data(topic):
await event_receive_data_started.wait()
for data in list_data:
await pubsub.publish(topic, data)
async def receive_data(topic):
i = 0
event_receive_data_started.set()
assert topic not in pubsub.topic_ids
subscription = await pubsub.subscribe(topic)
async with subscription:
assert topic in pubsub.topic_ids
async for msg in subscription:
assert msg.data == list_data[i]
i += 1
if i == len(list_data):
break
assert topic not in pubsub.topic_ids
async with trio.open_nursery() as nursery:
nursery.start_soon(receive_data, TESTING_TOPIC)
nursery.start_soon(publish_data, TESTING_TOPIC)
@pytest.mark.trio
async def test_subscribe_and_publish_full_channel():
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
pubsub = pubsubs_fsub[0]
extra_data_0 = b"extra_data_0"
extra_data_1 = b"extra_data_1"
# Test: Subscription channel is of size `SUBSCRIPTION_CHANNEL_SIZE`.
# When the channel is full, new received messages are dropped.
# After the channel has empty slot, the channel can receive new messages.
# Assume `SUBSCRIPTION_CHANNEL_SIZE` is smaller than `2**(4*8)`.
list_data = [i.to_bytes(4, "big") for i in range(SUBSCRIPTION_CHANNEL_SIZE)]
# Expect `extra_data_0` is dropped and `extra_data_1` is appended.
expected_list_data = list_data + [extra_data_1]
subscription = await pubsub.subscribe(TESTING_TOPIC)
for data in list_data:
await pubsub.publish(TESTING_TOPIC, data)
# Publish `extra_data_0` which should be dropped since the channel is
# already full.
await pubsub.publish(TESTING_TOPIC, extra_data_0)
# Consume a message and there is an empty slot in the channel.
assert (await subscription.get()).data == expected_list_data.pop(0)
# Publish `extra_data_1` which should be appended to the channel.
await pubsub.publish(TESTING_TOPIC, extra_data_1)
for expected_data in expected_list_data:
assert (await subscription.get()).data == expected_data
@pytest.mark.trio
async def test_publish_push_msg_is_called(monkeypatch):
msg_forwarders = []
msgs = []
async def push_msg(msg_forwarder, msg):
msg_forwarders.append(msg_forwarder)
msgs.append(msg)
await trio.lowlevel.checkpoint()
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
with monkeypatch.context() as m:
m.setattr(pubsubs_fsub[0], "push_msg", push_msg)
await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA)
await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA)
assert (
len(msgs) == 2
), "`push_msg` should be called every time `publish` is called"
assert (msg_forwarders[0] == msg_forwarders[1]) and (
msg_forwarders[1] == pubsubs_fsub[0].my_id
)
assert (
msgs[0].seqno != msgs[1].seqno
), "`seqno` should be different every time"
@pytest.mark.trio
async def test_push_msg(monkeypatch):
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
msg_0 = make_pubsub_msg(
origin_id=pubsubs_fsub[0].my_id,
topic_ids=[TESTING_TOPIC],
data=TESTING_DATA,
seqno=b"\x00" * 8,
)
@contextmanager
def mock_router_publish():
event = trio.Event()
async def router_publish(*args, **kwargs):
event.set()
await trio.lowlevel.checkpoint()
with monkeypatch.context() as m:
m.setattr(pubsubs_fsub[0].router, "publish", router_publish)
yield event
with mock_router_publish() as event:
# Test: `msg` is not seen before `push_msg`, and is seen after `push_msg`.
assert not pubsubs_fsub[0]._is_msg_seen(msg_0)
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0)
assert pubsubs_fsub[0]._is_msg_seen(msg_0)
# Test: Ensure `router.publish` is called in `push_msg`
with trio.fail_after(0.1):
await event.wait()
with mock_router_publish() as event:
# Test: `push_msg` the message again and it will be reject.
# `router_publish` is not called then.
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0)
await trio.sleep(0.01)
assert not event.is_set()
sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
# Test: `push_msg` succeeds with another unseen msg.
msg_1 = make_pubsub_msg(
origin_id=pubsubs_fsub[0].my_id,
topic_ids=[TESTING_TOPIC],
data=TESTING_DATA,
seqno=b"\x11" * 8,
)
assert not pubsubs_fsub[0]._is_msg_seen(msg_1)
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_1)
assert pubsubs_fsub[0]._is_msg_seen(msg_1)
with trio.fail_after(0.1):
await event.wait()
# Test: Subscribers are notified when `push_msg` new messages.
assert (await sub.get()) == msg_1
with mock_router_publish() as event:
# Test: add a topic validator and `push_msg` the message that
# does not pass the validation.
# `router_publish` is not called then.
def failed_sync_validator(peer_id, msg):
return False
pubsubs_fsub[0].set_topic_validator(
TESTING_TOPIC, failed_sync_validator, False
)
msg_2 = make_pubsub_msg(
origin_id=pubsubs_fsub[0].my_id,
topic_ids=[TESTING_TOPIC],
data=TESTING_DATA,
seqno=b"\x22" * 8,
)
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_2)
await trio.sleep(0.01)
assert not event.is_set()
@pytest.mark.trio
async def test_strict_signing():
async with PubsubFactory.create_batch_with_floodsub(
2, strict_signing=True
) as pubsubs_fsub:
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
await pubsubs_fsub[1].subscribe(TESTING_TOPIC)
await trio.sleep(1)
await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA)
await trio.sleep(1)
assert len(pubsubs_fsub[0].seen_messages) == 1
assert len(pubsubs_fsub[1].seen_messages) == 1
@pytest.mark.trio
async def test_strict_signing_failed_validation(monkeypatch):
async with PubsubFactory.create_batch_with_floodsub(
2, strict_signing=True
) as pubsubs_fsub:
msg = make_pubsub_msg(
origin_id=pubsubs_fsub[0].my_id,
topic_ids=[TESTING_TOPIC],
data=TESTING_DATA,
seqno=b"\x00" * 8,
)
priv_key = pubsubs_fsub[0].sign_key
signature = priv_key.sign(
PUBSUB_SIGNING_PREFIX.encode() + msg.SerializeToString()
)
event = trio.Event()
def _is_msg_seen(msg):
return False
# Use router publish to check if `push_msg` succeed.
async def router_publish(*args, **kwargs):
await trio.lowlevel.checkpoint()
# The event will only be set if `push_msg` succeed.
event.set()
monkeypatch.setattr(pubsubs_fsub[0], "_is_msg_seen", _is_msg_seen)
monkeypatch.setattr(pubsubs_fsub[0].router, "publish", router_publish)
# Test: no signature attached in `msg`
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg)
await trio.sleep(0.01)
assert not event.is_set()
# Test: `msg.key` does not match `msg.from_id`
msg.key = pubsubs_fsub[1].host.get_public_key().serialize()
msg.signature = signature
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg)
await trio.sleep(0.01)
assert not event.is_set()
# Test: invalid signature
msg.key = pubsubs_fsub[0].host.get_public_key().serialize()
msg.signature = b"\x12" * 100
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg)
await trio.sleep(0.01)
assert not event.is_set()
# Finally, assert the signature indeed will pass validation
msg.key = pubsubs_fsub[0].host.get_public_key().serialize()
msg.signature = signature
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg)
await trio.sleep(0.01)
assert event.is_set()

View File

@ -0,0 +1,88 @@
import math
import pytest
import trio
from libp2p.pubsub.pb import (
rpc_pb2,
)
from libp2p.pubsub.subscription import (
TrioSubscriptionAPI,
)
GET_TIMEOUT = 0.001
def make_trio_subscription():
send_channel, receive_channel = trio.open_memory_channel(math.inf)
async def unsubscribe_fn():
await send_channel.aclose()
return (
send_channel,
TrioSubscriptionAPI(receive_channel, unsubscribe_fn=unsubscribe_fn),
)
def make_pubsub_msg():
return rpc_pb2.Message()
async def send_something(send_channel):
msg = make_pubsub_msg()
await send_channel.send(msg)
return msg
@pytest.mark.trio
async def test_trio_subscription_get():
send_channel, sub = make_trio_subscription()
data_0 = await send_something(send_channel)
data_1 = await send_something(send_channel)
assert data_0 == await sub.get()
assert data_1 == await sub.get()
# No more message
with pytest.raises(trio.TooSlowError):
with trio.fail_after(GET_TIMEOUT):
await sub.get()
@pytest.mark.trio
async def test_trio_subscription_iter():
send_channel, sub = make_trio_subscription()
received_data = []
async def iter_subscriptions(subscription):
async for data in sub:
received_data.append(data)
async with trio.open_nursery() as nursery:
nursery.start_soon(iter_subscriptions, sub)
await send_something(send_channel)
await send_something(send_channel)
await send_channel.aclose()
assert len(received_data) == 2
@pytest.mark.trio
async def test_trio_subscription_unsubscribe():
send_channel, sub = make_trio_subscription()
await sub.unsubscribe()
# Test: If the subscription is unsubscribed, `send_channel` should be closed.
with pytest.raises(trio.ClosedResourceError):
await send_something(send_channel)
# Test: No side effect when cancelled twice.
await sub.unsubscribe()
@pytest.mark.trio
async def test_trio_subscription_async_context_manager():
send_channel, sub = make_trio_subscription()
async with sub:
# Test: `sub` is not cancelled yet, so `send_something` works fine.
await send_something(send_channel)
# Test: `sub` is cancelled, `send_something` fails
with pytest.raises(trio.ClosedResourceError):
await send_something(send_channel)

View File

@ -0,0 +1,32 @@
import pytest
from libp2p.security.noise.io import (
MAX_NOISE_MESSAGE_LEN,
NoisePacketReadWriter,
)
from libp2p.tools.factories import (
raw_conn_factory,
)
@pytest.mark.parametrize(
"noise_msg",
(b"", b"data", pytest.param(b"A" * MAX_NOISE_MESSAGE_LEN, id="maximum length")),
)
@pytest.mark.trio
async def test_noise_msg_read_write_round_trip(nursery, noise_msg):
async with raw_conn_factory(nursery) as conns:
reader, writer = (
NoisePacketReadWriter(conns[0]),
NoisePacketReadWriter(conns[1]),
)
await writer.write_msg(noise_msg)
assert (await reader.read_msg()) == noise_msg
@pytest.mark.trio
async def test_noise_msg_write_too_long(nursery):
async with raw_conn_factory(nursery) as conns:
writer = NoisePacketReadWriter(conns[0])
with pytest.raises(ValueError):
await writer.write_msg(b"1" * (MAX_NOISE_MESSAGE_LEN + 1))

View File

@ -0,0 +1,38 @@
import pytest
from libp2p.security.noise.messages import (
NoiseHandshakePayload,
)
from libp2p.tools.factories import (
noise_conn_factory,
noise_handshake_payload_factory,
)
DATA_0 = b"data_0"
DATA_1 = b"1" * 1000
DATA_2 = b"data_2"
@pytest.mark.trio
async def test_noise_transport(nursery):
async with noise_conn_factory(nursery):
pass
@pytest.mark.trio
async def test_noise_connection(nursery):
async with noise_conn_factory(nursery) as conns:
local_conn, remote_conn = conns
await local_conn.write(DATA_0)
await local_conn.write(DATA_1)
assert DATA_0 == (await remote_conn.read(len(DATA_0)))
assert DATA_1 == (await remote_conn.read(len(DATA_1)))
await local_conn.write(DATA_2)
assert DATA_2 == (await remote_conn.read(len(DATA_2)))
def test_noise_handshake_payload():
payload = noise_handshake_payload_factory()
payload_serialized = payload.serialize()
payload_deserialized = NoiseHandshakePayload.deserialize(payload_serialized)
assert payload == payload_deserialized

View File

@ -0,0 +1,60 @@
import pytest
import trio
from libp2p.crypto.secp256k1 import (
create_new_key_pair,
)
from libp2p.peer.id import (
ID,
)
from libp2p.security.secio.transport import (
NONCE_SIZE,
create_secure_session,
)
from libp2p.tools.constants import (
MAX_READ_LEN,
)
from libp2p.tools.factories import (
raw_conn_factory,
)
@pytest.mark.trio
async def test_create_secure_session(nursery):
local_nonce = b"\x01" * NONCE_SIZE
local_key_pair = create_new_key_pair(b"a")
local_peer = ID.from_pubkey(local_key_pair.public_key)
remote_nonce = b"\x02" * NONCE_SIZE
remote_key_pair = create_new_key_pair(b"b")
remote_peer = ID.from_pubkey(remote_key_pair.public_key)
async with raw_conn_factory(nursery) as conns:
local_conn, remote_conn = conns
local_secure_conn, remote_secure_conn = None, None
async def local_create_secure_session():
nonlocal local_secure_conn
local_secure_conn = await create_secure_session(
local_nonce,
local_peer,
local_key_pair.private_key,
local_conn,
remote_peer,
)
async def remote_create_secure_session():
nonlocal remote_secure_conn
remote_secure_conn = await create_secure_session(
remote_nonce, remote_peer, remote_key_pair.private_key, remote_conn
)
async with trio.open_nursery() as nursery_1:
nursery_1.start_soon(local_create_secure_session)
nursery_1.start_soon(remote_create_secure_session)
msg = b"abc"
await local_secure_conn.write(msg)
received_msg = await remote_secure_conn.read(MAX_READ_LEN)
assert received_msg == msg

View File

@ -0,0 +1,58 @@
import pytest
from libp2p.crypto.rsa import (
create_new_key_pair,
)
from libp2p.security.insecure.transport import (
PLAINTEXT_PROTOCOL_ID,
InsecureSession,
)
from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID
from libp2p.security.secio.transport import ID as SECIO_PROTOCOL_ID
from libp2p.security.secure_session import (
SecureSession,
)
from libp2p.tools.factories import (
host_pair_factory,
)
initiator_key_pair = create_new_key_pair()
noninitiator_key_pair = create_new_key_pair()
async def perform_simple_test(assertion_func, security_protocol):
async with host_pair_factory(security_protocol=security_protocol) as hosts:
conn_0 = hosts[0].get_network().connections[hosts[1].get_id()]
conn_1 = hosts[1].get_network().connections[hosts[0].get_id()]
# Perform assertion
assertion_func(conn_0.muxed_conn.secured_conn)
assertion_func(conn_1.muxed_conn.secured_conn)
@pytest.mark.trio
@pytest.mark.parametrize(
"security_protocol, transport_type",
(
(PLAINTEXT_PROTOCOL_ID, InsecureSession),
(SECIO_PROTOCOL_ID, SecureSession),
(NOISE_PROTOCOL_ID, SecureSession),
),
)
@pytest.mark.trio
async def test_single_insecure_security_transport_succeeds(
security_protocol, transport_type
):
def assertion_func(conn):
assert isinstance(conn, transport_type)
await perform_simple_test(assertion_func, security_protocol)
@pytest.mark.trio
async def test_default_insecure_security():
def assertion_func(conn):
assert isinstance(conn, InsecureSession)
await perform_simple_test(assertion_func, None)

View File

@ -0,0 +1,24 @@
import pytest
from libp2p.tools.factories import (
mplex_conn_pair_factory,
mplex_stream_pair_factory,
)
@pytest.fixture
async def mplex_conn_pair(security_protocol):
async with mplex_conn_pair_factory(
security_protocol=security_protocol
) as mplex_conn_pair:
assert mplex_conn_pair[0].is_initiator
assert not mplex_conn_pair[1].is_initiator
yield mplex_conn_pair[0], mplex_conn_pair[1]
@pytest.fixture
async def mplex_stream_pair(security_protocol):
async with mplex_stream_pair_factory(
security_protocol=security_protocol
) as mplex_stream_pair:
yield mplex_stream_pair

View File

@ -0,0 +1,41 @@
import pytest
import trio
@pytest.mark.trio
async def test_mplex_conn(mplex_conn_pair):
conn_0, conn_1 = mplex_conn_pair
assert len(conn_0.streams) == 0
assert len(conn_1.streams) == 0
# Test: Open a stream, and both side get 1 more stream.
stream_0 = await conn_0.open_stream()
await trio.sleep(0.01)
assert len(conn_0.streams) == 1
assert len(conn_1.streams) == 1
# Test: From another side.
stream_1 = await conn_1.open_stream()
await trio.sleep(0.01)
assert len(conn_0.streams) == 2
assert len(conn_1.streams) == 2
# Close from one side.
await conn_0.close()
# Sleep for a while for both side to handle `close`.
await trio.sleep(0.01)
# Test: Both side is closed.
assert conn_0.is_closed
assert conn_1.is_closed
# Test: All streams should have been closed.
assert stream_0.event_remote_closed.is_set()
assert stream_0.event_reset.is_set()
assert stream_0.event_local_closed.is_set()
# Test: All streams on the other side are also closed.
assert stream_1.event_remote_closed.is_set()
assert stream_1.event_reset.is_set()
assert stream_1.event_local_closed.is_set()
# Test: No effect to close more than once between two side.
await conn_0.close()
await conn_1.close()

View File

@ -0,0 +1,215 @@
import pytest
import trio
from trio.testing import (
wait_all_tasks_blocked,
)
from libp2p.stream_muxer.mplex.exceptions import (
MplexStreamClosed,
MplexStreamEOF,
MplexStreamReset,
)
from libp2p.stream_muxer.mplex.mplex import (
MPLEX_MESSAGE_CHANNEL_SIZE,
)
from libp2p.tools.constants import (
MAX_READ_LEN,
)
DATA = b"data_123"
@pytest.mark.trio
async def test_mplex_stream_read_write(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
await stream_0.write(DATA)
assert (await stream_1.read(MAX_READ_LEN)) == DATA
@pytest.mark.trio
async def test_mplex_stream_full_buffer(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
# Test: The message channel is of size `MPLEX_MESSAGE_CHANNEL_SIZE`.
# It should be fine to read even there are already `MPLEX_MESSAGE_CHANNEL_SIZE`
# messages arriving.
for _ in range(MPLEX_MESSAGE_CHANNEL_SIZE):
await stream_0.write(DATA)
await wait_all_tasks_blocked()
# Sanity check
assert MAX_READ_LEN >= MPLEX_MESSAGE_CHANNEL_SIZE * len(DATA)
assert (await stream_1.read(MAX_READ_LEN)) == MPLEX_MESSAGE_CHANNEL_SIZE * DATA
# Test: Read after `MPLEX_MESSAGE_CHANNEL_SIZE + 1` messages has arrived, which
# exceeds the channel size. The stream should have been reset.
for _ in range(MPLEX_MESSAGE_CHANNEL_SIZE + 1):
await stream_0.write(DATA)
await wait_all_tasks_blocked()
with pytest.raises(MplexStreamReset):
await stream_1.read(MAX_READ_LEN)
@pytest.mark.trio
async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair):
read_bytes = bytearray()
stream_0, stream_1 = mplex_stream_pair
async def read_until_eof():
read_bytes.extend(await stream_1.read())
expected_data = bytearray()
async with trio.open_nursery() as nursery:
nursery.start_soon(read_until_eof)
# Test: `read` doesn't return before `close` is called.
await stream_0.write(DATA)
expected_data.extend(DATA)
await trio.sleep(0.01)
assert len(read_bytes) == 0
# Test: `read` doesn't return before `close` is called.
await stream_0.write(DATA)
expected_data.extend(DATA)
await trio.sleep(0.01)
assert len(read_bytes) == 0
# Test: Close the stream, `read` returns, and receive previous sent data.
await stream_0.close()
assert read_bytes == expected_data
@pytest.mark.trio
async def test_mplex_stream_read_after_remote_closed(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
assert not stream_1.event_remote_closed.is_set()
await stream_0.write(DATA)
assert not stream_0.event_local_closed.is_set()
await trio.sleep(0.01)
await wait_all_tasks_blocked()
await stream_0.close()
assert stream_0.event_local_closed.is_set()
await trio.sleep(0.01)
await wait_all_tasks_blocked()
assert stream_1.event_remote_closed.is_set()
assert (await stream_1.read(MAX_READ_LEN)) == DATA
with pytest.raises(MplexStreamEOF):
await stream_1.read(MAX_READ_LEN)
@pytest.mark.trio
async def test_mplex_stream_read_after_local_reset(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
await stream_0.reset()
with pytest.raises(MplexStreamReset):
await stream_0.read(MAX_READ_LEN)
@pytest.mark.trio
async def test_mplex_stream_read_after_remote_reset(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
await stream_0.write(DATA)
await stream_0.reset()
# Sleep to let `stream_1` receive the message.
await trio.sleep(0.1)
await wait_all_tasks_blocked()
with pytest.raises(MplexStreamReset):
await stream_1.read(MAX_READ_LEN)
@pytest.mark.trio
async def test_mplex_stream_read_after_remote_closed_and_reset(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
await stream_0.write(DATA)
await stream_0.close()
await stream_0.reset()
# Sleep to let `stream_1` receive the message.
await trio.sleep(0.01)
assert (await stream_1.read(MAX_READ_LEN)) == DATA
@pytest.mark.trio
async def test_mplex_stream_write_after_local_closed(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
await stream_0.write(DATA)
await stream_0.close()
with pytest.raises(MplexStreamClosed):
await stream_0.write(DATA)
@pytest.mark.trio
async def test_mplex_stream_write_after_local_reset(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
await stream_0.reset()
with pytest.raises(MplexStreamClosed):
await stream_0.write(DATA)
@pytest.mark.trio
async def test_mplex_stream_write_after_remote_reset(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
await stream_1.reset()
await trio.sleep(0.01)
with pytest.raises(MplexStreamClosed):
await stream_0.write(DATA)
@pytest.mark.trio
async def test_mplex_stream_both_close(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
# Flags are not set initially.
assert not stream_0.event_local_closed.is_set()
assert not stream_1.event_local_closed.is_set()
assert not stream_0.event_remote_closed.is_set()
assert not stream_1.event_remote_closed.is_set()
# Streams are present in their `mplex_conn`.
assert stream_0 in stream_0.muxed_conn.streams.values()
assert stream_1 in stream_1.muxed_conn.streams.values()
# Test: Close one side.
await stream_0.close()
await trio.sleep(0.01)
assert stream_0.event_local_closed.is_set()
assert not stream_1.event_local_closed.is_set()
assert not stream_0.event_remote_closed.is_set()
assert stream_1.event_remote_closed.is_set()
# Streams are still present in their `mplex_conn`.
assert stream_0 in stream_0.muxed_conn.streams.values()
assert stream_1 in stream_1.muxed_conn.streams.values()
# Test: Close the other side.
await stream_1.close()
await trio.sleep(0.01)
# Both sides are closed.
assert stream_0.event_local_closed.is_set()
assert stream_1.event_local_closed.is_set()
assert stream_0.event_remote_closed.is_set()
assert stream_1.event_remote_closed.is_set()
# Streams are removed from their `mplex_conn`.
assert stream_0 not in stream_0.muxed_conn.streams.values()
assert stream_1 not in stream_1.muxed_conn.streams.values()
# Test: Reset after both close.
await stream_0.reset()
@pytest.mark.trio
async def test_mplex_stream_reset(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
await stream_0.reset()
await trio.sleep(0.01)
# Both sides are closed.
assert stream_0.event_local_closed.is_set()
assert stream_1.event_local_closed.is_set()
assert stream_0.event_remote_closed.is_set()
assert stream_1.event_remote_closed.is_set()
# Streams are removed from their `mplex_conn`.
assert stream_0 not in stream_0.muxed_conn.streams.values()
assert stream_1 not in stream_1.muxed_conn.streams.values()
# `close` should do nothing.
await stream_0.close()
await stream_1.close()
# `reset` should do nothing as well.
await stream_0.reset()
await stream_1.reset()

View File

@ -0,0 +1,309 @@
import multiaddr
import pytest
from libp2p.network.stream.exceptions import (
StreamError,
)
from libp2p.tools.constants import (
MAX_READ_LEN,
)
from libp2p.tools.factories import (
HostFactory,
)
from libp2p.tools.utils import (
connect,
create_echo_stream_handler,
)
from libp2p.typing import (
TProtocol,
)
PROTOCOL_ID_0 = TProtocol("/echo/0")
PROTOCOL_ID_1 = TProtocol("/echo/1")
PROTOCOL_ID_2 = TProtocol("/echo/2")
PROTOCOL_ID_3 = TProtocol("/echo/3")
ACK_STR_0 = "ack_0:"
ACK_STR_1 = "ack_1:"
ACK_STR_2 = "ack_2:"
ACK_STR_3 = "ack_3:"
@pytest.mark.trio
async def test_simple_messages(security_protocol):
async with HostFactory.create_batch_and_listen(
2, security_protocol=security_protocol
) as hosts:
hosts[1].set_stream_handler(
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
)
# Associate the peer with local ip address (see default parameters of Libp2p())
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
stream = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
messages = ["hello" + str(x) for x in range(10)]
for message in messages:
await stream.write(message.encode())
response = (await stream.read(MAX_READ_LEN)).decode()
assert response == (ACK_STR_0 + message)
@pytest.mark.trio
async def test_double_response(security_protocol):
async with HostFactory.create_batch_and_listen(
2, security_protocol=security_protocol
) as hosts:
async def double_response_stream_handler(stream):
while True:
try:
read_string = (await stream.read(MAX_READ_LEN)).decode()
except StreamError:
break
response = ACK_STR_0 + read_string
try:
await stream.write(response.encode())
except StreamError:
break
response = ACK_STR_1 + read_string
try:
await stream.write(response.encode())
except StreamError:
break
hosts[1].set_stream_handler(PROTOCOL_ID_0, double_response_stream_handler)
# Associate the peer with local ip address (see default parameters of Libp2p())
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
stream = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
messages = ["hello" + str(x) for x in range(10)]
for message in messages:
await stream.write(message.encode())
response1 = (await stream.read(MAX_READ_LEN)).decode()
assert response1 == (ACK_STR_0 + message)
response2 = (await stream.read(MAX_READ_LEN)).decode()
assert response2 == (ACK_STR_1 + message)
@pytest.mark.trio
async def test_multiple_streams(security_protocol):
# hosts[0] should be able to open a stream with hosts[1] and then vice versa.
# Stream IDs should be generated uniquely so that the stream state is not
# overwritten
async with HostFactory.create_batch_and_listen(
2, security_protocol=security_protocol
) as hosts:
hosts[0].set_stream_handler(
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
)
hosts[1].set_stream_handler(
PROTOCOL_ID_1, create_echo_stream_handler(ACK_STR_1)
)
# Associate the peer with local ip address (see default parameters of Libp2p())
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
hosts[1].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10)
stream_a = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_1])
stream_b = await hosts[1].new_stream(hosts[0].get_id(), [PROTOCOL_ID_0])
# A writes to /echo_b via stream_a, and B writes to /echo_a via stream_b
messages = ["hello" + str(x) for x in range(10)]
for message in messages:
a_message = message + "_a"
b_message = message + "_b"
await stream_a.write(a_message.encode())
await stream_b.write(b_message.encode())
response_a = (await stream_a.read(MAX_READ_LEN)).decode()
response_b = (await stream_b.read(MAX_READ_LEN)).decode()
assert response_a == (ACK_STR_1 + a_message) and response_b == (
ACK_STR_0 + b_message
)
@pytest.mark.trio
async def test_multiple_streams_same_initiator_different_protocols(security_protocol):
async with HostFactory.create_batch_and_listen(
2, security_protocol=security_protocol
) as hosts:
hosts[1].set_stream_handler(
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
)
hosts[1].set_stream_handler(
PROTOCOL_ID_1, create_echo_stream_handler(ACK_STR_1)
)
hosts[1].set_stream_handler(
PROTOCOL_ID_2, create_echo_stream_handler(ACK_STR_2)
)
# Associate the peer with local ip address (see default parameters of Libp2p())
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
hosts[1].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10)
# Open streams to hosts[1] over echo_a1 echo_a2 echo_a3 protocols
stream_a1 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
stream_a2 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_1])
stream_a3 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_2])
messages = ["hello" + str(x) for x in range(10)]
for message in messages:
a1_message = message + "_a1"
a2_message = message + "_a2"
a3_message = message + "_a3"
await stream_a1.write(a1_message.encode())
await stream_a2.write(a2_message.encode())
await stream_a3.write(a3_message.encode())
response_a1 = (await stream_a1.read(MAX_READ_LEN)).decode()
response_a2 = (await stream_a2.read(MAX_READ_LEN)).decode()
response_a3 = (await stream_a3.read(MAX_READ_LEN)).decode()
assert (
response_a1 == (ACK_STR_0 + a1_message)
and response_a2 == (ACK_STR_1 + a2_message)
and response_a3 == (ACK_STR_2 + a3_message)
)
# Success, terminate pending tasks.
@pytest.mark.trio
async def test_multiple_streams_two_initiators(security_protocol):
async with HostFactory.create_batch_and_listen(
2, security_protocol=security_protocol
) as hosts:
hosts[0].set_stream_handler(
PROTOCOL_ID_2, create_echo_stream_handler(ACK_STR_2)
)
hosts[0].set_stream_handler(
PROTOCOL_ID_3, create_echo_stream_handler(ACK_STR_3)
)
hosts[1].set_stream_handler(
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
)
hosts[1].set_stream_handler(
PROTOCOL_ID_1, create_echo_stream_handler(ACK_STR_1)
)
# Associate the peer with local ip address (see default parameters of Libp2p())
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
hosts[1].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10)
stream_a1 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
stream_a2 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_1])
stream_b1 = await hosts[1].new_stream(hosts[0].get_id(), [PROTOCOL_ID_2])
stream_b2 = await hosts[1].new_stream(hosts[0].get_id(), [PROTOCOL_ID_3])
# A writes to /echo_b via stream_a, and B writes to /echo_a via stream_b
messages = ["hello" + str(x) for x in range(10)]
for message in messages:
a1_message = message + "_a1"
a2_message = message + "_a2"
b1_message = message + "_b1"
b2_message = message + "_b2"
await stream_a1.write(a1_message.encode())
await stream_a2.write(a2_message.encode())
await stream_b1.write(b1_message.encode())
await stream_b2.write(b2_message.encode())
response_a1 = (await stream_a1.read(MAX_READ_LEN)).decode()
response_a2 = (await stream_a2.read(MAX_READ_LEN)).decode()
response_b1 = (await stream_b1.read(MAX_READ_LEN)).decode()
response_b2 = (await stream_b2.read(MAX_READ_LEN)).decode()
assert (
response_a1 == (ACK_STR_0 + a1_message)
and response_a2 == (ACK_STR_1 + a2_message)
and response_b1 == (ACK_STR_2 + b1_message)
and response_b2 == (ACK_STR_3 + b2_message)
)
@pytest.mark.trio
async def test_triangle_nodes_connection(security_protocol):
async with HostFactory.create_batch_and_listen(
3, security_protocol=security_protocol
) as hosts:
hosts[0].set_stream_handler(
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
)
hosts[1].set_stream_handler(
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
)
hosts[2].set_stream_handler(
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
)
# Associate the peer with local ip address (see default parameters of Libp2p())
# Associate all permutations
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
hosts[0].get_peerstore().add_addrs(hosts[2].get_id(), hosts[2].get_addrs(), 10)
hosts[1].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10)
hosts[1].get_peerstore().add_addrs(hosts[2].get_id(), hosts[2].get_addrs(), 10)
hosts[2].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10)
hosts[2].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
stream_0_to_1 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
stream_0_to_2 = await hosts[0].new_stream(hosts[2].get_id(), [PROTOCOL_ID_0])
stream_1_to_0 = await hosts[1].new_stream(hosts[0].get_id(), [PROTOCOL_ID_0])
stream_1_to_2 = await hosts[1].new_stream(hosts[2].get_id(), [PROTOCOL_ID_0])
stream_2_to_0 = await hosts[2].new_stream(hosts[0].get_id(), [PROTOCOL_ID_0])
stream_2_to_1 = await hosts[2].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
messages = ["hello" + str(x) for x in range(5)]
streams = [
stream_0_to_1,
stream_0_to_2,
stream_1_to_0,
stream_1_to_2,
stream_2_to_0,
stream_2_to_1,
]
for message in messages:
for stream in streams:
await stream.write(message.encode())
response = (await stream.read(MAX_READ_LEN)).decode()
assert response == (ACK_STR_0 + message)
@pytest.mark.trio
async def test_host_connect(security_protocol):
async with HostFactory.create_batch_and_listen(
2, security_protocol=security_protocol
) as hosts:
assert len(hosts[0].get_peerstore().peer_ids()) == 1
await connect(hosts[0], hosts[1])
assert len(hosts[0].get_peerstore().peer_ids()) == 2
await connect(hosts[0], hosts[1])
# make sure we don't do double connection
assert len(hosts[0].get_peerstore().peer_ids()) == 2
assert hosts[1].get_id() in hosts[0].get_peerstore().peer_ids()
ma_node_b = multiaddr.Multiaddr("/p2p/%s" % hosts[1].get_id().pretty())
for addr in hosts[0].get_peerstore().addrs(hosts[1].get_id()):
assert addr.encapsulate(ma_node_b) in hosts[1].get_addrs()

View File

@ -0,0 +1,63 @@
from multiaddr import (
Multiaddr,
)
import pytest
import trio
from libp2p.network.connection.raw_connection import (
RawConnection,
)
from libp2p.tools.constants import (
LISTEN_MADDR,
)
from libp2p.transport.exceptions import (
OpenConnectionError,
)
from libp2p.transport.tcp.tcp import (
TCP,
)
@pytest.mark.trio
async def test_tcp_listener(nursery):
transport = TCP()
async def handler(tcp_stream):
pass
listener = transport.create_listener(handler)
assert len(listener.get_addrs()) == 0
await listener.listen(LISTEN_MADDR, nursery)
assert len(listener.get_addrs()) == 1
await listener.listen(LISTEN_MADDR, nursery)
assert len(listener.get_addrs()) == 2
@pytest.mark.trio
async def test_tcp_dial(nursery):
transport = TCP()
raw_conn_other_side = None
event = trio.Event()
async def handler(tcp_stream):
nonlocal raw_conn_other_side
raw_conn_other_side = RawConnection(tcp_stream, False)
event.set()
await trio.sleep_forever()
# Test: `OpenConnectionError` is raised when trying to dial to a port which
# no one is not listening to.
with pytest.raises(OpenConnectionError):
await transport.dial(Multiaddr("/ip4/127.0.0.1/tcp/1"))
listener = transport.create_listener(handler)
await listener.listen(LISTEN_MADDR, nursery)
addrs = listener.get_addrs()
assert len(addrs) == 1
listen_addr = addrs[0]
raw_conn = await transport.dial(listen_addr)
await event.wait()
data = b"123"
await raw_conn_other_side.write(data)
assert (await raw_conn.read(len(data))) == data