Fix all modules except for security

This commit is contained in:
mhchia
2019-12-06 17:06:37 +08:00
parent e9ab0646e3
commit 1929f307fb
28 changed files with 764 additions and 955 deletions

View File

@ -1,6 +1,7 @@
import trio
import logging
import trio
from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset
from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.peer.id import ID as PeerID

View File

@ -1,7 +1,6 @@
import logging
import trio
from trio import SocketStream
from libp2p.io.abc import ReadWriteCloser
from libp2p.io.exceptions import IOException
@ -9,29 +8,48 @@ from libp2p.io.exceptions import IOException
logger = logging.getLogger("libp2p.io.trio")
class TrioReadWriteCloser(ReadWriteCloser):
stream: SocketStream
class TrioTCPStream(ReadWriteCloser):
stream: trio.SocketStream
# NOTE: Add both read and write lock to avoid `trio.BusyResourceError`
read_lock: trio.Lock
write_lock: trio.Lock
def __init__(self, stream: SocketStream) -> None:
def __init__(self, stream: trio.SocketStream) -> None:
self.stream = stream
self.read_lock = trio.Lock()
self.write_lock = trio.Lock()
async def write(self, data: bytes) -> None:
"""Raise `RawConnError` if the underlying connection breaks."""
try:
await self.stream.send_all(data)
except (trio.ClosedResourceError, trio.BrokenResourceError) as error:
raise IOException(error)
async with self.write_lock:
try:
await self.stream.send_all(data)
except (trio.ClosedResourceError, trio.BrokenResourceError) as error:
raise IOException from error
except trio.BusyResourceError as error:
# This should never happen, since we already access streams with read/write locks.
raise Exception(
"this should never happen "
"since we already access streams with read/write locks."
) from error
async def read(self, n: int = -1) -> bytes:
if n == 0:
# Check point
await trio.sleep(0)
return b""
max_bytes = n if n != -1 else None
try:
return await self.stream.receive_some(max_bytes)
except (trio.ClosedResourceError, trio.BrokenResourceError) as error:
raise IOException(error)
async with self.read_lock:
if n == 0:
# Checkpoint
await trio.hazmat.checkpoint()
return b""
max_bytes = n if n != -1 else None
try:
return await self.stream.receive_some(max_bytes)
except (trio.ClosedResourceError, trio.BrokenResourceError) as error:
raise IOException from error
except trio.BusyResourceError as error:
# This should never happen, since we already access streams with read/write locks.
raise Exception(
"this should never happen "
"since we already access streams with read/write locks."
) from error
async def close(self) -> None:
await self.stream.aclose()

View File

@ -1,5 +1,3 @@
import trio
from libp2p.io.abc import ReadWriteCloser
from libp2p.io.exceptions import IOException
@ -8,17 +6,17 @@ from .raw_connection_interface import IRawConnection
class RawConnection(IRawConnection):
read_write_closer: ReadWriteCloser
stream: ReadWriteCloser
is_initiator: bool
def __init__(self, read_write_closer: ReadWriteCloser, initiator: bool) -> None:
self.read_write_closer = read_write_closer
def __init__(self, stream: ReadWriteCloser, initiator: bool) -> None:
self.stream = stream
self.is_initiator = initiator
async def write(self, data: bytes) -> None:
"""Raise `RawConnError` if the underlying connection breaks."""
try:
await self.read_write_closer.write(data)
await self.stream.write(data)
except IOException as error:
raise RawConnError(error)
@ -30,9 +28,9 @@ class RawConnection(IRawConnection):
Raise `RawConnError` if the underlying connection breaks
"""
try:
return await self.read_write_closer.read(n)
return await self.stream.read(n)
except IOException as error:
raise RawConnError(error)
async def close(self) -> None:
await self.read_write_closer.close()
await self.stream.close()

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Awaitable, List, Set, Tuple
from typing import TYPE_CHECKING, Set, Tuple
from async_service import Service
import trio
@ -45,16 +45,11 @@ class SwarmConn(INetConn, Service):
# before we cancel the stream handler tasks.
await trio.sleep(0.1)
# FIXME: Now let `_notify_disconnected` finish first.
# Schedule `self._notify_disconnected` to make it execute after `close` is finished.
await self._notify_disconnected()
async def _handle_new_streams(self) -> None:
while self.manager.is_running:
try:
print(
f"!@# SwarmConn._handle_new_streams: {self.muxed_conn._id}: waiting for new streams"
)
stream = await self.muxed_conn.accept_stream()
except MuxedConnUnavailable:
# If there is anything wrong in the MuxedConn,
@ -63,9 +58,6 @@ class SwarmConn(INetConn, Service):
# Asynchronously handle the accepted stream, to avoid blocking the next stream.
self.manager.run_task(self._handle_muxed_stream, stream)
print(
f"!@# SwarmConn._handle_new_streams: {self.muxed_conn._id}: out of the loop"
)
await self.close()
async def _call_stream_handler(self, net_stream: NetStream) -> None:
@ -92,8 +84,7 @@ class SwarmConn(INetConn, Service):
await self.swarm.notify_disconnected(self)
async def run(self) -> None:
self.manager.run_task(self._handle_new_streams)
await self.manager.wait_finished()
await self._handle_new_streams()
async def new_stream(self) -> NetStream:
muxed_stream = await self.muxed_conn.open_stream()

View File

@ -203,16 +203,17 @@ class Swarm(INetwork, Service):
await self.add_conn(muxed_conn)
logger.debug("successfully opened connection to peer %s", peer_id)
# FIXME: This is a intentional barrier to prevent from the handler exiting and
# closing the connection. Probably change to `Service.manager.wait_finished`?
await trio.sleep_forever()
# NOTE: This is a intentional barrier to prevent from the handler exiting and
# closing the connection.
await self.manager.wait_finished()
try:
# Success
listener = self.transport.create_listener(conn_handler)
self.listeners[str(maddr)] = listener
# FIXME: Hack
await listener.listen(maddr, self.manager._task_nursery)
# TODO: `listener.listen` is not bounded with nursery. If we want to be
# I/O agnostic, we should change the API.
await listener.listen(maddr, self.manager._task_nursery) # type: ignore
# Call notifiers since event occurred
await self.notify_listen(maddr)
@ -278,6 +279,7 @@ class Swarm(INetwork, Service):
"""
self.notifees.append(notifee)
# TODO: Use `run_task`.
async def notify_opened_stream(self, stream: INetStream) -> None:
async with trio.open_nursery() as nursery:
for notifee in self.notifees:

View File

@ -64,7 +64,7 @@ class FloodSub(IPubsubRouter):
:param rpc: rpc message
"""
# Checkpoint
await trio.sleep(0)
await trio.hazmat.checkpoint()
async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None:
"""
@ -107,7 +107,7 @@ class FloodSub(IPubsubRouter):
:param topic: topic to join
"""
# Checkpoint
await trio.sleep(0)
await trio.hazmat.checkpoint()
async def leave(self, topic: str) -> None:
"""
@ -117,7 +117,7 @@ class FloodSub(IPubsubRouter):
:param topic: topic to leave
"""
# Checkpoint
await trio.sleep(0)
await trio.hazmat.checkpoint()
def _get_peers_to_send(
self, topic_ids: Iterable[str], msg_forwarder: ID, origin: ID

View File

@ -1,15 +1,18 @@
from ast import literal_eval
import asyncio
import logging
import random
from typing import Any, Dict, Iterable, List, Sequence, Set
from async_service import Service
import trio
from libp2p.network.stream.exceptions import StreamClosed
from libp2p.peer.id import ID
from libp2p.pubsub import floodsub
from libp2p.typing import TProtocol
from libp2p.utils import encode_varint_prefixed
from .exceptions import NoPubsubAttached
from .mcache import MessageCache
from .pb import rpc_pb2
from .pubsub import Pubsub
@ -20,8 +23,7 @@ PROTOCOL_ID = TProtocol("/meshsub/1.0.0")
logger = logging.getLogger("libp2p.pubsub.gossipsub")
class GossipSub(IPubsubRouter):
class GossipSub(IPubsubRouter, Service):
protocols: List[TProtocol]
pubsub: Pubsub
@ -86,6 +88,12 @@ class GossipSub(IPubsubRouter):
# Create heartbeat timer
self.heartbeat_interval = heartbeat_interval
async def run(self) -> None:
if self.pubsub is None:
raise NoPubsubAttached
self.manager.run_task(self.heartbeat)
await self.manager.wait_finished()
# Interface functions
def get_protocols(self) -> List[TProtocol]:
@ -105,10 +113,6 @@ class GossipSub(IPubsubRouter):
logger.debug("attached to pusub")
# Start heartbeat now that we have a pubsub instance
# TODO: Start after delay
asyncio.ensure_future(self.heartbeat())
def add_peer(self, peer_id: ID, protocol_id: TProtocol) -> None:
"""
Notifies the router that a new peer has been connected.
@ -310,7 +314,7 @@ class GossipSub(IPubsubRouter):
await self.fanout_heartbeat()
await self.gossip_heartbeat()
await asyncio.sleep(self.heartbeat_interval)
await trio.sleep(self.heartbeat_interval)
async def mesh_heartbeat(self) -> None:
# Note: the comments here are the exact pseudocode from the spec
@ -338,7 +342,7 @@ class GossipSub(IPubsubRouter):
if num_mesh_peers_in_topic > self.degree_high:
# Select |mesh[topic]| - D peers from mesh[topic]
selected_peers = GossipSub.select_from_minus(
selected_peers = self.select_from_minus(
num_mesh_peers_in_topic - self.degree, self.mesh[topic], []
)
for peer in selected_peers:
@ -353,7 +357,10 @@ class GossipSub(IPubsubRouter):
for topic in self.fanout:
# If time since last published > ttl
# TODO: there's no way time_since_last_publish gets set anywhere yet
if self.time_since_last_publish[topic] > self.time_to_live:
if (
topic in self.time_since_last_publish
and self.time_since_last_publish[topic] > self.time_to_live
):
# Remove topic from fanout
del self.fanout[topic]
del self.time_since_last_publish[topic]
@ -407,11 +414,7 @@ class GossipSub(IPubsubRouter):
topic, self.degree, []
)
for peer in peers_to_emit_ihave_to:
if (
peer not in self.mesh[topic]
and peer not in self.fanout[topic]
):
if peer not in self.fanout[topic]:
msg_id_strs = [str(msg) for msg in msg_ids]
await self.emit_ihave(topic, msg_id_strs, peer)

View File

@ -1,4 +1,4 @@
from abc import ABC, abstractmethod
from abc import ABC
import logging
import math
import time
@ -57,6 +57,7 @@ class TopicValidator(NamedTuple):
is_async: bool
# TODO: Add interface for Pubsub
class BasePubsub(ABC):
pass
@ -103,20 +104,24 @@ class Pubsub(BasePubsub, Service):
# Attach this new Pubsub object to the router
self.router.attach(self)
peer_send_channel, peer_receive_channel = trio.open_memory_channel(0)
dead_peer_send_channel, dead_peer_receive_channel = trio.open_memory_channel(0)
peer_channels: Tuple[
"trio.MemorySendChannel[ID]", "trio.MemoryReceiveChannel[ID]"
] = trio.open_memory_channel(0)
dead_peer_channels: Tuple[
"trio.MemorySendChannel[ID]", "trio.MemoryReceiveChannel[ID]"
] = trio.open_memory_channel(0)
# Only keep the receive channels in `Pubsub`.
# Therefore, we can only close from the receive side.
self.peer_receive_channel = peer_receive_channel
self.dead_peer_receive_channel = dead_peer_receive_channel
self.peer_receive_channel = peer_channels[1]
self.dead_peer_receive_channel = dead_peer_channels[1]
# Register a notifee
self.host.get_network().register_notifee(
PubsubNotifee(peer_send_channel, dead_peer_send_channel)
PubsubNotifee(peer_channels[0], dead_peer_channels[0])
)
# Register stream handlers for each pubsub router protocol to handle
# the pubsub streams opened on those protocols
for protocol in router.protocols:
for protocol in router.get_protocols():
self.host.set_stream_handler(protocol, self.stream_handler)
# keeps track of seen messages as LRU cache
@ -328,8 +333,9 @@ class Pubsub(BasePubsub, Service):
self.manager.run_task(self._handle_new_peer, peer_id)
async def handle_dead_peer_queue(self) -> None:
"""Continuously read from dead peer channel and close the stream between
that peer and remove peer info from pubsub and pubsub router."""
"""Continuously read from dead peer channel and close the stream
between that peer and remove peer info from pubsub and pubsub
router."""
async with self.dead_peer_receive_channel:
while self.manager.is_running:
peer_id: ID = await self.dead_peer_receive_channel.receive()
@ -391,7 +397,11 @@ class Pubsub(BasePubsub, Service):
return self.subscribed_topics_receive[topic_id]
# Map topic_id to a blocking channel
send_channel, receive_channel = trio.open_memory_channel(math.inf)
channels: Tuple[
"trio.MemorySendChannel[rpc_pb2.Message]",
"trio.MemoryReceiveChannel[rpc_pb2.Message]",
] = trio.open_memory_channel(math.inf)
send_channel, receive_channel = channels
self.subscribed_topics_send[topic_id] = send_channel
self.subscribed_topics_receive[topic_id] = receive_channel
@ -506,7 +516,7 @@ class Pubsub(BasePubsub, Service):
if len(async_topic_validators) > 0:
# TODO: Use a better pattern
final_result = True
final_result: bool = True
async def run_async_validator(func: AsyncValidatorFn) -> None:
nonlocal final_result
@ -514,8 +524,8 @@ class Pubsub(BasePubsub, Service):
final_result = final_result and result
async with trio.open_nursery() as nursery:
for validator in async_topic_validators:
nursery.start_soon(run_async_validator, validator)
for async_validator in async_topic_validators:
nursery.start_soon(run_async_validator, async_validator)
if not final_result:
raise ValidationError(f"Validation failed for msg={msg}")

View File

@ -1,11 +1,13 @@
from abc import ABC, abstractmethod
from abc import abstractmethod
from async_service import ServiceAPI
from libp2p.io.abc import ReadWriteCloser
from libp2p.peer.id import ID
from libp2p.security.secure_conn_interface import ISecureConn
class IMuxedConn(ABC):
class IMuxedConn(ServiceAPI):
"""
reference: https://github.com/libp2p/go-stream-muxer/blob/master/muxer.go
"""

View File

@ -1,7 +1,6 @@
import logging
import math
from typing import Any # noqa: F401
from typing import Awaitable, Dict, List, Optional, Tuple
from typing import Dict, Optional, Tuple
from async_service import Service
import trio
@ -67,13 +66,15 @@ class Mplex(IMuxedConn, Service):
self.streams = {}
self.streams_lock = trio.Lock()
self.streams_msg_channels = {}
send_channel, receive_channel = trio.open_memory_channel(math.inf)
self.new_stream_send_channel = send_channel
self.new_stream_receive_channel = receive_channel
channels: Tuple[
"trio.MemorySendChannel[IMuxedStream]",
"trio.MemoryReceiveChannel[IMuxedStream]",
] = trio.open_memory_channel(math.inf)
self.new_stream_send_channel, self.new_stream_receive_channel = channels
self.event_shutting_down = trio.Event()
self.event_closed = trio.Event()
async def run(self):
async def run(self) -> None:
self.manager.run_task(self.handle_incoming)
await self.manager.wait_finished()
@ -112,11 +113,13 @@ class Mplex(IMuxedConn, Service):
async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream:
# Use an unbounded buffer, to avoid `handle_incoming` being blocked when doing
# `send_channel.send`.
send_channel, receive_channel = trio.open_memory_channel(math.inf)
stream = MplexStream(name, stream_id, self, receive_channel)
channels: Tuple[
"trio.MemorySendChannel[bytes]", "trio.MemoryReceiveChannel[bytes]"
] = trio.open_memory_channel(math.inf)
stream = MplexStream(name, stream_id, self, channels[1])
async with self.streams_lock:
self.streams[stream_id] = stream
self.streams_msg_channels[stream_id] = send_channel
self.streams_msg_channels[stream_id] = channels[0]
return stream
async def open_stream(self) -> IMuxedStream:
@ -150,9 +153,6 @@ class Mplex(IMuxedConn, Service):
:param data: data to send in the message
:param stream_id: stream the message is in
"""
print(
f"!@# send_message: {self._id}: flag={flag}, data={data}, stream_id={stream_id}"
)
# << by 3, then or with flag
header = encode_uvarint((stream_id.channel_id << 3) | flag.value)
@ -179,19 +179,10 @@ class Mplex(IMuxedConn, Service):
while self.manager.is_running:
try:
print(
f"!@# handle_incoming: {self._id}: before _handle_incoming_message"
)
await self._handle_incoming_message()
print(
f"!@# handle_incoming: {self._id}: after _handle_incoming_message"
)
except MplexUnavailable as e:
logger.debug("mplex unavailable while waiting for incoming: %s", e)
print(f"!@# handle_incoming: {self._id}: MplexUnavailable: {e}")
break
print(f"!@# handle_incoming: {self._id}: leaving")
# If we enter here, it means this connection is shutting down.
# We should clean things up.
await self._cleanup()
@ -232,44 +223,27 @@ class Mplex(IMuxedConn, Service):
:raise MplexUnavailable: `Mplex` encounters fatal error or is shutting down.
"""
print(f"!@# _handle_incoming_message: {self._id}: before reading")
channel_id, flag, message = await self.read_message()
print(
f"!@# _handle_incoming_message: {self._id}: channel_id={channel_id}, flag={flag}, message={message}"
)
stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1))
print(f"!@# _handle_incoming_message: {self._id}: 2")
if flag == HeaderTags.NewStream.value:
print(f"!@# _handle_incoming_message: {self._id}: 3")
await self._handle_new_stream(stream_id, message)
print(f"!@# _handle_incoming_message: {self._id}: 4")
elif flag in (
HeaderTags.MessageInitiator.value,
HeaderTags.MessageReceiver.value,
):
print(f"!@# _handle_incoming_message: {self._id}: 5")
await self._handle_message(stream_id, message)
print(f"!@# _handle_incoming_message: {self._id}: 6")
elif flag in (HeaderTags.CloseInitiator.value, HeaderTags.CloseReceiver.value):
print(f"!@# _handle_incoming_message: {self._id}: 7")
await self._handle_close(stream_id)
print(f"!@# _handle_incoming_message: {self._id}: 8")
elif flag in (HeaderTags.ResetInitiator.value, HeaderTags.ResetReceiver.value):
print(f"!@# _handle_incoming_message: {self._id}: 9")
await self._handle_reset(stream_id)
print(f"!@# _handle_incoming_message: {self._id}: 10")
else:
print(f"!@# _handle_incoming_message: {self._id}: 11")
# Receives messages with an unknown flag
# TODO: logging
async with self.streams_lock:
print(f"!@# _handle_incoming_message: {self._id}: 12")
if stream_id in self.streams:
print(f"!@# _handle_incoming_message: {self._id}: 13")
stream = self.streams[stream_id]
await stream.reset()
print(f"!@# _handle_incoming_message: {self._id}: 14")
async def _handle_new_stream(self, stream_id: StreamID, message: bytes) -> None:
async with self.streams_lock:
@ -285,59 +259,43 @@ class Mplex(IMuxedConn, Service):
raise MplexUnavailable
async def _handle_message(self, stream_id: StreamID, message: bytes) -> None:
print(
f"!@# _handle_message: {self._id}: stream_id={stream_id}, message={message}"
)
async with self.streams_lock:
print(f"!@# _handle_message: {self._id}: 1")
if stream_id not in self.streams:
# We receive a message of the stream `stream_id` which is not accepted
# before. It is abnormal. Possibly disconnect?
# TODO: Warn and emit logs about this.
print(f"!@# _handle_message: {self._id}: 2")
return
print(f"!@# _handle_message: {self._id}: 3")
stream = self.streams[stream_id]
send_channel = self.streams_msg_channels[stream_id]
async with stream.close_lock:
print(f"!@# _handle_message: {self._id}: 4")
if stream.event_remote_closed.is_set():
print(f"!@# _handle_message: {self._id}: 5")
# TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501
return
print(f"!@# _handle_message: {self._id}: 6")
await send_channel.send(message)
print(f"!@# _handle_message: {self._id}: 7")
async def _handle_close(self, stream_id: StreamID) -> None:
print(f"!@# _handle_close: {self._id}: step=0")
async with self.streams_lock:
if stream_id not in self.streams:
# Ignore unmatched messages for now.
return
stream = self.streams[stream_id]
send_channel = self.streams_msg_channels[stream_id]
print(f"!@# _handle_close: {self._id}: step=1")
await send_channel.aclose()
print(f"!@# _handle_close: {self._id}: step=2")
# NOTE: If remote is already closed, then return: Technically a bug
# on the other side. We should consider killing the connection.
async with stream.close_lock:
if stream.event_remote_closed.is_set():
return
print(f"!@# _handle_close: {self._id}: step=3")
is_local_closed: bool
async with stream.close_lock:
stream.event_remote_closed.set()
is_local_closed = stream.event_local_closed.is_set()
print(f"!@# _handle_close: {self._id}: step=4")
# If local is also closed, both sides are closed. Then, we should clean up
# the entry of this stream, to avoid others from accessing it.
if is_local_closed:
async with self.streams_lock:
if stream_id in self.streams:
del self.streams[stream_id]
print(f"!@# _handle_close: {self._id}: step=5")
async def _handle_reset(self, stream_id: StreamID) -> None:
async with self.streams_lock:

View File

@ -1,30 +1,29 @@
from contextlib import AsyncExitStack, asynccontextmanager
from typing import Any, AsyncIterator, Dict, Tuple, cast
from typing import Any, AsyncIterator, Dict, Sequence, Tuple, cast
from async_service import background_trio_service
import factory
import trio
from libp2p.tools.constants import GOSSIPSUB_PARAMS
from libp2p import generate_new_rsa_identity, generate_peer_id_from
from libp2p.crypto.keys import KeyPair
from libp2p.host.basic_host import BasicHost
from libp2p.host.routed_host import RoutedHost
from libp2p.tools.utils import set_up_routers
from libp2p.kademlia.network import KademliaServer
from libp2p.host.host_interface import IHost
from libp2p.network.connection.swarm_connection import SwarmConn
from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.network.swarm import Swarm
from libp2p.peer.peerstore import PeerStore
from libp2p.peer.id import ID
from libp2p.peer.peerstore import PeerStore
from libp2p.pubsub.floodsub import FloodSub
from libp2p.pubsub.gossipsub import GossipSub
from libp2p.pubsub.pubsub import Pubsub
from libp2p.pubsub.pubsub_router_interface import IPubsubRouter
from libp2p.security.base_transport import BaseSecureTransport
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
import libp2p.security.secio.transport as secio
from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex
from libp2p.stream_muxer.mplex.mplex_stream import MplexStream
from libp2p.tools.constants import GOSSIPSUB_PARAMS
from libp2p.transport.tcp.tcp import TCP
from libp2p.transport.typing import TMuxerOptions
from libp2p.transport.upgrader import TransportUpgrader
@ -74,7 +73,7 @@ class SwarmFactory(factory.Factory):
@asynccontextmanager
async def create_and_listen(
cls, is_secure: bool, key_pair: KeyPair = None, muxer_opt: TMuxerOptions = None
) -> Swarm:
) -> AsyncIterator[Swarm]:
# `factory.Factory.__init__` does *not* prepare a *default value* if we pass
# an argument explicitly with `None`. If an argument is `None`, we don't pass it to
# `factory.Factory.__init__`, in order to let the function initialize it.
@ -92,7 +91,7 @@ class SwarmFactory(factory.Factory):
@asynccontextmanager
async def create_batch_and_listen(
cls, is_secure: bool, number: int, muxer_opt: TMuxerOptions = None
) -> Tuple[Swarm, ...]:
) -> AsyncIterator[Tuple[Swarm, ...]]:
async with AsyncExitStack() as stack:
ctx_mgrs = [
await stack.enter_async_context(
@ -100,7 +99,7 @@ class SwarmFactory(factory.Factory):
)
for _ in range(number)
]
yield ctx_mgrs
yield tuple(ctx_mgrs)
class HostFactory(factory.Factory):
@ -120,7 +119,7 @@ class HostFactory(factory.Factory):
@asynccontextmanager
async def create_batch_and_listen(
cls, is_secure: bool, number: int
) -> Tuple[BasicHost, ...]:
) -> AsyncIterator[Tuple[BasicHost, ...]]:
key_pairs = [generate_new_rsa_identity() for _ in range(number)]
async with AsyncExitStack() as stack:
swarms = [
@ -136,30 +135,6 @@ class HostFactory(factory.Factory):
yield hosts
class RoutedHostFactory(factory.Factory):
class Meta:
model = RoutedHost
public_key = factory.LazyAttribute(lambda o: o.key_pair.public_key)
network = factory.LazyAttribute(
lambda o: SwarmFactory(is_secure=o.is_secure, key_pair=o.key_pair)
)
router = factory.LazyFunction(KademliaServer)
@classmethod
@asynccontextmanager
async def create_batch_and_listen(
cls, is_secure: bool, number: int
) -> Tuple[RoutedHost, ...]:
key_pairs = [generate_new_rsa_identity() for _ in range(number)]
routers = await set_up_routers((0,) * number)
async with SwarmFactory.create_batch_and_listen(is_secure, number) as swarms:
yield tuple(
RoutedHost(key_pair.public_key, swarm, router)
for key_pair, swarm, router in zip(key_pairs, swarms, routers)
)
class FloodsubFactory(factory.Factory):
class Meta:
model = FloodSub
@ -191,17 +166,22 @@ class PubsubFactory(factory.Factory):
@classmethod
@asynccontextmanager
async def create_and_start(cls, host, router, cache_size):
async def create_and_start(
cls, host: IHost, router: IPubsubRouter, cache_size: int
) -> AsyncIterator[Pubsub]:
pubsub = PubsubFactory(host=host, router=router, cache_size=cache_size)
async with background_trio_service(pubsub):
yield pubsub
@classmethod
@asynccontextmanager
async def create_batch_with_floodsub(
cls, number: int, is_secure: bool = False, cache_size: int = None
):
floodsubs = FloodsubFactory.create_batch(number)
async def _create_batch_with_router(
cls,
number: int,
routers: Sequence[IPubsubRouter],
is_secure: bool = False,
cache_size: int = None,
) -> AsyncIterator[Tuple[Pubsub, ...]]:
async with HostFactory.create_batch_and_listen(is_secure, number) as hosts:
# Pubsubs should exit before hosts
async with AsyncExitStack() as stack:
@ -209,21 +189,80 @@ class PubsubFactory(factory.Factory):
await stack.enter_async_context(
cls.create_and_start(host, router, cache_size)
)
for host, router in zip(hosts, floodsubs)
for host, router in zip(hosts, routers)
]
yield pubsubs
yield tuple(pubsubs)
# @classmethod
# async def create_batch_with_gossipsub(
# cls, number: int, cache_size: int = None, gossipsub_params=GOSSIPSUB_PARAMS
# ):
# ...
@classmethod
@asynccontextmanager
async def create_batch_with_floodsub(
cls,
number: int,
is_secure: bool = False,
cache_size: int = None,
protocols: Sequence[TProtocol] = None,
) -> AsyncIterator[Tuple[Pubsub, ...]]:
if protocols is not None:
floodsubs = FloodsubFactory.create_batch(number, protocols=list(protocols))
else:
floodsubs = FloodsubFactory.create_batch(number)
async with cls._create_batch_with_router(
number, floodsubs, is_secure, cache_size
) as pubsubs:
yield pubsubs
@classmethod
@asynccontextmanager
async def create_batch_with_gossipsub(
cls,
number: int,
*,
is_secure: bool = False,
cache_size: int = None,
protocols: Sequence[TProtocol] = None,
degree: int = GOSSIPSUB_PARAMS.degree,
degree_low: int = GOSSIPSUB_PARAMS.degree_low,
degree_high: int = GOSSIPSUB_PARAMS.degree_high,
time_to_live: int = GOSSIPSUB_PARAMS.time_to_live,
gossip_window: int = GOSSIPSUB_PARAMS.gossip_window,
gossip_history: int = GOSSIPSUB_PARAMS.gossip_history,
heartbeat_interval: float = GOSSIPSUB_PARAMS.heartbeat_interval,
) -> AsyncIterator[Tuple[Pubsub, ...]]:
if protocols is not None:
gossipsubs = GossipsubFactory.create_batch(
number,
protocols=protocols,
degree=degree,
degree_low=degree_low,
degree_high=degree_high,
time_to_live=time_to_live,
gossip_window=gossip_window,
heartbeat_interval=heartbeat_interval,
)
else:
gossipsubs = GossipsubFactory.create_batch(
number,
degree=degree,
degree_low=degree_low,
degree_high=degree_high,
time_to_live=time_to_live,
gossip_window=gossip_window,
heartbeat_interval=heartbeat_interval,
)
async with cls._create_batch_with_router(
number, gossipsubs, is_secure, cache_size
) as pubsubs:
async with AsyncExitStack() as stack:
for router in gossipsubs:
await stack.enter_async_context(background_trio_service(router))
yield pubsubs
@asynccontextmanager
async def swarm_pair_factory(
is_secure: bool, muxer_opt: TMuxerOptions = None
) -> Tuple[Swarm, Swarm]:
) -> AsyncIterator[Tuple[Swarm, Swarm]]:
async with SwarmFactory.create_batch_and_listen(
is_secure, 2, muxer_opt=muxer_opt
) as swarms:
@ -232,7 +271,9 @@ async def swarm_pair_factory(
@asynccontextmanager
async def host_pair_factory(is_secure: bool) -> Tuple[BasicHost, BasicHost]:
async def host_pair_factory(
is_secure: bool
) -> AsyncIterator[Tuple[BasicHost, BasicHost]]:
async with HostFactory.create_batch_and_listen(is_secure, 2) as hosts:
await connect(hosts[0], hosts[1])
yield hosts[0], hosts[1]
@ -241,7 +282,7 @@ async def host_pair_factory(is_secure: bool) -> Tuple[BasicHost, BasicHost]:
@asynccontextmanager
async def swarm_conn_pair_factory(
is_secure: bool, muxer_opt: TMuxerOptions = None
) -> Tuple[SwarmConn, SwarmConn]:
) -> AsyncIterator[Tuple[SwarmConn, SwarmConn]]:
async with swarm_pair_factory(is_secure) as swarms:
conn_0 = swarms[0].connections[swarms[1].get_peer_id()]
conn_1 = swarms[1].connections[swarms[0].get_peer_id()]
@ -249,7 +290,9 @@ async def swarm_conn_pair_factory(
@asynccontextmanager
async def mplex_conn_pair_factory(is_secure: bool) -> Tuple[Mplex, Mplex]:
async def mplex_conn_pair_factory(
is_secure: bool
) -> AsyncIterator[Tuple[Mplex, Mplex]]:
muxer_opt = {MPLEX_PROTOCOL_ID: Mplex}
async with swarm_conn_pair_factory(is_secure, muxer_opt=muxer_opt) as swarm_pair:
yield (
@ -259,21 +302,25 @@ async def mplex_conn_pair_factory(is_secure: bool) -> Tuple[Mplex, Mplex]:
@asynccontextmanager
async def mplex_stream_pair_factory(is_secure: bool) -> Tuple[MplexStream, MplexStream]:
async def mplex_stream_pair_factory(
is_secure: bool
) -> AsyncIterator[Tuple[MplexStream, MplexStream]]:
async with mplex_conn_pair_factory(is_secure) as mplex_conn_pair_info:
mplex_conn_0, mplex_conn_1 = mplex_conn_pair_info
stream_0 = await mplex_conn_0.open_stream()
stream_0 = cast(MplexStream, await mplex_conn_0.open_stream())
await trio.sleep(0.01)
stream_1: MplexStream
async with mplex_conn_1.streams_lock:
if len(mplex_conn_1.streams) != 1:
raise Exception("Mplex should not have any other stream")
stream_1 = tuple(mplex_conn_1.streams.values())[0]
yield cast(MplexStream, stream_0), cast(MplexStream, stream_1)
yield stream_0, stream_1
@asynccontextmanager
async def net_stream_pair_factory(is_secure: bool) -> Tuple[INetStream, INetStream]:
async def net_stream_pair_factory(
is_secure: bool
) -> AsyncIterator[Tuple[INetStream, INetStream]]:
protocol_id = TProtocol("/example/id/1")
stream_1: INetStream

View File

@ -1,12 +1,11 @@
import asyncio
from typing import Dict
import uuid
from contextlib import AsyncExitStack, asynccontextmanager
from typing import AsyncIterator, Dict, Tuple
from async_service import Service, background_trio_service
from libp2p.host.host_interface import IHost
from libp2p.pubsub.floodsub import FloodSub
from libp2p.pubsub.pubsub import Pubsub
from libp2p.tools.constants import LISTEN_MADDR
from libp2p.tools.factories import FloodsubFactory, PubsubFactory
from libp2p.tools.factories import PubsubFactory
CRYPTO_TOPIC = "ethereum"
@ -18,7 +17,7 @@ CRYPTO_TOPIC = "ethereum"
# Determine message type by looking at first item before first comma
class DummyAccountNode:
class DummyAccountNode(Service):
"""
Node which has an internal balance mapping, meant to serve as a dummy
crypto blockchain.
@ -27,19 +26,24 @@ class DummyAccountNode:
crypto each user in the mappings holds
"""
libp2p_node: IHost
pubsub: Pubsub
floodsub: FloodSub
def __init__(self, libp2p_node: IHost, pubsub: Pubsub, floodsub: FloodSub):
self.libp2p_node = libp2p_node
def __init__(self, pubsub: Pubsub) -> None:
self.pubsub = pubsub
self.floodsub = floodsub
self.balances: Dict[str, int] = {}
self.node_id = str(uuid.uuid1())
@property
def host(self) -> IHost:
return self.pubsub.host
async def run(self) -> None:
self.subscription = await self.pubsub.subscribe(CRYPTO_TOPIC)
self.manager.run_daemon_task(self.handle_incoming_msgs)
await self.manager.wait_finished()
@classmethod
async def create(cls) -> "DummyAccountNode":
@asynccontextmanager
async def create(cls, number: int) -> AsyncIterator[Tuple["DummyAccountNode", ...]]:
"""
Create a new DummyAccountNode and attach a libp2p node, a floodsub, and
a pubsub instance to this new node.
@ -47,15 +51,17 @@ class DummyAccountNode:
We use create as this serves as a factory function and allows us
to use async await, unlike the init function
"""
pubsub = PubsubFactory(router=FloodsubFactory())
await pubsub.host.get_network().listen(LISTEN_MADDR)
return cls(libp2p_node=pubsub.host, pubsub=pubsub, floodsub=pubsub.router)
async with PubsubFactory.create_batch_with_floodsub(number) as pubsubs:
async with AsyncExitStack() as stack:
dummy_acount_nodes = tuple(cls(pubsub) for pubsub in pubsubs)
for node in dummy_acount_nodes:
await stack.enter_async_context(background_trio_service(node))
yield dummy_acount_nodes
async def handle_incoming_msgs(self) -> None:
"""Handle all incoming messages on the CRYPTO_TOPIC from peers."""
while True:
incoming = await self.q.get()
incoming = await self.subscription.receive()
msg_comps = incoming.data.decode("utf-8").split(",")
if msg_comps[0] == "send":
@ -63,13 +69,6 @@ class DummyAccountNode:
elif msg_comps[0] == "set":
self.handle_set_crypto(msg_comps[1], int(msg_comps[2]))
async def setup_crypto_networking(self) -> None:
"""Subscribe to CRYPTO_TOPIC and perform call to function that handles
all incoming messages on said topic."""
self.q = await self.pubsub.subscribe(CRYPTO_TOPIC)
asyncio.ensure_future(self.handle_incoming_msgs())
async def publish_send_crypto(
self, source_user: str, dest_user: str, amount: int
) -> None:

View File

@ -1,12 +1,10 @@
# type: ignore
# To add typing to this module, it's better to do it after refactoring test cases into classes
import asyncio
import pytest
import trio
from libp2p.tools.constants import FLOODSUB_PROTOCOL_ID, LISTEN_MADDR
from libp2p.tools.factories import PubsubFactory
from libp2p.tools.constants import FLOODSUB_PROTOCOL_ID
from libp2p.tools.utils import connect
SUPPORTED_PROTOCOLS = [FLOODSUB_PROTOCOL_ID]
@ -15,6 +13,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{
"name": "simple_two_nodes",
"supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["A", "B"],
"adj_list": {"A": ["B"]},
"topic_map": {"topic1": ["B"]},
"messages": [{"topics": ["topic1"], "data": b"foo", "node_id": "A"}],
@ -22,6 +21,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{
"name": "three_nodes_two_topics",
"supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["A", "B", "C"],
"adj_list": {"A": ["B"], "B": ["C"]},
"topic_map": {"topic1": ["B", "C"], "topic2": ["B", "C"]},
"messages": [
@ -32,6 +32,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{
"name": "two_nodes_one_topic_single_subscriber_is_sender",
"supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["A", "B"],
"adj_list": {"A": ["B"]},
"topic_map": {"topic1": ["B"]},
"messages": [{"topics": ["topic1"], "data": b"Alex is tall", "node_id": "B"}],
@ -39,6 +40,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{
"name": "two_nodes_one_topic_two_msgs",
"supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["A", "B"],
"adj_list": {"A": ["B"]},
"topic_map": {"topic1": ["B"]},
"messages": [
@ -49,6 +51,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{
"name": "seven_nodes_tree_one_topics",
"supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["1", "2", "3", "4", "5", "6", "7"],
"adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]},
"topic_map": {"astrophysics": ["2", "3", "4", "5", "6", "7"]},
"messages": [{"topics": ["astrophysics"], "data": b"e=mc^2", "node_id": "1"}],
@ -56,6 +59,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{
"name": "seven_nodes_tree_three_topics",
"supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["1", "2", "3", "4", "5", "6", "7"],
"adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]},
"topic_map": {
"astrophysics": ["2", "3", "4", "5", "6", "7"],
@ -71,6 +75,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{
"name": "seven_nodes_tree_three_topics_diff_origin",
"supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["1", "2", "3", "4", "5", "6", "7"],
"adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]},
"topic_map": {
"astrophysics": ["1", "2", "3", "4", "5", "6", "7"],
@ -86,6 +91,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{
"name": "three_nodes_clique_two_topic_diff_origin",
"supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["1", "2", "3"],
"adj_list": {"1": ["2", "3"], "2": ["3"]},
"topic_map": {"astrophysics": ["1", "2", "3"], "school": ["1", "2", "3"]},
"messages": [
@ -97,6 +103,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{
"name": "four_nodes_clique_two_topic_diff_origin_many_msgs",
"supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["1", "2", "3", "4"],
"adj_list": {
"1": ["2", "3", "4"],
"2": ["1", "3", "4"],
@ -120,6 +127,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{
"name": "five_nodes_ring_two_topic_diff_origin_many_msgs",
"supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["1", "2", "3", "4", "5"],
"adj_list": {"1": ["2"], "2": ["3"], "3": ["4"], "4": ["5"], "5": ["1"]},
"topic_map": {
"astrophysics": ["1", "2", "3", "4", "5"],
@ -143,7 +151,7 @@ floodsub_protocol_pytest_params = [
]
async def perform_test_from_obj(obj, router_factory) -> None:
async def perform_test_from_obj(obj, pubsub_factory) -> None:
"""
Perform pubsub tests from a test obj.
test obj are composed as follows:
@ -174,88 +182,75 @@ async def perform_test_from_obj(obj, router_factory) -> None:
# Step 1) Create graph
adj_list = obj["adj_list"]
node_list = obj["nodes"]
node_map = {}
pubsub_map = {}
async def add_node(node_id_str: str) -> None:
pubsub_router = router_factory(protocols=obj["supported_protocols"])
pubsub = PubsubFactory(router=pubsub_router)
await pubsub.host.get_network().listen(LISTEN_MADDR)
node_map[node_id_str] = pubsub.host
pubsub_map[node_id_str] = pubsub
async with pubsub_factory(
number=len(node_list), protocols=obj["supported_protocols"]
) as pubsubs:
for node_id_str, pubsub in zip(node_list, pubsubs):
node_map[node_id_str] = pubsub.host
pubsub_map[node_id_str] = pubsub
tasks_connect = []
for start_node_id in adj_list:
# Create node if node does not yet exist
if start_node_id not in node_map:
await add_node(start_node_id)
# Connect nodes and wait at least for 2 seconds
async with trio.open_nursery() as nursery:
for start_node_id in adj_list:
# For each neighbor of start_node, create if does not yet exist,
# then connect start_node to neighbor
for neighbor_id in adj_list[start_node_id]:
nursery.start_soon(
connect, node_map[start_node_id], node_map[neighbor_id]
)
nursery.start_soon(trio.sleep, 2)
# For each neighbor of start_node, create if does not yet exist,
# then connect start_node to neighbor
for neighbor_id in adj_list[start_node_id]:
# Create neighbor if neighbor does not yet exist
if neighbor_id not in node_map:
await add_node(neighbor_id)
tasks_connect.append(
connect(node_map[start_node_id], node_map[neighbor_id])
)
# Connect nodes and wait at least for 2 seconds
await asyncio.gather(*tasks_connect, asyncio.sleep(2))
# Step 2) Subscribe to topics
queues_map = {}
topic_map = obj["topic_map"]
# Step 2) Subscribe to topics
queues_map = {}
topic_map = obj["topic_map"]
async def subscribe_node(node_id, topic):
if node_id not in queues_map:
queues_map[node_id] = {}
# Avoid repeated works
if topic in queues_map[node_id]:
# Checkpoint
await trio.hazmat.checkpoint()
return
sub = await pubsub_map[node_id].subscribe(topic)
queues_map[node_id][topic] = sub
tasks_topic = []
tasks_topic_data = []
for topic, node_ids in topic_map.items():
for node_id in node_ids:
tasks_topic.append(pubsub_map[node_id].subscribe(topic))
tasks_topic_data.append((node_id, topic))
tasks_topic.append(asyncio.sleep(2))
async with trio.open_nursery() as nursery:
for topic, node_ids in topic_map.items():
for node_id in node_ids:
nursery.start_soon(subscribe_node, node_id, topic)
nursery.start_soon(trio.sleep, 2)
# Gather is like Promise.all
responses = await asyncio.gather(*tasks_topic)
for i in range(len(responses) - 1):
node_id, topic = tasks_topic_data[i]
if node_id not in queues_map:
queues_map[node_id] = {}
# Store queue in topic-queue map for node
queues_map[node_id][topic] = responses[i]
# Step 3) Publish messages
topics_in_msgs_ordered = []
messages = obj["messages"]
# Allow time for subscribing before continuing
await asyncio.sleep(0.01)
for msg in messages:
topics = msg["topics"]
data = msg["data"]
node_id = msg["node_id"]
# Step 3) Publish messages
topics_in_msgs_ordered = []
messages = obj["messages"]
tasks_publish = []
# Publish message
# TODO: Should be single RPC package with several topics
for topic in topics:
await pubsub_map[node_id].publish(topic, data)
for msg in messages:
topics = msg["topics"]
data = msg["data"]
node_id = msg["node_id"]
# For each topic in topics, add (topic, node_id, data) tuple to ordered test list
for topic in topics:
topics_in_msgs_ordered.append((topic, node_id, data))
# Allow time for publishing before continuing
await trio.sleep(1)
# Publish message
# TODO: Should be single RPC package with several topics
for topic in topics:
tasks_publish.append(pubsub_map[node_id].publish(topic, data))
# For each topic in topics, add (topic, node_id, data) tuple to ordered test list
for topic in topics:
topics_in_msgs_ordered.append((topic, node_id, data))
# Allow time for publishing before continuing
await asyncio.gather(*tasks_publish, asyncio.sleep(2))
# Step 4) Check that all messages were received correctly.
for topic, origin_node_id, data in topics_in_msgs_ordered:
# Look at each node in each topic
for node_id in topic_map[topic]:
# Get message from subscription queue
msg = await queues_map[node_id][topic].get()
assert data == msg.data
# Check the message origin
assert node_map[origin_node_id].get_id().to_bytes() == msg.from_id
# Success, terminate pending tasks.
# Step 4) Check that all messages were received correctly.
for topic, origin_node_id, data in topics_in_msgs_ordered:
# Look at each node in each topic
for node_id in topic_map[topic]:
# Get message from subscription queue
msg = await queues_map[node_id][topic].receive()
assert data == msg.data
# Check the message origin
assert node_map[origin_node_id].get_id().to_bytes() == msg.from_id

View File

@ -1,17 +1,9 @@
from typing import Callable, List, Sequence, Tuple
from typing import Awaitable, Callable
import multiaddr
import trio
from libp2p import new_node
from libp2p.host.basic_host import BasicHost
from libp2p.host.host_interface import IHost
from libp2p.kademlia.network import KademliaServer
from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.network.swarm import Swarm
from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.routing.interfaces import IPeerRouting
from libp2p.routing.kademlia.kademlia_peer_router import KadmeliaPeerRouter
from .constants import MAX_READ_LEN
@ -36,49 +28,9 @@ async def connect(node1: IHost, node2: IHost) -> None:
await node1.connect(info)
async def set_up_nodes_by_transport_opt(
transport_opt_list: Sequence[Sequence[str]], nursery: trio.Nursery
) -> Tuple[BasicHost, ...]:
nodes_list = []
for transport_opt in transport_opt_list:
node = new_node(transport_opt=transport_opt)
await node.get_network().listen(
multiaddr.Multiaddr(transport_opt[0]), nursery=nursery
)
nodes_list.append(node)
return tuple(nodes_list)
async def set_up_nodes_by_transport_and_disc_opt(
transport_disc_opt_list: Sequence[Tuple[Sequence[str], IPeerRouting]]
) -> Tuple[BasicHost, ...]:
nodes_list = []
for transport_opt, disc_opt in transport_disc_opt_list:
node = await new_node(transport_opt=transport_opt, disc_opt=disc_opt)
await node.get_network().listen(multiaddr.Multiaddr(transport_opt[0]))
nodes_list.append(node)
return tuple(nodes_list)
async def set_up_routers(
router_ports: Tuple[int, ...] = (0, 0)
) -> List[KadmeliaPeerRouter]:
"""The default ``router_confs`` selects two free ports local to this
machine."""
bootstrap_node = KademliaServer() # type: ignore
await bootstrap_node.listen(router_ports[0])
routers = [KadmeliaPeerRouter(bootstrap_node)]
for port in router_ports[1:]:
node = KademliaServer() # type: ignore
await node.listen(port)
await node.bootstrap_node(bootstrap_node.address)
routers.append(KadmeliaPeerRouter(node))
return routers
def create_echo_stream_handler(ack_prefix: str) -> Callable[[INetStream], None]:
def create_echo_stream_handler(
ack_prefix: str
) -> Callable[[INetStream], Awaitable[None]]:
async def echo_stream_handler(stream: INetStream) -> None:
while True:
read_string = (await stream.read(MAX_READ_LEN)).decode()

View File

@ -1,12 +1,13 @@
from abc import ABC, abstractmethod
from typing import List
from typing import Tuple
from multiaddr import Multiaddr
import trio
class IListener(ABC):
@abstractmethod
async def listen(self, maddr: Multiaddr) -> bool:
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool:
"""
put listener in listening mode and wait for incoming connections.
@ -15,14 +16,9 @@ class IListener(ABC):
"""
@abstractmethod
def get_addrs(self) -> List[Multiaddr]:
def get_addrs(self) -> Tuple[Multiaddr, ...]:
"""
retrieve list of addresses the listener is listening on.
:return: return list of addrs
"""
@abstractmethod
async def close(self) -> None:
"""close the listener such that no more connections can be open on this
transport instance."""

View File

@ -1,14 +1,13 @@
import logging
from socket import socket
from typing import List
from typing import Awaitable, Callable, List, Sequence, Tuple
from multiaddr import Multiaddr
import trio
from trio_typing import TaskStatus
from libp2p.io.trio import TrioReadWriteCloser
from libp2p.io.trio import TrioTCPStream
from libp2p.network.connection.raw_connection import RawConnection
from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.transport.exceptions import OpenConnectionError
from libp2p.transport.listener_interface import IListener
from libp2p.transport.transport_interface import ITransport
from libp2p.transport.typing import THandler
@ -18,14 +17,12 @@ logger = logging.getLogger("libp2p.transport.tcp")
class TCPListener(IListener):
multiaddrs: List[Multiaddr]
server = None
def __init__(self, handler_function: THandler) -> None:
self.multiaddrs = []
self.server = None
self.handler = handler_function
# TODO: Fix handling?
# TODO: Get rid of `nursery`?
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> None:
"""
put listener in listening mode and wait for incoming connections.
@ -34,13 +31,18 @@ class TCPListener(IListener):
:return: return True if successful
"""
async def serve_tcp(handler, port, host, task_status=None):
async def serve_tcp(
handler: Callable[[trio.SocketStream], Awaitable[None]],
port: int,
host: str,
task_status: TaskStatus[Sequence[trio.SocketListener]] = None,
) -> None:
logger.debug("serve_tcp %s %s", host, port)
await trio.serve_tcp(handler, port, host=host, task_status=task_status)
async def handler(stream):
read_write_closer = TrioReadWriteCloser(stream)
await self.handler(read_write_closer)
async def handler(stream: trio.SocketStream) -> None:
tcp_stream = TrioTCPStream(stream)
await self.handler(tcp_stream)
listeners = await nursery.start(
serve_tcp,
@ -51,7 +53,7 @@ class TCPListener(IListener):
socket = listeners[0].socket
self.multiaddrs.append(_multiaddr_from_socket(socket))
def get_addrs(self) -> List[Multiaddr]:
def get_addrs(self) -> Tuple[Multiaddr, ...]:
"""
retrieve list of addresses the listener is listening on.
@ -59,15 +61,6 @@ class TCPListener(IListener):
"""
return tuple(self.multiaddrs)
async def close(self) -> None:
"""close the listener such that no more connections can be open on this
transport instance."""
if self.server is None:
return
self.server.close()
await self.server.wait_closed()
self.server = None
class TCP(ITransport):
async def dial(self, maddr: Multiaddr) -> IRawConnection:
@ -82,7 +75,7 @@ class TCP(ITransport):
self.port = int(maddr.value_for_protocol("tcp"))
stream = await trio.open_tcp_stream(self.host, self.port)
read_write_closer = TrioReadWriteCloser(stream)
read_write_closer = TrioTCPStream(stream)
return RawConnection(read_write_closer, True)
@ -97,5 +90,6 @@ class TCP(ITransport):
return TCPListener(handler_function)
def _multiaddr_from_socket(socket: socket) -> Multiaddr:
return Multiaddr("/ip4/%s/tcp/%s" % socket.getsockname())
def _multiaddr_from_socket(socket: trio.socket.SocketType) -> Multiaddr:
ip, port = socket.getsockname() # type: ignore
return Multiaddr(f"/ip4/{ip}/tcp/{port}")