Rewrite factories, made some of the test running

This commit is contained in:
mhchia
2019-11-26 19:24:30 +08:00
parent 417b5e7d61
commit ec43c25b45
13 changed files with 260 additions and 282 deletions

View File

@ -1,8 +1,9 @@
import asyncio
from contextlib import asynccontextmanager
import trio
from contextlib import asynccontextmanager, AsyncExitStack
from typing import Any, AsyncIterator, Dict, Tuple, cast
import factory
from async_service import background_trio_service
from libp2p import generate_new_rsa_identity, generate_peer_id_from
from libp2p.crypto.keys import KeyPair
@ -61,6 +62,7 @@ class SwarmFactory(factory.Factory):
transport = factory.LazyFunction(TCP)
@classmethod
@asynccontextmanager
async def create_and_listen(
cls, is_secure: bool, key_pair: KeyPair = None, muxer_opt: TMuxerOptions = None
) -> Swarm:
@ -73,20 +75,23 @@ class SwarmFactory(factory.Factory):
if muxer_opt is not None:
optional_kwargs["muxer_opt"] = muxer_opt
swarm = cls(is_secure=is_secure, **optional_kwargs)
await swarm.listen(LISTEN_MADDR)
return swarm
async with background_trio_service(swarm):
await swarm.listen(LISTEN_MADDR)
yield swarm
@classmethod
@asynccontextmanager
async def create_batch_and_listen(
cls, is_secure: bool, number: int, muxer_opt: TMuxerOptions = None
) -> Tuple[Swarm, ...]:
# Ignore typing since we are removing asyncio soon
return await asyncio.gather( # type: ignore
*[
cls.create_and_listen(is_secure=is_secure, muxer_opt=muxer_opt)
async with AsyncExitStack() as stack:
ctx_mgrs = [
await stack.enter_async_context(
cls.create_and_listen(is_secure=is_secure, muxer_opt=muxer_opt)
)
for _ in range(number)
]
)
yield ctx_mgrs
class HostFactory(factory.Factory):
@ -103,20 +108,23 @@ class HostFactory(factory.Factory):
)
@classmethod
@asynccontextmanager
async def create_batch_and_listen(
cls, is_secure: bool, number: int
) -> Tuple[BasicHost, ...]:
key_pairs = [generate_new_rsa_identity() for _ in range(number)]
swarms = await asyncio.gather(
*[
SwarmFactory.create_and_listen(is_secure, key_pair)
async with AsyncExitStack() as stack:
swarms = [
await stack.enter_async_context(
SwarmFactory.create_and_listen(is_secure, key_pair)
)
for key_pair in key_pairs
]
)
return tuple(
BasicHost(key_pair.public_key, swarm)
for key_pair, swarm in zip(key_pairs, swarms)
)
hosts = tuple(
BasicHost(key_pair.public_key, swarm)
for key_pair, swarm in zip(key_pairs, swarms)
)
yield hosts
class FloodsubFactory(factory.Factory):
@ -150,73 +158,60 @@ class PubsubFactory(factory.Factory):
cache_size = None
@asynccontextmanager
async def swarm_pair_factory(
is_secure: bool, muxer_opt: TMuxerOptions = None
) -> Tuple[Swarm, Swarm]:
swarms = await SwarmFactory.create_batch_and_listen(
async with SwarmFactory.create_batch_and_listen(
is_secure, 2, muxer_opt=muxer_opt
)
await connect_swarm(swarms[0], swarms[1])
return swarms[0], swarms[1]
async def host_pair_factory(is_secure: bool) -> Tuple[BasicHost, BasicHost]:
hosts = await HostFactory.create_batch_and_listen(is_secure, 2)
await connect(hosts[0], hosts[1])
return hosts[0], hosts[1]
) as swarms:
await connect_swarm(swarms[0], swarms[1])
yield swarms[0], swarms[1]
@asynccontextmanager
async def pair_of_connected_hosts(
is_secure: bool = True
) -> AsyncIterator[Tuple[BasicHost, BasicHost]]:
a, b = await host_pair_factory(is_secure)
yield a, b
close_tasks = (a.close(), b.close())
await asyncio.gather(*close_tasks)
async def host_pair_factory(is_secure: bool) -> Tuple[BasicHost, BasicHost]:
async with HostFactory.create_batch_and_listen(is_secure, 2) as hosts:
await connect(hosts[0], hosts[1])
yield hosts[0], hosts[1]
@asynccontextmanager
async def swarm_conn_pair_factory(
is_secure: bool, muxer_opt: TMuxerOptions = None
) -> Tuple[SwarmConn, Swarm, SwarmConn, Swarm]:
swarms = await swarm_pair_factory(is_secure)
conn_0 = swarms[0].connections[swarms[1].get_peer_id()]
conn_1 = swarms[1].connections[swarms[0].get_peer_id()]
return cast(SwarmConn, conn_0), swarms[0], cast(SwarmConn, conn_1), swarms[1]
) -> Tuple[SwarmConn, SwarmConn]:
async with swarm_pair_factory(is_secure) as swarms:
conn_0 = swarms[0].connections[swarms[1].get_peer_id()]
conn_1 = swarms[1].connections[swarms[0].get_peer_id()]
yield cast(SwarmConn, conn_0), cast(SwarmConn, conn_1)
async def mplex_conn_pair_factory(is_secure: bool) -> Tuple[Mplex, Swarm, Mplex, Swarm]:
@asynccontextmanager
async def mplex_conn_pair_factory(is_secure: bool) -> Tuple[Mplex, Mplex]:
muxer_opt = {MPLEX_PROTOCOL_ID: Mplex}
conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory(
is_secure, muxer_opt=muxer_opt
)
return (
cast(Mplex, conn_0.muxed_conn),
swarm_0,
cast(Mplex, conn_1.muxed_conn),
swarm_1,
)
async with swarm_conn_pair_factory(is_secure, muxer_opt=muxer_opt) as swarm_pair:
yield (
cast(Mplex, swarm_pair[0].muxed_conn),
cast(Mplex, swarm_pair[1].muxed_conn),
)
async def mplex_stream_pair_factory(
is_secure: bool
) -> Tuple[MplexStream, Swarm, MplexStream, Swarm]:
mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory(
is_secure
)
stream_0 = await mplex_conn_0.open_stream()
await asyncio.sleep(0.01)
stream_1: MplexStream
async with mplex_conn_1.streams_lock:
if len(mplex_conn_1.streams) != 1:
raise Exception("Mplex should not have any stream upon connection")
stream_1 = tuple(mplex_conn_1.streams.values())[0]
return cast(MplexStream, stream_0), swarm_0, stream_1, swarm_1
@asynccontextmanager
async def mplex_stream_pair_factory(is_secure: bool) -> Tuple[MplexStream, MplexStream]:
async with mplex_conn_pair_factory(is_secure) as mplex_conn_pair_info:
mplex_conn_0, mplex_conn_1 = mplex_conn_pair_info
stream_0 = await mplex_conn_0.open_stream()
await trio.sleep(0.01)
stream_1: MplexStream
async with mplex_conn_1.streams_lock:
if len(mplex_conn_1.streams) != 1:
raise Exception("Mplex should not have any stream upon connection")
stream_1 = tuple(mplex_conn_1.streams.values())[0]
yield cast(MplexStream, stream_0), cast(MplexStream, stream_1)
async def net_stream_pair_factory(
is_secure: bool
) -> Tuple[INetStream, BasicHost, INetStream, BasicHost]:
@asynccontextmanager
async def net_stream_pair_factory(is_secure: bool) -> Tuple[INetStream, INetStream]:
protocol_id = TProtocol("/example/id/1")
stream_1: INetStream
@ -226,8 +221,8 @@ async def net_stream_pair_factory(
nonlocal stream_1
stream_1 = stream
host_0, host_1 = await host_pair_factory(is_secure)
host_1.set_stream_handler(protocol_id, handler)
async with host_pair_factory(is_secure) as hosts:
hosts[1].set_stream_handler(protocol_id, handler)
stream_0 = await host_0.new_stream(host_1.get_id(), [protocol_id])
return stream_0, host_0, stream_1, host_1
stream_0 = await hosts[0].new_stream(hosts[1].get_id(), [protocol_id])
yield stream_0, stream_1

View File

@ -17,7 +17,7 @@ from libp2p.typing import StreamHandlerFn, TProtocol
from .constants import MAX_READ_LEN
async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm, nursery: trio.Nursery) -> None:
async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm) -> None:
peer_id = swarm_1.get_peer_id()
addrs = tuple(
addr
@ -25,7 +25,7 @@ async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm, nursery: trio.Nursery) -
for addr in transport.get_addrs()
)
swarm_0.peerstore.add_addrs(peer_id, addrs, 10000)
await swarm_0.dial_peer(peer_id, nursery)
await swarm_0.dial_peer(peer_id)
assert swarm_0.get_peer_id() in swarm_1.connections
assert swarm_1.get_peer_id() in swarm_0.connections