mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-10 23:20:55 +00:00
Fix Mplex and Swarm
This commit is contained in:
@ -23,6 +23,10 @@ class TrioReadWriteCloser(ReadWriteCloser):
|
|||||||
raise IOException(error)
|
raise IOException(error)
|
||||||
|
|
||||||
async def read(self, n: int = -1) -> bytes:
|
async def read(self, n: int = -1) -> bytes:
|
||||||
|
if n == 0:
|
||||||
|
# Check point
|
||||||
|
await trio.sleep(0)
|
||||||
|
return b""
|
||||||
max_bytes = n if n != -1 else None
|
max_bytes = n if n != -1 else None
|
||||||
try:
|
try:
|
||||||
return await self.stream.receive_some(max_bytes)
|
return await self.stream.receive_some(max_bytes)
|
||||||
|
|||||||
@ -50,8 +50,11 @@ class SwarmConn(INetConn, Service):
|
|||||||
await self._notify_disconnected()
|
await self._notify_disconnected()
|
||||||
|
|
||||||
async def _handle_new_streams(self) -> None:
|
async def _handle_new_streams(self) -> None:
|
||||||
while True:
|
while self.manager.is_running:
|
||||||
try:
|
try:
|
||||||
|
print(
|
||||||
|
f"!@# SwarmConn._handle_new_streams: {self.muxed_conn._id}: waiting for new streams"
|
||||||
|
)
|
||||||
stream = await self.muxed_conn.accept_stream()
|
stream = await self.muxed_conn.accept_stream()
|
||||||
except MuxedConnUnavailable:
|
except MuxedConnUnavailable:
|
||||||
# If there is anything wrong in the MuxedConn,
|
# If there is anything wrong in the MuxedConn,
|
||||||
@ -60,6 +63,9 @@ class SwarmConn(INetConn, Service):
|
|||||||
# Asynchronously handle the accepted stream, to avoid blocking the next stream.
|
# Asynchronously handle the accepted stream, to avoid blocking the next stream.
|
||||||
self.manager.run_task(self._handle_muxed_stream, stream)
|
self.manager.run_task(self._handle_muxed_stream, stream)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"!@# SwarmConn._handle_new_streams: {self.muxed_conn._id}: out of the loop"
|
||||||
|
)
|
||||||
await self.close()
|
await self.close()
|
||||||
|
|
||||||
async def _call_stream_handler(self, net_stream: NetStream) -> None:
|
async def _call_stream_handler(self, net_stream: NetStream) -> None:
|
||||||
|
|||||||
@ -206,8 +206,7 @@ class Swarm(INetwork, Service):
|
|||||||
logger.debug("successfully opened connection to peer %s", peer_id)
|
logger.debug("successfully opened connection to peer %s", peer_id)
|
||||||
# FIXME: This is a intentional barrier to prevent from the handler exiting and
|
# FIXME: This is a intentional barrier to prevent from the handler exiting and
|
||||||
# closing the connection.
|
# closing the connection.
|
||||||
event = trio.Event()
|
await trio.sleep_forever()
|
||||||
await event.wait()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Success
|
# Success
|
||||||
@ -240,7 +239,7 @@ class Swarm(INetwork, Service):
|
|||||||
# await asyncio.gather(
|
# await asyncio.gather(
|
||||||
# *[connection.close() for connection in self.connections.values()]
|
# *[connection.close() for connection in self.connections.values()]
|
||||||
# )
|
# )
|
||||||
self.manager.stop()
|
await self.manager.stop()
|
||||||
await self.manager.wait_finished()
|
await self.manager.wait_finished()
|
||||||
logger.debug("swarm successfully closed")
|
logger.debug("swarm successfully closed")
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import math
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import Any # noqa: F401
|
from typing import Any # noqa: F401
|
||||||
@ -18,7 +19,6 @@ from libp2p.utils import (
|
|||||||
encode_uvarint,
|
encode_uvarint,
|
||||||
encode_varint_prefixed,
|
encode_varint_prefixed,
|
||||||
read_varint_prefixed_bytes,
|
read_varint_prefixed_bytes,
|
||||||
TrioQueue,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from .constants import HeaderTags
|
from .constants import HeaderTags
|
||||||
@ -41,7 +41,10 @@ class Mplex(IMuxedConn, Service):
|
|||||||
next_channel_id: int
|
next_channel_id: int
|
||||||
streams: Dict[StreamID, MplexStream]
|
streams: Dict[StreamID, MplexStream]
|
||||||
streams_lock: trio.Lock
|
streams_lock: trio.Lock
|
||||||
new_stream_queue: "TrioQueue[IMuxedStream]"
|
streams_msg_channels: Dict[StreamID, "trio.MemorySendChannel[bytes]"]
|
||||||
|
new_stream_send_channel: "trio.MemorySendChannel[IMuxedStream]"
|
||||||
|
new_stream_receive_channel: "trio.MemoryReceiveChannel[IMuxedStream]"
|
||||||
|
|
||||||
event_shutting_down: trio.Event
|
event_shutting_down: trio.Event
|
||||||
event_closed: trio.Event
|
event_closed: trio.Event
|
||||||
|
|
||||||
@ -64,7 +67,10 @@ class Mplex(IMuxedConn, Service):
|
|||||||
# Mapping from stream ID -> buffer of messages for that stream
|
# Mapping from stream ID -> buffer of messages for that stream
|
||||||
self.streams = {}
|
self.streams = {}
|
||||||
self.streams_lock = trio.Lock()
|
self.streams_lock = trio.Lock()
|
||||||
self.new_stream_queue = TrioQueue()
|
self.streams_msg_channels = {}
|
||||||
|
send_channel, receive_channel = trio.open_memory_channel(math.inf)
|
||||||
|
self.new_stream_send_channel = send_channel
|
||||||
|
self.new_stream_receive_channel = receive_channel
|
||||||
self.event_shutting_down = trio.Event()
|
self.event_shutting_down = trio.Event()
|
||||||
self.event_closed = trio.Event()
|
self.event_closed = trio.Event()
|
||||||
|
|
||||||
@ -105,9 +111,13 @@ class Mplex(IMuxedConn, Service):
|
|||||||
return next_id
|
return next_id
|
||||||
|
|
||||||
async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream:
|
async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream:
|
||||||
stream = MplexStream(name, stream_id, self)
|
# Use an unbounded buffer, to avoid `handle_incoming` being blocked when doing
|
||||||
|
# `send_channel.send`.
|
||||||
|
send_channel, receive_channel = trio.open_memory_channel(math.inf)
|
||||||
|
stream = MplexStream(name, stream_id, self, receive_channel)
|
||||||
async with self.streams_lock:
|
async with self.streams_lock:
|
||||||
self.streams[stream_id] = stream
|
self.streams[stream_id] = stream
|
||||||
|
self.streams_msg_channels[stream_id] = send_channel
|
||||||
return stream
|
return stream
|
||||||
|
|
||||||
async def open_stream(self) -> IMuxedStream:
|
async def open_stream(self) -> IMuxedStream:
|
||||||
@ -126,7 +136,10 @@ class Mplex(IMuxedConn, Service):
|
|||||||
|
|
||||||
async def accept_stream(self) -> IMuxedStream:
|
async def accept_stream(self) -> IMuxedStream:
|
||||||
"""accepts a muxed stream opened by the other end."""
|
"""accepts a muxed stream opened by the other end."""
|
||||||
return await self.new_stream_queue.get()
|
try:
|
||||||
|
return await self.new_stream_receive_channel.receive()
|
||||||
|
except (trio.ClosedResourceError, trio.EndOfChannel):
|
||||||
|
raise MplexUnavailable
|
||||||
|
|
||||||
async def send_message(
|
async def send_message(
|
||||||
self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID
|
self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID
|
||||||
@ -138,6 +151,9 @@ class Mplex(IMuxedConn, Service):
|
|||||||
:param data: data to send in the message
|
:param data: data to send in the message
|
||||||
:param stream_id: stream the message is in
|
:param stream_id: stream the message is in
|
||||||
"""
|
"""
|
||||||
|
print(
|
||||||
|
f"!@# send_message: {self._id}: flag={flag}, data={data}, stream_id={stream_id}"
|
||||||
|
)
|
||||||
# << by 3, then or with flag
|
# << by 3, then or with flag
|
||||||
header = encode_uvarint((stream_id.channel_id << 3) | flag.value)
|
header = encode_uvarint((stream_id.channel_id << 3) | flag.value)
|
||||||
|
|
||||||
@ -162,14 +178,21 @@ class Mplex(IMuxedConn, Service):
|
|||||||
"""Read a message off of the secured connection and add it to the
|
"""Read a message off of the secured connection and add it to the
|
||||||
corresponding message buffer."""
|
corresponding message buffer."""
|
||||||
|
|
||||||
while True:
|
while self.manager.is_running:
|
||||||
try:
|
try:
|
||||||
|
print(
|
||||||
|
f"!@# handle_incoming: {self._id}: before _handle_incoming_message"
|
||||||
|
)
|
||||||
await self._handle_incoming_message()
|
await self._handle_incoming_message()
|
||||||
|
print(
|
||||||
|
f"!@# handle_incoming: {self._id}: after _handle_incoming_message"
|
||||||
|
)
|
||||||
except MplexUnavailable as e:
|
except MplexUnavailable as e:
|
||||||
logger.debug("mplex unavailable while waiting for incoming: %s", e)
|
logger.debug("mplex unavailable while waiting for incoming: %s", e)
|
||||||
|
print(f"!@# handle_incoming: {self._id}: MplexUnavailable: {e}")
|
||||||
break
|
break
|
||||||
# Force context switch
|
|
||||||
await trio.sleep(0)
|
print(f"!@# handle_incoming: {self._id}: leaving")
|
||||||
# If we enter here, it means this connection is shutting down.
|
# If we enter here, it means this connection is shutting down.
|
||||||
# We should clean things up.
|
# We should clean things up.
|
||||||
await self._cleanup()
|
await self._cleanup()
|
||||||
@ -181,51 +204,73 @@ class Mplex(IMuxedConn, Service):
|
|||||||
:return: stream_id, flag, message contents
|
:return: stream_id, flag, message contents
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# FIXME: No timeout is used in Go implementation.
|
|
||||||
try:
|
try:
|
||||||
header = await decode_uvarint_from_stream(self.secured_conn)
|
header = await decode_uvarint_from_stream(self.secured_conn)
|
||||||
|
except (ParseError, RawConnError, IncompleteReadError) as error:
|
||||||
|
raise MplexUnavailable(
|
||||||
|
f"failed to read the header correctly from the underlying connection: {error}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
message = await read_varint_prefixed_bytes(self.secured_conn)
|
message = await read_varint_prefixed_bytes(self.secured_conn)
|
||||||
except (ParseError, RawConnError, IncompleteReadError) as error:
|
except (ParseError, RawConnError, IncompleteReadError) as error:
|
||||||
raise MplexUnavailable(
|
raise MplexUnavailable(
|
||||||
"failed to read messages correctly from the underlying connection"
|
"failed to read the message body correctly from the underlying connection: "
|
||||||
) from error
|
f"{error}"
|
||||||
except asyncio.TimeoutError as error:
|
)
|
||||||
raise MplexUnavailable(
|
|
||||||
"failed to read more message body within the timeout"
|
|
||||||
) from error
|
|
||||||
|
|
||||||
flag = header & 0x07
|
flag = header & 0x07
|
||||||
channel_id = header >> 3
|
channel_id = header >> 3
|
||||||
|
|
||||||
return channel_id, flag, message
|
return channel_id, flag, message
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _id(self) -> int:
|
||||||
|
return 0 if self.is_initiator else 1
|
||||||
|
|
||||||
async def _handle_incoming_message(self) -> None:
|
async def _handle_incoming_message(self) -> None:
|
||||||
"""
|
"""
|
||||||
Read and handle a new incoming message.
|
Read and handle a new incoming message.
|
||||||
|
|
||||||
:raise MplexUnavailable: `Mplex` encounters fatal error or is shutting down.
|
:raise MplexUnavailable: `Mplex` encounters fatal error or is shutting down.
|
||||||
"""
|
"""
|
||||||
|
print(f"!@# _handle_incoming_message: {self._id}: before reading")
|
||||||
channel_id, flag, message = await self.read_message()
|
channel_id, flag, message = await self.read_message()
|
||||||
|
print(
|
||||||
|
f"!@# _handle_incoming_message: {self._id}: channel_id={channel_id}, flag={flag}, message={message}"
|
||||||
|
)
|
||||||
stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1))
|
stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1))
|
||||||
|
print(f"!@# _handle_incoming_message: {self._id}: 2")
|
||||||
|
|
||||||
if flag == HeaderTags.NewStream.value:
|
if flag == HeaderTags.NewStream.value:
|
||||||
|
print(f"!@# _handle_incoming_message: {self._id}: 3")
|
||||||
await self._handle_new_stream(stream_id, message)
|
await self._handle_new_stream(stream_id, message)
|
||||||
|
print(f"!@# _handle_incoming_message: {self._id}: 4")
|
||||||
elif flag in (
|
elif flag in (
|
||||||
HeaderTags.MessageInitiator.value,
|
HeaderTags.MessageInitiator.value,
|
||||||
HeaderTags.MessageReceiver.value,
|
HeaderTags.MessageReceiver.value,
|
||||||
):
|
):
|
||||||
|
print(f"!@# _handle_incoming_message: {self._id}: 5")
|
||||||
await self._handle_message(stream_id, message)
|
await self._handle_message(stream_id, message)
|
||||||
|
print(f"!@# _handle_incoming_message: {self._id}: 6")
|
||||||
elif flag in (HeaderTags.CloseInitiator.value, HeaderTags.CloseReceiver.value):
|
elif flag in (HeaderTags.CloseInitiator.value, HeaderTags.CloseReceiver.value):
|
||||||
|
print(f"!@# _handle_incoming_message: {self._id}: 7")
|
||||||
await self._handle_close(stream_id)
|
await self._handle_close(stream_id)
|
||||||
|
print(f"!@# _handle_incoming_message: {self._id}: 8")
|
||||||
elif flag in (HeaderTags.ResetInitiator.value, HeaderTags.ResetReceiver.value):
|
elif flag in (HeaderTags.ResetInitiator.value, HeaderTags.ResetReceiver.value):
|
||||||
|
print(f"!@# _handle_incoming_message: {self._id}: 9")
|
||||||
await self._handle_reset(stream_id)
|
await self._handle_reset(stream_id)
|
||||||
|
print(f"!@# _handle_incoming_message: {self._id}: 10")
|
||||||
else:
|
else:
|
||||||
|
print(f"!@# _handle_incoming_message: {self._id}: 11")
|
||||||
# Receives messages with an unknown flag
|
# Receives messages with an unknown flag
|
||||||
# TODO: logging
|
# TODO: logging
|
||||||
async with self.streams_lock:
|
async with self.streams_lock:
|
||||||
|
print(f"!@# _handle_incoming_message: {self._id}: 12")
|
||||||
if stream_id in self.streams:
|
if stream_id in self.streams:
|
||||||
|
print(f"!@# _handle_incoming_message: {self._id}: 13")
|
||||||
stream = self.streams[stream_id]
|
stream = self.streams[stream_id]
|
||||||
await stream.reset()
|
await stream.reset()
|
||||||
|
print(f"!@# _handle_incoming_message: {self._id}: 14")
|
||||||
|
|
||||||
async def _handle_new_stream(self, stream_id: StreamID, message: bytes) -> None:
|
async def _handle_new_stream(self, stream_id: StreamID, message: bytes) -> None:
|
||||||
async with self.streams_lock:
|
async with self.streams_lock:
|
||||||
@ -235,43 +280,65 @@ class Mplex(IMuxedConn, Service):
|
|||||||
f"received NewStream message for existing stream: {stream_id}"
|
f"received NewStream message for existing stream: {stream_id}"
|
||||||
)
|
)
|
||||||
mplex_stream = await self._initialize_stream(stream_id, message.decode())
|
mplex_stream = await self._initialize_stream(stream_id, message.decode())
|
||||||
await self.new_stream_queue.put(mplex_stream)
|
try:
|
||||||
|
await self.new_stream_send_channel.send(mplex_stream)
|
||||||
|
except (trio.BrokenResourceError, trio.EndOfChannel):
|
||||||
|
raise MplexUnavailable
|
||||||
|
|
||||||
async def _handle_message(self, stream_id: StreamID, message: bytes) -> None:
|
async def _handle_message(self, stream_id: StreamID, message: bytes) -> None:
|
||||||
|
print(
|
||||||
|
f"!@# _handle_message: {self._id}: stream_id={stream_id}, message={message}"
|
||||||
|
)
|
||||||
async with self.streams_lock:
|
async with self.streams_lock:
|
||||||
|
print(f"!@# _handle_message: {self._id}: 1")
|
||||||
if stream_id not in self.streams:
|
if stream_id not in self.streams:
|
||||||
# We receive a message of the stream `stream_id` which is not accepted
|
# We receive a message of the stream `stream_id` which is not accepted
|
||||||
# before. It is abnormal. Possibly disconnect?
|
# before. It is abnormal. Possibly disconnect?
|
||||||
# TODO: Warn and emit logs about this.
|
# TODO: Warn and emit logs about this.
|
||||||
|
print(f"!@# _handle_message: {self._id}: 2")
|
||||||
return
|
return
|
||||||
|
print(f"!@# _handle_message: {self._id}: 3")
|
||||||
stream = self.streams[stream_id]
|
stream = self.streams[stream_id]
|
||||||
|
send_channel = self.streams_msg_channels[stream_id]
|
||||||
async with stream.close_lock:
|
async with stream.close_lock:
|
||||||
|
print(f"!@# _handle_message: {self._id}: 4")
|
||||||
if stream.event_remote_closed.is_set():
|
if stream.event_remote_closed.is_set():
|
||||||
|
print(f"!@# _handle_message: {self._id}: 5")
|
||||||
# TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501
|
# TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501
|
||||||
return
|
return
|
||||||
await stream.incoming_data.put(message)
|
print(f"!@# _handle_message: {self._id}: 6")
|
||||||
|
await send_channel.send(message)
|
||||||
|
print(f"!@# _handle_message: {self._id}: 7")
|
||||||
|
|
||||||
async def _handle_close(self, stream_id: StreamID) -> None:
|
async def _handle_close(self, stream_id: StreamID) -> None:
|
||||||
|
print(f"!@# _handle_close: {self._id}: step=0")
|
||||||
async with self.streams_lock:
|
async with self.streams_lock:
|
||||||
if stream_id not in self.streams:
|
if stream_id not in self.streams:
|
||||||
# Ignore unmatched messages for now.
|
# Ignore unmatched messages for now.
|
||||||
return
|
return
|
||||||
stream = self.streams[stream_id]
|
stream = self.streams[stream_id]
|
||||||
|
send_channel = self.streams_msg_channels[stream_id]
|
||||||
|
print(f"!@# _handle_close: {self._id}: step=1")
|
||||||
|
await send_channel.aclose()
|
||||||
|
print(f"!@# _handle_close: {self._id}: step=2")
|
||||||
# NOTE: If remote is already closed, then return: Technically a bug
|
# NOTE: If remote is already closed, then return: Technically a bug
|
||||||
# on the other side. We should consider killing the connection.
|
# on the other side. We should consider killing the connection.
|
||||||
async with stream.close_lock:
|
async with stream.close_lock:
|
||||||
if stream.event_remote_closed.is_set():
|
if stream.event_remote_closed.is_set():
|
||||||
return
|
return
|
||||||
|
print(f"!@# _handle_close: {self._id}: step=3")
|
||||||
is_local_closed: bool
|
is_local_closed: bool
|
||||||
async with stream.close_lock:
|
async with stream.close_lock:
|
||||||
stream.event_remote_closed.set()
|
stream.event_remote_closed.set()
|
||||||
is_local_closed = stream.event_local_closed.is_set()
|
is_local_closed = stream.event_local_closed.is_set()
|
||||||
|
print(f"!@# _handle_close: {self._id}: step=4")
|
||||||
# If local is also closed, both sides are closed. Then, we should clean up
|
# If local is also closed, both sides are closed. Then, we should clean up
|
||||||
# the entry of this stream, to avoid others from accessing it.
|
# the entry of this stream, to avoid others from accessing it.
|
||||||
if is_local_closed:
|
if is_local_closed:
|
||||||
async with self.streams_lock:
|
async with self.streams_lock:
|
||||||
if stream_id in self.streams:
|
if stream_id in self.streams:
|
||||||
del self.streams[stream_id]
|
del self.streams[stream_id]
|
||||||
|
print(f"!@# _handle_close: {self._id}: step=5")
|
||||||
|
|
||||||
async def _handle_reset(self, stream_id: StreamID) -> None:
|
async def _handle_reset(self, stream_id: StreamID) -> None:
|
||||||
async with self.streams_lock:
|
async with self.streams_lock:
|
||||||
@ -279,11 +346,11 @@ class Mplex(IMuxedConn, Service):
|
|||||||
# This is *ok*. We forget the stream on reset.
|
# This is *ok*. We forget the stream on reset.
|
||||||
return
|
return
|
||||||
stream = self.streams[stream_id]
|
stream = self.streams[stream_id]
|
||||||
|
send_channel = self.streams_msg_channels[stream_id]
|
||||||
|
await send_channel.aclose()
|
||||||
async with stream.close_lock:
|
async with stream.close_lock:
|
||||||
if not stream.event_remote_closed.is_set():
|
if not stream.event_remote_closed.is_set():
|
||||||
stream.event_reset.set()
|
stream.event_reset.set()
|
||||||
|
|
||||||
stream.event_remote_closed.set()
|
stream.event_remote_closed.set()
|
||||||
# If local is not closed, we should close it.
|
# If local is not closed, we should close it.
|
||||||
if not stream.event_local_closed.is_set():
|
if not stream.event_local_closed.is_set():
|
||||||
@ -291,16 +358,21 @@ class Mplex(IMuxedConn, Service):
|
|||||||
async with self.streams_lock:
|
async with self.streams_lock:
|
||||||
if stream_id in self.streams:
|
if stream_id in self.streams:
|
||||||
del self.streams[stream_id]
|
del self.streams[stream_id]
|
||||||
|
del self.streams_msg_channels[stream_id]
|
||||||
|
|
||||||
async def _cleanup(self) -> None:
|
async def _cleanup(self) -> None:
|
||||||
if not self.event_shutting_down.is_set():
|
if not self.event_shutting_down.is_set():
|
||||||
self.event_shutting_down.set()
|
self.event_shutting_down.set()
|
||||||
async with self.streams_lock:
|
async with self.streams_lock:
|
||||||
for stream in self.streams.values():
|
for stream_id, stream in self.streams.items():
|
||||||
async with stream.close_lock:
|
async with stream.close_lock:
|
||||||
if not stream.event_remote_closed.is_set():
|
if not stream.event_remote_closed.is_set():
|
||||||
stream.event_remote_closed.set()
|
stream.event_remote_closed.set()
|
||||||
stream.event_reset.set()
|
stream.event_reset.set()
|
||||||
stream.event_local_closed.set()
|
stream.event_local_closed.set()
|
||||||
|
send_channel = self.streams_msg_channels[stream_id]
|
||||||
|
await send_channel.aclose()
|
||||||
self.streams = None
|
self.streams = None
|
||||||
self.event_closed.set()
|
self.event_closed.set()
|
||||||
|
await self.new_stream_send_channel.aclose()
|
||||||
|
await self.new_stream_receive_channel.aclose()
|
||||||
|
|||||||
@ -3,7 +3,6 @@ from typing import TYPE_CHECKING
|
|||||||
import trio
|
import trio
|
||||||
|
|
||||||
from libp2p.stream_muxer.abc import IMuxedStream
|
from libp2p.stream_muxer.abc import IMuxedStream
|
||||||
from libp2p.utils import IQueue, TrioQueue
|
|
||||||
|
|
||||||
from .constants import HeaderTags
|
from .constants import HeaderTags
|
||||||
from .datastructures import StreamID
|
from .datastructures import StreamID
|
||||||
@ -24,10 +23,11 @@ class MplexStream(IMuxedStream):
|
|||||||
read_deadline: int
|
read_deadline: int
|
||||||
write_deadline: int
|
write_deadline: int
|
||||||
|
|
||||||
|
# TODO: Add lock for read/write to avoid interleaving receiving messages?
|
||||||
close_lock: trio.Lock
|
close_lock: trio.Lock
|
||||||
|
|
||||||
# NOTE: `dataIn` is size of 8 in Go implementation.
|
# NOTE: `dataIn` is size of 8 in Go implementation.
|
||||||
incoming_data: IQueue[bytes]
|
incoming_data_channel: "trio.MemoryReceiveChannel[bytes]"
|
||||||
|
|
||||||
event_local_closed: trio.Event
|
event_local_closed: trio.Event
|
||||||
event_remote_closed: trio.Event
|
event_remote_closed: trio.Event
|
||||||
@ -35,7 +35,13 @@ class MplexStream(IMuxedStream):
|
|||||||
|
|
||||||
_buf: bytearray
|
_buf: bytearray
|
||||||
|
|
||||||
def __init__(self, name: str, stream_id: StreamID, muxed_conn: "Mplex") -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
stream_id: StreamID,
|
||||||
|
muxed_conn: "Mplex",
|
||||||
|
incoming_data_channel: "trio.MemoryReceiveChannel[bytes]",
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
create new MuxedStream in muxer.
|
create new MuxedStream in muxer.
|
||||||
|
|
||||||
@ -51,13 +57,30 @@ class MplexStream(IMuxedStream):
|
|||||||
self.event_remote_closed = trio.Event()
|
self.event_remote_closed = trio.Event()
|
||||||
self.event_reset = trio.Event()
|
self.event_reset = trio.Event()
|
||||||
self.close_lock = trio.Lock()
|
self.close_lock = trio.Lock()
|
||||||
self.incoming_data = TrioQueue()
|
self.incoming_data_channel = incoming_data_channel
|
||||||
self._buf = bytearray()
|
self._buf = bytearray()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_initiator(self) -> bool:
|
def is_initiator(self) -> bool:
|
||||||
return self.stream_id.is_initiator
|
return self.stream_id.is_initiator
|
||||||
|
|
||||||
|
async def _read_until_eof(self) -> bytes:
|
||||||
|
async for data in self.incoming_data_channel:
|
||||||
|
self._buf.extend(data)
|
||||||
|
payload = self._buf
|
||||||
|
self._buf = self._buf[len(payload) :]
|
||||||
|
return bytes(payload)
|
||||||
|
|
||||||
|
def _read_return_when_blocked(self) -> bytes:
|
||||||
|
buf = bytearray()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data = self.incoming_data_channel.receive_nowait()
|
||||||
|
buf.extend(data)
|
||||||
|
except (trio.WouldBlock, trio.EndOfChannel):
|
||||||
|
break
|
||||||
|
return buf
|
||||||
|
|
||||||
async def read(self, n: int = -1) -> bytes:
|
async def read(self, n: int = -1) -> bytes:
|
||||||
"""
|
"""
|
||||||
Read up to n bytes. Read possibly returns fewer than `n` bytes, if
|
Read up to n bytes. Read possibly returns fewer than `n` bytes, if
|
||||||
@ -73,7 +96,40 @@ class MplexStream(IMuxedStream):
|
|||||||
)
|
)
|
||||||
if self.event_reset.is_set():
|
if self.event_reset.is_set():
|
||||||
raise MplexStreamReset
|
raise MplexStreamReset
|
||||||
return await self.incoming_data.get()
|
if n == -1:
|
||||||
|
return await self._read_until_eof()
|
||||||
|
if len(self._buf) == 0:
|
||||||
|
data: bytes
|
||||||
|
# Peek whether there is data available. If yes, we just read until there is no data,
|
||||||
|
# and then return.
|
||||||
|
try:
|
||||||
|
data = self.incoming_data_channel.receive_nowait()
|
||||||
|
except trio.EndOfChannel:
|
||||||
|
raise MplexStreamEOF
|
||||||
|
except trio.WouldBlock:
|
||||||
|
# We know `receive` will be blocked here. Wait for data here with `receive` and
|
||||||
|
# catch all kinds of errors here.
|
||||||
|
try:
|
||||||
|
data = await self.incoming_data_channel.receive()
|
||||||
|
except trio.EndOfChannel:
|
||||||
|
if self.event_reset.is_set():
|
||||||
|
raise MplexStreamReset
|
||||||
|
if self.event_remote_closed.is_set():
|
||||||
|
raise MplexStreamEOF
|
||||||
|
except trio.ClosedResourceError as error:
|
||||||
|
# Probably `incoming_data_channel` is closed in `reset` when we are waiting
|
||||||
|
# for `receive`.
|
||||||
|
if self.event_reset.is_set():
|
||||||
|
raise MplexStreamReset
|
||||||
|
raise Exception(
|
||||||
|
"`incoming_data_channel` is closed but stream is not reset. "
|
||||||
|
"This should never happen."
|
||||||
|
) from error
|
||||||
|
self._buf.extend(data)
|
||||||
|
self._buf.extend(self._read_return_when_blocked())
|
||||||
|
payload = self._buf[:n]
|
||||||
|
self._buf = self._buf[len(payload) :]
|
||||||
|
return bytes(payload)
|
||||||
|
|
||||||
async def write(self, data: bytes) -> int:
|
async def write(self, data: bytes) -> int:
|
||||||
"""
|
"""
|
||||||
@ -99,22 +155,26 @@ class MplexStream(IMuxedStream):
|
|||||||
if self.event_local_closed.is_set():
|
if self.event_local_closed.is_set():
|
||||||
return
|
return
|
||||||
|
|
||||||
|
print(f"!@# stream.close: {self.muxed_conn._id}: step=0")
|
||||||
flag = (
|
flag = (
|
||||||
HeaderTags.CloseInitiator if self.is_initiator else HeaderTags.CloseReceiver
|
HeaderTags.CloseInitiator if self.is_initiator else HeaderTags.CloseReceiver
|
||||||
)
|
)
|
||||||
# TODO: Raise when `muxed_conn.send_message` fails and `Mplex` isn't shutdown.
|
# TODO: Raise when `muxed_conn.send_message` fails and `Mplex` isn't shutdown.
|
||||||
await self.muxed_conn.send_message(flag, None, self.stream_id)
|
await self.muxed_conn.send_message(flag, None, self.stream_id)
|
||||||
|
|
||||||
|
print(f"!@# stream.close: {self.muxed_conn._id}: step=1")
|
||||||
_is_remote_closed: bool
|
_is_remote_closed: bool
|
||||||
async with self.close_lock:
|
async with self.close_lock:
|
||||||
self.event_local_closed.set()
|
self.event_local_closed.set()
|
||||||
_is_remote_closed = self.event_remote_closed.is_set()
|
_is_remote_closed = self.event_remote_closed.is_set()
|
||||||
|
|
||||||
|
print(f"!@# stream.close: {self.muxed_conn._id}: step=2")
|
||||||
if _is_remote_closed:
|
if _is_remote_closed:
|
||||||
# Both sides are closed, we can safely remove the buffer from the dict.
|
# Both sides are closed, we can safely remove the buffer from the dict.
|
||||||
async with self.muxed_conn.streams_lock:
|
async with self.muxed_conn.streams_lock:
|
||||||
if self.stream_id in self.muxed_conn.streams:
|
if self.stream_id in self.muxed_conn.streams:
|
||||||
del self.muxed_conn.streams[self.stream_id]
|
del self.muxed_conn.streams[self.stream_id]
|
||||||
|
print(f"!@# stream.close: {self.muxed_conn._id}: step=3")
|
||||||
|
|
||||||
async def reset(self) -> None:
|
async def reset(self) -> None:
|
||||||
"""closes both ends of the stream tells this remote side to hang up."""
|
"""closes both ends of the stream tells this remote side to hang up."""
|
||||||
@ -132,14 +192,15 @@ class MplexStream(IMuxedStream):
|
|||||||
if self.is_initiator
|
if self.is_initiator
|
||||||
else HeaderTags.ResetReceiver
|
else HeaderTags.ResetReceiver
|
||||||
)
|
)
|
||||||
async with trio.open_nursery() as nursery:
|
self.muxed_conn.manager.run_task(
|
||||||
nursery.start_soon(
|
self.muxed_conn.send_message, flag, None, self.stream_id
|
||||||
self.muxed_conn.send_message, flag, None, self.stream_id
|
)
|
||||||
)
|
|
||||||
|
|
||||||
self.event_local_closed.set()
|
self.event_local_closed.set()
|
||||||
self.event_remote_closed.set()
|
self.event_remote_closed.set()
|
||||||
|
|
||||||
|
await self.incoming_data_channel.aclose()
|
||||||
|
|
||||||
async with self.muxed_conn.streams_lock:
|
async with self.muxed_conn.streams_lock:
|
||||||
if (
|
if (
|
||||||
self.muxed_conn.streams is not None
|
self.muxed_conn.streams is not None
|
||||||
|
|||||||
@ -205,7 +205,7 @@ async def mplex_stream_pair_factory(is_secure: bool) -> Tuple[MplexStream, Mplex
|
|||||||
stream_1: MplexStream
|
stream_1: MplexStream
|
||||||
async with mplex_conn_1.streams_lock:
|
async with mplex_conn_1.streams_lock:
|
||||||
if len(mplex_conn_1.streams) != 1:
|
if len(mplex_conn_1.streams) != 1:
|
||||||
raise Exception("Mplex should not have any stream upon connection")
|
raise Exception("Mplex should not have any other stream")
|
||||||
stream_1 = tuple(mplex_conn_1.streams.values())[0]
|
stream_1 = tuple(mplex_conn_1.streams.values())[0]
|
||||||
yield cast(MplexStream, stream_0), cast(MplexStream, stream_1)
|
yield cast(MplexStream, stream_0), cast(MplexStream, stream_1)
|
||||||
|
|
||||||
|
|||||||
@ -1,8 +1,5 @@
|
|||||||
import itertools
|
import itertools
|
||||||
import math
|
import math
|
||||||
from typing import Generic, TypeVar
|
|
||||||
|
|
||||||
import trio
|
|
||||||
|
|
||||||
from libp2p.exceptions import ParseError
|
from libp2p.exceptions import ParseError
|
||||||
from libp2p.io.abc import Reader
|
from libp2p.io.abc import Reader
|
||||||
@ -98,25 +95,3 @@ async def read_fixedint_prefixed(reader: Reader) -> bytes:
|
|||||||
len_bytes = await reader.read(SIZE_LEN_BYTES)
|
len_bytes = await reader.read(SIZE_LEN_BYTES)
|
||||||
len_int = int.from_bytes(len_bytes, "big")
|
len_int = int.from_bytes(len_bytes, "big")
|
||||||
return await reader.read(len_int)
|
return await reader.read(len_int)
|
||||||
|
|
||||||
|
|
||||||
TItem = TypeVar("TItem")
|
|
||||||
|
|
||||||
|
|
||||||
class IQueue(Generic[TItem]):
|
|
||||||
async def put(self, item: TItem):
|
|
||||||
...
|
|
||||||
|
|
||||||
async def get(self) -> TItem:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class TrioQueue(IQueue):
|
|
||||||
def __init__(self):
|
|
||||||
self.send_channel, self.receive_channel = trio.open_memory_channel(0)
|
|
||||||
|
|
||||||
async def put(self, item: TItem):
|
|
||||||
await self.send_channel.send(item)
|
|
||||||
|
|
||||||
async def get(self) -> TItem:
|
|
||||||
return await self.receive_channel.receive()
|
|
||||||
|
|||||||
@ -1,8 +1,5 @@
|
|||||||
import asyncio
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from libp2p.tools.constants import LISTEN_MADDR
|
|
||||||
from libp2p.tools.factories import HostFactory
|
from libp2p.tools.factories import HostFactory
|
||||||
|
|
||||||
|
|
||||||
@ -17,17 +14,6 @@ def num_hosts():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def hosts(num_hosts, is_host_secure):
|
async def hosts(num_hosts, is_host_secure, nursery):
|
||||||
_hosts = HostFactory.create_batch(num_hosts, is_secure=is_host_secure)
|
async with HostFactory.create_batch_and_listen(is_host_secure, num_hosts) as _hosts:
|
||||||
await asyncio.gather(
|
|
||||||
*[_host.get_network().listen(LISTEN_MADDR) for _host in _hosts]
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
yield _hosts
|
yield _hosts
|
||||||
finally:
|
|
||||||
# TODO: It's possible that `close` raises exceptions currently,
|
|
||||||
# due to the connection reset things. Though we don't care much about that when
|
|
||||||
# cleaning up the tasks, it is probably better to handle the exceptions properly.
|
|
||||||
await asyncio.gather(
|
|
||||||
*[_host.close() for _host in _hosts], return_exceptions=True
|
|
||||||
)
|
|
||||||
|
|||||||
@ -1,5 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from libp2p.tools.factories import (
|
from libp2p.tools.factories import (
|
||||||
|
|||||||
@ -1,5 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from libp2p.tools.factories import mplex_conn_pair_factory, mplex_stream_pair_factory
|
from libp2p.tools.factories import mplex_conn_pair_factory, mplex_stream_pair_factory
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
import asyncio
|
import trio
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.trio
|
||||||
async def test_mplex_conn(mplex_conn_pair):
|
async def test_mplex_conn(mplex_conn_pair):
|
||||||
conn_0, conn_1 = mplex_conn_pair
|
conn_0, conn_1 = mplex_conn_pair
|
||||||
|
|
||||||
@ -16,19 +16,19 @@ async def test_mplex_conn(mplex_conn_pair):
|
|||||||
|
|
||||||
# Test: Open a stream, and both side get 1 more stream.
|
# Test: Open a stream, and both side get 1 more stream.
|
||||||
stream_0 = await conn_0.open_stream()
|
stream_0 = await conn_0.open_stream()
|
||||||
await asyncio.sleep(0.01)
|
await trio.sleep(0.01)
|
||||||
assert len(conn_0.streams) == 1
|
assert len(conn_0.streams) == 1
|
||||||
assert len(conn_1.streams) == 1
|
assert len(conn_1.streams) == 1
|
||||||
# Test: From another side.
|
# Test: From another side.
|
||||||
stream_1 = await conn_1.open_stream()
|
stream_1 = await conn_1.open_stream()
|
||||||
await asyncio.sleep(0.01)
|
await trio.sleep(0.01)
|
||||||
assert len(conn_0.streams) == 2
|
assert len(conn_0.streams) == 2
|
||||||
assert len(conn_1.streams) == 2
|
assert len(conn_1.streams) == 2
|
||||||
|
|
||||||
# Close from one side.
|
# Close from one side.
|
||||||
await conn_0.close()
|
await conn_0.close()
|
||||||
# Sleep for a while for both side to handle `close`.
|
# Sleep for a while for both side to handle `close`.
|
||||||
await asyncio.sleep(0.01)
|
await trio.sleep(0.01)
|
||||||
# Test: Both side is closed.
|
# Test: Both side is closed.
|
||||||
assert conn_0.event_shutting_down.is_set()
|
assert conn_0.event_shutting_down.is_set()
|
||||||
assert conn_0.event_closed.is_set()
|
assert conn_0.event_closed.is_set()
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import trio
|
import trio
|
||||||
|
from trio.testing import wait_all_tasks_blocked
|
||||||
|
|
||||||
from libp2p.stream_muxer.mplex.exceptions import (
|
from libp2p.stream_muxer.mplex.exceptions import (
|
||||||
MplexStreamClosed,
|
MplexStreamClosed,
|
||||||
@ -37,37 +38,65 @@ async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair):
|
|||||||
async def read_until_eof():
|
async def read_until_eof():
|
||||||
read_bytes.extend(await stream_1.read())
|
read_bytes.extend(await stream_1.read())
|
||||||
|
|
||||||
task = trio.ensure_future(read_until_eof())
|
|
||||||
|
|
||||||
expected_data = bytearray()
|
expected_data = bytearray()
|
||||||
|
|
||||||
# Test: `read` doesn't return before `close` is called.
|
async with trio.open_nursery() as nursery:
|
||||||
await stream_0.write(DATA)
|
nursery.start_soon(read_until_eof)
|
||||||
expected_data.extend(DATA)
|
# Test: `read` doesn't return before `close` is called.
|
||||||
await trio.sleep(0.01)
|
await stream_0.write(DATA)
|
||||||
assert len(read_bytes) == 0
|
expected_data.extend(DATA)
|
||||||
# Test: `read` doesn't return before `close` is called.
|
await trio.sleep(0.01)
|
||||||
await stream_0.write(DATA)
|
assert len(read_bytes) == 0
|
||||||
expected_data.extend(DATA)
|
# Test: `read` doesn't return before `close` is called.
|
||||||
await trio.sleep(0.01)
|
await stream_0.write(DATA)
|
||||||
assert len(read_bytes) == 0
|
expected_data.extend(DATA)
|
||||||
|
await trio.sleep(0.01)
|
||||||
|
assert len(read_bytes) == 0
|
||||||
|
|
||||||
|
# Test: Close the stream, `read` returns, and receive previous sent data.
|
||||||
|
await stream_0.close()
|
||||||
|
|
||||||
# Test: Close the stream, `read` returns, and receive previous sent data.
|
|
||||||
await stream_0.close()
|
|
||||||
await trio.sleep(0.01)
|
|
||||||
assert read_bytes == expected_data
|
assert read_bytes == expected_data
|
||||||
|
|
||||||
task.cancel()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
async def test_mplex_stream_read_after_remote_closed(mplex_stream_pair):
|
async def test_mplex_stream_read_after_remote_closed(mplex_stream_pair):
|
||||||
stream_0, stream_1 = mplex_stream_pair
|
stream_0, stream_1 = mplex_stream_pair
|
||||||
assert not stream_1.event_remote_closed.is_set()
|
assert not stream_1.event_remote_closed.is_set()
|
||||||
await stream_0.write(DATA)
|
await stream_0.write(DATA)
|
||||||
await stream_0.close()
|
assert not stream_0.event_local_closed.is_set()
|
||||||
await trio.sleep(0.01)
|
await trio.sleep(0.01)
|
||||||
|
await wait_all_tasks_blocked()
|
||||||
|
await stream_0.close()
|
||||||
|
assert stream_0.event_local_closed.is_set()
|
||||||
|
await trio.sleep(0.01)
|
||||||
|
print(
|
||||||
|
"!@# ",
|
||||||
|
stream_0.muxed_conn.event_shutting_down.is_set(),
|
||||||
|
stream_0.muxed_conn.event_closed.is_set(),
|
||||||
|
stream_1.muxed_conn.event_shutting_down.is_set(),
|
||||||
|
stream_1.muxed_conn.event_closed.is_set(),
|
||||||
|
)
|
||||||
|
# await trio.sleep(100000)
|
||||||
|
await wait_all_tasks_blocked()
|
||||||
|
print(
|
||||||
|
"!@# ",
|
||||||
|
stream_0.muxed_conn.event_shutting_down.is_set(),
|
||||||
|
stream_0.muxed_conn.event_closed.is_set(),
|
||||||
|
stream_1.muxed_conn.event_shutting_down.is_set(),
|
||||||
|
stream_1.muxed_conn.event_closed.is_set(),
|
||||||
|
)
|
||||||
|
print("!@# sleeping")
|
||||||
|
print("!@# result=", stream_1.event_remote_closed.is_set())
|
||||||
|
# await trio.sleep_forever()
|
||||||
assert stream_1.event_remote_closed.is_set()
|
assert stream_1.event_remote_closed.is_set()
|
||||||
|
print(
|
||||||
|
"!@# ",
|
||||||
|
stream_0.muxed_conn.event_shutting_down.is_set(),
|
||||||
|
stream_0.muxed_conn.event_closed.is_set(),
|
||||||
|
stream_1.muxed_conn.event_shutting_down.is_set(),
|
||||||
|
stream_1.muxed_conn.event_closed.is_set(),
|
||||||
|
)
|
||||||
assert (await stream_1.read(MAX_READ_LEN)) == DATA
|
assert (await stream_1.read(MAX_READ_LEN)) == DATA
|
||||||
with pytest.raises(MplexStreamEOF):
|
with pytest.raises(MplexStreamEOF):
|
||||||
await stream_1.read(MAX_READ_LEN)
|
await stream_1.read(MAX_READ_LEN)
|
||||||
@ -87,7 +116,8 @@ async def test_mplex_stream_read_after_remote_reset(mplex_stream_pair):
|
|||||||
await stream_0.write(DATA)
|
await stream_0.write(DATA)
|
||||||
await stream_0.reset()
|
await stream_0.reset()
|
||||||
# Sleep to let `stream_1` receive the message.
|
# Sleep to let `stream_1` receive the message.
|
||||||
await trio.sleep(0.01)
|
await trio.sleep(0.1)
|
||||||
|
await wait_all_tasks_blocked()
|
||||||
with pytest.raises(MplexStreamReset):
|
with pytest.raises(MplexStreamReset):
|
||||||
await stream_1.read(MAX_READ_LEN)
|
await stream_1.read(MAX_READ_LEN)
|
||||||
|
|
||||||
|
|||||||
@ -1,19 +0,0 @@
|
|||||||
import pytest
|
|
||||||
import trio
|
|
||||||
|
|
||||||
from libp2p.utils import TrioQueue
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
|
||||||
async def test_trio_queue():
|
|
||||||
queue = TrioQueue()
|
|
||||||
|
|
||||||
async def queue_get(task_status=None):
|
|
||||||
result = await queue.get()
|
|
||||||
task_status.started(result)
|
|
||||||
|
|
||||||
async with trio.open_nursery() as nursery:
|
|
||||||
nursery.start_soon(queue.put, 123)
|
|
||||||
result = await nursery.start(queue_get)
|
|
||||||
|
|
||||||
assert result == 123
|
|
||||||
Reference in New Issue
Block a user