mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
reorg test structure to match tox and CI jobs, drop bumpversion for bump-my-version and move config to pyproject.toml, fix docs building
This commit is contained in:
27
tests/core/crypto/test_ed25519.py
Normal file
27
tests/core/crypto/test_ed25519.py
Normal file
@ -0,0 +1,27 @@
|
||||
from libp2p.crypto.ed25519 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.crypto.serialization import (
|
||||
deserialize_private_key,
|
||||
deserialize_public_key,
|
||||
)
|
||||
|
||||
|
||||
def test_public_key_serialize_deserialize_round_trip():
|
||||
key_pair = create_new_key_pair()
|
||||
public_key = key_pair.public_key
|
||||
|
||||
public_key_bytes = public_key.serialize()
|
||||
another_public_key = deserialize_public_key(public_key_bytes)
|
||||
|
||||
assert public_key == another_public_key
|
||||
|
||||
|
||||
def test_private_key_serialize_deserialize_round_trip():
|
||||
key_pair = create_new_key_pair()
|
||||
private_key = key_pair.private_key
|
||||
|
||||
private_key_bytes = private_key.serialize()
|
||||
another_private_key = deserialize_private_key(private_key_bytes)
|
||||
|
||||
assert private_key == another_private_key
|
||||
27
tests/core/crypto/test_secp256k1.py
Normal file
27
tests/core/crypto/test_secp256k1.py
Normal file
@ -0,0 +1,27 @@
|
||||
from libp2p.crypto.secp256k1 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.crypto.serialization import (
|
||||
deserialize_private_key,
|
||||
deserialize_public_key,
|
||||
)
|
||||
|
||||
|
||||
def test_public_key_serialize_deserialize_round_trip():
|
||||
key_pair = create_new_key_pair()
|
||||
public_key = key_pair.public_key
|
||||
|
||||
public_key_bytes = public_key.serialize()
|
||||
another_public_key = deserialize_public_key(public_key_bytes)
|
||||
|
||||
assert public_key == another_public_key
|
||||
|
||||
|
||||
def test_private_key_serialize_deserialize_round_trip():
|
||||
key_pair = create_new_key_pair()
|
||||
private_key = key_pair.private_key
|
||||
|
||||
private_key_bytes = private_key.serialize()
|
||||
another_private_key = deserialize_private_key(private_key_bytes)
|
||||
|
||||
assert private_key == another_private_key
|
||||
112
tests/core/examples/test_examples.py
Normal file
112
tests/core/examples/test_examples.py
Normal file
@ -0,0 +1,112 @@
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.host.exceptions import (
|
||||
StreamFailure,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
info_from_p2p_addr,
|
||||
)
|
||||
from libp2p.tools.factories import (
|
||||
HostFactory,
|
||||
)
|
||||
from libp2p.tools.utils import (
|
||||
MAX_READ_LEN,
|
||||
)
|
||||
|
||||
PROTOCOL_ID = "/chat/1.0.0"
|
||||
|
||||
|
||||
async def hello_world(host_a, host_b):
|
||||
hello_world_from_host_a = b"hello world from host a"
|
||||
hello_world_from_host_b = b"hello world from host b"
|
||||
|
||||
async def stream_handler(stream):
|
||||
read = await stream.read(len(hello_world_from_host_b))
|
||||
assert read == hello_world_from_host_b
|
||||
await stream.write(hello_world_from_host_a)
|
||||
await stream.close()
|
||||
|
||||
host_a.set_stream_handler(PROTOCOL_ID, stream_handler)
|
||||
|
||||
# Start a stream with the destination.
|
||||
# Multiaddress of the destination peer is fetched from the peerstore using 'peerId'.
|
||||
stream = await host_b.new_stream(host_a.get_id(), [PROTOCOL_ID])
|
||||
await stream.write(hello_world_from_host_b)
|
||||
read = await stream.read(MAX_READ_LEN)
|
||||
assert read == hello_world_from_host_a
|
||||
await stream.close()
|
||||
|
||||
|
||||
async def connect_write(host_a, host_b):
|
||||
messages = ["data %d" % i for i in range(5)]
|
||||
received = []
|
||||
|
||||
async def stream_handler(stream):
|
||||
for message in messages:
|
||||
received.append((await stream.read(len(message))).decode())
|
||||
|
||||
host_a.set_stream_handler(PROTOCOL_ID, stream_handler)
|
||||
|
||||
# Start a stream with the destination.
|
||||
# Multiaddress of the destination peer is fetched from the peerstore using 'peerId'.
|
||||
stream = await host_b.new_stream(host_a.get_id(), [PROTOCOL_ID])
|
||||
for message in messages:
|
||||
await stream.write(message.encode())
|
||||
|
||||
# Reader needs time due to async reads
|
||||
await trio.sleep(2)
|
||||
|
||||
await stream.close()
|
||||
assert received == messages
|
||||
|
||||
|
||||
async def connect_read(host_a, host_b):
|
||||
messages = [b"data %d" % i for i in range(5)]
|
||||
|
||||
async def stream_handler(stream):
|
||||
for message in messages:
|
||||
await stream.write(message)
|
||||
await stream.close()
|
||||
|
||||
host_a.set_stream_handler(PROTOCOL_ID, stream_handler)
|
||||
|
||||
# Start a stream with the destination.
|
||||
# Multiaddress of the destination peer is fetched from the peerstore using 'peerId'.
|
||||
stream = await host_b.new_stream(host_a.get_id(), [PROTOCOL_ID])
|
||||
received = []
|
||||
for message in messages:
|
||||
received.append(await stream.read(len(message)))
|
||||
await stream.close()
|
||||
assert received == messages
|
||||
|
||||
|
||||
async def no_common_protocol(host_a, host_b):
|
||||
messages = [b"data %d" % i for i in range(5)]
|
||||
|
||||
async def stream_handler(stream):
|
||||
for message in messages:
|
||||
await stream.write(message)
|
||||
await stream.close()
|
||||
|
||||
host_a.set_stream_handler(PROTOCOL_ID, stream_handler)
|
||||
|
||||
# try to creates a new new with a procotol not known by the other host
|
||||
with pytest.raises(StreamFailure):
|
||||
await host_b.new_stream(host_a.get_id(), ["/fakeproto/0.0.1"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test", [(hello_world), (connect_write), (connect_read), (no_common_protocol)]
|
||||
)
|
||||
@pytest.mark.trio
|
||||
async def test_chat(test, security_protocol):
|
||||
print("!@# ", security_protocol)
|
||||
async with HostFactory.create_batch_and_listen(
|
||||
2, security_protocol=security_protocol
|
||||
) as hosts:
|
||||
addr = hosts[0].get_addrs()[0]
|
||||
info = info_from_p2p_addr(addr)
|
||||
await hosts[1].connect(info)
|
||||
|
||||
await test(hosts[0], hosts[1])
|
||||
24
tests/core/host/test_basic_host.py
Normal file
24
tests/core/host/test_basic_host.py
Normal file
@ -0,0 +1,24 @@
|
||||
from libp2p import (
|
||||
new_swarm,
|
||||
)
|
||||
from libp2p.crypto.rsa import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.host.basic_host import (
|
||||
BasicHost,
|
||||
)
|
||||
from libp2p.host.defaults import (
|
||||
get_default_protocols,
|
||||
)
|
||||
|
||||
|
||||
def test_default_protocols():
|
||||
key_pair = create_new_key_pair()
|
||||
swarm = new_swarm(key_pair)
|
||||
host = BasicHost(swarm)
|
||||
|
||||
mux = host.get_mux()
|
||||
handlers = mux.handlers
|
||||
# NOTE: comparing keys for equality as handlers may be closures that do not compare
|
||||
# in the way this test is concerned with
|
||||
assert handlers.keys() == get_default_protocols(host).keys()
|
||||
49
tests/core/host/test_ping.py
Normal file
49
tests/core/host/test_ping.py
Normal file
@ -0,0 +1,49 @@
|
||||
import secrets
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.host.ping import (
|
||||
ID,
|
||||
PING_LENGTH,
|
||||
)
|
||||
from libp2p.tools.factories import (
|
||||
host_pair_factory,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_ping_once(security_protocol):
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||
some_ping = secrets.token_bytes(PING_LENGTH)
|
||||
await stream.write(some_ping)
|
||||
await trio.sleep(0.01)
|
||||
some_pong = await stream.read(PING_LENGTH)
|
||||
assert some_ping == some_pong
|
||||
await stream.close()
|
||||
|
||||
|
||||
SOME_PING_COUNT = 3
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_ping_several(security_protocol):
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||
for _ in range(SOME_PING_COUNT):
|
||||
some_ping = secrets.token_bytes(PING_LENGTH)
|
||||
await stream.write(some_ping)
|
||||
some_pong = await stream.read(PING_LENGTH)
|
||||
assert some_ping == some_pong
|
||||
# NOTE: simulate some time to sleep to mirror a real
|
||||
# world usage where a peer sends pings on some periodic interval
|
||||
# NOTE: this interval can be `0` for this test.
|
||||
await trio.sleep(0)
|
||||
await stream.close()
|
||||
32
tests/core/host/test_routed_host.py
Normal file
32
tests/core/host/test_routed_host.py
Normal file
@ -0,0 +1,32 @@
|
||||
import pytest
|
||||
|
||||
from libp2p.host.exceptions import (
|
||||
ConnectionFailure,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
from libp2p.tools.factories import (
|
||||
HostFactory,
|
||||
RoutedHostFactory,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_host_routing_success():
|
||||
async with RoutedHostFactory.create_batch_and_listen(2) as hosts:
|
||||
# forces to use routing as no addrs are provided
|
||||
await hosts[0].connect(PeerInfo(hosts[1].get_id(), []))
|
||||
await hosts[1].connect(PeerInfo(hosts[0].get_id(), []))
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_host_routing_fail():
|
||||
async with RoutedHostFactory.create_batch_and_listen(
|
||||
2
|
||||
) as routed_hosts, HostFactory.create_batch_and_listen(1) as basic_hosts:
|
||||
# routing fails because host_c does not use routing
|
||||
with pytest.raises(ConnectionFailure):
|
||||
await routed_hosts[0].connect(PeerInfo(basic_hosts[0].get_id(), []))
|
||||
with pytest.raises(ConnectionFailure):
|
||||
await routed_hosts[1].connect(PeerInfo(basic_hosts[0].get_id(), []))
|
||||
27
tests/core/identity/identify/test_protocol.py
Normal file
27
tests/core/identity/identify/test_protocol.py
Normal file
@ -0,0 +1,27 @@
|
||||
import pytest
|
||||
|
||||
from libp2p.identity.identify.pb.identify_pb2 import (
|
||||
Identify,
|
||||
)
|
||||
from libp2p.identity.identify.protocol import (
|
||||
ID,
|
||||
_mk_identify_protobuf,
|
||||
)
|
||||
from libp2p.tools.factories import (
|
||||
host_pair_factory,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_protocol(security_protocol):
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||
response = await stream.read()
|
||||
await stream.close()
|
||||
|
||||
identify_response = Identify()
|
||||
identify_response.ParseFromString(response)
|
||||
assert identify_response == _mk_identify_protobuf(host_a)
|
||||
29
tests/core/network/conftest.py
Normal file
29
tests/core/network/conftest.py
Normal file
@ -0,0 +1,29 @@
|
||||
import pytest
|
||||
|
||||
from libp2p.tools.factories import (
|
||||
net_stream_pair_factory,
|
||||
swarm_conn_pair_factory,
|
||||
swarm_pair_factory,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def net_stream_pair(security_protocol):
|
||||
async with net_stream_pair_factory(
|
||||
security_protocol=security_protocol
|
||||
) as net_stream_pair:
|
||||
yield net_stream_pair
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def swarm_pair(security_protocol):
|
||||
async with swarm_pair_factory(security_protocol=security_protocol) as swarms:
|
||||
yield swarms
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def swarm_conn_pair(security_protocol):
|
||||
async with swarm_conn_pair_factory(
|
||||
security_protocol=security_protocol
|
||||
) as swarm_conn_pair:
|
||||
yield swarm_conn_pair
|
||||
120
tests/core/network/test_net_stream.py
Normal file
120
tests/core/network/test_net_stream.py
Normal file
@ -0,0 +1,120 @@
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.network.stream.exceptions import (
|
||||
StreamClosed,
|
||||
StreamEOF,
|
||||
StreamReset,
|
||||
)
|
||||
from libp2p.tools.constants import (
|
||||
MAX_READ_LEN,
|
||||
)
|
||||
|
||||
DATA = b"data_123"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_net_stream_read_write(net_stream_pair):
|
||||
stream_0, stream_1 = net_stream_pair
|
||||
assert (
|
||||
stream_0.protocol_id is not None
|
||||
and stream_0.protocol_id == stream_1.protocol_id
|
||||
)
|
||||
await stream_0.write(DATA)
|
||||
assert (await stream_1.read(MAX_READ_LEN)) == DATA
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_net_stream_read_until_eof(net_stream_pair):
|
||||
read_bytes = bytearray()
|
||||
stream_0, stream_1 = net_stream_pair
|
||||
|
||||
async def read_until_eof():
|
||||
read_bytes.extend(await stream_1.read())
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(read_until_eof)
|
||||
expected_data = bytearray()
|
||||
|
||||
# Test: `read` doesn't return before `close` is called.
|
||||
await stream_0.write(DATA)
|
||||
expected_data.extend(DATA)
|
||||
await trio.sleep(0.01)
|
||||
assert len(read_bytes) == 0
|
||||
# Test: `read` doesn't return before `close` is called.
|
||||
await stream_0.write(DATA)
|
||||
expected_data.extend(DATA)
|
||||
await trio.sleep(0.01)
|
||||
assert len(read_bytes) == 0
|
||||
|
||||
# Test: Close the stream, `read` returns, and receive previous sent data.
|
||||
await stream_0.close()
|
||||
await trio.sleep(0.01)
|
||||
assert read_bytes == expected_data
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_net_stream_read_after_remote_closed(net_stream_pair):
|
||||
stream_0, stream_1 = net_stream_pair
|
||||
await stream_0.write(DATA)
|
||||
await stream_0.close()
|
||||
await trio.sleep(0.01)
|
||||
assert (await stream_1.read(MAX_READ_LEN)) == DATA
|
||||
with pytest.raises(StreamEOF):
|
||||
await stream_1.read(MAX_READ_LEN)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_net_stream_read_after_local_reset(net_stream_pair):
|
||||
stream_0, stream_1 = net_stream_pair
|
||||
await stream_0.reset()
|
||||
with pytest.raises(StreamReset):
|
||||
await stream_0.read(MAX_READ_LEN)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_net_stream_read_after_remote_reset(net_stream_pair):
|
||||
stream_0, stream_1 = net_stream_pair
|
||||
await stream_0.write(DATA)
|
||||
await stream_0.reset()
|
||||
# Sleep to let `stream_1` receive the message.
|
||||
await trio.sleep(0.01)
|
||||
with pytest.raises(StreamReset):
|
||||
await stream_1.read(MAX_READ_LEN)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_net_stream_read_after_remote_closed_and_reset(net_stream_pair):
|
||||
stream_0, stream_1 = net_stream_pair
|
||||
await stream_0.write(DATA)
|
||||
await stream_0.close()
|
||||
await stream_0.reset()
|
||||
# Sleep to let `stream_1` receive the message.
|
||||
await trio.sleep(0.01)
|
||||
assert (await stream_1.read(MAX_READ_LEN)) == DATA
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_net_stream_write_after_local_closed(net_stream_pair):
|
||||
stream_0, stream_1 = net_stream_pair
|
||||
await stream_0.write(DATA)
|
||||
await stream_0.close()
|
||||
with pytest.raises(StreamClosed):
|
||||
await stream_0.write(DATA)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_net_stream_write_after_local_reset(net_stream_pair):
|
||||
stream_0, stream_1 = net_stream_pair
|
||||
await stream_0.reset()
|
||||
with pytest.raises(StreamClosed):
|
||||
await stream_0.write(DATA)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_net_stream_write_after_remote_reset(net_stream_pair):
|
||||
stream_0, stream_1 = net_stream_pair
|
||||
await stream_1.reset()
|
||||
await trio.sleep(0.01)
|
||||
with pytest.raises(StreamClosed):
|
||||
await stream_0.write(DATA)
|
||||
126
tests/core/network/test_notify.py
Normal file
126
tests/core/network/test_notify.py
Normal file
@ -0,0 +1,126 @@
|
||||
"""
|
||||
Test Notify and Notifee by ensuring that the proper events get called, and that
|
||||
the stream passed into opened_stream is correct.
|
||||
|
||||
Note: Listen event does not get hit because MyNotifee is passed
|
||||
into network after network has already started listening
|
||||
|
||||
TODO: Add tests for closed_stream, listen_close when those
|
||||
features are implemented in swarm
|
||||
"""
|
||||
import enum
|
||||
|
||||
from async_service import (
|
||||
background_trio_service,
|
||||
)
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.network.notifee_interface import (
|
||||
INotifee,
|
||||
)
|
||||
from libp2p.tools.constants import (
|
||||
LISTEN_MADDR,
|
||||
)
|
||||
from libp2p.tools.factories import (
|
||||
SwarmFactory,
|
||||
)
|
||||
from libp2p.tools.utils import (
|
||||
connect_swarm,
|
||||
)
|
||||
|
||||
|
||||
class Event(enum.Enum):
|
||||
OpenedStream = 0
|
||||
ClosedStream = 1 # Not implemented
|
||||
Connected = 2
|
||||
Disconnected = 3
|
||||
Listen = 4
|
||||
ListenClose = 5 # Not implemented
|
||||
|
||||
|
||||
class MyNotifee(INotifee):
|
||||
def __init__(self, events):
|
||||
self.events = events
|
||||
|
||||
async def opened_stream(self, network, stream):
|
||||
self.events.append(Event.OpenedStream)
|
||||
|
||||
async def closed_stream(self, network, stream):
|
||||
# TODO: It is not implemented yet.
|
||||
pass
|
||||
|
||||
async def connected(self, network, conn):
|
||||
self.events.append(Event.Connected)
|
||||
|
||||
async def disconnected(self, network, conn):
|
||||
self.events.append(Event.Disconnected)
|
||||
|
||||
async def listen(self, network, _multiaddr):
|
||||
self.events.append(Event.Listen)
|
||||
|
||||
async def listen_close(self, network, _multiaddr):
|
||||
# TODO: It is not implemented yet.
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_notify(security_protocol):
|
||||
swarms = [SwarmFactory(security_protocol=security_protocol) for _ in range(2)]
|
||||
|
||||
events_0_0 = []
|
||||
events_1_0 = []
|
||||
events_0_without_listen = []
|
||||
# Run swarms.
|
||||
async with background_trio_service(swarms[0]), background_trio_service(swarms[1]):
|
||||
# Register events before listening, to allow `MyNotifee` is notified with the
|
||||
# event `listen`.
|
||||
swarms[0].register_notifee(MyNotifee(events_0_0))
|
||||
swarms[1].register_notifee(MyNotifee(events_1_0))
|
||||
|
||||
# Listen
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(swarms[0].listen, LISTEN_MADDR)
|
||||
nursery.start_soon(swarms[1].listen, LISTEN_MADDR)
|
||||
|
||||
swarms[0].register_notifee(MyNotifee(events_0_without_listen))
|
||||
|
||||
# Connected
|
||||
await connect_swarm(swarms[0], swarms[1])
|
||||
# OpenedStream: first
|
||||
await swarms[0].new_stream(swarms[1].get_peer_id())
|
||||
# OpenedStream: second
|
||||
await swarms[0].new_stream(swarms[1].get_peer_id())
|
||||
# OpenedStream: third, but different direction.
|
||||
await swarms[1].new_stream(swarms[0].get_peer_id())
|
||||
|
||||
await trio.sleep(0.01)
|
||||
|
||||
# TODO: Check `ClosedStream` and `ListenClose` events after they are ready.
|
||||
|
||||
# Disconnected
|
||||
await swarms[0].close_peer(swarms[1].get_peer_id())
|
||||
await trio.sleep(0.01)
|
||||
|
||||
# Connected again, but different direction.
|
||||
await connect_swarm(swarms[1], swarms[0])
|
||||
await trio.sleep(0.01)
|
||||
|
||||
# Disconnected again, but different direction.
|
||||
await swarms[1].close_peer(swarms[0].get_peer_id())
|
||||
await trio.sleep(0.01)
|
||||
|
||||
expected_events_without_listen = [
|
||||
Event.Connected,
|
||||
Event.OpenedStream,
|
||||
Event.OpenedStream,
|
||||
Event.OpenedStream,
|
||||
Event.Disconnected,
|
||||
Event.Connected,
|
||||
Event.Disconnected,
|
||||
]
|
||||
expected_events = [Event.Listen] + expected_events_without_listen
|
||||
|
||||
assert events_0_0 == expected_events
|
||||
assert events_1_0 == expected_events
|
||||
assert events_0_without_listen == expected_events_without_listen
|
||||
158
tests/core/network/test_swarm.py
Normal file
158
tests/core/network/test_swarm.py
Normal file
@ -0,0 +1,158 @@
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
)
|
||||
import pytest
|
||||
import trio
|
||||
from trio.testing import (
|
||||
wait_all_tasks_blocked,
|
||||
)
|
||||
|
||||
from libp2p.network.exceptions import (
|
||||
SwarmException,
|
||||
)
|
||||
from libp2p.tools.factories import (
|
||||
SwarmFactory,
|
||||
)
|
||||
from libp2p.tools.utils import (
|
||||
connect_swarm,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_dial_peer(security_protocol):
|
||||
async with SwarmFactory.create_batch_and_listen(
|
||||
3, security_protocol=security_protocol
|
||||
) as swarms:
|
||||
# Test: No addr found.
|
||||
with pytest.raises(SwarmException):
|
||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
|
||||
# Test: len(addr) in the peerstore is 0.
|
||||
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), [], 10000)
|
||||
with pytest.raises(SwarmException):
|
||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
|
||||
# Test: Succeed if addrs of the peer_id are present in the peerstore.
|
||||
addrs = tuple(
|
||||
addr
|
||||
for transport in swarms[1].listeners.values()
|
||||
for addr in transport.get_addrs()
|
||||
)
|
||||
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
|
||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
assert swarms[0].get_peer_id() in swarms[1].connections
|
||||
assert swarms[1].get_peer_id() in swarms[0].connections
|
||||
|
||||
# Test: Reuse connections when we already have ones with a peer.
|
||||
conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()]
|
||||
conn = await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
assert conn is conn_to_1
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_close_peer(security_protocol):
|
||||
async with SwarmFactory.create_batch_and_listen(
|
||||
3, security_protocol=security_protocol
|
||||
) as swarms:
|
||||
# 0 <> 1 <> 2
|
||||
await connect_swarm(swarms[0], swarms[1])
|
||||
await connect_swarm(swarms[1], swarms[2])
|
||||
|
||||
# peer 1 closes peer 0
|
||||
await swarms[1].close_peer(swarms[0].get_peer_id())
|
||||
await trio.sleep(0.01)
|
||||
await wait_all_tasks_blocked()
|
||||
# 0 1 <> 2
|
||||
assert len(swarms[0].connections) == 0
|
||||
assert (
|
||||
len(swarms[1].connections) == 1
|
||||
and swarms[2].get_peer_id() in swarms[1].connections
|
||||
)
|
||||
|
||||
# peer 1 is closed by peer 2
|
||||
await swarms[2].close_peer(swarms[1].get_peer_id())
|
||||
await trio.sleep(0.01)
|
||||
# 0 1 2
|
||||
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
|
||||
|
||||
await connect_swarm(swarms[0], swarms[1])
|
||||
# 0 <> 1 2
|
||||
assert (
|
||||
len(swarms[0].connections) == 1
|
||||
and swarms[1].get_peer_id() in swarms[0].connections
|
||||
)
|
||||
assert (
|
||||
len(swarms[1].connections) == 1
|
||||
and swarms[0].get_peer_id() in swarms[1].connections
|
||||
)
|
||||
# peer 0 closes peer 1
|
||||
await swarms[0].close_peer(swarms[1].get_peer_id())
|
||||
await trio.sleep(0.01)
|
||||
# 0 1 2
|
||||
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_remove_conn(swarm_pair):
|
||||
swarm_0, swarm_1 = swarm_pair
|
||||
conn_0 = swarm_0.connections[swarm_1.get_peer_id()]
|
||||
swarm_0.remove_conn(conn_0)
|
||||
assert swarm_1.get_peer_id() not in swarm_0.connections
|
||||
# Test: Remove twice. There should not be errors.
|
||||
swarm_0.remove_conn(conn_0)
|
||||
assert swarm_1.get_peer_id() not in swarm_0.connections
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_multiaddr(security_protocol):
|
||||
async with SwarmFactory.create_batch_and_listen(
|
||||
3, security_protocol=security_protocol
|
||||
) as swarms:
|
||||
|
||||
def clear():
|
||||
swarms[0].peerstore.clear_addrs(swarms[1].get_peer_id())
|
||||
|
||||
clear()
|
||||
# No addresses
|
||||
with pytest.raises(SwarmException):
|
||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
|
||||
clear()
|
||||
# Wrong addresses
|
||||
swarms[0].peerstore.add_addrs(
|
||||
swarms[1].get_peer_id(), [Multiaddr("/ip4/0.0.0.0/tcp/9999")], 10000
|
||||
)
|
||||
|
||||
with pytest.raises(SwarmException):
|
||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
|
||||
clear()
|
||||
# Multiple wrong addresses
|
||||
swarms[0].peerstore.add_addrs(
|
||||
swarms[1].get_peer_id(),
|
||||
[Multiaddr("/ip4/0.0.0.0/tcp/9999"), Multiaddr("/ip4/0.0.0.0/tcp/9998")],
|
||||
10000,
|
||||
)
|
||||
|
||||
with pytest.raises(SwarmException):
|
||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
|
||||
# Test one address
|
||||
addrs = tuple(
|
||||
addr
|
||||
for transport in swarms[1].listeners.values()
|
||||
for addr in transport.get_addrs()
|
||||
)
|
||||
|
||||
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs[:1], 10000)
|
||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
|
||||
# Test multiple addresses
|
||||
addrs = tuple(
|
||||
addr
|
||||
for transport in swarms[1].listeners.values()
|
||||
for addr in transport.get_addrs()
|
||||
)
|
||||
|
||||
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs + addrs, 10000)
|
||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
48
tests/core/network/test_swarm_conn.py
Normal file
48
tests/core/network/test_swarm_conn.py
Normal file
@ -0,0 +1,48 @@
|
||||
import pytest
|
||||
import trio
|
||||
from trio.testing import (
|
||||
wait_all_tasks_blocked,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_conn_close(swarm_conn_pair):
|
||||
conn_0, conn_1 = swarm_conn_pair
|
||||
|
||||
assert not conn_0.is_closed
|
||||
assert not conn_1.is_closed
|
||||
|
||||
await conn_0.close()
|
||||
|
||||
await trio.sleep(0.1)
|
||||
await wait_all_tasks_blocked()
|
||||
|
||||
assert conn_0.is_closed
|
||||
assert conn_1.is_closed
|
||||
assert conn_0 not in conn_0.swarm.connections.values()
|
||||
assert conn_1 not in conn_1.swarm.connections.values()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_conn_streams(swarm_conn_pair):
|
||||
conn_0, conn_1 = swarm_conn_pair
|
||||
|
||||
assert len(conn_0.get_streams()) == 0
|
||||
assert len(conn_1.get_streams()) == 0
|
||||
|
||||
stream_0_0 = await conn_0.new_stream()
|
||||
await trio.sleep(0.01)
|
||||
assert len(conn_0.get_streams()) == 1
|
||||
assert len(conn_1.get_streams()) == 1
|
||||
|
||||
stream_0_1 = await conn_0.new_stream()
|
||||
await trio.sleep(0.01)
|
||||
assert len(conn_0.get_streams()) == 2
|
||||
assert len(conn_1.get_streams()) == 2
|
||||
|
||||
conn_0.remove_stream(stream_0_0)
|
||||
assert len(conn_0.get_streams()) == 1
|
||||
conn_0.remove_stream(stream_0_1)
|
||||
assert len(conn_0.get_streams()) == 0
|
||||
# Nothing happen if `stream_0_1` is not present or already removed.
|
||||
conn_0.remove_stream(stream_0_1)
|
||||
61
tests/core/peer/test_addrbook.py
Normal file
61
tests/core/peer/test_addrbook.py
Normal file
@ -0,0 +1,61 @@
|
||||
import pytest
|
||||
|
||||
from libp2p.peer.peerstore import (
|
||||
PeerStore,
|
||||
PeerStoreError,
|
||||
)
|
||||
|
||||
# Testing methods from IAddrBook base class.
|
||||
|
||||
|
||||
def test_addrs_empty():
|
||||
with pytest.raises(PeerStoreError):
|
||||
store = PeerStore()
|
||||
val = store.addrs("peer")
|
||||
assert not val
|
||||
|
||||
|
||||
def test_add_addr_single():
|
||||
store = PeerStore()
|
||||
store.add_addr("peer1", "/foo", 10)
|
||||
store.add_addr("peer1", "/bar", 10)
|
||||
store.add_addr("peer2", "/baz", 10)
|
||||
|
||||
assert store.addrs("peer1") == ["/foo", "/bar"]
|
||||
assert store.addrs("peer2") == ["/baz"]
|
||||
|
||||
|
||||
def test_add_addrs_multiple():
|
||||
store = PeerStore()
|
||||
store.add_addrs("peer1", ["/foo1", "/bar1"], 10)
|
||||
store.add_addrs("peer2", ["/foo2"], 10)
|
||||
|
||||
assert store.addrs("peer1") == ["/foo1", "/bar1"]
|
||||
assert store.addrs("peer2") == ["/foo2"]
|
||||
|
||||
|
||||
def test_clear_addrs():
|
||||
store = PeerStore()
|
||||
store.add_addrs("peer1", ["/foo1", "/bar1"], 10)
|
||||
store.add_addrs("peer2", ["/foo2"], 10)
|
||||
store.clear_addrs("peer1")
|
||||
|
||||
assert store.addrs("peer1") == []
|
||||
assert store.addrs("peer2") == ["/foo2"]
|
||||
|
||||
store.add_addrs("peer1", ["/foo1", "/bar1"], 10)
|
||||
|
||||
assert store.addrs("peer1") == ["/foo1", "/bar1"]
|
||||
|
||||
|
||||
def test_peers_with_addrs():
|
||||
store = PeerStore()
|
||||
store.add_addrs("peer1", [], 10)
|
||||
store.add_addrs("peer2", ["/foo"], 10)
|
||||
store.add_addrs("peer3", ["/bar"], 10)
|
||||
|
||||
assert set(store.peers_with_addrs()) == {"peer2", "peer3"}
|
||||
|
||||
store.clear_addrs("peer2")
|
||||
|
||||
assert set(store.peers_with_addrs()) == {"peer3"}
|
||||
47
tests/core/peer/test_interop.py
Normal file
47
tests/core/peer/test_interop.py
Normal file
@ -0,0 +1,47 @@
|
||||
import base64
|
||||
|
||||
import Crypto.PublicKey.RSA as RSA
|
||||
|
||||
from libp2p.crypto.pb import crypto_pb2 as pb
|
||||
from libp2p.crypto.rsa import (
|
||||
RSAPrivateKey,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
|
||||
# ``PRIVATE_KEY_PROTOBUF_SERIALIZATION`` is a protobuf holding an RSA private key.
|
||||
PRIVATE_KEY_PROTOBUF_SERIALIZATION = """
|
||||
CAAS4AQwggJcAgEAAoGBAL7w+Wc4VhZhCdM/+Hccg5Nrf4q9NXWwJylbSrXz/unFS24wyk6pEk0zi3W
|
||||
7li+vSNVO+NtJQw9qGNAMtQKjVTP+3Vt/jfQRnQM3s6awojtjueEWuLYVt62z7mofOhCtj+VwIdZNBo
|
||||
/EkLZ0ETfcvN5LVtLYa8JkXybnOPsLvK+PAgMBAAECgYBdk09HDM7zzL657uHfzfOVrdslrTCj6p5mo
|
||||
DzvCxLkkjIzYGnlPuqfNyGjozkpSWgSUc+X+EGLLl3WqEOVdWJtbM61fewEHlRTM5JzScvwrJ39t7o6
|
||||
CCAjKA0cBWBd6UWgbN/t53RoWvh9HrA2AW5YrT0ZiAgKe9y7EMUaENVJ8QJBAPhpdmb4ZL4Fkm4OKia
|
||||
NEcjzn6mGTlZtef7K/0oRC9+2JkQnCuf6HBpaRhJoCJYg7DW8ZY+AV6xClKrgjBOfERMCQQDExhnzu2
|
||||
dsQ9k8QChBlpHO0TRbZBiQfC70oU31kM1AeLseZRmrxv9Yxzdl8D693NNWS2JbKOXl0kMHHcuGQLMVA
|
||||
kBZ7WvkmPV3aPL6jnwp2pXepntdVnaTiSxJ1dkXShZ/VSSDNZMYKY306EtHrIu3NZHtXhdyHKcggDXr
|
||||
qkBrdgErAkAlpGPojUwemOggr4FD8sLX1ot2hDJyyV7OK2FXfajWEYJyMRL1Gm9Uk1+Un53RAkJneqp
|
||||
JGAzKpyttXBTIDO51AkEA98KTiROMnnU8Y6Mgcvr68/SMIsvCYMt9/mtwSBGgl80VaTQ5Hpaktl6Xbh
|
||||
VUt5Wv0tRxlXZiViCGCD1EtrrwTw==
|
||||
""".replace(
|
||||
"\n", ""
|
||||
)
|
||||
|
||||
EXPECTED_PEER_ID = "QmRK3JgmVEGiewxWbhpXLJyjWuGuLeSTMTndA1coMHEy5o"
|
||||
|
||||
|
||||
# NOTE: this test checks that we can recreate the expected peer id given a private key
|
||||
# serialization, taken from the Go implementation of libp2p.
|
||||
def test_peer_id_interop():
|
||||
private_key_protobuf_bytes = base64.b64decode(PRIVATE_KEY_PROTOBUF_SERIALIZATION)
|
||||
private_key_protobuf = pb.PrivateKey()
|
||||
private_key_protobuf.ParseFromString(private_key_protobuf_bytes)
|
||||
|
||||
private_key_data = private_key_protobuf.data
|
||||
|
||||
private_key_impl = RSA.import_key(private_key_data)
|
||||
private_key = RSAPrivateKey(private_key_impl)
|
||||
public_key = private_key.get_public_key()
|
||||
|
||||
peer_id = ID.from_pubkey(public_key)
|
||||
assert peer_id == EXPECTED_PEER_ID
|
||||
110
tests/core/peer/test_peerid.py
Normal file
110
tests/core/peer/test_peerid.py
Normal file
@ -0,0 +1,110 @@
|
||||
import random
|
||||
|
||||
import base58
|
||||
import multihash
|
||||
|
||||
from libp2p.crypto.rsa import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
import libp2p.peer.id as PeerID
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
|
||||
ALPHABETS = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
|
||||
|
||||
# ensure we are not in "debug" mode for the following tests
|
||||
PeerID.FRIENDLY_IDS = False
|
||||
|
||||
|
||||
def test_eq_impl_for_bytes():
|
||||
random_id_string = ""
|
||||
for _ in range(10):
|
||||
random_id_string += random.choice(ALPHABETS)
|
||||
peer_id = ID(random_id_string.encode())
|
||||
assert peer_id == random_id_string.encode()
|
||||
|
||||
|
||||
def test_pretty():
|
||||
random_id_string = ""
|
||||
for _ in range(10):
|
||||
random_id_string += random.choice(ALPHABETS)
|
||||
peer_id = ID(random_id_string.encode())
|
||||
actual = peer_id.pretty()
|
||||
expected = base58.b58encode(random_id_string).decode()
|
||||
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_str_less_than_10():
|
||||
random_id_string = ""
|
||||
for _ in range(5):
|
||||
random_id_string += random.choice(ALPHABETS)
|
||||
peer_id = base58.b58encode(random_id_string).decode()
|
||||
expected = peer_id
|
||||
actual = ID(random_id_string.encode()).__str__()
|
||||
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_str_more_than_10():
|
||||
random_id_string = ""
|
||||
for _ in range(10):
|
||||
random_id_string += random.choice(ALPHABETS)
|
||||
peer_id = base58.b58encode(random_id_string).decode()
|
||||
expected = peer_id
|
||||
actual = ID(random_id_string.encode()).__str__()
|
||||
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_eq_true():
|
||||
random_id_string = ""
|
||||
for _ in range(10):
|
||||
random_id_string += random.choice(ALPHABETS)
|
||||
peer_id = ID(random_id_string.encode())
|
||||
|
||||
assert peer_id == base58.b58encode(random_id_string).decode()
|
||||
assert peer_id == random_id_string.encode()
|
||||
assert peer_id == ID(random_id_string.encode())
|
||||
|
||||
|
||||
def test_eq_false():
|
||||
peer_id = ID("efgh")
|
||||
other = ID("abcd")
|
||||
|
||||
assert peer_id != other
|
||||
|
||||
|
||||
def test_id_to_base58():
|
||||
random_id_string = ""
|
||||
for _ in range(10):
|
||||
random_id_string += random.choice(ALPHABETS)
|
||||
expected = base58.b58encode(random_id_string).decode()
|
||||
actual = ID(random_id_string.encode()).to_base58()
|
||||
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_id_from_base58():
|
||||
random_id_string = ""
|
||||
for _ in range(10):
|
||||
random_id_string += random.choice(ALPHABETS)
|
||||
expected = ID(base58.b58decode(random_id_string))
|
||||
actual = ID.from_base58(random_id_string.encode())
|
||||
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_id_from_public_key():
|
||||
key_pair = create_new_key_pair()
|
||||
public_key = key_pair.public_key
|
||||
|
||||
key_bin = public_key.serialize()
|
||||
algo = multihash.Func.sha2_256
|
||||
mh_digest = multihash.digest(key_bin, algo)
|
||||
expected = ID(mh_digest.encode())
|
||||
|
||||
actual = ID.from_pubkey(public_key)
|
||||
|
||||
assert actual == expected
|
||||
54
tests/core/peer/test_peerinfo.py
Normal file
54
tests/core/peer/test_peerinfo.py
Normal file
@ -0,0 +1,54 @@
|
||||
import random
|
||||
|
||||
import multiaddr
|
||||
import pytest
|
||||
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
InvalidAddrError,
|
||||
PeerInfo,
|
||||
info_from_p2p_addr,
|
||||
)
|
||||
|
||||
ALPHABETS = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
|
||||
VALID_MULTI_ADDR_STR = "/ip4/127.0.0.1/tcp/8000/p2p/3YgLAeMKSAPcGqZkAt8mREqhQXmJT8SN8VCMN4T6ih4GNX9wvK8mWJnWZ1qA2mLdCQ" # noqa: E501
|
||||
|
||||
|
||||
def test_init_():
|
||||
random_addrs = [random.randint(0, 255) for r in range(4)]
|
||||
random_id_string = ""
|
||||
for _ in range(10):
|
||||
random_id_string += random.SystemRandom().choice(ALPHABETS)
|
||||
peer_id = ID(random_id_string.encode())
|
||||
peer_info = PeerInfo(peer_id, random_addrs)
|
||||
|
||||
assert peer_info.peer_id == peer_id
|
||||
assert peer_info.addrs == random_addrs
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"addr",
|
||||
(
|
||||
pytest.param(multiaddr.Multiaddr("/"), id="empty multiaddr"),
|
||||
pytest.param(
|
||||
multiaddr.Multiaddr("/ip4/127.0.0.1"),
|
||||
id="multiaddr without peer_id(p2p protocol)",
|
||||
),
|
||||
),
|
||||
)
|
||||
def test_info_from_p2p_addr_invalid(addr):
|
||||
with pytest.raises(InvalidAddrError):
|
||||
info_from_p2p_addr(addr)
|
||||
|
||||
|
||||
def test_info_from_p2p_addr_valid():
|
||||
m_addr = multiaddr.Multiaddr(VALID_MULTI_ADDR_STR)
|
||||
info = info_from_p2p_addr(m_addr)
|
||||
assert (
|
||||
info.peer_id.pretty()
|
||||
== "3YgLAeMKSAPcGqZkAt8mREqhQXmJT8SN8VCMN4T6ih4GNX9wvK8mWJnWZ1qA2mLdCQ"
|
||||
)
|
||||
assert len(info.addrs) == 1
|
||||
assert str(info.addrs[0]) == "/ip4/127.0.0.1/tcp/8000"
|
||||
46
tests/core/peer/test_peermetadata.py
Normal file
46
tests/core/peer/test_peermetadata.py
Normal file
@ -0,0 +1,46 @@
|
||||
import pytest
|
||||
|
||||
from libp2p.peer.peerstore import (
|
||||
PeerStore,
|
||||
PeerStoreError,
|
||||
)
|
||||
|
||||
# Testing methods from IPeerMetadata base class.
|
||||
|
||||
|
||||
def test_get_empty():
|
||||
with pytest.raises(PeerStoreError):
|
||||
store = PeerStore()
|
||||
val = store.get("peer", "key")
|
||||
assert not val
|
||||
|
||||
|
||||
def test_put_get_simple():
|
||||
store = PeerStore()
|
||||
store.put("peer", "key", "val")
|
||||
assert store.get("peer", "key") == "val"
|
||||
|
||||
|
||||
def test_put_get_update():
|
||||
store = PeerStore()
|
||||
store.put("peer", "key1", "val1")
|
||||
store.put("peer", "key2", "val2")
|
||||
store.put("peer", "key2", "new val2")
|
||||
|
||||
assert store.get("peer", "key1") == "val1"
|
||||
assert store.get("peer", "key2") == "new val2"
|
||||
|
||||
|
||||
def test_put_get_two_peers():
|
||||
store = PeerStore()
|
||||
store.put("peer1", "key1", "val1")
|
||||
store.put("peer2", "key1", "val1 prime")
|
||||
|
||||
assert store.get("peer1", "key1") == "val1"
|
||||
assert store.get("peer2", "key1") == "val1 prime"
|
||||
|
||||
# Try update
|
||||
store.put("peer2", "key1", "new val1")
|
||||
|
||||
assert store.get("peer1", "key1") == "val1"
|
||||
assert store.get("peer2", "key1") == "new val1"
|
||||
62
tests/core/peer/test_peerstore.py
Normal file
62
tests/core/peer/test_peerstore.py
Normal file
@ -0,0 +1,62 @@
|
||||
import pytest
|
||||
|
||||
from libp2p.peer.peerstore import (
|
||||
PeerStore,
|
||||
PeerStoreError,
|
||||
)
|
||||
|
||||
# Testing methods from IPeerStore base class.
|
||||
|
||||
|
||||
def test_peer_info_empty():
|
||||
store = PeerStore()
|
||||
with pytest.raises(PeerStoreError):
|
||||
store.peer_info("peer")
|
||||
|
||||
|
||||
def test_peer_info_basic():
|
||||
store = PeerStore()
|
||||
store.add_addr("peer", "/foo", 10)
|
||||
info = store.peer_info("peer")
|
||||
|
||||
assert info.peer_id == "peer"
|
||||
assert info.addrs == ["/foo"]
|
||||
|
||||
|
||||
def test_add_get_protocols_basic():
|
||||
store = PeerStore()
|
||||
store.add_protocols("peer1", ["p1", "p2"])
|
||||
store.add_protocols("peer2", ["p3"])
|
||||
|
||||
assert set(store.get_protocols("peer1")) == {"p1", "p2"}
|
||||
assert set(store.get_protocols("peer2")) == {"p3"}
|
||||
|
||||
|
||||
def test_add_get_protocols_extend():
|
||||
store = PeerStore()
|
||||
store.add_protocols("peer1", ["p1", "p2"])
|
||||
store.add_protocols("peer1", ["p3"])
|
||||
|
||||
assert set(store.get_protocols("peer1")) == {"p1", "p2", "p3"}
|
||||
|
||||
|
||||
def test_set_protocols():
|
||||
store = PeerStore()
|
||||
store.add_protocols("peer1", ["p1", "p2"])
|
||||
store.add_protocols("peer2", ["p3"])
|
||||
|
||||
store.set_protocols("peer1", ["p4"])
|
||||
store.set_protocols("peer2", [])
|
||||
|
||||
assert set(store.get_protocols("peer1")) == {"p4"}
|
||||
assert set(store.get_protocols("peer2")) == set()
|
||||
|
||||
|
||||
# Test with methods from other Peer interfaces.
|
||||
def test_peers():
|
||||
store = PeerStore()
|
||||
store.add_protocols("peer1", [])
|
||||
store.put("peer2", "key", "val")
|
||||
store.add_addr("peer3", "/foo", 10)
|
||||
|
||||
assert set(store.peer_ids()) == {"peer1", "peer2", "peer3"}
|
||||
102
tests/core/protocol_muxer/test_protocol_muxer.py
Normal file
102
tests/core/protocol_muxer/test_protocol_muxer.py
Normal file
@ -0,0 +1,102 @@
|
||||
import pytest
|
||||
|
||||
from libp2p.host.exceptions import (
|
||||
StreamFailure,
|
||||
)
|
||||
from libp2p.tools.factories import (
|
||||
HostFactory,
|
||||
)
|
||||
from libp2p.tools.utils import (
|
||||
create_echo_stream_handler,
|
||||
)
|
||||
|
||||
PROTOCOL_ECHO = "/echo/1.0.0"
|
||||
PROTOCOL_POTATO = "/potato/1.0.0"
|
||||
PROTOCOL_FOO = "/foo/1.0.0"
|
||||
PROTOCOL_ROCK = "/rock/1.0.0"
|
||||
|
||||
ACK_PREFIX = "ack:"
|
||||
|
||||
|
||||
async def perform_simple_test(
|
||||
expected_selected_protocol,
|
||||
protocols_for_client,
|
||||
protocols_with_handlers,
|
||||
security_protocol,
|
||||
):
|
||||
async with HostFactory.create_batch_and_listen(
|
||||
2, security_protocol=security_protocol
|
||||
) as hosts:
|
||||
for protocol in protocols_with_handlers:
|
||||
hosts[1].set_stream_handler(
|
||||
protocol, create_echo_stream_handler(ACK_PREFIX)
|
||||
)
|
||||
|
||||
# Associate the peer with local ip address (see default parameters of Libp2p())
|
||||
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
|
||||
stream = await hosts[0].new_stream(hosts[1].get_id(), protocols_for_client)
|
||||
messages = ["hello" + str(x) for x in range(10)]
|
||||
for message in messages:
|
||||
expected_resp = "ack:" + message
|
||||
await stream.write(message.encode())
|
||||
response = (await stream.read(len(expected_resp))).decode()
|
||||
assert response == expected_resp
|
||||
|
||||
assert expected_selected_protocol == stream.get_protocol()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_single_protocol_succeeds(security_protocol):
|
||||
expected_selected_protocol = PROTOCOL_ECHO
|
||||
await perform_simple_test(
|
||||
expected_selected_protocol,
|
||||
[expected_selected_protocol],
|
||||
[expected_selected_protocol],
|
||||
security_protocol,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_single_protocol_fails(security_protocol):
|
||||
with pytest.raises(StreamFailure):
|
||||
await perform_simple_test(
|
||||
"", [PROTOCOL_ECHO], [PROTOCOL_POTATO], security_protocol
|
||||
)
|
||||
|
||||
# Cleanup not reached on error
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_multiple_protocol_first_is_valid_succeeds(security_protocol):
|
||||
expected_selected_protocol = PROTOCOL_ECHO
|
||||
protocols_for_client = [PROTOCOL_ECHO, PROTOCOL_POTATO]
|
||||
protocols_for_listener = [PROTOCOL_FOO, PROTOCOL_ECHO]
|
||||
await perform_simple_test(
|
||||
expected_selected_protocol,
|
||||
protocols_for_client,
|
||||
protocols_for_listener,
|
||||
security_protocol,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_multiple_protocol_second_is_valid_succeeds(security_protocol):
|
||||
expected_selected_protocol = PROTOCOL_FOO
|
||||
protocols_for_client = [PROTOCOL_ROCK, PROTOCOL_FOO]
|
||||
protocols_for_listener = [PROTOCOL_FOO, PROTOCOL_ECHO]
|
||||
await perform_simple_test(
|
||||
expected_selected_protocol,
|
||||
protocols_for_client,
|
||||
protocols_for_listener,
|
||||
security_protocol,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_multiple_protocol_fails(security_protocol):
|
||||
protocols_for_client = [PROTOCOL_ROCK, PROTOCOL_FOO, "/bar/1.0.0"]
|
||||
protocols_for_listener = ["/aspyn/1.0.0", "/rob/1.0.0", "/zx/1.0.0", "/alex/1.0.0"]
|
||||
with pytest.raises(StreamFailure):
|
||||
await perform_simple_test(
|
||||
"", protocols_for_client, protocols_for_listener, security_protocol
|
||||
)
|
||||
197
tests/core/pubsub/test_dummyaccount_demo.py
Normal file
197
tests/core/pubsub/test_dummyaccount_demo.py
Normal file
@ -0,0 +1,197 @@
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.tools.pubsub.dummy_account_node import (
|
||||
DummyAccountNode,
|
||||
)
|
||||
from libp2p.tools.utils import (
|
||||
connect,
|
||||
)
|
||||
|
||||
|
||||
async def perform_test(num_nodes, adjacency_map, action_func, assertion_func):
|
||||
"""
|
||||
Helper function to allow for easy construction of custom tests for dummy
|
||||
account nodes in various network topologies.
|
||||
|
||||
:param num_nodes: number of nodes in the test
|
||||
:param adjacency_map: adjacency map defining each node and its list of neighbors
|
||||
:param action_func: function to execute that includes actions by the nodes,
|
||||
such as send crypto and set crypto
|
||||
:param assertion_func: assertions for testing the results of the actions are correct
|
||||
"""
|
||||
|
||||
async with DummyAccountNode.create(num_nodes) as dummy_nodes:
|
||||
# Create connections between nodes according to `adjacency_map`
|
||||
async with trio.open_nursery() as nursery:
|
||||
for source_num in adjacency_map:
|
||||
target_nums = adjacency_map[source_num]
|
||||
for target_num in target_nums:
|
||||
nursery.start_soon(
|
||||
connect,
|
||||
dummy_nodes[source_num].host,
|
||||
dummy_nodes[target_num].host,
|
||||
)
|
||||
|
||||
# Allow time for network creation to take place
|
||||
await trio.sleep(0.25)
|
||||
|
||||
# Perform action function
|
||||
await action_func(dummy_nodes)
|
||||
|
||||
# Allow time for action function to be performed (i.e. messages to propogate)
|
||||
await trio.sleep(1)
|
||||
|
||||
# Perform assertion function
|
||||
for dummy_node in dummy_nodes:
|
||||
assertion_func(dummy_node)
|
||||
|
||||
# Success, terminate pending tasks.
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_simple_two_nodes():
|
||||
num_nodes = 2
|
||||
adj_map = {0: [1]}
|
||||
|
||||
async def action_func(dummy_nodes):
|
||||
await dummy_nodes[0].publish_set_crypto("aspyn", 10)
|
||||
|
||||
def assertion_func(dummy_node):
|
||||
assert dummy_node.get_balance("aspyn") == 10
|
||||
|
||||
await perform_test(num_nodes, adj_map, action_func, assertion_func)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_simple_three_nodes_line_topography():
|
||||
num_nodes = 3
|
||||
adj_map = {0: [1], 1: [2]}
|
||||
|
||||
async def action_func(dummy_nodes):
|
||||
await dummy_nodes[0].publish_set_crypto("aspyn", 10)
|
||||
|
||||
def assertion_func(dummy_node):
|
||||
assert dummy_node.get_balance("aspyn") == 10
|
||||
|
||||
await perform_test(num_nodes, adj_map, action_func, assertion_func)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_simple_three_nodes_triangle_topography():
|
||||
num_nodes = 3
|
||||
adj_map = {0: [1, 2], 1: [2]}
|
||||
|
||||
async def action_func(dummy_nodes):
|
||||
await dummy_nodes[0].publish_set_crypto("aspyn", 20)
|
||||
|
||||
def assertion_func(dummy_node):
|
||||
assert dummy_node.get_balance("aspyn") == 20
|
||||
|
||||
await perform_test(num_nodes, adj_map, action_func, assertion_func)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_simple_seven_nodes_tree_topography():
|
||||
num_nodes = 7
|
||||
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
|
||||
|
||||
async def action_func(dummy_nodes):
|
||||
await dummy_nodes[0].publish_set_crypto("aspyn", 20)
|
||||
|
||||
def assertion_func(dummy_node):
|
||||
assert dummy_node.get_balance("aspyn") == 20
|
||||
|
||||
await perform_test(num_nodes, adj_map, action_func, assertion_func)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_set_then_send_from_root_seven_nodes_tree_topography():
|
||||
num_nodes = 7
|
||||
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
|
||||
|
||||
async def action_func(dummy_nodes):
|
||||
await dummy_nodes[0].publish_set_crypto("aspyn", 20)
|
||||
await trio.sleep(0.25)
|
||||
await dummy_nodes[0].publish_send_crypto("aspyn", "alex", 5)
|
||||
|
||||
def assertion_func(dummy_node):
|
||||
assert dummy_node.get_balance("aspyn") == 15
|
||||
assert dummy_node.get_balance("alex") == 5
|
||||
|
||||
await perform_test(num_nodes, adj_map, action_func, assertion_func)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_set_then_send_from_different_leafs_seven_nodes_tree_topography():
|
||||
num_nodes = 7
|
||||
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
|
||||
|
||||
async def action_func(dummy_nodes):
|
||||
await dummy_nodes[6].publish_set_crypto("aspyn", 20)
|
||||
await trio.sleep(0.25)
|
||||
await dummy_nodes[4].publish_send_crypto("aspyn", "alex", 5)
|
||||
|
||||
def assertion_func(dummy_node):
|
||||
assert dummy_node.get_balance("aspyn") == 15
|
||||
assert dummy_node.get_balance("alex") == 5
|
||||
|
||||
await perform_test(num_nodes, adj_map, action_func, assertion_func)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_simple_five_nodes_ring_topography():
|
||||
num_nodes = 5
|
||||
adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]}
|
||||
|
||||
async def action_func(dummy_nodes):
|
||||
await dummy_nodes[0].publish_set_crypto("aspyn", 20)
|
||||
|
||||
def assertion_func(dummy_node):
|
||||
assert dummy_node.get_balance("aspyn") == 20
|
||||
|
||||
await perform_test(num_nodes, adj_map, action_func, assertion_func)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_set_then_send_from_diff_nodes_five_nodes_ring_topography():
|
||||
num_nodes = 5
|
||||
adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]}
|
||||
|
||||
async def action_func(dummy_nodes):
|
||||
await dummy_nodes[0].publish_set_crypto("alex", 20)
|
||||
await trio.sleep(0.25)
|
||||
await dummy_nodes[3].publish_send_crypto("alex", "rob", 12)
|
||||
|
||||
def assertion_func(dummy_node):
|
||||
assert dummy_node.get_balance("alex") == 8
|
||||
assert dummy_node.get_balance("rob") == 12
|
||||
|
||||
await perform_test(num_nodes, adj_map, action_func, assertion_func)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
@pytest.mark.slow
|
||||
async def test_set_then_send_from_five_diff_nodes_five_nodes_ring_topography():
|
||||
num_nodes = 5
|
||||
adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]}
|
||||
|
||||
async def action_func(dummy_nodes):
|
||||
await dummy_nodes[0].publish_set_crypto("alex", 20)
|
||||
await trio.sleep(1)
|
||||
await dummy_nodes[1].publish_send_crypto("alex", "rob", 3)
|
||||
await trio.sleep(1)
|
||||
await dummy_nodes[2].publish_send_crypto("rob", "aspyn", 2)
|
||||
await trio.sleep(1)
|
||||
await dummy_nodes[3].publish_send_crypto("aspyn", "zx", 1)
|
||||
await trio.sleep(1)
|
||||
await dummy_nodes[4].publish_send_crypto("zx", "raul", 1)
|
||||
|
||||
def assertion_func(dummy_node):
|
||||
assert dummy_node.get_balance("alex") == 17
|
||||
assert dummy_node.get_balance("rob") == 1
|
||||
assert dummy_node.get_balance("aspyn") == 1
|
||||
assert dummy_node.get_balance("zx") == 0
|
||||
assert dummy_node.get_balance("raul") == 1
|
||||
|
||||
await perform_test(num_nodes, adj_map, action_func, assertion_func)
|
||||
95
tests/core/pubsub/test_floodsub.py
Normal file
95
tests/core/pubsub/test_floodsub.py
Normal file
@ -0,0 +1,95 @@
|
||||
import functools
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.tools.factories import (
|
||||
PubsubFactory,
|
||||
)
|
||||
from libp2p.tools.pubsub.floodsub_integration_test_settings import (
|
||||
floodsub_protocol_pytest_params,
|
||||
perform_test_from_obj,
|
||||
)
|
||||
from libp2p.tools.utils import (
|
||||
connect,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_simple_two_nodes():
|
||||
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
|
||||
topic = "my_topic"
|
||||
data = b"some data"
|
||||
|
||||
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
|
||||
await trio.sleep(0.25)
|
||||
|
||||
sub_b = await pubsubs_fsub[1].subscribe(topic)
|
||||
# Sleep to let a know of b's subscription
|
||||
await trio.sleep(0.25)
|
||||
|
||||
await pubsubs_fsub[0].publish(topic, data)
|
||||
|
||||
res_b = await sub_b.get()
|
||||
|
||||
# Check that the msg received by node_b is the same
|
||||
# as the message sent by node_a
|
||||
assert ID(res_b.from_id) == pubsubs_fsub[0].host.get_id()
|
||||
assert res_b.data == data
|
||||
assert res_b.topicIDs == [topic]
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_lru_cache_two_nodes():
|
||||
# two nodes with cache_size of 4
|
||||
|
||||
# Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`.
|
||||
def get_msg_id(msg):
|
||||
# Originally it is `(msg.seqno, msg.from_id)`
|
||||
return (msg.data, msg.from_id)
|
||||
|
||||
async with PubsubFactory.create_batch_with_floodsub(
|
||||
2, cache_size=4, msg_id_constructor=get_msg_id
|
||||
) as pubsubs_fsub:
|
||||
# `node_a` send the following messages to node_b
|
||||
message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1]
|
||||
# `node_b` should only receive the following
|
||||
expected_received_indices = [1, 2, 3, 4, 5, 1]
|
||||
|
||||
topic = "my_topic"
|
||||
|
||||
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
|
||||
await trio.sleep(0.25)
|
||||
|
||||
sub_b = await pubsubs_fsub[1].subscribe(topic)
|
||||
await trio.sleep(0.25)
|
||||
|
||||
def _make_testing_data(i: int) -> bytes:
|
||||
num_int_bytes = 4
|
||||
if i >= 2 ** (num_int_bytes * 8):
|
||||
raise ValueError("integer is too large to be serialized")
|
||||
return b"data" + i.to_bytes(num_int_bytes, "big")
|
||||
|
||||
for index in message_indices:
|
||||
await pubsubs_fsub[0].publish(topic, _make_testing_data(index))
|
||||
await trio.sleep(0.25)
|
||||
|
||||
for index in expected_received_indices:
|
||||
res_b = await sub_b.get()
|
||||
assert res_b.data == _make_testing_data(index)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params)
|
||||
@pytest.mark.trio
|
||||
@pytest.mark.slow
|
||||
async def test_gossipsub_run_with_floodsub_tests(test_case_obj, security_protocol):
|
||||
await perform_test_from_obj(
|
||||
test_case_obj,
|
||||
functools.partial(
|
||||
PubsubFactory.create_batch_with_floodsub,
|
||||
security_protocol=security_protocol,
|
||||
),
|
||||
)
|
||||
488
tests/core/pubsub/test_gossipsub.py
Normal file
488
tests/core/pubsub/test_gossipsub.py
Normal file
@ -0,0 +1,488 @@
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.pubsub.gossipsub import (
|
||||
PROTOCOL_ID,
|
||||
)
|
||||
from libp2p.tools.factories import (
|
||||
IDFactory,
|
||||
PubsubFactory,
|
||||
)
|
||||
from libp2p.tools.pubsub.utils import (
|
||||
dense_connect,
|
||||
one_to_all_connect,
|
||||
)
|
||||
from libp2p.tools.utils import (
|
||||
connect,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_join():
|
||||
async with PubsubFactory.create_batch_with_gossipsub(
|
||||
4, degree=4, degree_low=3, degree_high=5
|
||||
) as pubsubs_gsub:
|
||||
gossipsubs = [pubsub.router for pubsub in pubsubs_gsub]
|
||||
hosts = [pubsub.host for pubsub in pubsubs_gsub]
|
||||
hosts_indices = list(range(len(pubsubs_gsub)))
|
||||
|
||||
topic = "test_join"
|
||||
central_node_index = 0
|
||||
# Remove index of central host from the indices
|
||||
hosts_indices.remove(central_node_index)
|
||||
num_subscribed_peer = 2
|
||||
subscribed_peer_indices = random.sample(hosts_indices, num_subscribed_peer)
|
||||
|
||||
# All pubsub except the one of central node subscribe to topic
|
||||
for i in subscribed_peer_indices:
|
||||
await pubsubs_gsub[i].subscribe(topic)
|
||||
|
||||
# Connect central host to all other hosts
|
||||
await one_to_all_connect(hosts, central_node_index)
|
||||
|
||||
# Wait 2 seconds for heartbeat to allow mesh to connect
|
||||
await trio.sleep(2)
|
||||
|
||||
# Central node publish to the topic so that this topic
|
||||
# is added to central node's fanout
|
||||
# publish from the randomly chosen host
|
||||
await pubsubs_gsub[central_node_index].publish(topic, b"data")
|
||||
|
||||
# Check that the gossipsub of central node has fanout for the topic
|
||||
assert topic in gossipsubs[central_node_index].fanout
|
||||
# Check that the gossipsub of central node does not have a mesh for the topic
|
||||
assert topic not in gossipsubs[central_node_index].mesh
|
||||
|
||||
# Central node subscribes the topic
|
||||
await pubsubs_gsub[central_node_index].subscribe(topic)
|
||||
|
||||
await trio.sleep(2)
|
||||
|
||||
# Check that the gossipsub of central node no longer has fanout for the topic
|
||||
assert topic not in gossipsubs[central_node_index].fanout
|
||||
|
||||
for i in hosts_indices:
|
||||
if i in subscribed_peer_indices:
|
||||
assert hosts[i].get_id() in gossipsubs[central_node_index].mesh[topic]
|
||||
assert hosts[central_node_index].get_id() in gossipsubs[i].mesh[topic]
|
||||
else:
|
||||
assert (
|
||||
hosts[i].get_id() not in gossipsubs[central_node_index].mesh[topic]
|
||||
)
|
||||
assert topic not in gossipsubs[i].mesh
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_leave():
|
||||
async with PubsubFactory.create_batch_with_gossipsub(1) as pubsubs_gsub:
|
||||
gossipsub = pubsubs_gsub[0].router
|
||||
topic = "test_leave"
|
||||
|
||||
assert topic not in gossipsub.mesh
|
||||
|
||||
await gossipsub.join(topic)
|
||||
assert topic in gossipsub.mesh
|
||||
|
||||
await gossipsub.leave(topic)
|
||||
assert topic not in gossipsub.mesh
|
||||
|
||||
# Test re-leave
|
||||
await gossipsub.leave(topic)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_handle_graft(monkeypatch):
|
||||
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
|
||||
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub)
|
||||
|
||||
index_alice = 0
|
||||
id_alice = pubsubs_gsub[index_alice].my_id
|
||||
index_bob = 1
|
||||
id_bob = pubsubs_gsub[index_bob].my_id
|
||||
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
|
||||
|
||||
# Wait 2 seconds for heartbeat to allow mesh to connect
|
||||
await trio.sleep(2)
|
||||
|
||||
topic = "test_handle_graft"
|
||||
# Only lice subscribe to the topic
|
||||
await gossipsubs[index_alice].join(topic)
|
||||
|
||||
# Monkey patch bob's `emit_prune` function so we can
|
||||
# check if it is called in `handle_graft`
|
||||
event_emit_prune = trio.Event()
|
||||
|
||||
async def emit_prune(topic, sender_peer_id):
|
||||
event_emit_prune.set()
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
monkeypatch.setattr(gossipsubs[index_bob], "emit_prune", emit_prune)
|
||||
|
||||
# Check that alice is bob's peer but not his mesh peer
|
||||
assert gossipsubs[index_bob].peer_protocol[id_alice] == PROTOCOL_ID
|
||||
assert topic not in gossipsubs[index_bob].mesh
|
||||
|
||||
await gossipsubs[index_alice].emit_graft(topic, id_bob)
|
||||
|
||||
# Check that `emit_prune` is called
|
||||
await event_emit_prune.wait()
|
||||
|
||||
# Check that bob is alice's peer but not her mesh peer
|
||||
assert topic in gossipsubs[index_alice].mesh
|
||||
assert id_bob not in gossipsubs[index_alice].mesh[topic]
|
||||
assert gossipsubs[index_alice].peer_protocol[id_bob] == PROTOCOL_ID
|
||||
|
||||
await gossipsubs[index_bob].emit_graft(topic, id_alice)
|
||||
|
||||
await trio.sleep(1)
|
||||
|
||||
# Check that bob is now alice's mesh peer
|
||||
assert id_bob in gossipsubs[index_alice].mesh[topic]
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_handle_prune():
|
||||
async with PubsubFactory.create_batch_with_gossipsub(
|
||||
2, heartbeat_interval=3
|
||||
) as pubsubs_gsub:
|
||||
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub)
|
||||
|
||||
index_alice = 0
|
||||
id_alice = pubsubs_gsub[index_alice].my_id
|
||||
index_bob = 1
|
||||
id_bob = pubsubs_gsub[index_bob].my_id
|
||||
|
||||
topic = "test_handle_prune"
|
||||
for pubsub in pubsubs_gsub:
|
||||
await pubsub.subscribe(topic)
|
||||
|
||||
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
|
||||
|
||||
# Wait for heartbeat to allow mesh to connect
|
||||
await trio.sleep(1)
|
||||
|
||||
# Check that they are each other's mesh peer
|
||||
assert id_alice in gossipsubs[index_bob].mesh[topic]
|
||||
assert id_bob in gossipsubs[index_alice].mesh[topic]
|
||||
|
||||
# alice emit prune message to bob, alice should be removed
|
||||
# from bob's mesh peer
|
||||
await gossipsubs[index_alice].emit_prune(topic, id_bob)
|
||||
# `emit_prune` does not remove bob from alice's mesh peers
|
||||
assert id_bob in gossipsubs[index_alice].mesh[topic]
|
||||
|
||||
# NOTE: We increase `heartbeat_interval` to 3 seconds so that bob will not
|
||||
# add alice back to his mesh after heartbeat.
|
||||
# Wait for bob to `handle_prune`
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Check that alice is no longer bob's mesh peer
|
||||
assert id_alice not in gossipsubs[index_bob].mesh[topic]
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_dense():
|
||||
async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub:
|
||||
hosts = [pubsub.host for pubsub in pubsubs_gsub]
|
||||
num_msgs = 5
|
||||
|
||||
# All pubsub subscribe to foobar
|
||||
queues = [await pubsub.subscribe("foobar") for pubsub in pubsubs_gsub]
|
||||
|
||||
# Densely connect libp2p hosts in a random way
|
||||
await dense_connect(hosts)
|
||||
|
||||
# Wait 2 seconds for heartbeat to allow mesh to connect
|
||||
await trio.sleep(2)
|
||||
|
||||
for i in range(num_msgs):
|
||||
msg_content = b"foo " + i.to_bytes(1, "big")
|
||||
|
||||
# randomly pick a message origin
|
||||
origin_idx = random.randint(0, len(hosts) - 1)
|
||||
|
||||
# publish from the randomly chosen host
|
||||
await pubsubs_gsub[origin_idx].publish("foobar", msg_content)
|
||||
|
||||
await trio.sleep(0.5)
|
||||
# Assert that all blocking queues receive the message
|
||||
for queue in queues:
|
||||
msg = await queue.get()
|
||||
assert msg.data == msg_content
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_fanout():
|
||||
async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub:
|
||||
hosts = [pubsub.host for pubsub in pubsubs_gsub]
|
||||
num_msgs = 5
|
||||
|
||||
# All pubsub subscribe to foobar except for `pubsubs_gsub[0]`
|
||||
subs = [await pubsub.subscribe("foobar") for pubsub in pubsubs_gsub[1:]]
|
||||
|
||||
# Sparsely connect libp2p hosts in random way
|
||||
await dense_connect(hosts)
|
||||
|
||||
# Wait 2 seconds for heartbeat to allow mesh to connect
|
||||
await trio.sleep(2)
|
||||
|
||||
topic = "foobar"
|
||||
# Send messages with origin not subscribed
|
||||
for i in range(num_msgs):
|
||||
msg_content = b"foo " + i.to_bytes(1, "big")
|
||||
|
||||
# Pick the message origin to the node that is not subscribed to 'foobar'
|
||||
origin_idx = 0
|
||||
|
||||
# publish from the randomly chosen host
|
||||
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
|
||||
|
||||
await trio.sleep(0.5)
|
||||
# Assert that all blocking queues receive the message
|
||||
for sub in subs:
|
||||
msg = await sub.get()
|
||||
assert msg.data == msg_content
|
||||
|
||||
# Subscribe message origin
|
||||
subs.insert(0, await pubsubs_gsub[0].subscribe(topic))
|
||||
|
||||
# Send messages again
|
||||
for i in range(num_msgs):
|
||||
msg_content = b"bar " + i.to_bytes(1, "big")
|
||||
|
||||
# Pick the message origin to the node that is not subscribed to 'foobar'
|
||||
origin_idx = 0
|
||||
|
||||
# publish from the randomly chosen host
|
||||
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
|
||||
|
||||
await trio.sleep(0.5)
|
||||
# Assert that all blocking queues receive the message
|
||||
for sub in subs:
|
||||
msg = await sub.get()
|
||||
assert msg.data == msg_content
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
@pytest.mark.slow
|
||||
async def test_fanout_maintenance():
|
||||
async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub:
|
||||
hosts = [pubsub.host for pubsub in pubsubs_gsub]
|
||||
num_msgs = 5
|
||||
|
||||
# All pubsub subscribe to foobar
|
||||
queues = []
|
||||
topic = "foobar"
|
||||
for i in range(1, len(pubsubs_gsub)):
|
||||
q = await pubsubs_gsub[i].subscribe(topic)
|
||||
|
||||
# Add each blocking queue to an array of blocking queues
|
||||
queues.append(q)
|
||||
|
||||
# Sparsely connect libp2p hosts in random way
|
||||
await dense_connect(hosts)
|
||||
|
||||
# Wait 2 seconds for heartbeat to allow mesh to connect
|
||||
await trio.sleep(2)
|
||||
|
||||
# Send messages with origin not subscribed
|
||||
for i in range(num_msgs):
|
||||
msg_content = b"foo " + i.to_bytes(1, "big")
|
||||
|
||||
# Pick the message origin to the node that is not subscribed to 'foobar'
|
||||
origin_idx = 0
|
||||
|
||||
# publish from the randomly chosen host
|
||||
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
|
||||
|
||||
await trio.sleep(0.5)
|
||||
# Assert that all blocking queues receive the message
|
||||
for queue in queues:
|
||||
msg = await queue.get()
|
||||
assert msg.data == msg_content
|
||||
|
||||
for sub in pubsubs_gsub:
|
||||
await sub.unsubscribe(topic)
|
||||
|
||||
queues = []
|
||||
|
||||
await trio.sleep(2)
|
||||
|
||||
# Resub and repeat
|
||||
for i in range(1, len(pubsubs_gsub)):
|
||||
q = await pubsubs_gsub[i].subscribe(topic)
|
||||
|
||||
# Add each blocking queue to an array of blocking queues
|
||||
queues.append(q)
|
||||
|
||||
await trio.sleep(2)
|
||||
|
||||
# Check messages can still be sent
|
||||
for i in range(num_msgs):
|
||||
msg_content = b"bar " + i.to_bytes(1, "big")
|
||||
|
||||
# Pick the message origin to the node that is not subscribed to 'foobar'
|
||||
origin_idx = 0
|
||||
|
||||
# publish from the randomly chosen host
|
||||
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
|
||||
|
||||
await trio.sleep(0.5)
|
||||
# Assert that all blocking queues receive the message
|
||||
for queue in queues:
|
||||
msg = await queue.get()
|
||||
assert msg.data == msg_content
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_gossip_propagation():
|
||||
async with PubsubFactory.create_batch_with_gossipsub(
|
||||
2, degree=1, degree_low=0, degree_high=2, gossip_window=50, gossip_history=100
|
||||
) as pubsubs_gsub:
|
||||
topic = "foo"
|
||||
queue_0 = await pubsubs_gsub[0].subscribe(topic)
|
||||
|
||||
# node 0 publish to topic
|
||||
msg_content = b"foo_msg"
|
||||
|
||||
# publish from the randomly chosen host
|
||||
await pubsubs_gsub[0].publish(topic, msg_content)
|
||||
|
||||
await trio.sleep(0.5)
|
||||
# Assert that the blocking queues receive the message
|
||||
msg = await queue_0.get()
|
||||
assert msg.data == msg_content
|
||||
|
||||
|
||||
@pytest.mark.parametrize("initial_mesh_peer_count", (7, 10, 13))
|
||||
@pytest.mark.trio
|
||||
async def test_mesh_heartbeat(initial_mesh_peer_count, monkeypatch):
|
||||
async with PubsubFactory.create_batch_with_gossipsub(
|
||||
1, heartbeat_initial_delay=100
|
||||
) as pubsubs_gsub:
|
||||
# It's difficult to set up the initial peer subscription condition.
|
||||
# Ideally I would like to have initial mesh peer count that's below
|
||||
# ``GossipSubDegree`` so I can test if `mesh_heartbeat` return correct peers to
|
||||
# GRAFT. The problem is that I can not set it up so that we have peers subscribe
|
||||
# to the topic but not being part of our mesh peers (as these peers are the
|
||||
# peers to GRAFT). So I monkeypatch the peer subscriptions and our mesh peers.
|
||||
total_peer_count = 14
|
||||
topic = "TEST_MESH_HEARTBEAT"
|
||||
|
||||
fake_peer_ids = [IDFactory() for _ in range(total_peer_count)]
|
||||
peer_protocol = {peer_id: PROTOCOL_ID for peer_id in fake_peer_ids}
|
||||
monkeypatch.setattr(pubsubs_gsub[0].router, "peer_protocol", peer_protocol)
|
||||
|
||||
peer_topics = {topic: set(fake_peer_ids)}
|
||||
# Monkeypatch the peer subscriptions
|
||||
monkeypatch.setattr(pubsubs_gsub[0], "peer_topics", peer_topics)
|
||||
|
||||
mesh_peer_indices = random.sample(
|
||||
range(total_peer_count), initial_mesh_peer_count
|
||||
)
|
||||
mesh_peers = [fake_peer_ids[i] for i in mesh_peer_indices]
|
||||
router_mesh = {topic: set(mesh_peers)}
|
||||
# Monkeypatch our mesh peers
|
||||
monkeypatch.setattr(pubsubs_gsub[0].router, "mesh", router_mesh)
|
||||
|
||||
peers_to_graft, peers_to_prune = pubsubs_gsub[0].router.mesh_heartbeat()
|
||||
if initial_mesh_peer_count > pubsubs_gsub[0].router.degree:
|
||||
# If number of initial mesh peers is more than `GossipSubDegree`,
|
||||
# we should PRUNE mesh peers
|
||||
assert len(peers_to_graft) == 0
|
||||
assert (
|
||||
len(peers_to_prune)
|
||||
== initial_mesh_peer_count - pubsubs_gsub[0].router.degree
|
||||
)
|
||||
for peer in peers_to_prune:
|
||||
assert peer in mesh_peers
|
||||
elif initial_mesh_peer_count < pubsubs_gsub[0].router.degree:
|
||||
# If number of initial mesh peers is less than `GossipSubDegree`,
|
||||
# we should GRAFT more peers
|
||||
assert len(peers_to_prune) == 0
|
||||
assert (
|
||||
len(peers_to_graft)
|
||||
== pubsubs_gsub[0].router.degree - initial_mesh_peer_count
|
||||
)
|
||||
for peer in peers_to_graft:
|
||||
assert peer not in mesh_peers
|
||||
else:
|
||||
assert len(peers_to_prune) == 0 and len(peers_to_graft) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("initial_peer_count", (1, 4, 7))
|
||||
@pytest.mark.trio
|
||||
async def test_gossip_heartbeat(initial_peer_count, monkeypatch):
|
||||
async with PubsubFactory.create_batch_with_gossipsub(
|
||||
1, heartbeat_initial_delay=100
|
||||
) as pubsubs_gsub:
|
||||
# The problem is that I can not set it up so that we have peers subscribe to the
|
||||
# topic but not being part of our mesh peers (as these peers are the peers to
|
||||
# GRAFT). So I monkeypatch the peer subscriptions and our mesh peers.
|
||||
total_peer_count = 28
|
||||
topic_mesh = "TEST_GOSSIP_HEARTBEAT_1"
|
||||
topic_fanout = "TEST_GOSSIP_HEARTBEAT_2"
|
||||
|
||||
fake_peer_ids = [IDFactory() for _ in range(total_peer_count)]
|
||||
peer_protocol = {peer_id: PROTOCOL_ID for peer_id in fake_peer_ids}
|
||||
monkeypatch.setattr(pubsubs_gsub[0].router, "peer_protocol", peer_protocol)
|
||||
|
||||
topic_mesh_peer_count = 14
|
||||
# Split into mesh peers and fanout peers
|
||||
peer_topics = {
|
||||
topic_mesh: set(fake_peer_ids[:topic_mesh_peer_count]),
|
||||
topic_fanout: set(fake_peer_ids[topic_mesh_peer_count:]),
|
||||
}
|
||||
# Monkeypatch the peer subscriptions
|
||||
monkeypatch.setattr(pubsubs_gsub[0], "peer_topics", peer_topics)
|
||||
|
||||
mesh_peer_indices = random.sample(
|
||||
range(topic_mesh_peer_count), initial_peer_count
|
||||
)
|
||||
mesh_peers = [fake_peer_ids[i] for i in mesh_peer_indices]
|
||||
router_mesh = {topic_mesh: set(mesh_peers)}
|
||||
# Monkeypatch our mesh peers
|
||||
monkeypatch.setattr(pubsubs_gsub[0].router, "mesh", router_mesh)
|
||||
fanout_peer_indices = random.sample(
|
||||
range(topic_mesh_peer_count, total_peer_count), initial_peer_count
|
||||
)
|
||||
fanout_peers = [fake_peer_ids[i] for i in fanout_peer_indices]
|
||||
router_fanout = {topic_fanout: set(fanout_peers)}
|
||||
# Monkeypatch our fanout peers
|
||||
monkeypatch.setattr(pubsubs_gsub[0].router, "fanout", router_fanout)
|
||||
|
||||
def window(topic):
|
||||
if topic == topic_mesh:
|
||||
return [topic_mesh]
|
||||
elif topic == topic_fanout:
|
||||
return [topic_fanout]
|
||||
else:
|
||||
return []
|
||||
|
||||
# Monkeypatch the memory cache messages
|
||||
monkeypatch.setattr(pubsubs_gsub[0].router.mcache, "window", window)
|
||||
|
||||
peers_to_gossip = pubsubs_gsub[0].router.gossip_heartbeat()
|
||||
# If our mesh peer count is less than `GossipSubDegree`, we should gossip to up
|
||||
# to `GossipSubDegree` peers (exclude mesh peers).
|
||||
if topic_mesh_peer_count - initial_peer_count < pubsubs_gsub[0].router.degree:
|
||||
# The same goes for fanout so it's two times the number of peers to gossip.
|
||||
assert len(peers_to_gossip) == 2 * (
|
||||
topic_mesh_peer_count - initial_peer_count
|
||||
)
|
||||
elif (
|
||||
topic_mesh_peer_count - initial_peer_count >= pubsubs_gsub[0].router.degree
|
||||
):
|
||||
assert len(peers_to_gossip) == 2 * (pubsubs_gsub[0].router.degree)
|
||||
|
||||
for peer in peers_to_gossip:
|
||||
if peer in peer_topics[topic_mesh]:
|
||||
# Check that the peer to gossip to is not in our mesh peers
|
||||
assert peer not in mesh_peers
|
||||
assert topic_mesh in peers_to_gossip[peer]
|
||||
elif peer in peer_topics[topic_fanout]:
|
||||
# Check that the peer to gossip to is not in our fanout peers
|
||||
assert peer not in fanout_peers
|
||||
assert topic_fanout in peers_to_gossip[peer]
|
||||
31
tests/core/pubsub/test_gossipsub_backward_compatibility.py
Normal file
31
tests/core/pubsub/test_gossipsub_backward_compatibility.py
Normal file
@ -0,0 +1,31 @@
|
||||
import functools
|
||||
|
||||
import pytest
|
||||
|
||||
from libp2p.tools.constants import (
|
||||
FLOODSUB_PROTOCOL_ID,
|
||||
)
|
||||
from libp2p.tools.factories import (
|
||||
PubsubFactory,
|
||||
)
|
||||
from libp2p.tools.pubsub.floodsub_integration_test_settings import (
|
||||
floodsub_protocol_pytest_params,
|
||||
perform_test_from_obj,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params)
|
||||
@pytest.mark.trio
|
||||
@pytest.mark.slow
|
||||
async def test_gossipsub_run_with_floodsub_tests(test_case_obj):
|
||||
await perform_test_from_obj(
|
||||
test_case_obj,
|
||||
functools.partial(
|
||||
PubsubFactory.create_batch_with_gossipsub,
|
||||
protocols=[FLOODSUB_PROTOCOL_ID],
|
||||
degree=3,
|
||||
degree_low=2,
|
||||
degree_high=4,
|
||||
time_to_live=30,
|
||||
),
|
||||
)
|
||||
130
tests/core/pubsub/test_mcache.py
Normal file
130
tests/core/pubsub/test_mcache.py
Normal file
@ -0,0 +1,130 @@
|
||||
from libp2p.pubsub.mcache import (
|
||||
MessageCache,
|
||||
)
|
||||
|
||||
|
||||
class Msg:
|
||||
__slots__ = ["topicIDs", "seqno", "from_id"]
|
||||
|
||||
def __init__(self, topicIDs, seqno, from_id):
|
||||
self.topicIDs = topicIDs
|
||||
self.seqno = seqno
|
||||
self.from_id = from_id
|
||||
|
||||
|
||||
def test_mcache():
|
||||
# Ported from:
|
||||
# https://github.com/libp2p/go-libp2p-pubsub/blob/51b7501433411b5096cac2b4994a36a68515fc03/mcache_test.go
|
||||
mcache = MessageCache(3, 5)
|
||||
msgs = []
|
||||
|
||||
for i in range(60):
|
||||
msgs.append(Msg(["test"], i, "test"))
|
||||
|
||||
for i in range(10):
|
||||
mcache.put(msgs[i])
|
||||
|
||||
for i in range(10):
|
||||
msg = msgs[i]
|
||||
mid = (msg.seqno, msg.from_id)
|
||||
get_msg = mcache.get(mid)
|
||||
|
||||
# successful read
|
||||
assert get_msg == msg
|
||||
|
||||
gids = mcache.window("test")
|
||||
|
||||
assert len(gids) == 10
|
||||
|
||||
for i in range(10):
|
||||
msg = msgs[i]
|
||||
mid = (msg.seqno, msg.from_id)
|
||||
|
||||
assert mid == gids[i]
|
||||
|
||||
mcache.shift()
|
||||
|
||||
for i in range(10, 20):
|
||||
mcache.put(msgs[i])
|
||||
|
||||
for i in range(20):
|
||||
msg = msgs[i]
|
||||
mid = (msg.seqno, msg.from_id)
|
||||
get_msg = mcache.get(mid)
|
||||
|
||||
assert get_msg == msg
|
||||
|
||||
gids = mcache.window("test")
|
||||
|
||||
assert len(gids) == 20
|
||||
|
||||
for i in range(10):
|
||||
msg = msgs[i]
|
||||
mid = (msg.seqno, msg.from_id)
|
||||
|
||||
assert mid == gids[10 + i]
|
||||
|
||||
for i in range(10, 20):
|
||||
msg = msgs[i]
|
||||
mid = (msg.seqno, msg.from_id)
|
||||
|
||||
assert mid == gids[i - 10]
|
||||
|
||||
mcache.shift()
|
||||
|
||||
for i in range(20, 30):
|
||||
mcache.put(msgs[i])
|
||||
|
||||
mcache.shift()
|
||||
|
||||
for i in range(30, 40):
|
||||
mcache.put(msgs[i])
|
||||
|
||||
mcache.shift()
|
||||
|
||||
for i in range(40, 50):
|
||||
mcache.put(msgs[i])
|
||||
|
||||
mcache.shift()
|
||||
|
||||
for i in range(50, 60):
|
||||
mcache.put(msgs[i])
|
||||
|
||||
assert len(mcache.msgs) == 50
|
||||
|
||||
for i in range(10):
|
||||
msg = msgs[i]
|
||||
mid = (msg.seqno, msg.from_id)
|
||||
get_msg = mcache.get(mid)
|
||||
|
||||
# Should be evicted from cache
|
||||
assert not get_msg
|
||||
|
||||
for i in range(10, 60):
|
||||
msg = msgs[i]
|
||||
mid = (msg.seqno, msg.from_id)
|
||||
get_msg = mcache.get(mid)
|
||||
|
||||
assert get_msg == msg
|
||||
|
||||
gids = mcache.window("test")
|
||||
|
||||
assert len(gids) == 30
|
||||
|
||||
for i in range(10):
|
||||
msg = msgs[50 + i]
|
||||
mid = (msg.seqno, msg.from_id)
|
||||
|
||||
assert mid == gids[i]
|
||||
|
||||
for i in range(10, 20):
|
||||
msg = msgs[30 + i]
|
||||
mid = (msg.seqno, msg.from_id)
|
||||
|
||||
assert mid == gids[i]
|
||||
|
||||
for i in range(20, 30):
|
||||
msg = msgs[10 + i]
|
||||
mid = (msg.seqno, msg.from_id)
|
||||
|
||||
assert mid == gids[i]
|
||||
684
tests/core/pubsub/test_pubsub.py
Normal file
684
tests/core/pubsub/test_pubsub.py
Normal file
@ -0,0 +1,684 @@
|
||||
from contextlib import (
|
||||
contextmanager,
|
||||
)
|
||||
from typing import (
|
||||
NamedTuple,
|
||||
)
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.exceptions import (
|
||||
ValidationError,
|
||||
)
|
||||
from libp2p.pubsub.pb import (
|
||||
rpc_pb2,
|
||||
)
|
||||
from libp2p.pubsub.pubsub import (
|
||||
PUBSUB_SIGNING_PREFIX,
|
||||
SUBSCRIPTION_CHANNEL_SIZE,
|
||||
)
|
||||
from libp2p.tools.constants import (
|
||||
MAX_READ_LEN,
|
||||
)
|
||||
from libp2p.tools.factories import (
|
||||
IDFactory,
|
||||
PubsubFactory,
|
||||
net_stream_pair_factory,
|
||||
)
|
||||
from libp2p.tools.pubsub.utils import (
|
||||
make_pubsub_msg,
|
||||
)
|
||||
from libp2p.tools.utils import (
|
||||
connect,
|
||||
)
|
||||
from libp2p.utils import (
|
||||
encode_varint_prefixed,
|
||||
)
|
||||
|
||||
TESTING_TOPIC = "TEST_SUBSCRIBE"
|
||||
TESTING_DATA = b"data"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_subscribe_and_unsubscribe():
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids
|
||||
|
||||
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_re_subscribe():
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids
|
||||
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_re_unsubscribe():
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
# Unsubscribe from topic we didn't even subscribe to
|
||||
assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].topic_ids
|
||||
await pubsubs_fsub[0].unsubscribe("NOT_MY_TOPIC")
|
||||
assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].topic_ids
|
||||
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids
|
||||
|
||||
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids
|
||||
|
||||
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_peers_subscribe():
|
||||
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
|
||||
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
# Yield to let 0 notify 1
|
||||
await trio.sleep(1)
|
||||
assert pubsubs_fsub[0].my_id in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
|
||||
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
||||
# Yield to let 0 notify 1
|
||||
await trio.sleep(1)
|
||||
assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_get_hello_packet():
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
|
||||
def _get_hello_packet_topic_ids():
|
||||
packet = pubsubs_fsub[0].get_hello_packet()
|
||||
return tuple(sub.topicid for sub in packet.subscriptions)
|
||||
|
||||
# Test: No subscription, so there should not be any topic ids in the
|
||||
# hello packet.
|
||||
assert len(_get_hello_packet_topic_ids()) == 0
|
||||
|
||||
# Test: After subscriptions, topic ids should be in the hello packet.
|
||||
topic_ids = ["t", "o", "p", "i", "c"]
|
||||
for topic in topic_ids:
|
||||
await pubsubs_fsub[0].subscribe(topic)
|
||||
topic_ids_in_hello = _get_hello_packet_topic_ids()
|
||||
for topic in topic_ids:
|
||||
assert topic in topic_ids_in_hello
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_set_and_remove_topic_validator():
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
is_sync_validator_called = False
|
||||
|
||||
def sync_validator(peer_id, msg):
|
||||
nonlocal is_sync_validator_called
|
||||
is_sync_validator_called = True
|
||||
|
||||
is_async_validator_called = False
|
||||
|
||||
async def async_validator(peer_id, msg):
|
||||
nonlocal is_async_validator_called
|
||||
is_async_validator_called = True
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
topic = "TEST_VALIDATOR"
|
||||
|
||||
assert topic not in pubsubs_fsub[0].topic_validators
|
||||
|
||||
# Register sync validator
|
||||
pubsubs_fsub[0].set_topic_validator(topic, sync_validator, False)
|
||||
|
||||
assert topic in pubsubs_fsub[0].topic_validators
|
||||
topic_validator = pubsubs_fsub[0].topic_validators[topic]
|
||||
assert not topic_validator.is_async
|
||||
|
||||
# Validate with sync validator
|
||||
topic_validator.validator(peer_id=IDFactory(), msg="msg")
|
||||
|
||||
assert is_sync_validator_called
|
||||
assert not is_async_validator_called
|
||||
|
||||
# Register with async validator
|
||||
pubsubs_fsub[0].set_topic_validator(topic, async_validator, True)
|
||||
|
||||
is_sync_validator_called = False
|
||||
assert topic in pubsubs_fsub[0].topic_validators
|
||||
topic_validator = pubsubs_fsub[0].topic_validators[topic]
|
||||
assert topic_validator.is_async
|
||||
|
||||
# Validate with async validator
|
||||
await topic_validator.validator(peer_id=IDFactory(), msg="msg")
|
||||
|
||||
assert is_async_validator_called
|
||||
assert not is_sync_validator_called
|
||||
|
||||
# Remove validator
|
||||
pubsubs_fsub[0].remove_topic_validator(topic)
|
||||
assert topic not in pubsubs_fsub[0].topic_validators
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_get_msg_validators():
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
times_sync_validator_called = 0
|
||||
|
||||
def sync_validator(peer_id, msg):
|
||||
nonlocal times_sync_validator_called
|
||||
times_sync_validator_called += 1
|
||||
|
||||
times_async_validator_called = 0
|
||||
|
||||
async def async_validator(peer_id, msg):
|
||||
nonlocal times_async_validator_called
|
||||
times_async_validator_called += 1
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
topic_1 = "TEST_VALIDATOR_1"
|
||||
topic_2 = "TEST_VALIDATOR_2"
|
||||
topic_3 = "TEST_VALIDATOR_3"
|
||||
|
||||
# Register sync validator for topic 1 and 2
|
||||
pubsubs_fsub[0].set_topic_validator(topic_1, sync_validator, False)
|
||||
pubsubs_fsub[0].set_topic_validator(topic_2, sync_validator, False)
|
||||
|
||||
# Register async validator for topic 3
|
||||
pubsubs_fsub[0].set_topic_validator(topic_3, async_validator, True)
|
||||
|
||||
msg = make_pubsub_msg(
|
||||
origin_id=pubsubs_fsub[0].my_id,
|
||||
topic_ids=[topic_1, topic_2, topic_3],
|
||||
data=b"1234",
|
||||
seqno=b"\x00" * 8,
|
||||
)
|
||||
|
||||
topic_validators = pubsubs_fsub[0].get_msg_validators(msg)
|
||||
for topic_validator in topic_validators:
|
||||
if topic_validator.is_async:
|
||||
await topic_validator.validator(peer_id=IDFactory(), msg="msg")
|
||||
else:
|
||||
topic_validator.validator(peer_id=IDFactory(), msg="msg")
|
||||
|
||||
assert times_sync_validator_called == 2
|
||||
assert times_async_validator_called == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"is_topic_1_val_passed, is_topic_2_val_passed",
|
||||
((False, True), (True, False), (True, True)),
|
||||
)
|
||||
@pytest.mark.trio
|
||||
async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed):
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
|
||||
def passed_sync_validator(peer_id, msg):
|
||||
return True
|
||||
|
||||
def failed_sync_validator(peer_id, msg):
|
||||
return False
|
||||
|
||||
async def passed_async_validator(peer_id, msg):
|
||||
await trio.lowlevel.checkpoint()
|
||||
return True
|
||||
|
||||
async def failed_async_validator(peer_id, msg):
|
||||
await trio.lowlevel.checkpoint()
|
||||
return False
|
||||
|
||||
topic_1 = "TEST_SYNC_VALIDATOR"
|
||||
topic_2 = "TEST_ASYNC_VALIDATOR"
|
||||
|
||||
if is_topic_1_val_passed:
|
||||
pubsubs_fsub[0].set_topic_validator(topic_1, passed_sync_validator, False)
|
||||
else:
|
||||
pubsubs_fsub[0].set_topic_validator(topic_1, failed_sync_validator, False)
|
||||
|
||||
if is_topic_2_val_passed:
|
||||
pubsubs_fsub[0].set_topic_validator(topic_2, passed_async_validator, True)
|
||||
else:
|
||||
pubsubs_fsub[0].set_topic_validator(topic_2, failed_async_validator, True)
|
||||
|
||||
msg = make_pubsub_msg(
|
||||
origin_id=pubsubs_fsub[0].my_id,
|
||||
topic_ids=[topic_1, topic_2],
|
||||
data=b"1234",
|
||||
seqno=b"\x00" * 8,
|
||||
)
|
||||
|
||||
if is_topic_1_val_passed and is_topic_2_val_passed:
|
||||
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
|
||||
else:
|
||||
with pytest.raises(ValidationError):
|
||||
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_continuously_read_stream(monkeypatch, nursery, security_protocol):
|
||||
async def wait_for_event_occurring(event):
|
||||
await trio.lowlevel.checkpoint()
|
||||
with trio.fail_after(0.1):
|
||||
await event.wait()
|
||||
|
||||
class Events(NamedTuple):
|
||||
push_msg: trio.Event
|
||||
handle_subscription: trio.Event
|
||||
handle_rpc: trio.Event
|
||||
|
||||
@contextmanager
|
||||
def mock_methods():
|
||||
event_push_msg = trio.Event()
|
||||
event_handle_subscription = trio.Event()
|
||||
event_handle_rpc = trio.Event()
|
||||
|
||||
async def mock_push_msg(msg_forwarder, msg):
|
||||
event_push_msg.set()
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
def mock_handle_subscription(origin_id, sub_message):
|
||||
event_handle_subscription.set()
|
||||
|
||||
async def mock_handle_rpc(rpc, sender_peer_id):
|
||||
event_handle_rpc.set()
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setattr(pubsubs_fsub[0], "push_msg", mock_push_msg)
|
||||
m.setattr(pubsubs_fsub[0], "handle_subscription", mock_handle_subscription)
|
||||
m.setattr(pubsubs_fsub[0].router, "handle_rpc", mock_handle_rpc)
|
||||
yield Events(event_push_msg, event_handle_subscription, event_handle_rpc)
|
||||
|
||||
async with PubsubFactory.create_batch_with_floodsub(
|
||||
1, security_protocol=security_protocol
|
||||
) as pubsubs_fsub, net_stream_pair_factory(
|
||||
security_protocol=security_protocol
|
||||
) as stream_pair:
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
# Kick off the task `continuously_read_stream`
|
||||
nursery.start_soon(pubsubs_fsub[0].continuously_read_stream, stream_pair[0])
|
||||
|
||||
# Test: `push_msg` is called when publishing to a subscribed topic.
|
||||
publish_subscribed_topic = rpc_pb2.RPC(
|
||||
publish=[rpc_pb2.Message(topicIDs=[TESTING_TOPIC])]
|
||||
)
|
||||
with mock_methods() as events:
|
||||
await stream_pair[1].write(
|
||||
encode_varint_prefixed(publish_subscribed_topic.SerializeToString())
|
||||
)
|
||||
await wait_for_event_occurring(events.push_msg)
|
||||
# Make sure the other events are not emitted.
|
||||
with pytest.raises(trio.TooSlowError):
|
||||
await wait_for_event_occurring(events.handle_subscription)
|
||||
with pytest.raises(trio.TooSlowError):
|
||||
await wait_for_event_occurring(events.handle_rpc)
|
||||
|
||||
# Test: `push_msg` is not called when publishing to a topic-not-subscribed.
|
||||
publish_not_subscribed_topic = rpc_pb2.RPC(
|
||||
publish=[rpc_pb2.Message(topicIDs=["NOT_SUBSCRIBED"])]
|
||||
)
|
||||
with mock_methods() as events:
|
||||
await stream_pair[1].write(
|
||||
encode_varint_prefixed(publish_not_subscribed_topic.SerializeToString())
|
||||
)
|
||||
with pytest.raises(trio.TooSlowError):
|
||||
await wait_for_event_occurring(events.push_msg)
|
||||
|
||||
# Test: `handle_subscription` is called when a subscription message is received.
|
||||
subscription_msg = rpc_pb2.RPC(subscriptions=[rpc_pb2.RPC.SubOpts()])
|
||||
with mock_methods() as events:
|
||||
await stream_pair[1].write(
|
||||
encode_varint_prefixed(subscription_msg.SerializeToString())
|
||||
)
|
||||
await wait_for_event_occurring(events.handle_subscription)
|
||||
# Make sure the other events are not emitted.
|
||||
with pytest.raises(trio.TooSlowError):
|
||||
await wait_for_event_occurring(events.push_msg)
|
||||
with pytest.raises(trio.TooSlowError):
|
||||
await wait_for_event_occurring(events.handle_rpc)
|
||||
|
||||
# Test: `handle_rpc` is called when a control message is received.
|
||||
control_msg = rpc_pb2.RPC(control=rpc_pb2.ControlMessage())
|
||||
with mock_methods() as events:
|
||||
await stream_pair[1].write(
|
||||
encode_varint_prefixed(control_msg.SerializeToString())
|
||||
)
|
||||
await wait_for_event_occurring(events.handle_rpc)
|
||||
# Make sure the other events are not emitted.
|
||||
with pytest.raises(trio.TooSlowError):
|
||||
await wait_for_event_occurring(events.push_msg)
|
||||
with pytest.raises(trio.TooSlowError):
|
||||
await wait_for_event_occurring(events.handle_subscription)
|
||||
|
||||
|
||||
# TODO: Add the following tests after they are aligned with Go.
|
||||
# (Issue #191: https://github.com/libp2p/py-libp2p/issues/191)
|
||||
# - `test_stream_handler`
|
||||
# - `test_handle_peer_queue`
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_handle_subscription():
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
assert len(pubsubs_fsub[0].peer_topics) == 0
|
||||
sub_msg_0 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=TESTING_TOPIC)
|
||||
peer_ids = [IDFactory() for _ in range(2)]
|
||||
# Test: One peer is subscribed
|
||||
pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_0)
|
||||
assert (
|
||||
len(pubsubs_fsub[0].peer_topics) == 1
|
||||
and TESTING_TOPIC in pubsubs_fsub[0].peer_topics
|
||||
)
|
||||
assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 1
|
||||
assert peer_ids[0] in pubsubs_fsub[0].peer_topics[TESTING_TOPIC]
|
||||
# Test: Another peer is subscribed
|
||||
pubsubs_fsub[0].handle_subscription(peer_ids[1], sub_msg_0)
|
||||
assert len(pubsubs_fsub[0].peer_topics) == 1
|
||||
assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 2
|
||||
assert peer_ids[1] in pubsubs_fsub[0].peer_topics[TESTING_TOPIC]
|
||||
# Test: Subscribe to another topic
|
||||
another_topic = "ANOTHER_TOPIC"
|
||||
sub_msg_1 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=another_topic)
|
||||
pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_1)
|
||||
assert len(pubsubs_fsub[0].peer_topics) == 2
|
||||
assert another_topic in pubsubs_fsub[0].peer_topics
|
||||
assert peer_ids[0] in pubsubs_fsub[0].peer_topics[another_topic]
|
||||
# Test: unsubscribe
|
||||
unsub_msg = rpc_pb2.RPC.SubOpts(subscribe=False, topicid=TESTING_TOPIC)
|
||||
pubsubs_fsub[0].handle_subscription(peer_ids[0], unsub_msg)
|
||||
assert peer_ids[0] not in pubsubs_fsub[0].peer_topics[TESTING_TOPIC]
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_handle_talk():
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
msg_0 = make_pubsub_msg(
|
||||
origin_id=pubsubs_fsub[0].my_id,
|
||||
topic_ids=[TESTING_TOPIC],
|
||||
data=b"1234",
|
||||
seqno=b"\x00" * 8,
|
||||
)
|
||||
pubsubs_fsub[0].notify_subscriptions(msg_0)
|
||||
msg_1 = make_pubsub_msg(
|
||||
origin_id=pubsubs_fsub[0].my_id,
|
||||
topic_ids=["NOT_SUBSCRIBED"],
|
||||
data=b"1234",
|
||||
seqno=b"\x11" * 8,
|
||||
)
|
||||
pubsubs_fsub[0].notify_subscriptions(msg_1)
|
||||
assert (
|
||||
len(pubsubs_fsub[0].topic_ids) == 1
|
||||
and sub == pubsubs_fsub[0].subscribed_topics_receive[TESTING_TOPIC]
|
||||
)
|
||||
assert (await sub.get()) == msg_0
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_message_all_peers(monkeypatch, security_protocol):
|
||||
async with PubsubFactory.create_batch_with_floodsub(
|
||||
1, security_protocol=security_protocol
|
||||
) as pubsubs_fsub, net_stream_pair_factory(
|
||||
security_protocol=security_protocol
|
||||
) as stream_pair:
|
||||
peer_id = IDFactory()
|
||||
mock_peers = {peer_id: stream_pair[0]}
|
||||
with monkeypatch.context() as m:
|
||||
m.setattr(pubsubs_fsub[0], "peers", mock_peers)
|
||||
|
||||
empty_rpc = rpc_pb2.RPC()
|
||||
empty_rpc_bytes = empty_rpc.SerializeToString()
|
||||
empty_rpc_bytes_len_prefixed = encode_varint_prefixed(empty_rpc_bytes)
|
||||
await pubsubs_fsub[0].message_all_peers(empty_rpc_bytes)
|
||||
assert (
|
||||
await stream_pair[1].read(MAX_READ_LEN)
|
||||
) == empty_rpc_bytes_len_prefixed
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_subscribe_and_publish():
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
pubsub = pubsubs_fsub[0]
|
||||
|
||||
list_data = [b"d0", b"d1"]
|
||||
event_receive_data_started = trio.Event()
|
||||
|
||||
async def publish_data(topic):
|
||||
await event_receive_data_started.wait()
|
||||
for data in list_data:
|
||||
await pubsub.publish(topic, data)
|
||||
|
||||
async def receive_data(topic):
|
||||
i = 0
|
||||
event_receive_data_started.set()
|
||||
assert topic not in pubsub.topic_ids
|
||||
subscription = await pubsub.subscribe(topic)
|
||||
async with subscription:
|
||||
assert topic in pubsub.topic_ids
|
||||
async for msg in subscription:
|
||||
assert msg.data == list_data[i]
|
||||
i += 1
|
||||
if i == len(list_data):
|
||||
break
|
||||
assert topic not in pubsub.topic_ids
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(receive_data, TESTING_TOPIC)
|
||||
nursery.start_soon(publish_data, TESTING_TOPIC)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_subscribe_and_publish_full_channel():
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
pubsub = pubsubs_fsub[0]
|
||||
|
||||
extra_data_0 = b"extra_data_0"
|
||||
extra_data_1 = b"extra_data_1"
|
||||
|
||||
# Test: Subscription channel is of size `SUBSCRIPTION_CHANNEL_SIZE`.
|
||||
# When the channel is full, new received messages are dropped.
|
||||
# After the channel has empty slot, the channel can receive new messages.
|
||||
|
||||
# Assume `SUBSCRIPTION_CHANNEL_SIZE` is smaller than `2**(4*8)`.
|
||||
list_data = [i.to_bytes(4, "big") for i in range(SUBSCRIPTION_CHANNEL_SIZE)]
|
||||
# Expect `extra_data_0` is dropped and `extra_data_1` is appended.
|
||||
expected_list_data = list_data + [extra_data_1]
|
||||
|
||||
subscription = await pubsub.subscribe(TESTING_TOPIC)
|
||||
for data in list_data:
|
||||
await pubsub.publish(TESTING_TOPIC, data)
|
||||
|
||||
# Publish `extra_data_0` which should be dropped since the channel is
|
||||
# already full.
|
||||
await pubsub.publish(TESTING_TOPIC, extra_data_0)
|
||||
# Consume a message and there is an empty slot in the channel.
|
||||
assert (await subscription.get()).data == expected_list_data.pop(0)
|
||||
# Publish `extra_data_1` which should be appended to the channel.
|
||||
await pubsub.publish(TESTING_TOPIC, extra_data_1)
|
||||
|
||||
for expected_data in expected_list_data:
|
||||
assert (await subscription.get()).data == expected_data
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_publish_push_msg_is_called(monkeypatch):
|
||||
msg_forwarders = []
|
||||
msgs = []
|
||||
|
||||
async def push_msg(msg_forwarder, msg):
|
||||
msg_forwarders.append(msg_forwarder)
|
||||
msgs.append(msg)
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
with monkeypatch.context() as m:
|
||||
m.setattr(pubsubs_fsub[0], "push_msg", push_msg)
|
||||
|
||||
await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA)
|
||||
await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA)
|
||||
|
||||
assert (
|
||||
len(msgs) == 2
|
||||
), "`push_msg` should be called every time `publish` is called"
|
||||
assert (msg_forwarders[0] == msg_forwarders[1]) and (
|
||||
msg_forwarders[1] == pubsubs_fsub[0].my_id
|
||||
)
|
||||
assert (
|
||||
msgs[0].seqno != msgs[1].seqno
|
||||
), "`seqno` should be different every time"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_push_msg(monkeypatch):
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
msg_0 = make_pubsub_msg(
|
||||
origin_id=pubsubs_fsub[0].my_id,
|
||||
topic_ids=[TESTING_TOPIC],
|
||||
data=TESTING_DATA,
|
||||
seqno=b"\x00" * 8,
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def mock_router_publish():
|
||||
event = trio.Event()
|
||||
|
||||
async def router_publish(*args, **kwargs):
|
||||
event.set()
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setattr(pubsubs_fsub[0].router, "publish", router_publish)
|
||||
yield event
|
||||
|
||||
with mock_router_publish() as event:
|
||||
# Test: `msg` is not seen before `push_msg`, and is seen after `push_msg`.
|
||||
assert not pubsubs_fsub[0]._is_msg_seen(msg_0)
|
||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0)
|
||||
assert pubsubs_fsub[0]._is_msg_seen(msg_0)
|
||||
# Test: Ensure `router.publish` is called in `push_msg`
|
||||
with trio.fail_after(0.1):
|
||||
await event.wait()
|
||||
|
||||
with mock_router_publish() as event:
|
||||
# Test: `push_msg` the message again and it will be reject.
|
||||
# `router_publish` is not called then.
|
||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0)
|
||||
await trio.sleep(0.01)
|
||||
assert not event.is_set()
|
||||
|
||||
sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
# Test: `push_msg` succeeds with another unseen msg.
|
||||
msg_1 = make_pubsub_msg(
|
||||
origin_id=pubsubs_fsub[0].my_id,
|
||||
topic_ids=[TESTING_TOPIC],
|
||||
data=TESTING_DATA,
|
||||
seqno=b"\x11" * 8,
|
||||
)
|
||||
assert not pubsubs_fsub[0]._is_msg_seen(msg_1)
|
||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_1)
|
||||
assert pubsubs_fsub[0]._is_msg_seen(msg_1)
|
||||
with trio.fail_after(0.1):
|
||||
await event.wait()
|
||||
# Test: Subscribers are notified when `push_msg` new messages.
|
||||
assert (await sub.get()) == msg_1
|
||||
|
||||
with mock_router_publish() as event:
|
||||
# Test: add a topic validator and `push_msg` the message that
|
||||
# does not pass the validation.
|
||||
# `router_publish` is not called then.
|
||||
def failed_sync_validator(peer_id, msg):
|
||||
return False
|
||||
|
||||
pubsubs_fsub[0].set_topic_validator(
|
||||
TESTING_TOPIC, failed_sync_validator, False
|
||||
)
|
||||
|
||||
msg_2 = make_pubsub_msg(
|
||||
origin_id=pubsubs_fsub[0].my_id,
|
||||
topic_ids=[TESTING_TOPIC],
|
||||
data=TESTING_DATA,
|
||||
seqno=b"\x22" * 8,
|
||||
)
|
||||
|
||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_2)
|
||||
await trio.sleep(0.01)
|
||||
assert not event.is_set()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_strict_signing():
|
||||
async with PubsubFactory.create_batch_with_floodsub(
|
||||
2, strict_signing=True
|
||||
) as pubsubs_fsub:
|
||||
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
await pubsubs_fsub[1].subscribe(TESTING_TOPIC)
|
||||
await trio.sleep(1)
|
||||
|
||||
await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA)
|
||||
await trio.sleep(1)
|
||||
|
||||
assert len(pubsubs_fsub[0].seen_messages) == 1
|
||||
assert len(pubsubs_fsub[1].seen_messages) == 1
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_strict_signing_failed_validation(monkeypatch):
|
||||
async with PubsubFactory.create_batch_with_floodsub(
|
||||
2, strict_signing=True
|
||||
) as pubsubs_fsub:
|
||||
msg = make_pubsub_msg(
|
||||
origin_id=pubsubs_fsub[0].my_id,
|
||||
topic_ids=[TESTING_TOPIC],
|
||||
data=TESTING_DATA,
|
||||
seqno=b"\x00" * 8,
|
||||
)
|
||||
priv_key = pubsubs_fsub[0].sign_key
|
||||
signature = priv_key.sign(
|
||||
PUBSUB_SIGNING_PREFIX.encode() + msg.SerializeToString()
|
||||
)
|
||||
|
||||
event = trio.Event()
|
||||
|
||||
def _is_msg_seen(msg):
|
||||
return False
|
||||
|
||||
# Use router publish to check if `push_msg` succeed.
|
||||
async def router_publish(*args, **kwargs):
|
||||
await trio.lowlevel.checkpoint()
|
||||
# The event will only be set if `push_msg` succeed.
|
||||
event.set()
|
||||
|
||||
monkeypatch.setattr(pubsubs_fsub[0], "_is_msg_seen", _is_msg_seen)
|
||||
monkeypatch.setattr(pubsubs_fsub[0].router, "publish", router_publish)
|
||||
|
||||
# Test: no signature attached in `msg`
|
||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg)
|
||||
await trio.sleep(0.01)
|
||||
assert not event.is_set()
|
||||
|
||||
# Test: `msg.key` does not match `msg.from_id`
|
||||
msg.key = pubsubs_fsub[1].host.get_public_key().serialize()
|
||||
msg.signature = signature
|
||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg)
|
||||
await trio.sleep(0.01)
|
||||
assert not event.is_set()
|
||||
|
||||
# Test: invalid signature
|
||||
msg.key = pubsubs_fsub[0].host.get_public_key().serialize()
|
||||
msg.signature = b"\x12" * 100
|
||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg)
|
||||
await trio.sleep(0.01)
|
||||
assert not event.is_set()
|
||||
|
||||
# Finally, assert the signature indeed will pass validation
|
||||
msg.key = pubsubs_fsub[0].host.get_public_key().serialize()
|
||||
msg.signature = signature
|
||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg)
|
||||
await trio.sleep(0.01)
|
||||
assert event.is_set()
|
||||
88
tests/core/pubsub/test_subscription.py
Normal file
88
tests/core/pubsub/test_subscription.py
Normal file
@ -0,0 +1,88 @@
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.pubsub.pb import (
|
||||
rpc_pb2,
|
||||
)
|
||||
from libp2p.pubsub.subscription import (
|
||||
TrioSubscriptionAPI,
|
||||
)
|
||||
|
||||
GET_TIMEOUT = 0.001
|
||||
|
||||
|
||||
def make_trio_subscription():
|
||||
send_channel, receive_channel = trio.open_memory_channel(math.inf)
|
||||
|
||||
async def unsubscribe_fn():
|
||||
await send_channel.aclose()
|
||||
|
||||
return (
|
||||
send_channel,
|
||||
TrioSubscriptionAPI(receive_channel, unsubscribe_fn=unsubscribe_fn),
|
||||
)
|
||||
|
||||
|
||||
def make_pubsub_msg():
|
||||
return rpc_pb2.Message()
|
||||
|
||||
|
||||
async def send_something(send_channel):
|
||||
msg = make_pubsub_msg()
|
||||
await send_channel.send(msg)
|
||||
return msg
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_trio_subscription_get():
|
||||
send_channel, sub = make_trio_subscription()
|
||||
data_0 = await send_something(send_channel)
|
||||
data_1 = await send_something(send_channel)
|
||||
assert data_0 == await sub.get()
|
||||
assert data_1 == await sub.get()
|
||||
# No more message
|
||||
with pytest.raises(trio.TooSlowError):
|
||||
with trio.fail_after(GET_TIMEOUT):
|
||||
await sub.get()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_trio_subscription_iter():
|
||||
send_channel, sub = make_trio_subscription()
|
||||
received_data = []
|
||||
|
||||
async def iter_subscriptions(subscription):
|
||||
async for data in sub:
|
||||
received_data.append(data)
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(iter_subscriptions, sub)
|
||||
await send_something(send_channel)
|
||||
await send_something(send_channel)
|
||||
await send_channel.aclose()
|
||||
|
||||
assert len(received_data) == 2
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_trio_subscription_unsubscribe():
|
||||
send_channel, sub = make_trio_subscription()
|
||||
await sub.unsubscribe()
|
||||
# Test: If the subscription is unsubscribed, `send_channel` should be closed.
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await send_something(send_channel)
|
||||
# Test: No side effect when cancelled twice.
|
||||
await sub.unsubscribe()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_trio_subscription_async_context_manager():
|
||||
send_channel, sub = make_trio_subscription()
|
||||
async with sub:
|
||||
# Test: `sub` is not cancelled yet, so `send_something` works fine.
|
||||
await send_something(send_channel)
|
||||
# Test: `sub` is cancelled, `send_something` fails
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await send_something(send_channel)
|
||||
32
tests/core/security/noise/test_msg_read_writer.py
Normal file
32
tests/core/security/noise/test_msg_read_writer.py
Normal file
@ -0,0 +1,32 @@
|
||||
import pytest
|
||||
|
||||
from libp2p.security.noise.io import (
|
||||
MAX_NOISE_MESSAGE_LEN,
|
||||
NoisePacketReadWriter,
|
||||
)
|
||||
from libp2p.tools.factories import (
|
||||
raw_conn_factory,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"noise_msg",
|
||||
(b"", b"data", pytest.param(b"A" * MAX_NOISE_MESSAGE_LEN, id="maximum length")),
|
||||
)
|
||||
@pytest.mark.trio
|
||||
async def test_noise_msg_read_write_round_trip(nursery, noise_msg):
|
||||
async with raw_conn_factory(nursery) as conns:
|
||||
reader, writer = (
|
||||
NoisePacketReadWriter(conns[0]),
|
||||
NoisePacketReadWriter(conns[1]),
|
||||
)
|
||||
await writer.write_msg(noise_msg)
|
||||
assert (await reader.read_msg()) == noise_msg
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_noise_msg_write_too_long(nursery):
|
||||
async with raw_conn_factory(nursery) as conns:
|
||||
writer = NoisePacketReadWriter(conns[0])
|
||||
with pytest.raises(ValueError):
|
||||
await writer.write_msg(b"1" * (MAX_NOISE_MESSAGE_LEN + 1))
|
||||
38
tests/core/security/noise/test_noise.py
Normal file
38
tests/core/security/noise/test_noise.py
Normal file
@ -0,0 +1,38 @@
|
||||
import pytest
|
||||
|
||||
from libp2p.security.noise.messages import (
|
||||
NoiseHandshakePayload,
|
||||
)
|
||||
from libp2p.tools.factories import (
|
||||
noise_conn_factory,
|
||||
noise_handshake_payload_factory,
|
||||
)
|
||||
|
||||
DATA_0 = b"data_0"
|
||||
DATA_1 = b"1" * 1000
|
||||
DATA_2 = b"data_2"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_noise_transport(nursery):
|
||||
async with noise_conn_factory(nursery):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_noise_connection(nursery):
|
||||
async with noise_conn_factory(nursery) as conns:
|
||||
local_conn, remote_conn = conns
|
||||
await local_conn.write(DATA_0)
|
||||
await local_conn.write(DATA_1)
|
||||
assert DATA_0 == (await remote_conn.read(len(DATA_0)))
|
||||
assert DATA_1 == (await remote_conn.read(len(DATA_1)))
|
||||
await local_conn.write(DATA_2)
|
||||
assert DATA_2 == (await remote_conn.read(len(DATA_2)))
|
||||
|
||||
|
||||
def test_noise_handshake_payload():
|
||||
payload = noise_handshake_payload_factory()
|
||||
payload_serialized = payload.serialize()
|
||||
payload_deserialized = NoiseHandshakePayload.deserialize(payload_serialized)
|
||||
assert payload == payload_deserialized
|
||||
60
tests/core/security/test_secio.py
Normal file
60
tests/core/security/test_secio.py
Normal file
@ -0,0 +1,60 @@
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.crypto.secp256k1 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.security.secio.transport import (
|
||||
NONCE_SIZE,
|
||||
create_secure_session,
|
||||
)
|
||||
from libp2p.tools.constants import (
|
||||
MAX_READ_LEN,
|
||||
)
|
||||
from libp2p.tools.factories import (
|
||||
raw_conn_factory,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_create_secure_session(nursery):
|
||||
local_nonce = b"\x01" * NONCE_SIZE
|
||||
local_key_pair = create_new_key_pair(b"a")
|
||||
local_peer = ID.from_pubkey(local_key_pair.public_key)
|
||||
|
||||
remote_nonce = b"\x02" * NONCE_SIZE
|
||||
remote_key_pair = create_new_key_pair(b"b")
|
||||
remote_peer = ID.from_pubkey(remote_key_pair.public_key)
|
||||
|
||||
async with raw_conn_factory(nursery) as conns:
|
||||
local_conn, remote_conn = conns
|
||||
|
||||
local_secure_conn, remote_secure_conn = None, None
|
||||
|
||||
async def local_create_secure_session():
|
||||
nonlocal local_secure_conn
|
||||
local_secure_conn = await create_secure_session(
|
||||
local_nonce,
|
||||
local_peer,
|
||||
local_key_pair.private_key,
|
||||
local_conn,
|
||||
remote_peer,
|
||||
)
|
||||
|
||||
async def remote_create_secure_session():
|
||||
nonlocal remote_secure_conn
|
||||
remote_secure_conn = await create_secure_session(
|
||||
remote_nonce, remote_peer, remote_key_pair.private_key, remote_conn
|
||||
)
|
||||
|
||||
async with trio.open_nursery() as nursery_1:
|
||||
nursery_1.start_soon(local_create_secure_session)
|
||||
nursery_1.start_soon(remote_create_secure_session)
|
||||
|
||||
msg = b"abc"
|
||||
await local_secure_conn.write(msg)
|
||||
received_msg = await remote_secure_conn.read(MAX_READ_LEN)
|
||||
assert received_msg == msg
|
||||
58
tests/core/security/test_security_multistream.py
Normal file
58
tests/core/security/test_security_multistream.py
Normal file
@ -0,0 +1,58 @@
|
||||
import pytest
|
||||
|
||||
from libp2p.crypto.rsa import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.security.insecure.transport import (
|
||||
PLAINTEXT_PROTOCOL_ID,
|
||||
InsecureSession,
|
||||
)
|
||||
from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID
|
||||
from libp2p.security.secio.transport import ID as SECIO_PROTOCOL_ID
|
||||
from libp2p.security.secure_session import (
|
||||
SecureSession,
|
||||
)
|
||||
from libp2p.tools.factories import (
|
||||
host_pair_factory,
|
||||
)
|
||||
|
||||
initiator_key_pair = create_new_key_pair()
|
||||
|
||||
noninitiator_key_pair = create_new_key_pair()
|
||||
|
||||
|
||||
async def perform_simple_test(assertion_func, security_protocol):
|
||||
async with host_pair_factory(security_protocol=security_protocol) as hosts:
|
||||
conn_0 = hosts[0].get_network().connections[hosts[1].get_id()]
|
||||
conn_1 = hosts[1].get_network().connections[hosts[0].get_id()]
|
||||
|
||||
# Perform assertion
|
||||
assertion_func(conn_0.muxed_conn.secured_conn)
|
||||
assertion_func(conn_1.muxed_conn.secured_conn)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
@pytest.mark.parametrize(
|
||||
"security_protocol, transport_type",
|
||||
(
|
||||
(PLAINTEXT_PROTOCOL_ID, InsecureSession),
|
||||
(SECIO_PROTOCOL_ID, SecureSession),
|
||||
(NOISE_PROTOCOL_ID, SecureSession),
|
||||
),
|
||||
)
|
||||
@pytest.mark.trio
|
||||
async def test_single_insecure_security_transport_succeeds(
|
||||
security_protocol, transport_type
|
||||
):
|
||||
def assertion_func(conn):
|
||||
assert isinstance(conn, transport_type)
|
||||
|
||||
await perform_simple_test(assertion_func, security_protocol)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_default_insecure_security():
|
||||
def assertion_func(conn):
|
||||
assert isinstance(conn, InsecureSession)
|
||||
|
||||
await perform_simple_test(assertion_func, None)
|
||||
24
tests/core/stream_muxer/conftest.py
Normal file
24
tests/core/stream_muxer/conftest.py
Normal file
@ -0,0 +1,24 @@
|
||||
import pytest
|
||||
|
||||
from libp2p.tools.factories import (
|
||||
mplex_conn_pair_factory,
|
||||
mplex_stream_pair_factory,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mplex_conn_pair(security_protocol):
|
||||
async with mplex_conn_pair_factory(
|
||||
security_protocol=security_protocol
|
||||
) as mplex_conn_pair:
|
||||
assert mplex_conn_pair[0].is_initiator
|
||||
assert not mplex_conn_pair[1].is_initiator
|
||||
yield mplex_conn_pair[0], mplex_conn_pair[1]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mplex_stream_pair(security_protocol):
|
||||
async with mplex_stream_pair_factory(
|
||||
security_protocol=security_protocol
|
||||
) as mplex_stream_pair:
|
||||
yield mplex_stream_pair
|
||||
41
tests/core/stream_muxer/test_mplex_conn.py
Normal file
41
tests/core/stream_muxer/test_mplex_conn.py
Normal file
@ -0,0 +1,41 @@
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_mplex_conn(mplex_conn_pair):
|
||||
conn_0, conn_1 = mplex_conn_pair
|
||||
|
||||
assert len(conn_0.streams) == 0
|
||||
assert len(conn_1.streams) == 0
|
||||
|
||||
# Test: Open a stream, and both side get 1 more stream.
|
||||
stream_0 = await conn_0.open_stream()
|
||||
await trio.sleep(0.01)
|
||||
assert len(conn_0.streams) == 1
|
||||
assert len(conn_1.streams) == 1
|
||||
# Test: From another side.
|
||||
stream_1 = await conn_1.open_stream()
|
||||
await trio.sleep(0.01)
|
||||
assert len(conn_0.streams) == 2
|
||||
assert len(conn_1.streams) == 2
|
||||
|
||||
# Close from one side.
|
||||
await conn_0.close()
|
||||
# Sleep for a while for both side to handle `close`.
|
||||
await trio.sleep(0.01)
|
||||
# Test: Both side is closed.
|
||||
assert conn_0.is_closed
|
||||
assert conn_1.is_closed
|
||||
# Test: All streams should have been closed.
|
||||
assert stream_0.event_remote_closed.is_set()
|
||||
assert stream_0.event_reset.is_set()
|
||||
assert stream_0.event_local_closed.is_set()
|
||||
# Test: All streams on the other side are also closed.
|
||||
assert stream_1.event_remote_closed.is_set()
|
||||
assert stream_1.event_reset.is_set()
|
||||
assert stream_1.event_local_closed.is_set()
|
||||
|
||||
# Test: No effect to close more than once between two side.
|
||||
await conn_0.close()
|
||||
await conn_1.close()
|
||||
215
tests/core/stream_muxer/test_mplex_stream.py
Normal file
215
tests/core/stream_muxer/test_mplex_stream.py
Normal file
@ -0,0 +1,215 @@
|
||||
import pytest
|
||||
import trio
|
||||
from trio.testing import (
|
||||
wait_all_tasks_blocked,
|
||||
)
|
||||
|
||||
from libp2p.stream_muxer.mplex.exceptions import (
|
||||
MplexStreamClosed,
|
||||
MplexStreamEOF,
|
||||
MplexStreamReset,
|
||||
)
|
||||
from libp2p.stream_muxer.mplex.mplex import (
|
||||
MPLEX_MESSAGE_CHANNEL_SIZE,
|
||||
)
|
||||
from libp2p.tools.constants import (
|
||||
MAX_READ_LEN,
|
||||
)
|
||||
|
||||
DATA = b"data_123"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_mplex_stream_read_write(mplex_stream_pair):
|
||||
stream_0, stream_1 = mplex_stream_pair
|
||||
await stream_0.write(DATA)
|
||||
assert (await stream_1.read(MAX_READ_LEN)) == DATA
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_mplex_stream_full_buffer(mplex_stream_pair):
|
||||
stream_0, stream_1 = mplex_stream_pair
|
||||
# Test: The message channel is of size `MPLEX_MESSAGE_CHANNEL_SIZE`.
|
||||
# It should be fine to read even there are already `MPLEX_MESSAGE_CHANNEL_SIZE`
|
||||
# messages arriving.
|
||||
for _ in range(MPLEX_MESSAGE_CHANNEL_SIZE):
|
||||
await stream_0.write(DATA)
|
||||
await wait_all_tasks_blocked()
|
||||
# Sanity check
|
||||
assert MAX_READ_LEN >= MPLEX_MESSAGE_CHANNEL_SIZE * len(DATA)
|
||||
assert (await stream_1.read(MAX_READ_LEN)) == MPLEX_MESSAGE_CHANNEL_SIZE * DATA
|
||||
|
||||
# Test: Read after `MPLEX_MESSAGE_CHANNEL_SIZE + 1` messages has arrived, which
|
||||
# exceeds the channel size. The stream should have been reset.
|
||||
for _ in range(MPLEX_MESSAGE_CHANNEL_SIZE + 1):
|
||||
await stream_0.write(DATA)
|
||||
await wait_all_tasks_blocked()
|
||||
with pytest.raises(MplexStreamReset):
|
||||
await stream_1.read(MAX_READ_LEN)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair):
|
||||
read_bytes = bytearray()
|
||||
stream_0, stream_1 = mplex_stream_pair
|
||||
|
||||
async def read_until_eof():
|
||||
read_bytes.extend(await stream_1.read())
|
||||
|
||||
expected_data = bytearray()
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(read_until_eof)
|
||||
# Test: `read` doesn't return before `close` is called.
|
||||
await stream_0.write(DATA)
|
||||
expected_data.extend(DATA)
|
||||
await trio.sleep(0.01)
|
||||
assert len(read_bytes) == 0
|
||||
# Test: `read` doesn't return before `close` is called.
|
||||
await stream_0.write(DATA)
|
||||
expected_data.extend(DATA)
|
||||
await trio.sleep(0.01)
|
||||
assert len(read_bytes) == 0
|
||||
|
||||
# Test: Close the stream, `read` returns, and receive previous sent data.
|
||||
await stream_0.close()
|
||||
|
||||
assert read_bytes == expected_data
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_mplex_stream_read_after_remote_closed(mplex_stream_pair):
|
||||
stream_0, stream_1 = mplex_stream_pair
|
||||
assert not stream_1.event_remote_closed.is_set()
|
||||
await stream_0.write(DATA)
|
||||
assert not stream_0.event_local_closed.is_set()
|
||||
await trio.sleep(0.01)
|
||||
await wait_all_tasks_blocked()
|
||||
await stream_0.close()
|
||||
assert stream_0.event_local_closed.is_set()
|
||||
await trio.sleep(0.01)
|
||||
await wait_all_tasks_blocked()
|
||||
assert stream_1.event_remote_closed.is_set()
|
||||
assert (await stream_1.read(MAX_READ_LEN)) == DATA
|
||||
with pytest.raises(MplexStreamEOF):
|
||||
await stream_1.read(MAX_READ_LEN)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_mplex_stream_read_after_local_reset(mplex_stream_pair):
|
||||
stream_0, stream_1 = mplex_stream_pair
|
||||
await stream_0.reset()
|
||||
with pytest.raises(MplexStreamReset):
|
||||
await stream_0.read(MAX_READ_LEN)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_mplex_stream_read_after_remote_reset(mplex_stream_pair):
|
||||
stream_0, stream_1 = mplex_stream_pair
|
||||
await stream_0.write(DATA)
|
||||
await stream_0.reset()
|
||||
# Sleep to let `stream_1` receive the message.
|
||||
await trio.sleep(0.1)
|
||||
await wait_all_tasks_blocked()
|
||||
with pytest.raises(MplexStreamReset):
|
||||
await stream_1.read(MAX_READ_LEN)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_mplex_stream_read_after_remote_closed_and_reset(mplex_stream_pair):
|
||||
stream_0, stream_1 = mplex_stream_pair
|
||||
await stream_0.write(DATA)
|
||||
await stream_0.close()
|
||||
await stream_0.reset()
|
||||
# Sleep to let `stream_1` receive the message.
|
||||
await trio.sleep(0.01)
|
||||
assert (await stream_1.read(MAX_READ_LEN)) == DATA
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_mplex_stream_write_after_local_closed(mplex_stream_pair):
|
||||
stream_0, stream_1 = mplex_stream_pair
|
||||
await stream_0.write(DATA)
|
||||
await stream_0.close()
|
||||
with pytest.raises(MplexStreamClosed):
|
||||
await stream_0.write(DATA)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_mplex_stream_write_after_local_reset(mplex_stream_pair):
|
||||
stream_0, stream_1 = mplex_stream_pair
|
||||
await stream_0.reset()
|
||||
with pytest.raises(MplexStreamClosed):
|
||||
await stream_0.write(DATA)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_mplex_stream_write_after_remote_reset(mplex_stream_pair):
|
||||
stream_0, stream_1 = mplex_stream_pair
|
||||
await stream_1.reset()
|
||||
await trio.sleep(0.01)
|
||||
with pytest.raises(MplexStreamClosed):
|
||||
await stream_0.write(DATA)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_mplex_stream_both_close(mplex_stream_pair):
|
||||
stream_0, stream_1 = mplex_stream_pair
|
||||
# Flags are not set initially.
|
||||
assert not stream_0.event_local_closed.is_set()
|
||||
assert not stream_1.event_local_closed.is_set()
|
||||
assert not stream_0.event_remote_closed.is_set()
|
||||
assert not stream_1.event_remote_closed.is_set()
|
||||
# Streams are present in their `mplex_conn`.
|
||||
assert stream_0 in stream_0.muxed_conn.streams.values()
|
||||
assert stream_1 in stream_1.muxed_conn.streams.values()
|
||||
|
||||
# Test: Close one side.
|
||||
await stream_0.close()
|
||||
await trio.sleep(0.01)
|
||||
|
||||
assert stream_0.event_local_closed.is_set()
|
||||
assert not stream_1.event_local_closed.is_set()
|
||||
assert not stream_0.event_remote_closed.is_set()
|
||||
assert stream_1.event_remote_closed.is_set()
|
||||
# Streams are still present in their `mplex_conn`.
|
||||
assert stream_0 in stream_0.muxed_conn.streams.values()
|
||||
assert stream_1 in stream_1.muxed_conn.streams.values()
|
||||
|
||||
# Test: Close the other side.
|
||||
await stream_1.close()
|
||||
await trio.sleep(0.01)
|
||||
# Both sides are closed.
|
||||
assert stream_0.event_local_closed.is_set()
|
||||
assert stream_1.event_local_closed.is_set()
|
||||
assert stream_0.event_remote_closed.is_set()
|
||||
assert stream_1.event_remote_closed.is_set()
|
||||
# Streams are removed from their `mplex_conn`.
|
||||
assert stream_0 not in stream_0.muxed_conn.streams.values()
|
||||
assert stream_1 not in stream_1.muxed_conn.streams.values()
|
||||
|
||||
# Test: Reset after both close.
|
||||
await stream_0.reset()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_mplex_stream_reset(mplex_stream_pair):
|
||||
stream_0, stream_1 = mplex_stream_pair
|
||||
await stream_0.reset()
|
||||
await trio.sleep(0.01)
|
||||
|
||||
# Both sides are closed.
|
||||
assert stream_0.event_local_closed.is_set()
|
||||
assert stream_1.event_local_closed.is_set()
|
||||
assert stream_0.event_remote_closed.is_set()
|
||||
assert stream_1.event_remote_closed.is_set()
|
||||
# Streams are removed from their `mplex_conn`.
|
||||
assert stream_0 not in stream_0.muxed_conn.streams.values()
|
||||
assert stream_1 not in stream_1.muxed_conn.streams.values()
|
||||
|
||||
# `close` should do nothing.
|
||||
await stream_0.close()
|
||||
await stream_1.close()
|
||||
# `reset` should do nothing as well.
|
||||
await stream_0.reset()
|
||||
await stream_1.reset()
|
||||
309
tests/core/test_libp2p/test_libp2p.py
Normal file
309
tests/core/test_libp2p/test_libp2p.py
Normal file
@ -0,0 +1,309 @@
|
||||
import multiaddr
|
||||
import pytest
|
||||
|
||||
from libp2p.network.stream.exceptions import (
|
||||
StreamError,
|
||||
)
|
||||
from libp2p.tools.constants import (
|
||||
MAX_READ_LEN,
|
||||
)
|
||||
from libp2p.tools.factories import (
|
||||
HostFactory,
|
||||
)
|
||||
from libp2p.tools.utils import (
|
||||
connect,
|
||||
create_echo_stream_handler,
|
||||
)
|
||||
from libp2p.typing import (
|
||||
TProtocol,
|
||||
)
|
||||
|
||||
PROTOCOL_ID_0 = TProtocol("/echo/0")
|
||||
PROTOCOL_ID_1 = TProtocol("/echo/1")
|
||||
PROTOCOL_ID_2 = TProtocol("/echo/2")
|
||||
PROTOCOL_ID_3 = TProtocol("/echo/3")
|
||||
|
||||
ACK_STR_0 = "ack_0:"
|
||||
ACK_STR_1 = "ack_1:"
|
||||
ACK_STR_2 = "ack_2:"
|
||||
ACK_STR_3 = "ack_3:"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_simple_messages(security_protocol):
|
||||
async with HostFactory.create_batch_and_listen(
|
||||
2, security_protocol=security_protocol
|
||||
) as hosts:
|
||||
hosts[1].set_stream_handler(
|
||||
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
|
||||
)
|
||||
|
||||
# Associate the peer with local ip address (see default parameters of Libp2p())
|
||||
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
|
||||
|
||||
stream = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
|
||||
|
||||
messages = ["hello" + str(x) for x in range(10)]
|
||||
for message in messages:
|
||||
await stream.write(message.encode())
|
||||
response = (await stream.read(MAX_READ_LEN)).decode()
|
||||
assert response == (ACK_STR_0 + message)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_double_response(security_protocol):
|
||||
async with HostFactory.create_batch_and_listen(
|
||||
2, security_protocol=security_protocol
|
||||
) as hosts:
|
||||
|
||||
async def double_response_stream_handler(stream):
|
||||
while True:
|
||||
try:
|
||||
read_string = (await stream.read(MAX_READ_LEN)).decode()
|
||||
except StreamError:
|
||||
break
|
||||
|
||||
response = ACK_STR_0 + read_string
|
||||
try:
|
||||
await stream.write(response.encode())
|
||||
except StreamError:
|
||||
break
|
||||
|
||||
response = ACK_STR_1 + read_string
|
||||
try:
|
||||
await stream.write(response.encode())
|
||||
except StreamError:
|
||||
break
|
||||
|
||||
hosts[1].set_stream_handler(PROTOCOL_ID_0, double_response_stream_handler)
|
||||
|
||||
# Associate the peer with local ip address (see default parameters of Libp2p())
|
||||
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
|
||||
stream = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
|
||||
|
||||
messages = ["hello" + str(x) for x in range(10)]
|
||||
for message in messages:
|
||||
await stream.write(message.encode())
|
||||
|
||||
response1 = (await stream.read(MAX_READ_LEN)).decode()
|
||||
assert response1 == (ACK_STR_0 + message)
|
||||
|
||||
response2 = (await stream.read(MAX_READ_LEN)).decode()
|
||||
assert response2 == (ACK_STR_1 + message)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_multiple_streams(security_protocol):
|
||||
# hosts[0] should be able to open a stream with hosts[1] and then vice versa.
|
||||
# Stream IDs should be generated uniquely so that the stream state is not
|
||||
# overwritten
|
||||
|
||||
async with HostFactory.create_batch_and_listen(
|
||||
2, security_protocol=security_protocol
|
||||
) as hosts:
|
||||
hosts[0].set_stream_handler(
|
||||
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
|
||||
)
|
||||
hosts[1].set_stream_handler(
|
||||
PROTOCOL_ID_1, create_echo_stream_handler(ACK_STR_1)
|
||||
)
|
||||
|
||||
# Associate the peer with local ip address (see default parameters of Libp2p())
|
||||
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
|
||||
hosts[1].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10)
|
||||
|
||||
stream_a = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_1])
|
||||
stream_b = await hosts[1].new_stream(hosts[0].get_id(), [PROTOCOL_ID_0])
|
||||
|
||||
# A writes to /echo_b via stream_a, and B writes to /echo_a via stream_b
|
||||
messages = ["hello" + str(x) for x in range(10)]
|
||||
for message in messages:
|
||||
a_message = message + "_a"
|
||||
b_message = message + "_b"
|
||||
|
||||
await stream_a.write(a_message.encode())
|
||||
await stream_b.write(b_message.encode())
|
||||
|
||||
response_a = (await stream_a.read(MAX_READ_LEN)).decode()
|
||||
response_b = (await stream_b.read(MAX_READ_LEN)).decode()
|
||||
|
||||
assert response_a == (ACK_STR_1 + a_message) and response_b == (
|
||||
ACK_STR_0 + b_message
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_multiple_streams_same_initiator_different_protocols(security_protocol):
|
||||
async with HostFactory.create_batch_and_listen(
|
||||
2, security_protocol=security_protocol
|
||||
) as hosts:
|
||||
hosts[1].set_stream_handler(
|
||||
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
|
||||
)
|
||||
hosts[1].set_stream_handler(
|
||||
PROTOCOL_ID_1, create_echo_stream_handler(ACK_STR_1)
|
||||
)
|
||||
hosts[1].set_stream_handler(
|
||||
PROTOCOL_ID_2, create_echo_stream_handler(ACK_STR_2)
|
||||
)
|
||||
|
||||
# Associate the peer with local ip address (see default parameters of Libp2p())
|
||||
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
|
||||
hosts[1].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10)
|
||||
|
||||
# Open streams to hosts[1] over echo_a1 echo_a2 echo_a3 protocols
|
||||
stream_a1 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
|
||||
stream_a2 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_1])
|
||||
stream_a3 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_2])
|
||||
|
||||
messages = ["hello" + str(x) for x in range(10)]
|
||||
for message in messages:
|
||||
a1_message = message + "_a1"
|
||||
a2_message = message + "_a2"
|
||||
a3_message = message + "_a3"
|
||||
|
||||
await stream_a1.write(a1_message.encode())
|
||||
await stream_a2.write(a2_message.encode())
|
||||
await stream_a3.write(a3_message.encode())
|
||||
|
||||
response_a1 = (await stream_a1.read(MAX_READ_LEN)).decode()
|
||||
response_a2 = (await stream_a2.read(MAX_READ_LEN)).decode()
|
||||
response_a3 = (await stream_a3.read(MAX_READ_LEN)).decode()
|
||||
|
||||
assert (
|
||||
response_a1 == (ACK_STR_0 + a1_message)
|
||||
and response_a2 == (ACK_STR_1 + a2_message)
|
||||
and response_a3 == (ACK_STR_2 + a3_message)
|
||||
)
|
||||
|
||||
# Success, terminate pending tasks.
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_multiple_streams_two_initiators(security_protocol):
|
||||
async with HostFactory.create_batch_and_listen(
|
||||
2, security_protocol=security_protocol
|
||||
) as hosts:
|
||||
hosts[0].set_stream_handler(
|
||||
PROTOCOL_ID_2, create_echo_stream_handler(ACK_STR_2)
|
||||
)
|
||||
hosts[0].set_stream_handler(
|
||||
PROTOCOL_ID_3, create_echo_stream_handler(ACK_STR_3)
|
||||
)
|
||||
|
||||
hosts[1].set_stream_handler(
|
||||
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
|
||||
)
|
||||
hosts[1].set_stream_handler(
|
||||
PROTOCOL_ID_1, create_echo_stream_handler(ACK_STR_1)
|
||||
)
|
||||
|
||||
# Associate the peer with local ip address (see default parameters of Libp2p())
|
||||
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
|
||||
hosts[1].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10)
|
||||
|
||||
stream_a1 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
|
||||
stream_a2 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_1])
|
||||
|
||||
stream_b1 = await hosts[1].new_stream(hosts[0].get_id(), [PROTOCOL_ID_2])
|
||||
stream_b2 = await hosts[1].new_stream(hosts[0].get_id(), [PROTOCOL_ID_3])
|
||||
|
||||
# A writes to /echo_b via stream_a, and B writes to /echo_a via stream_b
|
||||
messages = ["hello" + str(x) for x in range(10)]
|
||||
for message in messages:
|
||||
a1_message = message + "_a1"
|
||||
a2_message = message + "_a2"
|
||||
|
||||
b1_message = message + "_b1"
|
||||
b2_message = message + "_b2"
|
||||
|
||||
await stream_a1.write(a1_message.encode())
|
||||
await stream_a2.write(a2_message.encode())
|
||||
|
||||
await stream_b1.write(b1_message.encode())
|
||||
await stream_b2.write(b2_message.encode())
|
||||
|
||||
response_a1 = (await stream_a1.read(MAX_READ_LEN)).decode()
|
||||
response_a2 = (await stream_a2.read(MAX_READ_LEN)).decode()
|
||||
|
||||
response_b1 = (await stream_b1.read(MAX_READ_LEN)).decode()
|
||||
response_b2 = (await stream_b2.read(MAX_READ_LEN)).decode()
|
||||
|
||||
assert (
|
||||
response_a1 == (ACK_STR_0 + a1_message)
|
||||
and response_a2 == (ACK_STR_1 + a2_message)
|
||||
and response_b1 == (ACK_STR_2 + b1_message)
|
||||
and response_b2 == (ACK_STR_3 + b2_message)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_triangle_nodes_connection(security_protocol):
|
||||
async with HostFactory.create_batch_and_listen(
|
||||
3, security_protocol=security_protocol
|
||||
) as hosts:
|
||||
hosts[0].set_stream_handler(
|
||||
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
|
||||
)
|
||||
hosts[1].set_stream_handler(
|
||||
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
|
||||
)
|
||||
hosts[2].set_stream_handler(
|
||||
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
|
||||
)
|
||||
|
||||
# Associate the peer with local ip address (see default parameters of Libp2p())
|
||||
# Associate all permutations
|
||||
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
|
||||
hosts[0].get_peerstore().add_addrs(hosts[2].get_id(), hosts[2].get_addrs(), 10)
|
||||
|
||||
hosts[1].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10)
|
||||
hosts[1].get_peerstore().add_addrs(hosts[2].get_id(), hosts[2].get_addrs(), 10)
|
||||
|
||||
hosts[2].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10)
|
||||
hosts[2].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
|
||||
|
||||
stream_0_to_1 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
|
||||
stream_0_to_2 = await hosts[0].new_stream(hosts[2].get_id(), [PROTOCOL_ID_0])
|
||||
|
||||
stream_1_to_0 = await hosts[1].new_stream(hosts[0].get_id(), [PROTOCOL_ID_0])
|
||||
stream_1_to_2 = await hosts[1].new_stream(hosts[2].get_id(), [PROTOCOL_ID_0])
|
||||
|
||||
stream_2_to_0 = await hosts[2].new_stream(hosts[0].get_id(), [PROTOCOL_ID_0])
|
||||
stream_2_to_1 = await hosts[2].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
|
||||
|
||||
messages = ["hello" + str(x) for x in range(5)]
|
||||
streams = [
|
||||
stream_0_to_1,
|
||||
stream_0_to_2,
|
||||
stream_1_to_0,
|
||||
stream_1_to_2,
|
||||
stream_2_to_0,
|
||||
stream_2_to_1,
|
||||
]
|
||||
|
||||
for message in messages:
|
||||
for stream in streams:
|
||||
await stream.write(message.encode())
|
||||
response = (await stream.read(MAX_READ_LEN)).decode()
|
||||
assert response == (ACK_STR_0 + message)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_host_connect(security_protocol):
|
||||
async with HostFactory.create_batch_and_listen(
|
||||
2, security_protocol=security_protocol
|
||||
) as hosts:
|
||||
assert len(hosts[0].get_peerstore().peer_ids()) == 1
|
||||
|
||||
await connect(hosts[0], hosts[1])
|
||||
assert len(hosts[0].get_peerstore().peer_ids()) == 2
|
||||
|
||||
await connect(hosts[0], hosts[1])
|
||||
# make sure we don't do double connection
|
||||
assert len(hosts[0].get_peerstore().peer_ids()) == 2
|
||||
|
||||
assert hosts[1].get_id() in hosts[0].get_peerstore().peer_ids()
|
||||
ma_node_b = multiaddr.Multiaddr("/p2p/%s" % hosts[1].get_id().pretty())
|
||||
for addr in hosts[0].get_peerstore().addrs(hosts[1].get_id()):
|
||||
assert addr.encapsulate(ma_node_b) in hosts[1].get_addrs()
|
||||
63
tests/core/transport/test_tcp.py
Normal file
63
tests/core/transport/test_tcp.py
Normal file
@ -0,0 +1,63 @@
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
)
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.network.connection.raw_connection import (
|
||||
RawConnection,
|
||||
)
|
||||
from libp2p.tools.constants import (
|
||||
LISTEN_MADDR,
|
||||
)
|
||||
from libp2p.transport.exceptions import (
|
||||
OpenConnectionError,
|
||||
)
|
||||
from libp2p.transport.tcp.tcp import (
|
||||
TCP,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_tcp_listener(nursery):
|
||||
transport = TCP()
|
||||
|
||||
async def handler(tcp_stream):
|
||||
pass
|
||||
|
||||
listener = transport.create_listener(handler)
|
||||
assert len(listener.get_addrs()) == 0
|
||||
await listener.listen(LISTEN_MADDR, nursery)
|
||||
assert len(listener.get_addrs()) == 1
|
||||
await listener.listen(LISTEN_MADDR, nursery)
|
||||
assert len(listener.get_addrs()) == 2
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_tcp_dial(nursery):
|
||||
transport = TCP()
|
||||
raw_conn_other_side = None
|
||||
event = trio.Event()
|
||||
|
||||
async def handler(tcp_stream):
|
||||
nonlocal raw_conn_other_side
|
||||
raw_conn_other_side = RawConnection(tcp_stream, False)
|
||||
event.set()
|
||||
await trio.sleep_forever()
|
||||
|
||||
# Test: `OpenConnectionError` is raised when trying to dial to a port which
|
||||
# no one is not listening to.
|
||||
with pytest.raises(OpenConnectionError):
|
||||
await transport.dial(Multiaddr("/ip4/127.0.0.1/tcp/1"))
|
||||
|
||||
listener = transport.create_listener(handler)
|
||||
await listener.listen(LISTEN_MADDR, nursery)
|
||||
addrs = listener.get_addrs()
|
||||
assert len(addrs) == 1
|
||||
listen_addr = addrs[0]
|
||||
raw_conn = await transport.dial(listen_addr)
|
||||
await event.wait()
|
||||
|
||||
data = b"123"
|
||||
await raw_conn_other_side.write(data)
|
||||
assert (await raw_conn.read(len(data))) == data
|
||||
Reference in New Issue
Block a user