Fix all modules except for security

This commit is contained in:
mhchia
2019-12-06 17:06:37 +08:00
parent e9ab0646e3
commit 1929f307fb
28 changed files with 764 additions and 955 deletions

View File

@ -1,6 +1,7 @@
import trio
import logging import logging
import trio
from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset
from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.peer.id import ID as PeerID from libp2p.peer.id import ID as PeerID

View File

@ -1,7 +1,6 @@
import logging import logging
import trio import trio
from trio import SocketStream
from libp2p.io.abc import ReadWriteCloser from libp2p.io.abc import ReadWriteCloser
from libp2p.io.exceptions import IOException from libp2p.io.exceptions import IOException
@ -9,29 +8,48 @@ from libp2p.io.exceptions import IOException
logger = logging.getLogger("libp2p.io.trio") logger = logging.getLogger("libp2p.io.trio")
class TrioReadWriteCloser(ReadWriteCloser): class TrioTCPStream(ReadWriteCloser):
stream: SocketStream stream: trio.SocketStream
# NOTE: Add both read and write lock to avoid `trio.BusyResourceError`
read_lock: trio.Lock
write_lock: trio.Lock
def __init__(self, stream: SocketStream) -> None: def __init__(self, stream: trio.SocketStream) -> None:
self.stream = stream self.stream = stream
self.read_lock = trio.Lock()
self.write_lock = trio.Lock()
async def write(self, data: bytes) -> None: async def write(self, data: bytes) -> None:
"""Raise `RawConnError` if the underlying connection breaks.""" """Raise `RawConnError` if the underlying connection breaks."""
try: async with self.write_lock:
await self.stream.send_all(data) try:
except (trio.ClosedResourceError, trio.BrokenResourceError) as error: await self.stream.send_all(data)
raise IOException(error) except (trio.ClosedResourceError, trio.BrokenResourceError) as error:
raise IOException from error
except trio.BusyResourceError as error:
# This should never happen, since we already access streams with read/write locks.
raise Exception(
"this should never happen "
"since we already access streams with read/write locks."
) from error
async def read(self, n: int = -1) -> bytes: async def read(self, n: int = -1) -> bytes:
if n == 0: async with self.read_lock:
# Check point if n == 0:
await trio.sleep(0) # Checkpoint
return b"" await trio.hazmat.checkpoint()
max_bytes = n if n != -1 else None return b""
try: max_bytes = n if n != -1 else None
return await self.stream.receive_some(max_bytes) try:
except (trio.ClosedResourceError, trio.BrokenResourceError) as error: return await self.stream.receive_some(max_bytes)
raise IOException(error) except (trio.ClosedResourceError, trio.BrokenResourceError) as error:
raise IOException from error
except trio.BusyResourceError as error:
# This should never happen, since we already access streams with read/write locks.
raise Exception(
"this should never happen "
"since we already access streams with read/write locks."
) from error
async def close(self) -> None: async def close(self) -> None:
await self.stream.aclose() await self.stream.aclose()

View File

@ -1,5 +1,3 @@
import trio
from libp2p.io.abc import ReadWriteCloser from libp2p.io.abc import ReadWriteCloser
from libp2p.io.exceptions import IOException from libp2p.io.exceptions import IOException
@ -8,17 +6,17 @@ from .raw_connection_interface import IRawConnection
class RawConnection(IRawConnection): class RawConnection(IRawConnection):
read_write_closer: ReadWriteCloser stream: ReadWriteCloser
is_initiator: bool is_initiator: bool
def __init__(self, read_write_closer: ReadWriteCloser, initiator: bool) -> None: def __init__(self, stream: ReadWriteCloser, initiator: bool) -> None:
self.read_write_closer = read_write_closer self.stream = stream
self.is_initiator = initiator self.is_initiator = initiator
async def write(self, data: bytes) -> None: async def write(self, data: bytes) -> None:
"""Raise `RawConnError` if the underlying connection breaks.""" """Raise `RawConnError` if the underlying connection breaks."""
try: try:
await self.read_write_closer.write(data) await self.stream.write(data)
except IOException as error: except IOException as error:
raise RawConnError(error) raise RawConnError(error)
@ -30,9 +28,9 @@ class RawConnection(IRawConnection):
Raise `RawConnError` if the underlying connection breaks Raise `RawConnError` if the underlying connection breaks
""" """
try: try:
return await self.read_write_closer.read(n) return await self.stream.read(n)
except IOException as error: except IOException as error:
raise RawConnError(error) raise RawConnError(error)
async def close(self) -> None: async def close(self) -> None:
await self.read_write_closer.close() await self.stream.close()

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Awaitable, List, Set, Tuple from typing import TYPE_CHECKING, Set, Tuple
from async_service import Service from async_service import Service
import trio import trio
@ -45,16 +45,11 @@ class SwarmConn(INetConn, Service):
# before we cancel the stream handler tasks. # before we cancel the stream handler tasks.
await trio.sleep(0.1) await trio.sleep(0.1)
# FIXME: Now let `_notify_disconnected` finish first.
# Schedule `self._notify_disconnected` to make it execute after `close` is finished.
await self._notify_disconnected() await self._notify_disconnected()
async def _handle_new_streams(self) -> None: async def _handle_new_streams(self) -> None:
while self.manager.is_running: while self.manager.is_running:
try: try:
print(
f"!@# SwarmConn._handle_new_streams: {self.muxed_conn._id}: waiting for new streams"
)
stream = await self.muxed_conn.accept_stream() stream = await self.muxed_conn.accept_stream()
except MuxedConnUnavailable: except MuxedConnUnavailable:
# If there is anything wrong in the MuxedConn, # If there is anything wrong in the MuxedConn,
@ -63,9 +58,6 @@ class SwarmConn(INetConn, Service):
# Asynchronously handle the accepted stream, to avoid blocking the next stream. # Asynchronously handle the accepted stream, to avoid blocking the next stream.
self.manager.run_task(self._handle_muxed_stream, stream) self.manager.run_task(self._handle_muxed_stream, stream)
print(
f"!@# SwarmConn._handle_new_streams: {self.muxed_conn._id}: out of the loop"
)
await self.close() await self.close()
async def _call_stream_handler(self, net_stream: NetStream) -> None: async def _call_stream_handler(self, net_stream: NetStream) -> None:
@ -92,8 +84,7 @@ class SwarmConn(INetConn, Service):
await self.swarm.notify_disconnected(self) await self.swarm.notify_disconnected(self)
async def run(self) -> None: async def run(self) -> None:
self.manager.run_task(self._handle_new_streams) await self._handle_new_streams()
await self.manager.wait_finished()
async def new_stream(self) -> NetStream: async def new_stream(self) -> NetStream:
muxed_stream = await self.muxed_conn.open_stream() muxed_stream = await self.muxed_conn.open_stream()

View File

@ -203,16 +203,17 @@ class Swarm(INetwork, Service):
await self.add_conn(muxed_conn) await self.add_conn(muxed_conn)
logger.debug("successfully opened connection to peer %s", peer_id) logger.debug("successfully opened connection to peer %s", peer_id)
# FIXME: This is a intentional barrier to prevent from the handler exiting and # NOTE: This is a intentional barrier to prevent from the handler exiting and
# closing the connection. Probably change to `Service.manager.wait_finished`? # closing the connection.
await trio.sleep_forever() await self.manager.wait_finished()
try: try:
# Success # Success
listener = self.transport.create_listener(conn_handler) listener = self.transport.create_listener(conn_handler)
self.listeners[str(maddr)] = listener self.listeners[str(maddr)] = listener
# FIXME: Hack # TODO: `listener.listen` is not bounded with nursery. If we want to be
await listener.listen(maddr, self.manager._task_nursery) # I/O agnostic, we should change the API.
await listener.listen(maddr, self.manager._task_nursery) # type: ignore
# Call notifiers since event occurred # Call notifiers since event occurred
await self.notify_listen(maddr) await self.notify_listen(maddr)
@ -278,6 +279,7 @@ class Swarm(INetwork, Service):
""" """
self.notifees.append(notifee) self.notifees.append(notifee)
# TODO: Use `run_task`.
async def notify_opened_stream(self, stream: INetStream) -> None: async def notify_opened_stream(self, stream: INetStream) -> None:
async with trio.open_nursery() as nursery: async with trio.open_nursery() as nursery:
for notifee in self.notifees: for notifee in self.notifees:

View File

@ -64,7 +64,7 @@ class FloodSub(IPubsubRouter):
:param rpc: rpc message :param rpc: rpc message
""" """
# Checkpoint # Checkpoint
await trio.sleep(0) await trio.hazmat.checkpoint()
async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None: async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None:
""" """
@ -107,7 +107,7 @@ class FloodSub(IPubsubRouter):
:param topic: topic to join :param topic: topic to join
""" """
# Checkpoint # Checkpoint
await trio.sleep(0) await trio.hazmat.checkpoint()
async def leave(self, topic: str) -> None: async def leave(self, topic: str) -> None:
""" """
@ -117,7 +117,7 @@ class FloodSub(IPubsubRouter):
:param topic: topic to leave :param topic: topic to leave
""" """
# Checkpoint # Checkpoint
await trio.sleep(0) await trio.hazmat.checkpoint()
def _get_peers_to_send( def _get_peers_to_send(
self, topic_ids: Iterable[str], msg_forwarder: ID, origin: ID self, topic_ids: Iterable[str], msg_forwarder: ID, origin: ID

View File

@ -1,15 +1,18 @@
from ast import literal_eval from ast import literal_eval
import asyncio
import logging import logging
import random import random
from typing import Any, Dict, Iterable, List, Sequence, Set from typing import Any, Dict, Iterable, List, Sequence, Set
from async_service import Service
import trio
from libp2p.network.stream.exceptions import StreamClosed from libp2p.network.stream.exceptions import StreamClosed
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.pubsub import floodsub from libp2p.pubsub import floodsub
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from libp2p.utils import encode_varint_prefixed from libp2p.utils import encode_varint_prefixed
from .exceptions import NoPubsubAttached
from .mcache import MessageCache from .mcache import MessageCache
from .pb import rpc_pb2 from .pb import rpc_pb2
from .pubsub import Pubsub from .pubsub import Pubsub
@ -20,8 +23,7 @@ PROTOCOL_ID = TProtocol("/meshsub/1.0.0")
logger = logging.getLogger("libp2p.pubsub.gossipsub") logger = logging.getLogger("libp2p.pubsub.gossipsub")
class GossipSub(IPubsubRouter): class GossipSub(IPubsubRouter, Service):
protocols: List[TProtocol] protocols: List[TProtocol]
pubsub: Pubsub pubsub: Pubsub
@ -86,6 +88,12 @@ class GossipSub(IPubsubRouter):
# Create heartbeat timer # Create heartbeat timer
self.heartbeat_interval = heartbeat_interval self.heartbeat_interval = heartbeat_interval
async def run(self) -> None:
if self.pubsub is None:
raise NoPubsubAttached
self.manager.run_task(self.heartbeat)
await self.manager.wait_finished()
# Interface functions # Interface functions
def get_protocols(self) -> List[TProtocol]: def get_protocols(self) -> List[TProtocol]:
@ -105,10 +113,6 @@ class GossipSub(IPubsubRouter):
logger.debug("attached to pusub") logger.debug("attached to pusub")
# Start heartbeat now that we have a pubsub instance
# TODO: Start after delay
asyncio.ensure_future(self.heartbeat())
def add_peer(self, peer_id: ID, protocol_id: TProtocol) -> None: def add_peer(self, peer_id: ID, protocol_id: TProtocol) -> None:
""" """
Notifies the router that a new peer has been connected. Notifies the router that a new peer has been connected.
@ -310,7 +314,7 @@ class GossipSub(IPubsubRouter):
await self.fanout_heartbeat() await self.fanout_heartbeat()
await self.gossip_heartbeat() await self.gossip_heartbeat()
await asyncio.sleep(self.heartbeat_interval) await trio.sleep(self.heartbeat_interval)
async def mesh_heartbeat(self) -> None: async def mesh_heartbeat(self) -> None:
# Note: the comments here are the exact pseudocode from the spec # Note: the comments here are the exact pseudocode from the spec
@ -338,7 +342,7 @@ class GossipSub(IPubsubRouter):
if num_mesh_peers_in_topic > self.degree_high: if num_mesh_peers_in_topic > self.degree_high:
# Select |mesh[topic]| - D peers from mesh[topic] # Select |mesh[topic]| - D peers from mesh[topic]
selected_peers = GossipSub.select_from_minus( selected_peers = self.select_from_minus(
num_mesh_peers_in_topic - self.degree, self.mesh[topic], [] num_mesh_peers_in_topic - self.degree, self.mesh[topic], []
) )
for peer in selected_peers: for peer in selected_peers:
@ -353,7 +357,10 @@ class GossipSub(IPubsubRouter):
for topic in self.fanout: for topic in self.fanout:
# If time since last published > ttl # If time since last published > ttl
# TODO: there's no way time_since_last_publish gets set anywhere yet # TODO: there's no way time_since_last_publish gets set anywhere yet
if self.time_since_last_publish[topic] > self.time_to_live: if (
topic in self.time_since_last_publish
and self.time_since_last_publish[topic] > self.time_to_live
):
# Remove topic from fanout # Remove topic from fanout
del self.fanout[topic] del self.fanout[topic]
del self.time_since_last_publish[topic] del self.time_since_last_publish[topic]
@ -407,11 +414,7 @@ class GossipSub(IPubsubRouter):
topic, self.degree, [] topic, self.degree, []
) )
for peer in peers_to_emit_ihave_to: for peer in peers_to_emit_ihave_to:
if ( if peer not in self.fanout[topic]:
peer not in self.mesh[topic]
and peer not in self.fanout[topic]
):
msg_id_strs = [str(msg) for msg in msg_ids] msg_id_strs = [str(msg) for msg in msg_ids]
await self.emit_ihave(topic, msg_id_strs, peer) await self.emit_ihave(topic, msg_id_strs, peer)

View File

@ -1,4 +1,4 @@
from abc import ABC, abstractmethod from abc import ABC
import logging import logging
import math import math
import time import time
@ -57,6 +57,7 @@ class TopicValidator(NamedTuple):
is_async: bool is_async: bool
# TODO: Add interface for Pubsub
class BasePubsub(ABC): class BasePubsub(ABC):
pass pass
@ -103,20 +104,24 @@ class Pubsub(BasePubsub, Service):
# Attach this new Pubsub object to the router # Attach this new Pubsub object to the router
self.router.attach(self) self.router.attach(self)
peer_send_channel, peer_receive_channel = trio.open_memory_channel(0) peer_channels: Tuple[
dead_peer_send_channel, dead_peer_receive_channel = trio.open_memory_channel(0) "trio.MemorySendChannel[ID]", "trio.MemoryReceiveChannel[ID]"
] = trio.open_memory_channel(0)
dead_peer_channels: Tuple[
"trio.MemorySendChannel[ID]", "trio.MemoryReceiveChannel[ID]"
] = trio.open_memory_channel(0)
# Only keep the receive channels in `Pubsub`. # Only keep the receive channels in `Pubsub`.
# Therefore, we can only close from the receive side. # Therefore, we can only close from the receive side.
self.peer_receive_channel = peer_receive_channel self.peer_receive_channel = peer_channels[1]
self.dead_peer_receive_channel = dead_peer_receive_channel self.dead_peer_receive_channel = dead_peer_channels[1]
# Register a notifee # Register a notifee
self.host.get_network().register_notifee( self.host.get_network().register_notifee(
PubsubNotifee(peer_send_channel, dead_peer_send_channel) PubsubNotifee(peer_channels[0], dead_peer_channels[0])
) )
# Register stream handlers for each pubsub router protocol to handle # Register stream handlers for each pubsub router protocol to handle
# the pubsub streams opened on those protocols # the pubsub streams opened on those protocols
for protocol in router.protocols: for protocol in router.get_protocols():
self.host.set_stream_handler(protocol, self.stream_handler) self.host.set_stream_handler(protocol, self.stream_handler)
# keeps track of seen messages as LRU cache # keeps track of seen messages as LRU cache
@ -328,8 +333,9 @@ class Pubsub(BasePubsub, Service):
self.manager.run_task(self._handle_new_peer, peer_id) self.manager.run_task(self._handle_new_peer, peer_id)
async def handle_dead_peer_queue(self) -> None: async def handle_dead_peer_queue(self) -> None:
"""Continuously read from dead peer channel and close the stream between """Continuously read from dead peer channel and close the stream
that peer and remove peer info from pubsub and pubsub router.""" between that peer and remove peer info from pubsub and pubsub
router."""
async with self.dead_peer_receive_channel: async with self.dead_peer_receive_channel:
while self.manager.is_running: while self.manager.is_running:
peer_id: ID = await self.dead_peer_receive_channel.receive() peer_id: ID = await self.dead_peer_receive_channel.receive()
@ -391,7 +397,11 @@ class Pubsub(BasePubsub, Service):
return self.subscribed_topics_receive[topic_id] return self.subscribed_topics_receive[topic_id]
# Map topic_id to a blocking channel # Map topic_id to a blocking channel
send_channel, receive_channel = trio.open_memory_channel(math.inf) channels: Tuple[
"trio.MemorySendChannel[rpc_pb2.Message]",
"trio.MemoryReceiveChannel[rpc_pb2.Message]",
] = trio.open_memory_channel(math.inf)
send_channel, receive_channel = channels
self.subscribed_topics_send[topic_id] = send_channel self.subscribed_topics_send[topic_id] = send_channel
self.subscribed_topics_receive[topic_id] = receive_channel self.subscribed_topics_receive[topic_id] = receive_channel
@ -506,7 +516,7 @@ class Pubsub(BasePubsub, Service):
if len(async_topic_validators) > 0: if len(async_topic_validators) > 0:
# TODO: Use a better pattern # TODO: Use a better pattern
final_result = True final_result: bool = True
async def run_async_validator(func: AsyncValidatorFn) -> None: async def run_async_validator(func: AsyncValidatorFn) -> None:
nonlocal final_result nonlocal final_result
@ -514,8 +524,8 @@ class Pubsub(BasePubsub, Service):
final_result = final_result and result final_result = final_result and result
async with trio.open_nursery() as nursery: async with trio.open_nursery() as nursery:
for validator in async_topic_validators: for async_validator in async_topic_validators:
nursery.start_soon(run_async_validator, validator) nursery.start_soon(run_async_validator, async_validator)
if not final_result: if not final_result:
raise ValidationError(f"Validation failed for msg={msg}") raise ValidationError(f"Validation failed for msg={msg}")

View File

@ -1,11 +1,13 @@
from abc import ABC, abstractmethod from abc import abstractmethod
from async_service import ServiceAPI
from libp2p.io.abc import ReadWriteCloser from libp2p.io.abc import ReadWriteCloser
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
class IMuxedConn(ABC): class IMuxedConn(ServiceAPI):
""" """
reference: https://github.com/libp2p/go-stream-muxer/blob/master/muxer.go reference: https://github.com/libp2p/go-stream-muxer/blob/master/muxer.go
""" """

View File

@ -1,7 +1,6 @@
import logging import logging
import math import math
from typing import Any # noqa: F401 from typing import Dict, Optional, Tuple
from typing import Awaitable, Dict, List, Optional, Tuple
from async_service import Service from async_service import Service
import trio import trio
@ -67,13 +66,15 @@ class Mplex(IMuxedConn, Service):
self.streams = {} self.streams = {}
self.streams_lock = trio.Lock() self.streams_lock = trio.Lock()
self.streams_msg_channels = {} self.streams_msg_channels = {}
send_channel, receive_channel = trio.open_memory_channel(math.inf) channels: Tuple[
self.new_stream_send_channel = send_channel "trio.MemorySendChannel[IMuxedStream]",
self.new_stream_receive_channel = receive_channel "trio.MemoryReceiveChannel[IMuxedStream]",
] = trio.open_memory_channel(math.inf)
self.new_stream_send_channel, self.new_stream_receive_channel = channels
self.event_shutting_down = trio.Event() self.event_shutting_down = trio.Event()
self.event_closed = trio.Event() self.event_closed = trio.Event()
async def run(self): async def run(self) -> None:
self.manager.run_task(self.handle_incoming) self.manager.run_task(self.handle_incoming)
await self.manager.wait_finished() await self.manager.wait_finished()
@ -112,11 +113,13 @@ class Mplex(IMuxedConn, Service):
async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream: async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream:
# Use an unbounded buffer, to avoid `handle_incoming` being blocked when doing # Use an unbounded buffer, to avoid `handle_incoming` being blocked when doing
# `send_channel.send`. # `send_channel.send`.
send_channel, receive_channel = trio.open_memory_channel(math.inf) channels: Tuple[
stream = MplexStream(name, stream_id, self, receive_channel) "trio.MemorySendChannel[bytes]", "trio.MemoryReceiveChannel[bytes]"
] = trio.open_memory_channel(math.inf)
stream = MplexStream(name, stream_id, self, channels[1])
async with self.streams_lock: async with self.streams_lock:
self.streams[stream_id] = stream self.streams[stream_id] = stream
self.streams_msg_channels[stream_id] = send_channel self.streams_msg_channels[stream_id] = channels[0]
return stream return stream
async def open_stream(self) -> IMuxedStream: async def open_stream(self) -> IMuxedStream:
@ -150,9 +153,6 @@ class Mplex(IMuxedConn, Service):
:param data: data to send in the message :param data: data to send in the message
:param stream_id: stream the message is in :param stream_id: stream the message is in
""" """
print(
f"!@# send_message: {self._id}: flag={flag}, data={data}, stream_id={stream_id}"
)
# << by 3, then or with flag # << by 3, then or with flag
header = encode_uvarint((stream_id.channel_id << 3) | flag.value) header = encode_uvarint((stream_id.channel_id << 3) | flag.value)
@ -179,19 +179,10 @@ class Mplex(IMuxedConn, Service):
while self.manager.is_running: while self.manager.is_running:
try: try:
print(
f"!@# handle_incoming: {self._id}: before _handle_incoming_message"
)
await self._handle_incoming_message() await self._handle_incoming_message()
print(
f"!@# handle_incoming: {self._id}: after _handle_incoming_message"
)
except MplexUnavailable as e: except MplexUnavailable as e:
logger.debug("mplex unavailable while waiting for incoming: %s", e) logger.debug("mplex unavailable while waiting for incoming: %s", e)
print(f"!@# handle_incoming: {self._id}: MplexUnavailable: {e}")
break break
print(f"!@# handle_incoming: {self._id}: leaving")
# If we enter here, it means this connection is shutting down. # If we enter here, it means this connection is shutting down.
# We should clean things up. # We should clean things up.
await self._cleanup() await self._cleanup()
@ -232,44 +223,27 @@ class Mplex(IMuxedConn, Service):
:raise MplexUnavailable: `Mplex` encounters fatal error or is shutting down. :raise MplexUnavailable: `Mplex` encounters fatal error or is shutting down.
""" """
print(f"!@# _handle_incoming_message: {self._id}: before reading")
channel_id, flag, message = await self.read_message() channel_id, flag, message = await self.read_message()
print(
f"!@# _handle_incoming_message: {self._id}: channel_id={channel_id}, flag={flag}, message={message}"
)
stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1)) stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1))
print(f"!@# _handle_incoming_message: {self._id}: 2")
if flag == HeaderTags.NewStream.value: if flag == HeaderTags.NewStream.value:
print(f"!@# _handle_incoming_message: {self._id}: 3")
await self._handle_new_stream(stream_id, message) await self._handle_new_stream(stream_id, message)
print(f"!@# _handle_incoming_message: {self._id}: 4")
elif flag in ( elif flag in (
HeaderTags.MessageInitiator.value, HeaderTags.MessageInitiator.value,
HeaderTags.MessageReceiver.value, HeaderTags.MessageReceiver.value,
): ):
print(f"!@# _handle_incoming_message: {self._id}: 5")
await self._handle_message(stream_id, message) await self._handle_message(stream_id, message)
print(f"!@# _handle_incoming_message: {self._id}: 6")
elif flag in (HeaderTags.CloseInitiator.value, HeaderTags.CloseReceiver.value): elif flag in (HeaderTags.CloseInitiator.value, HeaderTags.CloseReceiver.value):
print(f"!@# _handle_incoming_message: {self._id}: 7")
await self._handle_close(stream_id) await self._handle_close(stream_id)
print(f"!@# _handle_incoming_message: {self._id}: 8")
elif flag in (HeaderTags.ResetInitiator.value, HeaderTags.ResetReceiver.value): elif flag in (HeaderTags.ResetInitiator.value, HeaderTags.ResetReceiver.value):
print(f"!@# _handle_incoming_message: {self._id}: 9")
await self._handle_reset(stream_id) await self._handle_reset(stream_id)
print(f"!@# _handle_incoming_message: {self._id}: 10")
else: else:
print(f"!@# _handle_incoming_message: {self._id}: 11")
# Receives messages with an unknown flag # Receives messages with an unknown flag
# TODO: logging # TODO: logging
async with self.streams_lock: async with self.streams_lock:
print(f"!@# _handle_incoming_message: {self._id}: 12")
if stream_id in self.streams: if stream_id in self.streams:
print(f"!@# _handle_incoming_message: {self._id}: 13")
stream = self.streams[stream_id] stream = self.streams[stream_id]
await stream.reset() await stream.reset()
print(f"!@# _handle_incoming_message: {self._id}: 14")
async def _handle_new_stream(self, stream_id: StreamID, message: bytes) -> None: async def _handle_new_stream(self, stream_id: StreamID, message: bytes) -> None:
async with self.streams_lock: async with self.streams_lock:
@ -285,59 +259,43 @@ class Mplex(IMuxedConn, Service):
raise MplexUnavailable raise MplexUnavailable
async def _handle_message(self, stream_id: StreamID, message: bytes) -> None: async def _handle_message(self, stream_id: StreamID, message: bytes) -> None:
print(
f"!@# _handle_message: {self._id}: stream_id={stream_id}, message={message}"
)
async with self.streams_lock: async with self.streams_lock:
print(f"!@# _handle_message: {self._id}: 1")
if stream_id not in self.streams: if stream_id not in self.streams:
# We receive a message of the stream `stream_id` which is not accepted # We receive a message of the stream `stream_id` which is not accepted
# before. It is abnormal. Possibly disconnect? # before. It is abnormal. Possibly disconnect?
# TODO: Warn and emit logs about this. # TODO: Warn and emit logs about this.
print(f"!@# _handle_message: {self._id}: 2")
return return
print(f"!@# _handle_message: {self._id}: 3")
stream = self.streams[stream_id] stream = self.streams[stream_id]
send_channel = self.streams_msg_channels[stream_id] send_channel = self.streams_msg_channels[stream_id]
async with stream.close_lock: async with stream.close_lock:
print(f"!@# _handle_message: {self._id}: 4")
if stream.event_remote_closed.is_set(): if stream.event_remote_closed.is_set():
print(f"!@# _handle_message: {self._id}: 5")
# TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501 # TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501
return return
print(f"!@# _handle_message: {self._id}: 6")
await send_channel.send(message) await send_channel.send(message)
print(f"!@# _handle_message: {self._id}: 7")
async def _handle_close(self, stream_id: StreamID) -> None: async def _handle_close(self, stream_id: StreamID) -> None:
print(f"!@# _handle_close: {self._id}: step=0")
async with self.streams_lock: async with self.streams_lock:
if stream_id not in self.streams: if stream_id not in self.streams:
# Ignore unmatched messages for now. # Ignore unmatched messages for now.
return return
stream = self.streams[stream_id] stream = self.streams[stream_id]
send_channel = self.streams_msg_channels[stream_id] send_channel = self.streams_msg_channels[stream_id]
print(f"!@# _handle_close: {self._id}: step=1")
await send_channel.aclose() await send_channel.aclose()
print(f"!@# _handle_close: {self._id}: step=2")
# NOTE: If remote is already closed, then return: Technically a bug # NOTE: If remote is already closed, then return: Technically a bug
# on the other side. We should consider killing the connection. # on the other side. We should consider killing the connection.
async with stream.close_lock: async with stream.close_lock:
if stream.event_remote_closed.is_set(): if stream.event_remote_closed.is_set():
return return
print(f"!@# _handle_close: {self._id}: step=3")
is_local_closed: bool is_local_closed: bool
async with stream.close_lock: async with stream.close_lock:
stream.event_remote_closed.set() stream.event_remote_closed.set()
is_local_closed = stream.event_local_closed.is_set() is_local_closed = stream.event_local_closed.is_set()
print(f"!@# _handle_close: {self._id}: step=4")
# If local is also closed, both sides are closed. Then, we should clean up # If local is also closed, both sides are closed. Then, we should clean up
# the entry of this stream, to avoid others from accessing it. # the entry of this stream, to avoid others from accessing it.
if is_local_closed: if is_local_closed:
async with self.streams_lock: async with self.streams_lock:
if stream_id in self.streams: if stream_id in self.streams:
del self.streams[stream_id] del self.streams[stream_id]
print(f"!@# _handle_close: {self._id}: step=5")
async def _handle_reset(self, stream_id: StreamID) -> None: async def _handle_reset(self, stream_id: StreamID) -> None:
async with self.streams_lock: async with self.streams_lock:

View File

@ -1,30 +1,29 @@
from contextlib import AsyncExitStack, asynccontextmanager from contextlib import AsyncExitStack, asynccontextmanager
from typing import Any, AsyncIterator, Dict, Tuple, cast from typing import Any, AsyncIterator, Dict, Sequence, Tuple, cast
from async_service import background_trio_service from async_service import background_trio_service
import factory import factory
import trio import trio
from libp2p.tools.constants import GOSSIPSUB_PARAMS
from libp2p import generate_new_rsa_identity, generate_peer_id_from from libp2p import generate_new_rsa_identity, generate_peer_id_from
from libp2p.crypto.keys import KeyPair from libp2p.crypto.keys import KeyPair
from libp2p.host.basic_host import BasicHost from libp2p.host.basic_host import BasicHost
from libp2p.host.routed_host import RoutedHost from libp2p.host.host_interface import IHost
from libp2p.tools.utils import set_up_routers
from libp2p.kademlia.network import KademliaServer
from libp2p.network.connection.swarm_connection import SwarmConn from libp2p.network.connection.swarm_connection import SwarmConn
from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.network.swarm import Swarm from libp2p.network.swarm import Swarm
from libp2p.peer.peerstore import PeerStore
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.peer.peerstore import PeerStore
from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.floodsub import FloodSub
from libp2p.pubsub.gossipsub import GossipSub from libp2p.pubsub.gossipsub import GossipSub
from libp2p.pubsub.pubsub import Pubsub from libp2p.pubsub.pubsub import Pubsub
from libp2p.pubsub.pubsub_router_interface import IPubsubRouter
from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.base_transport import BaseSecureTransport
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
import libp2p.security.secio.transport as secio import libp2p.security.secio.transport as secio
from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex
from libp2p.stream_muxer.mplex.mplex_stream import MplexStream from libp2p.stream_muxer.mplex.mplex_stream import MplexStream
from libp2p.tools.constants import GOSSIPSUB_PARAMS
from libp2p.transport.tcp.tcp import TCP from libp2p.transport.tcp.tcp import TCP
from libp2p.transport.typing import TMuxerOptions from libp2p.transport.typing import TMuxerOptions
from libp2p.transport.upgrader import TransportUpgrader from libp2p.transport.upgrader import TransportUpgrader
@ -74,7 +73,7 @@ class SwarmFactory(factory.Factory):
@asynccontextmanager @asynccontextmanager
async def create_and_listen( async def create_and_listen(
cls, is_secure: bool, key_pair: KeyPair = None, muxer_opt: TMuxerOptions = None cls, is_secure: bool, key_pair: KeyPair = None, muxer_opt: TMuxerOptions = None
) -> Swarm: ) -> AsyncIterator[Swarm]:
# `factory.Factory.__init__` does *not* prepare a *default value* if we pass # `factory.Factory.__init__` does *not* prepare a *default value* if we pass
# an argument explicitly with `None`. If an argument is `None`, we don't pass it to # an argument explicitly with `None`. If an argument is `None`, we don't pass it to
# `factory.Factory.__init__`, in order to let the function initialize it. # `factory.Factory.__init__`, in order to let the function initialize it.
@ -92,7 +91,7 @@ class SwarmFactory(factory.Factory):
@asynccontextmanager @asynccontextmanager
async def create_batch_and_listen( async def create_batch_and_listen(
cls, is_secure: bool, number: int, muxer_opt: TMuxerOptions = None cls, is_secure: bool, number: int, muxer_opt: TMuxerOptions = None
) -> Tuple[Swarm, ...]: ) -> AsyncIterator[Tuple[Swarm, ...]]:
async with AsyncExitStack() as stack: async with AsyncExitStack() as stack:
ctx_mgrs = [ ctx_mgrs = [
await stack.enter_async_context( await stack.enter_async_context(
@ -100,7 +99,7 @@ class SwarmFactory(factory.Factory):
) )
for _ in range(number) for _ in range(number)
] ]
yield ctx_mgrs yield tuple(ctx_mgrs)
class HostFactory(factory.Factory): class HostFactory(factory.Factory):
@ -120,7 +119,7 @@ class HostFactory(factory.Factory):
@asynccontextmanager @asynccontextmanager
async def create_batch_and_listen( async def create_batch_and_listen(
cls, is_secure: bool, number: int cls, is_secure: bool, number: int
) -> Tuple[BasicHost, ...]: ) -> AsyncIterator[Tuple[BasicHost, ...]]:
key_pairs = [generate_new_rsa_identity() for _ in range(number)] key_pairs = [generate_new_rsa_identity() for _ in range(number)]
async with AsyncExitStack() as stack: async with AsyncExitStack() as stack:
swarms = [ swarms = [
@ -136,30 +135,6 @@ class HostFactory(factory.Factory):
yield hosts yield hosts
class RoutedHostFactory(factory.Factory):
class Meta:
model = RoutedHost
public_key = factory.LazyAttribute(lambda o: o.key_pair.public_key)
network = factory.LazyAttribute(
lambda o: SwarmFactory(is_secure=o.is_secure, key_pair=o.key_pair)
)
router = factory.LazyFunction(KademliaServer)
@classmethod
@asynccontextmanager
async def create_batch_and_listen(
cls, is_secure: bool, number: int
) -> Tuple[RoutedHost, ...]:
key_pairs = [generate_new_rsa_identity() for _ in range(number)]
routers = await set_up_routers((0,) * number)
async with SwarmFactory.create_batch_and_listen(is_secure, number) as swarms:
yield tuple(
RoutedHost(key_pair.public_key, swarm, router)
for key_pair, swarm, router in zip(key_pairs, swarms, routers)
)
class FloodsubFactory(factory.Factory): class FloodsubFactory(factory.Factory):
class Meta: class Meta:
model = FloodSub model = FloodSub
@ -191,17 +166,22 @@ class PubsubFactory(factory.Factory):
@classmethod @classmethod
@asynccontextmanager @asynccontextmanager
async def create_and_start(cls, host, router, cache_size): async def create_and_start(
cls, host: IHost, router: IPubsubRouter, cache_size: int
) -> AsyncIterator[Pubsub]:
pubsub = PubsubFactory(host=host, router=router, cache_size=cache_size) pubsub = PubsubFactory(host=host, router=router, cache_size=cache_size)
async with background_trio_service(pubsub): async with background_trio_service(pubsub):
yield pubsub yield pubsub
@classmethod @classmethod
@asynccontextmanager @asynccontextmanager
async def create_batch_with_floodsub( async def _create_batch_with_router(
cls, number: int, is_secure: bool = False, cache_size: int = None cls,
): number: int,
floodsubs = FloodsubFactory.create_batch(number) routers: Sequence[IPubsubRouter],
is_secure: bool = False,
cache_size: int = None,
) -> AsyncIterator[Tuple[Pubsub, ...]]:
async with HostFactory.create_batch_and_listen(is_secure, number) as hosts: async with HostFactory.create_batch_and_listen(is_secure, number) as hosts:
# Pubsubs should exit before hosts # Pubsubs should exit before hosts
async with AsyncExitStack() as stack: async with AsyncExitStack() as stack:
@ -209,21 +189,80 @@ class PubsubFactory(factory.Factory):
await stack.enter_async_context( await stack.enter_async_context(
cls.create_and_start(host, router, cache_size) cls.create_and_start(host, router, cache_size)
) )
for host, router in zip(hosts, floodsubs) for host, router in zip(hosts, routers)
] ]
yield pubsubs yield tuple(pubsubs)
# @classmethod @classmethod
# async def create_batch_with_gossipsub( @asynccontextmanager
# cls, number: int, cache_size: int = None, gossipsub_params=GOSSIPSUB_PARAMS async def create_batch_with_floodsub(
# ): cls,
# ... number: int,
is_secure: bool = False,
cache_size: int = None,
protocols: Sequence[TProtocol] = None,
) -> AsyncIterator[Tuple[Pubsub, ...]]:
if protocols is not None:
floodsubs = FloodsubFactory.create_batch(number, protocols=list(protocols))
else:
floodsubs = FloodsubFactory.create_batch(number)
async with cls._create_batch_with_router(
number, floodsubs, is_secure, cache_size
) as pubsubs:
yield pubsubs
@classmethod
@asynccontextmanager
async def create_batch_with_gossipsub(
cls,
number: int,
*,
is_secure: bool = False,
cache_size: int = None,
protocols: Sequence[TProtocol] = None,
degree: int = GOSSIPSUB_PARAMS.degree,
degree_low: int = GOSSIPSUB_PARAMS.degree_low,
degree_high: int = GOSSIPSUB_PARAMS.degree_high,
time_to_live: int = GOSSIPSUB_PARAMS.time_to_live,
gossip_window: int = GOSSIPSUB_PARAMS.gossip_window,
gossip_history: int = GOSSIPSUB_PARAMS.gossip_history,
heartbeat_interval: float = GOSSIPSUB_PARAMS.heartbeat_interval,
) -> AsyncIterator[Tuple[Pubsub, ...]]:
if protocols is not None:
gossipsubs = GossipsubFactory.create_batch(
number,
protocols=protocols,
degree=degree,
degree_low=degree_low,
degree_high=degree_high,
time_to_live=time_to_live,
gossip_window=gossip_window,
heartbeat_interval=heartbeat_interval,
)
else:
gossipsubs = GossipsubFactory.create_batch(
number,
degree=degree,
degree_low=degree_low,
degree_high=degree_high,
time_to_live=time_to_live,
gossip_window=gossip_window,
heartbeat_interval=heartbeat_interval,
)
async with cls._create_batch_with_router(
number, gossipsubs, is_secure, cache_size
) as pubsubs:
async with AsyncExitStack() as stack:
for router in gossipsubs:
await stack.enter_async_context(background_trio_service(router))
yield pubsubs
@asynccontextmanager @asynccontextmanager
async def swarm_pair_factory( async def swarm_pair_factory(
is_secure: bool, muxer_opt: TMuxerOptions = None is_secure: bool, muxer_opt: TMuxerOptions = None
) -> Tuple[Swarm, Swarm]: ) -> AsyncIterator[Tuple[Swarm, Swarm]]:
async with SwarmFactory.create_batch_and_listen( async with SwarmFactory.create_batch_and_listen(
is_secure, 2, muxer_opt=muxer_opt is_secure, 2, muxer_opt=muxer_opt
) as swarms: ) as swarms:
@ -232,7 +271,9 @@ async def swarm_pair_factory(
@asynccontextmanager @asynccontextmanager
async def host_pair_factory(is_secure: bool) -> Tuple[BasicHost, BasicHost]: async def host_pair_factory(
is_secure: bool
) -> AsyncIterator[Tuple[BasicHost, BasicHost]]:
async with HostFactory.create_batch_and_listen(is_secure, 2) as hosts: async with HostFactory.create_batch_and_listen(is_secure, 2) as hosts:
await connect(hosts[0], hosts[1]) await connect(hosts[0], hosts[1])
yield hosts[0], hosts[1] yield hosts[0], hosts[1]
@ -241,7 +282,7 @@ async def host_pair_factory(is_secure: bool) -> Tuple[BasicHost, BasicHost]:
@asynccontextmanager @asynccontextmanager
async def swarm_conn_pair_factory( async def swarm_conn_pair_factory(
is_secure: bool, muxer_opt: TMuxerOptions = None is_secure: bool, muxer_opt: TMuxerOptions = None
) -> Tuple[SwarmConn, SwarmConn]: ) -> AsyncIterator[Tuple[SwarmConn, SwarmConn]]:
async with swarm_pair_factory(is_secure) as swarms: async with swarm_pair_factory(is_secure) as swarms:
conn_0 = swarms[0].connections[swarms[1].get_peer_id()] conn_0 = swarms[0].connections[swarms[1].get_peer_id()]
conn_1 = swarms[1].connections[swarms[0].get_peer_id()] conn_1 = swarms[1].connections[swarms[0].get_peer_id()]
@ -249,7 +290,9 @@ async def swarm_conn_pair_factory(
@asynccontextmanager @asynccontextmanager
async def mplex_conn_pair_factory(is_secure: bool) -> Tuple[Mplex, Mplex]: async def mplex_conn_pair_factory(
is_secure: bool
) -> AsyncIterator[Tuple[Mplex, Mplex]]:
muxer_opt = {MPLEX_PROTOCOL_ID: Mplex} muxer_opt = {MPLEX_PROTOCOL_ID: Mplex}
async with swarm_conn_pair_factory(is_secure, muxer_opt=muxer_opt) as swarm_pair: async with swarm_conn_pair_factory(is_secure, muxer_opt=muxer_opt) as swarm_pair:
yield ( yield (
@ -259,21 +302,25 @@ async def mplex_conn_pair_factory(is_secure: bool) -> Tuple[Mplex, Mplex]:
@asynccontextmanager @asynccontextmanager
async def mplex_stream_pair_factory(is_secure: bool) -> Tuple[MplexStream, MplexStream]: async def mplex_stream_pair_factory(
is_secure: bool
) -> AsyncIterator[Tuple[MplexStream, MplexStream]]:
async with mplex_conn_pair_factory(is_secure) as mplex_conn_pair_info: async with mplex_conn_pair_factory(is_secure) as mplex_conn_pair_info:
mplex_conn_0, mplex_conn_1 = mplex_conn_pair_info mplex_conn_0, mplex_conn_1 = mplex_conn_pair_info
stream_0 = await mplex_conn_0.open_stream() stream_0 = cast(MplexStream, await mplex_conn_0.open_stream())
await trio.sleep(0.01) await trio.sleep(0.01)
stream_1: MplexStream stream_1: MplexStream
async with mplex_conn_1.streams_lock: async with mplex_conn_1.streams_lock:
if len(mplex_conn_1.streams) != 1: if len(mplex_conn_1.streams) != 1:
raise Exception("Mplex should not have any other stream") raise Exception("Mplex should not have any other stream")
stream_1 = tuple(mplex_conn_1.streams.values())[0] stream_1 = tuple(mplex_conn_1.streams.values())[0]
yield cast(MplexStream, stream_0), cast(MplexStream, stream_1) yield stream_0, stream_1
@asynccontextmanager @asynccontextmanager
async def net_stream_pair_factory(is_secure: bool) -> Tuple[INetStream, INetStream]: async def net_stream_pair_factory(
is_secure: bool
) -> AsyncIterator[Tuple[INetStream, INetStream]]:
protocol_id = TProtocol("/example/id/1") protocol_id = TProtocol("/example/id/1")
stream_1: INetStream stream_1: INetStream

View File

@ -1,12 +1,11 @@
import asyncio from contextlib import AsyncExitStack, asynccontextmanager
from typing import Dict from typing import AsyncIterator, Dict, Tuple
import uuid
from async_service import Service, background_trio_service
from libp2p.host.host_interface import IHost from libp2p.host.host_interface import IHost
from libp2p.pubsub.floodsub import FloodSub
from libp2p.pubsub.pubsub import Pubsub from libp2p.pubsub.pubsub import Pubsub
from libp2p.tools.constants import LISTEN_MADDR from libp2p.tools.factories import PubsubFactory
from libp2p.tools.factories import FloodsubFactory, PubsubFactory
CRYPTO_TOPIC = "ethereum" CRYPTO_TOPIC = "ethereum"
@ -18,7 +17,7 @@ CRYPTO_TOPIC = "ethereum"
# Determine message type by looking at first item before first comma # Determine message type by looking at first item before first comma
class DummyAccountNode: class DummyAccountNode(Service):
""" """
Node which has an internal balance mapping, meant to serve as a dummy Node which has an internal balance mapping, meant to serve as a dummy
crypto blockchain. crypto blockchain.
@ -27,19 +26,24 @@ class DummyAccountNode:
crypto each user in the mappings holds crypto each user in the mappings holds
""" """
libp2p_node: IHost
pubsub: Pubsub pubsub: Pubsub
floodsub: FloodSub
def __init__(self, libp2p_node: IHost, pubsub: Pubsub, floodsub: FloodSub): def __init__(self, pubsub: Pubsub) -> None:
self.libp2p_node = libp2p_node
self.pubsub = pubsub self.pubsub = pubsub
self.floodsub = floodsub
self.balances: Dict[str, int] = {} self.balances: Dict[str, int] = {}
self.node_id = str(uuid.uuid1())
@property
def host(self) -> IHost:
return self.pubsub.host
async def run(self) -> None:
self.subscription = await self.pubsub.subscribe(CRYPTO_TOPIC)
self.manager.run_daemon_task(self.handle_incoming_msgs)
await self.manager.wait_finished()
@classmethod @classmethod
async def create(cls) -> "DummyAccountNode": @asynccontextmanager
async def create(cls, number: int) -> AsyncIterator[Tuple["DummyAccountNode", ...]]:
""" """
Create a new DummyAccountNode and attach a libp2p node, a floodsub, and Create a new DummyAccountNode and attach a libp2p node, a floodsub, and
a pubsub instance to this new node. a pubsub instance to this new node.
@ -47,15 +51,17 @@ class DummyAccountNode:
We use create as this serves as a factory function and allows us We use create as this serves as a factory function and allows us
to use async await, unlike the init function to use async await, unlike the init function
""" """
async with PubsubFactory.create_batch_with_floodsub(number) as pubsubs:
pubsub = PubsubFactory(router=FloodsubFactory()) async with AsyncExitStack() as stack:
await pubsub.host.get_network().listen(LISTEN_MADDR) dummy_acount_nodes = tuple(cls(pubsub) for pubsub in pubsubs)
return cls(libp2p_node=pubsub.host, pubsub=pubsub, floodsub=pubsub.router) for node in dummy_acount_nodes:
await stack.enter_async_context(background_trio_service(node))
yield dummy_acount_nodes
async def handle_incoming_msgs(self) -> None: async def handle_incoming_msgs(self) -> None:
"""Handle all incoming messages on the CRYPTO_TOPIC from peers.""" """Handle all incoming messages on the CRYPTO_TOPIC from peers."""
while True: while True:
incoming = await self.q.get() incoming = await self.subscription.receive()
msg_comps = incoming.data.decode("utf-8").split(",") msg_comps = incoming.data.decode("utf-8").split(",")
if msg_comps[0] == "send": if msg_comps[0] == "send":
@ -63,13 +69,6 @@ class DummyAccountNode:
elif msg_comps[0] == "set": elif msg_comps[0] == "set":
self.handle_set_crypto(msg_comps[1], int(msg_comps[2])) self.handle_set_crypto(msg_comps[1], int(msg_comps[2]))
async def setup_crypto_networking(self) -> None:
"""Subscribe to CRYPTO_TOPIC and perform call to function that handles
all incoming messages on said topic."""
self.q = await self.pubsub.subscribe(CRYPTO_TOPIC)
asyncio.ensure_future(self.handle_incoming_msgs())
async def publish_send_crypto( async def publish_send_crypto(
self, source_user: str, dest_user: str, amount: int self, source_user: str, dest_user: str, amount: int
) -> None: ) -> None:

View File

@ -1,12 +1,10 @@
# type: ignore # type: ignore
# To add typing to this module, it's better to do it after refactoring test cases into classes # To add typing to this module, it's better to do it after refactoring test cases into classes
import asyncio
import pytest import pytest
import trio
from libp2p.tools.constants import FLOODSUB_PROTOCOL_ID, LISTEN_MADDR from libp2p.tools.constants import FLOODSUB_PROTOCOL_ID
from libp2p.tools.factories import PubsubFactory
from libp2p.tools.utils import connect from libp2p.tools.utils import connect
SUPPORTED_PROTOCOLS = [FLOODSUB_PROTOCOL_ID] SUPPORTED_PROTOCOLS = [FLOODSUB_PROTOCOL_ID]
@ -15,6 +13,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{ {
"name": "simple_two_nodes", "name": "simple_two_nodes",
"supported_protocols": SUPPORTED_PROTOCOLS, "supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["A", "B"],
"adj_list": {"A": ["B"]}, "adj_list": {"A": ["B"]},
"topic_map": {"topic1": ["B"]}, "topic_map": {"topic1": ["B"]},
"messages": [{"topics": ["topic1"], "data": b"foo", "node_id": "A"}], "messages": [{"topics": ["topic1"], "data": b"foo", "node_id": "A"}],
@ -22,6 +21,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{ {
"name": "three_nodes_two_topics", "name": "three_nodes_two_topics",
"supported_protocols": SUPPORTED_PROTOCOLS, "supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["A", "B", "C"],
"adj_list": {"A": ["B"], "B": ["C"]}, "adj_list": {"A": ["B"], "B": ["C"]},
"topic_map": {"topic1": ["B", "C"], "topic2": ["B", "C"]}, "topic_map": {"topic1": ["B", "C"], "topic2": ["B", "C"]},
"messages": [ "messages": [
@ -32,6 +32,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{ {
"name": "two_nodes_one_topic_single_subscriber_is_sender", "name": "two_nodes_one_topic_single_subscriber_is_sender",
"supported_protocols": SUPPORTED_PROTOCOLS, "supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["A", "B"],
"adj_list": {"A": ["B"]}, "adj_list": {"A": ["B"]},
"topic_map": {"topic1": ["B"]}, "topic_map": {"topic1": ["B"]},
"messages": [{"topics": ["topic1"], "data": b"Alex is tall", "node_id": "B"}], "messages": [{"topics": ["topic1"], "data": b"Alex is tall", "node_id": "B"}],
@ -39,6 +40,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{ {
"name": "two_nodes_one_topic_two_msgs", "name": "two_nodes_one_topic_two_msgs",
"supported_protocols": SUPPORTED_PROTOCOLS, "supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["A", "B"],
"adj_list": {"A": ["B"]}, "adj_list": {"A": ["B"]},
"topic_map": {"topic1": ["B"]}, "topic_map": {"topic1": ["B"]},
"messages": [ "messages": [
@ -49,6 +51,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{ {
"name": "seven_nodes_tree_one_topics", "name": "seven_nodes_tree_one_topics",
"supported_protocols": SUPPORTED_PROTOCOLS, "supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["1", "2", "3", "4", "5", "6", "7"],
"adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]}, "adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]},
"topic_map": {"astrophysics": ["2", "3", "4", "5", "6", "7"]}, "topic_map": {"astrophysics": ["2", "3", "4", "5", "6", "7"]},
"messages": [{"topics": ["astrophysics"], "data": b"e=mc^2", "node_id": "1"}], "messages": [{"topics": ["astrophysics"], "data": b"e=mc^2", "node_id": "1"}],
@ -56,6 +59,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{ {
"name": "seven_nodes_tree_three_topics", "name": "seven_nodes_tree_three_topics",
"supported_protocols": SUPPORTED_PROTOCOLS, "supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["1", "2", "3", "4", "5", "6", "7"],
"adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]}, "adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]},
"topic_map": { "topic_map": {
"astrophysics": ["2", "3", "4", "5", "6", "7"], "astrophysics": ["2", "3", "4", "5", "6", "7"],
@ -71,6 +75,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{ {
"name": "seven_nodes_tree_three_topics_diff_origin", "name": "seven_nodes_tree_three_topics_diff_origin",
"supported_protocols": SUPPORTED_PROTOCOLS, "supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["1", "2", "3", "4", "5", "6", "7"],
"adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]}, "adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]},
"topic_map": { "topic_map": {
"astrophysics": ["1", "2", "3", "4", "5", "6", "7"], "astrophysics": ["1", "2", "3", "4", "5", "6", "7"],
@ -86,6 +91,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{ {
"name": "three_nodes_clique_two_topic_diff_origin", "name": "three_nodes_clique_two_topic_diff_origin",
"supported_protocols": SUPPORTED_PROTOCOLS, "supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["1", "2", "3"],
"adj_list": {"1": ["2", "3"], "2": ["3"]}, "adj_list": {"1": ["2", "3"], "2": ["3"]},
"topic_map": {"astrophysics": ["1", "2", "3"], "school": ["1", "2", "3"]}, "topic_map": {"astrophysics": ["1", "2", "3"], "school": ["1", "2", "3"]},
"messages": [ "messages": [
@ -97,6 +103,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{ {
"name": "four_nodes_clique_two_topic_diff_origin_many_msgs", "name": "four_nodes_clique_two_topic_diff_origin_many_msgs",
"supported_protocols": SUPPORTED_PROTOCOLS, "supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["1", "2", "3", "4"],
"adj_list": { "adj_list": {
"1": ["2", "3", "4"], "1": ["2", "3", "4"],
"2": ["1", "3", "4"], "2": ["1", "3", "4"],
@ -120,6 +127,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{ {
"name": "five_nodes_ring_two_topic_diff_origin_many_msgs", "name": "five_nodes_ring_two_topic_diff_origin_many_msgs",
"supported_protocols": SUPPORTED_PROTOCOLS, "supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["1", "2", "3", "4", "5"],
"adj_list": {"1": ["2"], "2": ["3"], "3": ["4"], "4": ["5"], "5": ["1"]}, "adj_list": {"1": ["2"], "2": ["3"], "3": ["4"], "4": ["5"], "5": ["1"]},
"topic_map": { "topic_map": {
"astrophysics": ["1", "2", "3", "4", "5"], "astrophysics": ["1", "2", "3", "4", "5"],
@ -143,7 +151,7 @@ floodsub_protocol_pytest_params = [
] ]
async def perform_test_from_obj(obj, router_factory) -> None: async def perform_test_from_obj(obj, pubsub_factory) -> None:
""" """
Perform pubsub tests from a test obj. Perform pubsub tests from a test obj.
test obj are composed as follows: test obj are composed as follows:
@ -174,88 +182,75 @@ async def perform_test_from_obj(obj, router_factory) -> None:
# Step 1) Create graph # Step 1) Create graph
adj_list = obj["adj_list"] adj_list = obj["adj_list"]
node_list = obj["nodes"]
node_map = {} node_map = {}
pubsub_map = {} pubsub_map = {}
async def add_node(node_id_str: str) -> None: async with pubsub_factory(
pubsub_router = router_factory(protocols=obj["supported_protocols"]) number=len(node_list), protocols=obj["supported_protocols"]
pubsub = PubsubFactory(router=pubsub_router) ) as pubsubs:
await pubsub.host.get_network().listen(LISTEN_MADDR) for node_id_str, pubsub in zip(node_list, pubsubs):
node_map[node_id_str] = pubsub.host node_map[node_id_str] = pubsub.host
pubsub_map[node_id_str] = pubsub pubsub_map[node_id_str] = pubsub
tasks_connect = [] # Connect nodes and wait at least for 2 seconds
for start_node_id in adj_list: async with trio.open_nursery() as nursery:
# Create node if node does not yet exist for start_node_id in adj_list:
if start_node_id not in node_map: # For each neighbor of start_node, create if does not yet exist,
await add_node(start_node_id) # then connect start_node to neighbor
for neighbor_id in adj_list[start_node_id]:
nursery.start_soon(
connect, node_map[start_node_id], node_map[neighbor_id]
)
nursery.start_soon(trio.sleep, 2)
# For each neighbor of start_node, create if does not yet exist, # Step 2) Subscribe to topics
# then connect start_node to neighbor queues_map = {}
for neighbor_id in adj_list[start_node_id]: topic_map = obj["topic_map"]
# Create neighbor if neighbor does not yet exist
if neighbor_id not in node_map:
await add_node(neighbor_id)
tasks_connect.append(
connect(node_map[start_node_id], node_map[neighbor_id])
)
# Connect nodes and wait at least for 2 seconds
await asyncio.gather(*tasks_connect, asyncio.sleep(2))
# Step 2) Subscribe to topics async def subscribe_node(node_id, topic):
queues_map = {} if node_id not in queues_map:
topic_map = obj["topic_map"] queues_map[node_id] = {}
# Avoid repeated works
if topic in queues_map[node_id]:
# Checkpoint
await trio.hazmat.checkpoint()
return
sub = await pubsub_map[node_id].subscribe(topic)
queues_map[node_id][topic] = sub
tasks_topic = [] async with trio.open_nursery() as nursery:
tasks_topic_data = [] for topic, node_ids in topic_map.items():
for topic, node_ids in topic_map.items(): for node_id in node_ids:
for node_id in node_ids: nursery.start_soon(subscribe_node, node_id, topic)
tasks_topic.append(pubsub_map[node_id].subscribe(topic)) nursery.start_soon(trio.sleep, 2)
tasks_topic_data.append((node_id, topic))
tasks_topic.append(asyncio.sleep(2))
# Gather is like Promise.all # Step 3) Publish messages
responses = await asyncio.gather(*tasks_topic) topics_in_msgs_ordered = []
for i in range(len(responses) - 1): messages = obj["messages"]
node_id, topic = tasks_topic_data[i]
if node_id not in queues_map:
queues_map[node_id] = {}
# Store queue in topic-queue map for node
queues_map[node_id][topic] = responses[i]
# Allow time for subscribing before continuing for msg in messages:
await asyncio.sleep(0.01) topics = msg["topics"]
data = msg["data"]
node_id = msg["node_id"]
# Step 3) Publish messages # Publish message
topics_in_msgs_ordered = [] # TODO: Should be single RPC package with several topics
messages = obj["messages"] for topic in topics:
tasks_publish = [] await pubsub_map[node_id].publish(topic, data)
for msg in messages: # For each topic in topics, add (topic, node_id, data) tuple to ordered test list
topics = msg["topics"] for topic in topics:
data = msg["data"] topics_in_msgs_ordered.append((topic, node_id, data))
node_id = msg["node_id"] # Allow time for publishing before continuing
await trio.sleep(1)
# Publish message # Step 4) Check that all messages were received correctly.
# TODO: Should be single RPC package with several topics for topic, origin_node_id, data in topics_in_msgs_ordered:
for topic in topics: # Look at each node in each topic
tasks_publish.append(pubsub_map[node_id].publish(topic, data)) for node_id in topic_map[topic]:
# Get message from subscription queue
# For each topic in topics, add (topic, node_id, data) tuple to ordered test list msg = await queues_map[node_id][topic].receive()
for topic in topics: assert data == msg.data
topics_in_msgs_ordered.append((topic, node_id, data)) # Check the message origin
assert node_map[origin_node_id].get_id().to_bytes() == msg.from_id
# Allow time for publishing before continuing
await asyncio.gather(*tasks_publish, asyncio.sleep(2))
# Step 4) Check that all messages were received correctly.
for topic, origin_node_id, data in topics_in_msgs_ordered:
# Look at each node in each topic
for node_id in topic_map[topic]:
# Get message from subscription queue
msg = await queues_map[node_id][topic].get()
assert data == msg.data
# Check the message origin
assert node_map[origin_node_id].get_id().to_bytes() == msg.from_id
# Success, terminate pending tasks.

View File

@ -1,17 +1,9 @@
from typing import Callable, List, Sequence, Tuple from typing import Awaitable, Callable
import multiaddr
import trio
from libp2p import new_node
from libp2p.host.basic_host import BasicHost
from libp2p.host.host_interface import IHost from libp2p.host.host_interface import IHost
from libp2p.kademlia.network import KademliaServer
from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.network.swarm import Swarm from libp2p.network.swarm import Swarm
from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.routing.interfaces import IPeerRouting
from libp2p.routing.kademlia.kademlia_peer_router import KadmeliaPeerRouter
from .constants import MAX_READ_LEN from .constants import MAX_READ_LEN
@ -36,49 +28,9 @@ async def connect(node1: IHost, node2: IHost) -> None:
await node1.connect(info) await node1.connect(info)
async def set_up_nodes_by_transport_opt( def create_echo_stream_handler(
transport_opt_list: Sequence[Sequence[str]], nursery: trio.Nursery ack_prefix: str
) -> Tuple[BasicHost, ...]: ) -> Callable[[INetStream], Awaitable[None]]:
nodes_list = []
for transport_opt in transport_opt_list:
node = new_node(transport_opt=transport_opt)
await node.get_network().listen(
multiaddr.Multiaddr(transport_opt[0]), nursery=nursery
)
nodes_list.append(node)
return tuple(nodes_list)
async def set_up_nodes_by_transport_and_disc_opt(
transport_disc_opt_list: Sequence[Tuple[Sequence[str], IPeerRouting]]
) -> Tuple[BasicHost, ...]:
nodes_list = []
for transport_opt, disc_opt in transport_disc_opt_list:
node = await new_node(transport_opt=transport_opt, disc_opt=disc_opt)
await node.get_network().listen(multiaddr.Multiaddr(transport_opt[0]))
nodes_list.append(node)
return tuple(nodes_list)
async def set_up_routers(
router_ports: Tuple[int, ...] = (0, 0)
) -> List[KadmeliaPeerRouter]:
"""The default ``router_confs`` selects two free ports local to this
machine."""
bootstrap_node = KademliaServer() # type: ignore
await bootstrap_node.listen(router_ports[0])
routers = [KadmeliaPeerRouter(bootstrap_node)]
for port in router_ports[1:]:
node = KademliaServer() # type: ignore
await node.listen(port)
await node.bootstrap_node(bootstrap_node.address)
routers.append(KadmeliaPeerRouter(node))
return routers
def create_echo_stream_handler(ack_prefix: str) -> Callable[[INetStream], None]:
async def echo_stream_handler(stream: INetStream) -> None: async def echo_stream_handler(stream: INetStream) -> None:
while True: while True:
read_string = (await stream.read(MAX_READ_LEN)).decode() read_string = (await stream.read(MAX_READ_LEN)).decode()

View File

@ -1,12 +1,13 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List from typing import Tuple
from multiaddr import Multiaddr from multiaddr import Multiaddr
import trio
class IListener(ABC): class IListener(ABC):
@abstractmethod @abstractmethod
async def listen(self, maddr: Multiaddr) -> bool: async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool:
""" """
put listener in listening mode and wait for incoming connections. put listener in listening mode and wait for incoming connections.
@ -15,14 +16,9 @@ class IListener(ABC):
""" """
@abstractmethod @abstractmethod
def get_addrs(self) -> List[Multiaddr]: def get_addrs(self) -> Tuple[Multiaddr, ...]:
""" """
retrieve list of addresses the listener is listening on. retrieve list of addresses the listener is listening on.
:return: return list of addrs :return: return list of addrs
""" """
@abstractmethod
async def close(self) -> None:
"""close the listener such that no more connections can be open on this
transport instance."""

View File

@ -1,14 +1,13 @@
import logging import logging
from socket import socket from typing import Awaitable, Callable, List, Sequence, Tuple
from typing import List
from multiaddr import Multiaddr from multiaddr import Multiaddr
import trio import trio
from trio_typing import TaskStatus
from libp2p.io.trio import TrioReadWriteCloser from libp2p.io.trio import TrioTCPStream
from libp2p.network.connection.raw_connection import RawConnection from libp2p.network.connection.raw_connection import RawConnection
from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.transport.exceptions import OpenConnectionError
from libp2p.transport.listener_interface import IListener from libp2p.transport.listener_interface import IListener
from libp2p.transport.transport_interface import ITransport from libp2p.transport.transport_interface import ITransport
from libp2p.transport.typing import THandler from libp2p.transport.typing import THandler
@ -18,14 +17,12 @@ logger = logging.getLogger("libp2p.transport.tcp")
class TCPListener(IListener): class TCPListener(IListener):
multiaddrs: List[Multiaddr] multiaddrs: List[Multiaddr]
server = None
def __init__(self, handler_function: THandler) -> None: def __init__(self, handler_function: THandler) -> None:
self.multiaddrs = [] self.multiaddrs = []
self.server = None
self.handler = handler_function self.handler = handler_function
# TODO: Fix handling? # TODO: Get rid of `nursery`?
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> None: async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> None:
""" """
put listener in listening mode and wait for incoming connections. put listener in listening mode and wait for incoming connections.
@ -34,13 +31,18 @@ class TCPListener(IListener):
:return: return True if successful :return: return True if successful
""" """
async def serve_tcp(handler, port, host, task_status=None): async def serve_tcp(
handler: Callable[[trio.SocketStream], Awaitable[None]],
port: int,
host: str,
task_status: TaskStatus[Sequence[trio.SocketListener]] = None,
) -> None:
logger.debug("serve_tcp %s %s", host, port) logger.debug("serve_tcp %s %s", host, port)
await trio.serve_tcp(handler, port, host=host, task_status=task_status) await trio.serve_tcp(handler, port, host=host, task_status=task_status)
async def handler(stream): async def handler(stream: trio.SocketStream) -> None:
read_write_closer = TrioReadWriteCloser(stream) tcp_stream = TrioTCPStream(stream)
await self.handler(read_write_closer) await self.handler(tcp_stream)
listeners = await nursery.start( listeners = await nursery.start(
serve_tcp, serve_tcp,
@ -51,7 +53,7 @@ class TCPListener(IListener):
socket = listeners[0].socket socket = listeners[0].socket
self.multiaddrs.append(_multiaddr_from_socket(socket)) self.multiaddrs.append(_multiaddr_from_socket(socket))
def get_addrs(self) -> List[Multiaddr]: def get_addrs(self) -> Tuple[Multiaddr, ...]:
""" """
retrieve list of addresses the listener is listening on. retrieve list of addresses the listener is listening on.
@ -59,15 +61,6 @@ class TCPListener(IListener):
""" """
return tuple(self.multiaddrs) return tuple(self.multiaddrs)
async def close(self) -> None:
"""close the listener such that no more connections can be open on this
transport instance."""
if self.server is None:
return
self.server.close()
await self.server.wait_closed()
self.server = None
class TCP(ITransport): class TCP(ITransport):
async def dial(self, maddr: Multiaddr) -> IRawConnection: async def dial(self, maddr: Multiaddr) -> IRawConnection:
@ -82,7 +75,7 @@ class TCP(ITransport):
self.port = int(maddr.value_for_protocol("tcp")) self.port = int(maddr.value_for_protocol("tcp"))
stream = await trio.open_tcp_stream(self.host, self.port) stream = await trio.open_tcp_stream(self.host, self.port)
read_write_closer = TrioReadWriteCloser(stream) read_write_closer = TrioTCPStream(stream)
return RawConnection(read_write_closer, True) return RawConnection(read_write_closer, True)
@ -97,5 +90,6 @@ class TCP(ITransport):
return TCPListener(handler_function) return TCPListener(handler_function)
def _multiaddr_from_socket(socket: socket) -> Multiaddr: def _multiaddr_from_socket(socket: trio.socket.SocketType) -> Multiaddr:
return Multiaddr("/ip4/%s/tcp/%s" % socket.getsockname()) ip, port = socket.getsockname() # type: ignore
return Multiaddr(f"/ip4/{ip}/tcp/{port}")

View File

@ -1,7 +1,7 @@
import trio
import secrets import secrets
import pytest import pytest
import trio
from libp2p.host.ping import ID, PING_LENGTH from libp2p.host.ping import ID, PING_LENGTH
from libp2p.tools.factories import host_pair_factory from libp2p.tools.factories import host_pair_factory

View File

@ -1,73 +0,0 @@
import pytest
from libp2p.host.exceptions import ConnectionFailure
from libp2p.peer.peerinfo import PeerInfo
from libp2p.routing.kademlia.kademlia_peer_router import peer_info_to_str
from libp2p.tools.utils import (
set_up_nodes_by_transport_and_disc_opt,
set_up_nodes_by_transport_opt,
set_up_routers,
)
from libp2p.tools.factories import RoutedHostFactory
# FIXME:
# TODO: Kademlia is full of asyncio code. Skip it for now
@pytest.mark.skip
@pytest.mark.trio
async def test_host_routing_success(is_host_secure):
async with RoutedHostFactory.create_batch_and_listen(
is_host_secure, 2
) as routed_hosts:
# Set routing info
await routed_hosts[0]._router.server.set(
routed_hosts[0].get_id().xor_id,
peer_info_to_str(
PeerInfo(routed_hosts[0].get_id(), routed_hosts[0].get_addrs())
),
)
await routed_hosts[1]._router.server.set(
routed_hosts[1].get_id().xor_id,
peer_info_to_str(
PeerInfo(routed_hosts[1].get_id(), routed_hosts[1].get_addrs())
),
)
# forces to use routing as no addrs are provided
await routed_hosts[0].connect(PeerInfo(routed_hosts[1].get_id(), []))
await routed_hosts[1].connect(PeerInfo(routed_hosts[0].get_id(), []))
# TODO: Kademlia is full of asyncio code. Skip it for now
@pytest.mark.skip
@pytest.mark.trio
async def test_host_routing_fail():
routers = await set_up_routers()
transports = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
transport_disc_opt_list = zip(transports, routers)
(host_a, host_b) = await set_up_nodes_by_transport_and_disc_opt(
transport_disc_opt_list
)
host_c = (await set_up_nodes_by_transport_opt([["/ip4/127.0.0.1/tcp/0"]]))[0]
# Set routing info
await routers[0].server.set(
host_a.get_id().xor_id,
peer_info_to_str(PeerInfo(host_a.get_id(), host_a.get_addrs())),
)
await routers[1].server.set(
host_b.get_id().xor_id,
peer_info_to_str(PeerInfo(host_b.get_id(), host_b.get_addrs())),
)
# routing fails because host_c does not use routing
with pytest.raises(ConnectionFailure):
await host_a.connect(PeerInfo(host_c.get_id(), []))
with pytest.raises(ConnectionFailure):
await host_b.connect(PeerInfo(host_c.get_id(), []))
# Clean up
routers[0].server.stop()
routers[1].server.stop()

View File

@ -4,7 +4,6 @@ from libp2p.host.exceptions import StreamFailure
from libp2p.tools.factories import HostFactory from libp2p.tools.factories import HostFactory
from libp2p.tools.utils import create_echo_stream_handler from libp2p.tools.utils import create_echo_stream_handler
PROTOCOL_ECHO = "/echo/1.0.0" PROTOCOL_ECHO = "/echo/1.0.0"
PROTOCOL_POTATO = "/potato/1.0.0" PROTOCOL_POTATO = "/potato/1.0.0"
PROTOCOL_FOO = "/foo/1.0.0" PROTOCOL_FOO = "/foo/1.0.0"

View File

@ -1,22 +0,0 @@
import pytest
from libp2p.tools.constants import GOSSIPSUB_PARAMS
from libp2p.tools.factories import FloodsubFactory, GossipsubFactory, PubsubFactory
@pytest.fixture
def pubsub_cache_size():
return None # default
@pytest.fixture
def gossipsub_params():
return GOSSIPSUB_PARAMS
# @pytest.fixture
# def pubsubs_gsub(num_hosts, hosts, pubsub_cache_size, gossipsub_params):
# gossipsubs = GossipsubFactory.create_batch(num_hosts, **gossipsub_params._asdict())
# _pubsubs_gsub = _make_pubsubs(hosts, gossipsubs, pubsub_cache_size)
# yield _pubsubs_gsub
# # TODO: Clean up

View File

@ -1,19 +1,10 @@
import asyncio
from threading import Thread
import pytest import pytest
import trio
from libp2p.tools.pubsub.dummy_account_node import DummyAccountNode from libp2p.tools.pubsub.dummy_account_node import DummyAccountNode
from libp2p.tools.utils import connect from libp2p.tools.utils import connect
def create_setup_in_new_thread_func(dummy_node):
def setup_in_new_thread():
asyncio.ensure_future(dummy_node.setup_crypto_networking())
return setup_in_new_thread
async def perform_test(num_nodes, adjacency_map, action_func, assertion_func): async def perform_test(num_nodes, adjacency_map, action_func, assertion_func):
""" """
Helper function to allow for easy construction of custom tests for dummy Helper function to allow for easy construction of custom tests for dummy
@ -26,47 +17,35 @@ async def perform_test(num_nodes, adjacency_map, action_func, assertion_func):
:param assertion_func: assertions for testing the results of the actions are correct :param assertion_func: assertions for testing the results of the actions are correct
""" """
# Create nodes async with DummyAccountNode.create(num_nodes) as dummy_nodes:
dummy_nodes = [] # Create connections between nodes according to `adjacency_map`
for _ in range(num_nodes): async with trio.open_nursery() as nursery:
dummy_nodes.append(await DummyAccountNode.create()) for source_num in adjacency_map:
target_nums = adjacency_map[source_num]
for target_num in target_nums:
nursery.start_soon(
connect,
dummy_nodes[source_num].host,
dummy_nodes[target_num].host,
)
# Create network # Allow time for network creation to take place
for source_num in adjacency_map: await trio.sleep(0.25)
target_nums = adjacency_map[source_num]
for target_num in target_nums:
await connect(
dummy_nodes[source_num].libp2p_node, dummy_nodes[target_num].libp2p_node
)
# Allow time for network creation to take place # Perform action function
await asyncio.sleep(0.25) await action_func(dummy_nodes)
# Start a thread for each node so that each node can listen and respond # Allow time for action function to be performed (i.e. messages to propogate)
# to messages on its own thread, which will avoid waiting indefinitely await trio.sleep(1)
# on the main thread. On this thread, call the setup func for the node,
# which subscribes the node to the CRYPTO_TOPIC topic
for dummy_node in dummy_nodes:
thread = Thread(target=create_setup_in_new_thread_func(dummy_node))
thread.run()
# Allow time for nodes to subscribe to CRYPTO_TOPIC topic # Perform assertion function
await asyncio.sleep(0.25) for dummy_node in dummy_nodes:
assertion_func(dummy_node)
# Perform action function
await action_func(dummy_nodes)
# Allow time for action function to be performed (i.e. messages to propogate)
await asyncio.sleep(1)
# Perform assertion function
for dummy_node in dummy_nodes:
assertion_func(dummy_node)
# Success, terminate pending tasks. # Success, terminate pending tasks.
@pytest.mark.asyncio @pytest.mark.trio
async def test_simple_two_nodes(): async def test_simple_two_nodes():
num_nodes = 2 num_nodes = 2
adj_map = {0: [1]} adj_map = {0: [1]}
@ -80,7 +59,7 @@ async def test_simple_two_nodes():
await perform_test(num_nodes, adj_map, action_func, assertion_func) await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio @pytest.mark.trio
async def test_simple_three_nodes_line_topography(): async def test_simple_three_nodes_line_topography():
num_nodes = 3 num_nodes = 3
adj_map = {0: [1], 1: [2]} adj_map = {0: [1], 1: [2]}
@ -94,7 +73,7 @@ async def test_simple_three_nodes_line_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func) await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio @pytest.mark.trio
async def test_simple_three_nodes_triangle_topography(): async def test_simple_three_nodes_triangle_topography():
num_nodes = 3 num_nodes = 3
adj_map = {0: [1, 2], 1: [2]} adj_map = {0: [1, 2], 1: [2]}
@ -108,7 +87,7 @@ async def test_simple_three_nodes_triangle_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func) await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio @pytest.mark.trio
async def test_simple_seven_nodes_tree_topography(): async def test_simple_seven_nodes_tree_topography():
num_nodes = 7 num_nodes = 7
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]} adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
@ -122,14 +101,14 @@ async def test_simple_seven_nodes_tree_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func) await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio @pytest.mark.trio
async def test_set_then_send_from_root_seven_nodes_tree_topography(): async def test_set_then_send_from_root_seven_nodes_tree_topography():
num_nodes = 7 num_nodes = 7
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]} adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
async def action_func(dummy_nodes): async def action_func(dummy_nodes):
await dummy_nodes[0].publish_set_crypto("aspyn", 20) await dummy_nodes[0].publish_set_crypto("aspyn", 20)
await asyncio.sleep(0.25) await trio.sleep(0.25)
await dummy_nodes[0].publish_send_crypto("aspyn", "alex", 5) await dummy_nodes[0].publish_send_crypto("aspyn", "alex", 5)
def assertion_func(dummy_node): def assertion_func(dummy_node):
@ -139,14 +118,14 @@ async def test_set_then_send_from_root_seven_nodes_tree_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func) await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio @pytest.mark.trio
async def test_set_then_send_from_different_leafs_seven_nodes_tree_topography(): async def test_set_then_send_from_different_leafs_seven_nodes_tree_topography():
num_nodes = 7 num_nodes = 7
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]} adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
async def action_func(dummy_nodes): async def action_func(dummy_nodes):
await dummy_nodes[6].publish_set_crypto("aspyn", 20) await dummy_nodes[6].publish_set_crypto("aspyn", 20)
await asyncio.sleep(0.25) await trio.sleep(0.25)
await dummy_nodes[4].publish_send_crypto("aspyn", "alex", 5) await dummy_nodes[4].publish_send_crypto("aspyn", "alex", 5)
def assertion_func(dummy_node): def assertion_func(dummy_node):
@ -156,7 +135,7 @@ async def test_set_then_send_from_different_leafs_seven_nodes_tree_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func) await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio @pytest.mark.trio
async def test_simple_five_nodes_ring_topography(): async def test_simple_five_nodes_ring_topography():
num_nodes = 5 num_nodes = 5
adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]} adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]}
@ -170,14 +149,14 @@ async def test_simple_five_nodes_ring_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func) await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio @pytest.mark.trio
async def test_set_then_send_from_diff_nodes_five_nodes_ring_topography(): async def test_set_then_send_from_diff_nodes_five_nodes_ring_topography():
num_nodes = 5 num_nodes = 5
adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]} adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]}
async def action_func(dummy_nodes): async def action_func(dummy_nodes):
await dummy_nodes[0].publish_set_crypto("alex", 20) await dummy_nodes[0].publish_set_crypto("alex", 20)
await asyncio.sleep(0.25) await trio.sleep(0.25)
await dummy_nodes[3].publish_send_crypto("alex", "rob", 12) await dummy_nodes[3].publish_send_crypto("alex", "rob", 12)
def assertion_func(dummy_node): def assertion_func(dummy_node):
@ -187,7 +166,7 @@ async def test_set_then_send_from_diff_nodes_five_nodes_ring_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func) await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio @pytest.mark.trio
@pytest.mark.slow @pytest.mark.slow
async def test_set_then_send_from_five_diff_nodes_five_nodes_ring_topography(): async def test_set_then_send_from_five_diff_nodes_five_nodes_ring_topography():
num_nodes = 5 num_nodes = 5
@ -195,13 +174,13 @@ async def test_set_then_send_from_five_diff_nodes_five_nodes_ring_topography():
async def action_func(dummy_nodes): async def action_func(dummy_nodes):
await dummy_nodes[0].publish_set_crypto("alex", 20) await dummy_nodes[0].publish_set_crypto("alex", 20)
await asyncio.sleep(1) await trio.sleep(1)
await dummy_nodes[1].publish_send_crypto("alex", "rob", 3) await dummy_nodes[1].publish_send_crypto("alex", "rob", 3)
await asyncio.sleep(1) await trio.sleep(1)
await dummy_nodes[2].publish_send_crypto("rob", "aspyn", 2) await dummy_nodes[2].publish_send_crypto("rob", "aspyn", 2)
await asyncio.sleep(1) await trio.sleep(1)
await dummy_nodes[3].publish_send_crypto("aspyn", "zx", 1) await dummy_nodes[3].publish_send_crypto("aspyn", "zx", 1)
await asyncio.sleep(1) await trio.sleep(1)
await dummy_nodes[4].publish_send_crypto("zx", "raul", 1) await dummy_nodes[4].publish_send_crypto("zx", "raul", 1)
def assertion_func(dummy_node): def assertion_func(dummy_node):

View File

@ -1,9 +1,10 @@
import asyncio import functools
import pytest import pytest
import trio
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.tools.factories import FloodsubFactory from libp2p.tools.factories import PubsubFactory
from libp2p.tools.pubsub.floodsub_integration_test_settings import ( from libp2p.tools.pubsub.floodsub_integration_test_settings import (
floodsub_protocol_pytest_params, floodsub_protocol_pytest_params,
perform_test_from_obj, perform_test_from_obj,
@ -11,79 +12,83 @@ from libp2p.tools.pubsub.floodsub_integration_test_settings import (
from libp2p.tools.utils import connect from libp2p.tools.utils import connect
@pytest.mark.parametrize("num_hosts", (2,)) @pytest.mark.trio
@pytest.mark.asyncio async def test_simple_two_nodes():
async def test_simple_two_nodes(pubsubs_fsub): async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
topic = "my_topic" topic = "my_topic"
data = b"some data" data = b"some data"
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host) await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
await asyncio.sleep(0.25) await trio.sleep(0.25)
sub_b = await pubsubs_fsub[1].subscribe(topic) sub_b = await pubsubs_fsub[1].subscribe(topic)
# Sleep to let a know of b's subscription # Sleep to let a know of b's subscription
await asyncio.sleep(0.25) await trio.sleep(0.25)
await pubsubs_fsub[0].publish(topic, data) await pubsubs_fsub[0].publish(topic, data)
res_b = await sub_b.get() res_b = await sub_b.receive()
# Check that the msg received by node_b is the same # Check that the msg received by node_b is the same
# as the message sent by node_a # as the message sent by node_a
assert ID(res_b.from_id) == pubsubs_fsub[0].host.get_id() assert ID(res_b.from_id) == pubsubs_fsub[0].host.get_id()
assert res_b.data == data assert res_b.data == data
assert res_b.topicIDs == [topic] assert res_b.topicIDs == [topic]
# Success, terminate pending tasks.
# Initialize Pubsub with a cache_size of 4 @pytest.mark.trio
@pytest.mark.parametrize("num_hosts, pubsub_cache_size", ((2, 4),)) async def test_lru_cache_two_nodes(monkeypatch):
@pytest.mark.asyncio
async def test_lru_cache_two_nodes(pubsubs_fsub, monkeypatch):
# two nodes with cache_size of 4 # two nodes with cache_size of 4
# `node_a` send the following messages to node_b async with PubsubFactory.create_batch_with_floodsub(
message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1] 2, cache_size=4
# `node_b` should only receive the following ) as pubsubs_fsub:
expected_received_indices = [1, 2, 3, 4, 5, 1] # `node_a` send the following messages to node_b
message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1]
# `node_b` should only receive the following
expected_received_indices = [1, 2, 3, 4, 5, 1]
topic = "my_topic" topic = "my_topic"
# Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`. # Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`.
def get_msg_id(msg): def get_msg_id(msg):
# Originally it is `(msg.seqno, msg.from_id)` # Originally it is `(msg.seqno, msg.from_id)`
return (msg.data, msg.from_id) return (msg.data, msg.from_id)
import libp2p.pubsub.pubsub import libp2p.pubsub.pubsub
monkeypatch.setattr(libp2p.pubsub.pubsub, "get_msg_id", get_msg_id) monkeypatch.setattr(libp2p.pubsub.pubsub, "get_msg_id", get_msg_id)
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host) await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
await asyncio.sleep(0.25) await trio.sleep(0.25)
sub_b = await pubsubs_fsub[1].subscribe(topic) sub_b = await pubsubs_fsub[1].subscribe(topic)
await asyncio.sleep(0.25) await trio.sleep(0.25)
def _make_testing_data(i: int) -> bytes: def _make_testing_data(i: int) -> bytes:
num_int_bytes = 4 num_int_bytes = 4
if i >= 2 ** (num_int_bytes * 8): if i >= 2 ** (num_int_bytes * 8):
raise ValueError("integer is too large to be serialized") raise ValueError("integer is too large to be serialized")
return b"data" + i.to_bytes(num_int_bytes, "big") return b"data" + i.to_bytes(num_int_bytes, "big")
for index in message_indices: for index in message_indices:
await pubsubs_fsub[0].publish(topic, _make_testing_data(index)) await pubsubs_fsub[0].publish(topic, _make_testing_data(index))
await asyncio.sleep(0.25) await trio.sleep(0.25)
for index in expected_received_indices: for index in expected_received_indices:
res_b = await sub_b.get() res_b = await sub_b.receive()
assert res_b.data == _make_testing_data(index) assert res_b.data == _make_testing_data(index)
assert sub_b.empty()
# Success, terminate pending tasks. with pytest.raises(trio.WouldBlock):
sub_b.receive_nowait()
@pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params) @pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params)
@pytest.mark.asyncio @pytest.mark.trio
@pytest.mark.slow @pytest.mark.slow
async def test_gossipsub_run_with_floodsub_tests(test_case_obj): async def test_gossipsub_run_with_floodsub_tests(test_case_obj, is_host_secure):
await perform_test_from_obj(test_case_obj, FloodsubFactory) await perform_test_from_obj(
test_case_obj,
functools.partial(
PubsubFactory.create_batch_with_floodsub, is_secure=is_host_secure
),
)

View File

@ -1,368 +1,350 @@
import asyncio
import random import random
import pytest import pytest
import trio
from libp2p.tools.constants import GossipsubParams from libp2p.tools.factories import PubsubFactory
from libp2p.tools.pubsub.utils import dense_connect, one_to_all_connect from libp2p.tools.pubsub.utils import dense_connect, one_to_all_connect
from libp2p.tools.utils import connect from libp2p.tools.utils import connect
@pytest.mark.parametrize( @pytest.mark.trio
"num_hosts, gossipsub_params", async def test_join():
((4, GossipsubParams(degree=4, degree_low=3, degree_high=5)),), async with PubsubFactory.create_batch_with_gossipsub(
) 4, degree=4, degree_low=3, degree_high=5
@pytest.mark.asyncio ) as pubsubs_gsub:
async def test_join(num_hosts, hosts, pubsubs_gsub): gossipsubs = [pubsub.router for pubsub in pubsubs_gsub]
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub) hosts = [pubsub.host for pubsub in pubsubs_gsub]
hosts_indices = list(range(num_hosts)) hosts_indices = list(range(len(pubsubs_gsub)))
topic = "test_join" topic = "test_join"
central_node_index = 0 central_node_index = 0
# Remove index of central host from the indices # Remove index of central host from the indices
hosts_indices.remove(central_node_index) hosts_indices.remove(central_node_index)
num_subscribed_peer = 2 num_subscribed_peer = 2
subscribed_peer_indices = random.sample(hosts_indices, num_subscribed_peer) subscribed_peer_indices = random.sample(hosts_indices, num_subscribed_peer)
# All pubsub except the one of central node subscribe to topic # All pubsub except the one of central node subscribe to topic
for i in subscribed_peer_indices: for i in subscribed_peer_indices:
await pubsubs_gsub[i].subscribe(topic) await pubsubs_gsub[i].subscribe(topic)
# Connect central host to all other hosts # Connect central host to all other hosts
await one_to_all_connect(hosts, central_node_index) await one_to_all_connect(hosts, central_node_index)
# Wait 2 seconds for heartbeat to allow mesh to connect # Wait 2 seconds for heartbeat to allow mesh to connect
await asyncio.sleep(2) await trio.sleep(2)
# Central node publish to the topic so that this topic
# is added to central node's fanout
# publish from the randomly chosen host
await pubsubs_gsub[central_node_index].publish(topic, b"data")
# Check that the gossipsub of central node has fanout for the topic
assert topic in gossipsubs[central_node_index].fanout
# Check that the gossipsub of central node does not have a mesh for the topic
assert topic not in gossipsubs[central_node_index].mesh
# Central node subscribes the topic
await pubsubs_gsub[central_node_index].subscribe(topic)
await asyncio.sleep(2)
# Check that the gossipsub of central node no longer has fanout for the topic
assert topic not in gossipsubs[central_node_index].fanout
for i in hosts_indices:
if i in subscribed_peer_indices:
assert hosts[i].get_id() in gossipsubs[central_node_index].mesh[topic]
assert hosts[central_node_index].get_id() in gossipsubs[i].mesh[topic]
else:
assert hosts[i].get_id() not in gossipsubs[central_node_index].mesh[topic]
assert topic not in gossipsubs[i].mesh
@pytest.mark.parametrize("num_hosts", (1,))
@pytest.mark.asyncio
async def test_leave(pubsubs_gsub):
gossipsub = pubsubs_gsub[0].router
topic = "test_leave"
assert topic not in gossipsub.mesh
await gossipsub.join(topic)
assert topic in gossipsub.mesh
await gossipsub.leave(topic)
assert topic not in gossipsub.mesh
# Test re-leave
await gossipsub.leave(topic)
@pytest.mark.parametrize("num_hosts", (2,))
@pytest.mark.asyncio
async def test_handle_graft(pubsubs_gsub, hosts, event_loop, monkeypatch):
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub)
index_alice = 0
id_alice = hosts[index_alice].get_id()
index_bob = 1
id_bob = hosts[index_bob].get_id()
await connect(hosts[index_alice], hosts[index_bob])
# Wait 2 seconds for heartbeat to allow mesh to connect
await asyncio.sleep(2)
topic = "test_handle_graft"
# Only lice subscribe to the topic
await gossipsubs[index_alice].join(topic)
# Monkey patch bob's `emit_prune` function so we can
# check if it is called in `handle_graft`
event_emit_prune = asyncio.Event()
async def emit_prune(topic, sender_peer_id):
event_emit_prune.set()
monkeypatch.setattr(gossipsubs[index_bob], "emit_prune", emit_prune)
# Check that alice is bob's peer but not his mesh peer
assert id_alice in gossipsubs[index_bob].peers_gossipsub
assert topic not in gossipsubs[index_bob].mesh
await gossipsubs[index_alice].emit_graft(topic, id_bob)
# Check that `emit_prune` is called
await asyncio.wait_for(event_emit_prune.wait(), timeout=1, loop=event_loop)
assert event_emit_prune.is_set()
# Check that bob is alice's peer but not her mesh peer
assert topic in gossipsubs[index_alice].mesh
assert id_bob not in gossipsubs[index_alice].mesh[topic]
assert id_bob in gossipsubs[index_alice].peers_gossipsub
await gossipsubs[index_bob].emit_graft(topic, id_alice)
await asyncio.sleep(1)
# Check that bob is now alice's mesh peer
assert id_bob in gossipsubs[index_alice].mesh[topic]
@pytest.mark.parametrize(
"num_hosts, gossipsub_params", ((2, GossipsubParams(heartbeat_interval=3)),)
)
@pytest.mark.asyncio
async def test_handle_prune(pubsubs_gsub, hosts):
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub)
index_alice = 0
id_alice = hosts[index_alice].get_id()
index_bob = 1
id_bob = hosts[index_bob].get_id()
topic = "test_handle_prune"
for pubsub in pubsubs_gsub:
await pubsub.subscribe(topic)
await connect(hosts[index_alice], hosts[index_bob])
# Wait 3 seconds for heartbeat to allow mesh to connect
await asyncio.sleep(3)
# Check that they are each other's mesh peer
assert id_alice in gossipsubs[index_bob].mesh[topic]
assert id_bob in gossipsubs[index_alice].mesh[topic]
# alice emit prune message to bob, alice should be removed
# from bob's mesh peer
await gossipsubs[index_alice].emit_prune(topic, id_bob)
# FIXME: This test currently works because the heartbeat interval
# is increased to 3 seconds, so alice won't get add back into
# bob's mesh peer during heartbeat.
await asyncio.sleep(1)
# Check that alice is no longer bob's mesh peer
assert id_alice not in gossipsubs[index_bob].mesh[topic]
assert id_bob in gossipsubs[index_alice].mesh[topic]
@pytest.mark.parametrize("num_hosts", (10,))
@pytest.mark.asyncio
async def test_dense(num_hosts, pubsubs_gsub, hosts):
num_msgs = 5
# All pubsub subscribe to foobar
queues = []
for pubsub in pubsubs_gsub:
q = await pubsub.subscribe("foobar")
# Add each blocking queue to an array of blocking queues
queues.append(q)
# Densely connect libp2p hosts in a random way
await dense_connect(hosts)
# Wait 2 seconds for heartbeat to allow mesh to connect
await asyncio.sleep(2)
for i in range(num_msgs):
msg_content = b"foo " + i.to_bytes(1, "big")
# randomly pick a message origin
origin_idx = random.randint(0, num_hosts - 1)
# Central node publish to the topic so that this topic
# is added to central node's fanout
# publish from the randomly chosen host # publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish("foobar", msg_content) await pubsubs_gsub[central_node_index].publish(topic, b"data")
await asyncio.sleep(0.5) # Check that the gossipsub of central node has fanout for the topic
# Assert that all blocking queues receive the message assert topic in gossipsubs[central_node_index].fanout
for queue in queues: # Check that the gossipsub of central node does not have a mesh for the topic
msg = await queue.get() assert topic not in gossipsubs[central_node_index].mesh
assert msg.data == msg_content
# Central node subscribes the topic
await pubsubs_gsub[central_node_index].subscribe(topic)
await trio.sleep(2)
# Check that the gossipsub of central node no longer has fanout for the topic
assert topic not in gossipsubs[central_node_index].fanout
for i in hosts_indices:
if i in subscribed_peer_indices:
assert hosts[i].get_id() in gossipsubs[central_node_index].mesh[topic]
assert hosts[central_node_index].get_id() in gossipsubs[i].mesh[topic]
else:
assert (
hosts[i].get_id() not in gossipsubs[central_node_index].mesh[topic]
)
assert topic not in gossipsubs[i].mesh
@pytest.mark.parametrize("num_hosts", (10,)) @pytest.mark.trio
@pytest.mark.asyncio async def test_leave():
async def test_fanout(hosts, pubsubs_gsub): async with PubsubFactory.create_batch_with_gossipsub(1) as pubsubs_gsub:
num_msgs = 5 gossipsub = pubsubs_gsub[0].router
topic = "test_leave"
# All pubsub subscribe to foobar except for `pubsubs_gsub[0]` assert topic not in gossipsub.mesh
queues = []
for i in range(1, len(pubsubs_gsub)):
q = await pubsubs_gsub[i].subscribe("foobar")
# Add each blocking queue to an array of blocking queues await gossipsub.join(topic)
queues.append(q) assert topic in gossipsub.mesh
# Sparsely connect libp2p hosts in random way await gossipsub.leave(topic)
await dense_connect(hosts) assert topic not in gossipsub.mesh
# Wait 2 seconds for heartbeat to allow mesh to connect # Test re-leave
await asyncio.sleep(2) await gossipsub.leave(topic)
topic = "foobar"
# Send messages with origin not subscribed
for i in range(num_msgs):
msg_content = b"foo " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await asyncio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.get()
assert msg.data == msg_content
# Subscribe message origin
queues.insert(0, await pubsubs_gsub[0].subscribe(topic))
# Send messages again
for i in range(num_msgs):
msg_content = b"bar " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await asyncio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.get()
assert msg.data == msg_content
@pytest.mark.parametrize("num_hosts", (10,)) @pytest.mark.trio
@pytest.mark.asyncio async def test_handle_graft(monkeypatch):
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub)
index_alice = 0
id_alice = pubsubs_gsub[index_alice].my_id
index_bob = 1
id_bob = pubsubs_gsub[index_bob].my_id
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
# Wait 2 seconds for heartbeat to allow mesh to connect
await trio.sleep(2)
topic = "test_handle_graft"
# Only lice subscribe to the topic
await gossipsubs[index_alice].join(topic)
# Monkey patch bob's `emit_prune` function so we can
# check if it is called in `handle_graft`
event_emit_prune = trio.Event()
async def emit_prune(topic, sender_peer_id):
event_emit_prune.set()
monkeypatch.setattr(gossipsubs[index_bob], "emit_prune", emit_prune)
# Check that alice is bob's peer but not his mesh peer
assert id_alice in gossipsubs[index_bob].peers_gossipsub
assert topic not in gossipsubs[index_bob].mesh
await gossipsubs[index_alice].emit_graft(topic, id_bob)
# Check that `emit_prune` is called
await event_emit_prune.wait()
# Check that bob is alice's peer but not her mesh peer
assert topic in gossipsubs[index_alice].mesh
assert id_bob not in gossipsubs[index_alice].mesh[topic]
assert id_bob in gossipsubs[index_alice].peers_gossipsub
await gossipsubs[index_bob].emit_graft(topic, id_alice)
await trio.sleep(1)
# Check that bob is now alice's mesh peer
assert id_bob in gossipsubs[index_alice].mesh[topic]
@pytest.mark.trio
async def test_handle_prune():
async with PubsubFactory.create_batch_with_gossipsub(
2, heartbeat_interval=3
) as pubsubs_gsub:
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub)
index_alice = 0
id_alice = pubsubs_gsub[index_alice].my_id
index_bob = 1
id_bob = pubsubs_gsub[index_bob].my_id
topic = "test_handle_prune"
for pubsub in pubsubs_gsub:
await pubsub.subscribe(topic)
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
# Wait 3 seconds for heartbeat to allow mesh to connect
await trio.sleep(3)
# Check that they are each other's mesh peer
assert id_alice in gossipsubs[index_bob].mesh[topic]
assert id_bob in gossipsubs[index_alice].mesh[topic]
# alice emit prune message to bob, alice should be removed
# from bob's mesh peer
await gossipsubs[index_alice].emit_prune(topic, id_bob)
# FIXME: This test currently works because the heartbeat interval
# is increased to 3 seconds, so alice won't get add back into
# bob's mesh peer during heartbeat.
await trio.sleep(1)
# Check that alice is no longer bob's mesh peer
assert id_alice not in gossipsubs[index_bob].mesh[topic]
assert id_bob in gossipsubs[index_alice].mesh[topic]
@pytest.mark.trio
async def test_dense():
async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub:
hosts = [pubsub.host for pubsub in pubsubs_gsub]
num_msgs = 5
# All pubsub subscribe to foobar
queues = [await pubsub.subscribe("foobar") for pubsub in pubsubs_gsub]
# Densely connect libp2p hosts in a random way
await dense_connect(hosts)
# Wait 2 seconds for heartbeat to allow mesh to connect
await trio.sleep(2)
for i in range(num_msgs):
msg_content = b"foo " + i.to_bytes(1, "big")
# randomly pick a message origin
origin_idx = random.randint(0, len(hosts) - 1)
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish("foobar", msg_content)
await trio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.receive()
assert msg.data == msg_content
@pytest.mark.trio
async def test_fanout():
async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub:
hosts = [pubsub.host for pubsub in pubsubs_gsub]
num_msgs = 5
# All pubsub subscribe to foobar except for `pubsubs_gsub[0]`
subs = [await pubsub.subscribe("foobar") for pubsub in pubsubs_gsub[1:]]
# Sparsely connect libp2p hosts in random way
await dense_connect(hosts)
# Wait 2 seconds for heartbeat to allow mesh to connect
await trio.sleep(2)
topic = "foobar"
# Send messages with origin not subscribed
for i in range(num_msgs):
msg_content = b"foo " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await trio.sleep(0.5)
# Assert that all blocking queues receive the message
for sub in subs:
msg = await sub.receive()
assert msg.data == msg_content
# Subscribe message origin
subs.insert(0, await pubsubs_gsub[0].subscribe(topic))
# Send messages again
for i in range(num_msgs):
msg_content = b"bar " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await trio.sleep(0.5)
# Assert that all blocking queues receive the message
for sub in subs:
msg = await sub.receive()
assert msg.data == msg_content
@pytest.mark.trio
@pytest.mark.slow @pytest.mark.slow
async def test_fanout_maintenance(hosts, pubsubs_gsub): async def test_fanout_maintenance():
num_msgs = 5 async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub:
hosts = [pubsub.host for pubsub in pubsubs_gsub]
num_msgs = 5
# All pubsub subscribe to foobar # All pubsub subscribe to foobar
queues = [] queues = []
topic = "foobar" topic = "foobar"
for i in range(1, len(pubsubs_gsub)): for i in range(1, len(pubsubs_gsub)):
q = await pubsubs_gsub[i].subscribe(topic) q = await pubsubs_gsub[i].subscribe(topic)
# Add each blocking queue to an array of blocking queues # Add each blocking queue to an array of blocking queues
queues.append(q) queues.append(q)
# Sparsely connect libp2p hosts in random way # Sparsely connect libp2p hosts in random way
await dense_connect(hosts) await dense_connect(hosts)
# Wait 2 seconds for heartbeat to allow mesh to connect # Wait 2 seconds for heartbeat to allow mesh to connect
await asyncio.sleep(2) await trio.sleep(2)
# Send messages with origin not subscribed # Send messages with origin not subscribed
for i in range(num_msgs): for i in range(num_msgs):
msg_content = b"foo " + i.to_bytes(1, "big") msg_content = b"foo " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar' # Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0 origin_idx = 0
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await trio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.receive()
assert msg.data == msg_content
for sub in pubsubs_gsub:
await sub.unsubscribe(topic)
queues = []
await trio.sleep(2)
# Resub and repeat
for i in range(1, len(pubsubs_gsub)):
q = await pubsubs_gsub[i].subscribe(topic)
# Add each blocking queue to an array of blocking queues
queues.append(q)
await trio.sleep(2)
# Check messages can still be sent
for i in range(num_msgs):
msg_content = b"bar " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await trio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.receive()
assert msg.data == msg_content
@pytest.mark.trio
async def test_gossip_propagation():
async with PubsubFactory.create_batch_with_gossipsub(
2, degree=1, degree_low=0, degree_high=2, gossip_window=50, gossip_history=100
) as pubsubs_gsub:
topic = "foo"
await pubsubs_gsub[0].subscribe(topic)
# node 0 publish to topic
msg_content = b"foo_msg"
# publish from the randomly chosen host # publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content) await pubsubs_gsub[0].publish(topic, msg_content)
await asyncio.sleep(0.5) # now node 1 subscribes
# Assert that all blocking queues receive the message queue_1 = await pubsubs_gsub[1].subscribe(topic)
for queue in queues:
msg = await queue.get()
assert msg.data == msg_content
for sub in pubsubs_gsub: await connect(pubsubs_gsub[0].host, pubsubs_gsub[1].host)
await sub.unsubscribe(topic)
queues = [] # wait for gossip heartbeat
await trio.sleep(2)
await asyncio.sleep(2) # should be able to read message
msg = await queue_1.receive()
# Resub and repeat assert msg.data == msg_content
for i in range(1, len(pubsubs_gsub)):
q = await pubsubs_gsub[i].subscribe(topic)
# Add each blocking queue to an array of blocking queues
queues.append(q)
await asyncio.sleep(2)
# Check messages can still be sent
for i in range(num_msgs):
msg_content = b"bar " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await asyncio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.get()
assert msg.data == msg_content
@pytest.mark.parametrize(
"num_hosts, gossipsub_params",
(
(
2,
GossipsubParams(
degree=1,
degree_low=0,
degree_high=2,
gossip_window=50,
gossip_history=100,
),
),
),
)
@pytest.mark.asyncio
async def test_gossip_propagation(hosts, pubsubs_gsub):
topic = "foo"
await pubsubs_gsub[0].subscribe(topic)
# node 0 publish to topic
msg_content = b"foo_msg"
# publish from the randomly chosen host
await pubsubs_gsub[0].publish(topic, msg_content)
# now node 1 subscribes
queue_1 = await pubsubs_gsub[1].subscribe(topic)
await connect(hosts[0], hosts[1])
# wait for gossip heartbeat
await asyncio.sleep(2)
# should be able to read message
msg = await queue_1.get()
assert msg.data == msg_content

View File

@ -3,25 +3,25 @@ import functools
import pytest import pytest
from libp2p.tools.constants import FLOODSUB_PROTOCOL_ID from libp2p.tools.constants import FLOODSUB_PROTOCOL_ID
from libp2p.tools.factories import GossipsubFactory from libp2p.tools.factories import PubsubFactory
from libp2p.tools.pubsub.floodsub_integration_test_settings import ( from libp2p.tools.pubsub.floodsub_integration_test_settings import (
floodsub_protocol_pytest_params, floodsub_protocol_pytest_params,
perform_test_from_obj, perform_test_from_obj,
) )
@pytest.mark.asyncio
async def test_gossipsub_initialize_with_floodsub_protocol():
GossipsubFactory(protocols=[FLOODSUB_PROTOCOL_ID])
@pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params) @pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params)
@pytest.mark.asyncio @pytest.mark.trio
@pytest.mark.slow @pytest.mark.slow
async def test_gossipsub_run_with_floodsub_tests(test_case_obj): async def test_gossipsub_run_with_floodsub_tests(test_case_obj):
await perform_test_from_obj( await perform_test_from_obj(
test_case_obj, test_case_obj,
functools.partial( functools.partial(
GossipsubFactory, degree=3, degree_low=2, degree_high=4, time_to_live=30 PubsubFactory.create_batch_with_gossipsub,
protocols=[FLOODSUB_PROTOCOL_ID],
degree=3,
degree_low=2,
degree_high=4,
time_to_live=30,
), ),
) )

View File

@ -1,5 +1,3 @@
import pytest
from libp2p.pubsub.mcache import MessageCache from libp2p.pubsub.mcache import MessageCache
@ -12,8 +10,7 @@ class Msg:
self.from_id = from_id self.from_id = from_id
@pytest.mark.asyncio def test_mcache():
async def test_mcache():
# Ported from: # Ported from:
# https://github.com/libp2p/go-libp2p-pubsub/blob/51b7501433411b5096cac2b4994a36a68515fc03/mcache_test.go # https://github.com/libp2p/go-libp2p-pubsub/blob/51b7501433411b5096cac2b4994a36a68515fc03/mcache_test.go
mcache = MessageCache(3, 5) mcache = MessageCache(3, 5)

View File

@ -5,12 +5,11 @@ import pytest
import trio import trio
from libp2p.exceptions import ValidationError from libp2p.exceptions import ValidationError
from libp2p.peer.id import ID
from libp2p.pubsub.pb import rpc_pb2 from libp2p.pubsub.pb import rpc_pb2
from libp2p.tools.constants import MAX_READ_LEN
from libp2p.tools.factories import IDFactory, PubsubFactory, net_stream_pair_factory
from libp2p.tools.pubsub.utils import make_pubsub_msg from libp2p.tools.pubsub.utils import make_pubsub_msg
from libp2p.tools.utils import connect from libp2p.tools.utils import connect
from libp2p.tools.constants import MAX_READ_LEN
from libp2p.tools.factories import PubsubFactory, net_stream_pair_factory, IDFactory
from libp2p.utils import encode_varint_prefixed from libp2p.utils import encode_varint_prefixed
TESTING_TOPIC = "TEST_SUBSCRIBE" TESTING_TOPIC = "TEST_SUBSCRIBE"
@ -250,14 +249,14 @@ async def test_continuously_read_stream(monkeypatch, nursery, is_host_secure):
async def mock_push_msg(msg_forwarder, msg): async def mock_push_msg(msg_forwarder, msg):
event_push_msg.set() event_push_msg.set()
await trio.sleep(0) await trio.hazmat.checkpoint()
def mock_handle_subscription(origin_id, sub_message): def mock_handle_subscription(origin_id, sub_message):
event_handle_subscription.set() event_handle_subscription.set()
async def mock_handle_rpc(rpc, sender_peer_id): async def mock_handle_rpc(rpc, sender_peer_id):
event_handle_rpc.set() event_handle_rpc.set()
await trio.sleep(0) await trio.hazmat.checkpoint()
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setattr(pubsubs_fsub[0], "push_msg", mock_push_msg) m.setattr(pubsubs_fsub[0], "push_msg", mock_push_msg)

View File

@ -69,33 +69,10 @@ async def test_mplex_stream_read_after_remote_closed(mplex_stream_pair):
await stream_0.close() await stream_0.close()
assert stream_0.event_local_closed.is_set() assert stream_0.event_local_closed.is_set()
await trio.sleep(0.01) await trio.sleep(0.01)
print(
"!@# ",
stream_0.muxed_conn.event_shutting_down.is_set(),
stream_0.muxed_conn.event_closed.is_set(),
stream_1.muxed_conn.event_shutting_down.is_set(),
stream_1.muxed_conn.event_closed.is_set(),
)
# await trio.sleep(100000) # await trio.sleep(100000)
await wait_all_tasks_blocked() await wait_all_tasks_blocked()
print(
"!@# ",
stream_0.muxed_conn.event_shutting_down.is_set(),
stream_0.muxed_conn.event_closed.is_set(),
stream_1.muxed_conn.event_shutting_down.is_set(),
stream_1.muxed_conn.event_closed.is_set(),
)
print("!@# sleeping")
print("!@# result=", stream_1.event_remote_closed.is_set())
# await trio.sleep_forever() # await trio.sleep_forever()
assert stream_1.event_remote_closed.is_set() assert stream_1.event_remote_closed.is_set()
print(
"!@# ",
stream_0.muxed_conn.event_shutting_down.is_set(),
stream_0.muxed_conn.event_closed.is_set(),
stream_1.muxed_conn.event_shutting_down.is_set(),
stream_1.muxed_conn.event_closed.is_set(),
)
assert (await stream_1.read(MAX_READ_LEN)) == DATA assert (await stream_1.read(MAX_READ_LEN)) == DATA
with pytest.raises(MplexStreamEOF): with pytest.raises(MplexStreamEOF):
await stream_1.read(MAX_READ_LEN) await stream_1.read(MAX_READ_LEN)

View File

@ -3,7 +3,7 @@ import pytest
import trio import trio
from libp2p.network.connection.raw_connection import RawConnection from libp2p.network.connection.raw_connection import RawConnection
from libp2p.tools.constants import LISTEN_MADDR, MAX_READ_LEN from libp2p.tools.constants import LISTEN_MADDR
from libp2p.transport.tcp.tcp import TCP from libp2p.transport.tcp.tcp import TCP