From 24b2704d8c8204809f5cf8497b56a663572b505d Mon Sep 17 00:00:00 2001 From: pacrob <5199899+pacrob@users.noreply.github.com> Date: Mon, 24 Mar 2025 18:02:28 -0600 Subject: [PATCH] move factories to tests/utils/factories --- libp2p/tools/pubsub/dummy_account_node.py | 2 +- tests/conftest.py | 2 +- tests/core/examples/test_examples.py | 2 +- tests/core/host/test_connected_peers.py | 2 +- tests/core/host/test_live_peers.py | 2 +- tests/core/host/test_ping.py | 2 +- tests/core/host/test_routed_host.py | 2 +- tests/core/identity/identify/test_identify.py | 2 +- tests/core/network/conftest.py | 2 +- tests/core/network/test_notify.py | 2 +- tests/core/network/test_swarm.py | 2 +- .../protocol_muxer/test_protocol_muxer.py | 2 +- tests/core/pubsub/test_floodsub.py | 2 +- tests/core/pubsub/test_gossipsub.py | 2 +- .../test_gossipsub_backward_compatibility.py | 2 +- tests/core/pubsub/test_pubsub.py | 2 +- .../security/noise/test_msg_read_writer.py | 2 +- tests/core/security/noise/test_noise.py | 2 +- tests/core/security/test_secio.py | 2 +- .../security/test_security_multistream.py | 2 +- tests/core/stream_muxer/conftest.py | 2 +- tests/core/test_libp2p/test_libp2p.py | 2 +- tests/factories.py | 680 ------------------ tests/interop/conftest.py | 2 +- tests/interop/test_bindings.py | 2 +- tests/interop/test_echo.py | 2 +- tests/interop/test_pubsub.py | 2 +- 27 files changed, 26 insertions(+), 706 deletions(-) delete mode 100644 tests/factories.py diff --git a/libp2p/tools/pubsub/dummy_account_node.py b/libp2p/tools/pubsub/dummy_account_node.py index c208d327..a1149bd5 100644 --- a/libp2p/tools/pubsub/dummy_account_node.py +++ b/libp2p/tools/pubsub/dummy_account_node.py @@ -16,7 +16,7 @@ from libp2p.tools.async_service import ( Service, background_trio_service, ) -from tests.factories import ( +from tests.utils.factories import ( PubsubFactory, ) diff --git a/tests/conftest.py b/tests/conftest.py index c1993b88..6fc24415 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ import pytest -from tests.factories import ( +from tests.utils.factories import ( HostFactory, ) diff --git a/tests/core/examples/test_examples.py b/tests/core/examples/test_examples.py index 4dc526f6..d0c7fcf0 100644 --- a/tests/core/examples/test_examples.py +++ b/tests/core/examples/test_examples.py @@ -10,7 +10,7 @@ from libp2p.peer.peerinfo import ( from libp2p.tools.utils import ( MAX_READ_LEN, ) -from tests.factories import ( +from tests.utils.factories import ( HostFactory, ) diff --git a/tests/core/host/test_connected_peers.py b/tests/core/host/test_connected_peers.py index 990c7079..60b3750d 100644 --- a/tests/core/host/test_connected_peers.py +++ b/tests/core/host/test_connected_peers.py @@ -3,7 +3,7 @@ import pytest from libp2p.peer.peerinfo import ( info_from_p2p_addr, ) -from tests.factories import ( +from tests.utils.factories import ( HostFactory, ) diff --git a/tests/core/host/test_live_peers.py b/tests/core/host/test_live_peers.py index e58897a0..1d7948ad 100644 --- a/tests/core/host/test_live_peers.py +++ b/tests/core/host/test_live_peers.py @@ -4,7 +4,7 @@ import trio from libp2p.peer.peerinfo import ( info_from_p2p_addr, ) -from tests.factories import ( +from tests.utils.factories import ( HostFactory, ) diff --git a/tests/core/host/test_ping.py b/tests/core/host/test_ping.py index 358ab9cf..97b19e27 100644 --- a/tests/core/host/test_ping.py +++ b/tests/core/host/test_ping.py @@ -8,7 +8,7 @@ from libp2p.host.ping import ( PING_LENGTH, PingService, ) -from tests.factories import ( +from tests.utils.factories import ( host_pair_factory, ) diff --git a/tests/core/host/test_routed_host.py b/tests/core/host/test_routed_host.py index e0f87bce..1c0d21db 100644 --- a/tests/core/host/test_routed_host.py +++ b/tests/core/host/test_routed_host.py @@ -6,7 +6,7 @@ from libp2p.host.exceptions import ( from libp2p.peer.peerinfo import ( PeerInfo, ) -from tests.factories import ( +from tests.utils.factories import ( HostFactory, RoutedHostFactory, ) diff --git a/tests/core/identity/identify/test_identify.py b/tests/core/identity/identify/test_identify.py index 8dd58521..5c8bf8e3 100644 --- a/tests/core/identity/identify/test_identify.py +++ b/tests/core/identity/identify/test_identify.py @@ -15,7 +15,7 @@ from libp2p.identity.identify.identify import ( from libp2p.identity.identify.pb.identify_pb2 import ( Identify, ) -from tests.factories import ( +from tests.utils.factories import ( host_pair_factory, ) diff --git a/tests/core/network/conftest.py b/tests/core/network/conftest.py index dac403b1..05b2f6c1 100644 --- a/tests/core/network/conftest.py +++ b/tests/core/network/conftest.py @@ -1,6 +1,6 @@ import pytest -from tests.factories import ( +from tests.utils.factories import ( net_stream_pair_factory, swarm_conn_pair_factory, swarm_pair_factory, diff --git a/tests/core/network/test_notify.py b/tests/core/network/test_notify.py index 381b70f8..da00e613 100644 --- a/tests/core/network/test_notify.py +++ b/tests/core/network/test_notify.py @@ -25,7 +25,7 @@ from libp2p.tools.constants import ( from libp2p.tools.utils import ( connect_swarm, ) -from tests.factories import ( +from tests.utils.factories import ( SwarmFactory, ) diff --git a/tests/core/network/test_swarm.py b/tests/core/network/test_swarm.py index c2e908a2..8ccf7392 100644 --- a/tests/core/network/test_swarm.py +++ b/tests/core/network/test_swarm.py @@ -13,7 +13,7 @@ from libp2p.network.exceptions import ( from libp2p.tools.utils import ( connect_swarm, ) -from tests.factories import ( +from tests.utils.factories import ( SwarmFactory, ) diff --git a/tests/core/protocol_muxer/test_protocol_muxer.py b/tests/core/protocol_muxer/test_protocol_muxer.py index 1bea50a5..151dfd07 100644 --- a/tests/core/protocol_muxer/test_protocol_muxer.py +++ b/tests/core/protocol_muxer/test_protocol_muxer.py @@ -9,7 +9,7 @@ from libp2p.host.exceptions import ( from libp2p.tools.utils import ( create_echo_stream_handler, ) -from tests.factories import ( +from tests.utils.factories import ( HostFactory, ) diff --git a/tests/core/pubsub/test_floodsub.py b/tests/core/pubsub/test_floodsub.py index c056a62a..1927b614 100644 --- a/tests/core/pubsub/test_floodsub.py +++ b/tests/core/pubsub/test_floodsub.py @@ -13,7 +13,7 @@ from libp2p.tools.pubsub.floodsub_integration_test_settings import ( from libp2p.tools.utils import ( connect, ) -from tests.factories import ( +from tests.utils.factories import ( PubsubFactory, ) diff --git a/tests/core/pubsub/test_gossipsub.py b/tests/core/pubsub/test_gossipsub.py index fff2029a..aa7df33e 100644 --- a/tests/core/pubsub/test_gossipsub.py +++ b/tests/core/pubsub/test_gossipsub.py @@ -13,7 +13,7 @@ from libp2p.tools.pubsub.utils import ( from libp2p.tools.utils import ( connect, ) -from tests.factories import ( +from tests.utils.factories import ( IDFactory, PubsubFactory, ) diff --git a/tests/core/pubsub/test_gossipsub_backward_compatibility.py b/tests/core/pubsub/test_gossipsub_backward_compatibility.py index af115d17..8df16811 100644 --- a/tests/core/pubsub/test_gossipsub_backward_compatibility.py +++ b/tests/core/pubsub/test_gossipsub_backward_compatibility.py @@ -9,7 +9,7 @@ from libp2p.tools.pubsub.floodsub_integration_test_settings import ( floodsub_protocol_pytest_params, perform_test_from_obj, ) -from tests.factories import ( +from tests.utils.factories import ( PubsubFactory, ) diff --git a/tests/core/pubsub/test_pubsub.py b/tests/core/pubsub/test_pubsub.py index 02fbe661..47be6f67 100644 --- a/tests/core/pubsub/test_pubsub.py +++ b/tests/core/pubsub/test_pubsub.py @@ -30,7 +30,7 @@ from libp2p.tools.utils import ( from libp2p.utils import ( encode_varint_prefixed, ) -from tests.factories import ( +from tests.utils.factories import ( IDFactory, PubsubFactory, net_stream_pair_factory, diff --git a/tests/core/security/noise/test_msg_read_writer.py b/tests/core/security/noise/test_msg_read_writer.py index 6d84fffc..93b959cc 100644 --- a/tests/core/security/noise/test_msg_read_writer.py +++ b/tests/core/security/noise/test_msg_read_writer.py @@ -4,7 +4,7 @@ from libp2p.security.noise.io import ( MAX_NOISE_MESSAGE_LEN, NoisePacketReadWriter, ) -from tests.factories import ( +from tests.utils.factories import ( raw_conn_factory, ) diff --git a/tests/core/security/noise/test_noise.py b/tests/core/security/noise/test_noise.py index 94dde19b..37a728f7 100644 --- a/tests/core/security/noise/test_noise.py +++ b/tests/core/security/noise/test_noise.py @@ -3,7 +3,7 @@ import pytest from libp2p.security.noise.messages import ( NoiseHandshakePayload, ) -from tests.factories import ( +from tests.utils.factories import ( noise_conn_factory, noise_handshake_payload_factory, ) diff --git a/tests/core/security/test_secio.py b/tests/core/security/test_secio.py index f2df71a6..ac1a03a3 100644 --- a/tests/core/security/test_secio.py +++ b/tests/core/security/test_secio.py @@ -14,7 +14,7 @@ from libp2p.security.secio.transport import ( from libp2p.tools.constants import ( MAX_READ_LEN, ) -from tests.factories import ( +from tests.utils.factories import ( raw_conn_factory, ) diff --git a/tests/core/security/test_security_multistream.py b/tests/core/security/test_security_multistream.py index 62fb959f..c0bf3711 100644 --- a/tests/core/security/test_security_multistream.py +++ b/tests/core/security/test_security_multistream.py @@ -12,7 +12,7 @@ from libp2p.security.secio.transport import ID as SECIO_PROTOCOL_ID from libp2p.security.secure_session import ( SecureSession, ) -from tests.factories import ( +from tests.utils.factories import ( host_pair_factory, ) diff --git a/tests/core/stream_muxer/conftest.py b/tests/core/stream_muxer/conftest.py index c0bcbc32..5acf97bc 100644 --- a/tests/core/stream_muxer/conftest.py +++ b/tests/core/stream_muxer/conftest.py @@ -1,6 +1,6 @@ import pytest -from tests.factories import ( +from tests.utils.factories import ( mplex_conn_pair_factory, mplex_stream_pair_factory, ) diff --git a/tests/core/test_libp2p/test_libp2p.py b/tests/core/test_libp2p/test_libp2p.py index cd7a2156..9bcea612 100644 --- a/tests/core/test_libp2p/test_libp2p.py +++ b/tests/core/test_libp2p/test_libp2p.py @@ -14,7 +14,7 @@ from libp2p.tools.utils import ( connect, create_echo_stream_handler, ) -from tests.factories import ( +from tests.utils.factories import ( HostFactory, ) diff --git a/tests/factories.py b/tests/factories.py deleted file mode 100644 index 08a5b67e..00000000 --- a/tests/factories.py +++ /dev/null @@ -1,680 +0,0 @@ -from collections.abc import ( - AsyncIterator, - Sequence, -) -from contextlib import ( - AsyncExitStack, - asynccontextmanager, -) -from typing import ( - Any, - Callable, - cast, -) - -import factory -from multiaddr import ( - Multiaddr, -) -import trio - -from libp2p import ( - generate_new_rsa_identity, - generate_peer_id_from, -) -from libp2p.abc import ( - IHost, - INetStream, - IPeerRouting, - IPubsubRouter, - IRawConnection, - ISecureConn, - ISecureTransport, -) -from libp2p.crypto.ed25519 import create_new_key_pair as create_ed25519_key_pair -from libp2p.crypto.keys import ( - KeyPair, - PrivateKey, -) -from libp2p.crypto.secp256k1 import create_new_key_pair as create_secp256k1_key_pair -from libp2p.custom_types import ( - TMuxerOptions, - TProtocol, - TSecurityOptions, -) -from libp2p.host.basic_host import ( - BasicHost, -) -from libp2p.host.routed_host import ( - RoutedHost, -) -from libp2p.io.abc import ( - ReadWriteCloser, -) -from libp2p.network.connection.raw_connection import ( - RawConnection, -) -from libp2p.network.connection.swarm_connection import ( - SwarmConn, -) -from libp2p.network.swarm import ( - Swarm, -) -from libp2p.peer.id import ( - ID, -) -from libp2p.peer.peerinfo import ( - PeerInfo, -) -from libp2p.peer.peerstore import ( - PeerStore, -) -from libp2p.pubsub.floodsub import ( - FloodSub, -) -from libp2p.pubsub.gossipsub import ( - GossipSub, -) -import libp2p.pubsub.pb.rpc_pb2 as rpc_pb2 -from libp2p.pubsub.pubsub import ( - Pubsub, - get_peer_and_seqno_msg_id, -) -from libp2p.security.insecure.transport import ( - PLAINTEXT_PROTOCOL_ID, - InsecureTransport, -) -from libp2p.security.noise.messages import ( - NoiseHandshakePayload, - make_handshake_payload_sig, -) -from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID -from libp2p.security.noise.transport import Transport as NoiseTransport -import libp2p.security.secio.transport as secio -from libp2p.stream_muxer.mplex.mplex import ( - MPLEX_PROTOCOL_ID, - Mplex, -) -from libp2p.stream_muxer.mplex.mplex_stream import ( - MplexStream, -) -from libp2p.tools.async_service import ( - background_trio_service, -) -from libp2p.tools.constants import ( - FLOODSUB_PROTOCOL_ID, - GOSSIPSUB_PARAMS, - GOSSIPSUB_PROTOCOL_ID, - LISTEN_MADDR, -) -from libp2p.tools.utils import ( - connect, - connect_swarm, -) -from libp2p.transport.tcp.tcp import ( - TCP, -) -from libp2p.transport.upgrader import ( - TransportUpgrader, -) - -DEFAULT_SECURITY_PROTOCOL_ID = PLAINTEXT_PROTOCOL_ID - - -def default_key_pair_factory() -> KeyPair: - return generate_new_rsa_identity() - - -class IDFactory(factory.Factory): - class Meta: - model = ID - - peer_id_bytes = factory.LazyFunction( - lambda: generate_peer_id_from(default_key_pair_factory()) - ) - - -def initialize_peerstore_with_our_keypair(self_id: ID, key_pair: KeyPair) -> PeerStore: - peer_store = PeerStore() - peer_store.add_key_pair(self_id, key_pair) - return peer_store - - -def noise_static_key_factory() -> PrivateKey: - return create_ed25519_key_pair().private_key - - -def noise_handshake_payload_factory() -> NoiseHandshakePayload: - libp2p_keypair = create_secp256k1_key_pair() - noise_static_privkey = noise_static_key_factory() - return NoiseHandshakePayload( - libp2p_keypair.public_key, - make_handshake_payload_sig( - libp2p_keypair.private_key, noise_static_privkey.get_public_key() - ), - ) - - -def plaintext_transport_factory(key_pair: KeyPair) -> ISecureTransport: - return InsecureTransport(key_pair) - - -def secio_transport_factory(key_pair: KeyPair) -> ISecureTransport: - return secio.Transport(key_pair) - - -def noise_transport_factory(key_pair: KeyPair) -> ISecureTransport: - return NoiseTransport( - libp2p_keypair=key_pair, - noise_privkey=noise_static_key_factory(), - early_data=None, - with_noise_pipes=False, - ) - - -def security_options_factory_factory( - protocol_id: TProtocol = None, -) -> Callable[[KeyPair], TSecurityOptions]: - if protocol_id is None: - protocol_id = DEFAULT_SECURITY_PROTOCOL_ID - - def security_options_factory(key_pair: KeyPair) -> TSecurityOptions: - transport_factory: Callable[[KeyPair], ISecureTransport] - if protocol_id == PLAINTEXT_PROTOCOL_ID: - transport_factory = plaintext_transport_factory - elif protocol_id == secio.ID: - transport_factory = secio_transport_factory - elif protocol_id == NOISE_PROTOCOL_ID: - transport_factory = noise_transport_factory - else: - raise Exception(f"security transport {protocol_id} is not supported") - return {protocol_id: transport_factory(key_pair)} - - return security_options_factory - - -def mplex_transport_factory() -> TMuxerOptions: - return {MPLEX_PROTOCOL_ID: Mplex} - - -def default_muxer_transport_factory() -> TMuxerOptions: - return mplex_transport_factory() - - -@asynccontextmanager -async def raw_conn_factory( - nursery: trio.Nursery, -) -> AsyncIterator[tuple[IRawConnection, IRawConnection]]: - conn_0 = None - conn_1 = None - event = trio.Event() - - async def tcp_stream_handler(stream: ReadWriteCloser) -> None: - nonlocal conn_1 - conn_1 = RawConnection(stream, initiator=False) - event.set() - await trio.sleep_forever() - - tcp_transport = TCP() - listener = tcp_transport.create_listener(tcp_stream_handler) - await listener.listen(LISTEN_MADDR, nursery) - listening_maddr = listener.get_addrs()[0] - conn_0 = await tcp_transport.dial(listening_maddr) - await event.wait() - yield conn_0, conn_1 - - -@asynccontextmanager -async def noise_conn_factory( - nursery: trio.Nursery, -) -> AsyncIterator[tuple[ISecureConn, ISecureConn]]: - local_transport = cast( - NoiseTransport, noise_transport_factory(create_secp256k1_key_pair()) - ) - remote_transport = cast( - NoiseTransport, noise_transport_factory(create_secp256k1_key_pair()) - ) - - local_secure_conn: ISecureConn = None - remote_secure_conn: ISecureConn = None - - async def upgrade_local_conn() -> None: - nonlocal local_secure_conn - local_secure_conn = await local_transport.secure_outbound( - local_conn, remote_transport.local_peer - ) - - async def upgrade_remote_conn() -> None: - nonlocal remote_secure_conn - remote_secure_conn = await remote_transport.secure_inbound(remote_conn) - - async with raw_conn_factory(nursery) as conns: - local_conn, remote_conn = conns - async with trio.open_nursery() as nursery: - nursery.start_soon(upgrade_local_conn) - nursery.start_soon(upgrade_remote_conn) - if local_secure_conn is None or remote_secure_conn is None: - raise Exception( - "local or remote secure conn has not been successfully upgraded" - f"local_secure_conn={local_secure_conn}, " - f"remote_secure_conn={remote_secure_conn}" - ) - yield local_secure_conn, remote_secure_conn - - -class SwarmFactory(factory.Factory): - class Meta: - model = Swarm - - class Params: - key_pair = factory.LazyFunction(default_key_pair_factory) - security_protocol = DEFAULT_SECURITY_PROTOCOL_ID - muxer_opt = factory.LazyFunction(default_muxer_transport_factory) - - peer_id = factory.LazyAttribute(lambda o: generate_peer_id_from(o.key_pair)) - peerstore = factory.LazyAttribute( - lambda o: initialize_peerstore_with_our_keypair(o.peer_id, o.key_pair) - ) - upgrader = factory.LazyAttribute( - lambda o: TransportUpgrader( - (security_options_factory_factory(o.security_protocol))(o.key_pair), - o.muxer_opt, - ) - ) - transport = factory.LazyFunction(TCP) - - @classmethod - @asynccontextmanager - async def create_and_listen( - cls, - key_pair: KeyPair = None, - security_protocol: TProtocol = None, - muxer_opt: TMuxerOptions = None, - ) -> AsyncIterator[Swarm]: - # `factory.Factory.__init__` does *not* prepare a *default value* if we pass - # an argument explicitly with `None`. If an argument is `None`, we don't pass it - # to `factory.Factory.__init__`, in order to let the function initialize it. - optional_kwargs: dict[str, Any] = {} - if key_pair is not None: - optional_kwargs["key_pair"] = key_pair - if security_protocol is not None: - optional_kwargs["security_protocol"] = security_protocol - if muxer_opt is not None: - optional_kwargs["muxer_opt"] = muxer_opt - swarm = cls(**optional_kwargs) - async with background_trio_service(swarm): - await swarm.listen(LISTEN_MADDR) - yield swarm - - @classmethod - @asynccontextmanager - async def create_batch_and_listen( - cls, - number: int, - security_protocol: TProtocol = None, - muxer_opt: TMuxerOptions = None, - ) -> AsyncIterator[tuple[Swarm, ...]]: - async with AsyncExitStack() as stack: - ctx_mgrs = [ - await stack.enter_async_context( - cls.create_and_listen( - security_protocol=security_protocol, muxer_opt=muxer_opt - ) - ) - for _ in range(number) - ] - yield tuple(ctx_mgrs) - - -class HostFactory(factory.Factory): - class Meta: - model = BasicHost - - class Params: - key_pair = factory.LazyFunction(default_key_pair_factory) - security_protocol: TProtocol = None - muxer_opt = factory.LazyFunction(default_muxer_transport_factory) - - network = factory.LazyAttribute( - lambda o: SwarmFactory( - security_protocol=o.security_protocol, muxer_opt=o.muxer_opt - ) - ) - - @classmethod - @asynccontextmanager - async def create_batch_and_listen( - cls, - number: int, - security_protocol: TProtocol = None, - muxer_opt: TMuxerOptions = None, - ) -> AsyncIterator[tuple[BasicHost, ...]]: - async with SwarmFactory.create_batch_and_listen( - number, security_protocol=security_protocol, muxer_opt=muxer_opt - ) as swarms: - hosts = tuple(BasicHost(swarm) for swarm in swarms) - yield hosts - - -class DummyRouter(IPeerRouting): - _routing_table: dict[ID, PeerInfo] - - def __init__(self) -> None: - self._routing_table = dict() - - def _add_peer(self, peer_id: ID, addrs: list[Multiaddr]) -> None: - self._routing_table[peer_id] = PeerInfo(peer_id, addrs) - - async def find_peer(self, peer_id: ID) -> PeerInfo: - await trio.lowlevel.checkpoint() - return self._routing_table.get(peer_id, None) - - -class RoutedHostFactory(factory.Factory): - class Meta: - model = RoutedHost - - class Params: - key_pair = factory.LazyFunction(default_key_pair_factory) - security_protocol: TProtocol = None - muxer_opt = factory.LazyFunction(default_muxer_transport_factory) - - network = factory.LazyAttribute( - lambda o: HostFactory( - security_protocol=o.security_protocol, muxer_opt=o.muxer_opt - ).get_network() - ) - router = factory.LazyFunction(DummyRouter) - - @classmethod - @asynccontextmanager - async def create_batch_and_listen( - cls, - number: int, - security_protocol: TProtocol = None, - muxer_opt: TMuxerOptions = None, - ) -> AsyncIterator[tuple[RoutedHost, ...]]: - routing_table = DummyRouter() - async with HostFactory.create_batch_and_listen( - number, security_protocol=security_protocol, muxer_opt=muxer_opt - ) as hosts: - for host in hosts: - routing_table._add_peer(host.get_id(), host.get_addrs()) - routed_hosts = tuple( - RoutedHost(host.get_network(), routing_table) for host in hosts - ) - yield routed_hosts - - -class FloodsubFactory(factory.Factory): - class Meta: - model = FloodSub - - protocols = (FLOODSUB_PROTOCOL_ID,) - - -class GossipsubFactory(factory.Factory): - class Meta: - model = GossipSub - - protocols = (GOSSIPSUB_PROTOCOL_ID,) - degree = GOSSIPSUB_PARAMS.degree - degree_low = GOSSIPSUB_PARAMS.degree_low - degree_high = GOSSIPSUB_PARAMS.degree_high - gossip_window = GOSSIPSUB_PARAMS.gossip_window - gossip_history = GOSSIPSUB_PARAMS.gossip_history - heartbeat_initial_delay = GOSSIPSUB_PARAMS.heartbeat_initial_delay - heartbeat_interval = GOSSIPSUB_PARAMS.heartbeat_interval - - -class PubsubFactory(factory.Factory): - class Meta: - model = Pubsub - - host = factory.SubFactory(HostFactory) - router = None - cache_size = None - strict_signing = False - - @classmethod - @asynccontextmanager - async def create_and_start( - cls, - host: IHost, - router: IPubsubRouter, - cache_size: int, - seen_ttl: int, - sweep_interval: int, - strict_signing: bool, - msg_id_constructor: Callable[[rpc_pb2.Message], bytes] = None, - ) -> AsyncIterator[Pubsub]: - pubsub = cls( - host=host, - router=router, - cache_size=cache_size, - seen_ttl=seen_ttl, - sweep_interval=sweep_interval, - strict_signing=strict_signing, - msg_id_constructor=msg_id_constructor, - ) - async with background_trio_service(pubsub): - await pubsub.wait_until_ready() - yield pubsub - - @classmethod - @asynccontextmanager - async def _create_batch_with_router( - cls, - number: int, - routers: Sequence[IPubsubRouter], - cache_size: int = None, - seen_ttl: int = 120, - sweep_interval: int = 60, - strict_signing: bool = False, - security_protocol: TProtocol = None, - muxer_opt: TMuxerOptions = None, - msg_id_constructor: Callable[[rpc_pb2.Message], bytes] = None, - ) -> AsyncIterator[tuple[Pubsub, ...]]: - async with HostFactory.create_batch_and_listen( - number, security_protocol=security_protocol, muxer_opt=muxer_opt - ) as hosts: - # Pubsubs should exit before hosts - async with AsyncExitStack() as stack: - pubsubs = [ - await stack.enter_async_context( - cls.create_and_start( - host, - router, - cache_size, - seen_ttl, - sweep_interval, - strict_signing, - msg_id_constructor, - ) - ) - for host, router in zip(hosts, routers) - ] - yield tuple(pubsubs) - - @classmethod - @asynccontextmanager - async def create_batch_with_floodsub( - cls, - number: int, - cache_size: int = None, - seen_ttl: int = 120, - sweep_interval: int = 60, - strict_signing: bool = False, - protocols: Sequence[TProtocol] = None, - security_protocol: TProtocol = None, - muxer_opt: TMuxerOptions = None, - msg_id_constructor: Callable[ - [rpc_pb2.Message], bytes - ] = get_peer_and_seqno_msg_id, - ) -> AsyncIterator[tuple[Pubsub, ...]]: - if protocols is not None: - floodsubs = FloodsubFactory.create_batch(number, protocols=list(protocols)) - else: - floodsubs = FloodsubFactory.create_batch(number) - async with cls._create_batch_with_router( - number, - floodsubs, - cache_size, - seen_ttl, - sweep_interval, - strict_signing, - security_protocol=security_protocol, - muxer_opt=muxer_opt, - msg_id_constructor=msg_id_constructor, - ) as pubsubs: - yield pubsubs - - @classmethod - @asynccontextmanager - async def create_batch_with_gossipsub( - cls, - number: int, - *, - cache_size: int = None, - strict_signing: bool = False, - protocols: Sequence[TProtocol] = None, - degree: int = GOSSIPSUB_PARAMS.degree, - degree_low: int = GOSSIPSUB_PARAMS.degree_low, - degree_high: int = GOSSIPSUB_PARAMS.degree_high, - time_to_live: int = GOSSIPSUB_PARAMS.time_to_live, - gossip_window: int = GOSSIPSUB_PARAMS.gossip_window, - gossip_history: int = GOSSIPSUB_PARAMS.gossip_history, - heartbeat_interval: float = GOSSIPSUB_PARAMS.heartbeat_interval, - heartbeat_initial_delay: float = GOSSIPSUB_PARAMS.heartbeat_initial_delay, - security_protocol: TProtocol = None, - muxer_opt: TMuxerOptions = None, - msg_id_constructor: Callable[ - [rpc_pb2.Message], bytes - ] = get_peer_and_seqno_msg_id, - ) -> AsyncIterator[tuple[Pubsub, ...]]: - if protocols is not None: - gossipsubs = GossipsubFactory.create_batch( - number, - protocols=protocols, - degree=degree, - degree_low=degree_low, - degree_high=degree_high, - time_to_live=time_to_live, - gossip_window=gossip_window, - heartbeat_interval=heartbeat_interval, - ) - else: - gossipsubs = GossipsubFactory.create_batch( - number, - degree=degree, - degree_low=degree_low, - degree_high=degree_high, - gossip_window=gossip_window, - heartbeat_interval=heartbeat_interval, - ) - - async with cls._create_batch_with_router( - number, - gossipsubs, - cache_size, - strict_signing, - security_protocol=security_protocol, - muxer_opt=muxer_opt, - msg_id_constructor=msg_id_constructor, - ) as pubsubs: - async with AsyncExitStack() as stack: - for router in gossipsubs: - await stack.enter_async_context(background_trio_service(router)) - yield pubsubs - - -@asynccontextmanager -async def swarm_pair_factory( - security_protocol: TProtocol = None, muxer_opt: TMuxerOptions = None -) -> AsyncIterator[tuple[Swarm, Swarm]]: - async with SwarmFactory.create_batch_and_listen( - 2, security_protocol=security_protocol, muxer_opt=muxer_opt - ) as swarms: - await connect_swarm(swarms[0], swarms[1]) - yield swarms[0], swarms[1] - - -@asynccontextmanager -async def host_pair_factory( - security_protocol: TProtocol = None, muxer_opt: TMuxerOptions = None -) -> AsyncIterator[tuple[BasicHost, BasicHost]]: - async with HostFactory.create_batch_and_listen( - 2, security_protocol=security_protocol, muxer_opt=muxer_opt - ) as hosts: - await connect(hosts[0], hosts[1]) - yield hosts[0], hosts[1] - - -@asynccontextmanager -async def swarm_conn_pair_factory( - security_protocol: TProtocol = None, muxer_opt: TMuxerOptions = None -) -> AsyncIterator[tuple[SwarmConn, SwarmConn]]: - async with swarm_pair_factory( - security_protocol=security_protocol, muxer_opt=muxer_opt - ) as swarms: - conn_0 = swarms[0].connections[swarms[1].get_peer_id()] - conn_1 = swarms[1].connections[swarms[0].get_peer_id()] - yield cast(SwarmConn, conn_0), cast(SwarmConn, conn_1) - - -@asynccontextmanager -async def mplex_conn_pair_factory( - security_protocol: TProtocol = None, -) -> AsyncIterator[tuple[Mplex, Mplex]]: - async with swarm_conn_pair_factory( - security_protocol=security_protocol, muxer_opt=default_muxer_transport_factory() - ) as swarm_pair: - yield ( - cast(Mplex, swarm_pair[0].muxed_conn), - cast(Mplex, swarm_pair[1].muxed_conn), - ) - - -@asynccontextmanager -async def mplex_stream_pair_factory( - security_protocol: TProtocol = None, -) -> AsyncIterator[tuple[MplexStream, MplexStream]]: - async with mplex_conn_pair_factory( - security_protocol=security_protocol - ) as mplex_conn_pair_info: - mplex_conn_0, mplex_conn_1 = mplex_conn_pair_info - stream_0 = cast(MplexStream, await mplex_conn_0.open_stream()) - await trio.sleep(0.01) - stream_1: MplexStream - async with mplex_conn_1.streams_lock: - if len(mplex_conn_1.streams) != 1: - raise Exception("Mplex should not have any other stream") - stream_1 = tuple(mplex_conn_1.streams.values())[0] - yield stream_0, stream_1 - - -@asynccontextmanager -async def net_stream_pair_factory( - security_protocol: TProtocol = None, muxer_opt: TMuxerOptions = None -) -> AsyncIterator[tuple[INetStream, INetStream]]: - protocol_id = TProtocol("/example/id/1") - - stream_1: INetStream - - # Just a proxy, we only care about the stream. - # Add a barrier to avoid stream being removed. - event_handler_finished = trio.Event() - - async def handler(stream: INetStream) -> None: - nonlocal stream_1 - stream_1 = stream - await event_handler_finished.wait() - - async with host_pair_factory( - security_protocol=security_protocol, muxer_opt=muxer_opt - ) as hosts: - hosts[1].set_stream_handler(protocol_id, handler) - - stream_0 = await hosts[0].new_stream(hosts[1].get_id(), [protocol_id]) - yield stream_0, stream_1 - event_handler_finished.set() diff --git a/tests/interop/conftest.py b/tests/interop/conftest.py index 728dddda..803ca539 100644 --- a/tests/interop/conftest.py +++ b/tests/interop/conftest.py @@ -17,7 +17,7 @@ from libp2p.io.abc import ( ) from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID from libp2p.security.secio.transport import ID as SECIO_PROTOCOL_ID -from tests.factories import ( +from tests.utils.factories import ( HostFactory, PubsubFactory, ) diff --git a/tests/interop/test_bindings.py b/tests/interop/test_bindings.py index f38fad5c..851d298f 100644 --- a/tests/interop/test_bindings.py +++ b/tests/interop/test_bindings.py @@ -1,7 +1,7 @@ import pytest import trio -from tests.factories import ( +from tests.utils.factories import ( HostFactory, ) from tests.utils.interop.utils import ( diff --git a/tests/interop/test_echo.py b/tests/interop/test_echo.py index 164a90c6..2276463e 100644 --- a/tests/interop/test_echo.py +++ b/tests/interop/test_echo.py @@ -16,7 +16,7 @@ from libp2p.peer.peerinfo import ( PeerInfo, info_from_p2p_addr, ) -from tests.factories import ( +from tests.utils.factories import ( HostFactory, ) from tests.utils.interop.envs import ( diff --git a/tests/interop/test_pubsub.py b/tests/interop/test_pubsub.py index 485d48f4..6f7ec34f 100644 --- a/tests/interop/test_pubsub.py +++ b/tests/interop/test_pubsub.py @@ -22,7 +22,7 @@ from libp2p.pubsub.subscription import ( from libp2p.utils import ( read_varint_prefixed_bytes, ) -from tests.factories import ( +from tests.utils.factories import ( PubsubFactory, ) from tests.utils.interop.utils import (