Fix Mplex and Swarm

This commit is contained in:
mhchia
2019-11-29 19:09:56 +08:00
parent ec43c25b45
commit 1e600ea7e0
13 changed files with 232 additions and 122 deletions

View File

@ -23,6 +23,10 @@ class TrioReadWriteCloser(ReadWriteCloser):
raise IOException(error) raise IOException(error)
async def read(self, n: int = -1) -> bytes: async def read(self, n: int = -1) -> bytes:
if n == 0:
# Check point
await trio.sleep(0)
return b""
max_bytes = n if n != -1 else None max_bytes = n if n != -1 else None
try: try:
return await self.stream.receive_some(max_bytes) return await self.stream.receive_some(max_bytes)

View File

@ -50,8 +50,11 @@ class SwarmConn(INetConn, Service):
await self._notify_disconnected() await self._notify_disconnected()
async def _handle_new_streams(self) -> None: async def _handle_new_streams(self) -> None:
while True: while self.manager.is_running:
try: try:
print(
f"!@# SwarmConn._handle_new_streams: {self.muxed_conn._id}: waiting for new streams"
)
stream = await self.muxed_conn.accept_stream() stream = await self.muxed_conn.accept_stream()
except MuxedConnUnavailable: except MuxedConnUnavailable:
# If there is anything wrong in the MuxedConn, # If there is anything wrong in the MuxedConn,
@ -60,6 +63,9 @@ class SwarmConn(INetConn, Service):
# Asynchronously handle the accepted stream, to avoid blocking the next stream. # Asynchronously handle the accepted stream, to avoid blocking the next stream.
self.manager.run_task(self._handle_muxed_stream, stream) self.manager.run_task(self._handle_muxed_stream, stream)
print(
f"!@# SwarmConn._handle_new_streams: {self.muxed_conn._id}: out of the loop"
)
await self.close() await self.close()
async def _call_stream_handler(self, net_stream: NetStream) -> None: async def _call_stream_handler(self, net_stream: NetStream) -> None:

View File

@ -206,8 +206,7 @@ class Swarm(INetwork, Service):
logger.debug("successfully opened connection to peer %s", peer_id) logger.debug("successfully opened connection to peer %s", peer_id)
# FIXME: This is a intentional barrier to prevent from the handler exiting and # FIXME: This is a intentional barrier to prevent from the handler exiting and
# closing the connection. # closing the connection.
event = trio.Event() await trio.sleep_forever()
await event.wait()
try: try:
# Success # Success
@ -240,7 +239,7 @@ class Swarm(INetwork, Service):
# await asyncio.gather( # await asyncio.gather(
# *[connection.close() for connection in self.connections.values()] # *[connection.close() for connection in self.connections.values()]
# ) # )
self.manager.stop() await self.manager.stop()
await self.manager.wait_finished() await self.manager.wait_finished()
logger.debug("swarm successfully closed") logger.debug("swarm successfully closed")

View File

@ -1,3 +1,4 @@
import math
import asyncio import asyncio
import logging import logging
from typing import Any # noqa: F401 from typing import Any # noqa: F401
@ -18,7 +19,6 @@ from libp2p.utils import (
encode_uvarint, encode_uvarint,
encode_varint_prefixed, encode_varint_prefixed,
read_varint_prefixed_bytes, read_varint_prefixed_bytes,
TrioQueue,
) )
from .constants import HeaderTags from .constants import HeaderTags
@ -41,7 +41,10 @@ class Mplex(IMuxedConn, Service):
next_channel_id: int next_channel_id: int
streams: Dict[StreamID, MplexStream] streams: Dict[StreamID, MplexStream]
streams_lock: trio.Lock streams_lock: trio.Lock
new_stream_queue: "TrioQueue[IMuxedStream]" streams_msg_channels: Dict[StreamID, "trio.MemorySendChannel[bytes]"]
new_stream_send_channel: "trio.MemorySendChannel[IMuxedStream]"
new_stream_receive_channel: "trio.MemoryReceiveChannel[IMuxedStream]"
event_shutting_down: trio.Event event_shutting_down: trio.Event
event_closed: trio.Event event_closed: trio.Event
@ -64,7 +67,10 @@ class Mplex(IMuxedConn, Service):
# Mapping from stream ID -> buffer of messages for that stream # Mapping from stream ID -> buffer of messages for that stream
self.streams = {} self.streams = {}
self.streams_lock = trio.Lock() self.streams_lock = trio.Lock()
self.new_stream_queue = TrioQueue() self.streams_msg_channels = {}
send_channel, receive_channel = trio.open_memory_channel(math.inf)
self.new_stream_send_channel = send_channel
self.new_stream_receive_channel = receive_channel
self.event_shutting_down = trio.Event() self.event_shutting_down = trio.Event()
self.event_closed = trio.Event() self.event_closed = trio.Event()
@ -105,9 +111,13 @@ class Mplex(IMuxedConn, Service):
return next_id return next_id
async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream: async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream:
stream = MplexStream(name, stream_id, self) # Use an unbounded buffer, to avoid `handle_incoming` being blocked when doing
# `send_channel.send`.
send_channel, receive_channel = trio.open_memory_channel(math.inf)
stream = MplexStream(name, stream_id, self, receive_channel)
async with self.streams_lock: async with self.streams_lock:
self.streams[stream_id] = stream self.streams[stream_id] = stream
self.streams_msg_channels[stream_id] = send_channel
return stream return stream
async def open_stream(self) -> IMuxedStream: async def open_stream(self) -> IMuxedStream:
@ -126,7 +136,10 @@ class Mplex(IMuxedConn, Service):
async def accept_stream(self) -> IMuxedStream: async def accept_stream(self) -> IMuxedStream:
"""accepts a muxed stream opened by the other end.""" """accepts a muxed stream opened by the other end."""
return await self.new_stream_queue.get() try:
return await self.new_stream_receive_channel.receive()
except (trio.ClosedResourceError, trio.EndOfChannel):
raise MplexUnavailable
async def send_message( async def send_message(
self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID
@ -138,6 +151,9 @@ class Mplex(IMuxedConn, Service):
:param data: data to send in the message :param data: data to send in the message
:param stream_id: stream the message is in :param stream_id: stream the message is in
""" """
print(
f"!@# send_message: {self._id}: flag={flag}, data={data}, stream_id={stream_id}"
)
# << by 3, then or with flag # << by 3, then or with flag
header = encode_uvarint((stream_id.channel_id << 3) | flag.value) header = encode_uvarint((stream_id.channel_id << 3) | flag.value)
@ -162,14 +178,21 @@ class Mplex(IMuxedConn, Service):
"""Read a message off of the secured connection and add it to the """Read a message off of the secured connection and add it to the
corresponding message buffer.""" corresponding message buffer."""
while True: while self.manager.is_running:
try: try:
print(
f"!@# handle_incoming: {self._id}: before _handle_incoming_message"
)
await self._handle_incoming_message() await self._handle_incoming_message()
print(
f"!@# handle_incoming: {self._id}: after _handle_incoming_message"
)
except MplexUnavailable as e: except MplexUnavailable as e:
logger.debug("mplex unavailable while waiting for incoming: %s", e) logger.debug("mplex unavailable while waiting for incoming: %s", e)
print(f"!@# handle_incoming: {self._id}: MplexUnavailable: {e}")
break break
# Force context switch
await trio.sleep(0) print(f"!@# handle_incoming: {self._id}: leaving")
# If we enter here, it means this connection is shutting down. # If we enter here, it means this connection is shutting down.
# We should clean things up. # We should clean things up.
await self._cleanup() await self._cleanup()
@ -181,51 +204,73 @@ class Mplex(IMuxedConn, Service):
:return: stream_id, flag, message contents :return: stream_id, flag, message contents
""" """
# FIXME: No timeout is used in Go implementation.
try: try:
header = await decode_uvarint_from_stream(self.secured_conn) header = await decode_uvarint_from_stream(self.secured_conn)
except (ParseError, RawConnError, IncompleteReadError) as error:
raise MplexUnavailable(
f"failed to read the header correctly from the underlying connection: {error}"
)
try:
message = await read_varint_prefixed_bytes(self.secured_conn) message = await read_varint_prefixed_bytes(self.secured_conn)
except (ParseError, RawConnError, IncompleteReadError) as error: except (ParseError, RawConnError, IncompleteReadError) as error:
raise MplexUnavailable( raise MplexUnavailable(
"failed to read messages correctly from the underlying connection" "failed to read the message body correctly from the underlying connection: "
) from error f"{error}"
except asyncio.TimeoutError as error: )
raise MplexUnavailable(
"failed to read more message body within the timeout"
) from error
flag = header & 0x07 flag = header & 0x07
channel_id = header >> 3 channel_id = header >> 3
return channel_id, flag, message return channel_id, flag, message
@property
def _id(self) -> int:
return 0 if self.is_initiator else 1
async def _handle_incoming_message(self) -> None: async def _handle_incoming_message(self) -> None:
""" """
Read and handle a new incoming message. Read and handle a new incoming message.
:raise MplexUnavailable: `Mplex` encounters fatal error or is shutting down. :raise MplexUnavailable: `Mplex` encounters fatal error or is shutting down.
""" """
print(f"!@# _handle_incoming_message: {self._id}: before reading")
channel_id, flag, message = await self.read_message() channel_id, flag, message = await self.read_message()
print(
f"!@# _handle_incoming_message: {self._id}: channel_id={channel_id}, flag={flag}, message={message}"
)
stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1)) stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1))
print(f"!@# _handle_incoming_message: {self._id}: 2")
if flag == HeaderTags.NewStream.value: if flag == HeaderTags.NewStream.value:
print(f"!@# _handle_incoming_message: {self._id}: 3")
await self._handle_new_stream(stream_id, message) await self._handle_new_stream(stream_id, message)
print(f"!@# _handle_incoming_message: {self._id}: 4")
elif flag in ( elif flag in (
HeaderTags.MessageInitiator.value, HeaderTags.MessageInitiator.value,
HeaderTags.MessageReceiver.value, HeaderTags.MessageReceiver.value,
): ):
print(f"!@# _handle_incoming_message: {self._id}: 5")
await self._handle_message(stream_id, message) await self._handle_message(stream_id, message)
print(f"!@# _handle_incoming_message: {self._id}: 6")
elif flag in (HeaderTags.CloseInitiator.value, HeaderTags.CloseReceiver.value): elif flag in (HeaderTags.CloseInitiator.value, HeaderTags.CloseReceiver.value):
print(f"!@# _handle_incoming_message: {self._id}: 7")
await self._handle_close(stream_id) await self._handle_close(stream_id)
print(f"!@# _handle_incoming_message: {self._id}: 8")
elif flag in (HeaderTags.ResetInitiator.value, HeaderTags.ResetReceiver.value): elif flag in (HeaderTags.ResetInitiator.value, HeaderTags.ResetReceiver.value):
print(f"!@# _handle_incoming_message: {self._id}: 9")
await self._handle_reset(stream_id) await self._handle_reset(stream_id)
print(f"!@# _handle_incoming_message: {self._id}: 10")
else: else:
print(f"!@# _handle_incoming_message: {self._id}: 11")
# Receives messages with an unknown flag # Receives messages with an unknown flag
# TODO: logging # TODO: logging
async with self.streams_lock: async with self.streams_lock:
print(f"!@# _handle_incoming_message: {self._id}: 12")
if stream_id in self.streams: if stream_id in self.streams:
print(f"!@# _handle_incoming_message: {self._id}: 13")
stream = self.streams[stream_id] stream = self.streams[stream_id]
await stream.reset() await stream.reset()
print(f"!@# _handle_incoming_message: {self._id}: 14")
async def _handle_new_stream(self, stream_id: StreamID, message: bytes) -> None: async def _handle_new_stream(self, stream_id: StreamID, message: bytes) -> None:
async with self.streams_lock: async with self.streams_lock:
@ -235,43 +280,65 @@ class Mplex(IMuxedConn, Service):
f"received NewStream message for existing stream: {stream_id}" f"received NewStream message for existing stream: {stream_id}"
) )
mplex_stream = await self._initialize_stream(stream_id, message.decode()) mplex_stream = await self._initialize_stream(stream_id, message.decode())
await self.new_stream_queue.put(mplex_stream) try:
await self.new_stream_send_channel.send(mplex_stream)
except (trio.BrokenResourceError, trio.EndOfChannel):
raise MplexUnavailable
async def _handle_message(self, stream_id: StreamID, message: bytes) -> None: async def _handle_message(self, stream_id: StreamID, message: bytes) -> None:
print(
f"!@# _handle_message: {self._id}: stream_id={stream_id}, message={message}"
)
async with self.streams_lock: async with self.streams_lock:
print(f"!@# _handle_message: {self._id}: 1")
if stream_id not in self.streams: if stream_id not in self.streams:
# We receive a message of the stream `stream_id` which is not accepted # We receive a message of the stream `stream_id` which is not accepted
# before. It is abnormal. Possibly disconnect? # before. It is abnormal. Possibly disconnect?
# TODO: Warn and emit logs about this. # TODO: Warn and emit logs about this.
print(f"!@# _handle_message: {self._id}: 2")
return return
print(f"!@# _handle_message: {self._id}: 3")
stream = self.streams[stream_id] stream = self.streams[stream_id]
send_channel = self.streams_msg_channels[stream_id]
async with stream.close_lock: async with stream.close_lock:
print(f"!@# _handle_message: {self._id}: 4")
if stream.event_remote_closed.is_set(): if stream.event_remote_closed.is_set():
print(f"!@# _handle_message: {self._id}: 5")
# TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501 # TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501
return return
await stream.incoming_data.put(message) print(f"!@# _handle_message: {self._id}: 6")
await send_channel.send(message)
print(f"!@# _handle_message: {self._id}: 7")
async def _handle_close(self, stream_id: StreamID) -> None: async def _handle_close(self, stream_id: StreamID) -> None:
print(f"!@# _handle_close: {self._id}: step=0")
async with self.streams_lock: async with self.streams_lock:
if stream_id not in self.streams: if stream_id not in self.streams:
# Ignore unmatched messages for now. # Ignore unmatched messages for now.
return return
stream = self.streams[stream_id] stream = self.streams[stream_id]
send_channel = self.streams_msg_channels[stream_id]
print(f"!@# _handle_close: {self._id}: step=1")
await send_channel.aclose()
print(f"!@# _handle_close: {self._id}: step=2")
# NOTE: If remote is already closed, then return: Technically a bug # NOTE: If remote is already closed, then return: Technically a bug
# on the other side. We should consider killing the connection. # on the other side. We should consider killing the connection.
async with stream.close_lock: async with stream.close_lock:
if stream.event_remote_closed.is_set(): if stream.event_remote_closed.is_set():
return return
print(f"!@# _handle_close: {self._id}: step=3")
is_local_closed: bool is_local_closed: bool
async with stream.close_lock: async with stream.close_lock:
stream.event_remote_closed.set() stream.event_remote_closed.set()
is_local_closed = stream.event_local_closed.is_set() is_local_closed = stream.event_local_closed.is_set()
print(f"!@# _handle_close: {self._id}: step=4")
# If local is also closed, both sides are closed. Then, we should clean up # If local is also closed, both sides are closed. Then, we should clean up
# the entry of this stream, to avoid others from accessing it. # the entry of this stream, to avoid others from accessing it.
if is_local_closed: if is_local_closed:
async with self.streams_lock: async with self.streams_lock:
if stream_id in self.streams: if stream_id in self.streams:
del self.streams[stream_id] del self.streams[stream_id]
print(f"!@# _handle_close: {self._id}: step=5")
async def _handle_reset(self, stream_id: StreamID) -> None: async def _handle_reset(self, stream_id: StreamID) -> None:
async with self.streams_lock: async with self.streams_lock:
@ -279,11 +346,11 @@ class Mplex(IMuxedConn, Service):
# This is *ok*. We forget the stream on reset. # This is *ok*. We forget the stream on reset.
return return
stream = self.streams[stream_id] stream = self.streams[stream_id]
send_channel = self.streams_msg_channels[stream_id]
await send_channel.aclose()
async with stream.close_lock: async with stream.close_lock:
if not stream.event_remote_closed.is_set(): if not stream.event_remote_closed.is_set():
stream.event_reset.set() stream.event_reset.set()
stream.event_remote_closed.set() stream.event_remote_closed.set()
# If local is not closed, we should close it. # If local is not closed, we should close it.
if not stream.event_local_closed.is_set(): if not stream.event_local_closed.is_set():
@ -291,16 +358,21 @@ class Mplex(IMuxedConn, Service):
async with self.streams_lock: async with self.streams_lock:
if stream_id in self.streams: if stream_id in self.streams:
del self.streams[stream_id] del self.streams[stream_id]
del self.streams_msg_channels[stream_id]
async def _cleanup(self) -> None: async def _cleanup(self) -> None:
if not self.event_shutting_down.is_set(): if not self.event_shutting_down.is_set():
self.event_shutting_down.set() self.event_shutting_down.set()
async with self.streams_lock: async with self.streams_lock:
for stream in self.streams.values(): for stream_id, stream in self.streams.items():
async with stream.close_lock: async with stream.close_lock:
if not stream.event_remote_closed.is_set(): if not stream.event_remote_closed.is_set():
stream.event_remote_closed.set() stream.event_remote_closed.set()
stream.event_reset.set() stream.event_reset.set()
stream.event_local_closed.set() stream.event_local_closed.set()
send_channel = self.streams_msg_channels[stream_id]
await send_channel.aclose()
self.streams = None self.streams = None
self.event_closed.set() self.event_closed.set()
await self.new_stream_send_channel.aclose()
await self.new_stream_receive_channel.aclose()

View File

@ -3,7 +3,6 @@ from typing import TYPE_CHECKING
import trio import trio
from libp2p.stream_muxer.abc import IMuxedStream from libp2p.stream_muxer.abc import IMuxedStream
from libp2p.utils import IQueue, TrioQueue
from .constants import HeaderTags from .constants import HeaderTags
from .datastructures import StreamID from .datastructures import StreamID
@ -24,10 +23,11 @@ class MplexStream(IMuxedStream):
read_deadline: int read_deadline: int
write_deadline: int write_deadline: int
# TODO: Add lock for read/write to avoid interleaving receiving messages?
close_lock: trio.Lock close_lock: trio.Lock
# NOTE: `dataIn` is size of 8 in Go implementation. # NOTE: `dataIn` is size of 8 in Go implementation.
incoming_data: IQueue[bytes] incoming_data_channel: "trio.MemoryReceiveChannel[bytes]"
event_local_closed: trio.Event event_local_closed: trio.Event
event_remote_closed: trio.Event event_remote_closed: trio.Event
@ -35,7 +35,13 @@ class MplexStream(IMuxedStream):
_buf: bytearray _buf: bytearray
def __init__(self, name: str, stream_id: StreamID, muxed_conn: "Mplex") -> None: def __init__(
self,
name: str,
stream_id: StreamID,
muxed_conn: "Mplex",
incoming_data_channel: "trio.MemoryReceiveChannel[bytes]",
) -> None:
""" """
create new MuxedStream in muxer. create new MuxedStream in muxer.
@ -51,13 +57,30 @@ class MplexStream(IMuxedStream):
self.event_remote_closed = trio.Event() self.event_remote_closed = trio.Event()
self.event_reset = trio.Event() self.event_reset = trio.Event()
self.close_lock = trio.Lock() self.close_lock = trio.Lock()
self.incoming_data = TrioQueue() self.incoming_data_channel = incoming_data_channel
self._buf = bytearray() self._buf = bytearray()
@property @property
def is_initiator(self) -> bool: def is_initiator(self) -> bool:
return self.stream_id.is_initiator return self.stream_id.is_initiator
async def _read_until_eof(self) -> bytes:
async for data in self.incoming_data_channel:
self._buf.extend(data)
payload = self._buf
self._buf = self._buf[len(payload) :]
return bytes(payload)
def _read_return_when_blocked(self) -> bytes:
buf = bytearray()
while True:
try:
data = self.incoming_data_channel.receive_nowait()
buf.extend(data)
except (trio.WouldBlock, trio.EndOfChannel):
break
return buf
async def read(self, n: int = -1) -> bytes: async def read(self, n: int = -1) -> bytes:
""" """
Read up to n bytes. Read possibly returns fewer than `n` bytes, if Read up to n bytes. Read possibly returns fewer than `n` bytes, if
@ -73,7 +96,40 @@ class MplexStream(IMuxedStream):
) )
if self.event_reset.is_set(): if self.event_reset.is_set():
raise MplexStreamReset raise MplexStreamReset
return await self.incoming_data.get() if n == -1:
return await self._read_until_eof()
if len(self._buf) == 0:
data: bytes
# Peek whether there is data available. If yes, we just read until there is no data,
# and then return.
try:
data = self.incoming_data_channel.receive_nowait()
except trio.EndOfChannel:
raise MplexStreamEOF
except trio.WouldBlock:
# We know `receive` will be blocked here. Wait for data here with `receive` and
# catch all kinds of errors here.
try:
data = await self.incoming_data_channel.receive()
except trio.EndOfChannel:
if self.event_reset.is_set():
raise MplexStreamReset
if self.event_remote_closed.is_set():
raise MplexStreamEOF
except trio.ClosedResourceError as error:
# Probably `incoming_data_channel` is closed in `reset` when we are waiting
# for `receive`.
if self.event_reset.is_set():
raise MplexStreamReset
raise Exception(
"`incoming_data_channel` is closed but stream is not reset. "
"This should never happen."
) from error
self._buf.extend(data)
self._buf.extend(self._read_return_when_blocked())
payload = self._buf[:n]
self._buf = self._buf[len(payload) :]
return bytes(payload)
async def write(self, data: bytes) -> int: async def write(self, data: bytes) -> int:
""" """
@ -99,22 +155,26 @@ class MplexStream(IMuxedStream):
if self.event_local_closed.is_set(): if self.event_local_closed.is_set():
return return
print(f"!@# stream.close: {self.muxed_conn._id}: step=0")
flag = ( flag = (
HeaderTags.CloseInitiator if self.is_initiator else HeaderTags.CloseReceiver HeaderTags.CloseInitiator if self.is_initiator else HeaderTags.CloseReceiver
) )
# TODO: Raise when `muxed_conn.send_message` fails and `Mplex` isn't shutdown. # TODO: Raise when `muxed_conn.send_message` fails and `Mplex` isn't shutdown.
await self.muxed_conn.send_message(flag, None, self.stream_id) await self.muxed_conn.send_message(flag, None, self.stream_id)
print(f"!@# stream.close: {self.muxed_conn._id}: step=1")
_is_remote_closed: bool _is_remote_closed: bool
async with self.close_lock: async with self.close_lock:
self.event_local_closed.set() self.event_local_closed.set()
_is_remote_closed = self.event_remote_closed.is_set() _is_remote_closed = self.event_remote_closed.is_set()
print(f"!@# stream.close: {self.muxed_conn._id}: step=2")
if _is_remote_closed: if _is_remote_closed:
# Both sides are closed, we can safely remove the buffer from the dict. # Both sides are closed, we can safely remove the buffer from the dict.
async with self.muxed_conn.streams_lock: async with self.muxed_conn.streams_lock:
if self.stream_id in self.muxed_conn.streams: if self.stream_id in self.muxed_conn.streams:
del self.muxed_conn.streams[self.stream_id] del self.muxed_conn.streams[self.stream_id]
print(f"!@# stream.close: {self.muxed_conn._id}: step=3")
async def reset(self) -> None: async def reset(self) -> None:
"""closes both ends of the stream tells this remote side to hang up.""" """closes both ends of the stream tells this remote side to hang up."""
@ -132,14 +192,15 @@ class MplexStream(IMuxedStream):
if self.is_initiator if self.is_initiator
else HeaderTags.ResetReceiver else HeaderTags.ResetReceiver
) )
async with trio.open_nursery() as nursery: self.muxed_conn.manager.run_task(
nursery.start_soon( self.muxed_conn.send_message, flag, None, self.stream_id
self.muxed_conn.send_message, flag, None, self.stream_id )
)
self.event_local_closed.set() self.event_local_closed.set()
self.event_remote_closed.set() self.event_remote_closed.set()
await self.incoming_data_channel.aclose()
async with self.muxed_conn.streams_lock: async with self.muxed_conn.streams_lock:
if ( if (
self.muxed_conn.streams is not None self.muxed_conn.streams is not None

View File

@ -205,7 +205,7 @@ async def mplex_stream_pair_factory(is_secure: bool) -> Tuple[MplexStream, Mplex
stream_1: MplexStream stream_1: MplexStream
async with mplex_conn_1.streams_lock: async with mplex_conn_1.streams_lock:
if len(mplex_conn_1.streams) != 1: if len(mplex_conn_1.streams) != 1:
raise Exception("Mplex should not have any stream upon connection") raise Exception("Mplex should not have any other stream")
stream_1 = tuple(mplex_conn_1.streams.values())[0] stream_1 = tuple(mplex_conn_1.streams.values())[0]
yield cast(MplexStream, stream_0), cast(MplexStream, stream_1) yield cast(MplexStream, stream_0), cast(MplexStream, stream_1)

View File

@ -1,8 +1,5 @@
import itertools import itertools
import math import math
from typing import Generic, TypeVar
import trio
from libp2p.exceptions import ParseError from libp2p.exceptions import ParseError
from libp2p.io.abc import Reader from libp2p.io.abc import Reader
@ -98,25 +95,3 @@ async def read_fixedint_prefixed(reader: Reader) -> bytes:
len_bytes = await reader.read(SIZE_LEN_BYTES) len_bytes = await reader.read(SIZE_LEN_BYTES)
len_int = int.from_bytes(len_bytes, "big") len_int = int.from_bytes(len_bytes, "big")
return await reader.read(len_int) return await reader.read(len_int)
TItem = TypeVar("TItem")
class IQueue(Generic[TItem]):
async def put(self, item: TItem):
...
async def get(self) -> TItem:
...
class TrioQueue(IQueue):
def __init__(self):
self.send_channel, self.receive_channel = trio.open_memory_channel(0)
async def put(self, item: TItem):
await self.send_channel.send(item)
async def get(self) -> TItem:
return await self.receive_channel.receive()

View File

@ -1,8 +1,5 @@
import asyncio
import pytest import pytest
from libp2p.tools.constants import LISTEN_MADDR
from libp2p.tools.factories import HostFactory from libp2p.tools.factories import HostFactory
@ -17,17 +14,6 @@ def num_hosts():
@pytest.fixture @pytest.fixture
async def hosts(num_hosts, is_host_secure): async def hosts(num_hosts, is_host_secure, nursery):
_hosts = HostFactory.create_batch(num_hosts, is_secure=is_host_secure) async with HostFactory.create_batch_and_listen(is_host_secure, num_hosts) as _hosts:
await asyncio.gather(
*[_host.get_network().listen(LISTEN_MADDR) for _host in _hosts]
)
try:
yield _hosts yield _hosts
finally:
# TODO: It's possible that `close` raises exceptions currently,
# due to the connection reset things. Though we don't care much about that when
# cleaning up the tasks, it is probably better to handle the exceptions properly.
await asyncio.gather(
*[_host.close() for _host in _hosts], return_exceptions=True
)

View File

@ -1,5 +1,3 @@
import asyncio
import pytest import pytest
from libp2p.tools.factories import ( from libp2p.tools.factories import (

View File

@ -1,5 +1,3 @@
import asyncio
import pytest import pytest
from libp2p.tools.factories import mplex_conn_pair_factory, mplex_stream_pair_factory from libp2p.tools.factories import mplex_conn_pair_factory, mplex_stream_pair_factory

View File

@ -1,9 +1,9 @@
import asyncio import trio
import pytest import pytest
@pytest.mark.asyncio @pytest.mark.trio
async def test_mplex_conn(mplex_conn_pair): async def test_mplex_conn(mplex_conn_pair):
conn_0, conn_1 = mplex_conn_pair conn_0, conn_1 = mplex_conn_pair
@ -16,19 +16,19 @@ async def test_mplex_conn(mplex_conn_pair):
# Test: Open a stream, and both side get 1 more stream. # Test: Open a stream, and both side get 1 more stream.
stream_0 = await conn_0.open_stream() stream_0 = await conn_0.open_stream()
await asyncio.sleep(0.01) await trio.sleep(0.01)
assert len(conn_0.streams) == 1 assert len(conn_0.streams) == 1
assert len(conn_1.streams) == 1 assert len(conn_1.streams) == 1
# Test: From another side. # Test: From another side.
stream_1 = await conn_1.open_stream() stream_1 = await conn_1.open_stream()
await asyncio.sleep(0.01) await trio.sleep(0.01)
assert len(conn_0.streams) == 2 assert len(conn_0.streams) == 2
assert len(conn_1.streams) == 2 assert len(conn_1.streams) == 2
# Close from one side. # Close from one side.
await conn_0.close() await conn_0.close()
# Sleep for a while for both side to handle `close`. # Sleep for a while for both side to handle `close`.
await asyncio.sleep(0.01) await trio.sleep(0.01)
# Test: Both side is closed. # Test: Both side is closed.
assert conn_0.event_shutting_down.is_set() assert conn_0.event_shutting_down.is_set()
assert conn_0.event_closed.is_set() assert conn_0.event_closed.is_set()

View File

@ -1,5 +1,6 @@
import pytest import pytest
import trio import trio
from trio.testing import wait_all_tasks_blocked
from libp2p.stream_muxer.mplex.exceptions import ( from libp2p.stream_muxer.mplex.exceptions import (
MplexStreamClosed, MplexStreamClosed,
@ -37,37 +38,65 @@ async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair):
async def read_until_eof(): async def read_until_eof():
read_bytes.extend(await stream_1.read()) read_bytes.extend(await stream_1.read())
task = trio.ensure_future(read_until_eof())
expected_data = bytearray() expected_data = bytearray()
# Test: `read` doesn't return before `close` is called. async with trio.open_nursery() as nursery:
await stream_0.write(DATA) nursery.start_soon(read_until_eof)
expected_data.extend(DATA) # Test: `read` doesn't return before `close` is called.
await trio.sleep(0.01) await stream_0.write(DATA)
assert len(read_bytes) == 0 expected_data.extend(DATA)
# Test: `read` doesn't return before `close` is called. await trio.sleep(0.01)
await stream_0.write(DATA) assert len(read_bytes) == 0
expected_data.extend(DATA) # Test: `read` doesn't return before `close` is called.
await trio.sleep(0.01) await stream_0.write(DATA)
assert len(read_bytes) == 0 expected_data.extend(DATA)
await trio.sleep(0.01)
assert len(read_bytes) == 0
# Test: Close the stream, `read` returns, and receive previous sent data.
await stream_0.close()
# Test: Close the stream, `read` returns, and receive previous sent data.
await stream_0.close()
await trio.sleep(0.01)
assert read_bytes == expected_data assert read_bytes == expected_data
task.cancel()
@pytest.mark.trio @pytest.mark.trio
async def test_mplex_stream_read_after_remote_closed(mplex_stream_pair): async def test_mplex_stream_read_after_remote_closed(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair stream_0, stream_1 = mplex_stream_pair
assert not stream_1.event_remote_closed.is_set() assert not stream_1.event_remote_closed.is_set()
await stream_0.write(DATA) await stream_0.write(DATA)
await stream_0.close() assert not stream_0.event_local_closed.is_set()
await trio.sleep(0.01) await trio.sleep(0.01)
await wait_all_tasks_blocked()
await stream_0.close()
assert stream_0.event_local_closed.is_set()
await trio.sleep(0.01)
print(
"!@# ",
stream_0.muxed_conn.event_shutting_down.is_set(),
stream_0.muxed_conn.event_closed.is_set(),
stream_1.muxed_conn.event_shutting_down.is_set(),
stream_1.muxed_conn.event_closed.is_set(),
)
# await trio.sleep(100000)
await wait_all_tasks_blocked()
print(
"!@# ",
stream_0.muxed_conn.event_shutting_down.is_set(),
stream_0.muxed_conn.event_closed.is_set(),
stream_1.muxed_conn.event_shutting_down.is_set(),
stream_1.muxed_conn.event_closed.is_set(),
)
print("!@# sleeping")
print("!@# result=", stream_1.event_remote_closed.is_set())
# await trio.sleep_forever()
assert stream_1.event_remote_closed.is_set() assert stream_1.event_remote_closed.is_set()
print(
"!@# ",
stream_0.muxed_conn.event_shutting_down.is_set(),
stream_0.muxed_conn.event_closed.is_set(),
stream_1.muxed_conn.event_shutting_down.is_set(),
stream_1.muxed_conn.event_closed.is_set(),
)
assert (await stream_1.read(MAX_READ_LEN)) == DATA assert (await stream_1.read(MAX_READ_LEN)) == DATA
with pytest.raises(MplexStreamEOF): with pytest.raises(MplexStreamEOF):
await stream_1.read(MAX_READ_LEN) await stream_1.read(MAX_READ_LEN)
@ -87,7 +116,8 @@ async def test_mplex_stream_read_after_remote_reset(mplex_stream_pair):
await stream_0.write(DATA) await stream_0.write(DATA)
await stream_0.reset() await stream_0.reset()
# Sleep to let `stream_1` receive the message. # Sleep to let `stream_1` receive the message.
await trio.sleep(0.01) await trio.sleep(0.1)
await wait_all_tasks_blocked()
with pytest.raises(MplexStreamReset): with pytest.raises(MplexStreamReset):
await stream_1.read(MAX_READ_LEN) await stream_1.read(MAX_READ_LEN)

View File

@ -1,19 +0,0 @@
import pytest
import trio
from libp2p.utils import TrioQueue
@pytest.mark.trio
async def test_trio_queue():
queue = TrioQueue()
async def queue_get(task_status=None):
result = await queue.get()
task_status.started(result)
async with trio.open_nursery() as nursery:
nursery.start_soon(queue.put, 123)
result = await nursery.start(queue_get)
assert result == 123