Merge pull request #404 from libp2p/feature/trio

Merge `feature/trio` into `master`
This commit is contained in:
Kevin Mai-Husan Chia
2020-02-06 10:49:53 +08:00
committed by GitHub
80 changed files with 3563 additions and 3378 deletions

View File

@ -5,16 +5,16 @@ matrix:
- python: 3.6-dev - python: 3.6-dev
dist: xenial dist: xenial
env: TOXENV=py36-test env: TOXENV=py36-test
- python: 3.7-dev - python: 3.7
dist: xenial dist: xenial
env: TOXENV=py37-test env: TOXENV=py37-test
- python: 3.7-dev - python: 3.7
dist: xenial dist: xenial
env: TOXENV=lint env: TOXENV=lint
- python: 3.7-dev - python: 3.7
dist: xenial dist: xenial
env: TOXENV=docs env: TOXENV=docs
- python: 3.7-dev - python: 3.7
dist: xenial dist: xenial
env: TOXENV=py37-interop env: TOXENV=py37-interop
sudo: true sudo: true

View File

@ -51,7 +51,7 @@ lint:
black --check $(FILES_TO_LINT) black --check $(FILES_TO_LINT)
isort --recursive --check-only --diff $(FILES_TO_LINT) isort --recursive --check-only --diff $(FILES_TO_LINT)
docformatter --pre-summary-newline --check --recursive $(FILES_TO_LINT) docformatter --pre-summary-newline --check --recursive $(FILES_TO_LINT)
tox -elint # This is probably redundant, but just in case... tox -e lint # This is probably redundant, but just in case...
lint-roll: lint-roll:
isort --recursive $(FILES_TO_LINT) isort --recursive $(FILES_TO_LINT)

View File

@ -11,6 +11,22 @@ Subpackages
Submodules Submodules
---------- ----------
libp2p.pubsub.abc module
------------------------
.. automodule:: libp2p.pubsub.abc
:members:
:undoc-members:
:show-inheritance:
libp2p.pubsub.exceptions module
-------------------------------
.. automodule:: libp2p.pubsub.exceptions
:members:
:undoc-members:
:show-inheritance:
libp2p.pubsub.floodsub module libp2p.pubsub.floodsub module
----------------------------- -----------------------------
@ -51,10 +67,10 @@ libp2p.pubsub.pubsub\_notifee module
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
libp2p.pubsub.pubsub\_router\_interface module libp2p.pubsub.subscription module
---------------------------------------------- ---------------------------------
.. automodule:: libp2p.pubsub.pubsub_router_interface .. automodule:: libp2p.pubsub.subscription
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:

View File

@ -1,11 +1,10 @@
import argparse import argparse
import asyncio
import sys import sys
import urllib.request
import multiaddr import multiaddr
import trio
from libp2p import new_node from libp2p import new_host
from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
@ -26,53 +25,47 @@ async def read_data(stream: INetStream) -> None:
async def write_data(stream: INetStream) -> None: async def write_data(stream: INetStream) -> None:
loop = asyncio.get_event_loop() async_f = trio.wrap_file(sys.stdin)
while True: while True:
line = await loop.run_in_executor(None, sys.stdin.readline) line = await async_f.readline()
await stream.write(line.encode()) await stream.write(line.encode())
async def run(port: int, destination: str, localhost: bool) -> None: async def run(port: int, destination: str) -> None:
if localhost: localhost_ip = "127.0.0.1"
ip = "127.0.0.1" listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
else: host = new_host()
ip = urllib.request.urlopen("https://v4.ident.me/").read().decode("utf8") async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
transport_opt = f"/ip4/{ip}/tcp/{port}" if not destination: # its the server
host = await new_node(transport_opt=[transport_opt])
await host.get_network().listen(multiaddr.Multiaddr(transport_opt)) async def stream_handler(stream: INetStream) -> None:
nursery.start_soon(read_data, stream)
nursery.start_soon(write_data, stream)
if not destination: # its the server host.set_stream_handler(PROTOCOL_ID, stream_handler)
async def stream_handler(stream: INetStream) -> None: print(
asyncio.ensure_future(read_data(stream)) f"Run 'python ./examples/chat/chat.py "
asyncio.ensure_future(write_data(stream)) f"-p {int(port) + 1} "
f"-d /ip4/{localhost_ip}/tcp/{port}/p2p/{host.get_id().pretty()}' "
"on another console."
)
print("Waiting for incoming connection...")
host.set_stream_handler(PROTOCOL_ID, stream_handler) else: # its the client
maddr = multiaddr.Multiaddr(destination)
info = info_from_p2p_addr(maddr)
# Associate the peer with local ip address
await host.connect(info)
# Start a stream with the destination.
# Multiaddress of the destination peer is fetched from the peerstore using 'peerId'.
stream = await host.new_stream(info.peer_id, [PROTOCOL_ID])
localhost_opt = " --localhost" if localhost else "" nursery.start_soon(read_data, stream)
nursery.start_soon(write_data, stream)
print(f"Connected to peer {info.addrs[0]}")
print( await trio.sleep_forever()
f"Run 'python ./examples/chat/chat.py"
+ localhost_opt
+ f" -p {int(port) + 1} -d /ip4/{ip}/tcp/{port}/p2p/{host.get_id().pretty()}'"
+ " on another console."
)
print("Waiting for incoming connection...")
else: # its the client
maddr = multiaddr.Multiaddr(destination)
info = info_from_p2p_addr(maddr)
# Associate the peer with local ip address
await host.connect(info)
# Start a stream with the destination.
# Multiaddress of the destination peer is fetched from the peerstore using 'peerId'.
stream = await host.new_stream(info.peer_id, [PROTOCOL_ID])
asyncio.ensure_future(read_data(stream))
asyncio.ensure_future(write_data(stream))
print("Connected to peer %s" % info.addrs[0])
def main() -> None: def main() -> None:
@ -86,11 +79,6 @@ def main() -> None:
"/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q" "/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
) )
parser = argparse.ArgumentParser(description=description) parser = argparse.ArgumentParser(description=description)
parser.add_argument(
"--debug",
action="store_true",
help="generate the same node ID on every execution",
)
parser.add_argument( parser.add_argument(
"-p", "--port", default=8000, type=int, help="source port number" "-p", "--port", default=8000, type=int, help="source port number"
) )
@ -100,26 +88,15 @@ def main() -> None:
type=str, type=str,
help=f"destination multiaddr string, e.g. {example_maddr}", help=f"destination multiaddr string, e.g. {example_maddr}",
) )
parser.add_argument(
"-l",
"--localhost",
dest="localhost",
action="store_true",
help="flag indicating if localhost should be used or an external IP",
)
args = parser.parse_args() args = parser.parse_args()
if not args.port: if not args.port:
raise RuntimeError("was not able to determine a local port") raise RuntimeError("was not able to determine a local port")
loop = asyncio.get_event_loop()
try: try:
asyncio.ensure_future(run(args.port, args.destination, args.localhost)) trio.run(run, *(args.port, args.destination))
loop.run_forever()
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
finally:
loop.close()
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,10 +1,9 @@
import argparse import argparse
import asyncio
import urllib.request
import multiaddr import multiaddr
import trio
from libp2p import new_node from libp2p import new_host
from libp2p.crypto.secp256k1 import create_new_key_pair from libp2p.crypto.secp256k1 import create_new_key_pair
from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.peer.peerinfo import info_from_p2p_addr
@ -20,12 +19,9 @@ async def _echo_stream_handler(stream: INetStream) -> None:
await stream.close() await stream.close()
async def run(port: int, destination: str, localhost: bool, seed: int = None) -> None: async def run(port: int, destination: str, seed: int = None) -> None:
if localhost: localhost_ip = "127.0.0.1"
ip = "127.0.0.1" listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
else:
ip = urllib.request.urlopen("https://v4.ident.me/").read().decode("utf8")
transport_opt = f"/ip4/{ip}/tcp/{port}"
if seed: if seed:
import random import random
@ -38,47 +34,43 @@ async def run(port: int, destination: str, localhost: bool, seed: int = None) ->
secret = secrets.token_bytes(32) secret = secrets.token_bytes(32)
host = await new_node( host = new_host(key_pair=create_new_key_pair(secret))
key_pair=create_new_key_pair(secret), transport_opt=[transport_opt] async with host.run(listen_addrs=[listen_addr]):
)
print(f"I am {host.get_id().to_string()}") print(f"I am {host.get_id().to_string()}")
await host.get_network().listen(multiaddr.Multiaddr(transport_opt)) if not destination: # its the server
if not destination: # its the server host.set_stream_handler(PROTOCOL_ID, _echo_stream_handler)
host.set_stream_handler(PROTOCOL_ID, _echo_stream_handler) print(
f"Run 'python ./examples/echo/echo.py "
f"-p {int(port) + 1} "
f"-d /ip4/{localhost_ip}/tcp/{port}/p2p/{host.get_id().pretty()}' "
"on another console."
)
print("Waiting for incoming connections...")
await trio.sleep_forever()
localhost_opt = " --localhost" if localhost else "" else: # its the client
maddr = multiaddr.Multiaddr(destination)
info = info_from_p2p_addr(maddr)
# Associate the peer with local ip address
await host.connect(info)
print( # Start a stream with the destination.
f"Run 'python ./examples/echo/echo.py" # Multiaddress of the destination peer is fetched from the peerstore using 'peerId'.
+ localhost_opt stream = await host.new_stream(info.peer_id, [PROTOCOL_ID])
+ f" -p {int(port) + 1} -d /ip4/{ip}/tcp/{port}/p2p/{host.get_id().pretty()}'"
+ " on another console."
)
print("Waiting for incoming connections...")
else: # its the client msg = b"hi, there!\n"
maddr = multiaddr.Multiaddr(destination)
info = info_from_p2p_addr(maddr)
# Associate the peer with local ip address
await host.connect(info)
# Start a stream with the destination. await stream.write(msg)
# Multiaddress of the destination peer is fetched from the peerstore using 'peerId'. # Notify the other side about EOF
stream = await host.new_stream(info.peer_id, [PROTOCOL_ID]) await stream.close()
response = await stream.read()
msg = b"hi, there!\n" print(f"Sent: {msg}")
print(f"Got: {response}")
await stream.write(msg)
# Notify the other side about EOF
await stream.close()
response = await stream.read()
print(f"Sent: {msg}")
print(f"Got: {response}")
def main() -> None: def main() -> None:
@ -94,11 +86,6 @@ def main() -> None:
"/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q" "/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
) )
parser = argparse.ArgumentParser(description=description) parser = argparse.ArgumentParser(description=description)
parser.add_argument(
"--debug",
action="store_true",
help="generate the same node ID on every execution",
)
parser.add_argument( parser.add_argument(
"-p", "--port", default=8000, type=int, help="source port number" "-p", "--port", default=8000, type=int, help="source port number"
) )
@ -108,13 +95,6 @@ def main() -> None:
type=str, type=str,
help=f"destination multiaddr string, e.g. {example_maddr}", help=f"destination multiaddr string, e.g. {example_maddr}",
) )
parser.add_argument(
"-l",
"--localhost",
dest="localhost",
action="store_true",
help="flag indicating if localhost should be used or an external IP",
)
parser.add_argument( parser.add_argument(
"-s", "-s",
"--seed", "--seed",
@ -126,16 +106,10 @@ def main() -> None:
if not args.port: if not args.port:
raise RuntimeError("was not able to determine a local port") raise RuntimeError("was not able to determine a local port")
loop = asyncio.get_event_loop()
try: try:
asyncio.ensure_future( trio.run(run, args.port, args.destination, args.seed)
run(args.port, args.destination, args.localhost, args.seed)
)
loop.run_forever()
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
finally:
loop.close()
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,12 +1,9 @@
import asyncio
from typing import Sequence
from libp2p.crypto.keys import KeyPair from libp2p.crypto.keys import KeyPair
from libp2p.crypto.rsa import create_new_key_pair from libp2p.crypto.rsa import create_new_key_pair
from libp2p.host.basic_host import BasicHost from libp2p.host.basic_host import BasicHost
from libp2p.host.host_interface import IHost from libp2p.host.host_interface import IHost
from libp2p.host.routed_host import RoutedHost from libp2p.host.routed_host import RoutedHost
from libp2p.network.network_interface import INetwork from libp2p.network.network_interface import INetworkService
from libp2p.network.swarm import Swarm from libp2p.network.swarm import Swarm
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.peer.peerstore import PeerStore from libp2p.peer.peerstore import PeerStore
@ -21,18 +18,6 @@ from libp2p.transport.upgrader import TransportUpgrader
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
async def cleanup_done_tasks() -> None:
"""clean up asyncio done tasks to free up resources."""
while True:
for task in asyncio.all_tasks():
if task.done():
await task
# Need not run often
# Some sleep necessary to context switch
await asyncio.sleep(3)
def generate_new_rsa_identity() -> KeyPair: def generate_new_rsa_identity() -> KeyPair:
return create_new_key_pair() return create_new_key_pair()
@ -42,29 +27,28 @@ def generate_peer_id_from(key_pair: KeyPair) -> ID:
return ID.from_pubkey(public_key) return ID.from_pubkey(public_key)
def initialize_default_swarm( def new_swarm(
key_pair: KeyPair, key_pair: KeyPair = None,
id_opt: ID = None,
transport_opt: Sequence[str] = None,
muxer_opt: TMuxerOptions = None, muxer_opt: TMuxerOptions = None,
sec_opt: TSecurityOptions = None, sec_opt: TSecurityOptions = None,
peerstore_opt: IPeerStore = None, peerstore_opt: IPeerStore = None,
) -> Swarm: ) -> INetworkService:
""" """
initialize swarm when no swarm is passed in. Create a swarm instance based on the parameters.
:param id_opt: optional id for host :param key_pair: optional choice of the ``KeyPair``
:param transport_opt: optional choice of transport upgrade
:param muxer_opt: optional choice of stream muxer :param muxer_opt: optional choice of stream muxer
:param sec_opt: optional choice of security upgrade :param sec_opt: optional choice of security upgrade
:param peerstore_opt: optional peerstore :param peerstore_opt: optional peerstore
:return: return a default swarm instance :return: return a default swarm instance
""" """
if not id_opt: if key_pair is None:
id_opt = generate_peer_id_from(key_pair) key_pair = generate_new_rsa_identity()
# TODO: Parse `transport_opt` to determine transport id_opt = generate_peer_id_from(key_pair)
# TODO: Parse `listen_addrs` to determine transport
transport = TCP() transport = TCP()
muxer_transports_by_protocol = muxer_opt or {MPLEX_PROTOCOL_ID: Mplex} muxer_transports_by_protocol = muxer_opt or {MPLEX_PROTOCOL_ID: Mplex}
@ -80,57 +64,35 @@ def initialize_default_swarm(
# Store our key pair in peerstore # Store our key pair in peerstore
peerstore.add_key_pair(id_opt, key_pair) peerstore.add_key_pair(id_opt, key_pair)
# TODO: Initialize discovery if not presented
return Swarm(id_opt, peerstore, upgrader, transport) return Swarm(id_opt, peerstore, upgrader, transport)
async def new_node( def new_host(
key_pair: KeyPair = None, key_pair: KeyPair = None,
swarm_opt: INetwork = None,
transport_opt: Sequence[str] = None,
muxer_opt: TMuxerOptions = None, muxer_opt: TMuxerOptions = None,
sec_opt: TSecurityOptions = None, sec_opt: TSecurityOptions = None,
peerstore_opt: IPeerStore = None, peerstore_opt: IPeerStore = None,
disc_opt: IPeerRouting = None, disc_opt: IPeerRouting = None,
) -> BasicHost: ) -> IHost:
""" """
create new libp2p node. Create a new libp2p host based on the given parameters.
:param key_pair: key pair for deriving an identity :param key_pair: optional choice of the ``KeyPair``
:param swarm_opt: optional swarm
:param id_opt: optional id for host
:param transport_opt: optional choice of transport upgrade
:param muxer_opt: optional choice of stream muxer :param muxer_opt: optional choice of stream muxer
:param sec_opt: optional choice of security upgrade :param sec_opt: optional choice of security upgrade
:param peerstore_opt: optional peerstore :param peerstore_opt: optional peerstore
:param disc_opt: optional discovery :param disc_opt: optional discovery
:return: return a host instance :return: return a host instance
""" """
swarm = new_swarm(
if not key_pair: key_pair=key_pair,
key_pair = generate_new_rsa_identity() muxer_opt=muxer_opt,
sec_opt=sec_opt,
id_opt = generate_peer_id_from(key_pair) peerstore_opt=peerstore_opt,
)
if not swarm_opt: host: IHost
swarm_opt = initialize_default_swarm(
key_pair=key_pair,
id_opt=id_opt,
transport_opt=transport_opt,
muxer_opt=muxer_opt,
sec_opt=sec_opt,
peerstore_opt=peerstore_opt,
)
# TODO enable support for other host type
# TODO routing unimplemented
host: IHost # If not explicitly typed, MyPy raises error
if disc_opt: if disc_opt:
host = RoutedHost(swarm_opt, disc_opt) host = RoutedHost(swarm, disc_opt)
else: else:
host = BasicHost(swarm_opt) host = BasicHost(swarm)
# Kick off cleanup job
asyncio.ensure_future(cleanup_done_tasks())
return host return host

View File

@ -1,12 +1,14 @@
import logging import logging
from typing import TYPE_CHECKING, List, Sequence from typing import TYPE_CHECKING, AsyncIterator, List, Sequence
from async_generator import asynccontextmanager
from async_service import background_trio_service
import multiaddr import multiaddr
from libp2p.crypto.keys import PrivateKey, PublicKey from libp2p.crypto.keys import PrivateKey, PublicKey
from libp2p.host.defaults import get_default_protocols from libp2p.host.defaults import get_default_protocols
from libp2p.host.exceptions import StreamFailure from libp2p.host.exceptions import StreamFailure
from libp2p.network.network_interface import INetwork from libp2p.network.network_interface import INetworkService
from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.peer.peerinfo import PeerInfo from libp2p.peer.peerinfo import PeerInfo
@ -39,7 +41,7 @@ class BasicHost(IHost):
right after a stream is initialized. right after a stream is initialized.
""" """
_network: INetwork _network: INetworkService
peerstore: IPeerStore peerstore: IPeerStore
multiselect: Multiselect multiselect: Multiselect
@ -47,7 +49,7 @@ class BasicHost(IHost):
def __init__( def __init__(
self, self,
network: INetwork, network: INetworkService,
default_protocols: "OrderedDict[TProtocol, StreamHandlerFn]" = None, default_protocols: "OrderedDict[TProtocol, StreamHandlerFn]" = None,
) -> None: ) -> None:
self._network = network self._network = network
@ -70,7 +72,7 @@ class BasicHost(IHost):
def get_private_key(self) -> PrivateKey: def get_private_key(self) -> PrivateKey:
return self.peerstore.privkey(self.get_id()) return self.peerstore.privkey(self.get_id())
def get_network(self) -> INetwork: def get_network(self) -> INetworkService:
""" """
:return: network instance of host :return: network instance of host
""" """
@ -101,6 +103,20 @@ class BasicHost(IHost):
addrs.append(addr.encapsulate(p2p_part)) addrs.append(addr.encapsulate(p2p_part))
return addrs return addrs
@asynccontextmanager
async def run(
self, listen_addrs: Sequence[multiaddr.Multiaddr]
) -> AsyncIterator[None]:
"""
run the host instance and listen to ``listen_addrs``.
:param listen_addrs: a sequence of multiaddrs that we want to listen to
"""
network = self.get_network()
async with background_trio_service(network):
await network.listen(*listen_addrs)
yield
def set_stream_handler( def set_stream_handler(
self, protocol_id: TProtocol, stream_handler: StreamHandlerFn self, protocol_id: TProtocol, stream_handler: StreamHandlerFn
) -> None: ) -> None:

View File

@ -1,10 +1,10 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, List, Sequence from typing import Any, AsyncContextManager, List, Sequence
import multiaddr import multiaddr
from libp2p.crypto.keys import PrivateKey, PublicKey from libp2p.crypto.keys import PrivateKey, PublicKey
from libp2p.network.network_interface import INetwork from libp2p.network.network_interface import INetworkService
from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.peer.peerinfo import PeerInfo from libp2p.peer.peerinfo import PeerInfo
@ -31,7 +31,7 @@ class IHost(ABC):
""" """
@abstractmethod @abstractmethod
def get_network(self) -> INetwork: def get_network(self) -> INetworkService:
""" """
:return: network instance of host :return: network instance of host
""" """
@ -49,6 +49,16 @@ class IHost(ABC):
:return: all the multiaddr addresses this host is listening to :return: all the multiaddr addresses this host is listening to
""" """
@abstractmethod
def run(
self, listen_addrs: Sequence[multiaddr.Multiaddr]
) -> AsyncContextManager[None]:
"""
run the host instance and listen to ``listen_addrs``.
:param listen_addrs: a sequence of multiaddrs that we want to listen to
"""
@abstractmethod @abstractmethod
def set_stream_handler( def set_stream_handler(
self, protocol_id: TProtocol, stream_handler: StreamHandlerFn self, protocol_id: TProtocol, stream_handler: StreamHandlerFn

View File

@ -1,6 +1,7 @@
import asyncio
import logging import logging
import trio
from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset
from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.peer.id import ID as PeerID from libp2p.peer.id import ID as PeerID
@ -17,8 +18,9 @@ async def _handle_ping(stream: INetStream, peer_id: PeerID) -> bool:
"""Return a boolean indicating if we expect more pings from the peer at """Return a boolean indicating if we expect more pings from the peer at
``peer_id``.""" ``peer_id``."""
try: try:
payload = await asyncio.wait_for(stream.read(PING_LENGTH), RESP_TIMEOUT) with trio.fail_after(RESP_TIMEOUT):
except asyncio.TimeoutError as error: payload = await stream.read(PING_LENGTH)
except trio.TooSlowError as error:
logger.debug("Timed out waiting for ping from %s: %s", peer_id, error) logger.debug("Timed out waiting for ping from %s: %s", peer_id, error)
raise raise
except StreamEOF: except StreamEOF:

View File

@ -1,6 +1,6 @@
from libp2p.host.basic_host import BasicHost from libp2p.host.basic_host import BasicHost
from libp2p.host.exceptions import ConnectionFailure from libp2p.host.exceptions import ConnectionFailure
from libp2p.network.network_interface import INetwork from libp2p.network.network_interface import INetworkService
from libp2p.peer.peerinfo import PeerInfo from libp2p.peer.peerinfo import PeerInfo
from libp2p.routing.interfaces import IPeerRouting from libp2p.routing.interfaces import IPeerRouting
@ -10,7 +10,7 @@ from libp2p.routing.interfaces import IPeerRouting
class RoutedHost(BasicHost): class RoutedHost(BasicHost):
_router: IPeerRouting _router: IPeerRouting
def __init__(self, network: INetwork, router: IPeerRouting): def __init__(self, network: INetworkService, router: IPeerRouting):
super().__init__(network) super().__init__(network)
self._router = router self._router = router

View File

@ -8,7 +8,7 @@ class Closer(ABC):
class Reader(ABC): class Reader(ABC):
@abstractmethod @abstractmethod
async def read(self, n: int = -1) -> bytes: async def read(self, n: int = None) -> bytes:
... ...

View File

@ -54,7 +54,7 @@ class MsgIOReader(ReadCloser):
self.read_closer = read_closer self.read_closer = read_closer
self.next_length = None self.next_length = None
async def read(self, n: int = -1) -> bytes: async def read(self, n: int = None) -> bytes:
return await self.read_msg() return await self.read_msg()
async def read_msg(self) -> bytes: async def read_msg(self) -> bytes:

40
libp2p/io/trio.py Normal file
View File

@ -0,0 +1,40 @@
import logging
import trio
from libp2p.io.abc import ReadWriteCloser
from libp2p.io.exceptions import IOException
logger = logging.getLogger("libp2p.io.trio")
class TrioTCPStream(ReadWriteCloser):
stream: trio.SocketStream
# NOTE: Add both read and write lock to avoid `trio.BusyResourceError`
read_lock: trio.Lock
write_lock: trio.Lock
def __init__(self, stream: trio.SocketStream) -> None:
self.stream = stream
self.read_lock = trio.Lock()
self.write_lock = trio.Lock()
async def write(self, data: bytes) -> None:
"""Raise `RawConnError` if the underlying connection breaks."""
async with self.write_lock:
try:
await self.stream.send_all(data)
except (trio.ClosedResourceError, trio.BrokenResourceError) as error:
raise IOException from error
async def read(self, n: int = None) -> bytes:
async with self.read_lock:
if n is not None and n == 0:
return b""
try:
return await self.stream.receive_some(n)
except (trio.ClosedResourceError, trio.BrokenResourceError) as error:
raise IOException from error
async def close(self) -> None:
await self.stream.aclose()

View File

@ -1,6 +1,8 @@
from abc import abstractmethod from abc import abstractmethod
from typing import Tuple from typing import Tuple
import trio
from libp2p.io.abc import Closer from libp2p.io.abc import Closer
from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.stream_muxer.abc import IMuxedConn from libp2p.stream_muxer.abc import IMuxedConn
@ -8,11 +10,12 @@ from libp2p.stream_muxer.abc import IMuxedConn
class INetConn(Closer): class INetConn(Closer):
muxed_conn: IMuxedConn muxed_conn: IMuxedConn
event_started: trio.Event
@abstractmethod @abstractmethod
async def new_stream(self) -> INetStream: async def new_stream(self) -> INetStream:
... ...
@abstractmethod @abstractmethod
async def get_streams(self) -> Tuple[INetStream, ...]: def get_streams(self) -> Tuple[INetStream, ...]:
... ...

View File

@ -1,46 +1,26 @@
import asyncio from libp2p.io.abc import ReadWriteCloser
import sys from libp2p.io.exceptions import IOException
from .exceptions import RawConnError from .exceptions import RawConnError
from .raw_connection_interface import IRawConnection from .raw_connection_interface import IRawConnection
class RawConnection(IRawConnection): class RawConnection(IRawConnection):
reader: asyncio.StreamReader stream: ReadWriteCloser
writer: asyncio.StreamWriter
is_initiator: bool is_initiator: bool
_drain_lock: asyncio.Lock def __init__(self, stream: ReadWriteCloser, initiator: bool) -> None:
self.stream = stream
def __init__(
self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
initiator: bool,
) -> None:
self.reader = reader
self.writer = writer
self.is_initiator = initiator self.is_initiator = initiator
self._drain_lock = asyncio.Lock()
async def write(self, data: bytes) -> None: async def write(self, data: bytes) -> None:
"""Raise `RawConnError` if the underlying connection breaks.""" """Raise `RawConnError` if the underlying connection breaks."""
# Detect if underlying transport is closing before write data to it try:
# ref: https://github.com/ethereum/trinity/pull/614 await self.stream.write(data)
if self.writer.transport.is_closing(): except IOException as error:
raise RawConnError("Transport is closing") raise RawConnError from error
self.writer.write(data)
# Reference: https://github.com/ethereum/lahja/blob/93610b2eb46969ff1797e0748c7ac2595e130aef/lahja/asyncio/endpoint.py#L99-L102 # noqa: E501
# Use a lock to serialize drain() calls. Circumvents this bug:
# https://bugs.python.org/issue29930
async with self._drain_lock:
try:
await self.writer.drain()
except ConnectionResetError as error:
raise RawConnError() from error
async def read(self, n: int = -1) -> bytes: async def read(self, n: int = None) -> bytes:
""" """
Read up to ``n`` bytes from the underlying stream. This call is Read up to ``n`` bytes from the underlying stream. This call is
delegated directly to the underlying ``self.reader``. delegated directly to the underlying ``self.reader``.
@ -48,18 +28,9 @@ class RawConnection(IRawConnection):
Raise `RawConnError` if the underlying connection breaks Raise `RawConnError` if the underlying connection breaks
""" """
try: try:
return await self.reader.read(n) return await self.stream.read(n)
except ConnectionResetError as error: except IOException as error:
raise RawConnError() from error raise RawConnError from error
async def close(self) -> None: async def close(self) -> None:
if self.writer.transport.is_closing(): await self.stream.close()
return
self.writer.close()
if sys.version_info < (3, 7):
return
try:
await self.writer.wait_closed()
# In case the connection is already reset.
except ConnectionResetError:
return

View File

@ -1,5 +1,6 @@
import asyncio from typing import TYPE_CHECKING, Set, Tuple
from typing import TYPE_CHECKING, Any, Awaitable, List, Set, Tuple
import trio
from libp2p.network.connection.net_connection_interface import INetConn from libp2p.network.connection.net_connection_interface import INetConn
from libp2p.network.stream.net_stream import NetStream from libp2p.network.stream.net_stream import NetStream
@ -19,90 +20,78 @@ class SwarmConn(INetConn):
muxed_conn: IMuxedConn muxed_conn: IMuxedConn
swarm: "Swarm" swarm: "Swarm"
streams: Set[NetStream] streams: Set[NetStream]
event_closed: asyncio.Event event_closed: trio.Event
_tasks: List["asyncio.Future[Any]"]
def __init__(self, muxed_conn: IMuxedConn, swarm: "Swarm") -> None: def __init__(self, muxed_conn: IMuxedConn, swarm: "Swarm") -> None:
self.muxed_conn = muxed_conn self.muxed_conn = muxed_conn
self.swarm = swarm self.swarm = swarm
self.streams = set() self.streams = set()
self.event_closed = asyncio.Event() self.event_closed = trio.Event()
self.event_started = trio.Event()
self._tasks = [] @property
def is_closed(self) -> bool:
return self.event_closed.is_set()
async def close(self) -> None: async def close(self) -> None:
if self.event_closed.is_set(): if self.event_closed.is_set():
return return
self.event_closed.set() self.event_closed.set()
await self._cleanup()
async def _cleanup(self) -> None:
self.swarm.remove_conn(self) self.swarm.remove_conn(self)
await self.muxed_conn.close() await self.muxed_conn.close()
# This is just for cleaning up state. The connection has already been closed. # This is just for cleaning up state. The connection has already been closed.
# We *could* optimize this but it really isn't worth it. # We *could* optimize this but it really isn't worth it.
for stream in self.streams: for stream in self.streams.copy():
await stream.reset() await stream.reset()
# Force context switch for stream handlers to process the stream reset event we just emit # Force context switch for stream handlers to process the stream reset event we just emit
# before we cancel the stream handler tasks. # before we cancel the stream handler tasks.
await asyncio.sleep(0.1) await trio.sleep(0.1)
for task in self._tasks: await self._notify_disconnected()
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# Schedule `self._notify_disconnected` to make it execute after `close` is finished.
self._notify_disconnected()
async def _handle_new_streams(self) -> None: async def _handle_new_streams(self) -> None:
while True: self.event_started.set()
try: async with trio.open_nursery() as nursery:
stream = await self.muxed_conn.accept_stream() while True:
except MuxedConnUnavailable: try:
# If there is anything wrong in the MuxedConn, stream = await self.muxed_conn.accept_stream()
# we should break the loop and close the connection. except MuxedConnUnavailable:
break await self.close()
# Asynchronously handle the accepted stream, to avoid blocking the next stream. break
await self.run_task(self._handle_muxed_stream(stream)) # Asynchronously handle the accepted stream, to avoid blocking the next stream.
nursery.start_soon(self._handle_muxed_stream, stream)
await self.close()
async def _call_stream_handler(self, net_stream: NetStream) -> None:
try:
await self.swarm.common_stream_handler(net_stream)
# TODO: More exact exceptions
except Exception:
# TODO: Emit logs.
# TODO: Clean up and remove the stream from SwarmConn if there is anything wrong.
self.remove_stream(net_stream)
async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None: async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None:
net_stream = self._add_stream(muxed_stream) net_stream = await self._add_stream(muxed_stream)
if self.swarm.common_stream_handler is not None: try:
await self.run_task(self._call_stream_handler(net_stream)) # Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427
await self.swarm.common_stream_handler(net_stream) # type: ignore
finally:
# As long as `common_stream_handler`, remove the stream.
self.remove_stream(net_stream)
def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream:
net_stream = NetStream(muxed_stream) net_stream = NetStream(muxed_stream)
self.streams.add(net_stream) self.streams.add(net_stream)
self.swarm.notify_opened_stream(net_stream) await self.swarm.notify_opened_stream(net_stream)
return net_stream return net_stream
def _notify_disconnected(self) -> None: async def _notify_disconnected(self) -> None:
self.swarm.notify_disconnected(self) await self.swarm.notify_disconnected(self)
async def start(self) -> None: async def start(self) -> None:
await self.run_task(self._handle_new_streams()) await self._handle_new_streams()
async def run_task(self, coro: Awaitable[Any]) -> None:
self._tasks.append(asyncio.ensure_future(coro))
async def new_stream(self) -> NetStream: async def new_stream(self) -> NetStream:
muxed_stream = await self.muxed_conn.open_stream() muxed_stream = await self.muxed_conn.open_stream()
return self._add_stream(muxed_stream) return await self._add_stream(muxed_stream)
async def get_streams(self) -> Tuple[NetStream, ...]: def get_streams(self) -> Tuple[NetStream, ...]:
return tuple(self.streams) return tuple(self.streams)
def remove_stream(self, stream: NetStream) -> None: def remove_stream(self, stream: NetStream) -> None:

View File

@ -1,6 +1,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Dict, Sequence from typing import TYPE_CHECKING, Dict, Sequence
from async_service import ServiceAPI
from multiaddr import Multiaddr from multiaddr import Multiaddr
from libp2p.network.connection.net_connection_interface import INetConn from libp2p.network.connection.net_connection_interface import INetConn
@ -70,3 +71,7 @@ class INetwork(ABC):
@abstractmethod @abstractmethod
async def close_peer(self, peer_id: ID) -> None: async def close_peer(self, peer_id: ID) -> None:
pass pass
class INetworkService(INetwork, ServiceAPI):
pass

View File

@ -37,7 +37,7 @@ class NetStream(INetStream):
""" """
self.protocol_id = protocol_id self.protocol_id = protocol_id
async def read(self, n: int = -1) -> bytes: async def read(self, n: int = None) -> bytes:
""" """
reads from stream. reads from stream.

View File

@ -1,9 +1,11 @@
import asyncio
import logging import logging
from typing import Dict, List, Optional from typing import Dict, List, Optional
from async_service import Service
from multiaddr import Multiaddr from multiaddr import Multiaddr
import trio
from libp2p.io.abc import ReadWriteCloser
from libp2p.network.connection.net_connection_interface import INetConn from libp2p.network.connection.net_connection_interface import INetConn
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.peer.peerstore import PeerStoreError from libp2p.peer.peerstore import PeerStoreError
@ -23,14 +25,21 @@ from ..exceptions import MultiError
from .connection.raw_connection import RawConnection from .connection.raw_connection import RawConnection
from .connection.swarm_connection import SwarmConn from .connection.swarm_connection import SwarmConn
from .exceptions import SwarmException from .exceptions import SwarmException
from .network_interface import INetwork from .network_interface import INetworkService
from .notifee_interface import INotifee from .notifee_interface import INotifee
from .stream.net_stream_interface import INetStream from .stream.net_stream_interface import INetStream
logger = logging.getLogger("libp2p.network.swarm") logger = logging.getLogger("libp2p.network.swarm")
class Swarm(INetwork): def create_default_stream_handler(network: INetworkService) -> StreamHandlerFn:
async def stream_handler(stream: INetStream) -> None:
await network.get_manager().wait_finished()
return stream_handler
class Swarm(Service, INetworkService):
self_id: ID self_id: ID
peerstore: IPeerStore peerstore: IPeerStore
@ -40,7 +49,9 @@ class Swarm(INetwork):
# whereas in Go one `peer_id` may point to multiple connections. # whereas in Go one `peer_id` may point to multiple connections.
connections: Dict[ID, INetConn] connections: Dict[ID, INetConn]
listeners: Dict[str, IListener] listeners: Dict[str, IListener]
common_stream_handler: Optional[StreamHandlerFn] common_stream_handler: StreamHandlerFn
listener_nursery: Optional[trio.Nursery]
event_listener_nursery_created: trio.Event
notifees: List[INotifee] notifees: List[INotifee]
@ -61,13 +72,31 @@ class Swarm(INetwork):
# Create Notifee array # Create Notifee array
self.notifees = [] self.notifees = []
self.common_stream_handler = None # Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427
self.common_stream_handler = create_default_stream_handler(self) # type: ignore
self.listener_nursery = None
self.event_listener_nursery_created = trio.Event()
async def run(self) -> None:
async with trio.open_nursery() as nursery:
# Create a nursery for listener tasks.
self.listener_nursery = nursery
self.event_listener_nursery_created.set()
try:
await self.manager.wait_finished()
finally:
# The service ended. Cancel listener tasks.
nursery.cancel_scope.cancel()
# Indicate that the nursery has been cancelled.
self.listener_nursery = None
def get_peer_id(self) -> ID: def get_peer_id(self) -> ID:
return self.self_id return self.self_id
def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None: def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None:
self.common_stream_handler = stream_handler # Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427
self.common_stream_handler = stream_handler # type: ignore
async def dial_peer(self, peer_id: ID) -> INetConn: async def dial_peer(self, peer_id: ID) -> INetConn:
""" """
@ -195,19 +224,15 @@ class Swarm(INetwork):
- Call listener listen with the multiaddr - Call listener listen with the multiaddr
- Map multiaddr to listener - Map multiaddr to listener
""" """
# We need to wait until `self.listener_nursery` is created.
await self.event_listener_nursery_created.wait()
for maddr in multiaddrs: for maddr in multiaddrs:
if str(maddr) in self.listeners: if str(maddr) in self.listeners:
return True return True
async def conn_handler( async def conn_handler(read_write_closer: ReadWriteCloser) -> None:
reader: asyncio.StreamReader, writer: asyncio.StreamWriter raw_conn = RawConnection(read_write_closer, False)
) -> None:
connection_info = writer.get_extra_info("peername")
# TODO make a proper multiaddr
peer_addr = f"/ip4/{connection_info[0]}/tcp/{connection_info[1]}"
logger.debug("inbound connection at %s", peer_addr)
# logger.debug("inbound connection request", peer_id)
raw_conn = RawConnection(reader, writer, False)
# Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure # Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure
# the conn and then mux the conn # the conn and then mux the conn
@ -217,16 +242,13 @@ class Swarm(INetwork):
raw_conn, ID(b""), False raw_conn, ID(b""), False
) )
except SecurityUpgradeFailure as error: except SecurityUpgradeFailure as error:
logger.debug("failed to upgrade security for peer at %s", peer_addr) logger.debug("failed to upgrade security for peer at %s", maddr)
await raw_conn.close() await raw_conn.close()
raise SwarmException( raise SwarmException(
f"failed to upgrade security for peer at {peer_addr}" f"failed to upgrade security for peer at {maddr}"
) from error ) from error
peer_id = secured_conn.get_remote_peer() peer_id = secured_conn.get_remote_peer()
logger.debug("upgraded security for peer at %s", peer_addr)
logger.debug("identified peer at %s as %s", peer_addr, peer_id)
try: try:
muxed_conn = await self.upgrader.upgrade_connection( muxed_conn = await self.upgrader.upgrade_connection(
secured_conn, peer_id secured_conn, peer_id
@ -240,17 +262,24 @@ class Swarm(INetwork):
logger.debug("upgraded mux for peer %s", peer_id) logger.debug("upgraded mux for peer %s", peer_id)
await self.add_conn(muxed_conn) await self.add_conn(muxed_conn)
logger.debug("successfully opened connection to peer %s", peer_id) logger.debug("successfully opened connection to peer %s", peer_id)
# NOTE: This is a intentional barrier to prevent from the handler exiting and
# closing the connection.
await self.manager.wait_finished()
try: try:
# Success # Success
listener = self.transport.create_listener(conn_handler) listener = self.transport.create_listener(conn_handler)
self.listeners[str(maddr)] = listener self.listeners[str(maddr)] = listener
await listener.listen(maddr) # TODO: `listener.listen` is not bounded with nursery. If we want to be
# I/O agnostic, we should change the API.
if self.listener_nursery is None:
raise SwarmException("swarm instance hasn't been run")
await listener.listen(maddr, self.listener_nursery)
# Call notifiers since event occurred # Call notifiers since event occurred
self.notify_listen(maddr) await self.notify_listen(maddr)
return True return True
except IOError: except IOError:
@ -261,26 +290,12 @@ class Swarm(INetwork):
return False return False
async def close(self) -> None: async def close(self) -> None:
# TODO: Prevent from new listeners and conns being added. await self.manager.stop()
# Reference: https://github.com/libp2p/go-libp2p-swarm/blob/8be680aef8dea0a4497283f2f98470c2aeae6b65/swarm.go#L124-L134 # noqa: E501
# Close listeners
await asyncio.gather(
*[listener.close() for listener in self.listeners.values()]
)
# Close connections
await asyncio.gather(
*[connection.close() for connection in self.connections.values()]
)
logger.debug("swarm successfully closed") logger.debug("swarm successfully closed")
async def close_peer(self, peer_id: ID) -> None: async def close_peer(self, peer_id: ID) -> None:
if peer_id not in self.connections: if peer_id not in self.connections:
return return
# TODO: Should be changed to close multisple connections,
# if we have several connections per peer in the future.
connection = self.connections[peer_id] connection = self.connections[peer_id]
# NOTE: `connection.close` will delete `peer_id` from `self.connections` # NOTE: `connection.close` will delete `peer_id` from `self.connections`
# and `notify_disconnected` for us. # and `notify_disconnected` for us.
@ -293,11 +308,14 @@ class Swarm(INetwork):
and start to monitor the connection for its new streams and and start to monitor the connection for its new streams and
disconnection.""" disconnection."""
swarm_conn = SwarmConn(muxed_conn, self) swarm_conn = SwarmConn(muxed_conn, self)
self.manager.run_task(muxed_conn.start)
await muxed_conn.event_started.wait()
self.manager.run_task(swarm_conn.start)
await swarm_conn.event_started.wait()
# Store muxed_conn with peer id # Store muxed_conn with peer id
self.connections[muxed_conn.peer_id] = swarm_conn self.connections[muxed_conn.peer_id] = swarm_conn
# Call notifiers since event occurred # Call notifiers since event occurred
self.notify_connected(swarm_conn) await self.notify_connected(swarm_conn)
await swarm_conn.start()
return swarm_conn return swarm_conn
def remove_conn(self, swarm_conn: SwarmConn) -> None: def remove_conn(self, swarm_conn: SwarmConn) -> None:
@ -306,14 +324,10 @@ class Swarm(INetwork):
peer_id = swarm_conn.muxed_conn.peer_id peer_id = swarm_conn.muxed_conn.peer_id
if peer_id not in self.connections: if peer_id not in self.connections:
return return
# TODO: Should be changed to remove the exact connection,
# if we have several connections per peer in the future.
del self.connections[peer_id] del self.connections[peer_id]
# Notifee # Notifee
# TODO: Remeber the spawn notifying tasks and clean them up when closing.
def register_notifee(self, notifee: INotifee) -> None: def register_notifee(self, notifee: INotifee) -> None:
""" """
:param notifee: object implementing Notifee interface :param notifee: object implementing Notifee interface
@ -321,20 +335,28 @@ class Swarm(INetwork):
""" """
self.notifees.append(notifee) self.notifees.append(notifee)
def notify_opened_stream(self, stream: INetStream) -> None: async def notify_opened_stream(self, stream: INetStream) -> None:
asyncio.gather( async with trio.open_nursery() as nursery:
*[notifee.opened_stream(self, stream) for notifee in self.notifees] for notifee in self.notifees:
) nursery.start_soon(notifee.opened_stream, self, stream)
# TODO: `notify_closed_stream` async def notify_connected(self, conn: INetConn) -> None:
async with trio.open_nursery() as nursery:
for notifee in self.notifees:
nursery.start_soon(notifee.connected, self, conn)
def notify_connected(self, conn: INetConn) -> None: async def notify_disconnected(self, conn: INetConn) -> None:
asyncio.gather(*[notifee.connected(self, conn) for notifee in self.notifees]) async with trio.open_nursery() as nursery:
for notifee in self.notifees:
nursery.start_soon(notifee.disconnected, self, conn)
def notify_disconnected(self, conn: INetConn) -> None: async def notify_listen(self, multiaddr: Multiaddr) -> None:
asyncio.gather(*[notifee.disconnected(self, conn) for notifee in self.notifees]) async with trio.open_nursery() as nursery:
for notifee in self.notifees:
nursery.start_soon(notifee.listen, self, multiaddr)
def notify_listen(self, multiaddr: Multiaddr) -> None: async def notify_closed_stream(self, stream: INetStream) -> None:
asyncio.gather(*[notifee.listen(self, multiaddr) for notifee in self.notifees]) raise NotImplementedError
# TODO: `notify_listen_close` async def notify_listen_close(self, multiaddr: Multiaddr) -> None:
raise NotImplementedError

View File

@ -25,9 +25,6 @@ def info_from_p2p_addr(addr: multiaddr.Multiaddr) -> PeerInfo:
if not addr: if not addr:
raise InvalidAddrError("`addr` should not be `None`") raise InvalidAddrError("`addr` should not be `None`")
if not isinstance(addr, multiaddr.Multiaddr):
raise InvalidAddrError(f"`addr`={addr} should be of type `Multiaddr`")
parts = addr.split() parts = addr.split()
if not parts: if not parts:
raise InvalidAddrError( raise InvalidAddrError(

View File

@ -1,15 +1,37 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List from typing import (
TYPE_CHECKING,
AsyncContextManager,
AsyncIterable,
KeysView,
List,
Tuple,
)
from async_service import ServiceAPI
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from .pb import rpc_pb2 from .pb import rpc_pb2
from .typing import ValidatorFn
if TYPE_CHECKING: if TYPE_CHECKING:
from .pubsub import Pubsub # noqa: F401 from .pubsub import Pubsub # noqa: F401
class ISubscriptionAPI(
AsyncContextManager["ISubscriptionAPI"], AsyncIterable[rpc_pb2.Message]
):
@abstractmethod
async def unsubscribe(self) -> None:
...
@abstractmethod
async def get(self) -> rpc_pb2.Message:
...
class IPubsubRouter(ABC): class IPubsubRouter(ABC):
@abstractmethod @abstractmethod
def get_protocols(self) -> List[TProtocol]: def get_protocols(self) -> List[TProtocol]:
@ -53,7 +75,6 @@ class IPubsubRouter(ABC):
:param rpc: rpc message :param rpc: rpc message
""" """
# FIXME: Should be changed to type 'peer.ID'
@abstractmethod @abstractmethod
async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None: async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None:
""" """
@ -80,3 +101,46 @@ class IPubsubRouter(ABC):
:param topic: topic to leave :param topic: topic to leave
""" """
class IPubsub(ServiceAPI):
@property
@abstractmethod
def my_id(self) -> ID:
...
@property
@abstractmethod
def protocols(self) -> Tuple[TProtocol, ...]:
...
@property
@abstractmethod
def topic_ids(self) -> KeysView[str]:
...
@abstractmethod
def set_topic_validator(
self, topic: str, validator: ValidatorFn, is_async_validator: bool
) -> None:
...
@abstractmethod
def remove_topic_validator(self, topic: str) -> None:
...
@abstractmethod
async def wait_until_ready(self) -> None:
...
@abstractmethod
async def subscribe(self, topic_id: str) -> ISubscriptionAPI:
...
@abstractmethod
async def unsubscribe(self, topic_id: str) -> None:
...
@abstractmethod
async def publish(self, topic_id: str, data: bytes) -> None:
...

View File

@ -0,0 +1,9 @@
from libp2p.exceptions import BaseLibp2pError
class PubsubRouterError(BaseLibp2pError):
pass
class NoPubsubAttached(PubsubRouterError):
pass

View File

@ -1,14 +1,16 @@
import logging import logging
from typing import Iterable, List, Sequence from typing import Iterable, List, Sequence
import trio
from libp2p.network.stream.exceptions import StreamClosed from libp2p.network.stream.exceptions import StreamClosed
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from libp2p.utils import encode_varint_prefixed from libp2p.utils import encode_varint_prefixed
from .abc import IPubsubRouter
from .pb import rpc_pb2 from .pb import rpc_pb2
from .pubsub import Pubsub from .pubsub import Pubsub
from .pubsub_router_interface import IPubsubRouter
PROTOCOL_ID = TProtocol("/floodsub/1.0.0") PROTOCOL_ID = TProtocol("/floodsub/1.0.0")
@ -61,6 +63,8 @@ class FloodSub(IPubsubRouter):
:param rpc: rpc message :param rpc: rpc message
""" """
# Checkpoint
await trio.hazmat.checkpoint()
async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None: async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None:
""" """
@ -107,6 +111,8 @@ class FloodSub(IPubsubRouter):
:param topic: topic to join :param topic: topic to join
""" """
# Checkpoint
await trio.hazmat.checkpoint()
async def leave(self, topic: str) -> None: async def leave(self, topic: str) -> None:
""" """
@ -115,6 +121,8 @@ class FloodSub(IPubsubRouter):
:param topic: topic to leave :param topic: topic to leave
""" """
# Checkpoint
await trio.hazmat.checkpoint()
def _get_peers_to_send( def _get_peers_to_send(
self, topic_ids: Iterable[str], msg_forwarder: ID, origin: ID self, topic_ids: Iterable[str], msg_forwarder: ID, origin: ID

View File

@ -1,28 +1,30 @@
from ast import literal_eval from ast import literal_eval
import asyncio
from collections import defaultdict from collections import defaultdict
import logging import logging
import random import random
from typing import Any, DefaultDict, Dict, Iterable, List, Sequence, Set, Tuple from typing import Any, DefaultDict, Dict, Iterable, List, Sequence, Set, Tuple
from async_service import Service
import trio
from libp2p.network.stream.exceptions import StreamClosed from libp2p.network.stream.exceptions import StreamClosed
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.pubsub import floodsub from libp2p.pubsub import floodsub
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from libp2p.utils import encode_varint_prefixed from libp2p.utils import encode_varint_prefixed
from .abc import IPubsubRouter
from .exceptions import NoPubsubAttached
from .mcache import MessageCache from .mcache import MessageCache
from .pb import rpc_pb2 from .pb import rpc_pb2
from .pubsub import Pubsub from .pubsub import Pubsub
from .pubsub_router_interface import IPubsubRouter
PROTOCOL_ID = TProtocol("/meshsub/1.0.0") PROTOCOL_ID = TProtocol("/meshsub/1.0.0")
logger = logging.getLogger("libp2p.pubsub.gossipsub") logger = logging.getLogger("libp2p.pubsub.gossipsub")
class GossipSub(IPubsubRouter): class GossipSub(IPubsubRouter, Service):
protocols: List[TProtocol] protocols: List[TProtocol]
pubsub: Pubsub pubsub: Pubsub
@ -38,7 +40,8 @@ class GossipSub(IPubsubRouter):
# The protocol peer supports # The protocol peer supports
peer_protocol: Dict[ID, TProtocol] peer_protocol: Dict[ID, TProtocol]
time_since_last_publish: Dict[str, int] # TODO: Add `time_since_last_publish`
# Create topic --> time since last publish map.
mcache: MessageCache mcache: MessageCache
@ -75,9 +78,6 @@ class GossipSub(IPubsubRouter):
# Create peer --> protocol mapping # Create peer --> protocol mapping
self.peer_protocol = {} self.peer_protocol = {}
# Create topic --> time since last publish map
self.time_since_last_publish = {}
# Create message cache # Create message cache
self.mcache = MessageCache(gossip_window, gossip_history) self.mcache = MessageCache(gossip_window, gossip_history)
@ -85,6 +85,12 @@ class GossipSub(IPubsubRouter):
self.heartbeat_initial_delay = heartbeat_initial_delay self.heartbeat_initial_delay = heartbeat_initial_delay
self.heartbeat_interval = heartbeat_interval self.heartbeat_interval = heartbeat_interval
async def run(self) -> None:
if self.pubsub is None:
raise NoPubsubAttached
self.manager.run_daemon_task(self.heartbeat)
await self.manager.wait_finished()
# Interface functions # Interface functions
def get_protocols(self) -> List[TProtocol]: def get_protocols(self) -> List[TProtocol]:
@ -104,9 +110,6 @@ class GossipSub(IPubsubRouter):
logger.debug("attached to pusub") logger.debug("attached to pusub")
# Start heartbeat now that we have a pubsub instance
asyncio.ensure_future(self.heartbeat())
def add_peer(self, peer_id: ID, protocol_id: TProtocol) -> None: def add_peer(self, peer_id: ID, protocol_id: TProtocol) -> None:
""" """
Notifies the router that a new peer has been connected. Notifies the router that a new peer has been connected.
@ -370,7 +373,7 @@ class GossipSub(IPubsubRouter):
state changes in the preceding heartbeat state changes in the preceding heartbeat
""" """
# Start after a delay. Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/01b9825fbee1848751d90a8469e3f5f43bac8466/gossipsub.go#L410 # Noqa: E501 # Start after a delay. Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/01b9825fbee1848751d90a8469e3f5f43bac8466/gossipsub.go#L410 # Noqa: E501
await asyncio.sleep(self.heartbeat_initial_delay) await trio.sleep(self.heartbeat_initial_delay)
while True: while True:
# Maintain mesh and keep track of which peers to send GRAFT or PRUNE to # Maintain mesh and keep track of which peers to send GRAFT or PRUNE to
peers_to_graft, peers_to_prune = self.mesh_heartbeat() peers_to_graft, peers_to_prune = self.mesh_heartbeat()
@ -385,7 +388,7 @@ class GossipSub(IPubsubRouter):
self.mcache.shift() self.mcache.shift()
await asyncio.sleep(self.heartbeat_interval) await trio.sleep(self.heartbeat_interval)
def mesh_heartbeat( def mesh_heartbeat(
self self
@ -413,7 +416,7 @@ class GossipSub(IPubsubRouter):
if num_mesh_peers_in_topic > self.degree_high: if num_mesh_peers_in_topic > self.degree_high:
# Select |mesh[topic]| - D peers from mesh[topic] # Select |mesh[topic]| - D peers from mesh[topic]
selected_peers = GossipSub.select_from_minus( selected_peers = self.select_from_minus(
num_mesh_peers_in_topic - self.degree, self.mesh[topic], set() num_mesh_peers_in_topic - self.degree, self.mesh[topic], set()
) )
for peer in selected_peers: for peer in selected_peers:
@ -428,15 +431,10 @@ class GossipSub(IPubsubRouter):
# Note: the comments here are the exact pseudocode from the spec # Note: the comments here are the exact pseudocode from the spec
for topic in self.fanout: for topic in self.fanout:
# Delete topic entry if it's not in `pubsub.peer_topics` # Delete topic entry if it's not in `pubsub.peer_topics`
# or if it's time-since-last-published > ttl # or (TODO) if it's time-since-last-published > ttl
# TODO: there's no way time_since_last_publish gets set anywhere yet if topic not in self.pubsub.peer_topics:
if (
topic not in self.pubsub.peer_topics
or self.time_since_last_publish[topic] > self.time_to_live
):
# Remove topic from fanout # Remove topic from fanout
del self.fanout[topic] del self.fanout[topic]
del self.time_since_last_publish[topic]
else: else:
# Check if fanout peers are still in the topic and remove the ones that are not # Check if fanout peers are still in the topic and remove the ones that are not
# ref: https://github.com/libp2p/go-libp2p-pubsub/blob/01b9825fbee1848751d90a8469e3f5f43bac8466/gossipsub.go#L498-L504 # noqa: E501 # ref: https://github.com/libp2p/go-libp2p-pubsub/blob/01b9825fbee1848751d90a8469e3f5f43bac8466/gossipsub.go#L498-L504 # noqa: E501

View File

@ -1,21 +1,12 @@
import asyncio import functools
import logging import logging
import time import time
from typing import ( from typing import TYPE_CHECKING, Dict, KeysView, List, NamedTuple, Set, Tuple, cast
TYPE_CHECKING,
Awaitable,
Callable,
Dict,
List,
NamedTuple,
Set,
Tuple,
Union,
cast,
)
from async_service import Service
import base58 import base58
from lru import LRU from lru import LRU
import trio
from libp2p.crypto.keys import PrivateKey from libp2p.crypto.keys import PrivateKey
from libp2p.exceptions import ParseError, ValidationError from libp2p.exceptions import ParseError, ValidationError
@ -28,15 +19,21 @@ from libp2p.peer.id import ID
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from libp2p.utils import encode_varint_prefixed, read_varint_prefixed_bytes from libp2p.utils import encode_varint_prefixed, read_varint_prefixed_bytes
from .abc import IPubsub, ISubscriptionAPI
from .pb import rpc_pb2 from .pb import rpc_pb2
from .pubsub_notifee import PubsubNotifee from .pubsub_notifee import PubsubNotifee
from .subscription import TrioSubscriptionAPI
from .typing import AsyncValidatorFn, SyncValidatorFn, ValidatorFn
from .validators import PUBSUB_SIGNING_PREFIX, signature_validator from .validators import PUBSUB_SIGNING_PREFIX, signature_validator
if TYPE_CHECKING: if TYPE_CHECKING:
from .pubsub_router_interface import IPubsubRouter # noqa: F401 from .abc import IPubsubRouter # noqa: F401
from typing import Any # noqa: F401 from typing import Any # noqa: F401
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/40e1c94708658b155f30cf99e4574f384756d83c/topic.go#L97 # noqa: E501
SUBSCRIPTION_CHANNEL_SIZE = 32
logger = logging.getLogger("libp2p.pubsub") logger = logging.getLogger("libp2p.pubsub")
@ -45,34 +42,24 @@ def get_msg_id(msg: rpc_pb2.Message) -> Tuple[bytes, bytes]:
return (msg.seqno, msg.from_id) return (msg.seqno, msg.from_id)
SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool]
AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]]
ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn]
class TopicValidator(NamedTuple): class TopicValidator(NamedTuple):
validator: ValidatorFn validator: ValidatorFn
is_async: bool is_async: bool
class Pubsub: class Pubsub(Service, IPubsub):
host: IHost host: IHost
my_id: ID
router: "IPubsubRouter" router: "IPubsubRouter"
peer_queue: "asyncio.Queue[ID]" peer_receive_channel: "trio.MemoryReceiveChannel[ID]"
dead_peer_queue: "asyncio.Queue[ID]" dead_peer_receive_channel: "trio.MemoryReceiveChannel[ID]"
protocols: List[TProtocol]
incoming_msgs_from_peers: "asyncio.Queue[rpc_pb2.Message]"
outgoing_messages: "asyncio.Queue[rpc_pb2.Message]"
seen_messages: LRU seen_messages: LRU
my_topics: Dict[str, "asyncio.Queue[rpc_pb2.Message]"] subscribed_topics_send: Dict[str, "trio.MemorySendChannel[rpc_pb2.Message]"]
subscribed_topics_receive: Dict[str, "TrioSubscriptionAPI"]
peer_topics: Dict[str, Set[ID]] peer_topics: Dict[str, Set[ID]]
peers: Dict[ID, INetStream] peers: Dict[ID, INetStream]
@ -81,17 +68,17 @@ class Pubsub:
counter: int # uint64 counter: int # uint64
_tasks: List["asyncio.Future[Any]"]
# Indicate if we should enforce signature verification # Indicate if we should enforce signature verification
strict_signing: bool strict_signing: bool
sign_key: PrivateKey sign_key: PrivateKey
event_handle_peer_queue_started: trio.Event
event_handle_dead_peer_queue_started: trio.Event
def __init__( def __init__(
self, self,
host: IHost, host: IHost,
router: "IPubsubRouter", router: "IPubsubRouter",
my_id: ID,
cache_size: int = None, cache_size: int = None,
strict_signing: bool = True, strict_signing: bool = True,
) -> None: ) -> None:
@ -107,39 +94,44 @@ class Pubsub:
""" """
self.host = host self.host = host
self.router = router self.router = router
self.my_id = my_id
# Attach this new Pubsub object to the router # Attach this new Pubsub object to the router
self.router.attach(self) self.router.attach(self)
peer_send, peer_receive = trio.open_memory_channel[ID](0)
dead_peer_send, dead_peer_receive = trio.open_memory_channel[ID](0)
# Only keep the receive channels in `Pubsub`.
# Therefore, we can only close from the receive side.
self.peer_receive_channel = peer_receive
self.dead_peer_receive_channel = dead_peer_receive
# Register a notifee # Register a notifee
self.peer_queue = asyncio.Queue()
self.dead_peer_queue = asyncio.Queue()
self.host.get_network().register_notifee( self.host.get_network().register_notifee(
PubsubNotifee(self.peer_queue, self.dead_peer_queue) PubsubNotifee(peer_send, dead_peer_send)
) )
# Register stream handlers for each pubsub router protocol to handle # Register stream handlers for each pubsub router protocol to handle
# the pubsub streams opened on those protocols # the pubsub streams opened on those protocols
self.protocols = self.router.get_protocols() for protocol in router.get_protocols():
for protocol in self.protocols:
self.host.set_stream_handler(protocol, self.stream_handler) self.host.set_stream_handler(protocol, self.stream_handler)
# Use asyncio queues for proper context switching
self.incoming_msgs_from_peers = asyncio.Queue()
self.outgoing_messages = asyncio.Queue()
# keeps track of seen messages as LRU cache # keeps track of seen messages as LRU cache
if cache_size is None: if cache_size is None:
self.cache_size = 128 self.cache_size = 128
else: else:
self.cache_size = cache_size self.cache_size = cache_size
self.strict_signing = strict_signing
if strict_signing:
self.sign_key = self.host.get_private_key()
else:
self.sign_key = None
self.seen_messages = LRU(self.cache_size) self.seen_messages = LRU(self.cache_size)
# Map of topics we are subscribed to blocking queues # Map of topics we are subscribed to blocking queues
# for when the given topic receives a message # for when the given topic receives a message
self.my_topics = {} self.subscribed_topics_send = {}
self.subscribed_topics_receive = {}
# Map of topic to peers to keep track of what peers are subscribed to # Map of topic to peers to keep track of what peers are subscribed to
self.peer_topics = {} self.peer_topics = {}
@ -152,22 +144,31 @@ class Pubsub:
self.counter = int(time.time()) self.counter = int(time.time())
self._tasks = [] self.event_handle_peer_queue_started = trio.Event()
# Call handle peer to keep waiting for updates to peer queue self.event_handle_dead_peer_queue_started = trio.Event()
self._tasks.append(asyncio.ensure_future(self.handle_peer_queue()))
self._tasks.append(asyncio.ensure_future(self.handle_dead_peer_queue()))
self.strict_signing = strict_signing async def run(self) -> None:
if strict_signing: self.manager.run_daemon_task(self.handle_peer_queue)
self.sign_key = self.host.get_private_key() self.manager.run_daemon_task(self.handle_dead_peer_queue)
else: await self.manager.wait_finished()
self.sign_key = None
@property
def my_id(self) -> ID:
return self.host.get_id()
@property
def protocols(self) -> Tuple[TProtocol, ...]:
return tuple(self.router.get_protocols())
@property
def topic_ids(self) -> KeysView[str]:
return self.subscribed_topics_receive.keys()
def get_hello_packet(self) -> rpc_pb2.RPC: def get_hello_packet(self) -> rpc_pb2.RPC:
"""Generate subscription message with all topics we are subscribed to """Generate subscription message with all topics we are subscribed to
only send hello packet if we have subscribed topics.""" only send hello packet if we have subscribed topics."""
packet = rpc_pb2.RPC() packet = rpc_pb2.RPC()
for topic_id in self.my_topics: for topic_id in self.topic_ids:
packet.subscriptions.extend( packet.subscriptions.extend(
[rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id)] [rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id)]
) )
@ -182,7 +183,7 @@ class Pubsub:
""" """
peer_id = stream.muxed_conn.peer_id peer_id = stream.muxed_conn.peer_id
while True: while self.manager.is_running:
incoming: bytes = await read_varint_prefixed_bytes(stream) incoming: bytes = await read_varint_prefixed_bytes(stream)
rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC() rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC()
rpc_incoming.ParseFromString(incoming) rpc_incoming.ParseFromString(incoming)
@ -194,11 +195,7 @@ class Pubsub:
logger.debug( logger.debug(
"received `publish` message %s from peer %s", msg, peer_id "received `publish` message %s from peer %s", msg, peer_id
) )
self._tasks.append( self.manager.run_task(self.push_msg, peer_id, msg)
asyncio.ensure_future(
self.push_msg(msg_forwarder=peer_id, msg=msg)
)
)
if rpc_incoming.subscriptions: if rpc_incoming.subscriptions:
# deal with RPC.subscriptions # deal with RPC.subscriptions
@ -226,9 +223,6 @@ class Pubsub:
) )
await self.router.handle_rpc(rpc_incoming, peer_id) await self.router.handle_rpc(rpc_incoming, peer_id)
# Force context switch
await asyncio.sleep(0)
def set_topic_validator( def set_topic_validator(
self, topic: str, validator: ValidatorFn, is_async_validator: bool self, topic: str, validator: ValidatorFn, is_async_validator: bool
) -> None: ) -> None:
@ -283,6 +277,10 @@ class Pubsub:
await stream.reset() await stream.reset()
self._handle_dead_peer(peer_id) self._handle_dead_peer(peer_id)
async def wait_until_ready(self) -> None:
await self.event_handle_peer_queue_started.wait()
await self.event_handle_dead_peer_queue_started.wait()
async def _handle_new_peer(self, peer_id: ID) -> None: async def _handle_new_peer(self, peer_id: ID) -> None:
try: try:
stream: INetStream = await self.host.new_stream(peer_id, self.protocols) stream: INetStream = await self.host.new_stream(peer_id, self.protocols)
@ -325,18 +323,21 @@ class Pubsub:
"""Continuously read from peer queue and each time a new peer is found, """Continuously read from peer queue and each time a new peer is found,
open a stream to the peer using a supported pubsub protocol pubsub open a stream to the peer using a supported pubsub protocol pubsub
protocols we support.""" protocols we support."""
while True: async with self.peer_receive_channel:
peer_id: ID = await self.peer_queue.get() self.event_handle_peer_queue_started.set()
# Add Peer async for peer_id in self.peer_receive_channel:
self._tasks.append(asyncio.ensure_future(self._handle_new_peer(peer_id))) # Add Peer
self.manager.run_task(self._handle_new_peer, peer_id)
async def handle_dead_peer_queue(self) -> None: async def handle_dead_peer_queue(self) -> None:
"""Continuously read from dead peer queue and close the stream between """Continuously read from dead peer channel and close the stream
that peer and remove peer info from pubsub and pubsub router.""" between that peer and remove peer info from pubsub and pubsub
while True: router."""
peer_id: ID = await self.dead_peer_queue.get() async with self.dead_peer_receive_channel:
# Remove Peer self.event_handle_dead_peer_queue_started.set()
self._handle_dead_peer(peer_id) async for peer_id in self.dead_peer_receive_channel:
# Remove Peer
self._handle_dead_peer(peer_id)
def handle_subscription( def handle_subscription(
self, origin_id: ID, sub_message: rpc_pb2.RPC.SubOpts self, origin_id: ID, sub_message: rpc_pb2.RPC.SubOpts
@ -360,8 +361,7 @@ class Pubsub:
if origin_id in self.peer_topics[sub_message.topicid]: if origin_id in self.peer_topics[sub_message.topicid]:
self.peer_topics[sub_message.topicid].discard(origin_id) self.peer_topics[sub_message.topicid].discard(origin_id)
# FIXME(mhchia): Change the function name? def notify_subscriptions(self, publish_message: rpc_pb2.Message) -> None:
async def handle_talk(self, publish_message: rpc_pb2.Message) -> None:
""" """
Put incoming message from a peer onto my blocking queue. Put incoming message from a peer onto my blocking queue.
@ -370,13 +370,19 @@ class Pubsub:
# Check if this message has any topics that we are subscribed to # Check if this message has any topics that we are subscribed to
for topic in publish_message.topicIDs: for topic in publish_message.topicIDs:
if topic in self.my_topics: if topic in self.topic_ids:
# we are subscribed to a topic this message was sent for, # we are subscribed to a topic this message was sent for,
# so add message to the subscription output queue # so add message to the subscription output queue
# for each topic # for each topic
await self.my_topics[topic].put(publish_message) try:
self.subscribed_topics_send[topic].send_nowait(publish_message)
except trio.WouldBlock:
# Channel is full, ignore this message.
logger.warning(
"fail to deliver message to subscription for topic %s", topic
)
async def subscribe(self, topic_id: str) -> "asyncio.Queue[rpc_pb2.Message]": async def subscribe(self, topic_id: str) -> ISubscriptionAPI:
""" """
Subscribe ourself to a topic. Subscribe ourself to a topic.
@ -386,11 +392,19 @@ class Pubsub:
logger.debug("subscribing to topic %s", topic_id) logger.debug("subscribing to topic %s", topic_id)
# Already subscribed # Already subscribed
if topic_id in self.my_topics: if topic_id in self.topic_ids:
return self.my_topics[topic_id] return self.subscribed_topics_receive[topic_id]
# Map topic_id to blocking queue send_channel, receive_channel = trio.open_memory_channel[rpc_pb2.Message](
self.my_topics[topic_id] = asyncio.Queue() SUBSCRIPTION_CHANNEL_SIZE
)
subscription = TrioSubscriptionAPI(
receive_channel,
unsubscribe_fn=functools.partial(self.unsubscribe, topic_id),
)
self.subscribed_topics_send[topic_id] = send_channel
self.subscribed_topics_receive[topic_id] = subscription
# Create subscribe message # Create subscribe message
packet: rpc_pb2.RPC = rpc_pb2.RPC() packet: rpc_pb2.RPC = rpc_pb2.RPC()
@ -404,8 +418,8 @@ class Pubsub:
# Tell router we are joining this topic # Tell router we are joining this topic
await self.router.join(topic_id) await self.router.join(topic_id)
# Return the asyncio queue for messages on this topic # Return the subscription for messages on this topic
return self.my_topics[topic_id] return subscription
async def unsubscribe(self, topic_id: str) -> None: async def unsubscribe(self, topic_id: str) -> None:
""" """
@ -417,10 +431,14 @@ class Pubsub:
logger.debug("unsubscribing from topic %s", topic_id) logger.debug("unsubscribing from topic %s", topic_id)
# Return if we already unsubscribed from the topic # Return if we already unsubscribed from the topic
if topic_id not in self.my_topics: if topic_id not in self.topic_ids:
return return
# Remove topic_id from map if present # Remove topic_id from the maps before yielding
del self.my_topics[topic_id] send_channel = self.subscribed_topics_send[topic_id]
del self.subscribed_topics_send[topic_id]
del self.subscribed_topics_receive[topic_id]
# Only close the send side
await send_channel.aclose()
# Create unsubscribe message # Create unsubscribe message
packet: rpc_pb2.RPC = rpc_pb2.RPC() packet: rpc_pb2.RPC = rpc_pb2.RPC()
@ -462,7 +480,7 @@ class Pubsub:
data=data, data=data,
topicIDs=[topic_id], topicIDs=[topic_id],
# Origin is ourself. # Origin is ourself.
from_id=self.host.get_id().to_bytes(), from_id=self.my_id.to_bytes(),
seqno=self._next_seqno(), seqno=self._next_seqno(),
) )
@ -474,7 +492,7 @@ class Pubsub:
msg.key = self.host.get_public_key().serialize() msg.key = self.host.get_public_key().serialize()
msg.signature = signature msg.signature = signature
await self.push_msg(self.host.get_id(), msg) await self.push_msg(self.my_id, msg)
logger.debug("successfully published message %s", msg) logger.debug("successfully published message %s", msg)
@ -485,12 +503,12 @@ class Pubsub:
:param msg_forwarder: the peer who forward us the message. :param msg_forwarder: the peer who forward us the message.
:param msg: the message. :param msg: the message.
""" """
sync_topic_validators = [] sync_topic_validators: List[SyncValidatorFn] = []
async_topic_validator_futures: List[Awaitable[bool]] = [] async_topic_validators: List[AsyncValidatorFn] = []
for topic_validator in self.get_msg_validators(msg): for topic_validator in self.get_msg_validators(msg):
if topic_validator.is_async: if topic_validator.is_async:
async_topic_validator_futures.append( async_topic_validators.append(
cast(Awaitable[bool], topic_validator.validator(msg_forwarder, msg)) cast(AsyncValidatorFn, topic_validator.validator)
) )
else: else:
sync_topic_validators.append( sync_topic_validators.append(
@ -503,9 +521,20 @@ class Pubsub:
# TODO: Implement throttle on async validators # TODO: Implement throttle on async validators
if len(async_topic_validator_futures) > 0: if len(async_topic_validators) > 0:
results = await asyncio.gather(*async_topic_validator_futures) # TODO: Use a better pattern
if not all(results): final_result: bool = True
async def run_async_validator(func: AsyncValidatorFn) -> None:
nonlocal final_result
result = await func(msg_forwarder, msg)
final_result = final_result and result
async with trio.open_nursery() as nursery:
for async_validator in async_topic_validators:
nursery.start_soon(run_async_validator, async_validator)
if not final_result:
raise ValidationError(f"Validation failed for msg={msg}") raise ValidationError(f"Validation failed for msg={msg}")
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None: async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
@ -548,7 +577,7 @@ class Pubsub:
return return
self._mark_msg_seen(msg) self._mark_msg_seen(msg)
await self.handle_talk(msg) self.notify_subscriptions(msg)
await self.router.publish(msg_forwarder, msg) await self.router.publish(msg_forwarder, msg)
def _next_seqno(self) -> bytes: def _next_seqno(self) -> bytes:
@ -567,14 +596,4 @@ class Pubsub:
self.seen_messages[msg_id] = 1 self.seen_messages[msg_id] = 1
def _is_subscribed_to_msg(self, msg: rpc_pb2.Message) -> bool: def _is_subscribed_to_msg(self, msg: rpc_pb2.Message) -> bool:
if not self.my_topics: return any(topic in self.topic_ids for topic in msg.topicIDs)
return False
return any(topic in self.my_topics for topic in msg.topicIDs)
async def close(self) -> None:
for task in self._tasks:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass

View File

@ -1,6 +1,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from multiaddr import Multiaddr from multiaddr import Multiaddr
import trio
from libp2p.network.connection.net_connection_interface import INetConn from libp2p.network.connection.net_connection_interface import INetConn
from libp2p.network.network_interface import INetwork from libp2p.network.network_interface import INetwork
@ -8,19 +9,18 @@ from libp2p.network.notifee_interface import INotifee
from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.stream.net_stream_interface import INetStream
if TYPE_CHECKING: if TYPE_CHECKING:
import asyncio # noqa: F401
from libp2p.peer.id import ID # noqa: F401 from libp2p.peer.id import ID # noqa: F401
class PubsubNotifee(INotifee): class PubsubNotifee(INotifee):
initiator_peers_queue: "asyncio.Queue[ID]" initiator_peers_queue: "trio.MemorySendChannel[ID]"
dead_peers_queue: "asyncio.Queue[ID]" dead_peers_queue: "trio.MemorySendChannel[ID]"
def __init__( def __init__(
self, self,
initiator_peers_queue: "asyncio.Queue[ID]", initiator_peers_queue: "trio.MemorySendChannel[ID]",
dead_peers_queue: "asyncio.Queue[ID]", dead_peers_queue: "trio.MemorySendChannel[ID]",
) -> None: ) -> None:
""" """
:param initiator_peers_queue: queue to add new peers to so that pubsub :param initiator_peers_queue: queue to add new peers to so that pubsub
@ -32,10 +32,10 @@ class PubsubNotifee(INotifee):
self.dead_peers_queue = dead_peers_queue self.dead_peers_queue = dead_peers_queue
async def opened_stream(self, network: INetwork, stream: INetStream) -> None: async def opened_stream(self, network: INetwork, stream: INetStream) -> None:
pass await trio.hazmat.checkpoint()
async def closed_stream(self, network: INetwork, stream: INetStream) -> None: async def closed_stream(self, network: INetwork, stream: INetStream) -> None:
pass await trio.hazmat.checkpoint()
async def connected(self, network: INetwork, conn: INetConn) -> None: async def connected(self, network: INetwork, conn: INetConn) -> None:
""" """
@ -46,7 +46,11 @@ class PubsubNotifee(INotifee):
:param network: network the connection was opened on :param network: network the connection was opened on
:param conn: connection that was opened :param conn: connection that was opened
""" """
await self.initiator_peers_queue.put(conn.muxed_conn.peer_id) try:
await self.initiator_peers_queue.send(conn.muxed_conn.peer_id)
except trio.BrokenResourceError:
# The receive channel is closed by Pubsub. We should do nothing here.
pass
async def disconnected(self, network: INetwork, conn: INetConn) -> None: async def disconnected(self, network: INetwork, conn: INetConn) -> None:
""" """
@ -56,10 +60,14 @@ class PubsubNotifee(INotifee):
:param network: network the connection was opened on :param network: network the connection was opened on
:param conn: connection that was opened :param conn: connection that was opened
""" """
await self.dead_peers_queue.put(conn.muxed_conn.peer_id) try:
await self.dead_peers_queue.send(conn.muxed_conn.peer_id)
except trio.BrokenResourceError:
# The receive channel is closed by Pubsub. We should do nothing here.
pass
async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None: async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None:
pass await trio.hazmat.checkpoint()
async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None: async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None:
pass await trio.hazmat.checkpoint()

View File

@ -0,0 +1,46 @@
from types import TracebackType
from typing import AsyncIterator, Optional, Type
import trio
from .abc import ISubscriptionAPI
from .pb import rpc_pb2
from .typing import UnsubscribeFn
class BaseSubscriptionAPI(ISubscriptionAPI):
async def __aenter__(self) -> "BaseSubscriptionAPI":
await trio.hazmat.checkpoint()
return self
async def __aexit__(
self,
exc_type: "Optional[Type[BaseException]]",
exc_value: "Optional[BaseException]",
traceback: "Optional[TracebackType]",
) -> None:
await self.unsubscribe()
class TrioSubscriptionAPI(BaseSubscriptionAPI):
receive_channel: "trio.MemoryReceiveChannel[rpc_pb2.Message]"
unsubscribe_fn: UnsubscribeFn
def __init__(
self,
receive_channel: "trio.MemoryReceiveChannel[rpc_pb2.Message]",
unsubscribe_fn: UnsubscribeFn,
) -> None:
self.receive_channel = receive_channel
# Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427
self.unsubscribe_fn = unsubscribe_fn # type: ignore
async def unsubscribe(self) -> None:
# Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427
await self.unsubscribe_fn() # type: ignore
def __aiter__(self) -> AsyncIterator[rpc_pb2.Message]:
return self.receive_channel.__aiter__()
async def get(self) -> rpc_pb2.Message:
return await self.receive_channel.receive()

11
libp2p/pubsub/typing.py Normal file
View File

@ -0,0 +1,11 @@
from typing import Awaitable, Callable, Union
from libp2p.peer.id import ID
from .pb import rpc_pb2
SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool]
AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]]
ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn]
UnsubscribeFn = Callable[[], Awaitable[None]]

View File

@ -39,7 +39,7 @@ class InsecureSession(BaseSession):
await self.conn.write(data) await self.conn.write(data)
return len(data) return len(data)
async def read(self, n: int = -1) -> bytes: async def read(self, n: int = None) -> bytes:
return await self.conn.read(n) return await self.conn.read(n)
async def close(self) -> None: async def close(self) -> None:

View File

@ -94,7 +94,7 @@ class SecureSession(BaseSession):
data = self.buf.getbuffer()[self.low_watermark : self.high_watermark] data = self.buf.getbuffer()[self.low_watermark : self.high_watermark]
if n < 0: if n is None:
n = len(data) n = len(data)
result = data[:n].tobytes() result = data[:n].tobytes()
self.low_watermark += len(result) self.low_watermark += len(result)
@ -111,7 +111,7 @@ class SecureSession(BaseSession):
self.low_watermark = 0 self.low_watermark = 0
self.high_watermark = len(msg) self.high_watermark = len(msg)
async def read(self, n: int = -1) -> bytes: async def read(self, n: int = None) -> bytes:
if n == 0: if n == 0:
return bytes() return bytes()

View File

@ -1,5 +1,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import trio
from libp2p.io.abc import ReadWriteCloser from libp2p.io.abc import ReadWriteCloser
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
@ -11,6 +13,7 @@ class IMuxedConn(ABC):
""" """
peer_id: ID peer_id: ID
event_started: trio.Event
@abstractmethod @abstractmethod
def __init__(self, conn: ISecureConn, peer_id: ID) -> None: def __init__(self, conn: ISecureConn, peer_id: ID) -> None:
@ -25,12 +28,17 @@ class IMuxedConn(ABC):
@property @property
@abstractmethod @abstractmethod
def is_initiator(self) -> bool: def is_initiator(self) -> bool:
pass """if this connection is the initiator."""
@abstractmethod
async def start(self) -> None:
"""start the multiplexer."""
@abstractmethod @abstractmethod
async def close(self) -> None: async def close(self) -> None:
"""close connection.""" """close connection."""
@property
@abstractmethod @abstractmethod
def is_closed(self) -> bool: def is_closed(self) -> bool:
""" """

View File

@ -1,7 +1,7 @@
import asyncio
import logging import logging
from typing import Any # noqa: F401 from typing import Dict, Optional, Tuple
from typing import Awaitable, Dict, List, Optional, Tuple
import trio
from libp2p.exceptions import ParseError from libp2p.exceptions import ParseError
from libp2p.io.exceptions import IncompleteReadError from libp2p.io.exceptions import IncompleteReadError
@ -23,6 +23,8 @@ from .exceptions import MplexUnavailable
from .mplex_stream import MplexStream from .mplex_stream import MplexStream
MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0") MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0")
# Ref: https://github.com/libp2p/go-mplex/blob/414db61813d9ad3e6f4a7db5c1b1612de343ace9/multiplex.go#L115 # noqa: E501
MPLEX_MESSAGE_CHANNEL_SIZE = 8
logger = logging.getLogger("libp2p.stream_muxer.mplex.mplex") logger = logging.getLogger("libp2p.stream_muxer.mplex.mplex")
@ -36,12 +38,14 @@ class Mplex(IMuxedConn):
peer_id: ID peer_id: ID
next_channel_id: int next_channel_id: int
streams: Dict[StreamID, MplexStream] streams: Dict[StreamID, MplexStream]
streams_lock: asyncio.Lock streams_lock: trio.Lock
new_stream_queue: "asyncio.Queue[IMuxedStream]" streams_msg_channels: Dict[StreamID, "trio.MemorySendChannel[bytes]"]
event_shutting_down: asyncio.Event new_stream_send_channel: "trio.MemorySendChannel[IMuxedStream]"
event_closed: asyncio.Event new_stream_receive_channel: "trio.MemoryReceiveChannel[IMuxedStream]"
_tasks: List["asyncio.Future[Any]"] event_shutting_down: trio.Event
event_closed: trio.Event
event_started: trio.Event
def __init__(self, secured_conn: ISecureConn, peer_id: ID) -> None: def __init__(self, secured_conn: ISecureConn, peer_id: ID) -> None:
""" """
@ -61,15 +65,16 @@ class Mplex(IMuxedConn):
# Mapping from stream ID -> buffer of messages for that stream # Mapping from stream ID -> buffer of messages for that stream
self.streams = {} self.streams = {}
self.streams_lock = asyncio.Lock() self.streams_lock = trio.Lock()
self.new_stream_queue = asyncio.Queue() self.streams_msg_channels = {}
self.event_shutting_down = asyncio.Event() channels = trio.open_memory_channel[IMuxedStream](0)
self.event_closed = asyncio.Event() self.new_stream_send_channel, self.new_stream_receive_channel = channels
self.event_shutting_down = trio.Event()
self.event_closed = trio.Event()
self.event_started = trio.Event()
self._tasks = [] async def start(self) -> None:
await self.handle_incoming()
# Kick off reading
self._tasks.append(asyncio.ensure_future(self.handle_incoming()))
@property @property
def is_initiator(self) -> bool: def is_initiator(self) -> bool:
@ -85,6 +90,7 @@ class Mplex(IMuxedConn):
# Blocked until `close` is finally set. # Blocked until `close` is finally set.
await self.event_closed.wait() await self.event_closed.wait()
@property
def is_closed(self) -> bool: def is_closed(self) -> bool:
""" """
check connection is fully closed. check connection is fully closed.
@ -104,9 +110,13 @@ class Mplex(IMuxedConn):
return next_id return next_id
async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream: async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream:
stream = MplexStream(name, stream_id, self) send_channel, receive_channel = trio.open_memory_channel[bytes](
MPLEX_MESSAGE_CHANNEL_SIZE
)
stream = MplexStream(name, stream_id, self, receive_channel)
async with self.streams_lock: async with self.streams_lock:
self.streams[stream_id] = stream self.streams[stream_id] = stream
self.streams_msg_channels[stream_id] = send_channel
return stream return stream
async def open_stream(self) -> IMuxedStream: async def open_stream(self) -> IMuxedStream:
@ -123,27 +133,12 @@ class Mplex(IMuxedConn):
await self.send_message(HeaderTags.NewStream, name.encode(), stream_id) await self.send_message(HeaderTags.NewStream, name.encode(), stream_id)
return stream return stream
async def _wait_until_shutting_down_or_closed(self, coro: Awaitable[Any]) -> Any:
task_coro = asyncio.ensure_future(coro)
task_wait_closed = asyncio.ensure_future(self.event_closed.wait())
task_wait_shutting_down = asyncio.ensure_future(self.event_shutting_down.wait())
done, pending = await asyncio.wait(
[task_coro, task_wait_closed, task_wait_shutting_down],
return_when=asyncio.FIRST_COMPLETED,
)
for fut in pending:
fut.cancel()
if task_wait_closed in done:
raise MplexUnavailable("Mplex is closed")
if task_wait_shutting_down in done:
raise MplexUnavailable("Mplex is shutting down")
return task_coro.result()
async def accept_stream(self) -> IMuxedStream: async def accept_stream(self) -> IMuxedStream:
"""accepts a muxed stream opened by the other end.""" """accepts a muxed stream opened by the other end."""
return await self._wait_until_shutting_down_or_closed( try:
self.new_stream_queue.get() return await self.new_stream_receive_channel.receive()
) except trio.EndOfChannel:
raise MplexUnavailable
async def send_message( async def send_message(
self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID
@ -151,7 +146,7 @@ class Mplex(IMuxedConn):
""" """
sends a message over the connection. sends a message over the connection.
:param header: header to use :param flag: header to use
:param data: data to send in the message :param data: data to send in the message
:param stream_id: stream the message is in :param stream_id: stream the message is in
""" """
@ -163,9 +158,7 @@ class Mplex(IMuxedConn):
_bytes = header + encode_varint_prefixed(data) _bytes = header + encode_varint_prefixed(data)
return await self._wait_until_shutting_down_or_closed( return await self.write_to_stream(_bytes)
self.write_to_stream(_bytes)
)
async def write_to_stream(self, _bytes: bytes) -> int: async def write_to_stream(self, _bytes: bytes) -> int:
""" """
@ -174,21 +167,25 @@ class Mplex(IMuxedConn):
:param _bytes: byte array to write :param _bytes: byte array to write
:return: length written :return: length written
""" """
await self.secured_conn.write(_bytes) try:
await self.secured_conn.write(_bytes)
except RawConnError as e:
raise MplexUnavailable(
"failed to write message to the underlying connection"
) from e
return len(_bytes) return len(_bytes)
async def handle_incoming(self) -> None: async def handle_incoming(self) -> None:
"""Read a message off of the secured connection and add it to the """Read a message off of the secured connection and add it to the
corresponding message buffer.""" corresponding message buffer."""
self.event_started.set()
while True: while True:
try: try:
await self._handle_incoming_message() await self._handle_incoming_message()
except MplexUnavailable as e: except MplexUnavailable as e:
logger.debug("mplex unavailable while waiting for incoming: %s", e) logger.debug("mplex unavailable while waiting for incoming: %s", e)
break break
# Force context switch
await asyncio.sleep(0)
# If we enter here, it means this connection is shutting down. # If we enter here, it means this connection is shutting down.
# We should clean things up. # We should clean things up.
await self._cleanup() await self._cleanup()
@ -200,20 +197,19 @@ class Mplex(IMuxedConn):
:return: stream_id, flag, message contents :return: stream_id, flag, message contents
""" """
# FIXME: No timeout is used in Go implementation.
try: try:
header = await decode_uvarint_from_stream(self.secured_conn) header = await decode_uvarint_from_stream(self.secured_conn)
message = await asyncio.wait_for(
read_varint_prefixed_bytes(self.secured_conn), timeout=5
)
except (ParseError, RawConnError, IncompleteReadError) as error: except (ParseError, RawConnError, IncompleteReadError) as error:
raise MplexUnavailable( raise MplexUnavailable(
"failed to read messages correctly from the underlying connection" f"failed to read the header correctly from the underlying connection: {error}"
) from error )
except asyncio.TimeoutError as error: try:
message = await read_varint_prefixed_bytes(self.secured_conn)
except (ParseError, RawConnError, IncompleteReadError) as error:
raise MplexUnavailable( raise MplexUnavailable(
"failed to read more message body within the timeout" "failed to read the message body correctly from the underlying connection: "
) from error f"{error}"
)
flag = header & 0x07 flag = header & 0x07
channel_id = header >> 3 channel_id = header >> 3
@ -226,9 +222,7 @@ class Mplex(IMuxedConn):
:raise MplexUnavailable: `Mplex` encounters fatal error or is shutting down. :raise MplexUnavailable: `Mplex` encounters fatal error or is shutting down.
""" """
channel_id, flag, message = await self._wait_until_shutting_down_or_closed( channel_id, flag, message = await self.read_message()
self.read_message()
)
stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1)) stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1))
if flag == HeaderTags.NewStream.value: if flag == HeaderTags.NewStream.value:
@ -258,9 +252,10 @@ class Mplex(IMuxedConn):
f"received NewStream message for existing stream: {stream_id}" f"received NewStream message for existing stream: {stream_id}"
) )
mplex_stream = await self._initialize_stream(stream_id, message.decode()) mplex_stream = await self._initialize_stream(stream_id, message.decode())
await self._wait_until_shutting_down_or_closed( try:
self.new_stream_queue.put(mplex_stream) await self.new_stream_send_channel.send(mplex_stream)
) except trio.ClosedResourceError:
raise MplexUnavailable
async def _handle_message(self, stream_id: StreamID, message: bytes) -> None: async def _handle_message(self, stream_id: StreamID, message: bytes) -> None:
async with self.streams_lock: async with self.streams_lock:
@ -270,13 +265,21 @@ class Mplex(IMuxedConn):
# TODO: Warn and emit logs about this. # TODO: Warn and emit logs about this.
return return
stream = self.streams[stream_id] stream = self.streams[stream_id]
send_channel = self.streams_msg_channels[stream_id]
async with stream.close_lock: async with stream.close_lock:
if stream.event_remote_closed.is_set(): if stream.event_remote_closed.is_set():
# TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501 # TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501
return return
await self._wait_until_shutting_down_or_closed( try:
stream.incoming_data.put(message) send_channel.send_nowait(message)
) except (trio.BrokenResourceError, trio.ClosedResourceError):
raise MplexUnavailable
except trio.WouldBlock:
# `send_channel` is full, reset this stream.
logger.warning(
"message channel of stream %s is full: stream is reset", stream_id
)
await stream.reset()
async def _handle_close(self, stream_id: StreamID) -> None: async def _handle_close(self, stream_id: StreamID) -> None:
async with self.streams_lock: async with self.streams_lock:
@ -284,6 +287,8 @@ class Mplex(IMuxedConn):
# Ignore unmatched messages for now. # Ignore unmatched messages for now.
return return
stream = self.streams[stream_id] stream = self.streams[stream_id]
send_channel = self.streams_msg_channels[stream_id]
await send_channel.aclose()
# NOTE: If remote is already closed, then return: Technically a bug # NOTE: If remote is already closed, then return: Technically a bug
# on the other side. We should consider killing the connection. # on the other side. We should consider killing the connection.
async with stream.close_lock: async with stream.close_lock:
@ -305,27 +310,30 @@ class Mplex(IMuxedConn):
# This is *ok*. We forget the stream on reset. # This is *ok*. We forget the stream on reset.
return return
stream = self.streams[stream_id] stream = self.streams[stream_id]
send_channel = self.streams_msg_channels[stream_id]
await send_channel.aclose()
async with stream.close_lock: async with stream.close_lock:
if not stream.event_remote_closed.is_set(): if not stream.event_remote_closed.is_set():
stream.event_reset.set() stream.event_reset.set()
stream.event_remote_closed.set() stream.event_remote_closed.set()
# If local is not closed, we should close it. # If local is not closed, we should close it.
if not stream.event_local_closed.is_set(): if not stream.event_local_closed.is_set():
stream.event_local_closed.set() stream.event_local_closed.set()
async with self.streams_lock: async with self.streams_lock:
self.streams.pop(stream_id, None) self.streams.pop(stream_id, None)
self.streams_msg_channels.pop(stream_id, None)
async def _cleanup(self) -> None: async def _cleanup(self) -> None:
if not self.event_shutting_down.is_set(): if not self.event_shutting_down.is_set():
self.event_shutting_down.set() self.event_shutting_down.set()
async with self.streams_lock: async with self.streams_lock:
for stream in self.streams.values(): for stream_id, stream in self.streams.items():
async with stream.close_lock: async with stream.close_lock:
if not stream.event_remote_closed.is_set(): if not stream.event_remote_closed.is_set():
stream.event_remote_closed.set() stream.event_remote_closed.set()
stream.event_reset.set() stream.event_reset.set()
stream.event_local_closed.set() stream.event_local_closed.set()
self.streams = None send_channel = self.streams_msg_channels[stream_id]
await send_channel.aclose()
self.event_closed.set() self.event_closed.set()
await self.new_stream_send_channel.aclose()

View File

@ -1,7 +1,9 @@
import asyncio
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import trio
from libp2p.stream_muxer.abc import IMuxedStream from libp2p.stream_muxer.abc import IMuxedStream
from libp2p.stream_muxer.exceptions import MuxedConnUnavailable
from .constants import HeaderTags from .constants import HeaderTags
from .datastructures import StreamID from .datastructures import StreamID
@ -22,18 +24,25 @@ class MplexStream(IMuxedStream):
read_deadline: int read_deadline: int
write_deadline: int write_deadline: int
close_lock: asyncio.Lock # TODO: Add lock for read/write to avoid interleaving receiving messages?
close_lock: trio.Lock
# NOTE: `dataIn` is size of 8 in Go implementation. # NOTE: `dataIn` is size of 8 in Go implementation.
incoming_data: "asyncio.Queue[bytes]" incoming_data_channel: "trio.MemoryReceiveChannel[bytes]"
event_local_closed: asyncio.Event event_local_closed: trio.Event
event_remote_closed: asyncio.Event event_remote_closed: trio.Event
event_reset: asyncio.Event event_reset: trio.Event
_buf: bytearray _buf: bytearray
def __init__(self, name: str, stream_id: StreamID, muxed_conn: "Mplex") -> None: def __init__(
self,
name: str,
stream_id: StreamID,
muxed_conn: "Mplex",
incoming_data_channel: "trio.MemoryReceiveChannel[bytes]",
) -> None:
""" """
create new MuxedStream in muxer. create new MuxedStream in muxer.
@ -45,99 +54,82 @@ class MplexStream(IMuxedStream):
self.muxed_conn = muxed_conn self.muxed_conn = muxed_conn
self.read_deadline = None self.read_deadline = None
self.write_deadline = None self.write_deadline = None
self.event_local_closed = asyncio.Event() self.event_local_closed = trio.Event()
self.event_remote_closed = asyncio.Event() self.event_remote_closed = trio.Event()
self.event_reset = asyncio.Event() self.event_reset = trio.Event()
self.close_lock = asyncio.Lock() self.close_lock = trio.Lock()
self.incoming_data = asyncio.Queue() self.incoming_data_channel = incoming_data_channel
self._buf = bytearray() self._buf = bytearray()
@property @property
def is_initiator(self) -> bool: def is_initiator(self) -> bool:
return self.stream_id.is_initiator return self.stream_id.is_initiator
async def _wait_for_data(self) -> None:
task_event_reset = asyncio.ensure_future(self.event_reset.wait())
task_incoming_data_get = asyncio.ensure_future(self.incoming_data.get())
task_event_remote_closed = asyncio.ensure_future(
self.event_remote_closed.wait()
)
done, pending = await asyncio.wait( # type: ignore
[ # type: ignore
task_event_reset,
task_incoming_data_get,
task_event_remote_closed,
],
return_when=asyncio.FIRST_COMPLETED,
)
for fut in pending:
fut.cancel()
if task_event_reset in done:
if self.event_reset.is_set():
raise MplexStreamReset()
else:
# However, it is abnormal that `Event.wait` is unblocked without any of the flag
# is set. The task is probably cancelled.
raise Exception(
"Should not enter here. "
f"It is probably because {task_event_remote_closed} is cancelled."
)
if task_incoming_data_get in done:
data = task_incoming_data_get.result()
self._buf.extend(data)
return
if task_event_remote_closed in done:
if self.event_remote_closed.is_set():
raise MplexStreamEOF()
else:
# However, it is abnormal that `Event.wait` is unblocked without any of the flag
# is set. The task is probably cancelled.
raise Exception(
"Should not enter here. "
f"It is probably because {task_event_remote_closed} is cancelled."
)
# TODO: Handle timeout when deadline is used.
async def _read_until_eof(self) -> bytes: async def _read_until_eof(self) -> bytes:
while True: async for data in self.incoming_data_channel:
try: self._buf.extend(data)
await self._wait_for_data()
except MplexStreamEOF:
break
payload = self._buf payload = self._buf
self._buf = self._buf[len(payload) :] self._buf = self._buf[len(payload) :]
return bytes(payload) return bytes(payload)
async def read(self, n: int = -1) -> bytes: def _read_return_when_blocked(self) -> bytes:
buf = bytearray()
while True:
try:
data = self.incoming_data_channel.receive_nowait()
buf.extend(data)
except (trio.WouldBlock, trio.EndOfChannel):
break
return buf
async def read(self, n: int = None) -> bytes:
""" """
Read up to n bytes. Read possibly returns fewer than `n` bytes, if Read up to n bytes. Read possibly returns fewer than `n` bytes, if
there are not enough bytes in the Mplex buffer. If `n == -1`, read there are not enough bytes in the Mplex buffer. If `n is None`, read
until EOF. until EOF.
:param n: number of bytes to read :param n: number of bytes to read
:return: bytes actually read :return: bytes actually read
""" """
if n < 0 and n != -1: if n is not None and n < 0:
raise ValueError( raise ValueError(
f"the number of bytes to read `n` must be positive or -1 to indicate read until EOF" f"the number of bytes to read `n` must be non-negative or "
"`None` to indicate read until EOF"
) )
if self.event_reset.is_set(): if self.event_reset.is_set():
raise MplexStreamReset() raise MplexStreamReset
if n == -1: if n is None:
return await self._read_until_eof() return await self._read_until_eof()
if len(self._buf) == 0 and self.incoming_data.empty(): if len(self._buf) == 0:
await self._wait_for_data() data: bytes
# Now we are sure we have something to read. # Peek whether there is data available. If yes, we just read until there is no data,
# Try to put enough incoming data into `self._buf`. # and then return.
while len(self._buf) < n:
try: try:
self._buf.extend(self.incoming_data.get_nowait()) data = self.incoming_data_channel.receive_nowait()
except asyncio.QueueEmpty: self._buf.extend(data)
break except trio.EndOfChannel:
raise MplexStreamEOF
except trio.WouldBlock:
# We know `receive` will be blocked here. Wait for data here with `receive` and
# catch all kinds of errors here.
try:
data = await self.incoming_data_channel.receive()
self._buf.extend(data)
except trio.EndOfChannel:
if self.event_reset.is_set():
raise MplexStreamReset
if self.event_remote_closed.is_set():
raise MplexStreamEOF
except trio.ClosedResourceError as error:
# Probably `incoming_data_channel` is closed in `reset` when we are waiting
# for `receive`.
if self.event_reset.is_set():
raise MplexStreamReset
raise Exception(
"`incoming_data_channel` is closed but stream is not reset. "
"This should never happen."
) from error
self._buf.extend(self._read_return_when_blocked())
payload = self._buf[:n] payload = self._buf[:n]
self._buf = self._buf[len(payload) :] self._buf = self._buf[len(payload) :]
return bytes(payload) return bytes(payload)
@ -198,14 +190,17 @@ class MplexStream(IMuxedStream):
if self.is_initiator if self.is_initiator
else HeaderTags.ResetReceiver else HeaderTags.ResetReceiver
) )
asyncio.ensure_future( # Try to send reset message to the other side. Ignore if there is anything wrong.
self.muxed_conn.send_message(flag, None, self.stream_id) try:
) await self.muxed_conn.send_message(flag, None, self.stream_id)
await asyncio.sleep(0) except MuxedConnUnavailable:
pass
self.event_local_closed.set() self.event_local_closed.set()
self.event_remote_closed.set() self.event_remote_closed.set()
await self.incoming_data_channel.aclose()
async with self.muxed_conn.streams_lock: async with self.muxed_conn.streams_lock:
if self.muxed_conn.streams is not None: if self.muxed_conn.streams is not None:
self.muxed_conn.streams.pop(self.stream_id, None) self.muxed_conn.streams.pop(self.stream_id, None)

View File

@ -7,7 +7,7 @@ from libp2p.pubsub import floodsub, gossipsub
# Just a arbitrary large number. # Just a arbitrary large number.
# It is used when calling `MplexStream.read(MAX_READ_LEN)`, # It is used when calling `MplexStream.read(MAX_READ_LEN)`,
# to avoid `MplexStream.read()`, which blocking reads until EOF. # to avoid `MplexStream.read()`, which blocking reads until EOF.
MAX_READ_LEN = 2 ** 32 - 1 MAX_READ_LEN = 65535
LISTEN_MADDR = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0") LISTEN_MADDR = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")

View File

@ -1,40 +1,55 @@
import asyncio from typing import Any, AsyncIterator, Dict, List, Sequence, Tuple, cast
from typing import Any, AsyncIterator, Dict, Tuple, cast
# NOTE: import ``asynccontextmanager`` from ``contextlib`` when support for python 3.6 is dropped. from async_exit_stack import AsyncExitStack
from async_generator import asynccontextmanager from async_generator import asynccontextmanager
from async_service import background_trio_service
import factory import factory
from multiaddr import Multiaddr
import trio
from libp2p import generate_new_rsa_identity, generate_peer_id_from from libp2p import generate_new_rsa_identity, generate_peer_id_from
from libp2p.crypto.keys import KeyPair from libp2p.crypto.keys import KeyPair
from libp2p.host.basic_host import BasicHost from libp2p.host.basic_host import BasicHost
from libp2p.host.host_interface import IHost
from libp2p.host.routed_host import RoutedHost
from libp2p.io.abc import ReadWriteCloser
from libp2p.network.connection.raw_connection import RawConnection
from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.network.connection.swarm_connection import SwarmConn from libp2p.network.connection.swarm_connection import SwarmConn
from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.network.swarm import Swarm from libp2p.network.swarm import Swarm
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.peer.peerinfo import PeerInfo
from libp2p.peer.peerstore import PeerStore from libp2p.peer.peerstore import PeerStore
from libp2p.pubsub.abc import IPubsubRouter
from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.floodsub import FloodSub
from libp2p.pubsub.gossipsub import GossipSub from libp2p.pubsub.gossipsub import GossipSub
from libp2p.pubsub.pubsub import Pubsub from libp2p.pubsub.pubsub import Pubsub
from libp2p.routing.interfaces import IPeerRouting
from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.base_transport import BaseSecureTransport
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
import libp2p.security.secio.transport as secio import libp2p.security.secio.transport as secio
from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex
from libp2p.stream_muxer.mplex.mplex_stream import MplexStream from libp2p.stream_muxer.mplex.mplex_stream import MplexStream
from libp2p.tools.constants import GOSSIPSUB_PARAMS
from libp2p.transport.tcp.tcp import TCP from libp2p.transport.tcp.tcp import TCP
from libp2p.transport.typing import TMuxerOptions from libp2p.transport.typing import TMuxerOptions
from libp2p.transport.upgrader import TransportUpgrader from libp2p.transport.upgrader import TransportUpgrader
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from .constants import ( from .constants import FLOODSUB_PROTOCOL_ID, GOSSIPSUB_PROTOCOL_ID, LISTEN_MADDR
FLOODSUB_PROTOCOL_ID,
GOSSIPSUB_PARAMS,
GOSSIPSUB_PROTOCOL_ID,
LISTEN_MADDR,
)
from .utils import connect, connect_swarm from .utils import connect, connect_swarm
class IDFactory(factory.Factory):
class Meta:
model = ID
peer_id_bytes = factory.LazyFunction(
lambda: generate_peer_id_from(generate_new_rsa_identity())
)
def initialize_peerstore_with_our_keypair(self_id: ID, key_pair: KeyPair) -> PeerStore: def initialize_peerstore_with_our_keypair(self_id: ID, key_pair: KeyPair) -> PeerStore:
peer_store = PeerStore() peer_store = PeerStore()
peer_store.add_key_pair(self_id, key_pair) peer_store.add_key_pair(self_id, key_pair)
@ -50,6 +65,29 @@ def security_transport_factory(
return {secio.ID: secio.Transport(key_pair)} return {secio.ID: secio.Transport(key_pair)}
@asynccontextmanager
async def raw_conn_factory(
nursery: trio.Nursery
) -> AsyncIterator[Tuple[IRawConnection, IRawConnection]]:
conn_0 = None
conn_1 = None
event = trio.Event()
async def tcp_stream_handler(stream: ReadWriteCloser) -> None:
nonlocal conn_1
conn_1 = RawConnection(stream, initiator=False)
event.set()
await trio.sleep_forever()
tcp_transport = TCP()
listener = tcp_transport.create_listener(tcp_stream_handler)
await listener.listen(LISTEN_MADDR, nursery)
listening_maddr = listener.get_addrs()[0]
conn_0 = await tcp_transport.dial(listening_maddr)
await event.wait()
yield conn_0, conn_1
class SwarmFactory(factory.Factory): class SwarmFactory(factory.Factory):
class Meta: class Meta:
model = Swarm model = Swarm
@ -71,9 +109,10 @@ class SwarmFactory(factory.Factory):
transport = factory.LazyFunction(TCP) transport = factory.LazyFunction(TCP)
@classmethod @classmethod
@asynccontextmanager
async def create_and_listen( async def create_and_listen(
cls, is_secure: bool, key_pair: KeyPair = None, muxer_opt: TMuxerOptions = None cls, is_secure: bool, key_pair: KeyPair = None, muxer_opt: TMuxerOptions = None
) -> Swarm: ) -> AsyncIterator[Swarm]:
# `factory.Factory.__init__` does *not* prepare a *default value* if we pass # `factory.Factory.__init__` does *not* prepare a *default value* if we pass
# an argument explicitly with `None`. If an argument is `None`, we don't pass it to # an argument explicitly with `None`. If an argument is `None`, we don't pass it to
# `factory.Factory.__init__`, in order to let the function initialize it. # `factory.Factory.__init__`, in order to let the function initialize it.
@ -83,20 +122,23 @@ class SwarmFactory(factory.Factory):
if muxer_opt is not None: if muxer_opt is not None:
optional_kwargs["muxer_opt"] = muxer_opt optional_kwargs["muxer_opt"] = muxer_opt
swarm = cls(is_secure=is_secure, **optional_kwargs) swarm = cls(is_secure=is_secure, **optional_kwargs)
await swarm.listen(LISTEN_MADDR) async with background_trio_service(swarm):
return swarm await swarm.listen(LISTEN_MADDR)
yield swarm
@classmethod @classmethod
@asynccontextmanager
async def create_batch_and_listen( async def create_batch_and_listen(
cls, is_secure: bool, number: int, muxer_opt: TMuxerOptions = None cls, is_secure: bool, number: int, muxer_opt: TMuxerOptions = None
) -> Tuple[Swarm, ...]: ) -> AsyncIterator[Tuple[Swarm, ...]]:
# Ignore typing since we are removing asyncio soon async with AsyncExitStack() as stack:
return await asyncio.gather( # type: ignore ctx_mgrs = [
*[ await stack.enter_async_context(
cls.create_and_listen(is_secure=is_secure, muxer_opt=muxer_opt) cls.create_and_listen(is_secure=is_secure, muxer_opt=muxer_opt)
)
for _ in range(number) for _ in range(number)
] ]
) yield tuple(ctx_mgrs)
class HostFactory(factory.Factory): class HostFactory(factory.Factory):
@ -107,22 +149,57 @@ class HostFactory(factory.Factory):
is_secure = False is_secure = False
key_pair = factory.LazyFunction(generate_new_rsa_identity) key_pair = factory.LazyFunction(generate_new_rsa_identity)
network = factory.LazyAttribute( network = factory.LazyAttribute(lambda o: SwarmFactory(is_secure=o.is_secure))
lambda o: SwarmFactory(is_secure=o.is_secure, key_pair=o.key_pair)
)
@classmethod @classmethod
@asynccontextmanager
async def create_batch_and_listen( async def create_batch_and_listen(
cls, is_secure: bool, number: int cls, is_secure: bool, number: int
) -> Tuple[BasicHost, ...]: ) -> AsyncIterator[Tuple[BasicHost, ...]]:
key_pairs = [generate_new_rsa_identity() for _ in range(number)] async with SwarmFactory.create_batch_and_listen(is_secure, number) as swarms:
swarms = await asyncio.gather( hosts = tuple(BasicHost(swarm) for swarm in swarms)
*[ yield hosts
SwarmFactory.create_and_listen(is_secure, key_pair)
for key_pair in key_pairs
] class DummyRouter(IPeerRouting):
) _routing_table: Dict[ID, PeerInfo]
return tuple(BasicHost(swarm) for swarm in swarms)
def __init__(self) -> None:
self._routing_table = dict()
def _add_peer(self, peer_id: ID, addrs: List[Multiaddr]) -> None:
self._routing_table[peer_id] = PeerInfo(peer_id, addrs)
async def find_peer(self, peer_id: ID) -> PeerInfo:
await trio.hazmat.checkpoint()
return self._routing_table.get(peer_id, None)
class RoutedHostFactory(factory.Factory):
class Meta:
model = RoutedHost
class Params:
is_secure = False
network = factory.LazyAttribute(
lambda o: HostFactory(is_secure=o.is_secure).get_network()
)
router = factory.LazyFunction(DummyRouter)
@classmethod
@asynccontextmanager
async def create_batch_and_listen(
cls, is_secure: bool, number: int
) -> AsyncIterator[Tuple[RoutedHost, ...]]:
routing_table = DummyRouter()
async with HostFactory.create_batch_and_listen(is_secure, number) as hosts:
for host in hosts:
routing_table._add_peer(host.get_id(), host.get_addrs())
routed_hosts = tuple(
RoutedHost(host.get_network(), routing_table) for host in hosts
)
yield routed_hosts
class FloodsubFactory(factory.Factory): class FloodsubFactory(factory.Factory):
@ -153,89 +230,192 @@ class PubsubFactory(factory.Factory):
host = factory.SubFactory(HostFactory) host = factory.SubFactory(HostFactory)
router = None router = None
my_id = factory.LazyAttribute(lambda obj: obj.host.get_id())
cache_size = None cache_size = None
strict_signing = False strict_signing = False
@classmethod
@asynccontextmanager
async def create_and_start(
cls, host: IHost, router: IPubsubRouter, cache_size: int, strict_signing: bool
) -> AsyncIterator[Pubsub]:
pubsub = cls(
host=host,
router=router,
cache_size=cache_size,
strict_signing=strict_signing,
)
async with background_trio_service(pubsub):
await pubsub.wait_until_ready()
yield pubsub
@classmethod
@asynccontextmanager
async def _create_batch_with_router(
cls,
number: int,
routers: Sequence[IPubsubRouter],
is_secure: bool = False,
cache_size: int = None,
strict_signing: bool = False,
) -> AsyncIterator[Tuple[Pubsub, ...]]:
async with HostFactory.create_batch_and_listen(is_secure, number) as hosts:
# Pubsubs should exit before hosts
async with AsyncExitStack() as stack:
pubsubs = [
await stack.enter_async_context(
cls.create_and_start(host, router, cache_size, strict_signing)
)
for host, router in zip(hosts, routers)
]
yield tuple(pubsubs)
@classmethod
@asynccontextmanager
async def create_batch_with_floodsub(
cls,
number: int,
is_secure: bool = False,
cache_size: int = None,
strict_signing: bool = False,
protocols: Sequence[TProtocol] = None,
) -> AsyncIterator[Tuple[Pubsub, ...]]:
if protocols is not None:
floodsubs = FloodsubFactory.create_batch(number, protocols=list(protocols))
else:
floodsubs = FloodsubFactory.create_batch(number)
async with cls._create_batch_with_router(
number, floodsubs, is_secure, cache_size, strict_signing
) as pubsubs:
yield pubsubs
@classmethod
@asynccontextmanager
async def create_batch_with_gossipsub(
cls,
number: int,
*,
is_secure: bool = False,
cache_size: int = None,
strict_signing: bool = False,
protocols: Sequence[TProtocol] = None,
degree: int = GOSSIPSUB_PARAMS.degree,
degree_low: int = GOSSIPSUB_PARAMS.degree_low,
degree_high: int = GOSSIPSUB_PARAMS.degree_high,
time_to_live: int = GOSSIPSUB_PARAMS.time_to_live,
gossip_window: int = GOSSIPSUB_PARAMS.gossip_window,
gossip_history: int = GOSSIPSUB_PARAMS.gossip_history,
heartbeat_interval: float = GOSSIPSUB_PARAMS.heartbeat_interval,
heartbeat_initial_delay: float = GOSSIPSUB_PARAMS.heartbeat_initial_delay,
) -> AsyncIterator[Tuple[Pubsub, ...]]:
if protocols is not None:
gossipsubs = GossipsubFactory.create_batch(
number,
protocols=protocols,
degree=degree,
degree_low=degree_low,
degree_high=degree_high,
time_to_live=time_to_live,
gossip_window=gossip_window,
heartbeat_interval=heartbeat_interval,
)
else:
gossipsubs = GossipsubFactory.create_batch(
number,
degree=degree,
degree_low=degree_low,
degree_high=degree_high,
time_to_live=time_to_live,
gossip_window=gossip_window,
heartbeat_interval=heartbeat_interval,
)
async with cls._create_batch_with_router(
number, gossipsubs, is_secure, cache_size, strict_signing
) as pubsubs:
async with AsyncExitStack() as stack:
for router in gossipsubs:
await stack.enter_async_context(background_trio_service(router))
yield pubsubs
@asynccontextmanager
async def swarm_pair_factory( async def swarm_pair_factory(
is_secure: bool, muxer_opt: TMuxerOptions = None is_secure: bool, muxer_opt: TMuxerOptions = None
) -> Tuple[Swarm, Swarm]: ) -> AsyncIterator[Tuple[Swarm, Swarm]]:
swarms = await SwarmFactory.create_batch_and_listen( async with SwarmFactory.create_batch_and_listen(
is_secure, 2, muxer_opt=muxer_opt is_secure, 2, muxer_opt=muxer_opt
) ) as swarms:
await connect_swarm(swarms[0], swarms[1]) await connect_swarm(swarms[0], swarms[1])
return swarms[0], swarms[1] yield swarms[0], swarms[1]
async def host_pair_factory(is_secure: bool) -> Tuple[BasicHost, BasicHost]: @asynccontextmanager
hosts = await HostFactory.create_batch_and_listen(is_secure, 2) async def host_pair_factory(
await connect(hosts[0], hosts[1]) is_secure: bool
return hosts[0], hosts[1]
@asynccontextmanager # type: ignore
async def pair_of_connected_hosts(
is_secure: bool = True
) -> AsyncIterator[Tuple[BasicHost, BasicHost]]: ) -> AsyncIterator[Tuple[BasicHost, BasicHost]]:
a, b = await host_pair_factory(is_secure) async with HostFactory.create_batch_and_listen(is_secure, 2) as hosts:
yield a, b await connect(hosts[0], hosts[1])
close_tasks = (a.close(), b.close()) yield hosts[0], hosts[1]
await asyncio.gather(*close_tasks)
@asynccontextmanager
async def swarm_conn_pair_factory( async def swarm_conn_pair_factory(
is_secure: bool, muxer_opt: TMuxerOptions = None is_secure: bool, muxer_opt: TMuxerOptions = None
) -> Tuple[SwarmConn, Swarm, SwarmConn, Swarm]: ) -> AsyncIterator[Tuple[SwarmConn, SwarmConn]]:
swarms = await swarm_pair_factory(is_secure) async with swarm_pair_factory(is_secure) as swarms:
conn_0 = swarms[0].connections[swarms[1].get_peer_id()] conn_0 = swarms[0].connections[swarms[1].get_peer_id()]
conn_1 = swarms[1].connections[swarms[0].get_peer_id()] conn_1 = swarms[1].connections[swarms[0].get_peer_id()]
return cast(SwarmConn, conn_0), swarms[0], cast(SwarmConn, conn_1), swarms[1] yield cast(SwarmConn, conn_0), cast(SwarmConn, conn_1)
async def mplex_conn_pair_factory(is_secure: bool) -> Tuple[Mplex, Swarm, Mplex, Swarm]: @asynccontextmanager
async def mplex_conn_pair_factory(
is_secure: bool
) -> AsyncIterator[Tuple[Mplex, Mplex]]:
muxer_opt = {MPLEX_PROTOCOL_ID: Mplex} muxer_opt = {MPLEX_PROTOCOL_ID: Mplex}
conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory( async with swarm_conn_pair_factory(is_secure, muxer_opt=muxer_opt) as swarm_pair:
is_secure, muxer_opt=muxer_opt yield (
) cast(Mplex, swarm_pair[0].muxed_conn),
return ( cast(Mplex, swarm_pair[1].muxed_conn),
cast(Mplex, conn_0.muxed_conn), )
swarm_0,
cast(Mplex, conn_1.muxed_conn),
swarm_1,
)
@asynccontextmanager
async def mplex_stream_pair_factory( async def mplex_stream_pair_factory(
is_secure: bool is_secure: bool
) -> Tuple[MplexStream, Swarm, MplexStream, Swarm]: ) -> AsyncIterator[Tuple[MplexStream, MplexStream]]:
mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory( async with mplex_conn_pair_factory(is_secure) as mplex_conn_pair_info:
is_secure mplex_conn_0, mplex_conn_1 = mplex_conn_pair_info
) stream_0 = cast(MplexStream, await mplex_conn_0.open_stream())
stream_0 = await mplex_conn_0.open_stream() await trio.sleep(0.01)
await asyncio.sleep(0.01) stream_1: MplexStream
stream_1: MplexStream async with mplex_conn_1.streams_lock:
async with mplex_conn_1.streams_lock: if len(mplex_conn_1.streams) != 1:
if len(mplex_conn_1.streams) != 1: raise Exception("Mplex should not have any other stream")
raise Exception("Mplex should not have any stream upon connection") stream_1 = tuple(mplex_conn_1.streams.values())[0]
stream_1 = tuple(mplex_conn_1.streams.values())[0] yield stream_0, stream_1
return cast(MplexStream, stream_0), swarm_0, stream_1, swarm_1
@asynccontextmanager
async def net_stream_pair_factory( async def net_stream_pair_factory(
is_secure: bool is_secure: bool
) -> Tuple[INetStream, BasicHost, INetStream, BasicHost]: ) -> AsyncIterator[Tuple[INetStream, INetStream]]:
protocol_id = TProtocol("/example/id/1") protocol_id = TProtocol("/example/id/1")
stream_1: INetStream stream_1: INetStream
# Just a proxy, we only care about the stream # Just a proxy, we only care about the stream.
def handler(stream: INetStream) -> None: # Add a barrier to avoid stream being removed.
event_handler_finished = trio.Event()
async def handler(stream: INetStream) -> None:
nonlocal stream_1 nonlocal stream_1
stream_1 = stream stream_1 = stream
await event_handler_finished.wait()
host_0, host_1 = await host_pair_factory(is_secure) async with host_pair_factory(is_secure) as hosts:
host_1.set_stream_handler(protocol_id, handler) hosts[1].set_stream_handler(protocol_id, handler)
stream_0 = await host_0.new_stream(host_1.get_id(), [protocol_id]) stream_0 = await hosts[0].new_stream(hosts[1].get_id(), [protocol_id])
return stream_0, host_0, stream_1, host_1 yield stream_0, stream_1
event_handler_finished.set()

View File

@ -1,2 +1 @@
LOCALHOST_IP = "127.0.0.1" LOCALHOST_IP = "127.0.0.1"
PEXPECT_NEW_LINE = "\r\n"

View File

@ -1,52 +1,22 @@
import asyncio from typing import AsyncIterator
import time
from typing import Any, Awaitable, Callable, List
from async_generator import asynccontextmanager
import multiaddr import multiaddr
from multiaddr import Multiaddr from multiaddr import Multiaddr
from p2pclient import Client from p2pclient import Client
import pytest import trio
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.peer.peerinfo import PeerInfo, info_from_p2p_addr from libp2p.peer.peerinfo import PeerInfo, info_from_p2p_addr
from .constants import LOCALHOST_IP from .constants import LOCALHOST_IP
from .envs import GO_BIN_PATH from .envs import GO_BIN_PATH
from .process import BaseInteractiveProcess
P2PD_PATH = GO_BIN_PATH / "p2pd" P2PD_PATH = GO_BIN_PATH / "p2pd"
TIMEOUT_DURATION = 30 class P2PDProcess(BaseInteractiveProcess):
async def try_until_success(
coro_func: Callable[[], Awaitable[Any]], timeout: int = TIMEOUT_DURATION
) -> None:
"""
Keep running ``coro_func`` until either it succeed or time is up.
All arguments of ``coro_func`` should be filled, i.e. it should be
called without arguments.
"""
t_start = time.monotonic()
while True:
result = await coro_func()
if result:
break
if (time.monotonic() - t_start) >= timeout:
# timeout
pytest.fail(f"{coro_func} is still failing after `{timeout}` seconds")
await asyncio.sleep(0.01)
class P2PDProcess:
proc: asyncio.subprocess.Process
cmd: str = str(P2PD_PATH)
args: List[Any]
is_proc_running: bool
_tasks: List["asyncio.Future[Any]"]
def __init__( def __init__(
self, self,
control_maddr: Multiaddr, control_maddr: Multiaddr,
@ -75,74 +45,21 @@ class P2PDProcess:
# - gossipsubHeartbeatInterval: GossipSubHeartbeatInitialDelay = 100 * time.Millisecond # noqa: E501 # - gossipsubHeartbeatInterval: GossipSubHeartbeatInitialDelay = 100 * time.Millisecond # noqa: E501
# - gossipsubHeartbeatInitialDelay: GossipSubHeartbeatInterval = 1 * time.Second # - gossipsubHeartbeatInitialDelay: GossipSubHeartbeatInterval = 1 * time.Second
# Referece: https://github.com/libp2p/go-libp2p-daemon/blob/b95e77dbfcd186ccf817f51e95f73f9fd5982600/p2pd/main.go#L348-L353 # noqa: E501 # Referece: https://github.com/libp2p/go-libp2p-daemon/blob/b95e77dbfcd186ccf817f51e95f73f9fd5982600/p2pd/main.go#L348-L353 # noqa: E501
self.proc = None
self.cmd = str(P2PD_PATH)
self.args = args self.args = args
self.is_proc_running = False self.patterns = (b"Control socket:", b"Peer ID:", b"Peer Addrs:")
self.bytes_read = bytearray()
self._tasks = [] self.event_ready = trio.Event()
async def wait_until_ready(self) -> None:
lines_head_pattern = (b"Control socket:", b"Peer ID:", b"Peer Addrs:")
lines_head_occurred = {line: False for line in lines_head_pattern}
async def read_from_daemon_and_check() -> bool:
line = await self.proc.stdout.readline()
for head_pattern in lines_head_occurred:
if line.startswith(head_pattern):
lines_head_occurred[head_pattern] = True
return all([value for value in lines_head_occurred.values()])
await try_until_success(read_from_daemon_and_check)
# Sleep a little bit to ensure the listener is up after logs are emitted.
await asyncio.sleep(0.01)
async def start_printing_logs(self) -> None:
async def _print_from_stream(
src_name: str, reader: asyncio.StreamReader
) -> None:
while True:
line = await reader.readline()
if line != b"":
print(f"{src_name}\t: {line.rstrip().decode()}")
await asyncio.sleep(0.01)
self._tasks.append(
asyncio.ensure_future(_print_from_stream("out", self.proc.stdout))
)
self._tasks.append(
asyncio.ensure_future(_print_from_stream("err", self.proc.stderr))
)
await asyncio.sleep(0)
async def start(self) -> None:
if self.is_proc_running:
return
self.proc = await asyncio.subprocess.create_subprocess_exec(
self.cmd,
*self.args,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
bufsize=0,
)
self.is_proc_running = True
await self.wait_until_ready()
await self.start_printing_logs()
async def close(self) -> None:
if self.is_proc_running:
self.proc.terminate()
await self.proc.wait()
self.is_proc_running = False
for task in self._tasks:
task.cancel()
class Daemon: class Daemon:
p2pd_proc: P2PDProcess p2pd_proc: BaseInteractiveProcess
control: Client control: Client
peer_info: PeerInfo peer_info: PeerInfo
def __init__( def __init__(
self, p2pd_proc: P2PDProcess, control: Client, peer_info: PeerInfo self, p2pd_proc: BaseInteractiveProcess, control: Client, peer_info: PeerInfo
) -> None: ) -> None:
self.p2pd_proc = p2pd_proc self.p2pd_proc = p2pd_proc
self.control = control self.control = control
@ -164,6 +81,7 @@ class Daemon:
await self.control.close() await self.control.close()
@asynccontextmanager
async def make_p2pd( async def make_p2pd(
daemon_control_port: int, daemon_control_port: int,
client_callback_port: int, client_callback_port: int,
@ -172,7 +90,7 @@ async def make_p2pd(
is_gossipsub: bool = True, is_gossipsub: bool = True,
is_pubsub_signing: bool = False, is_pubsub_signing: bool = False,
is_pubsub_signing_strict: bool = False, is_pubsub_signing_strict: bool = False,
) -> Daemon: ) -> AsyncIterator[Daemon]:
control_maddr = Multiaddr(f"/ip4/{LOCALHOST_IP}/tcp/{daemon_control_port}") control_maddr = Multiaddr(f"/ip4/{LOCALHOST_IP}/tcp/{daemon_control_port}")
p2pd_proc = P2PDProcess( p2pd_proc = P2PDProcess(
control_maddr, control_maddr,
@ -185,21 +103,22 @@ async def make_p2pd(
await p2pd_proc.start() await p2pd_proc.start()
client_callback_maddr = Multiaddr(f"/ip4/{LOCALHOST_IP}/tcp/{client_callback_port}") client_callback_maddr = Multiaddr(f"/ip4/{LOCALHOST_IP}/tcp/{client_callback_port}")
p2pc = Client(control_maddr, client_callback_maddr) p2pc = Client(control_maddr, client_callback_maddr)
await p2pc.listen()
peer_id, maddrs = await p2pc.identify() async with p2pc.listen():
listen_maddr: Multiaddr = None peer_id, maddrs = await p2pc.identify()
for maddr in maddrs: listen_maddr: Multiaddr = None
try: for maddr in maddrs:
ip = maddr.value_for_protocol(multiaddr.protocols.P_IP4) try:
# NOTE: Check if this `maddr` uses `tcp`. ip = maddr.value_for_protocol(multiaddr.protocols.P_IP4)
maddr.value_for_protocol(multiaddr.protocols.P_TCP) # NOTE: Check if this `maddr` uses `tcp`.
except multiaddr.exceptions.ProtocolLookupError: maddr.value_for_protocol(multiaddr.protocols.P_TCP)
continue except multiaddr.exceptions.ProtocolLookupError:
if ip == LOCALHOST_IP: continue
listen_maddr = maddr if ip == LOCALHOST_IP:
break listen_maddr = maddr
assert listen_maddr is not None, "no loopback maddr is found" break
peer_info = info_from_p2p_addr( assert listen_maddr is not None, "no loopback maddr is found"
listen_maddr.encapsulate(Multiaddr(f"/p2p/{peer_id.to_string()}")) peer_info = info_from_p2p_addr(
) listen_maddr.encapsulate(Multiaddr(f"/p2p/{peer_id.to_string()}"))
return Daemon(p2pd_proc, p2pc, peer_info) )
yield Daemon(p2pd_proc, p2pc, peer_info)

View File

@ -0,0 +1,66 @@
from abc import ABC, abstractmethod
import subprocess
from typing import Iterable, List
import trio
TIMEOUT_DURATION = 30
class AbstractInterativeProcess(ABC):
@abstractmethod
async def start(self) -> None:
...
@abstractmethod
async def close(self) -> None:
...
class BaseInteractiveProcess(AbstractInterativeProcess):
proc: trio.Process = None
cmd: str
args: List[str]
bytes_read: bytearray
patterns: Iterable[bytes] = None
event_ready: trio.Event
async def wait_until_ready(self) -> None:
patterns_occurred = {pat: False for pat in self.patterns}
async def read_from_daemon_and_check() -> None:
async for data in self.proc.stdout:
# TODO: It takes O(n^2), which is quite bad.
# But it should succeed in a few seconds.
self.bytes_read.extend(data)
for pat, occurred in patterns_occurred.items():
if occurred:
continue
if pat in self.bytes_read:
patterns_occurred[pat] = True
if all([value for value in patterns_occurred.values()]):
return
with trio.fail_after(TIMEOUT_DURATION):
await read_from_daemon_and_check()
self.event_ready.set()
# Sleep a little bit to ensure the listener is up after logs are emitted.
await trio.sleep(0.01)
async def start(self) -> None:
if self.proc is not None:
return
# NOTE: Ignore type checks here since mypy complains about bufsize=0
self.proc = await trio.open_process( # type: ignore
[self.cmd] + self.args,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # Redirect stderr to stdout, which makes parsing easier
bufsize=0,
)
await self.wait_until_ready()
async def close(self) -> None:
if self.proc is None:
return
self.proc.terminate()
await self.proc.wait()

View File

@ -1,7 +1,7 @@
import asyncio
from typing import Union from typing import Union
from multiaddr import Multiaddr from multiaddr import Multiaddr
import trio
from libp2p.host.host_interface import IHost from libp2p.host.host_interface import IHost
from libp2p.peer.id import ID from libp2p.peer.id import ID
@ -50,7 +50,7 @@ async def connect(a: TDaemonOrHost, b: TDaemonOrHost) -> None:
else: # isinstance(b, IHost) else: # isinstance(b, IHost)
await a.connect(b_peer_info) await a.connect(b_peer_info)
# Allow additional sleep for both side to establish the connection. # Allow additional sleep for both side to establish the connection.
await asyncio.sleep(0.1) await trio.sleep(0.1)
a_peer_info = _get_peer_info(a) a_peer_info = _get_peer_info(a)

View File

@ -1,12 +1,12 @@
import asyncio from typing import AsyncIterator, Dict, Tuple
from typing import Dict
import uuid from async_exit_stack import AsyncExitStack
from async_generator import asynccontextmanager
from async_service import Service, background_trio_service
from libp2p.host.host_interface import IHost from libp2p.host.host_interface import IHost
from libp2p.pubsub.floodsub import FloodSub
from libp2p.pubsub.pubsub import Pubsub from libp2p.pubsub.pubsub import Pubsub
from libp2p.tools.constants import LISTEN_MADDR from libp2p.tools.factories import PubsubFactory
from libp2p.tools.factories import FloodsubFactory, PubsubFactory
CRYPTO_TOPIC = "ethereum" CRYPTO_TOPIC = "ethereum"
@ -18,7 +18,7 @@ CRYPTO_TOPIC = "ethereum"
# Determine message type by looking at first item before first comma # Determine message type by looking at first item before first comma
class DummyAccountNode: class DummyAccountNode(Service):
""" """
Node which has an internal balance mapping, meant to serve as a dummy Node which has an internal balance mapping, meant to serve as a dummy
crypto blockchain. crypto blockchain.
@ -27,19 +27,24 @@ class DummyAccountNode:
crypto each user in the mappings holds crypto each user in the mappings holds
""" """
libp2p_node: IHost
pubsub: Pubsub pubsub: Pubsub
floodsub: FloodSub
def __init__(self, libp2p_node: IHost, pubsub: Pubsub, floodsub: FloodSub): def __init__(self, pubsub: Pubsub) -> None:
self.libp2p_node = libp2p_node
self.pubsub = pubsub self.pubsub = pubsub
self.floodsub = floodsub
self.balances: Dict[str, int] = {} self.balances: Dict[str, int] = {}
self.node_id = str(uuid.uuid1())
@property
def host(self) -> IHost:
return self.pubsub.host
async def run(self) -> None:
self.subscription = await self.pubsub.subscribe(CRYPTO_TOPIC)
self.manager.run_daemon_task(self.handle_incoming_msgs)
await self.manager.wait_finished()
@classmethod @classmethod
async def create(cls) -> "DummyAccountNode": @asynccontextmanager
async def create(cls, number: int) -> AsyncIterator[Tuple["DummyAccountNode", ...]]:
""" """
Create a new DummyAccountNode and attach a libp2p node, a floodsub, and Create a new DummyAccountNode and attach a libp2p node, a floodsub, and
a pubsub instance to this new node. a pubsub instance to this new node.
@ -47,15 +52,17 @@ class DummyAccountNode:
We use create as this serves as a factory function and allows us We use create as this serves as a factory function and allows us
to use async await, unlike the init function to use async await, unlike the init function
""" """
async with PubsubFactory.create_batch_with_floodsub(number) as pubsubs:
pubsub = PubsubFactory(router=FloodsubFactory()) async with AsyncExitStack() as stack:
await pubsub.host.get_network().listen(LISTEN_MADDR) dummy_acount_nodes = tuple(cls(pubsub) for pubsub in pubsubs)
return cls(libp2p_node=pubsub.host, pubsub=pubsub, floodsub=pubsub.router) for node in dummy_acount_nodes:
await stack.enter_async_context(background_trio_service(node))
yield dummy_acount_nodes
async def handle_incoming_msgs(self) -> None: async def handle_incoming_msgs(self) -> None:
"""Handle all incoming messages on the CRYPTO_TOPIC from peers.""" """Handle all incoming messages on the CRYPTO_TOPIC from peers."""
while True: while True:
incoming = await self.q.get() incoming = await self.subscription.get()
msg_comps = incoming.data.decode("utf-8").split(",") msg_comps = incoming.data.decode("utf-8").split(",")
if msg_comps[0] == "send": if msg_comps[0] == "send":
@ -63,13 +70,6 @@ class DummyAccountNode:
elif msg_comps[0] == "set": elif msg_comps[0] == "set":
self.handle_set_crypto(msg_comps[1], int(msg_comps[2])) self.handle_set_crypto(msg_comps[1], int(msg_comps[2]))
async def setup_crypto_networking(self) -> None:
"""Subscribe to CRYPTO_TOPIC and perform call to function that handles
all incoming messages on said topic."""
self.q = await self.pubsub.subscribe(CRYPTO_TOPIC)
asyncio.ensure_future(self.handle_incoming_msgs())
async def publish_send_crypto( async def publish_send_crypto(
self, source_user: str, dest_user: str, amount: int self, source_user: str, dest_user: str, amount: int
) -> None: ) -> None:

View File

@ -1,12 +1,10 @@
# type: ignore # type: ignore
# To add typing to this module, it's better to do it after refactoring test cases into classes # To add typing to this module, it's better to do it after refactoring test cases into classes
import asyncio
import pytest import pytest
import trio
from libp2p.tools.constants import FLOODSUB_PROTOCOL_ID, LISTEN_MADDR from libp2p.tools.constants import FLOODSUB_PROTOCOL_ID
from libp2p.tools.factories import PubsubFactory
from libp2p.tools.utils import connect from libp2p.tools.utils import connect
SUPPORTED_PROTOCOLS = [FLOODSUB_PROTOCOL_ID] SUPPORTED_PROTOCOLS = [FLOODSUB_PROTOCOL_ID]
@ -15,6 +13,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{ {
"name": "simple_two_nodes", "name": "simple_two_nodes",
"supported_protocols": SUPPORTED_PROTOCOLS, "supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["A", "B"],
"adj_list": {"A": ["B"]}, "adj_list": {"A": ["B"]},
"topic_map": {"topic1": ["B"]}, "topic_map": {"topic1": ["B"]},
"messages": [{"topics": ["topic1"], "data": b"foo", "node_id": "A"}], "messages": [{"topics": ["topic1"], "data": b"foo", "node_id": "A"}],
@ -22,6 +21,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{ {
"name": "three_nodes_two_topics", "name": "three_nodes_two_topics",
"supported_protocols": SUPPORTED_PROTOCOLS, "supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["A", "B", "C"],
"adj_list": {"A": ["B"], "B": ["C"]}, "adj_list": {"A": ["B"], "B": ["C"]},
"topic_map": {"topic1": ["B", "C"], "topic2": ["B", "C"]}, "topic_map": {"topic1": ["B", "C"], "topic2": ["B", "C"]},
"messages": [ "messages": [
@ -32,6 +32,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{ {
"name": "two_nodes_one_topic_single_subscriber_is_sender", "name": "two_nodes_one_topic_single_subscriber_is_sender",
"supported_protocols": SUPPORTED_PROTOCOLS, "supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["A", "B"],
"adj_list": {"A": ["B"]}, "adj_list": {"A": ["B"]},
"topic_map": {"topic1": ["B"]}, "topic_map": {"topic1": ["B"]},
"messages": [{"topics": ["topic1"], "data": b"Alex is tall", "node_id": "B"}], "messages": [{"topics": ["topic1"], "data": b"Alex is tall", "node_id": "B"}],
@ -39,6 +40,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{ {
"name": "two_nodes_one_topic_two_msgs", "name": "two_nodes_one_topic_two_msgs",
"supported_protocols": SUPPORTED_PROTOCOLS, "supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["A", "B"],
"adj_list": {"A": ["B"]}, "adj_list": {"A": ["B"]},
"topic_map": {"topic1": ["B"]}, "topic_map": {"topic1": ["B"]},
"messages": [ "messages": [
@ -49,6 +51,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{ {
"name": "seven_nodes_tree_one_topics", "name": "seven_nodes_tree_one_topics",
"supported_protocols": SUPPORTED_PROTOCOLS, "supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["1", "2", "3", "4", "5", "6", "7"],
"adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]}, "adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]},
"topic_map": {"astrophysics": ["2", "3", "4", "5", "6", "7"]}, "topic_map": {"astrophysics": ["2", "3", "4", "5", "6", "7"]},
"messages": [{"topics": ["astrophysics"], "data": b"e=mc^2", "node_id": "1"}], "messages": [{"topics": ["astrophysics"], "data": b"e=mc^2", "node_id": "1"}],
@ -56,6 +59,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{ {
"name": "seven_nodes_tree_three_topics", "name": "seven_nodes_tree_three_topics",
"supported_protocols": SUPPORTED_PROTOCOLS, "supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["1", "2", "3", "4", "5", "6", "7"],
"adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]}, "adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]},
"topic_map": { "topic_map": {
"astrophysics": ["2", "3", "4", "5", "6", "7"], "astrophysics": ["2", "3", "4", "5", "6", "7"],
@ -71,6 +75,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{ {
"name": "seven_nodes_tree_three_topics_diff_origin", "name": "seven_nodes_tree_three_topics_diff_origin",
"supported_protocols": SUPPORTED_PROTOCOLS, "supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["1", "2", "3", "4", "5", "6", "7"],
"adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]}, "adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]},
"topic_map": { "topic_map": {
"astrophysics": ["1", "2", "3", "4", "5", "6", "7"], "astrophysics": ["1", "2", "3", "4", "5", "6", "7"],
@ -86,6 +91,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{ {
"name": "three_nodes_clique_two_topic_diff_origin", "name": "three_nodes_clique_two_topic_diff_origin",
"supported_protocols": SUPPORTED_PROTOCOLS, "supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["1", "2", "3"],
"adj_list": {"1": ["2", "3"], "2": ["3"]}, "adj_list": {"1": ["2", "3"], "2": ["3"]},
"topic_map": {"astrophysics": ["1", "2", "3"], "school": ["1", "2", "3"]}, "topic_map": {"astrophysics": ["1", "2", "3"], "school": ["1", "2", "3"]},
"messages": [ "messages": [
@ -97,6 +103,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{ {
"name": "four_nodes_clique_two_topic_diff_origin_many_msgs", "name": "four_nodes_clique_two_topic_diff_origin_many_msgs",
"supported_protocols": SUPPORTED_PROTOCOLS, "supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["1", "2", "3", "4"],
"adj_list": { "adj_list": {
"1": ["2", "3", "4"], "1": ["2", "3", "4"],
"2": ["1", "3", "4"], "2": ["1", "3", "4"],
@ -120,6 +127,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{ {
"name": "five_nodes_ring_two_topic_diff_origin_many_msgs", "name": "five_nodes_ring_two_topic_diff_origin_many_msgs",
"supported_protocols": SUPPORTED_PROTOCOLS, "supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["1", "2", "3", "4", "5"],
"adj_list": {"1": ["2"], "2": ["3"], "3": ["4"], "4": ["5"], "5": ["1"]}, "adj_list": {"1": ["2"], "2": ["3"], "3": ["4"], "4": ["5"], "5": ["1"]},
"topic_map": { "topic_map": {
"astrophysics": ["1", "2", "3", "4", "5"], "astrophysics": ["1", "2", "3", "4", "5"],
@ -143,15 +151,7 @@ floodsub_protocol_pytest_params = [
] ]
def _collect_node_ids(adj_list): async def perform_test_from_obj(obj, pubsub_factory) -> None:
node_ids = set()
for node, neighbors in adj_list.items():
node_ids.add(node)
node_ids.update(set(neighbors))
return node_ids
async def perform_test_from_obj(obj, router_factory) -> None:
""" """
Perform pubsub tests from a test object, which is composed as follows: Perform pubsub tests from a test object, which is composed as follows:
@ -185,68 +185,75 @@ async def perform_test_from_obj(obj, router_factory) -> None:
# Step 1) Create graph # Step 1) Create graph
adj_list = obj["adj_list"] adj_list = obj["adj_list"]
node_list = obj["nodes"]
node_map = {} node_map = {}
pubsub_map = {} pubsub_map = {}
async def add_node(node_id_str: str): async with pubsub_factory(
pubsub_router = router_factory(protocols=obj["supported_protocols"]) number=len(node_list), protocols=obj["supported_protocols"]
pubsub = PubsubFactory(router=pubsub_router) ) as pubsubs:
await pubsub.host.get_network().listen(LISTEN_MADDR) for node_id_str, pubsub in zip(node_list, pubsubs):
node_map[node_id_str] = pubsub.host node_map[node_id_str] = pubsub.host
pubsub_map[node_id_str] = pubsub pubsub_map[node_id_str] = pubsub
all_node_ids = _collect_node_ids(adj_list) # Connect nodes and wait at least for 2 seconds
async with trio.open_nursery() as nursery:
for start_node_id in adj_list:
# For each neighbor of start_node, create if does not yet exist,
# then connect start_node to neighbor
for neighbor_id in adj_list[start_node_id]:
nursery.start_soon(
connect, node_map[start_node_id], node_map[neighbor_id]
)
nursery.start_soon(trio.sleep, 2)
for node in all_node_ids: # Step 2) Subscribe to topics
await add_node(node) queues_map = {}
topic_map = obj["topic_map"]
for node, neighbors in adj_list.items(): async def subscribe_node(node_id, topic):
for neighbor_id in neighbors:
await connect(node_map[node], node_map[neighbor_id])
# NOTE: the test using this routine will fail w/o these sleeps...
await asyncio.sleep(1)
# Step 2) Subscribe to topics
queues_map = {}
topic_map = obj["topic_map"]
for topic, node_ids in topic_map.items():
for node_id in node_ids:
queue = await pubsub_map[node_id].subscribe(topic)
if node_id not in queues_map: if node_id not in queues_map:
queues_map[node_id] = {} queues_map[node_id] = {}
# Store queue in topic-queue map for node # Avoid repeated works
queues_map[node_id][topic] = queue if topic in queues_map[node_id]:
# Checkpoint
await trio.hazmat.checkpoint()
return
sub = await pubsub_map[node_id].subscribe(topic)
queues_map[node_id][topic] = sub
# NOTE: the test using this routine will fail w/o these sleeps... async with trio.open_nursery() as nursery:
await asyncio.sleep(1) for topic, node_ids in topic_map.items():
for node_id in node_ids:
nursery.start_soon(subscribe_node, node_id, topic)
nursery.start_soon(trio.sleep, 2)
# Step 3) Publish messages # Step 3) Publish messages
topics_in_msgs_ordered = [] topics_in_msgs_ordered = []
messages = obj["messages"] messages = obj["messages"]
for msg in messages: for msg in messages:
topics = msg["topics"] topics = msg["topics"]
data = msg["data"] data = msg["data"]
node_id = msg["node_id"] node_id = msg["node_id"]
# Publish message
# TODO: Should be single RPC package with several topics
for topic in topics:
await pubsub_map[node_id].publish(topic, data)
# Publish message
# TODO: Should be single RPC package with several topics
for topic in topics:
await pubsub_map[node_id].publish(topic, data)
# For each topic in topics, add (topic, node_id, data) tuple to ordered test list # For each topic in topics, add (topic, node_id, data) tuple to ordered test list
topics_in_msgs_ordered.append((topic, node_id, data)) for topic in topics:
topics_in_msgs_ordered.append((topic, node_id, data))
# Allow time for publishing before continuing
await trio.sleep(1)
# Step 4) Check that all messages were received correctly. # Step 4) Check that all messages were received correctly.
for topic, origin_node_id, data in topics_in_msgs_ordered: for topic, origin_node_id, data in topics_in_msgs_ordered:
# Look at each node in each topic # Look at each node in each topic
for node_id in topic_map[topic]: for node_id in topic_map[topic]:
# Get message from subscription queue # Get message from subscription queue
queue = queues_map[node_id][topic] msg = await queues_map[node_id][topic].get()
msg = await queue.get() assert data == msg.data
assert data == msg.data # Check the message origin
# Check the message origin assert node_map[origin_node_id].get_id().to_bytes() == msg.from_id
assert node_map[origin_node_id].get_id().to_bytes() == msg.from_id
# Success, terminate pending tasks.

View File

@ -1,17 +1,10 @@
from typing import Dict, Sequence, Tuple, cast from typing import Awaitable, Callable
import multiaddr
from libp2p import new_node
from libp2p.host.basic_host import BasicHost
from libp2p.host.host_interface import IHost from libp2p.host.host_interface import IHost
from libp2p.host.routed_host import RoutedHost from libp2p.network.stream.exceptions import StreamError
from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.network.swarm import Swarm from libp2p.network.swarm import Swarm
from libp2p.peer.id import ID from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.peer.peerinfo import PeerInfo, info_from_p2p_addr
from libp2p.routing.interfaces import IPeerRouting
from libp2p.typing import StreamHandlerFn, TProtocol
from .constants import MAX_READ_LEN from .constants import MAX_READ_LEN
@ -36,63 +29,20 @@ async def connect(node1: IHost, node2: IHost) -> None:
await node1.connect(info) await node1.connect(info)
async def set_up_nodes_by_transport_opt( def create_echo_stream_handler(
transport_opt_list: Sequence[Sequence[str]] ack_prefix: str
) -> Tuple[BasicHost, ...]: ) -> Callable[[INetStream], Awaitable[None]]:
nodes_list = [] async def echo_stream_handler(stream: INetStream) -> None:
for transport_opt in transport_opt_list: while True:
node = await new_node(transport_opt=transport_opt) try:
await node.get_network().listen(multiaddr.Multiaddr(transport_opt[0])) read_string = (await stream.read(MAX_READ_LEN)).decode()
nodes_list.append(node) except StreamError:
return tuple(nodes_list) break
resp = ack_prefix + read_string
try:
await stream.write(resp.encode())
except StreamError:
break
async def echo_stream_handler(stream: INetStream) -> None: return echo_stream_handler
while True:
read_string = (await stream.read(MAX_READ_LEN)).decode()
resp = f"ack:{read_string}"
await stream.write(resp.encode())
async def perform_two_host_set_up(
handler: StreamHandlerFn = echo_stream_handler
) -> Tuple[BasicHost, BasicHost]:
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
(node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
node_b.set_stream_handler(TProtocol("/echo/1.0.0"), handler)
# Associate the peer with local ip address (see default parameters of Libp2p())
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
return node_a, node_b
class DummyRouter(IPeerRouting):
_routing_table: Dict[ID, PeerInfo]
def __init__(self) -> None:
self._routing_table = dict()
async def find_peer(self, peer_id: ID) -> PeerInfo:
return self._routing_table.get(peer_id, None)
async def set_up_routed_hosts() -> Tuple[RoutedHost, RoutedHost]:
router_a, router_b = DummyRouter(), DummyRouter()
transport = "/ip4/127.0.0.1/tcp/0"
host_a = await new_node(transport_opt=[transport], disc_opt=router_a)
host_b = await new_node(transport_opt=[transport], disc_opt=router_b)
address = multiaddr.Multiaddr(transport)
await host_a.get_network().listen(address)
await host_b.get_network().listen(address)
mock_routing_table = {
host_a.get_id(): PeerInfo(host_a.get_id(), host_a.get_addrs()),
host_b.get_id(): PeerInfo(host_b.get_id(), host_b.get_addrs()),
}
router_a._routing_table = router_b._routing_table = mock_routing_table
return cast(RoutedHost, host_a), cast(RoutedHost, host_b)

View File

@ -1,12 +1,13 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List from typing import Tuple
from multiaddr import Multiaddr from multiaddr import Multiaddr
import trio
class IListener(ABC): class IListener(ABC):
@abstractmethod @abstractmethod
async def listen(self, maddr: Multiaddr) -> bool: async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool:
""" """
put listener in listening mode and wait for incoming connections. put listener in listening mode and wait for incoming connections.
@ -15,7 +16,7 @@ class IListener(ABC):
""" """
@abstractmethod @abstractmethod
def get_addrs(self) -> List[Multiaddr]: def get_addrs(self) -> Tuple[Multiaddr, ...]:
""" """
retrieve list of addresses the listener is listening on. retrieve list of addresses the listener is listening on.
@ -24,5 +25,4 @@ class IListener(ABC):
@abstractmethod @abstractmethod
async def close(self) -> None: async def close(self) -> None:
"""close the listener such that no more connections can be open on this ...
transport instance."""

View File

@ -1,10 +1,11 @@
import asyncio import logging
from socket import socket from typing import Awaitable, Callable, List, Sequence, Tuple
import sys
from typing import List
from multiaddr import Multiaddr from multiaddr import Multiaddr
import trio
from trio_typing import TaskStatus
from libp2p.io.trio import TrioTCPStream
from libp2p.network.connection.raw_connection import RawConnection from libp2p.network.connection.raw_connection import RawConnection
from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.transport.exceptions import OpenConnectionError from libp2p.transport.exceptions import OpenConnectionError
@ -12,53 +13,61 @@ from libp2p.transport.listener_interface import IListener
from libp2p.transport.transport_interface import ITransport from libp2p.transport.transport_interface import ITransport
from libp2p.transport.typing import THandler from libp2p.transport.typing import THandler
logger = logging.getLogger("libp2p.transport.tcp")
class TCPListener(IListener): class TCPListener(IListener):
multiaddrs: List[Multiaddr] listeners: List[trio.SocketListener]
server = None
def __init__(self, handler_function: THandler) -> None: def __init__(self, handler_function: THandler) -> None:
self.multiaddrs = [] self.listeners = []
self.server = None
self.handler = handler_function self.handler = handler_function
async def listen(self, maddr: Multiaddr) -> bool: # TODO: Get rid of `nursery`?
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> None:
""" """
put listener in listening mode and wait for incoming connections. put listener in listening mode and wait for incoming connections.
:param maddr: maddr of peer :param maddr: maddr of peer
:return: return True if successful :return: return True if successful
""" """
self.server = await asyncio.start_server(
self.handler, async def serve_tcp(
handler: Callable[[trio.SocketStream], Awaitable[None]],
port: int,
host: str,
task_status: TaskStatus[Sequence[trio.SocketListener]] = None,
) -> None:
"""Just a proxy function to add logging here."""
logger.debug("serve_tcp %s %s", host, port)
await trio.serve_tcp(handler, port, host=host, task_status=task_status)
async def handler(stream: trio.SocketStream) -> None:
tcp_stream = TrioTCPStream(stream)
await self.handler(tcp_stream)
listeners = await nursery.start(
serve_tcp,
handler,
int(maddr.value_for_protocol("tcp")),
maddr.value_for_protocol("ip4"), maddr.value_for_protocol("ip4"),
maddr.value_for_protocol("tcp"),
) )
socket = self.server.sockets[0] self.listeners.extend(listeners)
self.multiaddrs.append(_multiaddr_from_socket(socket))
return True def get_addrs(self) -> Tuple[Multiaddr, ...]:
def get_addrs(self) -> List[Multiaddr]:
""" """
retrieve list of addresses the listener is listening on. retrieve list of addresses the listener is listening on.
:return: return list of addrs :return: return list of addrs
""" """
# TODO check if server is listening return tuple(
return self.multiaddrs _multiaddr_from_socket(listener.socket) for listener in self.listeners
)
async def close(self) -> None: async def close(self) -> None:
"""close the listener such that no more connections can be open on this async with trio.open_nursery() as nursery:
transport instance.""" for listener in self.listeners:
if self.server is None: nursery.start_soon(listener.aclose)
return
self.server.close()
server = self.server
self.server = None
if sys.version_info < (3, 7):
return
await server.wait_closed()
class TCP(ITransport): class TCP(ITransport):
@ -74,11 +83,12 @@ class TCP(ITransport):
self.port = int(maddr.value_for_protocol("tcp")) self.port = int(maddr.value_for_protocol("tcp"))
try: try:
reader, writer = await asyncio.open_connection(self.host, self.port) stream = await trio.open_tcp_stream(self.host, self.port)
except (ConnectionAbortedError, ConnectionRefusedError) as error: except OSError as error:
raise OpenConnectionError(error) raise OpenConnectionError from error
read_write_closer = TrioTCPStream(stream)
return RawConnection(reader, writer, True) return RawConnection(read_write_closer, True)
def create_listener(self, handler_function: THandler) -> TCPListener: def create_listener(self, handler_function: THandler) -> TCPListener:
""" """
@ -91,6 +101,6 @@ class TCP(ITransport):
return TCPListener(handler_function) return TCPListener(handler_function)
def _multiaddr_from_socket(socket: socket) -> Multiaddr: def _multiaddr_from_socket(socket: trio.socket.SocketType) -> Multiaddr:
addr, port = socket.getsockname()[:2] ip, port = socket.getsockname() # type: ignore
return Multiaddr(f"/ip4/{addr}/tcp/{port}") return Multiaddr(f"/ip4/{ip}/tcp/{port}")

View File

@ -1,11 +1,11 @@
from asyncio import StreamReader, StreamWriter
from typing import Awaitable, Callable, Mapping, Type from typing import Awaitable, Callable, Mapping, Type
from libp2p.io.abc import ReadWriteCloser
from libp2p.security.secure_transport_interface import ISecureTransport from libp2p.security.secure_transport_interface import ISecureTransport
from libp2p.stream_muxer.abc import IMuxedConn from libp2p.stream_muxer.abc import IMuxedConn
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
THandler = Callable[[StreamReader, StreamWriter], Awaitable[None]] THandler = Callable[[ReadWriteCloser], Awaitable[None]]
TSecurityOptions = Mapping[TProtocol, ISecureTransport] TSecurityOptions = Mapping[TProtocol, ISecureTransport]
TMuxerClass = Type[IMuxedConn] TMuxerClass = Type[IMuxedConn]
TMuxerOptions = Mapping[TProtocol, TMuxerClass] TMuxerOptions = Mapping[TProtocol, TMuxerClass]

View File

@ -7,8 +7,8 @@ from setuptools import find_packages, setup
extras_require = { extras_require = {
"test": [ "test": [
"pytest>=4.6.3,<5.0.0", "pytest>=4.6.3,<5.0.0",
"pytest-xdist>=1.30.0,<2", "pytest-xdist>=1.30.0",
"pytest-asyncio>=0.10.0,<1.0.0", "pytest-trio>=0.5.2",
"factory-boy>=2.12.0,<3.0.0", "factory-boy>=2.12.0,<3.0.0",
], ],
"lint": [ "lint": [
@ -74,6 +74,10 @@ install_requires = [
"pynacl==1.3.0", "pynacl==1.3.0",
"dataclasses>=0.7, <1;python_version<'3.7'", "dataclasses>=0.7, <1;python_version<'3.7'",
"async_generator==1.10", "async_generator==1.10",
"trio>=0.13.0",
"async-service>=0.1.0a6",
"async-exit-stack==1.0.1",
"trio-typing>=0.3.0,<0.4.0",
] ]

View File

@ -1,8 +1,5 @@
import asyncio
import pytest import pytest
from libp2p.tools.constants import LISTEN_MADDR
from libp2p.tools.factories import HostFactory from libp2p.tools.factories import HostFactory
@ -17,17 +14,6 @@ def num_hosts():
@pytest.fixture @pytest.fixture
async def hosts(num_hosts, is_host_secure): async def hosts(num_hosts, is_host_secure, nursery):
_hosts = HostFactory.create_batch(num_hosts, is_secure=is_host_secure) async with HostFactory.create_batch_and_listen(is_host_secure, num_hosts) as _hosts:
await asyncio.gather(
*[_host.get_network().listen(LISTEN_MADDR) for _host in _hosts]
)
try:
yield _hosts yield _hosts
finally:
# TODO: It's possible that `close` raises exceptions currently,
# due to the connection reset things. Though we don't care much about that when
# cleaning up the tasks, it is probably better to handle the exceptions properly.
await asyncio.gather(
*[_host.close() for _host in _hosts], return_exceptions=True
)

View File

@ -1,10 +1,10 @@
import asyncio
import pytest import pytest
import trio
from libp2p.host.exceptions import StreamFailure from libp2p.host.exceptions import StreamFailure
from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.tools.utils import set_up_nodes_by_transport_opt from libp2p.tools.factories import HostFactory
from libp2p.tools.utils import MAX_READ_LEN
PROTOCOL_ID = "/chat/1.0.0" PROTOCOL_ID = "/chat/1.0.0"
@ -25,7 +25,7 @@ async def hello_world(host_a, host_b):
# Multiaddress of the destination peer is fetched from the peerstore using 'peerId'. # Multiaddress of the destination peer is fetched from the peerstore using 'peerId'.
stream = await host_b.new_stream(host_a.get_id(), [PROTOCOL_ID]) stream = await host_b.new_stream(host_a.get_id(), [PROTOCOL_ID])
await stream.write(hello_world_from_host_b) await stream.write(hello_world_from_host_b)
read = await stream.read() read = await stream.read(MAX_READ_LEN)
assert read == hello_world_from_host_a assert read == hello_world_from_host_a
await stream.close() await stream.close()
@ -47,7 +47,7 @@ async def connect_write(host_a, host_b):
await stream.write(message.encode()) await stream.write(message.encode())
# Reader needs time due to async reads # Reader needs time due to async reads
await asyncio.sleep(2) await trio.sleep(2)
await stream.close() await stream.close()
assert received == messages assert received == messages
@ -88,16 +88,14 @@ async def no_common_protocol(host_a, host_b):
await host_b.new_stream(host_a.get_id(), ["/fakeproto/0.0.1"]) await host_b.new_stream(host_a.get_id(), ["/fakeproto/0.0.1"])
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"test", [(hello_world), (connect_write), (connect_read), (no_common_protocol)] "test", [(hello_world), (connect_write), (connect_read), (no_common_protocol)]
) )
async def test_chat(test): @pytest.mark.trio
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] async def test_chat(test, is_host_secure):
(host_a, host_b) = await set_up_nodes_by_transport_opt(transport_opt_list) async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts:
addr = hosts[0].get_addrs()[0]
info = info_from_p2p_addr(addr)
await hosts[1].connect(info)
addr = host_a.get_addrs()[0] await test(hosts[0], hosts[1])
info = info_from_p2p_addr(addr)
await host_b.connect(info)
await test(host_a, host_b)

View File

@ -1,4 +1,4 @@
from libp2p import initialize_default_swarm from libp2p import new_swarm
from libp2p.crypto.rsa import create_new_key_pair from libp2p.crypto.rsa import create_new_key_pair
from libp2p.host.basic_host import BasicHost from libp2p.host.basic_host import BasicHost
from libp2p.host.defaults import get_default_protocols from libp2p.host.defaults import get_default_protocols
@ -6,7 +6,7 @@ from libp2p.host.defaults import get_default_protocols
def test_default_protocols(): def test_default_protocols():
key_pair = create_new_key_pair() key_pair = create_new_key_pair()
swarm = initialize_default_swarm(key_pair) swarm = new_swarm(key_pair)
host = BasicHost(swarm) host = BasicHost(swarm)
mux = host.get_mux() mux = host.get_mux()

View File

@ -1,18 +1,19 @@
import asyncio
import secrets import secrets
import pytest import pytest
import trio
from libp2p.host.ping import ID, PING_LENGTH from libp2p.host.ping import ID, PING_LENGTH
from libp2p.tools.factories import pair_of_connected_hosts from libp2p.tools.factories import host_pair_factory
@pytest.mark.asyncio @pytest.mark.trio
async def test_ping_once(): async def test_ping_once(is_host_secure):
async with pair_of_connected_hosts() as (host_a, host_b): async with host_pair_factory(is_host_secure) as (host_a, host_b):
stream = await host_b.new_stream(host_a.get_id(), (ID,)) stream = await host_b.new_stream(host_a.get_id(), (ID,))
some_ping = secrets.token_bytes(PING_LENGTH) some_ping = secrets.token_bytes(PING_LENGTH)
await stream.write(some_ping) await stream.write(some_ping)
await trio.sleep(0.01)
some_pong = await stream.read(PING_LENGTH) some_pong = await stream.read(PING_LENGTH)
assert some_ping == some_pong assert some_ping == some_pong
await stream.close() await stream.close()
@ -21,9 +22,9 @@ async def test_ping_once():
SOME_PING_COUNT = 3 SOME_PING_COUNT = 3
@pytest.mark.asyncio @pytest.mark.trio
async def test_ping_several(): async def test_ping_several(is_host_secure):
async with pair_of_connected_hosts() as (host_a, host_b): async with host_pair_factory(is_host_secure) as (host_a, host_b):
stream = await host_b.new_stream(host_a.get_id(), (ID,)) stream = await host_b.new_stream(host_a.get_id(), (ID,))
for _ in range(SOME_PING_COUNT): for _ in range(SOME_PING_COUNT):
some_ping = secrets.token_bytes(PING_LENGTH) some_ping = secrets.token_bytes(PING_LENGTH)
@ -33,5 +34,5 @@ async def test_ping_several():
# NOTE: simulate some time to sleep to mirror a real # NOTE: simulate some time to sleep to mirror a real
# world usage where a peer sends pings on some periodic interval # world usage where a peer sends pings on some periodic interval
# NOTE: this interval can be `0` for this test. # NOTE: this interval can be `0` for this test.
await asyncio.sleep(0) await trio.sleep(0)
await stream.close() await stream.close()

View File

@ -1,33 +1,26 @@
import asyncio
import pytest import pytest
from libp2p.host.exceptions import ConnectionFailure from libp2p.host.exceptions import ConnectionFailure
from libp2p.peer.peerinfo import PeerInfo from libp2p.peer.peerinfo import PeerInfo
from libp2p.tools.utils import set_up_nodes_by_transport_opt, set_up_routed_hosts from libp2p.tools.factories import HostFactory, RoutedHostFactory
@pytest.mark.asyncio @pytest.mark.trio
async def test_host_routing_success(): async def test_host_routing_success():
host_a, host_b = await set_up_routed_hosts() async with RoutedHostFactory.create_batch_and_listen(False, 2) as hosts:
# forces to use routing as no addrs are provided # forces to use routing as no addrs are provided
await host_a.connect(PeerInfo(host_b.get_id(), [])) await hosts[0].connect(PeerInfo(hosts[1].get_id(), []))
await host_b.connect(PeerInfo(host_a.get_id(), [])) await hosts[1].connect(PeerInfo(hosts[0].get_id(), []))
# Clean up
await asyncio.gather(*[host_a.close(), host_b.close()])
@pytest.mark.asyncio @pytest.mark.trio
async def test_host_routing_fail(): async def test_host_routing_fail():
host_a, host_b = await set_up_routed_hosts() is_secure = False
basic_host_c = (await set_up_nodes_by_transport_opt([["/ip4/127.0.0.1/tcp/0"]]))[0] async with RoutedHostFactory.create_batch_and_listen(
is_secure, 2
# routing fails because host_c does not use routing ) as routed_hosts, HostFactory.create_batch_and_listen(is_secure, 1) as basic_hosts:
with pytest.raises(ConnectionFailure): # routing fails because host_c does not use routing
await host_a.connect(PeerInfo(basic_host_c.get_id(), [])) with pytest.raises(ConnectionFailure):
with pytest.raises(ConnectionFailure): await routed_hosts[0].connect(PeerInfo(basic_hosts[0].get_id(), []))
await host_b.connect(PeerInfo(basic_host_c.get_id(), [])) with pytest.raises(ConnectionFailure):
await routed_hosts[1].connect(PeerInfo(basic_hosts[0].get_id(), []))
# Clean up
await asyncio.gather(*[host_a.close(), host_b.close(), basic_host_c.close()])

View File

@ -2,12 +2,12 @@ import pytest
from libp2p.identity.identify.pb.identify_pb2 import Identify from libp2p.identity.identify.pb.identify_pb2 import Identify
from libp2p.identity.identify.protocol import ID, _mk_identify_protobuf from libp2p.identity.identify.protocol import ID, _mk_identify_protobuf
from libp2p.tools.factories import pair_of_connected_hosts from libp2p.tools.factories import host_pair_factory
@pytest.mark.asyncio @pytest.mark.trio
async def test_identify_protocol(): async def test_identify_protocol(is_host_secure):
async with pair_of_connected_hosts() as (host_a, host_b): async with host_pair_factory(is_host_secure) as (host_a, host_b):
stream = await host_b.new_stream(host_a.get_id(), (ID,)) stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read() response = await stream.read()
await stream.close() await stream.close()

View File

@ -1,350 +1,285 @@
import multiaddr import multiaddr
import pytest import pytest
from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.network.stream.exceptions import StreamError
from libp2p.tools.constants import MAX_READ_LEN from libp2p.tools.constants import MAX_READ_LEN
from libp2p.tools.utils import set_up_nodes_by_transport_opt from libp2p.tools.factories import HostFactory
from libp2p.tools.utils import connect, create_echo_stream_handler
from libp2p.typing import TProtocol
PROTOCOL_ID_0 = TProtocol("/echo/0")
PROTOCOL_ID_1 = TProtocol("/echo/1")
PROTOCOL_ID_2 = TProtocol("/echo/2")
PROTOCOL_ID_3 = TProtocol("/echo/3")
ACK_STR_0 = "ack_0:"
ACK_STR_1 = "ack_1:"
ACK_STR_2 = "ack_2:"
ACK_STR_3 = "ack_3:"
@pytest.mark.asyncio @pytest.mark.trio
async def test_simple_messages(): async def test_simple_messages(is_host_secure):
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts:
(node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list) hosts[1].set_stream_handler(
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
async def stream_handler(stream):
while True:
read_string = (await stream.read(MAX_READ_LEN)).decode()
response = "ack:" + read_string
await stream.write(response.encode())
node_b.set_stream_handler("/echo/1.0.0", stream_handler)
# Associate the peer with local ip address (see default parameters of Libp2p())
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"])
messages = ["hello" + str(x) for x in range(10)]
for message in messages:
await stream.write(message.encode())
response = (await stream.read(MAX_READ_LEN)).decode()
assert response == ("ack:" + message)
# Success, terminate pending tasks.
@pytest.mark.asyncio
async def test_double_response():
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
(node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
async def stream_handler(stream):
while True:
read_string = (await stream.read(MAX_READ_LEN)).decode()
response = "ack1:" + read_string
await stream.write(response.encode())
response = "ack2:" + read_string
await stream.write(response.encode())
node_b.set_stream_handler("/echo/1.0.0", stream_handler)
# Associate the peer with local ip address (see default parameters of Libp2p())
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"])
messages = ["hello" + str(x) for x in range(10)]
for message in messages:
await stream.write(message.encode())
response1 = (await stream.read(MAX_READ_LEN)).decode()
assert response1 == ("ack1:" + message)
response2 = (await stream.read(MAX_READ_LEN)).decode()
assert response2 == ("ack2:" + message)
# Success, terminate pending tasks.
@pytest.mark.asyncio
async def test_multiple_streams():
# Node A should be able to open a stream with node B and then vice versa.
# Stream IDs should be generated uniquely so that the stream state is not overwritten
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
(node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
async def stream_handler_a(stream):
while True:
read_string = (await stream.read(MAX_READ_LEN)).decode()
response = "ack_a:" + read_string
await stream.write(response.encode())
async def stream_handler_b(stream):
while True:
read_string = (await stream.read(MAX_READ_LEN)).decode()
response = "ack_b:" + read_string
await stream.write(response.encode())
node_a.set_stream_handler("/echo_a/1.0.0", stream_handler_a)
node_b.set_stream_handler("/echo_b/1.0.0", stream_handler_b)
# Associate the peer with local ip address (see default parameters of Libp2p())
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
node_b.get_peerstore().add_addrs(node_a.get_id(), node_a.get_addrs(), 10)
stream_a = await node_a.new_stream(node_b.get_id(), ["/echo_b/1.0.0"])
stream_b = await node_b.new_stream(node_a.get_id(), ["/echo_a/1.0.0"])
# A writes to /echo_b via stream_a, and B writes to /echo_a via stream_b
messages = ["hello" + str(x) for x in range(10)]
for message in messages:
a_message = message + "_a"
b_message = message + "_b"
await stream_a.write(a_message.encode())
await stream_b.write(b_message.encode())
response_a = (await stream_a.read(MAX_READ_LEN)).decode()
response_b = (await stream_b.read(MAX_READ_LEN)).decode()
assert response_a == ("ack_b:" + a_message) and response_b == (
"ack_a:" + b_message
) )
# Success, terminate pending tasks. # Associate the peer with local ip address (see default parameters of Libp2p())
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
stream = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
messages = ["hello" + str(x) for x in range(10)]
for message in messages:
await stream.write(message.encode())
response = (await stream.read(MAX_READ_LEN)).decode()
assert response == (ACK_STR_0 + message)
@pytest.mark.asyncio @pytest.mark.trio
async def test_multiple_streams_same_initiator_different_protocols(): async def test_double_response(is_host_secure):
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts:
(node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
async def stream_handler_a1(stream): async def double_response_stream_handler(stream):
while True: while True:
read_string = (await stream.read(MAX_READ_LEN)).decode() try:
read_string = (await stream.read(MAX_READ_LEN)).decode()
except StreamError:
break
response = "ack_a1:" + read_string response = ACK_STR_0 + read_string
await stream.write(response.encode()) try:
await stream.write(response.encode())
except StreamError:
break
async def stream_handler_a2(stream): response = ACK_STR_1 + read_string
while True: try:
read_string = (await stream.read(MAX_READ_LEN)).decode() await stream.write(response.encode())
except StreamError:
break
response = "ack_a2:" + read_string hosts[1].set_stream_handler(PROTOCOL_ID_0, double_response_stream_handler)
await stream.write(response.encode())
async def stream_handler_a3(stream): # Associate the peer with local ip address (see default parameters of Libp2p())
while True: hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
read_string = (await stream.read(MAX_READ_LEN)).decode() stream = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
response = "ack_a3:" + read_string messages = ["hello" + str(x) for x in range(10)]
await stream.write(response.encode()) for message in messages:
node_b.set_stream_handler("/echo_a1/1.0.0", stream_handler_a1)
node_b.set_stream_handler("/echo_a2/1.0.0", stream_handler_a2)
node_b.set_stream_handler("/echo_a3/1.0.0", stream_handler_a3)
# Associate the peer with local ip address (see default parameters of Libp2p())
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
node_b.get_peerstore().add_addrs(node_a.get_id(), node_a.get_addrs(), 10)
# Open streams to node_b over echo_a1 echo_a2 echo_a3 protocols
stream_a1 = await node_a.new_stream(node_b.get_id(), ["/echo_a1/1.0.0"])
stream_a2 = await node_a.new_stream(node_b.get_id(), ["/echo_a2/1.0.0"])
stream_a3 = await node_a.new_stream(node_b.get_id(), ["/echo_a3/1.0.0"])
messages = ["hello" + str(x) for x in range(10)]
for message in messages:
a1_message = message + "_a1"
a2_message = message + "_a2"
a3_message = message + "_a3"
await stream_a1.write(a1_message.encode())
await stream_a2.write(a2_message.encode())
await stream_a3.write(a3_message.encode())
response_a1 = (await stream_a1.read(MAX_READ_LEN)).decode()
response_a2 = (await stream_a2.read(MAX_READ_LEN)).decode()
response_a3 = (await stream_a3.read(MAX_READ_LEN)).decode()
assert (
response_a1 == ("ack_a1:" + a1_message)
and response_a2 == ("ack_a2:" + a2_message)
and response_a3 == ("ack_a3:" + a3_message)
)
# Success, terminate pending tasks.
@pytest.mark.asyncio
async def test_multiple_streams_two_initiators():
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
(node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
async def stream_handler_a1(stream):
while True:
read_string = (await stream.read(MAX_READ_LEN)).decode()
response = "ack_a1:" + read_string
await stream.write(response.encode())
async def stream_handler_a2(stream):
while True:
read_string = (await stream.read(MAX_READ_LEN)).decode()
response = "ack_a2:" + read_string
await stream.write(response.encode())
async def stream_handler_b1(stream):
while True:
read_string = (await stream.read(MAX_READ_LEN)).decode()
response = "ack_b1:" + read_string
await stream.write(response.encode())
async def stream_handler_b2(stream):
while True:
read_string = (await stream.read(MAX_READ_LEN)).decode()
response = "ack_b2:" + read_string
await stream.write(response.encode())
node_a.set_stream_handler("/echo_b1/1.0.0", stream_handler_b1)
node_a.set_stream_handler("/echo_b2/1.0.0", stream_handler_b2)
node_b.set_stream_handler("/echo_a1/1.0.0", stream_handler_a1)
node_b.set_stream_handler("/echo_a2/1.0.0", stream_handler_a2)
# Associate the peer with local ip address (see default parameters of Libp2p())
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
node_b.get_peerstore().add_addrs(node_a.get_id(), node_a.get_addrs(), 10)
stream_a1 = await node_a.new_stream(node_b.get_id(), ["/echo_a1/1.0.0"])
stream_a2 = await node_a.new_stream(node_b.get_id(), ["/echo_a2/1.0.0"])
stream_b1 = await node_b.new_stream(node_a.get_id(), ["/echo_b1/1.0.0"])
stream_b2 = await node_b.new_stream(node_a.get_id(), ["/echo_b2/1.0.0"])
# A writes to /echo_b via stream_a, and B writes to /echo_a via stream_b
messages = ["hello" + str(x) for x in range(10)]
for message in messages:
a1_message = message + "_a1"
a2_message = message + "_a2"
b1_message = message + "_b1"
b2_message = message + "_b2"
await stream_a1.write(a1_message.encode())
await stream_a2.write(a2_message.encode())
await stream_b1.write(b1_message.encode())
await stream_b2.write(b2_message.encode())
response_a1 = (await stream_a1.read(MAX_READ_LEN)).decode()
response_a2 = (await stream_a2.read(MAX_READ_LEN)).decode()
response_b1 = (await stream_b1.read(MAX_READ_LEN)).decode()
response_b2 = (await stream_b2.read(MAX_READ_LEN)).decode()
assert (
response_a1 == ("ack_a1:" + a1_message)
and response_a2 == ("ack_a2:" + a2_message)
and response_b1 == ("ack_b1:" + b1_message)
and response_b2 == ("ack_b2:" + b2_message)
)
# Success, terminate pending tasks.
@pytest.mark.asyncio
async def test_triangle_nodes_connection():
transport_opt_list = [
["/ip4/127.0.0.1/tcp/0"],
["/ip4/127.0.0.1/tcp/0"],
["/ip4/127.0.0.1/tcp/0"],
]
(node_a, node_b, node_c) = await set_up_nodes_by_transport_opt(transport_opt_list)
async def stream_handler(stream):
while True:
read_string = (await stream.read(MAX_READ_LEN)).decode()
response = "ack:" + read_string
await stream.write(response.encode())
node_a.set_stream_handler("/echo/1.0.0", stream_handler)
node_b.set_stream_handler("/echo/1.0.0", stream_handler)
node_c.set_stream_handler("/echo/1.0.0", stream_handler)
# Associate the peer with local ip address (see default parameters of Libp2p())
# Associate all permutations
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
node_a.get_peerstore().add_addrs(node_c.get_id(), node_c.get_addrs(), 10)
node_b.get_peerstore().add_addrs(node_a.get_id(), node_a.get_addrs(), 10)
node_b.get_peerstore().add_addrs(node_c.get_id(), node_c.get_addrs(), 10)
node_c.get_peerstore().add_addrs(node_a.get_id(), node_a.get_addrs(), 10)
node_c.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
stream_a_to_b = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"])
stream_a_to_c = await node_a.new_stream(node_c.get_id(), ["/echo/1.0.0"])
stream_b_to_a = await node_b.new_stream(node_a.get_id(), ["/echo/1.0.0"])
stream_b_to_c = await node_b.new_stream(node_c.get_id(), ["/echo/1.0.0"])
stream_c_to_a = await node_c.new_stream(node_a.get_id(), ["/echo/1.0.0"])
stream_c_to_b = await node_c.new_stream(node_b.get_id(), ["/echo/1.0.0"])
messages = ["hello" + str(x) for x in range(5)]
streams = [
stream_a_to_b,
stream_a_to_c,
stream_b_to_a,
stream_b_to_c,
stream_c_to_a,
stream_c_to_b,
]
for message in messages:
for stream in streams:
await stream.write(message.encode()) await stream.write(message.encode())
response = (await stream.read(MAX_READ_LEN)).decode() response1 = (await stream.read(MAX_READ_LEN)).decode()
assert response1 == (ACK_STR_0 + message)
assert response == ("ack:" + message) response2 = (await stream.read(MAX_READ_LEN)).decode()
assert response2 == (ACK_STR_1 + message)
# Success, terminate pending tasks.
@pytest.mark.asyncio @pytest.mark.trio
async def test_host_connect(): async def test_multiple_streams(is_host_secure):
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] # hosts[0] should be able to open a stream with hosts[1] and then vice versa.
(node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list) # Stream IDs should be generated uniquely so that the stream state is not overwritten
# Only our peer ID is stored in peer store async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts:
assert len(node_a.get_peerstore().peer_ids()) == 1 hosts[0].set_stream_handler(
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
)
hosts[1].set_stream_handler(
PROTOCOL_ID_1, create_echo_stream_handler(ACK_STR_1)
)
addr = node_b.get_addrs()[0] # Associate the peer with local ip address (see default parameters of Libp2p())
info = info_from_p2p_addr(addr) hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
await node_a.connect(info) hosts[1].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10)
assert len(node_a.get_peerstore().peer_ids()) == 2 stream_a = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_1])
stream_b = await hosts[1].new_stream(hosts[0].get_id(), [PROTOCOL_ID_0])
await node_a.connect(info) # A writes to /echo_b via stream_a, and B writes to /echo_a via stream_b
messages = ["hello" + str(x) for x in range(10)]
for message in messages:
a_message = message + "_a"
b_message = message + "_b"
# make sure we don't do double connection await stream_a.write(a_message.encode())
assert len(node_a.get_peerstore().peer_ids()) == 2 await stream_b.write(b_message.encode())
assert node_b.get_id() in node_a.get_peerstore().peer_ids() response_a = (await stream_a.read(MAX_READ_LEN)).decode()
ma_node_b = multiaddr.Multiaddr("/p2p/%s" % node_b.get_id().pretty()) response_b = (await stream_b.read(MAX_READ_LEN)).decode()
for addr in node_a.get_peerstore().addrs(node_b.get_id()):
assert addr.encapsulate(ma_node_b) in node_b.get_addrs()
# Success, terminate pending tasks. assert response_a == (ACK_STR_1 + a_message) and response_b == (
ACK_STR_0 + b_message
)
@pytest.mark.trio
async def test_multiple_streams_same_initiator_different_protocols(is_host_secure):
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts:
hosts[1].set_stream_handler(
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
)
hosts[1].set_stream_handler(
PROTOCOL_ID_1, create_echo_stream_handler(ACK_STR_1)
)
hosts[1].set_stream_handler(
PROTOCOL_ID_2, create_echo_stream_handler(ACK_STR_2)
)
# Associate the peer with local ip address (see default parameters of Libp2p())
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
hosts[1].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10)
# Open streams to hosts[1] over echo_a1 echo_a2 echo_a3 protocols
stream_a1 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
stream_a2 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_1])
stream_a3 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_2])
messages = ["hello" + str(x) for x in range(10)]
for message in messages:
a1_message = message + "_a1"
a2_message = message + "_a2"
a3_message = message + "_a3"
await stream_a1.write(a1_message.encode())
await stream_a2.write(a2_message.encode())
await stream_a3.write(a3_message.encode())
response_a1 = (await stream_a1.read(MAX_READ_LEN)).decode()
response_a2 = (await stream_a2.read(MAX_READ_LEN)).decode()
response_a3 = (await stream_a3.read(MAX_READ_LEN)).decode()
assert (
response_a1 == (ACK_STR_0 + a1_message)
and response_a2 == (ACK_STR_1 + a2_message)
and response_a3 == (ACK_STR_2 + a3_message)
)
# Success, terminate pending tasks.
@pytest.mark.trio
async def test_multiple_streams_two_initiators(is_host_secure):
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts:
hosts[0].set_stream_handler(
PROTOCOL_ID_2, create_echo_stream_handler(ACK_STR_2)
)
hosts[0].set_stream_handler(
PROTOCOL_ID_3, create_echo_stream_handler(ACK_STR_3)
)
hosts[1].set_stream_handler(
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
)
hosts[1].set_stream_handler(
PROTOCOL_ID_1, create_echo_stream_handler(ACK_STR_1)
)
# Associate the peer with local ip address (see default parameters of Libp2p())
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
hosts[1].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10)
stream_a1 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
stream_a2 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_1])
stream_b1 = await hosts[1].new_stream(hosts[0].get_id(), [PROTOCOL_ID_2])
stream_b2 = await hosts[1].new_stream(hosts[0].get_id(), [PROTOCOL_ID_3])
# A writes to /echo_b via stream_a, and B writes to /echo_a via stream_b
messages = ["hello" + str(x) for x in range(10)]
for message in messages:
a1_message = message + "_a1"
a2_message = message + "_a2"
b1_message = message + "_b1"
b2_message = message + "_b2"
await stream_a1.write(a1_message.encode())
await stream_a2.write(a2_message.encode())
await stream_b1.write(b1_message.encode())
await stream_b2.write(b2_message.encode())
response_a1 = (await stream_a1.read(MAX_READ_LEN)).decode()
response_a2 = (await stream_a2.read(MAX_READ_LEN)).decode()
response_b1 = (await stream_b1.read(MAX_READ_LEN)).decode()
response_b2 = (await stream_b2.read(MAX_READ_LEN)).decode()
assert (
response_a1 == (ACK_STR_0 + a1_message)
and response_a2 == (ACK_STR_1 + a2_message)
and response_b1 == (ACK_STR_2 + b1_message)
and response_b2 == (ACK_STR_3 + b2_message)
)
@pytest.mark.trio
async def test_triangle_nodes_connection(is_host_secure):
async with HostFactory.create_batch_and_listen(is_host_secure, 3) as hosts:
hosts[0].set_stream_handler(
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
)
hosts[1].set_stream_handler(
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
)
hosts[2].set_stream_handler(
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
)
# Associate the peer with local ip address (see default parameters of Libp2p())
# Associate all permutations
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
hosts[0].get_peerstore().add_addrs(hosts[2].get_id(), hosts[2].get_addrs(), 10)
hosts[1].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10)
hosts[1].get_peerstore().add_addrs(hosts[2].get_id(), hosts[2].get_addrs(), 10)
hosts[2].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10)
hosts[2].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
stream_0_to_1 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
stream_0_to_2 = await hosts[0].new_stream(hosts[2].get_id(), [PROTOCOL_ID_0])
stream_1_to_0 = await hosts[1].new_stream(hosts[0].get_id(), [PROTOCOL_ID_0])
stream_1_to_2 = await hosts[1].new_stream(hosts[2].get_id(), [PROTOCOL_ID_0])
stream_2_to_0 = await hosts[2].new_stream(hosts[0].get_id(), [PROTOCOL_ID_0])
stream_2_to_1 = await hosts[2].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
messages = ["hello" + str(x) for x in range(5)]
streams = [
stream_0_to_1,
stream_0_to_2,
stream_1_to_0,
stream_1_to_2,
stream_2_to_0,
stream_2_to_1,
]
for message in messages:
for stream in streams:
await stream.write(message.encode())
response = (await stream.read(MAX_READ_LEN)).decode()
assert response == (ACK_STR_0 + message)
@pytest.mark.trio
async def test_host_connect(is_host_secure):
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts:
assert len(hosts[0].get_peerstore().peer_ids()) == 1
await connect(hosts[0], hosts[1])
assert len(hosts[0].get_peerstore().peer_ids()) == 2
await connect(hosts[0], hosts[1])
# make sure we don't do double connection
assert len(hosts[0].get_peerstore().peer_ids()) == 2
assert hosts[1].get_id() in hosts[0].get_peerstore().peer_ids()
ma_node_b = multiaddr.Multiaddr("/p2p/%s" % hosts[1].get_id().pretty())
for addr in hosts[0].get_peerstore().addrs(hosts[1].get_id()):
assert addr.encapsulate(ma_node_b) in hosts[1].get_addrs()

View File

@ -1,5 +1,3 @@
import asyncio
import pytest import pytest
from libp2p.tools.factories import ( from libp2p.tools.factories import (
@ -11,26 +9,17 @@ from libp2p.tools.factories import (
@pytest.fixture @pytest.fixture
async def net_stream_pair(is_host_secure): async def net_stream_pair(is_host_secure):
stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory(is_host_secure) async with net_stream_pair_factory(is_host_secure) as net_stream_pair:
try: yield net_stream_pair
yield stream_0, stream_1
finally:
await asyncio.gather(*[host_0.close(), host_1.close()])
@pytest.fixture @pytest.fixture
async def swarm_pair(is_host_secure): async def swarm_pair(is_host_secure):
swarm_0, swarm_1 = await swarm_pair_factory(is_host_secure) async with swarm_pair_factory(is_host_secure) as swarms:
try: yield swarms
yield swarm_0, swarm_1
finally:
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])
@pytest.fixture @pytest.fixture
async def swarm_conn_pair(is_host_secure): async def swarm_conn_pair(is_host_secure):
conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory(is_host_secure) async with swarm_conn_pair_factory(is_host_secure) as swarm_conn_pair:
try: yield swarm_conn_pair
yield conn_0, conn_1
finally:
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])

View File

@ -1,6 +1,5 @@
import asyncio
import pytest import pytest
import trio
from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset
from libp2p.tools.constants import MAX_READ_LEN from libp2p.tools.constants import MAX_READ_LEN
@ -8,7 +7,7 @@ from libp2p.tools.constants import MAX_READ_LEN
DATA = b"data_123" DATA = b"data_123"
@pytest.mark.asyncio @pytest.mark.trio
async def test_net_stream_read_write(net_stream_pair): async def test_net_stream_read_write(net_stream_pair):
stream_0, stream_1 = net_stream_pair stream_0, stream_1 = net_stream_pair
assert ( assert (
@ -19,7 +18,7 @@ async def test_net_stream_read_write(net_stream_pair):
assert (await stream_1.read(MAX_READ_LEN)) == DATA assert (await stream_1.read(MAX_READ_LEN)) == DATA
@pytest.mark.asyncio @pytest.mark.trio
async def test_net_stream_read_until_eof(net_stream_pair): async def test_net_stream_read_until_eof(net_stream_pair):
read_bytes = bytearray() read_bytes = bytearray()
stream_0, stream_1 = net_stream_pair stream_0, stream_1 = net_stream_pair
@ -27,41 +26,39 @@ async def test_net_stream_read_until_eof(net_stream_pair):
async def read_until_eof(): async def read_until_eof():
read_bytes.extend(await stream_1.read()) read_bytes.extend(await stream_1.read())
task = asyncio.ensure_future(read_until_eof()) async with trio.open_nursery() as nursery:
nursery.start_soon(read_until_eof)
expected_data = bytearray()
expected_data = bytearray() # Test: `read` doesn't return before `close` is called.
await stream_0.write(DATA)
expected_data.extend(DATA)
await trio.sleep(0.01)
assert len(read_bytes) == 0
# Test: `read` doesn't return before `close` is called.
await stream_0.write(DATA)
expected_data.extend(DATA)
await trio.sleep(0.01)
assert len(read_bytes) == 0
# Test: `read` doesn't return before `close` is called. # Test: Close the stream, `read` returns, and receive previous sent data.
await stream_0.write(DATA) await stream_0.close()
expected_data.extend(DATA) await trio.sleep(0.01)
await asyncio.sleep(0.01) assert read_bytes == expected_data
assert len(read_bytes) == 0
# Test: `read` doesn't return before `close` is called.
await stream_0.write(DATA)
expected_data.extend(DATA)
await asyncio.sleep(0.01)
assert len(read_bytes) == 0
# Test: Close the stream, `read` returns, and receive previous sent data.
await stream_0.close()
await asyncio.sleep(0.01)
assert read_bytes == expected_data
task.cancel()
@pytest.mark.asyncio @pytest.mark.trio
async def test_net_stream_read_after_remote_closed(net_stream_pair): async def test_net_stream_read_after_remote_closed(net_stream_pair):
stream_0, stream_1 = net_stream_pair stream_0, stream_1 = net_stream_pair
await stream_0.write(DATA) await stream_0.write(DATA)
await stream_0.close() await stream_0.close()
await asyncio.sleep(0.01) await trio.sleep(0.01)
assert (await stream_1.read(MAX_READ_LEN)) == DATA assert (await stream_1.read(MAX_READ_LEN)) == DATA
with pytest.raises(StreamEOF): with pytest.raises(StreamEOF):
await stream_1.read(MAX_READ_LEN) await stream_1.read(MAX_READ_LEN)
@pytest.mark.asyncio @pytest.mark.trio
async def test_net_stream_read_after_local_reset(net_stream_pair): async def test_net_stream_read_after_local_reset(net_stream_pair):
stream_0, stream_1 = net_stream_pair stream_0, stream_1 = net_stream_pair
await stream_0.reset() await stream_0.reset()
@ -69,29 +66,29 @@ async def test_net_stream_read_after_local_reset(net_stream_pair):
await stream_0.read(MAX_READ_LEN) await stream_0.read(MAX_READ_LEN)
@pytest.mark.asyncio @pytest.mark.trio
async def test_net_stream_read_after_remote_reset(net_stream_pair): async def test_net_stream_read_after_remote_reset(net_stream_pair):
stream_0, stream_1 = net_stream_pair stream_0, stream_1 = net_stream_pair
await stream_0.write(DATA) await stream_0.write(DATA)
await stream_0.reset() await stream_0.reset()
# Sleep to let `stream_1` receive the message. # Sleep to let `stream_1` receive the message.
await asyncio.sleep(0.01) await trio.sleep(0.01)
with pytest.raises(StreamReset): with pytest.raises(StreamReset):
await stream_1.read(MAX_READ_LEN) await stream_1.read(MAX_READ_LEN)
@pytest.mark.asyncio @pytest.mark.trio
async def test_net_stream_read_after_remote_closed_and_reset(net_stream_pair): async def test_net_stream_read_after_remote_closed_and_reset(net_stream_pair):
stream_0, stream_1 = net_stream_pair stream_0, stream_1 = net_stream_pair
await stream_0.write(DATA) await stream_0.write(DATA)
await stream_0.close() await stream_0.close()
await stream_0.reset() await stream_0.reset()
# Sleep to let `stream_1` receive the message. # Sleep to let `stream_1` receive the message.
await asyncio.sleep(0.01) await trio.sleep(0.01)
assert (await stream_1.read(MAX_READ_LEN)) == DATA assert (await stream_1.read(MAX_READ_LEN)) == DATA
@pytest.mark.asyncio @pytest.mark.trio
async def test_net_stream_write_after_local_closed(net_stream_pair): async def test_net_stream_write_after_local_closed(net_stream_pair):
stream_0, stream_1 = net_stream_pair stream_0, stream_1 = net_stream_pair
await stream_0.write(DATA) await stream_0.write(DATA)
@ -100,7 +97,7 @@ async def test_net_stream_write_after_local_closed(net_stream_pair):
await stream_0.write(DATA) await stream_0.write(DATA)
@pytest.mark.asyncio @pytest.mark.trio
async def test_net_stream_write_after_local_reset(net_stream_pair): async def test_net_stream_write_after_local_reset(net_stream_pair):
stream_0, stream_1 = net_stream_pair stream_0, stream_1 = net_stream_pair
await stream_0.reset() await stream_0.reset()
@ -108,10 +105,10 @@ async def test_net_stream_write_after_local_reset(net_stream_pair):
await stream_0.write(DATA) await stream_0.write(DATA)
@pytest.mark.asyncio @pytest.mark.trio
async def test_net_stream_write_after_remote_reset(net_stream_pair): async def test_net_stream_write_after_remote_reset(net_stream_pair):
stream_0, stream_1 = net_stream_pair stream_0, stream_1 = net_stream_pair
await stream_1.reset() await stream_1.reset()
await asyncio.sleep(0.01) await trio.sleep(0.01)
with pytest.raises(StreamClosed): with pytest.raises(StreamClosed):
await stream_0.write(DATA) await stream_0.write(DATA)

View File

@ -8,11 +8,11 @@ into network after network has already started listening
TODO: Add tests for closed_stream, listen_close when those TODO: Add tests for closed_stream, listen_close when those
features are implemented in swarm features are implemented in swarm
""" """
import asyncio
import enum import enum
from async_service import background_trio_service
import pytest import pytest
import trio
from libp2p.network.notifee_interface import INotifee from libp2p.network.notifee_interface import INotifee
from libp2p.tools.constants import LISTEN_MADDR from libp2p.tools.constants import LISTEN_MADDR
@ -54,59 +54,63 @@ class MyNotifee(INotifee):
pass pass
@pytest.mark.asyncio @pytest.mark.trio
async def test_notify(is_host_secure): async def test_notify(is_host_secure):
swarms = [SwarmFactory(is_secure=is_host_secure) for _ in range(2)] swarms = [SwarmFactory(is_secure=is_host_secure) for _ in range(2)]
events_0_0 = [] events_0_0 = []
events_1_0 = [] events_1_0 = []
events_0_without_listen = [] events_0_without_listen = []
swarms[0].register_notifee(MyNotifee(events_0_0)) # Run swarms.
swarms[1].register_notifee(MyNotifee(events_1_0)) async with background_trio_service(swarms[0]), background_trio_service(swarms[1]):
# Listen # Register events before listening, to allow `MyNotifee` is notified with the event
await asyncio.gather(*[swarm.listen(LISTEN_MADDR) for swarm in swarms]) # `listen`.
swarms[0].register_notifee(MyNotifee(events_0_0))
swarms[1].register_notifee(MyNotifee(events_1_0))
swarms[0].register_notifee(MyNotifee(events_0_without_listen)) # Listen
async with trio.open_nursery() as nursery:
nursery.start_soon(swarms[0].listen, LISTEN_MADDR)
nursery.start_soon(swarms[1].listen, LISTEN_MADDR)
# Connected swarms[0].register_notifee(MyNotifee(events_0_without_listen))
await connect_swarm(swarms[0], swarms[1])
# OpenedStream: first
await swarms[0].new_stream(swarms[1].get_peer_id())
# OpenedStream: second
await swarms[0].new_stream(swarms[1].get_peer_id())
# OpenedStream: third, but different direction.
await swarms[1].new_stream(swarms[0].get_peer_id())
await asyncio.sleep(0.01) # Connected
await connect_swarm(swarms[0], swarms[1])
# OpenedStream: first
await swarms[0].new_stream(swarms[1].get_peer_id())
# OpenedStream: second
await swarms[0].new_stream(swarms[1].get_peer_id())
# OpenedStream: third, but different direction.
await swarms[1].new_stream(swarms[0].get_peer_id())
# TODO: Check `ClosedStream` and `ListenClose` events after they are ready. await trio.sleep(0.01)
# Disconnected # TODO: Check `ClosedStream` and `ListenClose` events after they are ready.
await swarms[0].close_peer(swarms[1].get_peer_id())
await asyncio.sleep(0.01)
# Connected again, but different direction. # Disconnected
await connect_swarm(swarms[1], swarms[0]) await swarms[0].close_peer(swarms[1].get_peer_id())
await asyncio.sleep(0.01) await trio.sleep(0.01)
# Disconnected again, but different direction. # Connected again, but different direction.
await swarms[1].close_peer(swarms[0].get_peer_id()) await connect_swarm(swarms[1], swarms[0])
await asyncio.sleep(0.01) await trio.sleep(0.01)
expected_events_without_listen = [ # Disconnected again, but different direction.
Event.Connected, await swarms[1].close_peer(swarms[0].get_peer_id())
Event.OpenedStream, await trio.sleep(0.01)
Event.OpenedStream,
Event.OpenedStream,
Event.Disconnected,
Event.Connected,
Event.Disconnected,
]
expected_events = [Event.Listen] + expected_events_without_listen
assert events_0_0 == expected_events expected_events_without_listen = [
assert events_1_0 == expected_events Event.Connected,
assert events_0_without_listen == expected_events_without_listen Event.OpenedStream,
Event.OpenedStream,
Event.OpenedStream,
Event.Disconnected,
Event.Connected,
Event.Disconnected,
]
expected_events = [Event.Listen] + expected_events_without_listen
# Clean up assert events_0_0 == expected_events
await asyncio.gather(*[swarm.close() for swarm in swarms]) assert events_1_0 == expected_events
assert events_0_without_listen == expected_events_without_listen

View File

@ -1,89 +1,84 @@
import asyncio
from multiaddr import Multiaddr from multiaddr import Multiaddr
import pytest import pytest
import trio
from trio.testing import wait_all_tasks_blocked
from libp2p.network.exceptions import SwarmException from libp2p.network.exceptions import SwarmException
from libp2p.tools.factories import SwarmFactory from libp2p.tools.factories import SwarmFactory
from libp2p.tools.utils import connect_swarm from libp2p.tools.utils import connect_swarm
@pytest.mark.asyncio @pytest.mark.trio
async def test_swarm_dial_peer(is_host_secure): async def test_swarm_dial_peer(is_host_secure):
swarms = await SwarmFactory.create_batch_and_listen(is_host_secure, 3) async with SwarmFactory.create_batch_and_listen(is_host_secure, 3) as swarms:
# Test: No addr found. # Test: No addr found.
with pytest.raises(SwarmException): with pytest.raises(SwarmException):
await swarms[0].dial_peer(swarms[1].get_peer_id())
# Test: len(addr) in the peerstore is 0.
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), [], 10000)
with pytest.raises(SwarmException):
await swarms[0].dial_peer(swarms[1].get_peer_id())
# Test: Succeed if addrs of the peer_id are present in the peerstore.
addrs = tuple(
addr
for transport in swarms[1].listeners.values()
for addr in transport.get_addrs()
)
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
await swarms[0].dial_peer(swarms[1].get_peer_id()) await swarms[0].dial_peer(swarms[1].get_peer_id())
assert swarms[0].get_peer_id() in swarms[1].connections
assert swarms[1].get_peer_id() in swarms[0].connections
# Test: len(addr) in the peerstore is 0. # Test: Reuse connections when we already have ones with a peer.
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), [], 10000) conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()]
with pytest.raises(SwarmException): conn = await swarms[0].dial_peer(swarms[1].get_peer_id())
await swarms[0].dial_peer(swarms[1].get_peer_id()) assert conn is conn_to_1
# Test: Succeed if addrs of the peer_id are present in the peerstore.
addrs = tuple(
addr
for transport in swarms[1].listeners.values()
for addr in transport.get_addrs()
)
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
await swarms[0].dial_peer(swarms[1].get_peer_id())
assert swarms[0].get_peer_id() in swarms[1].connections
assert swarms[1].get_peer_id() in swarms[0].connections
# Test: Reuse connections when we already have ones with a peer.
conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()]
conn = await swarms[0].dial_peer(swarms[1].get_peer_id())
assert conn is conn_to_1
# Clean up
await asyncio.gather(*[swarm.close() for swarm in swarms])
@pytest.mark.asyncio @pytest.mark.trio
async def test_swarm_close_peer(is_host_secure): async def test_swarm_close_peer(is_host_secure):
swarms = await SwarmFactory.create_batch_and_listen(is_host_secure, 3) async with SwarmFactory.create_batch_and_listen(is_host_secure, 3) as swarms:
# 0 <> 1 <> 2 # 0 <> 1 <> 2
await connect_swarm(swarms[0], swarms[1]) await connect_swarm(swarms[0], swarms[1])
await connect_swarm(swarms[1], swarms[2]) await connect_swarm(swarms[1], swarms[2])
# peer 1 closes peer 0 # peer 1 closes peer 0
await swarms[1].close_peer(swarms[0].get_peer_id()) await swarms[1].close_peer(swarms[0].get_peer_id())
await asyncio.sleep(0.01) await trio.sleep(0.01)
# 0 1 <> 2 await wait_all_tasks_blocked()
assert len(swarms[0].connections) == 0 # 0 1 <> 2
assert ( assert len(swarms[0].connections) == 0
len(swarms[1].connections) == 1 assert (
and swarms[2].get_peer_id() in swarms[1].connections len(swarms[1].connections) == 1
) and swarms[2].get_peer_id() in swarms[1].connections
)
# peer 1 is closed by peer 2 # peer 1 is closed by peer 2
await swarms[2].close_peer(swarms[1].get_peer_id()) await swarms[2].close_peer(swarms[1].get_peer_id())
await asyncio.sleep(0.01) await trio.sleep(0.01)
# 0 1 2 # 0 1 2
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0 assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
await connect_swarm(swarms[0], swarms[1]) await connect_swarm(swarms[0], swarms[1])
# 0 <> 1 2 # 0 <> 1 2
assert ( assert (
len(swarms[0].connections) == 1 len(swarms[0].connections) == 1
and swarms[1].get_peer_id() in swarms[0].connections and swarms[1].get_peer_id() in swarms[0].connections
) )
assert ( assert (
len(swarms[1].connections) == 1 len(swarms[1].connections) == 1
and swarms[0].get_peer_id() in swarms[1].connections and swarms[0].get_peer_id() in swarms[1].connections
) )
# peer 0 closes peer 1 # peer 0 closes peer 1
await swarms[0].close_peer(swarms[1].get_peer_id()) await swarms[0].close_peer(swarms[1].get_peer_id())
await asyncio.sleep(0.01) await trio.sleep(0.01)
# 0 1 2 # 0 1 2
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0 assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
# Clean up
await asyncio.gather(*[swarm.close() for swarm in swarms])
@pytest.mark.asyncio @pytest.mark.trio
async def test_swarm_remove_conn(swarm_pair): async def test_swarm_remove_conn(swarm_pair):
swarm_0, swarm_1 = swarm_pair swarm_0, swarm_1 = swarm_pair
conn_0 = swarm_0.connections[swarm_1.get_peer_id()] conn_0 = swarm_0.connections[swarm_1.get_peer_id()]
@ -94,57 +89,54 @@ async def test_swarm_remove_conn(swarm_pair):
assert swarm_1.get_peer_id() not in swarm_0.connections assert swarm_1.get_peer_id() not in swarm_0.connections
@pytest.mark.asyncio @pytest.mark.trio
async def test_swarm_multiaddr(is_host_secure): async def test_swarm_multiaddr(is_host_secure):
swarms = await SwarmFactory.create_batch_and_listen(is_host_secure, 3) async with SwarmFactory.create_batch_and_listen(is_host_secure, 3) as swarms:
def clear(): def clear():
swarms[0].peerstore.clear_addrs(swarms[1].get_peer_id()) swarms[0].peerstore.clear_addrs(swarms[1].get_peer_id())
clear() clear()
# No addresses # No addresses
with pytest.raises(SwarmException): with pytest.raises(SwarmException):
await swarms[0].dial_peer(swarms[1].get_peer_id())
clear()
# Wrong addresses
swarms[0].peerstore.add_addrs(
swarms[1].get_peer_id(), [Multiaddr("/ip4/0.0.0.0/tcp/9999")], 10000
)
with pytest.raises(SwarmException):
await swarms[0].dial_peer(swarms[1].get_peer_id())
clear()
# Multiple wrong addresses
swarms[0].peerstore.add_addrs(
swarms[1].get_peer_id(),
[Multiaddr("/ip4/0.0.0.0/tcp/9999"), Multiaddr("/ip4/0.0.0.0/tcp/9998")],
10000,
)
with pytest.raises(SwarmException):
await swarms[0].dial_peer(swarms[1].get_peer_id())
# Test one address
addrs = tuple(
addr
for transport in swarms[1].listeners.values()
for addr in transport.get_addrs()
)
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs[:1], 10000)
await swarms[0].dial_peer(swarms[1].get_peer_id()) await swarms[0].dial_peer(swarms[1].get_peer_id())
clear() # Test multiple addresses
# Wrong addresses addrs = tuple(
swarms[0].peerstore.add_addrs( addr
swarms[1].get_peer_id(), [Multiaddr("/ip4/0.0.0.0/tcp/9999")], 10000 for transport in swarms[1].listeners.values()
) for addr in transport.get_addrs()
)
with pytest.raises(SwarmException): swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs + addrs, 10000)
await swarms[0].dial_peer(swarms[1].get_peer_id()) await swarms[0].dial_peer(swarms[1].get_peer_id())
clear()
# Multiple wrong addresses
swarms[0].peerstore.add_addrs(
swarms[1].get_peer_id(),
[Multiaddr("/ip4/0.0.0.0/tcp/9999"), Multiaddr("/ip4/0.0.0.0/tcp/9998")],
10000,
)
with pytest.raises(SwarmException):
await swarms[0].dial_peer(swarms[1].get_peer_id())
# Test one address
addrs = tuple(
addr
for transport in swarms[1].listeners.values()
for addr in transport.get_addrs()
)
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs[:1], 10000)
await swarms[0].dial_peer(swarms[1].get_peer_id())
# Test multiple addresses
addrs = tuple(
addr
for transport in swarms[1].listeners.values()
for addr in transport.get_addrs()
)
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs + addrs, 10000)
await swarms[0].dial_peer(swarms[1].get_peer_id())
for swarm in swarms:
await swarm.close()

View File

@ -1,45 +1,46 @@
import asyncio
import pytest import pytest
import trio
from trio.testing import wait_all_tasks_blocked
@pytest.mark.asyncio @pytest.mark.trio
async def test_swarm_conn_close(swarm_conn_pair): async def test_swarm_conn_close(swarm_conn_pair):
conn_0, conn_1 = swarm_conn_pair conn_0, conn_1 = swarm_conn_pair
assert not conn_0.event_closed.is_set() assert not conn_0.is_closed
assert not conn_1.event_closed.is_set() assert not conn_1.is_closed
await conn_0.close() await conn_0.close()
await asyncio.sleep(0.01) await trio.sleep(0.1)
await wait_all_tasks_blocked()
assert conn_0.event_closed.is_set() assert conn_0.is_closed
assert conn_1.event_closed.is_set() assert conn_1.is_closed
assert conn_0 not in conn_0.swarm.connections.values() assert conn_0 not in conn_0.swarm.connections.values()
assert conn_1 not in conn_1.swarm.connections.values() assert conn_1 not in conn_1.swarm.connections.values()
@pytest.mark.asyncio @pytest.mark.trio
async def test_swarm_conn_streams(swarm_conn_pair): async def test_swarm_conn_streams(swarm_conn_pair):
conn_0, conn_1 = swarm_conn_pair conn_0, conn_1 = swarm_conn_pair
assert len(await conn_0.get_streams()) == 0 assert len(conn_0.get_streams()) == 0
assert len(await conn_1.get_streams()) == 0 assert len(conn_1.get_streams()) == 0
stream_0_0 = await conn_0.new_stream() stream_0_0 = await conn_0.new_stream()
await asyncio.sleep(0.01) await trio.sleep(0.01)
assert len(await conn_0.get_streams()) == 1 assert len(conn_0.get_streams()) == 1
assert len(await conn_1.get_streams()) == 1 assert len(conn_1.get_streams()) == 1
stream_0_1 = await conn_0.new_stream() stream_0_1 = await conn_0.new_stream()
await asyncio.sleep(0.01) await trio.sleep(0.01)
assert len(await conn_0.get_streams()) == 2 assert len(conn_0.get_streams()) == 2
assert len(await conn_1.get_streams()) == 2 assert len(conn_1.get_streams()) == 2
conn_0.remove_stream(stream_0_0) conn_0.remove_stream(stream_0_0)
assert len(await conn_0.get_streams()) == 1 assert len(conn_0.get_streams()) == 1
conn_0.remove_stream(stream_0_1) conn_0.remove_stream(stream_0_1)
assert len(await conn_0.get_streams()) == 0 assert len(conn_0.get_streams()) == 0
# Nothing happen if `stream_0_1` is not present or already removed. # Nothing happen if `stream_0_1` is not present or already removed.
conn_0.remove_stream(stream_0_1) conn_0.remove_stream(stream_0_1)

View File

@ -25,8 +25,6 @@ def test_init_():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"addr", "addr",
( (
pytest.param(None),
pytest.param(random.randint(0, 255), id="random integer"),
pytest.param(multiaddr.Multiaddr("/"), id="empty multiaddr"), pytest.param(multiaddr.Multiaddr("/"), id="empty multiaddr"),
pytest.param( pytest.param(
multiaddr.Multiaddr("/ip4/127.0.0.1"), multiaddr.Multiaddr("/ip4/127.0.0.1"),

View File

@ -1,83 +1,94 @@
import pytest import pytest
from libp2p.host.exceptions import StreamFailure from libp2p.host.exceptions import StreamFailure
from libp2p.tools.utils import echo_stream_handler, set_up_nodes_by_transport_opt from libp2p.tools.factories import HostFactory
from libp2p.tools.utils import create_echo_stream_handler
# TODO: Add tests for multiple streams being opened on different PROTOCOL_ECHO = "/echo/1.0.0"
# protocols through the same connection PROTOCOL_POTATO = "/potato/1.0.0"
PROTOCOL_FOO = "/foo/1.0.0"
PROTOCOL_ROCK = "/rock/1.0.0"
# Note: async issues occurred when using the same port ACK_PREFIX = "ack:"
# so that's why I use different ports here.
# TODO: modify tests so that those async issues don't occur
# when using the same ports across tests
async def perform_simple_test( async def perform_simple_test(
expected_selected_protocol, protocols_for_client, protocols_with_handlers expected_selected_protocol,
protocols_for_client,
protocols_with_handlers,
is_host_secure,
): ):
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts:
(node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list) for protocol in protocols_with_handlers:
hosts[1].set_stream_handler(
protocol, create_echo_stream_handler(ACK_PREFIX)
)
for protocol in protocols_with_handlers: # Associate the peer with local ip address (see default parameters of Libp2p())
node_b.set_stream_handler(protocol, echo_stream_handler) hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
stream = await hosts[0].new_stream(hosts[1].get_id(), protocols_for_client)
messages = ["hello" + str(x) for x in range(10)]
for message in messages:
expected_resp = "ack:" + message
await stream.write(message.encode())
response = (await stream.read(len(expected_resp))).decode()
assert response == expected_resp
# Associate the peer with local ip address (see default parameters of Libp2p()) assert expected_selected_protocol == stream.get_protocol()
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
stream = await node_a.new_stream(node_b.get_id(), protocols_for_client)
messages = ["hello" + str(x) for x in range(10)]
for message in messages:
expected_resp = "ack:" + message
await stream.write(message.encode())
response = (await stream.read(len(expected_resp))).decode()
assert response == expected_resp
assert expected_selected_protocol == stream.get_protocol()
# Success, terminate pending tasks.
@pytest.mark.asyncio @pytest.mark.trio
async def test_single_protocol_succeeds(): async def test_single_protocol_succeeds(is_host_secure):
expected_selected_protocol = "/echo/1.0.0" expected_selected_protocol = PROTOCOL_ECHO
await perform_simple_test( await perform_simple_test(
expected_selected_protocol, ["/echo/1.0.0"], ["/echo/1.0.0"] expected_selected_protocol,
[expected_selected_protocol],
[expected_selected_protocol],
is_host_secure,
) )
@pytest.mark.asyncio @pytest.mark.trio
async def test_single_protocol_fails(): async def test_single_protocol_fails(is_host_secure):
with pytest.raises(StreamFailure): with pytest.raises(StreamFailure):
await perform_simple_test("", ["/echo/1.0.0"], ["/potato/1.0.0"]) await perform_simple_test(
"", [PROTOCOL_ECHO], [PROTOCOL_POTATO], is_host_secure
)
# Cleanup not reached on error # Cleanup not reached on error
@pytest.mark.asyncio @pytest.mark.trio
async def test_multiple_protocol_first_is_valid_succeeds(): async def test_multiple_protocol_first_is_valid_succeeds(is_host_secure):
expected_selected_protocol = "/echo/1.0.0" expected_selected_protocol = PROTOCOL_ECHO
protocols_for_client = ["/echo/1.0.0", "/potato/1.0.0"] protocols_for_client = [PROTOCOL_ECHO, PROTOCOL_POTATO]
protocols_for_listener = ["/foo/1.0.0", "/echo/1.0.0"] protocols_for_listener = [PROTOCOL_FOO, PROTOCOL_ECHO]
await perform_simple_test( await perform_simple_test(
expected_selected_protocol, protocols_for_client, protocols_for_listener expected_selected_protocol,
protocols_for_client,
protocols_for_listener,
is_host_secure,
) )
@pytest.mark.asyncio @pytest.mark.trio
async def test_multiple_protocol_second_is_valid_succeeds(): async def test_multiple_protocol_second_is_valid_succeeds(is_host_secure):
expected_selected_protocol = "/foo/1.0.0" expected_selected_protocol = PROTOCOL_FOO
protocols_for_client = ["/rock/1.0.0", "/foo/1.0.0"] protocols_for_client = [PROTOCOL_ROCK, PROTOCOL_FOO]
protocols_for_listener = ["/foo/1.0.0", "/echo/1.0.0"] protocols_for_listener = [PROTOCOL_FOO, PROTOCOL_ECHO]
await perform_simple_test( await perform_simple_test(
expected_selected_protocol, protocols_for_client, protocols_for_listener expected_selected_protocol,
protocols_for_client,
protocols_for_listener,
is_host_secure,
) )
@pytest.mark.asyncio @pytest.mark.trio
async def test_multiple_protocol_fails(): async def test_multiple_protocol_fails(is_host_secure):
protocols_for_client = ["/rock/1.0.0", "/foo/1.0.0", "/bar/1.0.0"] protocols_for_client = [PROTOCOL_ROCK, PROTOCOL_FOO, "/bar/1.0.0"]
protocols_for_listener = ["/aspyn/1.0.0", "/rob/1.0.0", "/zx/1.0.0", "/alex/1.0.0"] protocols_for_listener = ["/aspyn/1.0.0", "/rob/1.0.0", "/zx/1.0.0", "/alex/1.0.0"]
with pytest.raises(StreamFailure): with pytest.raises(StreamFailure):
await perform_simple_test("", protocols_for_client, protocols_for_listener) await perform_simple_test(
"", protocols_for_client, protocols_for_listener, is_host_secure
# Cleanup not reached on error )

View File

@ -1,58 +0,0 @@
import pytest
from libp2p.tools.constants import GOSSIPSUB_PARAMS
from libp2p.tools.factories import FloodsubFactory, GossipsubFactory, PubsubFactory
@pytest.fixture
def is_strict_signing():
return False
def _make_pubsubs(hosts, pubsub_routers, cache_size, is_strict_signing):
if len(pubsub_routers) != len(hosts):
raise ValueError(
f"lenght of pubsub_routers={pubsub_routers} should be equaled to the "
f"length of hosts={len(hosts)}"
)
return tuple(
PubsubFactory(
host=host,
router=router,
cache_size=cache_size,
strict_signing=is_strict_signing,
)
for host, router in zip(hosts, pubsub_routers)
)
@pytest.fixture
def pubsub_cache_size():
return None # default
@pytest.fixture
def gossipsub_params():
return GOSSIPSUB_PARAMS
@pytest.fixture
def pubsubs_fsub(num_hosts, hosts, pubsub_cache_size, is_strict_signing):
floodsubs = FloodsubFactory.create_batch(num_hosts)
_pubsubs_fsub = _make_pubsubs(
hosts, floodsubs, pubsub_cache_size, is_strict_signing
)
yield _pubsubs_fsub
# TODO: Clean up
@pytest.fixture
def pubsubs_gsub(
num_hosts, hosts, pubsub_cache_size, gossipsub_params, is_strict_signing
):
gossipsubs = GossipsubFactory.create_batch(num_hosts, **gossipsub_params._asdict())
_pubsubs_gsub = _make_pubsubs(
hosts, gossipsubs, pubsub_cache_size, is_strict_signing
)
yield _pubsubs_gsub
# TODO: Clean up

View File

@ -1,19 +1,10 @@
import asyncio
from threading import Thread
import pytest import pytest
import trio
from libp2p.tools.pubsub.dummy_account_node import DummyAccountNode from libp2p.tools.pubsub.dummy_account_node import DummyAccountNode
from libp2p.tools.utils import connect from libp2p.tools.utils import connect
def create_setup_in_new_thread_func(dummy_node):
def setup_in_new_thread():
asyncio.ensure_future(dummy_node.setup_crypto_networking())
return setup_in_new_thread
async def perform_test(num_nodes, adjacency_map, action_func, assertion_func): async def perform_test(num_nodes, adjacency_map, action_func, assertion_func):
""" """
Helper function to allow for easy construction of custom tests for dummy Helper function to allow for easy construction of custom tests for dummy
@ -26,47 +17,35 @@ async def perform_test(num_nodes, adjacency_map, action_func, assertion_func):
:param assertion_func: assertions for testing the results of the actions are correct :param assertion_func: assertions for testing the results of the actions are correct
""" """
# Create nodes async with DummyAccountNode.create(num_nodes) as dummy_nodes:
dummy_nodes = [] # Create connections between nodes according to `adjacency_map`
for _ in range(num_nodes): async with trio.open_nursery() as nursery:
dummy_nodes.append(await DummyAccountNode.create()) for source_num in adjacency_map:
target_nums = adjacency_map[source_num]
for target_num in target_nums:
nursery.start_soon(
connect,
dummy_nodes[source_num].host,
dummy_nodes[target_num].host,
)
# Create network # Allow time for network creation to take place
for source_num in adjacency_map: await trio.sleep(0.25)
target_nums = adjacency_map[source_num]
for target_num in target_nums:
await connect(
dummy_nodes[source_num].libp2p_node, dummy_nodes[target_num].libp2p_node
)
# Allow time for network creation to take place # Perform action function
await asyncio.sleep(0.25) await action_func(dummy_nodes)
# Start a thread for each node so that each node can listen and respond # Allow time for action function to be performed (i.e. messages to propogate)
# to messages on its own thread, which will avoid waiting indefinitely await trio.sleep(1)
# on the main thread. On this thread, call the setup func for the node,
# which subscribes the node to the CRYPTO_TOPIC topic
for dummy_node in dummy_nodes:
thread = Thread(target=create_setup_in_new_thread_func(dummy_node))
thread.run()
# Allow time for nodes to subscribe to CRYPTO_TOPIC topic # Perform assertion function
await asyncio.sleep(0.25) for dummy_node in dummy_nodes:
assertion_func(dummy_node)
# Perform action function
await action_func(dummy_nodes)
# Allow time for action function to be performed (i.e. messages to propogate)
await asyncio.sleep(1)
# Perform assertion function
for dummy_node in dummy_nodes:
assertion_func(dummy_node)
# Success, terminate pending tasks. # Success, terminate pending tasks.
@pytest.mark.asyncio @pytest.mark.trio
async def test_simple_two_nodes(): async def test_simple_two_nodes():
num_nodes = 2 num_nodes = 2
adj_map = {0: [1]} adj_map = {0: [1]}
@ -80,7 +59,7 @@ async def test_simple_two_nodes():
await perform_test(num_nodes, adj_map, action_func, assertion_func) await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio @pytest.mark.trio
async def test_simple_three_nodes_line_topography(): async def test_simple_three_nodes_line_topography():
num_nodes = 3 num_nodes = 3
adj_map = {0: [1], 1: [2]} adj_map = {0: [1], 1: [2]}
@ -94,7 +73,7 @@ async def test_simple_three_nodes_line_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func) await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio @pytest.mark.trio
async def test_simple_three_nodes_triangle_topography(): async def test_simple_three_nodes_triangle_topography():
num_nodes = 3 num_nodes = 3
adj_map = {0: [1, 2], 1: [2]} adj_map = {0: [1, 2], 1: [2]}
@ -108,7 +87,7 @@ async def test_simple_three_nodes_triangle_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func) await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio @pytest.mark.trio
async def test_simple_seven_nodes_tree_topography(): async def test_simple_seven_nodes_tree_topography():
num_nodes = 7 num_nodes = 7
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]} adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
@ -122,14 +101,14 @@ async def test_simple_seven_nodes_tree_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func) await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio @pytest.mark.trio
async def test_set_then_send_from_root_seven_nodes_tree_topography(): async def test_set_then_send_from_root_seven_nodes_tree_topography():
num_nodes = 7 num_nodes = 7
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]} adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
async def action_func(dummy_nodes): async def action_func(dummy_nodes):
await dummy_nodes[0].publish_set_crypto("aspyn", 20) await dummy_nodes[0].publish_set_crypto("aspyn", 20)
await asyncio.sleep(0.25) await trio.sleep(0.25)
await dummy_nodes[0].publish_send_crypto("aspyn", "alex", 5) await dummy_nodes[0].publish_send_crypto("aspyn", "alex", 5)
def assertion_func(dummy_node): def assertion_func(dummy_node):
@ -139,14 +118,14 @@ async def test_set_then_send_from_root_seven_nodes_tree_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func) await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio @pytest.mark.trio
async def test_set_then_send_from_different_leafs_seven_nodes_tree_topography(): async def test_set_then_send_from_different_leafs_seven_nodes_tree_topography():
num_nodes = 7 num_nodes = 7
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]} adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
async def action_func(dummy_nodes): async def action_func(dummy_nodes):
await dummy_nodes[6].publish_set_crypto("aspyn", 20) await dummy_nodes[6].publish_set_crypto("aspyn", 20)
await asyncio.sleep(0.25) await trio.sleep(0.25)
await dummy_nodes[4].publish_send_crypto("aspyn", "alex", 5) await dummy_nodes[4].publish_send_crypto("aspyn", "alex", 5)
def assertion_func(dummy_node): def assertion_func(dummy_node):
@ -156,7 +135,7 @@ async def test_set_then_send_from_different_leafs_seven_nodes_tree_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func) await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio @pytest.mark.trio
async def test_simple_five_nodes_ring_topography(): async def test_simple_five_nodes_ring_topography():
num_nodes = 5 num_nodes = 5
adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]} adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]}
@ -170,14 +149,14 @@ async def test_simple_five_nodes_ring_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func) await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio @pytest.mark.trio
async def test_set_then_send_from_diff_nodes_five_nodes_ring_topography(): async def test_set_then_send_from_diff_nodes_five_nodes_ring_topography():
num_nodes = 5 num_nodes = 5
adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]} adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]}
async def action_func(dummy_nodes): async def action_func(dummy_nodes):
await dummy_nodes[0].publish_set_crypto("alex", 20) await dummy_nodes[0].publish_set_crypto("alex", 20)
await asyncio.sleep(0.25) await trio.sleep(0.25)
await dummy_nodes[3].publish_send_crypto("alex", "rob", 12) await dummy_nodes[3].publish_send_crypto("alex", "rob", 12)
def assertion_func(dummy_node): def assertion_func(dummy_node):
@ -187,7 +166,7 @@ async def test_set_then_send_from_diff_nodes_five_nodes_ring_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func) await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio @pytest.mark.trio
@pytest.mark.slow @pytest.mark.slow
async def test_set_then_send_from_five_diff_nodes_five_nodes_ring_topography(): async def test_set_then_send_from_five_diff_nodes_five_nodes_ring_topography():
num_nodes = 5 num_nodes = 5
@ -195,13 +174,13 @@ async def test_set_then_send_from_five_diff_nodes_five_nodes_ring_topography():
async def action_func(dummy_nodes): async def action_func(dummy_nodes):
await dummy_nodes[0].publish_set_crypto("alex", 20) await dummy_nodes[0].publish_set_crypto("alex", 20)
await asyncio.sleep(1) await trio.sleep(1)
await dummy_nodes[1].publish_send_crypto("alex", "rob", 3) await dummy_nodes[1].publish_send_crypto("alex", "rob", 3)
await asyncio.sleep(1) await trio.sleep(1)
await dummy_nodes[2].publish_send_crypto("rob", "aspyn", 2) await dummy_nodes[2].publish_send_crypto("rob", "aspyn", 2)
await asyncio.sleep(1) await trio.sleep(1)
await dummy_nodes[3].publish_send_crypto("aspyn", "zx", 1) await dummy_nodes[3].publish_send_crypto("aspyn", "zx", 1)
await asyncio.sleep(1) await trio.sleep(1)
await dummy_nodes[4].publish_send_crypto("zx", "raul", 1) await dummy_nodes[4].publish_send_crypto("zx", "raul", 1)
def assertion_func(dummy_node): def assertion_func(dummy_node):

View File

@ -1,9 +1,10 @@
import asyncio import functools
import pytest import pytest
import trio
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.tools.factories import FloodsubFactory from libp2p.tools.factories import PubsubFactory
from libp2p.tools.pubsub.floodsub_integration_test_settings import ( from libp2p.tools.pubsub.floodsub_integration_test_settings import (
floodsub_protocol_pytest_params, floodsub_protocol_pytest_params,
perform_test_from_obj, perform_test_from_obj,
@ -11,79 +12,80 @@ from libp2p.tools.pubsub.floodsub_integration_test_settings import (
from libp2p.tools.utils import connect from libp2p.tools.utils import connect
@pytest.mark.parametrize("num_hosts", (2,)) @pytest.mark.trio
@pytest.mark.asyncio async def test_simple_two_nodes():
async def test_simple_two_nodes(pubsubs_fsub): async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
topic = "my_topic" topic = "my_topic"
data = b"some data" data = b"some data"
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host) await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
await asyncio.sleep(0.25) await trio.sleep(0.25)
sub_b = await pubsubs_fsub[1].subscribe(topic) sub_b = await pubsubs_fsub[1].subscribe(topic)
# Sleep to let a know of b's subscription # Sleep to let a know of b's subscription
await asyncio.sleep(0.25) await trio.sleep(0.25)
await pubsubs_fsub[0].publish(topic, data) await pubsubs_fsub[0].publish(topic, data)
res_b = await sub_b.get()
# Check that the msg received by node_b is the same
# as the message sent by node_a
assert ID(res_b.from_id) == pubsubs_fsub[0].host.get_id()
assert res_b.data == data
assert res_b.topicIDs == [topic]
# Success, terminate pending tasks.
# Initialize Pubsub with a cache_size of 4
@pytest.mark.parametrize("num_hosts, pubsub_cache_size", ((2, 4),))
@pytest.mark.asyncio
async def test_lru_cache_two_nodes(pubsubs_fsub, monkeypatch):
# two nodes with cache_size of 4
# `node_a` send the following messages to node_b
message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1]
# `node_b` should only receive the following
expected_received_indices = [1, 2, 3, 4, 5, 1]
topic = "my_topic"
# Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`.
def get_msg_id(msg):
# Originally it is `(msg.seqno, msg.from_id)`
return (msg.data, msg.from_id)
import libp2p.pubsub.pubsub
monkeypatch.setattr(libp2p.pubsub.pubsub, "get_msg_id", get_msg_id)
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
await asyncio.sleep(0.25)
sub_b = await pubsubs_fsub[1].subscribe(topic)
await asyncio.sleep(0.25)
def _make_testing_data(i: int) -> bytes:
num_int_bytes = 4
if i >= 2 ** (num_int_bytes * 8):
raise ValueError("integer is too large to be serialized")
return b"data" + i.to_bytes(num_int_bytes, "big")
for index in message_indices:
await pubsubs_fsub[0].publish(topic, _make_testing_data(index))
await asyncio.sleep(0.25)
for index in expected_received_indices:
res_b = await sub_b.get() res_b = await sub_b.get()
assert res_b.data == _make_testing_data(index)
assert sub_b.empty()
# Success, terminate pending tasks. # Check that the msg received by node_b is the same
# as the message sent by node_a
assert ID(res_b.from_id) == pubsubs_fsub[0].host.get_id()
assert res_b.data == data
assert res_b.topicIDs == [topic]
@pytest.mark.trio
async def test_lru_cache_two_nodes(monkeypatch):
# two nodes with cache_size of 4
async with PubsubFactory.create_batch_with_floodsub(
2, cache_size=4
) as pubsubs_fsub:
# `node_a` send the following messages to node_b
message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1]
# `node_b` should only receive the following
expected_received_indices = [1, 2, 3, 4, 5, 1]
topic = "my_topic"
# Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`.
def get_msg_id(msg):
# Originally it is `(msg.seqno, msg.from_id)`
return (msg.data, msg.from_id)
import libp2p.pubsub.pubsub
monkeypatch.setattr(libp2p.pubsub.pubsub, "get_msg_id", get_msg_id)
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
await trio.sleep(0.25)
sub_b = await pubsubs_fsub[1].subscribe(topic)
await trio.sleep(0.25)
def _make_testing_data(i: int) -> bytes:
num_int_bytes = 4
if i >= 2 ** (num_int_bytes * 8):
raise ValueError("integer is too large to be serialized")
return b"data" + i.to_bytes(num_int_bytes, "big")
for index in message_indices:
await pubsubs_fsub[0].publish(topic, _make_testing_data(index))
await trio.sleep(0.25)
for index in expected_received_indices:
res_b = await sub_b.get()
assert res_b.data == _make_testing_data(index)
@pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params) @pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params)
@pytest.mark.asyncio @pytest.mark.trio
@pytest.mark.slow @pytest.mark.slow
async def test_gossipsub_run_with_floodsub_tests(test_case_obj): async def test_gossipsub_run_with_floodsub_tests(test_case_obj, is_host_secure):
await perform_test_from_obj(test_case_obj, FloodsubFactory) await perform_test_from_obj(
test_case_obj,
functools.partial(
PubsubFactory.create_batch_with_floodsub, is_secure=is_host_secure
),
)

View File

@ -1,495 +1,478 @@
import asyncio
import random import random
import pytest import pytest
import trio
from libp2p.peer.id import ID
from libp2p.pubsub.gossipsub import PROTOCOL_ID from libp2p.pubsub.gossipsub import PROTOCOL_ID
from libp2p.tools.constants import GOSSIPSUB_PARAMS, GossipsubParams from libp2p.tools.factories import IDFactory, PubsubFactory
from libp2p.tools.pubsub.utils import dense_connect, one_to_all_connect from libp2p.tools.pubsub.utils import dense_connect, one_to_all_connect
from libp2p.tools.utils import connect from libp2p.tools.utils import connect
@pytest.mark.parametrize( @pytest.mark.trio
"num_hosts, gossipsub_params", async def test_join():
((4, GossipsubParams(degree=4, degree_low=3, degree_high=5)),), async with PubsubFactory.create_batch_with_gossipsub(
) 4, degree=4, degree_low=3, degree_high=5
@pytest.mark.asyncio ) as pubsubs_gsub:
async def test_join(num_hosts, hosts, pubsubs_gsub): gossipsubs = [pubsub.router for pubsub in pubsubs_gsub]
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub) hosts = [pubsub.host for pubsub in pubsubs_gsub]
hosts_indices = list(range(num_hosts)) hosts_indices = list(range(len(pubsubs_gsub)))
topic = "test_join" topic = "test_join"
central_node_index = 0 central_node_index = 0
# Remove index of central host from the indices # Remove index of central host from the indices
hosts_indices.remove(central_node_index) hosts_indices.remove(central_node_index)
num_subscribed_peer = 2 num_subscribed_peer = 2
subscribed_peer_indices = random.sample(hosts_indices, num_subscribed_peer) subscribed_peer_indices = random.sample(hosts_indices, num_subscribed_peer)
# All pubsub except the one of central node subscribe to topic # All pubsub except the one of central node subscribe to topic
for i in subscribed_peer_indices: for i in subscribed_peer_indices:
await pubsubs_gsub[i].subscribe(topic) await pubsubs_gsub[i].subscribe(topic)
# Connect central host to all other hosts # Connect central host to all other hosts
await one_to_all_connect(hosts, central_node_index) await one_to_all_connect(hosts, central_node_index)
# Wait 2 seconds for heartbeat to allow mesh to connect # Wait 2 seconds for heartbeat to allow mesh to connect
await asyncio.sleep(2) await trio.sleep(2)
# Central node publish to the topic so that this topic
# is added to central node's fanout
# publish from the randomly chosen host
await pubsubs_gsub[central_node_index].publish(topic, b"data")
# Check that the gossipsub of central node has fanout for the topic
assert topic in gossipsubs[central_node_index].fanout
# Check that the gossipsub of central node does not have a mesh for the topic
assert topic not in gossipsubs[central_node_index].mesh
# Central node subscribes the topic
await pubsubs_gsub[central_node_index].subscribe(topic)
await asyncio.sleep(2)
# Check that the gossipsub of central node no longer has fanout for the topic
assert topic not in gossipsubs[central_node_index].fanout
for i in hosts_indices:
if i in subscribed_peer_indices:
assert hosts[i].get_id() in gossipsubs[central_node_index].mesh[topic]
assert hosts[central_node_index].get_id() in gossipsubs[i].mesh[topic]
else:
assert hosts[i].get_id() not in gossipsubs[central_node_index].mesh[topic]
assert topic not in gossipsubs[i].mesh
@pytest.mark.parametrize("num_hosts", (1,))
@pytest.mark.asyncio
async def test_leave(pubsubs_gsub):
gossipsub = pubsubs_gsub[0].router
topic = "test_leave"
assert topic not in gossipsub.mesh
await gossipsub.join(topic)
assert topic in gossipsub.mesh
await gossipsub.leave(topic)
assert topic not in gossipsub.mesh
# Test re-leave
await gossipsub.leave(topic)
@pytest.mark.parametrize("num_hosts", (2,))
@pytest.mark.asyncio
async def test_handle_graft(pubsubs_gsub, hosts, event_loop, monkeypatch):
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub)
index_alice = 0
id_alice = hosts[index_alice].get_id()
index_bob = 1
id_bob = hosts[index_bob].get_id()
await connect(hosts[index_alice], hosts[index_bob])
# Wait 2 seconds for heartbeat to allow mesh to connect
await asyncio.sleep(2)
topic = "test_handle_graft"
# Only lice subscribe to the topic
await gossipsubs[index_alice].join(topic)
# Monkey patch bob's `emit_prune` function so we can
# check if it is called in `handle_graft`
event_emit_prune = asyncio.Event()
async def emit_prune(topic, sender_peer_id):
event_emit_prune.set()
monkeypatch.setattr(gossipsubs[index_bob], "emit_prune", emit_prune)
# Check that alice is bob's peer but not his mesh peer
assert gossipsubs[index_bob].peer_protocol[id_alice] == PROTOCOL_ID
assert topic not in gossipsubs[index_bob].mesh
await gossipsubs[index_alice].emit_graft(topic, id_bob)
# Check that `emit_prune` is called
await asyncio.wait_for(event_emit_prune.wait(), timeout=1, loop=event_loop)
assert event_emit_prune.is_set()
# Check that bob is alice's peer but not her mesh peer
assert topic in gossipsubs[index_alice].mesh
assert id_bob not in gossipsubs[index_alice].mesh[topic]
assert gossipsubs[index_alice].peer_protocol[id_bob] == PROTOCOL_ID
await gossipsubs[index_bob].emit_graft(topic, id_alice)
await asyncio.sleep(1)
# Check that bob is now alice's mesh peer
assert id_bob in gossipsubs[index_alice].mesh[topic]
@pytest.mark.parametrize(
"num_hosts, gossipsub_params", ((2, GossipsubParams(heartbeat_interval=3)),)
)
@pytest.mark.asyncio
async def test_handle_prune(pubsubs_gsub, hosts):
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub)
index_alice = 0
id_alice = hosts[index_alice].get_id()
index_bob = 1
id_bob = hosts[index_bob].get_id()
topic = "test_handle_prune"
for pubsub in pubsubs_gsub:
await pubsub.subscribe(topic)
await connect(hosts[index_alice], hosts[index_bob])
# Wait for heartbeat to allow mesh to connect
await asyncio.sleep(1)
# Check that they are each other's mesh peer
assert id_alice in gossipsubs[index_bob].mesh[topic]
assert id_bob in gossipsubs[index_alice].mesh[topic]
# alice emit prune message to bob, alice should be removed
# from bob's mesh peer
await gossipsubs[index_alice].emit_prune(topic, id_bob)
# `emit_prune` does not remove bob from alice's mesh peers
assert id_bob in gossipsubs[index_alice].mesh[topic]
# NOTE: We increase `heartbeat_interval` to 3 seconds so that bob will not
# add alice back to his mesh after heartbeat.
# Wait for bob to `handle_prune`
await asyncio.sleep(0.1)
# Check that alice is no longer bob's mesh peer
assert id_alice not in gossipsubs[index_bob].mesh[topic]
@pytest.mark.parametrize("num_hosts", (10,))
@pytest.mark.asyncio
async def test_dense(num_hosts, pubsubs_gsub, hosts):
num_msgs = 5
# All pubsub subscribe to foobar
queues = []
for pubsub in pubsubs_gsub:
q = await pubsub.subscribe("foobar")
# Add each blocking queue to an array of blocking queues
queues.append(q)
# Densely connect libp2p hosts in a random way
await dense_connect(hosts)
# Wait 2 seconds for heartbeat to allow mesh to connect
await asyncio.sleep(2)
for i in range(num_msgs):
msg_content = b"foo " + i.to_bytes(1, "big")
# randomly pick a message origin
origin_idx = random.randint(0, num_hosts - 1)
# Central node publish to the topic so that this topic
# is added to central node's fanout
# publish from the randomly chosen host # publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish("foobar", msg_content) await pubsubs_gsub[central_node_index].publish(topic, b"data")
await asyncio.sleep(0.5) # Check that the gossipsub of central node has fanout for the topic
# Assert that all blocking queues receive the message assert topic in gossipsubs[central_node_index].fanout
for queue in queues: # Check that the gossipsub of central node does not have a mesh for the topic
msg = await queue.get() assert topic not in gossipsubs[central_node_index].mesh
assert msg.data == msg_content
# Central node subscribes the topic
await pubsubs_gsub[central_node_index].subscribe(topic)
await trio.sleep(2)
# Check that the gossipsub of central node no longer has fanout for the topic
assert topic not in gossipsubs[central_node_index].fanout
for i in hosts_indices:
if i in subscribed_peer_indices:
assert hosts[i].get_id() in gossipsubs[central_node_index].mesh[topic]
assert hosts[central_node_index].get_id() in gossipsubs[i].mesh[topic]
else:
assert (
hosts[i].get_id() not in gossipsubs[central_node_index].mesh[topic]
)
assert topic not in gossipsubs[i].mesh
@pytest.mark.parametrize("num_hosts", (10,)) @pytest.mark.trio
@pytest.mark.asyncio async def test_leave():
async def test_fanout(hosts, pubsubs_gsub): async with PubsubFactory.create_batch_with_gossipsub(1) as pubsubs_gsub:
num_msgs = 5 gossipsub = pubsubs_gsub[0].router
topic = "test_leave"
# All pubsub subscribe to foobar except for `pubsubs_gsub[0]` assert topic not in gossipsub.mesh
queues = []
for i in range(1, len(pubsubs_gsub)):
q = await pubsubs_gsub[i].subscribe("foobar")
# Add each blocking queue to an array of blocking queues await gossipsub.join(topic)
queues.append(q) assert topic in gossipsub.mesh
# Sparsely connect libp2p hosts in random way await gossipsub.leave(topic)
await dense_connect(hosts) assert topic not in gossipsub.mesh
# Wait 2 seconds for heartbeat to allow mesh to connect # Test re-leave
await asyncio.sleep(2) await gossipsub.leave(topic)
topic = "foobar"
# Send messages with origin not subscribed
for i in range(num_msgs):
msg_content = b"foo " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await asyncio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.get()
assert msg.data == msg_content
# Subscribe message origin
queues.insert(0, await pubsubs_gsub[0].subscribe(topic))
# Send messages again
for i in range(num_msgs):
msg_content = b"bar " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await asyncio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.get()
assert msg.data == msg_content
@pytest.mark.parametrize("num_hosts", (10,)) @pytest.mark.trio
@pytest.mark.asyncio async def test_handle_graft(monkeypatch):
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub)
index_alice = 0
id_alice = pubsubs_gsub[index_alice].my_id
index_bob = 1
id_bob = pubsubs_gsub[index_bob].my_id
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
# Wait 2 seconds for heartbeat to allow mesh to connect
await trio.sleep(2)
topic = "test_handle_graft"
# Only lice subscribe to the topic
await gossipsubs[index_alice].join(topic)
# Monkey patch bob's `emit_prune` function so we can
# check if it is called in `handle_graft`
event_emit_prune = trio.Event()
async def emit_prune(topic, sender_peer_id):
event_emit_prune.set()
await trio.hazmat.checkpoint()
monkeypatch.setattr(gossipsubs[index_bob], "emit_prune", emit_prune)
# Check that alice is bob's peer but not his mesh peer
assert gossipsubs[index_bob].peer_protocol[id_alice] == PROTOCOL_ID
assert topic not in gossipsubs[index_bob].mesh
await gossipsubs[index_alice].emit_graft(topic, id_bob)
# Check that `emit_prune` is called
await event_emit_prune.wait()
# Check that bob is alice's peer but not her mesh peer
assert topic in gossipsubs[index_alice].mesh
assert id_bob not in gossipsubs[index_alice].mesh[topic]
assert gossipsubs[index_alice].peer_protocol[id_bob] == PROTOCOL_ID
await gossipsubs[index_bob].emit_graft(topic, id_alice)
await trio.sleep(1)
# Check that bob is now alice's mesh peer
assert id_bob in gossipsubs[index_alice].mesh[topic]
@pytest.mark.trio
async def test_handle_prune():
async with PubsubFactory.create_batch_with_gossipsub(
2, heartbeat_interval=3
) as pubsubs_gsub:
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub)
index_alice = 0
id_alice = pubsubs_gsub[index_alice].my_id
index_bob = 1
id_bob = pubsubs_gsub[index_bob].my_id
topic = "test_handle_prune"
for pubsub in pubsubs_gsub:
await pubsub.subscribe(topic)
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
# Wait for heartbeat to allow mesh to connect
await trio.sleep(1)
# Check that they are each other's mesh peer
assert id_alice in gossipsubs[index_bob].mesh[topic]
assert id_bob in gossipsubs[index_alice].mesh[topic]
# alice emit prune message to bob, alice should be removed
# from bob's mesh peer
await gossipsubs[index_alice].emit_prune(topic, id_bob)
# `emit_prune` does not remove bob from alice's mesh peers
assert id_bob in gossipsubs[index_alice].mesh[topic]
# NOTE: We increase `heartbeat_interval` to 3 seconds so that bob will not
# add alice back to his mesh after heartbeat.
# Wait for bob to `handle_prune`
await trio.sleep(0.1)
# Check that alice is no longer bob's mesh peer
assert id_alice not in gossipsubs[index_bob].mesh[topic]
@pytest.mark.trio
async def test_dense():
async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub:
hosts = [pubsub.host for pubsub in pubsubs_gsub]
num_msgs = 5
# All pubsub subscribe to foobar
queues = [await pubsub.subscribe("foobar") for pubsub in pubsubs_gsub]
# Densely connect libp2p hosts in a random way
await dense_connect(hosts)
# Wait 2 seconds for heartbeat to allow mesh to connect
await trio.sleep(2)
for i in range(num_msgs):
msg_content = b"foo " + i.to_bytes(1, "big")
# randomly pick a message origin
origin_idx = random.randint(0, len(hosts) - 1)
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish("foobar", msg_content)
await trio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.get()
assert msg.data == msg_content
@pytest.mark.trio
async def test_fanout():
async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub:
hosts = [pubsub.host for pubsub in pubsubs_gsub]
num_msgs = 5
# All pubsub subscribe to foobar except for `pubsubs_gsub[0]`
subs = [await pubsub.subscribe("foobar") for pubsub in pubsubs_gsub[1:]]
# Sparsely connect libp2p hosts in random way
await dense_connect(hosts)
# Wait 2 seconds for heartbeat to allow mesh to connect
await trio.sleep(2)
topic = "foobar"
# Send messages with origin not subscribed
for i in range(num_msgs):
msg_content = b"foo " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await trio.sleep(0.5)
# Assert that all blocking queues receive the message
for sub in subs:
msg = await sub.get()
assert msg.data == msg_content
# Subscribe message origin
subs.insert(0, await pubsubs_gsub[0].subscribe(topic))
# Send messages again
for i in range(num_msgs):
msg_content = b"bar " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await trio.sleep(0.5)
# Assert that all blocking queues receive the message
for sub in subs:
msg = await sub.get()
assert msg.data == msg_content
@pytest.mark.trio
@pytest.mark.slow @pytest.mark.slow
async def test_fanout_maintenance(hosts, pubsubs_gsub): async def test_fanout_maintenance():
num_msgs = 5 async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub:
hosts = [pubsub.host for pubsub in pubsubs_gsub]
num_msgs = 5
# All pubsub subscribe to foobar # All pubsub subscribe to foobar
queues = [] queues = []
topic = "foobar" topic = "foobar"
for i in range(1, len(pubsubs_gsub)): for i in range(1, len(pubsubs_gsub)):
q = await pubsubs_gsub[i].subscribe(topic) q = await pubsubs_gsub[i].subscribe(topic)
# Add each blocking queue to an array of blocking queues # Add each blocking queue to an array of blocking queues
queues.append(q) queues.append(q)
# Sparsely connect libp2p hosts in random way # Sparsely connect libp2p hosts in random way
await dense_connect(hosts) await dense_connect(hosts)
# Wait 2 seconds for heartbeat to allow mesh to connect # Wait 2 seconds for heartbeat to allow mesh to connect
await asyncio.sleep(2) await trio.sleep(2)
# Send messages with origin not subscribed # Send messages with origin not subscribed
for i in range(num_msgs): for i in range(num_msgs):
msg_content = b"foo " + i.to_bytes(1, "big") msg_content = b"foo " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar' # Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0 origin_idx = 0
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await trio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.get()
assert msg.data == msg_content
for sub in pubsubs_gsub:
await sub.unsubscribe(topic)
queues = []
await trio.sleep(2)
# Resub and repeat
for i in range(1, len(pubsubs_gsub)):
q = await pubsubs_gsub[i].subscribe(topic)
# Add each blocking queue to an array of blocking queues
queues.append(q)
await trio.sleep(2)
# Check messages can still be sent
for i in range(num_msgs):
msg_content = b"bar " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await trio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.get()
assert msg.data == msg_content
@pytest.mark.trio
async def test_gossip_propagation():
async with PubsubFactory.create_batch_with_gossipsub(
2, degree=1, degree_low=0, degree_high=2, gossip_window=50, gossip_history=100
) as pubsubs_gsub:
topic = "foo"
queue_0 = await pubsubs_gsub[0].subscribe(topic)
# node 0 publish to topic
msg_content = b"foo_msg"
# publish from the randomly chosen host # publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content) await pubsubs_gsub[0].publish(topic, msg_content)
await asyncio.sleep(0.5) await trio.sleep(0.5)
# Assert that all blocking queues receive the message # Assert that the blocking queues receive the message
for queue in queues: msg = await queue_0.get()
msg = await queue.get() assert msg.data == msg_content
assert msg.data == msg_content
for sub in pubsubs_gsub:
await sub.unsubscribe(topic)
queues = []
await asyncio.sleep(2)
# Resub and repeat
for i in range(1, len(pubsubs_gsub)):
q = await pubsubs_gsub[i].subscribe(topic)
# Add each blocking queue to an array of blocking queues
queues.append(q)
await asyncio.sleep(2)
# Check messages can still be sent
for i in range(num_msgs):
msg_content = b"bar " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await asyncio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.get()
assert msg.data == msg_content
@pytest.mark.parametrize(
"num_hosts, gossipsub_params",
(
(
2,
GossipsubParams(
degree=1,
degree_low=0,
degree_high=2,
gossip_window=50,
gossip_history=100,
),
),
),
)
@pytest.mark.asyncio
async def test_gossip_propagation(hosts, pubsubs_gsub):
topic = "foo"
await pubsubs_gsub[0].subscribe(topic)
# node 0 publish to topic
msg_content = b"foo_msg"
# publish from the randomly chosen host
await pubsubs_gsub[0].publish(topic, msg_content)
# now node 1 subscribes
queue_1 = await pubsubs_gsub[1].subscribe(topic)
await connect(hosts[0], hosts[1])
# wait for gossip heartbeat
await asyncio.sleep(2)
# should be able to read message
msg = await queue_1.get()
assert msg.data == msg_content
@pytest.mark.parametrize(
"num_hosts, gossipsub_params", ((1, GossipsubParams(heartbeat_initial_delay=100)),)
)
@pytest.mark.parametrize("initial_mesh_peer_count", (7, 10, 13)) @pytest.mark.parametrize("initial_mesh_peer_count", (7, 10, 13))
@pytest.mark.asyncio @pytest.mark.trio
async def test_mesh_heartbeat( async def test_mesh_heartbeat(initial_mesh_peer_count, monkeypatch):
num_hosts, initial_mesh_peer_count, pubsubs_gsub, hosts, monkeypatch async with PubsubFactory.create_batch_with_gossipsub(
): 1, heartbeat_initial_delay=100
# It's difficult to set up the initial peer subscription condition. ) as pubsubs_gsub:
# Ideally I would like to have initial mesh peer count that's below ``GossipSubDegree`` # It's difficult to set up the initial peer subscription condition.
# so I can test if `mesh_heartbeat` return correct peers to GRAFT. # Ideally I would like to have initial mesh peer count that's below ``GossipSubDegree``
# The problem is that I can not set it up so that we have peers subscribe to the topic # so I can test if `mesh_heartbeat` return correct peers to GRAFT.
# but not being part of our mesh peers (as these peers are the peers to GRAFT). # The problem is that I can not set it up so that we have peers subscribe to the topic
# So I monkeypatch the peer subscriptions and our mesh peers. # but not being part of our mesh peers (as these peers are the peers to GRAFT).
total_peer_count = 14 # So I monkeypatch the peer subscriptions and our mesh peers.
topic = "TEST_MESH_HEARTBEAT" total_peer_count = 14
topic = "TEST_MESH_HEARTBEAT"
fake_peer_ids = [ fake_peer_ids = [IDFactory() for _ in range(total_peer_count)]
ID((i).to_bytes(2, byteorder="big")) for i in range(total_peer_count) peer_protocol = {peer_id: PROTOCOL_ID for peer_id in fake_peer_ids}
] monkeypatch.setattr(pubsubs_gsub[0].router, "peer_protocol", peer_protocol)
peer_protocol = {peer_id: PROTOCOL_ID for peer_id in fake_peer_ids}
monkeypatch.setattr(pubsubs_gsub[0].router, "peer_protocol", peer_protocol)
peer_topics = {topic: set(fake_peer_ids)} peer_topics = {topic: set(fake_peer_ids)}
# Monkeypatch the peer subscriptions # Monkeypatch the peer subscriptions
monkeypatch.setattr(pubsubs_gsub[0], "peer_topics", peer_topics) monkeypatch.setattr(pubsubs_gsub[0], "peer_topics", peer_topics)
mesh_peer_indices = random.sample(range(total_peer_count), initial_mesh_peer_count) mesh_peer_indices = random.sample(
mesh_peers = [fake_peer_ids[i] for i in mesh_peer_indices] range(total_peer_count), initial_mesh_peer_count
router_mesh = {topic: set(mesh_peers)} )
# Monkeypatch our mesh peers mesh_peers = [fake_peer_ids[i] for i in mesh_peer_indices]
monkeypatch.setattr(pubsubs_gsub[0].router, "mesh", router_mesh) router_mesh = {topic: set(mesh_peers)}
# Monkeypatch our mesh peers
monkeypatch.setattr(pubsubs_gsub[0].router, "mesh", router_mesh)
peers_to_graft, peers_to_prune = pubsubs_gsub[0].router.mesh_heartbeat() peers_to_graft, peers_to_prune = pubsubs_gsub[0].router.mesh_heartbeat()
if initial_mesh_peer_count > GOSSIPSUB_PARAMS.degree: if initial_mesh_peer_count > pubsubs_gsub[0].router.degree:
# If number of initial mesh peers is more than `GossipSubDegree`, we should PRUNE mesh peers # If number of initial mesh peers is more than `GossipSubDegree`,
assert len(peers_to_graft) == 0 # we should PRUNE mesh peers
assert len(peers_to_prune) == initial_mesh_peer_count - GOSSIPSUB_PARAMS.degree assert len(peers_to_graft) == 0
for peer in peers_to_prune: assert (
assert peer in mesh_peers len(peers_to_prune)
elif initial_mesh_peer_count < GOSSIPSUB_PARAMS.degree: == initial_mesh_peer_count - pubsubs_gsub[0].router.degree
# If number of initial mesh peers is less than `GossipSubDegree`, we should GRAFT more peers )
assert len(peers_to_prune) == 0 for peer in peers_to_prune:
assert len(peers_to_graft) == GOSSIPSUB_PARAMS.degree - initial_mesh_peer_count assert peer in mesh_peers
for peer in peers_to_graft: elif initial_mesh_peer_count < pubsubs_gsub[0].router.degree:
assert peer not in mesh_peers # If number of initial mesh peers is less than `GossipSubDegree`,
else: # we should GRAFT more peers
assert len(peers_to_prune) == 0 and len(peers_to_graft) == 0 assert len(peers_to_prune) == 0
assert (
len(peers_to_graft)
@pytest.mark.parametrize( == pubsubs_gsub[0].router.degree - initial_mesh_peer_count
"num_hosts, gossipsub_params", ((1, GossipsubParams(heartbeat_initial_delay=100)),) )
) for peer in peers_to_graft:
@pytest.mark.parametrize("initial_peer_count", (1, 4, 7)) assert peer not in mesh_peers
@pytest.mark.asyncio
async def test_gossip_heartbeat(
num_hosts, initial_peer_count, pubsubs_gsub, hosts, monkeypatch
):
# The problem is that I can not set it up so that we have peers subscribe to the topic
# but not being part of our mesh peers (as these peers are the peers to GRAFT).
# So I monkeypatch the peer subscriptions and our mesh peers.
total_peer_count = 28
topic_mesh = "TEST_GOSSIP_HEARTBEAT_1"
topic_fanout = "TEST_GOSSIP_HEARTBEAT_2"
fake_peer_ids = [
ID((i).to_bytes(2, byteorder="big")) for i in range(total_peer_count)
]
peer_protocol = {peer_id: PROTOCOL_ID for peer_id in fake_peer_ids}
monkeypatch.setattr(pubsubs_gsub[0].router, "peer_protocol", peer_protocol)
topic_mesh_peer_count = 14
# Split into mesh peers and fanout peers
peer_topics = {
topic_mesh: set(fake_peer_ids[:topic_mesh_peer_count]),
topic_fanout: set(fake_peer_ids[topic_mesh_peer_count:]),
}
# Monkeypatch the peer subscriptions
monkeypatch.setattr(pubsubs_gsub[0], "peer_topics", peer_topics)
mesh_peer_indices = random.sample(range(topic_mesh_peer_count), initial_peer_count)
mesh_peers = [fake_peer_ids[i] for i in mesh_peer_indices]
router_mesh = {topic_mesh: set(mesh_peers)}
# Monkeypatch our mesh peers
monkeypatch.setattr(pubsubs_gsub[0].router, "mesh", router_mesh)
fanout_peer_indices = random.sample(
range(topic_mesh_peer_count, total_peer_count), initial_peer_count
)
fanout_peers = [fake_peer_ids[i] for i in fanout_peer_indices]
router_fanout = {topic_fanout: set(fanout_peers)}
# Monkeypatch our fanout peers
monkeypatch.setattr(pubsubs_gsub[0].router, "fanout", router_fanout)
def window(topic):
if topic == topic_mesh:
return [topic_mesh]
elif topic == topic_fanout:
return [topic_fanout]
else: else:
return [] assert len(peers_to_prune) == 0 and len(peers_to_graft) == 0
# Monkeypatch the memory cache messages
monkeypatch.setattr(pubsubs_gsub[0].router.mcache, "window", window)
peers_to_gossip = pubsubs_gsub[0].router.gossip_heartbeat() @pytest.mark.parametrize("initial_peer_count", (1, 4, 7))
# If our mesh peer count is less than `GossipSubDegree`, we should gossip to up to @pytest.mark.trio
# `GossipSubDegree` peers (exclude mesh peers). async def test_gossip_heartbeat(initial_peer_count, monkeypatch):
if topic_mesh_peer_count - initial_peer_count < GOSSIPSUB_PARAMS.degree: async with PubsubFactory.create_batch_with_gossipsub(
# The same goes for fanout so it's two times the number of peers to gossip. 1, heartbeat_initial_delay=100
assert len(peers_to_gossip) == 2 * (topic_mesh_peer_count - initial_peer_count) ) as pubsubs_gsub:
elif topic_mesh_peer_count - initial_peer_count >= GOSSIPSUB_PARAMS.degree: # The problem is that I can not set it up so that we have peers subscribe to the topic
assert len(peers_to_gossip) == 2 * (GOSSIPSUB_PARAMS.degree) # but not being part of our mesh peers (as these peers are the peers to GRAFT).
# So I monkeypatch the peer subscriptions and our mesh peers.
total_peer_count = 28
topic_mesh = "TEST_GOSSIP_HEARTBEAT_1"
topic_fanout = "TEST_GOSSIP_HEARTBEAT_2"
for peer in peers_to_gossip: fake_peer_ids = [IDFactory() for _ in range(total_peer_count)]
if peer in peer_topics[topic_mesh]: peer_protocol = {peer_id: PROTOCOL_ID for peer_id in fake_peer_ids}
# Check that the peer to gossip to is not in our mesh peers monkeypatch.setattr(pubsubs_gsub[0].router, "peer_protocol", peer_protocol)
assert peer not in mesh_peers
assert topic_mesh in peers_to_gossip[peer] topic_mesh_peer_count = 14
elif peer in peer_topics[topic_fanout]: # Split into mesh peers and fanout peers
# Check that the peer to gossip to is not in our fanout peers peer_topics = {
assert peer not in fanout_peers topic_mesh: set(fake_peer_ids[:topic_mesh_peer_count]),
assert topic_fanout in peers_to_gossip[peer] topic_fanout: set(fake_peer_ids[topic_mesh_peer_count:]),
}
# Monkeypatch the peer subscriptions
monkeypatch.setattr(pubsubs_gsub[0], "peer_topics", peer_topics)
mesh_peer_indices = random.sample(
range(topic_mesh_peer_count), initial_peer_count
)
mesh_peers = [fake_peer_ids[i] for i in mesh_peer_indices]
router_mesh = {topic_mesh: set(mesh_peers)}
# Monkeypatch our mesh peers
monkeypatch.setattr(pubsubs_gsub[0].router, "mesh", router_mesh)
fanout_peer_indices = random.sample(
range(topic_mesh_peer_count, total_peer_count), initial_peer_count
)
fanout_peers = [fake_peer_ids[i] for i in fanout_peer_indices]
router_fanout = {topic_fanout: set(fanout_peers)}
# Monkeypatch our fanout peers
monkeypatch.setattr(pubsubs_gsub[0].router, "fanout", router_fanout)
def window(topic):
if topic == topic_mesh:
return [topic_mesh]
elif topic == topic_fanout:
return [topic_fanout]
else:
return []
# Monkeypatch the memory cache messages
monkeypatch.setattr(pubsubs_gsub[0].router.mcache, "window", window)
peers_to_gossip = pubsubs_gsub[0].router.gossip_heartbeat()
# If our mesh peer count is less than `GossipSubDegree`, we should gossip to up to
# `GossipSubDegree` peers (exclude mesh peers).
if topic_mesh_peer_count - initial_peer_count < pubsubs_gsub[0].router.degree:
# The same goes for fanout so it's two times the number of peers to gossip.
assert len(peers_to_gossip) == 2 * (
topic_mesh_peer_count - initial_peer_count
)
elif (
topic_mesh_peer_count - initial_peer_count >= pubsubs_gsub[0].router.degree
):
assert len(peers_to_gossip) == 2 * (pubsubs_gsub[0].router.degree)
for peer in peers_to_gossip:
if peer in peer_topics[topic_mesh]:
# Check that the peer to gossip to is not in our mesh peers
assert peer not in mesh_peers
assert topic_mesh in peers_to_gossip[peer]
elif peer in peer_topics[topic_fanout]:
# Check that the peer to gossip to is not in our fanout peers
assert peer not in fanout_peers
assert topic_fanout in peers_to_gossip[peer]

View File

@ -3,25 +3,25 @@ import functools
import pytest import pytest
from libp2p.tools.constants import FLOODSUB_PROTOCOL_ID from libp2p.tools.constants import FLOODSUB_PROTOCOL_ID
from libp2p.tools.factories import GossipsubFactory from libp2p.tools.factories import PubsubFactory
from libp2p.tools.pubsub.floodsub_integration_test_settings import ( from libp2p.tools.pubsub.floodsub_integration_test_settings import (
floodsub_protocol_pytest_params, floodsub_protocol_pytest_params,
perform_test_from_obj, perform_test_from_obj,
) )
@pytest.mark.asyncio
async def test_gossipsub_initialize_with_floodsub_protocol():
GossipsubFactory(protocols=[FLOODSUB_PROTOCOL_ID])
@pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params) @pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params)
@pytest.mark.asyncio @pytest.mark.trio
@pytest.mark.slow @pytest.mark.slow
async def test_gossipsub_run_with_floodsub_tests(test_case_obj): async def test_gossipsub_run_with_floodsub_tests(test_case_obj):
await perform_test_from_obj( await perform_test_from_obj(
test_case_obj, test_case_obj,
functools.partial( functools.partial(
GossipsubFactory, degree=3, degree_low=2, degree_high=4, time_to_live=30 PubsubFactory.create_batch_with_gossipsub,
protocols=[FLOODSUB_PROTOCOL_ID],
degree=3,
degree_low=2,
degree_high=4,
time_to_live=30,
), ),
) )

View File

@ -1,5 +1,3 @@
import pytest
from libp2p.pubsub.mcache import MessageCache from libp2p.pubsub.mcache import MessageCache
@ -12,8 +10,7 @@ class Msg:
self.from_id = from_id self.from_id = from_id
@pytest.mark.asyncio def test_mcache():
async def test_mcache():
# Ported from: # Ported from:
# https://github.com/libp2p/go-libp2p-pubsub/blob/51b7501433411b5096cac2b4994a36a68515fc03/mcache_test.go # https://github.com/libp2p/go-libp2p-pubsub/blob/51b7501433411b5096cac2b4994a36a68515fc03/mcache_test.go
mcache = MessageCache(3, 5) mcache = MessageCache(3, 5)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,84 @@
import math
import pytest
import trio
from libp2p.pubsub.pb import rpc_pb2
from libp2p.pubsub.subscription import TrioSubscriptionAPI
GET_TIMEOUT = 0.001
def make_trio_subscription():
send_channel, receive_channel = trio.open_memory_channel(math.inf)
async def unsubscribe_fn():
await send_channel.aclose()
return (
send_channel,
TrioSubscriptionAPI(receive_channel, unsubscribe_fn=unsubscribe_fn),
)
def make_pubsub_msg():
return rpc_pb2.Message()
async def send_something(send_channel):
msg = make_pubsub_msg()
await send_channel.send(msg)
return msg
@pytest.mark.trio
async def test_trio_subscription_get():
send_channel, sub = make_trio_subscription()
data_0 = await send_something(send_channel)
data_1 = await send_something(send_channel)
assert data_0 == await sub.get()
assert data_1 == await sub.get()
# No more message
with pytest.raises(trio.TooSlowError):
with trio.fail_after(GET_TIMEOUT):
await sub.get()
@pytest.mark.trio
async def test_trio_subscription_iter():
send_channel, sub = make_trio_subscription()
received_data = []
async def iter_subscriptions(subscription):
async for data in sub:
received_data.append(data)
async with trio.open_nursery() as nursery:
nursery.start_soon(iter_subscriptions, sub)
await send_something(send_channel)
await send_something(send_channel)
await send_channel.aclose()
assert len(received_data) == 2
@pytest.mark.trio
async def test_trio_subscription_unsubscribe():
send_channel, sub = make_trio_subscription()
await sub.unsubscribe()
# Test: If the subscription is unsubscribed, `send_channel` should be closed.
with pytest.raises(trio.ClosedResourceError):
await send_something(send_channel)
# Test: No side effect when cancelled twice.
await sub.unsubscribe()
@pytest.mark.trio
async def test_trio_subscription_async_context_manager():
send_channel, sub = make_trio_subscription()
async with sub:
# Test: `sub` is not cancelled yet, so `send_something` works fine.
await send_something(send_channel)
# Test: `sub` is cancelled, `send_something` fails
with pytest.raises(trio.ClosedResourceError):
await send_something(send_channel)

View File

@ -1,70 +1,15 @@
import asyncio
import pytest import pytest
import trio
from libp2p.crypto.secp256k1 import create_new_key_pair from libp2p.crypto.secp256k1 import create_new_key_pair
from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.security.secio.transport import NONCE_SIZE, create_secure_session from libp2p.security.secio.transport import NONCE_SIZE, create_secure_session
from libp2p.tools.constants import MAX_READ_LEN
from libp2p.tools.factories import raw_conn_factory
class InMemoryConnection(IRawConnection): @pytest.mark.trio
def __init__(self, peer, is_initiator=False): async def test_create_secure_session(nursery):
self.peer = peer
self.recv_queue = asyncio.Queue()
self.send_queue = asyncio.Queue()
self.is_initiator = is_initiator
self.current_msg = None
self.current_position = 0
self.closed = False
async def write(self, data: bytes) -> int:
if self.closed:
raise Exception("InMemoryConnection is closed for writing")
await self.send_queue.put(data)
return len(data)
async def read(self, n: int = -1) -> bytes:
"""
NOTE: have to buffer the current message and juggle packets
off the recv queue to satisfy the semantics of this function.
"""
if self.closed:
raise Exception("InMemoryConnection is closed for reading")
if not self.current_msg:
self.current_msg = await self.recv_queue.get()
self.current_position = 0
if n < 0:
msg = self.current_msg
self.current_msg = None
return msg
next_msg = self.current_msg[self.current_position : self.current_position + n]
self.current_position += n
if self.current_position == len(self.current_msg):
self.current_msg = None
return next_msg
async def close(self) -> None:
self.closed = True
async def create_pipe(local_conn, remote_conn):
try:
while True:
next_msg = await local_conn.send_queue.get()
await remote_conn.recv_queue.put(next_msg)
except asyncio.CancelledError:
return
@pytest.mark.asyncio
async def test_create_secure_session():
local_nonce = b"\x01" * NONCE_SIZE local_nonce = b"\x01" * NONCE_SIZE
local_key_pair = create_new_key_pair(b"a") local_key_pair = create_new_key_pair(b"a")
local_peer = ID.from_pubkey(local_key_pair.public_key) local_peer = ID.from_pubkey(local_key_pair.public_key)
@ -73,30 +18,32 @@ async def test_create_secure_session():
remote_key_pair = create_new_key_pair(b"b") remote_key_pair = create_new_key_pair(b"b")
remote_peer = ID.from_pubkey(remote_key_pair.public_key) remote_peer = ID.from_pubkey(remote_key_pair.public_key)
local_conn = InMemoryConnection(local_peer, is_initiator=True) async with raw_conn_factory(nursery) as conns:
remote_conn = InMemoryConnection(remote_peer) local_conn, remote_conn = conns
local_pipe_task = asyncio.ensure_future(create_pipe(local_conn, remote_conn)) local_secure_conn, remote_secure_conn = None, None
remote_pipe_task = asyncio.ensure_future(create_pipe(remote_conn, local_conn))
local_session_builder = create_secure_session( async def local_create_secure_session():
local_nonce, local_peer, local_key_pair.private_key, local_conn, remote_peer nonlocal local_secure_conn
) local_secure_conn = await create_secure_session(
remote_session_builder = create_secure_session( local_nonce,
remote_nonce, remote_peer, remote_key_pair.private_key, remote_conn local_peer,
) local_key_pair.private_key,
local_secure_conn, remote_secure_conn = await asyncio.gather( local_conn,
local_session_builder, remote_session_builder remote_peer,
) )
msg = b"abc" async def remote_create_secure_session():
await local_secure_conn.write(msg) nonlocal remote_secure_conn
received_msg = await remote_secure_conn.read() remote_secure_conn = await create_secure_session(
assert received_msg == msg remote_nonce, remote_peer, remote_key_pair.private_key, remote_conn
)
await asyncio.gather(local_secure_conn.close(), remote_secure_conn.close()) async with trio.open_nursery() as nursery_1:
nursery_1.start_soon(local_create_secure_session)
nursery_1.start_soon(remote_create_secure_session)
local_pipe_task.cancel() msg = b"abc"
remote_pipe_task.cancel() await local_secure_conn.write(msg)
await local_pipe_task received_msg = await remote_secure_conn.read(MAX_READ_LEN)
await remote_pipe_task assert received_msg == msg

View File

@ -1,8 +1,7 @@
import asyncio
import pytest import pytest
import trio
from libp2p import new_node from libp2p import new_host
from libp2p.crypto.rsa import create_new_key_pair from libp2p.crypto.rsa import create_new_key_pair
from libp2p.security.insecure.transport import InsecureSession, InsecureTransport from libp2p.security.insecure.transport import InsecureSession, InsecureTransport
from libp2p.tools.constants import LISTEN_MADDR from libp2p.tools.constants import LISTEN_MADDR
@ -24,42 +23,36 @@ noninitiator_key_pair = create_new_key_pair()
async def perform_simple_test( async def perform_simple_test(
assertion_func, transports_for_initiator, transports_for_noninitiator assertion_func, transports_for_initiator, transports_for_noninitiator
): ):
# Create libp2p nodes and connect them, then secure the connection, then check # Create libp2p nodes and connect them, then secure the connection, then check
# the proper security was chosen # the proper security was chosen
# TODO: implement -- note we need to introduce the notion of communicating over a raw connection # TODO: implement -- note we need to introduce the notion of communicating over a raw connection
# for testing, we do NOT want to communicate over a stream so we can't just create two nodes # for testing, we do NOT want to communicate over a stream so we can't just create two nodes
# and use their conn because our mplex will internally relay messages to a stream # and use their conn because our mplex will internally relay messages to a stream
node1 = await new_node( node1 = new_host(key_pair=initiator_key_pair, sec_opt=transports_for_initiator)
key_pair=initiator_key_pair, sec_opt=transports_for_initiator node2 = new_host(
)
node2 = await new_node(
key_pair=noninitiator_key_pair, sec_opt=transports_for_noninitiator key_pair=noninitiator_key_pair, sec_opt=transports_for_noninitiator
) )
async with node1.run(listen_addrs=[LISTEN_MADDR]), node2.run(
listen_addrs=[LISTEN_MADDR]
):
await connect(node1, node2)
await node1.get_network().listen(LISTEN_MADDR) # Wait a very short period to allow conns to be stored (since the functions
await node2.get_network().listen(LISTEN_MADDR) # storing the conns are async, they may happen at slightly different times
# on each node)
await trio.sleep(0.1)
await connect(node1, node2) # Get conns
node1_conn = node1.get_network().connections[peer_id_for_node(node2)]
node2_conn = node2.get_network().connections[peer_id_for_node(node1)]
# Wait a very short period to allow conns to be stored (since the functions # Perform assertion
# storing the conns are async, they may happen at slightly different times assertion_func(node1_conn.muxed_conn.secured_conn)
# on each node) assertion_func(node2_conn.muxed_conn.secured_conn)
await asyncio.sleep(0.1)
# Get conns
node1_conn = node1.get_network().connections[peer_id_for_node(node2)]
node2_conn = node2.get_network().connections[peer_id_for_node(node1)]
# Perform assertion
assertion_func(node1_conn.muxed_conn.secured_conn)
assertion_func(node2_conn.muxed_conn.secured_conn)
# Success, terminate pending tasks.
@pytest.mark.asyncio @pytest.mark.trio
async def test_single_insecure_security_transport_succeeds(): async def test_single_insecure_security_transport_succeeds():
transports_for_initiator = {"foo": InsecureTransport(initiator_key_pair)} transports_for_initiator = {"foo": InsecureTransport(initiator_key_pair)}
transports_for_noninitiator = {"foo": InsecureTransport(noninitiator_key_pair)} transports_for_noninitiator = {"foo": InsecureTransport(noninitiator_key_pair)}
@ -72,7 +65,7 @@ async def test_single_insecure_security_transport_succeeds():
) )
@pytest.mark.asyncio @pytest.mark.trio
async def test_default_insecure_security(): async def test_default_insecure_security():
transports_for_initiator = None transports_for_initiator = None
transports_for_noninitiator = None transports_for_noninitiator = None

View File

@ -1,5 +1,3 @@
import asyncio
import pytest import pytest
from libp2p.tools.factories import mplex_conn_pair_factory, mplex_stream_pair_factory from libp2p.tools.factories import mplex_conn_pair_factory, mplex_stream_pair_factory
@ -7,23 +5,13 @@ from libp2p.tools.factories import mplex_conn_pair_factory, mplex_stream_pair_fa
@pytest.fixture @pytest.fixture
async def mplex_conn_pair(is_host_secure): async def mplex_conn_pair(is_host_secure):
mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory( async with mplex_conn_pair_factory(is_host_secure) as mplex_conn_pair:
is_host_secure assert mplex_conn_pair[0].is_initiator
) assert not mplex_conn_pair[1].is_initiator
assert mplex_conn_0.is_initiator yield mplex_conn_pair[0], mplex_conn_pair[1]
assert not mplex_conn_1.is_initiator
try:
yield mplex_conn_0, mplex_conn_1
finally:
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])
@pytest.fixture @pytest.fixture
async def mplex_stream_pair(is_host_secure): async def mplex_stream_pair(is_host_secure):
mplex_stream_0, swarm_0, mplex_stream_1, swarm_1 = await mplex_stream_pair_factory( async with mplex_stream_pair_factory(is_host_secure) as mplex_stream_pair:
is_host_secure yield mplex_stream_pair
)
try:
yield mplex_stream_0, mplex_stream_1
finally:
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])

View File

@ -1,49 +1,40 @@
import asyncio
import pytest import pytest
import trio
@pytest.mark.asyncio @pytest.mark.trio
async def test_mplex_conn(mplex_conn_pair): async def test_mplex_conn(mplex_conn_pair):
conn_0, conn_1 = mplex_conn_pair conn_0, conn_1 = mplex_conn_pair
assert len(conn_0.streams) == 0 assert len(conn_0.streams) == 0
assert len(conn_1.streams) == 0 assert len(conn_1.streams) == 0
assert not conn_0.event_shutting_down.is_set()
assert not conn_1.event_shutting_down.is_set()
assert not conn_0.event_closed.is_set()
assert not conn_1.event_closed.is_set()
# Test: Open a stream, and both side get 1 more stream. # Test: Open a stream, and both side get 1 more stream.
stream_0 = await conn_0.open_stream() stream_0 = await conn_0.open_stream()
await asyncio.sleep(0.01) await trio.sleep(0.01)
assert len(conn_0.streams) == 1 assert len(conn_0.streams) == 1
assert len(conn_1.streams) == 1 assert len(conn_1.streams) == 1
# Test: From another side. # Test: From another side.
stream_1 = await conn_1.open_stream() stream_1 = await conn_1.open_stream()
await asyncio.sleep(0.01) await trio.sleep(0.01)
assert len(conn_0.streams) == 2 assert len(conn_0.streams) == 2
assert len(conn_1.streams) == 2 assert len(conn_1.streams) == 2
# Close from one side. # Close from one side.
await conn_0.close() await conn_0.close()
# Sleep for a while for both side to handle `close`. # Sleep for a while for both side to handle `close`.
await asyncio.sleep(0.01) await trio.sleep(0.01)
# Test: Both side is closed. # Test: Both side is closed.
assert conn_0.event_shutting_down.is_set() assert conn_0.is_closed
assert conn_0.event_closed.is_set() assert conn_1.is_closed
assert conn_1.event_shutting_down.is_set()
assert conn_1.event_closed.is_set()
# Test: All streams should have been closed. # Test: All streams should have been closed.
assert stream_0.event_remote_closed.is_set() assert stream_0.event_remote_closed.is_set()
assert stream_0.event_reset.is_set() assert stream_0.event_reset.is_set()
assert stream_0.event_local_closed.is_set() assert stream_0.event_local_closed.is_set()
assert conn_0.streams is None
# Test: All streams on the other side are also closed. # Test: All streams on the other side are also closed.
assert stream_1.event_remote_closed.is_set() assert stream_1.event_remote_closed.is_set()
assert stream_1.event_reset.is_set() assert stream_1.event_reset.is_set()
assert stream_1.event_local_closed.is_set() assert stream_1.event_local_closed.is_set()
assert conn_1.streams is None
# Test: No effect to close more than once between two side. # Test: No effect to close more than once between two side.
await conn_0.close() await conn_0.close()

View File

@ -1,25 +1,48 @@
import asyncio
import pytest import pytest
import trio
from trio.testing import wait_all_tasks_blocked
from libp2p.stream_muxer.mplex.exceptions import ( from libp2p.stream_muxer.mplex.exceptions import (
MplexStreamClosed, MplexStreamClosed,
MplexStreamEOF, MplexStreamEOF,
MplexStreamReset, MplexStreamReset,
) )
from libp2p.stream_muxer.mplex.mplex import MPLEX_MESSAGE_CHANNEL_SIZE
from libp2p.tools.constants import MAX_READ_LEN from libp2p.tools.constants import MAX_READ_LEN
DATA = b"data_123" DATA = b"data_123"
@pytest.mark.asyncio @pytest.mark.trio
async def test_mplex_stream_read_write(mplex_stream_pair): async def test_mplex_stream_read_write(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair stream_0, stream_1 = mplex_stream_pair
await stream_0.write(DATA) await stream_0.write(DATA)
assert (await stream_1.read(MAX_READ_LEN)) == DATA assert (await stream_1.read(MAX_READ_LEN)) == DATA
@pytest.mark.asyncio @pytest.mark.trio
async def test_mplex_stream_full_buffer(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
# Test: The message channel is of size `MPLEX_MESSAGE_CHANNEL_SIZE`.
# It should be fine to read even there are already `MPLEX_MESSAGE_CHANNEL_SIZE`
# messages arriving.
for _ in range(MPLEX_MESSAGE_CHANNEL_SIZE):
await stream_0.write(DATA)
await wait_all_tasks_blocked()
# Sanity check
assert MAX_READ_LEN >= MPLEX_MESSAGE_CHANNEL_SIZE * len(DATA)
assert (await stream_1.read(MAX_READ_LEN)) == MPLEX_MESSAGE_CHANNEL_SIZE * DATA
# Test: Read after `MPLEX_MESSAGE_CHANNEL_SIZE + 1` messages has arrived, which
# exceeds the channel size. The stream should have been reset.
for _ in range(MPLEX_MESSAGE_CHANNEL_SIZE + 1):
await stream_0.write(DATA)
await wait_all_tasks_blocked()
with pytest.raises(MplexStreamReset):
await stream_1.read(MAX_READ_LEN)
@pytest.mark.trio
async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair): async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair):
read_bytes = bytearray() read_bytes = bytearray()
stream_0, stream_1 = mplex_stream_pair stream_0, stream_1 = mplex_stream_pair
@ -27,43 +50,46 @@ async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair):
async def read_until_eof(): async def read_until_eof():
read_bytes.extend(await stream_1.read()) read_bytes.extend(await stream_1.read())
task = asyncio.ensure_future(read_until_eof())
expected_data = bytearray() expected_data = bytearray()
# Test: `read` doesn't return before `close` is called. async with trio.open_nursery() as nursery:
await stream_0.write(DATA) nursery.start_soon(read_until_eof)
expected_data.extend(DATA) # Test: `read` doesn't return before `close` is called.
await asyncio.sleep(0.01) await stream_0.write(DATA)
assert len(read_bytes) == 0 expected_data.extend(DATA)
# Test: `read` doesn't return before `close` is called. await trio.sleep(0.01)
await stream_0.write(DATA) assert len(read_bytes) == 0
expected_data.extend(DATA) # Test: `read` doesn't return before `close` is called.
await asyncio.sleep(0.01) await stream_0.write(DATA)
assert len(read_bytes) == 0 expected_data.extend(DATA)
await trio.sleep(0.01)
assert len(read_bytes) == 0
# Test: Close the stream, `read` returns, and receive previous sent data.
await stream_0.close()
# Test: Close the stream, `read` returns, and receive previous sent data.
await stream_0.close()
await asyncio.sleep(0.01)
assert read_bytes == expected_data assert read_bytes == expected_data
task.cancel()
@pytest.mark.trio
@pytest.mark.asyncio
async def test_mplex_stream_read_after_remote_closed(mplex_stream_pair): async def test_mplex_stream_read_after_remote_closed(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair stream_0, stream_1 = mplex_stream_pair
assert not stream_1.event_remote_closed.is_set() assert not stream_1.event_remote_closed.is_set()
await stream_0.write(DATA) await stream_0.write(DATA)
assert not stream_0.event_local_closed.is_set()
await trio.sleep(0.01)
await wait_all_tasks_blocked()
await stream_0.close() await stream_0.close()
await asyncio.sleep(0.01) assert stream_0.event_local_closed.is_set()
await trio.sleep(0.01)
await wait_all_tasks_blocked()
assert stream_1.event_remote_closed.is_set() assert stream_1.event_remote_closed.is_set()
assert (await stream_1.read(MAX_READ_LEN)) == DATA assert (await stream_1.read(MAX_READ_LEN)) == DATA
with pytest.raises(MplexStreamEOF): with pytest.raises(MplexStreamEOF):
await stream_1.read(MAX_READ_LEN) await stream_1.read(MAX_READ_LEN)
@pytest.mark.asyncio @pytest.mark.trio
async def test_mplex_stream_read_after_local_reset(mplex_stream_pair): async def test_mplex_stream_read_after_local_reset(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair stream_0, stream_1 = mplex_stream_pair
await stream_0.reset() await stream_0.reset()
@ -71,29 +97,30 @@ async def test_mplex_stream_read_after_local_reset(mplex_stream_pair):
await stream_0.read(MAX_READ_LEN) await stream_0.read(MAX_READ_LEN)
@pytest.mark.asyncio @pytest.mark.trio
async def test_mplex_stream_read_after_remote_reset(mplex_stream_pair): async def test_mplex_stream_read_after_remote_reset(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair stream_0, stream_1 = mplex_stream_pair
await stream_0.write(DATA) await stream_0.write(DATA)
await stream_0.reset() await stream_0.reset()
# Sleep to let `stream_1` receive the message. # Sleep to let `stream_1` receive the message.
await asyncio.sleep(0.01) await trio.sleep(0.1)
await wait_all_tasks_blocked()
with pytest.raises(MplexStreamReset): with pytest.raises(MplexStreamReset):
await stream_1.read(MAX_READ_LEN) await stream_1.read(MAX_READ_LEN)
@pytest.mark.asyncio @pytest.mark.trio
async def test_mplex_stream_read_after_remote_closed_and_reset(mplex_stream_pair): async def test_mplex_stream_read_after_remote_closed_and_reset(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair stream_0, stream_1 = mplex_stream_pair
await stream_0.write(DATA) await stream_0.write(DATA)
await stream_0.close() await stream_0.close()
await stream_0.reset() await stream_0.reset()
# Sleep to let `stream_1` receive the message. # Sleep to let `stream_1` receive the message.
await asyncio.sleep(0.01) await trio.sleep(0.01)
assert (await stream_1.read(MAX_READ_LEN)) == DATA assert (await stream_1.read(MAX_READ_LEN)) == DATA
@pytest.mark.asyncio @pytest.mark.trio
async def test_mplex_stream_write_after_local_closed(mplex_stream_pair): async def test_mplex_stream_write_after_local_closed(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair stream_0, stream_1 = mplex_stream_pair
await stream_0.write(DATA) await stream_0.write(DATA)
@ -102,7 +129,7 @@ async def test_mplex_stream_write_after_local_closed(mplex_stream_pair):
await stream_0.write(DATA) await stream_0.write(DATA)
@pytest.mark.asyncio @pytest.mark.trio
async def test_mplex_stream_write_after_local_reset(mplex_stream_pair): async def test_mplex_stream_write_after_local_reset(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair stream_0, stream_1 = mplex_stream_pair
await stream_0.reset() await stream_0.reset()
@ -110,16 +137,16 @@ async def test_mplex_stream_write_after_local_reset(mplex_stream_pair):
await stream_0.write(DATA) await stream_0.write(DATA)
@pytest.mark.asyncio @pytest.mark.trio
async def test_mplex_stream_write_after_remote_reset(mplex_stream_pair): async def test_mplex_stream_write_after_remote_reset(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair stream_0, stream_1 = mplex_stream_pair
await stream_1.reset() await stream_1.reset()
await asyncio.sleep(0.01) await trio.sleep(0.01)
with pytest.raises(MplexStreamClosed): with pytest.raises(MplexStreamClosed):
await stream_0.write(DATA) await stream_0.write(DATA)
@pytest.mark.asyncio @pytest.mark.trio
async def test_mplex_stream_both_close(mplex_stream_pair): async def test_mplex_stream_both_close(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair stream_0, stream_1 = mplex_stream_pair
# Flags are not set initially. # Flags are not set initially.
@ -133,7 +160,7 @@ async def test_mplex_stream_both_close(mplex_stream_pair):
# Test: Close one side. # Test: Close one side.
await stream_0.close() await stream_0.close()
await asyncio.sleep(0.01) await trio.sleep(0.01)
assert stream_0.event_local_closed.is_set() assert stream_0.event_local_closed.is_set()
assert not stream_1.event_local_closed.is_set() assert not stream_1.event_local_closed.is_set()
@ -145,7 +172,7 @@ async def test_mplex_stream_both_close(mplex_stream_pair):
# Test: Close the other side. # Test: Close the other side.
await stream_1.close() await stream_1.close()
await asyncio.sleep(0.01) await trio.sleep(0.01)
# Both sides are closed. # Both sides are closed.
assert stream_0.event_local_closed.is_set() assert stream_0.event_local_closed.is_set()
assert stream_1.event_local_closed.is_set() assert stream_1.event_local_closed.is_set()
@ -159,11 +186,11 @@ async def test_mplex_stream_both_close(mplex_stream_pair):
await stream_0.reset() await stream_0.reset()
@pytest.mark.asyncio @pytest.mark.trio
async def test_mplex_stream_reset(mplex_stream_pair): async def test_mplex_stream_reset(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair stream_0, stream_1 = mplex_stream_pair
await stream_0.reset() await stream_0.reset()
await asyncio.sleep(0.01) await trio.sleep(0.01)
# Both sides are closed. # Both sides are closed.
assert stream_0.event_local_closed.is_set() assert stream_0.event_local_closed.is_set()

View File

@ -1,20 +1,53 @@
import asyncio from multiaddr import Multiaddr
import pytest import pytest
import trio
from libp2p.transport.tcp.tcp import _multiaddr_from_socket from libp2p.network.connection.raw_connection import RawConnection
from libp2p.tools.constants import LISTEN_MADDR
from libp2p.transport.exceptions import OpenConnectionError
from libp2p.transport.tcp.tcp import TCP
@pytest.mark.asyncio @pytest.mark.trio
async def test_multiaddr_from_socket(): async def test_tcp_listener(nursery):
def handler(r, w): transport = TCP()
async def handler(tcp_stream):
pass pass
server = await asyncio.start_server(handler, "127.0.0.1", 8000) listener = transport.create_listener(handler)
assert str(_multiaddr_from_socket(server.sockets[0])) == "/ip4/127.0.0.1/tcp/8000" assert len(listener.get_addrs()) == 0
await listener.listen(LISTEN_MADDR, nursery)
assert len(listener.get_addrs()) == 1
await listener.listen(LISTEN_MADDR, nursery)
assert len(listener.get_addrs()) == 2
server = await asyncio.start_server(handler, "127.0.0.1", 0)
addr = _multiaddr_from_socket(server.sockets[0]) @pytest.mark.trio
assert addr.value_for_protocol("ip4") == "127.0.0.1" async def test_tcp_dial(nursery):
port = addr.value_for_protocol("tcp") transport = TCP()
assert int(port) > 0 raw_conn_other_side = None
event = trio.Event()
async def handler(tcp_stream):
nonlocal raw_conn_other_side
raw_conn_other_side = RawConnection(tcp_stream, False)
event.set()
await trio.sleep_forever()
# Test: `OpenConnectionError` is raised when trying to dial to a port which
# no one is not listening to.
with pytest.raises(OpenConnectionError):
await transport.dial(Multiaddr("/ip4/127.0.0.1/tcp/1"))
listener = transport.create_listener(handler)
await listener.listen(LISTEN_MADDR, nursery)
addrs = listener.get_addrs()
assert len(addrs) == 1
listen_addr = addrs[0]
raw_conn = await transport.dial(listen_addr)
await event.wait()
data = b"123"
await raw_conn_other_side.write(data)
assert (await raw_conn.read(len(data))) == data

View File

@ -1,20 +1,13 @@
import asyncio import anyio
import sys from async_exit_stack import AsyncExitStack
from typing import Union
from p2pclient.datastructures import StreamInfo from p2pclient.datastructures import StreamInfo
import pexpect from p2pclient.utils import get_unused_tcp_port
import pytest import pytest
import trio
from libp2p.io.abc import ReadWriteCloser from libp2p.io.abc import ReadWriteCloser
from libp2p.tools.constants import GOSSIPSUB_PARAMS, LISTEN_MADDR from libp2p.tools.factories import HostFactory, PubsubFactory
from libp2p.tools.factories import ( from libp2p.tools.interop.daemon import make_p2pd
FloodsubFactory,
GossipsubFactory,
HostFactory,
PubsubFactory,
)
from libp2p.tools.interop.daemon import Daemon, make_p2pd
from libp2p.tools.interop.utils import connect from libp2p.tools.interop.utils import connect
@ -23,48 +16,6 @@ def is_host_secure():
return False return False
@pytest.fixture
def num_hosts():
return 3
@pytest.fixture
async def hosts(num_hosts, is_host_secure):
_hosts = HostFactory.create_batch(num_hosts, is_secure=is_host_secure)
await asyncio.gather(
*[_host.get_network().listen(LISTEN_MADDR) for _host in _hosts]
)
try:
yield _hosts
finally:
# TODO: It's possible that `close` raises exceptions currently,
# due to the connection reset things. Though we don't care much about that when
# cleaning up the tasks, it is probably better to handle the exceptions properly.
await asyncio.gather(
*[_host.close() for _host in _hosts], return_exceptions=True
)
@pytest.fixture
def proc_factory():
procs = []
def call_proc(cmd, args, logfile=None, encoding=None):
if logfile is None:
logfile = sys.stdout
if encoding is None:
encoding = "utf-8"
proc = pexpect.spawn(cmd, args, logfile=logfile, encoding=encoding)
procs.append(proc)
return proc
try:
yield call_proc
finally:
for proc in procs:
proc.close()
@pytest.fixture @pytest.fixture
def num_p2pds(): def num_p2pds():
return 1 return 1
@ -87,79 +38,57 @@ def is_pubsub_signing_strict():
@pytest.fixture @pytest.fixture
async def p2pds( async def p2pds(
num_p2pds, num_p2pds, is_host_secure, is_gossipsub, is_pubsub_signing, is_pubsub_signing_strict
is_host_secure,
is_gossipsub,
unused_tcp_port_factory,
is_pubsub_signing,
is_pubsub_signing_strict,
): ):
p2pds: Union[Daemon, Exception] = await asyncio.gather( async with AsyncExitStack() as stack:
*[ p2pds = [
make_p2pd( await stack.enter_async_context(
unused_tcp_port_factory(), make_p2pd(
unused_tcp_port_factory(), get_unused_tcp_port(),
is_host_secure, get_unused_tcp_port(),
is_gossipsub=is_gossipsub, is_host_secure,
is_pubsub_signing=is_pubsub_signing, is_gossipsub=is_gossipsub,
is_pubsub_signing_strict=is_pubsub_signing_strict, is_pubsub_signing=is_pubsub_signing,
is_pubsub_signing_strict=is_pubsub_signing_strict,
)
) )
for _ in range(num_p2pds) for _ in range(num_p2pds)
], ]
return_exceptions=True, try:
) yield p2pds
p2pds_succeeded = tuple(p2pd for p2pd in p2pds if isinstance(p2pd, Daemon)) finally:
if len(p2pds_succeeded) != len(p2pds): for p2pd in p2pds:
# Not all succeeded. Close the succeeded ones and print the failed ones(exceptions). await p2pd.close()
await asyncio.gather(*[p2pd.close() for p2pd in p2pds_succeeded])
exceptions = tuple(p2pd for p2pd in p2pds if isinstance(p2pd, Exception))
raise Exception(f"not all p2pds succeed: first exception={exceptions[0]}")
try:
yield p2pds
finally:
await asyncio.gather(*[p2pd.close() for p2pd in p2pds])
@pytest.fixture @pytest.fixture
def pubsubs(num_hosts, hosts, is_gossipsub, is_pubsub_signing_strict): async def pubsubs(num_hosts, is_host_secure, is_gossipsub, is_pubsub_signing_strict):
if is_gossipsub: if is_gossipsub:
routers = GossipsubFactory.create_batch(num_hosts, **GOSSIPSUB_PARAMS._asdict()) yield PubsubFactory.create_batch_with_gossipsub(
num_hosts, is_secure=is_host_secure, strict_signing=is_pubsub_signing_strict
)
else: else:
routers = FloodsubFactory.create_batch(num_hosts) yield PubsubFactory.create_batch_with_floodsub(
_pubsubs = tuple( num_hosts, is_host_secure, strict_signing=is_pubsub_signing_strict
PubsubFactory(host=host, router=router, strict_signing=is_pubsub_signing_strict) )
for host, router in zip(hosts, routers)
)
yield _pubsubs
# TODO: Clean up
class DaemonStream(ReadWriteCloser): class DaemonStream(ReadWriteCloser):
stream_info: StreamInfo stream_info: StreamInfo
reader: asyncio.StreamReader stream: anyio.abc.SocketStream
writer: asyncio.StreamWriter
def __init__( def __init__(self, stream_info: StreamInfo, stream: anyio.abc.SocketStream) -> None:
self,
stream_info: StreamInfo,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
) -> None:
self.stream_info = stream_info self.stream_info = stream_info
self.reader = reader self.stream = stream
self.writer = writer
async def close(self) -> None: async def close(self) -> None:
self.writer.close() await self.stream.close()
if sys.version_info < (3, 7):
return
await self.writer.wait_closed()
async def read(self, n: int = -1) -> bytes: async def read(self, n: int = None) -> bytes:
return await self.reader.read(n) return await self.stream.receive_some(n)
async def write(self, data: bytes) -> int: async def write(self, data: bytes) -> None:
return self.writer.write(data) return await self.stream.send_all(data)
@pytest.fixture @pytest.fixture
@ -168,40 +97,38 @@ async def is_to_fail_daemon_stream():
@pytest.fixture @pytest.fixture
async def py_to_daemon_stream_pair(hosts, p2pds, is_to_fail_daemon_stream): async def py_to_daemon_stream_pair(p2pds, is_host_secure, is_to_fail_daemon_stream):
assert len(hosts) >= 1 async with HostFactory.create_batch_and_listen(is_host_secure, 1) as hosts:
assert len(p2pds) >= 1 assert len(p2pds) >= 1
host = hosts[0] host = hosts[0]
p2pd = p2pds[0] p2pd = p2pds[0]
protocol_id = "/protocol/id/123" protocol_id = "/protocol/id/123"
stream_py = None stream_py = None
stream_daemon = None stream_daemon = None
event_stream_handled = asyncio.Event() event_stream_handled = trio.Event()
await connect(host, p2pd) await connect(host, p2pd)
async def daemon_stream_handler(stream_info, reader, writer): async def daemon_stream_handler(stream_info, stream):
nonlocal stream_daemon nonlocal stream_daemon
stream_daemon = DaemonStream(stream_info, reader, writer) stream_daemon = DaemonStream(stream_info, stream)
event_stream_handled.set() event_stream_handled.set()
await trio.hazmat.checkpoint()
await p2pd.control.stream_handler(protocol_id, daemon_stream_handler) await p2pd.control.stream_handler(protocol_id, daemon_stream_handler)
# Sleep for a while to wait for the handler being registered. # Sleep for a while to wait for the handler being registered.
await asyncio.sleep(0.01) await trio.sleep(0.01)
if is_to_fail_daemon_stream: if is_to_fail_daemon_stream:
# FIXME: This is a workaround to make daemon reset the stream. # FIXME: This is a workaround to make daemon reset the stream.
# We intentionally close the listener on the python side, it makes the connection from # We intentionally close the listener on the python side, it makes the connection from
# daemon to us fail, and therefore the daemon resets the opened stream on their side. # daemon to us fail, and therefore the daemon resets the opened stream on their side.
# Reference: https://github.com/libp2p/go-libp2p-daemon/blob/b95e77dbfcd186ccf817f51e95f73f9fd5982600/stream.go#L47-L50 # noqa: E501 # Reference: https://github.com/libp2p/go-libp2p-daemon/blob/b95e77dbfcd186ccf817f51e95f73f9fd5982600/stream.go#L47-L50 # noqa: E501
# We need it because we want to test against `stream_py` after the remote side(daemon) # We need it because we want to test against `stream_py` after the remote side(daemon)
# is reset. This should be removed after the API `stream.reset` is exposed in daemon # is reset. This should be removed after the API `stream.reset` is exposed in daemon
# some day. # some day.
listener = p2pds[0].control.control.listener await p2pds[0].control.control.close()
listener.close() stream_py = await host.new_stream(p2pd.peer_id, [protocol_id])
if sys.version_info[0:2] > (3, 6): if not is_to_fail_daemon_stream:
await listener.wait_closed() await event_stream_handled.wait()
stream_py = await host.new_stream(p2pd.peer_id, [protocol_id]) # NOTE: If `is_to_fail_daemon_stream == True`, then `stream_daemon == None`.
if not is_to_fail_daemon_stream: yield stream_py, stream_daemon
await event_stream_handled.wait()
# NOTE: If `is_to_fail_daemon_stream == True`, then `stream_daemon == None`.
yield stream_py, stream_daemon

View File

@ -1,26 +1,26 @@
import asyncio
import pytest import pytest
import trio
from libp2p.tools.factories import HostFactory
from libp2p.tools.interop.utils import connect from libp2p.tools.interop.utils import connect
@pytest.mark.parametrize("num_hosts", (1,)) @pytest.mark.trio
@pytest.mark.asyncio async def test_connect(is_host_secure, p2pds):
async def test_connect(hosts, p2pds): async with HostFactory.create_batch_and_listen(is_host_secure, 1) as hosts:
p2pd = p2pds[0] p2pd = p2pds[0]
host = hosts[0] host = hosts[0]
assert len(await p2pd.control.list_peers()) == 0 assert len(await p2pd.control.list_peers()) == 0
# Test: connect from Py # Test: connect from Py
await connect(host, p2pd) await connect(host, p2pd)
assert len(await p2pd.control.list_peers()) == 1 assert len(await p2pd.control.list_peers()) == 1
# Test: `disconnect` from Py # Test: `disconnect` from Py
await host.disconnect(p2pd.peer_id) await host.disconnect(p2pd.peer_id)
assert len(await p2pd.control.list_peers()) == 0 assert len(await p2pd.control.list_peers()) == 0
# Test: connect from Go # Test: connect from Go
await connect(p2pd, host) await connect(p2pd, host)
assert len(host.get_network().connections) == 1 assert len(host.get_network().connections) == 1
# Test: `disconnect` from Go # Test: `disconnect` from Go
await p2pd.control.disconnect(host.get_id()) await p2pd.control.disconnect(host.get_id())
await asyncio.sleep(0.01) await trio.sleep(0.01)
assert len(host.get_network().connections) == 0 assert len(host.get_network().connections) == 0

View File

@ -1,82 +1,99 @@
import asyncio import re
from multiaddr import Multiaddr from multiaddr import Multiaddr
from p2pclient.utils import get_unused_tcp_port
import pytest import pytest
import trio
from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.peer.peerinfo import PeerInfo, info_from_p2p_addr
from libp2p.tools.interop.constants import PEXPECT_NEW_LINE from libp2p.tools.factories import HostFactory
from libp2p.tools.interop.envs import GO_BIN_PATH from libp2p.tools.interop.envs import GO_BIN_PATH
from libp2p.tools.interop.process import BaseInteractiveProcess
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
ECHO_PATH = GO_BIN_PATH / "echo" ECHO_PATH = GO_BIN_PATH / "echo"
ECHO_PROTOCOL_ID = TProtocol("/echo/1.0.0") ECHO_PROTOCOL_ID = TProtocol("/echo/1.0.0")
async def make_echo_proc( class EchoProcess(BaseInteractiveProcess):
proc_factory, port: int, is_secure: bool, destination: Multiaddr = None port: int
): _peer_info: PeerInfo
args = [f"-l={port}"]
if not is_secure: def __init__(
args.append("-insecure") self, port: int, is_secure: bool, destination: Multiaddr = None
if destination is not None: ) -> None:
args.append(f"-d={str(destination)}") args = [f"-l={port}"]
echo_proc = proc_factory(str(ECHO_PATH), args) if not is_secure:
await echo_proc.expect(r"I am ([\w\./]+)" + PEXPECT_NEW_LINE, async_=True) args.append("-insecure")
maddr_str_ipfs = echo_proc.match.group(1) if destination is not None:
maddr_str = maddr_str_ipfs.replace("ipfs", "p2p") args.append(f"-d={str(destination)}")
maddr = Multiaddr(maddr_str)
go_pinfo = info_from_p2p_addr(maddr) patterns = [b"I am"]
if destination is None: if destination is None:
await echo_proc.expect("listening for connections", async_=True) patterns.append(b"listening for connections")
return echo_proc, go_pinfo
self.args = args
self.cmd = str(ECHO_PATH)
self.patterns = patterns
self.bytes_read = bytearray()
self.event_ready = trio.Event()
self.port = port
self._peer_info = None
self.regex_pat = re.compile(br"I am ([\w\./]+)")
@property
def peer_info(self) -> None:
if self._peer_info is not None:
return self._peer_info
if not self.event_ready.is_set():
raise Exception("process is not ready yet. failed to parse the peer info")
# Example:
# b"I am /ip4/127.0.0.1/tcp/56171/ipfs/QmU41TRPs34WWqa1brJEojBLYZKrrBcJq9nyNfVvSrbZUJ\n"
m = re.search(br"I am ([\w\./]+)", self.bytes_read)
if m is None:
raise Exception("failed to find the pattern for the listening multiaddr")
maddr_bytes_str_ipfs = m.group(1)
maddr_str = maddr_bytes_str_ipfs.decode().replace("ipfs", "p2p")
maddr = Multiaddr(maddr_str)
self._peer_info = info_from_p2p_addr(maddr)
return self._peer_info
@pytest.mark.parametrize("num_hosts", (1,)) @pytest.mark.trio
@pytest.mark.asyncio async def test_insecure_conn_py_to_go(is_host_secure):
async def test_insecure_conn_py_to_go( async with HostFactory.create_batch_and_listen(is_host_secure, 1) as hosts:
hosts, proc_factory, is_host_secure, unused_tcp_port go_proc = EchoProcess(get_unused_tcp_port(), is_host_secure)
): await go_proc.start()
go_proc, go_pinfo = await make_echo_proc(
proc_factory, unused_tcp_port, is_host_secure
)
host = hosts[0] host = hosts[0]
await host.connect(go_pinfo) peer_info = go_proc.peer_info
await go_proc.expect("swarm listener accepted connection", async_=True) await host.connect(peer_info)
s = await host.new_stream(go_pinfo.peer_id, [ECHO_PROTOCOL_ID]) s = await host.new_stream(peer_info.peer_id, [ECHO_PROTOCOL_ID])
data = "data321123\n"
await go_proc.expect("Got a new stream!", async_=True) await s.write(data.encode())
data = "data321123\n" echoed_resp = await s.read(len(data))
await s.write(data.encode()) assert echoed_resp.decode() == data
await go_proc.expect(f"read: {data[:-1]}", async_=True) await s.close()
echoed_resp = await s.read(len(data))
assert echoed_resp.decode() == data
await s.close()
@pytest.mark.parametrize("num_hosts", (1,)) @pytest.mark.trio
@pytest.mark.asyncio async def test_insecure_conn_go_to_py(is_host_secure):
async def test_insecure_conn_go_to_py( async with HostFactory.create_batch_and_listen(is_host_secure, 1) as hosts:
hosts, proc_factory, is_host_secure, unused_tcp_port host = hosts[0]
): expected_data = "Hello, world!\n"
host = hosts[0] reply_data = "Replyooo!\n"
expected_data = "Hello, world!\n" event_handler_finished = trio.Event()
reply_data = "Replyooo!\n"
event_handler_finished = asyncio.Event()
async def _handle_echo(stream): async def _handle_echo(stream):
read_data = await stream.read(len(expected_data)) read_data = await stream.read(len(expected_data))
assert read_data == expected_data.encode() assert read_data == expected_data.encode()
event_handler_finished.set() event_handler_finished.set()
await stream.write(reply_data.encode()) await stream.write(reply_data.encode())
await stream.close() await stream.close()
host.set_stream_handler(ECHO_PROTOCOL_ID, _handle_echo) host.set_stream_handler(ECHO_PROTOCOL_ID, _handle_echo)
py_maddr = host.get_addrs()[0] py_maddr = host.get_addrs()[0]
go_proc, _ = await make_echo_proc( go_proc = EchoProcess(get_unused_tcp_port(), is_host_secure, py_maddr)
proc_factory, unused_tcp_port, is_host_secure, py_maddr await go_proc.start()
) await event_handler_finished.wait()
await go_proc.expect("connect with peer", async_=True)
await go_proc.expect("opened stream", async_=True)
await event_handler_finished.wait()
await go_proc.expect(f"read reply: .*{reply_data.rstrip()}.*", async_=True)

View File

@ -1,6 +1,5 @@
import asyncio
import pytest import pytest
import trio
from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset
from libp2p.tools.constants import MAX_READ_LEN from libp2p.tools.constants import MAX_READ_LEN
@ -8,7 +7,7 @@ from libp2p.tools.constants import MAX_READ_LEN
DATA = b"data" DATA = b"data"
@pytest.mark.asyncio @pytest.mark.trio
async def test_net_stream_read_write(py_to_daemon_stream_pair, p2pds): async def test_net_stream_read_write(py_to_daemon_stream_pair, p2pds):
stream_py, stream_daemon = py_to_daemon_stream_pair stream_py, stream_daemon = py_to_daemon_stream_pair
assert ( assert (
@ -19,19 +18,19 @@ async def test_net_stream_read_write(py_to_daemon_stream_pair, p2pds):
assert (await stream_daemon.read(MAX_READ_LEN)) == DATA assert (await stream_daemon.read(MAX_READ_LEN)) == DATA
@pytest.mark.asyncio @pytest.mark.trio
async def test_net_stream_read_after_remote_closed(py_to_daemon_stream_pair, p2pds): async def test_net_stream_read_after_remote_closed(py_to_daemon_stream_pair, p2pds):
stream_py, stream_daemon = py_to_daemon_stream_pair stream_py, stream_daemon = py_to_daemon_stream_pair
await stream_daemon.write(DATA) await stream_daemon.write(DATA)
await stream_daemon.close() await stream_daemon.close()
await asyncio.sleep(0.01) await trio.sleep(0.01)
assert (await stream_py.read(MAX_READ_LEN)) == DATA assert (await stream_py.read(MAX_READ_LEN)) == DATA
# EOF # EOF
with pytest.raises(StreamEOF): with pytest.raises(StreamEOF):
await stream_py.read(MAX_READ_LEN) await stream_py.read(MAX_READ_LEN)
@pytest.mark.asyncio @pytest.mark.trio
async def test_net_stream_read_after_local_reset(py_to_daemon_stream_pair, p2pds): async def test_net_stream_read_after_local_reset(py_to_daemon_stream_pair, p2pds):
stream_py, _ = py_to_daemon_stream_pair stream_py, _ = py_to_daemon_stream_pair
await stream_py.reset() await stream_py.reset()
@ -40,15 +39,15 @@ async def test_net_stream_read_after_local_reset(py_to_daemon_stream_pair, p2pds
@pytest.mark.parametrize("is_to_fail_daemon_stream", (True,)) @pytest.mark.parametrize("is_to_fail_daemon_stream", (True,))
@pytest.mark.asyncio @pytest.mark.trio
async def test_net_stream_read_after_remote_reset(py_to_daemon_stream_pair, p2pds): async def test_net_stream_read_after_remote_reset(py_to_daemon_stream_pair, p2pds):
stream_py, _ = py_to_daemon_stream_pair stream_py, _ = py_to_daemon_stream_pair
await asyncio.sleep(0.01) await trio.sleep(0.01)
with pytest.raises(StreamReset): with pytest.raises(StreamReset):
await stream_py.read(MAX_READ_LEN) await stream_py.read(MAX_READ_LEN)
@pytest.mark.asyncio @pytest.mark.trio
async def test_net_stream_write_after_local_closed(py_to_daemon_stream_pair, p2pds): async def test_net_stream_write_after_local_closed(py_to_daemon_stream_pair, p2pds):
stream_py, _ = py_to_daemon_stream_pair stream_py, _ = py_to_daemon_stream_pair
await stream_py.write(DATA) await stream_py.write(DATA)
@ -57,7 +56,7 @@ async def test_net_stream_write_after_local_closed(py_to_daemon_stream_pair, p2p
await stream_py.write(DATA) await stream_py.write(DATA)
@pytest.mark.asyncio @pytest.mark.trio
async def test_net_stream_write_after_local_reset(py_to_daemon_stream_pair, p2pds): async def test_net_stream_write_after_local_reset(py_to_daemon_stream_pair, p2pds):
stream_py, stream_daemon = py_to_daemon_stream_pair stream_py, stream_daemon = py_to_daemon_stream_pair
await stream_py.reset() await stream_py.reset()
@ -66,9 +65,9 @@ async def test_net_stream_write_after_local_reset(py_to_daemon_stream_pair, p2pd
@pytest.mark.parametrize("is_to_fail_daemon_stream", (True,)) @pytest.mark.parametrize("is_to_fail_daemon_stream", (True,))
@pytest.mark.asyncio @pytest.mark.trio
async def test_net_stream_write_after_remote_reset(py_to_daemon_stream_pair, p2pds): async def test_net_stream_write_after_remote_reset(py_to_daemon_stream_pair, p2pds):
stream_py, _ = py_to_daemon_stream_pair stream_py, _ = py_to_daemon_stream_pair
await asyncio.sleep(0.01) await trio.sleep(0.01)
with pytest.raises(StreamClosed): with pytest.raises(StreamClosed):
await stream_py.write(DATA) await stream_py.write(DATA)

View File

@ -1,11 +1,15 @@
import asyncio
import functools import functools
import math
from p2pclient.pb import p2pd_pb2 from p2pclient.pb import p2pd_pb2
import pytest import pytest
import trio
from libp2p.io.trio import TrioTCPStream
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.pubsub.pb import rpc_pb2 from libp2p.pubsub.pb import rpc_pb2
from libp2p.pubsub.subscription import TrioSubscriptionAPI
from libp2p.tools.factories import PubsubFactory
from libp2p.tools.interop.utils import connect from libp2p.tools.interop.utils import connect
from libp2p.utils import read_varint_prefixed_bytes from libp2p.utils import read_varint_prefixed_bytes
@ -13,26 +17,15 @@ TOPIC_0 = "ABALA"
TOPIC_1 = "YOOOO" TOPIC_1 = "YOOOO"
async def p2pd_subscribe(p2pd, topic) -> "asyncio.Queue[rpc_pb2.Message]": async def p2pd_subscribe(p2pd, topic, nursery):
reader, writer = await p2pd.control.pubsub_subscribe(topic) stream = TrioTCPStream(await p2pd.control.pubsub_subscribe(topic))
send_channel, receive_channel = trio.open_memory_channel(math.inf)
queue = asyncio.Queue() sub = TrioSubscriptionAPI(receive_channel, unsubscribe_fn=stream.close)
async def _read_pubsub_msg() -> None: async def _read_pubsub_msg() -> None:
writer_closed_task = asyncio.ensure_future(writer.wait_closed())
while True: while True:
done, pending = await asyncio.wait( msg_bytes = await read_varint_prefixed_bytes(stream)
[read_varint_prefixed_bytes(reader), writer_closed_task],
return_when=asyncio.FIRST_COMPLETED,
)
done_tasks = tuple(done)
if writer.is_closing():
return
read_task = done_tasks[0]
# Sanity check
assert read_task._coro.__name__ == "read_varint_prefixed_bytes"
msg_bytes = read_task.result()
ps_msg = p2pd_pb2.PSMessage() ps_msg = p2pd_pb2.PSMessage()
ps_msg.ParseFromString(msg_bytes) ps_msg.ParseFromString(msg_bytes)
# Fill in the message used in py-libp2p # Fill in the message used in py-libp2p
@ -44,11 +37,10 @@ async def p2pd_subscribe(p2pd, topic) -> "asyncio.Queue[rpc_pb2.Message]":
signature=ps_msg.signature, signature=ps_msg.signature,
key=ps_msg.key, key=ps_msg.key,
) )
queue.put_nowait(msg) await send_channel.send(msg)
asyncio.ensure_future(_read_pubsub_msg()) nursery.start_soon(_read_pubsub_msg)
await asyncio.sleep(0) return sub
return queue
def validate_pubsub_msg(msg: rpc_pb2.Message, data: bytes, from_peer_id: ID) -> None: def validate_pubsub_msg(msg: rpc_pb2.Message, data: bytes, from_peer_id: ID) -> None:
@ -59,108 +51,119 @@ def validate_pubsub_msg(msg: rpc_pb2.Message, data: bytes, from_peer_id: ID) ->
"is_pubsub_signing, is_pubsub_signing_strict", ((True, True), (False, False)) "is_pubsub_signing, is_pubsub_signing_strict", ((True, True), (False, False))
) )
@pytest.mark.parametrize("is_gossipsub", (True, False)) @pytest.mark.parametrize("is_gossipsub", (True, False))
@pytest.mark.parametrize("num_hosts, num_p2pds", ((1, 2),)) @pytest.mark.parametrize("num_p2pds", (2,))
@pytest.mark.asyncio @pytest.mark.trio
async def test_pubsub(pubsubs, p2pds): async def test_pubsub(
# p2pds, is_gossipsub, is_host_secure, is_pubsub_signing_strict, nursery
# Test: Recognize pubsub peers on connection. ):
# pubsub_factory = None
py_pubsub = pubsubs[0] if is_gossipsub:
# go0 <-> py <-> go1 pubsub_factory = PubsubFactory.create_batch_with_gossipsub
await connect(p2pds[0], py_pubsub.host) else:
await connect(py_pubsub.host, p2pds[1]) pubsub_factory = PubsubFactory.create_batch_with_floodsub
py_peer_id = py_pubsub.host.get_id()
# Check pubsub peers
pubsub_peers_0 = await p2pds[0].control.pubsub_list_peers("")
assert len(pubsub_peers_0) == 1 and pubsub_peers_0[0] == py_peer_id
pubsub_peers_1 = await p2pds[1].control.pubsub_list_peers("")
assert len(pubsub_peers_1) == 1 and pubsub_peers_1[0] == py_peer_id
assert (
len(py_pubsub.peers) == 2
and p2pds[0].peer_id in py_pubsub.peers
and p2pds[1].peer_id in py_pubsub.peers
)
# async with pubsub_factory(
# Test: `subscribe`. 1, is_secure=is_host_secure, strict_signing=is_pubsub_signing_strict
# ) as pubsubs:
# (name, topics) #
# (go_0, [0, 1]) <-> (py, [0, 1]) <-> (go_1, [1]) # Test: Recognize pubsub peers on connection.
sub_py_topic_0 = await py_pubsub.subscribe(TOPIC_0) #
sub_py_topic_1 = await py_pubsub.subscribe(TOPIC_1) py_pubsub = pubsubs[0]
sub_go_0_topic_0 = await p2pd_subscribe(p2pds[0], TOPIC_0) # go0 <-> py <-> go1
sub_go_0_topic_1 = await p2pd_subscribe(p2pds[0], TOPIC_1) await connect(p2pds[0], py_pubsub.host)
sub_go_1_topic_1 = await p2pd_subscribe(p2pds[1], TOPIC_1) await connect(py_pubsub.host, p2pds[1])
# Check topic peers py_peer_id = py_pubsub.host.get_id()
await asyncio.sleep(0.1) # Check pubsub peers
# go_0 pubsub_peers_0 = await p2pds[0].control.pubsub_list_peers("")
go_0_topic_0_peers = await p2pds[0].control.pubsub_list_peers(TOPIC_0) assert len(pubsub_peers_0) == 1 and pubsub_peers_0[0] == py_peer_id
assert len(go_0_topic_0_peers) == 1 and py_peer_id == go_0_topic_0_peers[0] pubsub_peers_1 = await p2pds[1].control.pubsub_list_peers("")
go_0_topic_1_peers = await p2pds[0].control.pubsub_list_peers(TOPIC_1) assert len(pubsub_peers_1) == 1 and pubsub_peers_1[0] == py_peer_id
assert len(go_0_topic_1_peers) == 1 and py_peer_id == go_0_topic_1_peers[0] assert (
# py len(py_pubsub.peers) == 2
py_topic_0_peers = list(py_pubsub.peer_topics[TOPIC_0]) and p2pds[0].peer_id in py_pubsub.peers
assert len(py_topic_0_peers) == 1 and p2pds[0].peer_id == py_topic_0_peers[0] and p2pds[1].peer_id in py_pubsub.peers
# go_1 )
go_1_topic_1_peers = await p2pds[1].control.pubsub_list_peers(TOPIC_1)
assert len(go_1_topic_1_peers) == 1 and py_peer_id == go_1_topic_1_peers[0]
# #
# Test: `publish` # Test: `subscribe`.
# #
# 1. py publishes # (name, topics)
# - 1.1. py publishes data_11 to topic_0, py and go_0 receives. # (go_0, [0, 1]) <-> (py, [0, 1]) <-> (go_1, [1])
# - 1.2. py publishes data_12 to topic_1, all receive. sub_py_topic_0 = await py_pubsub.subscribe(TOPIC_0)
# 2. go publishes sub_py_topic_1 = await py_pubsub.subscribe(TOPIC_1)
# - 2.1. go_0 publishes data_21 to topic_0, py and go_0 receive. sub_go_0_topic_0 = await p2pd_subscribe(p2pds[0], TOPIC_0, nursery)
# - 2.2. go_1 publishes data_22 to topic_1, all receive. sub_go_0_topic_1 = await p2pd_subscribe(p2pds[0], TOPIC_1, nursery)
sub_go_1_topic_1 = await p2pd_subscribe(p2pds[1], TOPIC_1, nursery)
# Check topic peers
await trio.sleep(0.1)
# go_0
go_0_topic_0_peers = await p2pds[0].control.pubsub_list_peers(TOPIC_0)
assert len(go_0_topic_0_peers) == 1 and py_peer_id == go_0_topic_0_peers[0]
go_0_topic_1_peers = await p2pds[0].control.pubsub_list_peers(TOPIC_1)
assert len(go_0_topic_1_peers) == 1 and py_peer_id == go_0_topic_1_peers[0]
# py
py_topic_0_peers = list(py_pubsub.peer_topics[TOPIC_0])
assert len(py_topic_0_peers) == 1 and p2pds[0].peer_id == py_topic_0_peers[0]
# go_1
go_1_topic_1_peers = await p2pds[1].control.pubsub_list_peers(TOPIC_1)
assert len(go_1_topic_1_peers) == 1 and py_peer_id == go_1_topic_1_peers[0]
# 1.1. py publishes data_11 to topic_0, py and go_0 receives. #
data_11 = b"data_11" # Test: `publish`
await py_pubsub.publish(TOPIC_0, data_11) #
validate_11 = functools.partial( # 1. py publishes
validate_pubsub_msg, data=data_11, from_peer_id=py_peer_id # - 1.1. py publishes data_11 to topic_0, py and go_0 receives.
) # - 1.2. py publishes data_12 to topic_1, all receive.
validate_11(await sub_py_topic_0.get()) # 2. go publishes
validate_11(await sub_go_0_topic_0.get()) # - 2.1. go_0 publishes data_21 to topic_0, py and go_0 receive.
# - 2.2. go_1 publishes data_22 to topic_1, all receive.
# 1.2. py publishes data_12 to topic_1, all receive. # 1.1. py publishes data_11 to topic_0, py and go_0 receives.
data_12 = b"data_12" data_11 = b"data_11"
validate_12 = functools.partial( await py_pubsub.publish(TOPIC_0, data_11)
validate_pubsub_msg, data=data_12, from_peer_id=py_peer_id validate_11 = functools.partial(
) validate_pubsub_msg, data=data_11, from_peer_id=py_peer_id
await py_pubsub.publish(TOPIC_1, data_12) )
validate_12(await sub_py_topic_1.get()) validate_11(await sub_py_topic_0.get())
validate_12(await sub_go_0_topic_1.get()) validate_11(await sub_go_0_topic_0.get())
validate_12(await sub_go_1_topic_1.get())
# 2.1. go_0 publishes data_21 to topic_0, py and go_0 receive. # 1.2. py publishes data_12 to topic_1, all receive.
data_21 = b"data_21" data_12 = b"data_12"
validate_21 = functools.partial( validate_12 = functools.partial(
validate_pubsub_msg, data=data_21, from_peer_id=p2pds[0].peer_id validate_pubsub_msg, data=data_12, from_peer_id=py_peer_id
) )
await p2pds[0].control.pubsub_publish(TOPIC_0, data_21) await py_pubsub.publish(TOPIC_1, data_12)
validate_21(await sub_py_topic_0.get()) validate_12(await sub_py_topic_1.get())
validate_21(await sub_go_0_topic_0.get()) validate_12(await sub_go_0_topic_1.get())
validate_12(await sub_go_1_topic_1.get())
# 2.2. go_1 publishes data_22 to topic_1, all receive. # 2.1. go_0 publishes data_21 to topic_0, py and go_0 receive.
data_22 = b"data_22" data_21 = b"data_21"
validate_22 = functools.partial( validate_21 = functools.partial(
validate_pubsub_msg, data=data_22, from_peer_id=p2pds[1].peer_id validate_pubsub_msg, data=data_21, from_peer_id=p2pds[0].peer_id
) )
await p2pds[1].control.pubsub_publish(TOPIC_1, data_22) await p2pds[0].control.pubsub_publish(TOPIC_0, data_21)
validate_22(await sub_py_topic_1.get()) validate_21(await sub_py_topic_0.get())
validate_22(await sub_go_0_topic_1.get()) validate_21(await sub_go_0_topic_0.get())
validate_22(await sub_go_1_topic_1.get())
# # 2.2. go_1 publishes data_22 to topic_1, all receive.
# Test: `unsubscribe` and re`subscribe` data_22 = b"data_22"
# validate_22 = functools.partial(
await py_pubsub.unsubscribe(TOPIC_0) validate_pubsub_msg, data=data_22, from_peer_id=p2pds[1].peer_id
await asyncio.sleep(0.1) )
assert py_peer_id not in (await p2pds[0].control.pubsub_list_peers(TOPIC_0)) await p2pds[1].control.pubsub_publish(TOPIC_1, data_22)
assert py_peer_id not in (await p2pds[1].control.pubsub_list_peers(TOPIC_0)) validate_22(await sub_py_topic_1.get())
await py_pubsub.subscribe(TOPIC_0) validate_22(await sub_go_0_topic_1.get())
await asyncio.sleep(0.1) validate_22(await sub_go_1_topic_1.get())
assert py_peer_id in (await p2pds[0].control.pubsub_list_peers(TOPIC_0))
assert py_peer_id in (await p2pds[1].control.pubsub_list_peers(TOPIC_0)) #
# Test: `unsubscribe` and re`subscribe`
#
await py_pubsub.unsubscribe(TOPIC_0)
await trio.sleep(0.1)
assert py_peer_id not in (await p2pds[0].control.pubsub_list_peers(TOPIC_0))
assert py_peer_id not in (await p2pds[1].control.pubsub_list_peers(TOPIC_0))
await py_pubsub.subscribe(TOPIC_0)
await trio.sleep(0.1)
assert py_peer_id in (await p2pds[0].control.pubsub_list_peers(TOPIC_0))
assert py_peer_id in (await p2pds[1].control.pubsub_list_peers(TOPIC_0))

View File

@ -12,7 +12,7 @@ envlist =
combine_as_imports=False combine_as_imports=False
force_sort_within_sections=True force_sort_within_sections=True
include_trailing_comma=True include_trailing_comma=True
known_third_party=hypothesis,pytest,p2pclient,pexpect,factory,lru known_third_party=anyio,factory,lru,p2pclient,pytest
known_first_party=libp2p known_first_party=libp2p
line_length=88 line_length=88
multi_line_output=3 multi_line_output=3
@ -58,7 +58,6 @@ commands =
[testenv:py37-interop] [testenv:py37-interop]
deps = deps =
p2pclient p2pclient
pexpect
passenv = CI TRAVIS TRAVIS_* GOPATH passenv = CI TRAVIS TRAVIS_* GOPATH
extras = test extras = test
commands = commands =