diff --git a/libp2p/pubsub/exceptions.py b/libp2p/pubsub/exceptions.py new file mode 100644 index 00000000..a47446de --- /dev/null +++ b/libp2p/pubsub/exceptions.py @@ -0,0 +1,9 @@ +from libp2p.exceptions import BaseLibp2pError + + +class PubsubRouterError(BaseLibp2pError): + ... + + +class NoPubsubAttached(PubsubRouterError): + ... diff --git a/libp2p/tools/factories.py b/libp2p/tools/factories.py index 568a2762..e1798898 100644 --- a/libp2p/tools/factories.py +++ b/libp2p/tools/factories.py @@ -9,6 +9,9 @@ from libp2p import generate_new_rsa_identity, generate_peer_id_from from libp2p.crypto.keys import KeyPair from libp2p.host.basic_host import BasicHost from libp2p.host.host_interface import IHost +from libp2p.io.abc import ReadWriteCloser +from libp2p.network.connection.raw_connection import RawConnection +from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.connection.swarm_connection import SwarmConn from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.swarm import Swarm @@ -51,6 +54,27 @@ def security_transport_factory( return {secio.ID: secio.Transport(key_pair)} +@asynccontextmanager +async def raw_conn_factory( + nursery: trio.Nursery +) -> AsyncIterator[Tuple[IRawConnection, IRawConnection]]: + conn_0 = None + conn_1 = None + + async def tcp_stream_handler(stream: ReadWriteCloser) -> None: + nonlocal conn_1 + conn_1 = RawConnection(stream, initiator=False) + await trio.sleep_forever() + + tcp_transport = TCP() + listener = tcp_transport.create_listener(tcp_stream_handler) + await listener.listen(LISTEN_MADDR, nursery) + listening_maddr = listener.multiaddrs[0] + conn_0 = await tcp_transport.dial(listening_maddr) + print("raw_conn_factory") + yield conn_0, conn_1 + + class SwarmFactory(factory.Factory): class Meta: model = Swarm diff --git a/setup.py b/setup.py index cbb8eaf8..edfd1aae 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,6 @@ extras_require = { "test": [ "factory-boy>=2.12.0,<3.0.0", "pytest>=4.6.3,<5.0.0", - "pytest-asyncio>=0.10.0,<1.0.0", "pytest-xdist>=1.30.0", "pytest-trio>=0.5.2", ], diff --git a/tests/security/test_secio.py b/tests/security/test_secio.py index c7808b46..d009a738 100644 --- a/tests/security/test_secio.py +++ b/tests/security/test_secio.py @@ -1,70 +1,15 @@ -import asyncio - import pytest +import trio from libp2p.crypto.secp256k1 import create_new_key_pair -from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.peer.id import ID from libp2p.security.secio.transport import NONCE_SIZE, create_secure_session +from libp2p.tools.constants import MAX_READ_LEN +from libp2p.tools.factories import raw_conn_factory -class InMemoryConnection(IRawConnection): - def __init__(self, peer, is_initiator=False): - self.peer = peer - self.recv_queue = asyncio.Queue() - self.send_queue = asyncio.Queue() - self.is_initiator = is_initiator - - self.current_msg = None - self.current_position = 0 - - self.closed = False - - async def write(self, data: bytes) -> int: - if self.closed: - raise Exception("InMemoryConnection is closed for writing") - - await self.send_queue.put(data) - return len(data) - - async def read(self, n: int = -1) -> bytes: - """ - NOTE: have to buffer the current message and juggle packets - off the recv queue to satisfy the semantics of this function. - """ - if self.closed: - raise Exception("InMemoryConnection is closed for reading") - - if not self.current_msg: - self.current_msg = await self.recv_queue.get() - self.current_position = 0 - - if n < 0: - msg = self.current_msg - self.current_msg = None - return msg - - next_msg = self.current_msg[self.current_position : self.current_position + n] - self.current_position += n - if self.current_position == len(self.current_msg): - self.current_msg = None - return next_msg - - async def close(self) -> None: - self.closed = True - - -async def create_pipe(local_conn, remote_conn): - try: - while True: - next_msg = await local_conn.send_queue.get() - await remote_conn.recv_queue.put(next_msg) - except asyncio.CancelledError: - return - - -@pytest.mark.asyncio -async def test_create_secure_session(): +@pytest.mark.trio +async def test_create_secure_session(nursery): local_nonce = b"\x01" * NONCE_SIZE local_key_pair = create_new_key_pair(b"a") local_peer = ID.from_pubkey(local_key_pair.public_key) @@ -73,30 +18,32 @@ async def test_create_secure_session(): remote_key_pair = create_new_key_pair(b"b") remote_peer = ID.from_pubkey(remote_key_pair.public_key) - local_conn = InMemoryConnection(local_peer, is_initiator=True) - remote_conn = InMemoryConnection(remote_peer) + async with raw_conn_factory(nursery) as conns: + local_conn, remote_conn = conns - local_pipe_task = asyncio.create_task(create_pipe(local_conn, remote_conn)) - remote_pipe_task = asyncio.create_task(create_pipe(remote_conn, local_conn)) + local_secure_conn, remote_secure_conn = None, None - local_session_builder = create_secure_session( - local_nonce, local_peer, local_key_pair.private_key, local_conn, remote_peer - ) - remote_session_builder = create_secure_session( - remote_nonce, remote_peer, remote_key_pair.private_key, remote_conn - ) - local_secure_conn, remote_secure_conn = await asyncio.gather( - local_session_builder, remote_session_builder - ) + async def local_create_secure_session(): + nonlocal local_secure_conn + local_secure_conn = await create_secure_session( + local_nonce, + local_peer, + local_key_pair.private_key, + local_conn, + remote_peer, + ) - msg = b"abc" - await local_secure_conn.write(msg) - received_msg = await remote_secure_conn.read() - assert received_msg == msg + async def remote_create_secure_session(): + nonlocal remote_secure_conn + remote_secure_conn = await create_secure_session( + remote_nonce, remote_peer, remote_key_pair.private_key, remote_conn + ) - await asyncio.gather(local_secure_conn.close(), remote_secure_conn.close()) + async with trio.open_nursery() as nursery_1: + nursery_1.start_soon(local_create_secure_session) + nursery_1.start_soon(remote_create_secure_session) - local_pipe_task.cancel() - remote_pipe_task.cancel() - await local_pipe_task - await remote_pipe_task + msg = b"abc" + await local_secure_conn.write(msg) + received_msg = await remote_secure_conn.read(MAX_READ_LEN) + assert received_msg == msg diff --git a/tests/security/test_security_multistream.py b/tests/security/test_security_multistream.py index c4eb3ecb..5c751f92 100644 --- a/tests/security/test_security_multistream.py +++ b/tests/security/test_security_multistream.py @@ -1,6 +1,6 @@ -import asyncio - +from async_service import background_trio_service import pytest +import trio from libp2p import new_node from libp2p.crypto.rsa import create_new_key_pair @@ -24,42 +24,39 @@ noninitiator_key_pair = create_new_key_pair() async def perform_simple_test( assertion_func, transports_for_initiator, transports_for_noninitiator ): - # Create libp2p nodes and connect them, then secure the connection, then check # the proper security was chosen # TODO: implement -- note we need to introduce the notion of communicating over a raw connection # for testing, we do NOT want to communicate over a stream so we can't just create two nodes # and use their conn because our mplex will internally relay messages to a stream - node1 = await new_node( - key_pair=initiator_key_pair, sec_opt=transports_for_initiator - ) - node2 = await new_node( + node1 = new_node(key_pair=initiator_key_pair, sec_opt=transports_for_initiator) + node2 = new_node( key_pair=noninitiator_key_pair, sec_opt=transports_for_noninitiator ) + swarm1 = node1.get_network() + swarm2 = node2.get_network() + async with background_trio_service(swarm1), background_trio_service(swarm2): + await swarm1.listen(LISTEN_MADDR) + await swarm2.listen(LISTEN_MADDR) - await node1.get_network().listen(LISTEN_MADDR) - await node2.get_network().listen(LISTEN_MADDR) + await connect(node1, node2) - await connect(node1, node2) + # Wait a very short period to allow conns to be stored (since the functions + # storing the conns are async, they may happen at slightly different times + # on each node) + await trio.sleep(0.1) - # Wait a very short period to allow conns to be stored (since the functions - # storing the conns are async, they may happen at slightly different times - # on each node) - await asyncio.sleep(0.1) + # Get conns + node1_conn = node1.get_network().connections[peer_id_for_node(node2)] + node2_conn = node2.get_network().connections[peer_id_for_node(node1)] - # Get conns - node1_conn = node1.get_network().connections[peer_id_for_node(node2)] - node2_conn = node2.get_network().connections[peer_id_for_node(node1)] - - # Perform assertion - assertion_func(node1_conn.muxed_conn.secured_conn) - assertion_func(node2_conn.muxed_conn.secured_conn) - - # Success, terminate pending tasks. + # Perform assertion + assertion_func(node1_conn.muxed_conn.secured_conn) + assertion_func(node2_conn.muxed_conn.secured_conn) -@pytest.mark.asyncio +@pytest.mark.trio async def test_single_insecure_security_transport_succeeds(): transports_for_initiator = {"foo": InsecureTransport(initiator_key_pair)} transports_for_noninitiator = {"foo": InsecureTransport(noninitiator_key_pair)} @@ -72,7 +69,7 @@ async def test_single_insecure_security_transport_succeeds(): ) -@pytest.mark.asyncio +@pytest.mark.trio async def test_default_insecure_security(): transports_for_initiator = None transports_for_noninitiator = None