From 96230758e42a8c5cc357cd68f558842da7839815 Mon Sep 17 00:00:00 2001 From: mhchia Date: Thu, 5 Sep 2019 18:18:08 +0800 Subject: [PATCH 1/9] Add events in MplexStream And modify a little bit of `close` and `reset` --- libp2p/stream_muxer/mplex/mplex.py | 24 +++--- libp2p/stream_muxer/mplex/mplex_stream.py | 91 ++++++++++++----------- 2 files changed, 64 insertions(+), 51 deletions(-) diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 5f55a665..cf1ec91d 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -10,6 +10,7 @@ from libp2p.typing import TProtocol from libp2p.utils import ( decode_uvarint_from_stream, encode_uvarint, + encode_varint_prefixed, read_varint_prefixed_bytes, ) @@ -34,6 +35,8 @@ class Mplex(IMuxedConn): buffers: Dict[StreamID, "asyncio.Queue[bytes]"] stream_queue: "asyncio.Queue[StreamID]" next_channel_id: int + buffers_lock: asyncio.Lock + shutdown: asyncio.Event _tasks: List["asyncio.Future[Any]"] @@ -63,6 +66,8 @@ class Mplex(IMuxedConn): # Mapping from stream ID -> buffer of messages for that stream self.buffers = {} + self.buffers_lock = asyncio.Lock() + self.shutdown = asyncio.Event() self.stream_queue = asyncio.Queue() @@ -145,7 +150,7 @@ class Mplex(IMuxedConn): self._tasks.append(asyncio.ensure_future(self.generic_protocol_handler(stream))) async def send_message( - self, flag: HeaderTags, data: bytes, stream_id: StreamID + self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID ) -> int: """ sends a message over the connection @@ -154,19 +159,16 @@ class Mplex(IMuxedConn): :param stream_id: stream the message is in """ # << by 3, then or with flag - header = (stream_id.channel_id << 3) | flag.value - header = encode_uvarint(header) + header = encode_uvarint((stream_id.channel_id << 3) | flag.value) if data is None: - data_length = encode_uvarint(0) - _bytes = header + data_length - else: - data_length = encode_uvarint(len(data)) - _bytes = header + data_length + data + data = b"" + + _bytes = header + encode_varint_prefixed(data) return await self.write_to_stream(_bytes) - async def write_to_stream(self, _bytes: bytearray) -> int: + async def write_to_stream(self, _bytes: bytes) -> int: """ writes a byte array to a secured connection :param _bytes: byte array to write @@ -199,6 +201,10 @@ class Mplex(IMuxedConn): HeaderTags.MessageReceiver.value, ): await self.buffers[stream_id].put(message) + # elif flag in ( + # HeaderTags.CloseInitiator.value, + # HeaderTags.CloseReceiver.value + # ): # Force context switch await asyncio.sleep(0) diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index fe0261be..d0f0801d 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -1,10 +1,14 @@ import asyncio +from typing import TYPE_CHECKING -from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream +from libp2p.stream_muxer.abc import IMuxedStream from .constants import HeaderTags from .datastructures import StreamID +if TYPE_CHECKING: + from libp2p.stream_muxer.mplex.mplex import Mplex + class MplexStream(IMuxedStream): """ @@ -13,16 +17,19 @@ class MplexStream(IMuxedStream): name: str stream_id: StreamID - mplex_conn: IMuxedConn + mplex_conn: "Mplex" read_deadline: int write_deadline: int - local_closed: bool - remote_closed: bool - stream_lock: asyncio.Lock + + close_lock: asyncio.Lock + + event_local_closed: asyncio.Event + event_remote_closed: asyncio.Event + event_reset: asyncio.Event _buf: bytearray - def __init__(self, name: str, stream_id: StreamID, mplex_conn: IMuxedConn) -> None: + def __init__(self, name: str, stream_id: StreamID, mplex_conn: "Mplex") -> None: """ create new MuxedStream in muxer :param stream_id: stream id of this stream @@ -33,9 +40,10 @@ class MplexStream(IMuxedStream): self.mplex_conn = mplex_conn self.read_deadline = None self.write_deadline = None - self.local_closed = False - self.remote_closed = False - self.stream_lock = asyncio.Lock() + self.event_local_closed = asyncio.Event() + self.event_remote_closed = asyncio.Event() + self.event_reset = asyncio.Event() + self.close_lock = asyncio.Lock() self._buf = bytearray() @property @@ -90,63 +98,62 @@ class MplexStream(IMuxedStream): ) return await self.mplex_conn.send_message(flag, data, self.stream_id) - async def close(self) -> bool: + async def close(self) -> None: """ Closing a stream closes it for writing and closes the remote end for reading but allows writing in the other direction. - :return: true if successful """ # TODO error handling with timeout - # TODO understand better how mutexes are used from go repo + + async with self.close_lock: + if self.event_local_closed.is_set(): + return + flag = ( HeaderTags.CloseInitiator if self.is_initiator else HeaderTags.CloseReceiver ) + # TODO: Raise when `mplex_conn.send_message` fails and `Mplex` isn't shutdown. await self.mplex_conn.send_message(flag, None, self.stream_id) - remote_lock = False - async with self.stream_lock: - if self.local_closed: - return True - self.local_closed = True - remote_lock = self.remote_closed + _is_remote_closed: bool + async with self.close_lock: + self.event_local_closed.set() + _is_remote_closed = self.event_remote_closed.is_set() - if remote_lock: - # FIXME: mplex_conn has no conn_lock! - async with self.mplex_conn.conn_lock: # type: ignore - # FIXME: Don't access to buffers directly - self.mplex_conn.buffers.pop(self.stream_id) # type: ignore + if _is_remote_closed: + # Both sides are closed, we can safely remove the buffer from the dict. + async with self.mplex_conn.buffers_lock: + del self.mplex_conn.buffers[self.stream_id] - return True - - async def reset(self) -> bool: + async def reset(self) -> None: """ closes both ends of the stream tells this remote side to hang up - :return: true if successful """ - # TODO understand better how mutexes are used here - # TODO understand the difference between close and reset - async with self.stream_lock: - if self.remote_closed and self.local_closed: - return True + async with self.close_lock: + # Both sides have been closed. No need to event_reset. + if self.event_remote_closed.is_set() and self.event_local_closed.is_set(): + return + if self.event_reset.is_set(): + return + self.event_reset.set() - if not self.remote_closed: + if not self.event_remote_closed.is_set(): flag = ( HeaderTags.ResetInitiator if self.is_initiator else HeaderTags.ResetReceiver ) - await self.mplex_conn.send_message(flag, None, self.stream_id) + asyncio.ensure_future( + self.mplex_conn.send_message(flag, None, self.stream_id) + ) + await asyncio.sleep(0) - self.local_closed = True - self.remote_closed = True + self.event_local_closed.set() + self.event_remote_closed.set() - # FIXME: mplex_conn has no conn_lock! - async with self.mplex_conn.conn_lock: # type: ignore - # FIXME: Don't access to buffers directly - self.mplex_conn.buffers.pop(self.stream_id, None) # type: ignore - - return True + async with self.mplex_conn.buffers_lock: + del self.mplex_conn.buffers[self.stream_id] # TODO deadline not in use def set_deadline(self, ttl: int) -> bool: From eac159c527ea83fb51f5a32b986cfccf4c2776e8 Mon Sep 17 00:00:00 2001 From: mhchia Date: Thu, 5 Sep 2019 22:29:33 +0800 Subject: [PATCH 2/9] Restructure mplex and mplex_stream --- libp2p/stream_muxer/abc.py | 18 +----- libp2p/stream_muxer/mplex/exceptions.py | 13 +++- libp2p/stream_muxer/mplex/mplex.py | 76 +++++++++-------------- libp2p/stream_muxer/mplex/mplex_stream.py | 19 +++--- 4 files changed, 55 insertions(+), 71 deletions(-) diff --git a/libp2p/stream_muxer/abc.py b/libp2p/stream_muxer/abc.py index 0600deeb..6e7737ee 100644 --- a/libp2p/stream_muxer/abc.py +++ b/libp2p/stream_muxer/abc.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from libp2p.peer.id import ID from libp2p.security.secure_conn_interface import ISecureConn @@ -51,20 +51,6 @@ class IMuxedConn(ABC): :return: true if successful """ - @abstractmethod - async def read_buffer(self, stream_id: StreamID) -> bytes: - """ - Read a message from stream_id's buffer, check raw connection for new messages - :param stream_id: stream id of stream to read from - :return: message read - """ - - @abstractmethod - async def read_buffer_nonblocking(self, stream_id: StreamID) -> Optional[bytes]: - """ - Read a message from `stream_id`'s buffer, non-blockingly. - """ - @abstractmethod async def open_stream(self) -> "IMuxedStream": """ @@ -73,7 +59,7 @@ class IMuxedConn(ABC): """ @abstractmethod - async def accept_stream(self, name: str) -> None: + async def accept_stream(self, stream_id: StreamID, name: str) -> None: """ accepts a muxed stream opened by the other end """ diff --git a/libp2p/stream_muxer/mplex/exceptions.py b/libp2p/stream_muxer/mplex/exceptions.py index 74a6ade8..bd4ceb56 100644 --- a/libp2p/stream_muxer/mplex/exceptions.py +++ b/libp2p/stream_muxer/mplex/exceptions.py @@ -1,2 +1,13 @@ -class StreamNotFound(Exception): +from libp2p.exceptions import BaseLibp2pError + + +class MplexError(BaseLibp2pError): + pass + + +class MplexShutdown(MplexError): + pass + + +class StreamNotFound(MplexError): pass diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index cf1ec91d..af1282e9 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -16,7 +16,6 @@ from libp2p.utils import ( from .constants import HeaderTags from .datastructures import StreamID -from .exceptions import StreamNotFound from .mplex_stream import MplexStream MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0") @@ -32,10 +31,9 @@ class Mplex(IMuxedConn): # TODO: `dataIn` in go implementation. Should be size of 8. # TODO: Also, `dataIn` is closed indicating EOF in Go. We don't have similar strategies # to let the `MplexStream`s know that EOF arrived (#235). - buffers: Dict[StreamID, "asyncio.Queue[bytes]"] - stream_queue: "asyncio.Queue[StreamID]" next_channel_id: int - buffers_lock: asyncio.Lock + streams: Dict[StreamID, MplexStream] + streams_lock: asyncio.Lock shutdown: asyncio.Event _tasks: List["asyncio.Future[Any]"] @@ -65,12 +63,10 @@ class Mplex(IMuxedConn): self.peer_id = peer_id # Mapping from stream ID -> buffer of messages for that stream - self.buffers = {} - self.buffers_lock = asyncio.Lock() + self.streams = {} + self.streams_lock = asyncio.Lock() self.shutdown = asyncio.Event() - self.stream_queue = asyncio.Queue() - self._tasks = [] # Kick off reading @@ -95,29 +91,6 @@ class Mplex(IMuxedConn): """ raise NotImplementedError() - async def read_buffer(self, stream_id: StreamID) -> bytes: - """ - Read a message from buffer of the stream specified by `stream_id`, - check secured connection for new messages. - `StreamNotFound` is raised when stream `stream_id` is not found in `Mplex`. - :param stream_id: stream id of stream to read from - :return: message read - """ - if stream_id not in self.buffers: - raise StreamNotFound(f"stream {stream_id} is not found") - return await self.buffers[stream_id].get() - - async def read_buffer_nonblocking(self, stream_id: StreamID) -> Optional[bytes]: - """ - Read a message from buffer of the stream specified by `stream_id`, non-blockingly. - `StreamNotFound` is raised when stream `stream_id` is not found in `Mplex`. - """ - if stream_id not in self.buffers: - raise StreamNotFound(f"stream {stream_id} is not found") - if self.buffers[stream_id].empty(): - return None - return await self.buffers[stream_id].get() - def _get_next_channel_id(self) -> int: """ Get next available stream id @@ -127,6 +100,12 @@ class Mplex(IMuxedConn): self.next_channel_id += 1 return next_id + async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream: + async with self.streams_lock: + stream = MplexStream(name, stream_id, self) + self.streams[stream_id] = stream + return stream + async def open_stream(self) -> IMuxedStream: """ creates a new muxed_stream @@ -134,19 +113,18 @@ class Mplex(IMuxedConn): """ channel_id = self._get_next_channel_id() stream_id = StreamID(channel_id=channel_id, is_initiator=True) - name = str(channel_id) - stream = MplexStream(name, stream_id, self) - self.buffers[stream_id] = asyncio.Queue() # Default stream name is the `channel_id` + name = str(channel_id) + stream = await self._initialize_stream(stream_id, name) await self.send_message(HeaderTags.NewStream, name.encode(), stream_id) return stream - async def accept_stream(self, name: str) -> None: + async def accept_stream(self, stream_id: StreamID, name: str) -> None: """ accepts a muxed stream opened by the other end """ - stream_id = await self.stream_queue.get() - stream = MplexStream(name, stream_id, self) + stream = await self._initialize_stream(stream_id, name) + # Perform protocol negotiation for the stream. self._tasks.append(asyncio.ensure_future(self.generic_protocol_handler(stream))) async def send_message( @@ -185,22 +163,30 @@ class Mplex(IMuxedConn): while True: channel_id, flag, message = await self.read_message() - if channel_id is not None and flag is not None and message is not None: stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1)) - if stream_id not in self.buffers: - self.buffers[stream_id] = asyncio.Queue() - await self.stream_queue.put(stream_id) - + is_stream_id_seen: bool + async with self.streams_lock: + is_stream_id_seen = stream_id in self.streams + # Other consequent stream message should wait until the stream get accepted # TODO: Handle more tags, and refactor `HeaderTags` if flag == HeaderTags.NewStream.value: - # new stream detected on connection - await self.accept_stream(message.decode()) + if is_stream_id_seen: + # `NewStream` for the same id is received twice... + pass + await self.accept_stream(stream_id, message.decode()) elif flag in ( HeaderTags.MessageInitiator.value, HeaderTags.MessageReceiver.value, ): - await self.buffers[stream_id].put(message) + if not is_stream_id_seen: + # We receive a message of the stream `stream_id` which is not accepted + # before. It is abnormal. Possibly disconnect? + # TODO: Warn and emit logs about this. + continue + async with self.streams_lock: + stream = self.streams[stream_id] + await stream.incoming_data.put(message) # elif flag in ( # HeaderTags.CloseInitiator.value, # HeaderTags.CloseReceiver.value diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index d0f0801d..d257297f 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -23,6 +23,8 @@ class MplexStream(IMuxedStream): close_lock: asyncio.Lock + incoming_data: "asyncio.Queue[bytes]" + event_local_closed: asyncio.Event event_remote_closed: asyncio.Event event_reset: asyncio.Event @@ -44,6 +46,7 @@ class MplexStream(IMuxedStream): self.event_remote_closed = asyncio.Event() self.event_reset = asyncio.Event() self.close_lock = asyncio.Lock() + self.incoming_data = asyncio.Queue() self._buf = bytearray() @property @@ -58,7 +61,6 @@ class MplexStream(IMuxedStream): :param n: number of bytes to read :return: bytes actually read """ - # TODO: Handle `StreamNotFound` raised in `self.mplex_conn.read_buffer`. # TODO: Add exceptions and handle/raise them in this class. if n < 0 and n != -1: raise ValueError( @@ -66,17 +68,16 @@ class MplexStream(IMuxedStream): ) # If the buffer is empty at first, blocking wait for data. if len(self._buf) == 0: - self._buf.extend(await self.mplex_conn.read_buffer(self.stream_id)) + self._buf.extend(await self.incoming_data.get()) # FIXME: If `n == -1`, we should blocking read until EOF, instead of returning when # no message is available. # If `n >= 0`, read up to `n` bytes. # Else, read until no message is available. while len(self._buf) < n or n == -1: - new_bytes = await self.mplex_conn.read_buffer_nonblocking(self.stream_id) - if new_bytes is None: - # Nothing to read in the `MplexConn` buffer + if self.incoming_data.empty(): break + new_bytes = await self.incoming_data.get() self._buf.extend(new_bytes) payload: bytearray if n == -1: @@ -122,8 +123,8 @@ class MplexStream(IMuxedStream): if _is_remote_closed: # Both sides are closed, we can safely remove the buffer from the dict. - async with self.mplex_conn.buffers_lock: - del self.mplex_conn.buffers[self.stream_id] + async with self.mplex_conn.streams_lock: + del self.mplex_conn.streams[self.stream_id] async def reset(self) -> None: """ @@ -152,8 +153,8 @@ class MplexStream(IMuxedStream): self.event_local_closed.set() self.event_remote_closed.set() - async with self.mplex_conn.buffers_lock: - del self.mplex_conn.buffers[self.stream_id] + async with self.mplex_conn.streams_lock: + del self.mplex_conn.streams[self.stream_id] # TODO deadline not in use def set_deadline(self, ttl: int) -> bool: From 10415cb95638acbfb2c6ae392a768b04ab4d0446 Mon Sep 17 00:00:00 2001 From: mhchia Date: Thu, 5 Sep 2019 23:24:17 +0800 Subject: [PATCH 3/9] Use `ReadWriteCloser` for conns and streams --- libp2p/network/stream/net_stream.py | 3 +- libp2p/network/stream/net_stream_interface.py | 27 ++--------------- libp2p/network/swarm.py | 6 ++-- .../multiselect_communicator.py | 30 +++++-------------- libp2p/security/security_multistream.py | 4 +-- libp2p/stream_muxer/abc.py | 25 ++-------------- libp2p/stream_muxer/muxer_multistream.py | 4 +-- libp2p/typing.py | 6 +--- libp2p/utils.py | 7 ++--- 9 files changed, 24 insertions(+), 88 deletions(-) diff --git a/libp2p/network/stream/net_stream.py b/libp2p/network/stream/net_stream.py index ff78f5a8..7383f736 100644 --- a/libp2p/network/stream/net_stream.py +++ b/libp2p/network/stream/net_stream.py @@ -44,13 +44,12 @@ class NetStream(INetStream): """ return await self.muxed_stream.write(data) - async def close(self) -> bool: + async def close(self) -> None: """ close stream :return: true if successful """ await self.muxed_stream.close() - return True async def reset(self) -> bool: return await self.muxed_stream.reset() diff --git a/libp2p/network/stream/net_stream_interface.py b/libp2p/network/stream/net_stream_interface.py index 4df95d8a..aaa775a3 100644 --- a/libp2p/network/stream/net_stream_interface.py +++ b/libp2p/network/stream/net_stream_interface.py @@ -1,10 +1,11 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod +from libp2p.io.abc import ReadWriteCloser from libp2p.stream_muxer.abc import IMuxedConn from libp2p.typing import TProtocol -class INetStream(ABC): +class INetStream(ReadWriteCloser): mplex_conn: IMuxedConn @@ -21,28 +22,6 @@ class INetStream(ABC): :return: true if successful """ - @abstractmethod - async def read(self, n: int = -1) -> bytes: - """ - reads from the underlying muxed_stream - :param n: number of bytes to read - :return: bytes of input - """ - - @abstractmethod - async def write(self, data: bytes) -> int: - """ - write to the underlying muxed_stream - :return: number of bytes written - """ - - @abstractmethod - async def close(self) -> bool: - """ - close the underlying muxed stream - :return: true if successful - """ - @abstractmethod async def reset(self) -> bool: """ diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index bccfdac1..38cbf719 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -8,7 +8,7 @@ from libp2p.peer.peerstore import PeerStoreError from libp2p.peer.peerstore_interface import IPeerStore from libp2p.protocol_muxer.multiselect import Multiselect from libp2p.protocol_muxer.multiselect_client import MultiselectClient -from libp2p.protocol_muxer.multiselect_communicator import StreamCommunicator +from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator from libp2p.routing.interfaces import IPeerRouting from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream from libp2p.transport.exceptions import MuxerUpgradeFailure, SecurityUpgradeFailure @@ -161,7 +161,7 @@ class Swarm(INetwork): # Perform protocol muxing to determine protocol to use selected_protocol = await self.multiselect_client.select_one_of( - list(protocol_ids), StreamCommunicator(muxed_stream) + list(protocol_ids), MultiselectCommunicator(muxed_stream) ) # Create a net stream with the selected protocol @@ -294,7 +294,7 @@ def create_generic_protocol_handler(swarm: Swarm) -> GenericProtocolHandlerFn: async def generic_protocol_handler(muxed_stream: IMuxedStream) -> None: # Perform protocol muxing to determine protocol to use protocol, handler = await multiselect.negotiate( - StreamCommunicator(muxed_stream) + MultiselectCommunicator(muxed_stream) ) net_stream = NetStream(muxed_stream) diff --git a/libp2p/protocol_muxer/multiselect_communicator.py b/libp2p/protocol_muxer/multiselect_communicator.py index e252304c..59252c56 100644 --- a/libp2p/protocol_muxer/multiselect_communicator.py +++ b/libp2p/protocol_muxer/multiselect_communicator.py @@ -1,35 +1,19 @@ -from libp2p.network.connection.raw_connection_interface import IRawConnection -from libp2p.stream_muxer.abc import IMuxedStream +from libp2p.io.abc import ReadWriteCloser from libp2p.utils import encode_delim, read_delim from .multiselect_communicator_interface import IMultiselectCommunicator -class RawConnectionCommunicator(IMultiselectCommunicator): - conn: IRawConnection +class MultiselectCommunicator(IMultiselectCommunicator): + read_writer: ReadWriteCloser - def __init__(self, conn: IRawConnection) -> None: - self.conn = conn + def __init__(self, read_writer: ReadWriteCloser) -> None: + self.read_writer = read_writer async def write(self, msg_str: str) -> None: msg_bytes = encode_delim(msg_str.encode()) - await self.conn.write(msg_bytes) + await self.read_writer.write(msg_bytes) async def read(self) -> str: - data = await read_delim(self.conn) - return data.decode() - - -class StreamCommunicator(IMultiselectCommunicator): - stream: IMuxedStream - - def __init__(self, stream: IMuxedStream) -> None: - self.stream = stream - - async def write(self, msg_str: str) -> None: - msg_bytes = encode_delim(msg_str.encode()) - await self.stream.write(msg_bytes) - - async def read(self) -> str: - data = await read_delim(self.stream) + data = await read_delim(self.read_writer) return data.decode() diff --git a/libp2p/security/security_multistream.py b/libp2p/security/security_multistream.py index 6e69d7a0..466d60a8 100644 --- a/libp2p/security/security_multistream.py +++ b/libp2p/security/security_multistream.py @@ -6,7 +6,7 @@ from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.peer.id import ID from libp2p.protocol_muxer.multiselect import Multiselect from libp2p.protocol_muxer.multiselect_client import MultiselectClient -from libp2p.protocol_muxer.multiselect_communicator import RawConnectionCommunicator +from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_transport_interface import ISecureTransport from libp2p.typing import TProtocol @@ -88,7 +88,7 @@ class SecurityMultistream(ABC): :return: selected secure transport """ protocol: TProtocol - communicator = RawConnectionCommunicator(conn) + communicator = MultiselectCommunicator(conn) if initiator: # Select protocol if initiator protocol = await self.multiselect_client.select_one_of( diff --git a/libp2p/stream_muxer/abc.py b/libp2p/stream_muxer/abc.py index 6e7737ee..2c577c7a 100644 --- a/libp2p/stream_muxer/abc.py +++ b/libp2p/stream_muxer/abc.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING +from libp2p.io.abc import ReadWriteCloser from libp2p.peer.id import ID from libp2p.security.secure_conn_interface import ISecureConn from libp2p.stream_muxer.mplex.constants import HeaderTags @@ -76,32 +77,10 @@ class IMuxedConn(ABC): """ -class IMuxedStream(ABC): +class IMuxedStream(ReadWriteCloser): mplex_conn: IMuxedConn - @abstractmethod - async def read(self, n: int = -1) -> bytes: - """ - reads from the underlying muxed_conn - :param n: number of bytes to read - :return: bytes of input - """ - - @abstractmethod - async def write(self, data: bytes) -> int: - """ - writes to the underlying muxed_conn - :return: number of bytes written - """ - - @abstractmethod - async def close(self) -> bool: - """ - close the underlying muxed_conn - :return: true if successful - """ - @abstractmethod async def reset(self) -> bool: """ diff --git a/libp2p/stream_muxer/muxer_multistream.py b/libp2p/stream_muxer/muxer_multistream.py index 703c4e2d..b118cee2 100644 --- a/libp2p/stream_muxer/muxer_multistream.py +++ b/libp2p/stream_muxer/muxer_multistream.py @@ -6,7 +6,7 @@ from libp2p.network.typing import GenericProtocolHandlerFn from libp2p.peer.id import ID from libp2p.protocol_muxer.multiselect import Multiselect from libp2p.protocol_muxer.multiselect_client import MultiselectClient -from libp2p.protocol_muxer.multiselect_communicator import RawConnectionCommunicator +from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator from libp2p.security.secure_conn_interface import ISecureConn from libp2p.typing import TProtocol @@ -60,7 +60,7 @@ class MuxerMultistream: :return: selected muxer transport """ protocol: TProtocol - communicator = RawConnectionCommunicator(conn) + communicator = MultiselectCommunicator(conn) if conn.initiator: protocol = await self.multiselect_client.select_one_of( tuple(self.transports.keys()), communicator diff --git a/libp2p/typing.py b/libp2p/typing.py index ba776e19..be0b584e 100644 --- a/libp2p/typing.py +++ b/libp2p/typing.py @@ -1,6 +1,4 @@ -from typing import TYPE_CHECKING, Awaitable, Callable, NewType, Union - -from libp2p.network.connection.raw_connection_interface import IRawConnection +from typing import TYPE_CHECKING, Awaitable, Callable, NewType if TYPE_CHECKING: from libp2p.network.stream.net_stream_interface import INetStream # noqa: F401 @@ -8,5 +6,3 @@ if TYPE_CHECKING: TProtocol = NewType("TProtocol", str) StreamHandlerFn = Callable[["INetStream"], Awaitable[None]] - -StreamReader = Union["IMuxedStream", "INetStream", IRawConnection] diff --git a/libp2p/utils.py b/libp2p/utils.py index e1c45fd8..7374993d 100644 --- a/libp2p/utils.py +++ b/libp2p/utils.py @@ -3,7 +3,6 @@ import math from libp2p.exceptions import ParseError from libp2p.io.abc import Reader -from libp2p.typing import StreamReader # Unsigned LEB128(varint codec) # Reference: https://github.com/ethereum/py-wasm/blob/master/wasm/parsers/leb128.py @@ -31,7 +30,7 @@ def encode_uvarint(number: int) -> bytes: return buf -async def decode_uvarint_from_stream(reader: StreamReader) -> int: +async def decode_uvarint_from_stream(reader: Reader) -> int: """ https://en.wikipedia.org/wiki/LEB128 """ @@ -61,7 +60,7 @@ def encode_varint_prefixed(msg_bytes: bytes) -> bytes: return varint_len + msg_bytes -async def read_varint_prefixed_bytes(reader: StreamReader) -> bytes: +async def read_varint_prefixed_bytes(reader: Reader) -> bytes: len_msg = await decode_uvarint_from_stream(reader) data = await reader.read(len_msg) if len(data) != len_msg: @@ -80,7 +79,7 @@ def encode_delim(msg: bytes) -> bytes: return encode_varint_prefixed(delimited_msg) -async def read_delim(reader: StreamReader) -> bytes: +async def read_delim(reader: Reader) -> bytes: msg_bytes = await read_varint_prefixed_bytes(reader) # TODO: Investigate if it is possible to have empty `msg_bytes` if len(msg_bytes) != 0 and msg_bytes[-1:] != b"\n": From 207fa75d8f596d2988a9bcdc2150634aee23a0d9 Mon Sep 17 00:00:00 2001 From: mhchia Date: Thu, 5 Sep 2019 23:44:22 +0800 Subject: [PATCH 4/9] Add `reset` and `close` --- libp2p/stream_muxer/abc.py | 3 +- libp2p/stream_muxer/mplex/mplex.py | 48 +++++++++++++++++++++++++++--- 2 files changed, 45 insertions(+), 6 deletions(-) diff --git a/libp2p/stream_muxer/abc.py b/libp2p/stream_muxer/abc.py index 2c577c7a..547e917f 100644 --- a/libp2p/stream_muxer/abc.py +++ b/libp2p/stream_muxer/abc.py @@ -82,11 +82,10 @@ class IMuxedStream(ReadWriteCloser): mplex_conn: IMuxedConn @abstractmethod - async def reset(self) -> bool: + async def reset(self) -> None: """ closes both ends of the stream tells this remote side to hang up - :return: true if successful """ @abstractmethod diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index af1282e9..f342978e 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -173,6 +173,7 @@ class Mplex(IMuxedConn): if flag == HeaderTags.NewStream.value: if is_stream_id_seen: # `NewStream` for the same id is received twice... + # TODO: Shutdown pass await self.accept_stream(stream_id, message.decode()) elif flag in ( @@ -187,10 +188,49 @@ class Mplex(IMuxedConn): async with self.streams_lock: stream = self.streams[stream_id] await stream.incoming_data.put(message) - # elif flag in ( - # HeaderTags.CloseInitiator.value, - # HeaderTags.CloseReceiver.value - # ): + elif flag in ( + HeaderTags.CloseInitiator.value, + HeaderTags.CloseReceiver.value, + ): + if not is_stream_id_seen: + continue + stream: MplexStream + async with self.streams_lock: + stream = self.streams[stream_id] + is_local_closed: bool + async with stream.close_lock: + stream.event_remote_closed.set() + is_local_closed = stream.event_local_closed.is_set() + # If local is also closed, both sides are closed. Then, we should clean up + # this stream. + if is_local_closed: + async with self.streams_lock: + del self.streams[stream_id] + elif flag in ( + HeaderTags.ResetInitiator.value, + HeaderTags.ResetReceiver.value, + ): + if not is_stream_id_seen: + # This is *ok*. We forget the stream on reset. + continue + stream: MplexStream + async with self.streams_lock: + stream = self.streams[stream_id] + async with stream.close_lock: + if not stream.event_remote_closed.is_set(): + stream.event_reset.set() + stream.event_remote_closed.set() + if not stream.event_local_closed.is_set(): + stream.event_local_closed.close() + async with self.streams_lock: + del self.streams[stream_id] + else: + # TODO: logging + print(f"message with unknown header on stream {stream_id}") + if is_stream_id_seen: + async with self.streams_lock: + stream = self.streams[stream_id] + await stream.reset() # Force context switch await asyncio.sleep(0) From 95926b7376d1476fe39d4449abaf37e1d3399e21 Mon Sep 17 00:00:00 2001 From: mhchia Date: Fri, 6 Sep 2019 01:08:42 +0800 Subject: [PATCH 5/9] Temp for mplex_stream --- libp2p/stream_muxer/mplex/exceptions.py | 12 ++++++--- libp2p/stream_muxer/mplex/mplex_stream.py | 31 ++++++++++++++++++----- 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/libp2p/stream_muxer/mplex/exceptions.py b/libp2p/stream_muxer/mplex/exceptions.py index bd4ceb56..11663e2e 100644 --- a/libp2p/stream_muxer/mplex/exceptions.py +++ b/libp2p/stream_muxer/mplex/exceptions.py @@ -5,9 +5,13 @@ class MplexError(BaseLibp2pError): pass +class MplexStreamReset(MplexError): + pass + + +class MplexStreamEOF(MplexError, EOFError): + pass + + class MplexShutdown(MplexError): pass - - -class StreamNotFound(MplexError): - pass diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index d257297f..4f2e76ce 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING from libp2p.stream_muxer.abc import IMuxedStream from .constants import HeaderTags +from .exceptions import MplexStreamReset, MplexStreamEOF from .datastructures import StreamID if TYPE_CHECKING: @@ -53,6 +54,26 @@ class MplexStream(IMuxedStream): def is_initiator(self) -> bool: return self.stream_id.is_initiator + async def _wait_for_data(self) -> None: + print("!@# _wait_for_data: 0") + done, pending = await asyncio.wait( + [ + self.event_reset.wait(), + self.event_remote_closed.wait(), + self.incoming_data.get(), + ], + return_when=asyncio.FIRST_COMPLETED, + ) + print("!@# _wait_for_data: 1") + if self.event_reset.is_set(): + raise MplexStreamReset + if self.event_remote_closed.is_set(): + while not self.incoming_data.empty(): + self._buf.extend(await self.incoming_data.get()) + raise MplexStreamEOF + data = tuple(done)[0].result() + self._buf.extend(data) + async def read(self, n: int = -1) -> bytes: """ Read up to n bytes. Read possibly returns fewer than `n` bytes, @@ -66,19 +87,17 @@ class MplexStream(IMuxedStream): raise ValueError( f"the number of bytes to read `n` must be positive or -1 to indicate read until EOF" ) - # If the buffer is empty at first, blocking wait for data. - if len(self._buf) == 0: - self._buf.extend(await self.incoming_data.get()) # FIXME: If `n == -1`, we should blocking read until EOF, instead of returning when # no message is available. # If `n >= 0`, read up to `n` bytes. # Else, read until no message is available. while len(self._buf) < n or n == -1: - if self.incoming_data.empty(): + # new_bytes = await self.incoming_data.get() + try: + await self._wait_for_data() + except MplexStreamEOF: break - new_bytes = await self.incoming_data.get() - self._buf.extend(new_bytes) payload: bytearray if n == -1: payload = self._buf From 649a2307769ab4b266c69d48f8131ff32f39bf9f Mon Sep 17 00:00:00 2001 From: mhchia Date: Fri, 6 Sep 2019 17:26:40 +0800 Subject: [PATCH 6/9] Fix `MplexStream.read` --- examples/chat/chat.py | 4 +- examples/echo/echo.py | 5 +- libp2p/stream_muxer/mplex/mplex_stream.py | 42 ++++---- tests/examples/test_chat.py | 28 +++--- tests/libp2p/test_libp2p.py | 51 +++++----- tests/libp2p/test_notify.py | 106 ++++++++------------ tests/protocol_muxer/test_protocol_muxer.py | 17 +--- tests/utils.py | 6 +- 8 files changed, 116 insertions(+), 143 deletions(-) diff --git a/examples/chat/chat.py b/examples/chat/chat.py index 39258b54..24c92699 100755 --- a/examples/chat/chat.py +++ b/examples/chat/chat.py @@ -11,11 +11,12 @@ from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.typing import TProtocol PROTOCOL_ID = TProtocol("/chat/1.0.0") +MAX_READ_LEN = 2 ** 32 - 1 async def read_data(stream: INetStream) -> None: while True: - read_bytes = await stream.read() + read_bytes = await stream.read(MAX_READ_LEN) if read_bytes is not None: read_string = read_bytes.decode() if read_string != "\n": @@ -24,7 +25,6 @@ async def read_data(stream: INetStream) -> None: print("\x1b[32m %s\x1b[0m " % read_string, end="") -# FIXME(mhchia): Reconsider whether we should use a thread pool here. async def write_data(stream: INetStream) -> None: loop = asyncio.get_event_loop() while True: diff --git a/examples/echo/echo.py b/examples/echo/echo.py index 06e4f177..3f3ed33e 100644 --- a/examples/echo/echo.py +++ b/examples/echo/echo.py @@ -14,6 +14,7 @@ PROTOCOL_ID = TProtocol("/echo/1.0.0") async def _echo_stream_handler(stream: INetStream) -> None: + # Wait until EOF msg = await stream.read() await stream.write(msg) await stream.close() @@ -72,13 +73,13 @@ async def run(port: int, destination: str, localhost: bool, seed: int = None) -> msg = b"hi, there!\n" await stream.write(msg) + # Notify the other side about EOF + await stream.close() response = await stream.read() print(f"Sent: {msg}") print(f"Got: {response}") - await stream.close() - def main() -> None: description = """ diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 4f2e76ce..e537dda3 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -55,7 +55,6 @@ class MplexStream(IMuxedStream): return self.stream_id.is_initiator async def _wait_for_data(self) -> None: - print("!@# _wait_for_data: 0") done, pending = await asyncio.wait( [ self.event_reset.wait(), @@ -64,16 +63,25 @@ class MplexStream(IMuxedStream): ], return_when=asyncio.FIRST_COMPLETED, ) - print("!@# _wait_for_data: 1") if self.event_reset.is_set(): raise MplexStreamReset if self.event_remote_closed.is_set(): - while not self.incoming_data.empty(): - self._buf.extend(await self.incoming_data.get()) raise MplexStreamEOF + # TODO: Handle timeout when deadline is used. + data = tuple(done)[0].result() self._buf.extend(data) + async def _read_until_eof(self) -> bytes: + while True: + try: + await self._wait_for_data() + except MplexStreamEOF: + break + payload = self._buf + self._buf = self._buf[len(payload) :] + return bytes(payload) + async def read(self, n: int = -1) -> bytes: """ Read up to n bytes. Read possibly returns fewer than `n` bytes, @@ -87,22 +95,18 @@ class MplexStream(IMuxedStream): raise ValueError( f"the number of bytes to read `n` must be positive or -1 to indicate read until EOF" ) - - # FIXME: If `n == -1`, we should blocking read until EOF, instead of returning when - # no message is available. - # If `n >= 0`, read up to `n` bytes. - # Else, read until no message is available. - while len(self._buf) < n or n == -1: - # new_bytes = await self.incoming_data.get() - try: - await self._wait_for_data() - except MplexStreamEOF: - break - payload: bytearray + if self.event_reset.is_set(): + raise MplexStreamReset if n == -1: - payload = self._buf - else: - payload = self._buf[:n] + return await self._read_until_eof() + if len(self._buf) == 0: + await self._wait_for_data() + # Read up to `n` bytes. + while len(self._buf) < n: + if self.incoming_data.empty() or self.event_remote_closed.is_set(): + break + self._buf.extend(await self.incoming_data.get()) + payload = self._buf[:n] self._buf = self._buf[len(payload) :] return bytes(payload) diff --git a/tests/examples/test_chat.py b/tests/examples/test_chat.py index f461d9da..75d8ec71 100644 --- a/tests/examples/test_chat.py +++ b/tests/examples/test_chat.py @@ -10,10 +10,13 @@ PROTOCOL_ID = "/chat/1.0.0" async def hello_world(host_a, host_b): + hello_world_from_host_a = b"hello world from host a" + hello_world_from_host_b = b"hello world from host b" + async def stream_handler(stream): - read = await stream.read() - assert read == b"hello world from host b" - await stream.write(b"hello world from host a") + read = await stream.read(len(hello_world_from_host_b)) + assert read == hello_world_from_host_b + await stream.write(hello_world_from_host_a) await stream.close() host_a.set_stream_handler(PROTOCOL_ID, stream_handler) @@ -21,9 +24,9 @@ async def hello_world(host_a, host_b): # Start a stream with the destination. # Multiaddress of the destination peer is fetched from the peerstore using 'peerId'. stream = await host_b.new_stream(host_a.get_id(), [PROTOCOL_ID]) - await stream.write(b"hello world from host b") + await stream.write(hello_world_from_host_b) read = await stream.read() - assert read == b"hello world from host a" + assert read == hello_world_from_host_a await stream.close() @@ -32,11 +35,8 @@ async def connect_write(host_a, host_b): received = [] async def stream_handler(stream): - while True: - try: - received.append((await stream.read()).decode()) - except Exception: # exception is raised when other side close the stream ? - break + for message in messages: + received.append((await stream.read(len(message))).decode()) host_a.set_stream_handler(PROTOCOL_ID, stream_handler) @@ -67,12 +67,8 @@ async def connect_read(host_a, host_b): # Multiaddress of the destination peer is fetched from the peerstore using 'peerId'. stream = await host_b.new_stream(host_a.get_id(), [PROTOCOL_ID]) received = [] - # while True: Seems the close stream event from the other host is not received - for _ in range(5): - try: - received.append(await stream.read()) - except Exception: # exception is raised when other side close the stream ? - break + for message in messages: + received.append(await stream.read(len(message))) await stream.close() assert received == messages diff --git a/tests/libp2p/test_libp2p.py b/tests/libp2p/test_libp2p.py index bc58a8c5..b4a643d2 100644 --- a/tests/libp2p/test_libp2p.py +++ b/tests/libp2p/test_libp2p.py @@ -3,6 +3,7 @@ import pytest from libp2p.peer.peerinfo import info_from_p2p_addr from tests.utils import cleanup, set_up_nodes_by_transport_opt +from tests.constants import MAX_READ_LEN @pytest.mark.asyncio @@ -12,7 +13,7 @@ async def test_simple_messages(): async def stream_handler(stream): while True: - read_string = (await stream.read()).decode() + read_string = (await stream.read(MAX_READ_LEN)).decode() response = "ack:" + read_string await stream.write(response.encode()) @@ -28,7 +29,7 @@ async def test_simple_messages(): for message in messages: await stream.write(message.encode()) - response = (await stream.read()).decode() + response = (await stream.read(MAX_READ_LEN)).decode() assert response == ("ack:" + message) @@ -43,7 +44,7 @@ async def test_double_response(): async def stream_handler(stream): while True: - read_string = (await stream.read()).decode() + read_string = (await stream.read(MAX_READ_LEN)).decode() response = "ack1:" + read_string await stream.write(response.encode()) @@ -61,10 +62,10 @@ async def test_double_response(): for message in messages: await stream.write(message.encode()) - response1 = (await stream.read()).decode() + response1 = (await stream.read(MAX_READ_LEN)).decode() assert response1 == ("ack1:" + message) - response2 = (await stream.read()).decode() + response2 = (await stream.read(MAX_READ_LEN)).decode() assert response2 == ("ack2:" + message) # Success, terminate pending tasks. @@ -80,14 +81,14 @@ async def test_multiple_streams(): async def stream_handler_a(stream): while True: - read_string = (await stream.read()).decode() + read_string = (await stream.read(MAX_READ_LEN)).decode() response = "ack_a:" + read_string await stream.write(response.encode()) async def stream_handler_b(stream): while True: - read_string = (await stream.read()).decode() + read_string = (await stream.read(MAX_READ_LEN)).decode() response = "ack_b:" + read_string await stream.write(response.encode()) @@ -111,8 +112,8 @@ async def test_multiple_streams(): await stream_a.write(a_message.encode()) await stream_b.write(b_message.encode()) - response_a = (await stream_a.read()).decode() - response_b = (await stream_b.read()).decode() + response_a = (await stream_a.read(MAX_READ_LEN)).decode() + response_b = (await stream_b.read(MAX_READ_LEN)).decode() assert response_a == ("ack_b:" + a_message) and response_b == ( "ack_a:" + b_message @@ -129,21 +130,21 @@ async def test_multiple_streams_same_initiator_different_protocols(): async def stream_handler_a1(stream): while True: - read_string = (await stream.read()).decode() + read_string = (await stream.read(MAX_READ_LEN)).decode() response = "ack_a1:" + read_string await stream.write(response.encode()) async def stream_handler_a2(stream): while True: - read_string = (await stream.read()).decode() + read_string = (await stream.read(MAX_READ_LEN)).decode() response = "ack_a2:" + read_string await stream.write(response.encode()) async def stream_handler_a3(stream): while True: - read_string = (await stream.read()).decode() + read_string = (await stream.read(MAX_READ_LEN)).decode() response = "ack_a3:" + read_string await stream.write(response.encode()) @@ -171,9 +172,9 @@ async def test_multiple_streams_same_initiator_different_protocols(): await stream_a2.write(a2_message.encode()) await stream_a3.write(a3_message.encode()) - response_a1 = (await stream_a1.read()).decode() - response_a2 = (await stream_a2.read()).decode() - response_a3 = (await stream_a3.read()).decode() + response_a1 = (await stream_a1.read(MAX_READ_LEN)).decode() + response_a2 = (await stream_a2.read(MAX_READ_LEN)).decode() + response_a3 = (await stream_a3.read(MAX_READ_LEN)).decode() assert ( response_a1 == ("ack_a1:" + a1_message) @@ -192,28 +193,28 @@ async def test_multiple_streams_two_initiators(): async def stream_handler_a1(stream): while True: - read_string = (await stream.read()).decode() + read_string = (await stream.read(MAX_READ_LEN)).decode() response = "ack_a1:" + read_string await stream.write(response.encode()) async def stream_handler_a2(stream): while True: - read_string = (await stream.read()).decode() + read_string = (await stream.read(MAX_READ_LEN)).decode() response = "ack_a2:" + read_string await stream.write(response.encode()) async def stream_handler_b1(stream): while True: - read_string = (await stream.read()).decode() + read_string = (await stream.read(MAX_READ_LEN)).decode() response = "ack_b1:" + read_string await stream.write(response.encode()) async def stream_handler_b2(stream): while True: - read_string = (await stream.read()).decode() + read_string = (await stream.read(MAX_READ_LEN)).decode() response = "ack_b2:" + read_string await stream.write(response.encode()) @@ -249,11 +250,11 @@ async def test_multiple_streams_two_initiators(): await stream_b1.write(b1_message.encode()) await stream_b2.write(b2_message.encode()) - response_a1 = (await stream_a1.read()).decode() - response_a2 = (await stream_a2.read()).decode() + response_a1 = (await stream_a1.read(MAX_READ_LEN)).decode() + response_a2 = (await stream_a2.read(MAX_READ_LEN)).decode() - response_b1 = (await stream_b1.read()).decode() - response_b2 = (await stream_b2.read()).decode() + response_b1 = (await stream_b1.read(MAX_READ_LEN)).decode() + response_b2 = (await stream_b2.read(MAX_READ_LEN)).decode() assert ( response_a1 == ("ack_a1:" + a1_message) @@ -277,7 +278,7 @@ async def test_triangle_nodes_connection(): async def stream_handler(stream): while True: - read_string = (await stream.read()).decode() + read_string = (await stream.read(MAX_READ_LEN)).decode() response = "ack:" + read_string await stream.write(response.encode()) @@ -320,7 +321,7 @@ async def test_triangle_nodes_connection(): for stream in streams: await stream.write(message.encode()) - response = (await stream.read()).decode() + response = (await stream.read(MAX_READ_LEN)).decode() assert response == ("ack:" + message) diff --git a/tests/libp2p/test_notify.py b/tests/libp2p/test_notify.py index f4bd2efc..206f3e3a 100644 --- a/tests/libp2p/test_notify.py +++ b/tests/libp2p/test_notify.py @@ -16,11 +16,10 @@ from libp2p import initialize_default_swarm, new_node from libp2p.crypto.rsa import create_new_key_pair from libp2p.host.basic_host import BasicHost from libp2p.network.notifee_interface import INotifee -from tests.utils import ( - cleanup, - echo_stream_handler, - perform_two_host_set_up_custom_handler, -) +from tests.utils import cleanup, perform_two_host_set_up +from tests.constants import MAX_READ_LEN + +ACK = "ack:" class MyNotifee(INotifee): @@ -67,38 +66,9 @@ class InvalidNotifee: assert False -async def perform_two_host_simple_set_up(): - node_a = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"]) - node_b = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"]) - - async def my_stream_handler(stream): - while True: - read_string = (await stream.read()).decode() - - resp = "ack:" + read_string - await stream.write(resp.encode()) - - node_b.set_stream_handler("/echo/1.0.0", my_stream_handler) - - # Associate the peer with local ip address (see default parameters of Libp2p()) - node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) - return node_a, node_b - - -async def perform_two_host_simple_set_up_custom_handler(handler): - node_a = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"]) - node_b = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"]) - - node_b.set_stream_handler("/echo/1.0.0", handler) - - # Associate the peer with local ip address (see default parameters of Libp2p()) - node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) - return node_a, node_b - - @pytest.mark.asyncio async def test_one_notifier(): - node_a, node_b = await perform_two_host_set_up_custom_handler(echo_stream_handler) + node_a, node_b = await perform_two_host_set_up() # Add notifee for node_a events = [] @@ -113,11 +83,12 @@ async def test_one_notifier(): messages = ["hello", "hello"] for message in messages: + expected_resp = ACK + message await stream.write(message.encode()) - response = (await stream.read()).decode() + response = (await stream.read(len(expected_resp))).decode() - assert response == ("ack:" + message) + assert response == expected_resp # Success, terminate pending tasks. await cleanup() @@ -126,6 +97,7 @@ async def test_one_notifier(): @pytest.mark.asyncio async def test_one_notifier_on_two_nodes(): events_b = [] + messages = ["hello", "hello"] async def my_stream_handler(stream): # Ensure the connected and opened_stream events were hit in Notifee obj @@ -135,13 +107,13 @@ async def test_one_notifier_on_two_nodes(): ["connectedb", stream.mplex_conn], ["opened_streamb", stream], ] - while True: - read_string = (await stream.read()).decode() + for message in messages: + read_string = (await stream.read(len(message))).decode() - resp = "ack:" + read_string + resp = ACK + read_string await stream.write(resp.encode()) - node_a, node_b = await perform_two_host_set_up_custom_handler(my_stream_handler) + node_a, node_b = await perform_two_host_set_up(my_stream_handler) # Add notifee for node_a events_a = [] @@ -157,13 +129,13 @@ async def test_one_notifier_on_two_nodes(): # node_a assert events_a == [["connecteda", stream.mplex_conn], ["opened_streama", stream]] - messages = ["hello", "hello"] for message in messages: + expected_resp = ACK + message await stream.write(message.encode()) - response = (await stream.read()).decode() + response = (await stream.read(len(expected_resp))).decode() - assert response == ("ack:" + message) + assert response == expected_resp # Success, terminate pending tasks. await cleanup() @@ -172,6 +144,7 @@ async def test_one_notifier_on_two_nodes(): @pytest.mark.asyncio async def test_one_notifier_on_two_nodes_with_listen(): events_b = [] + messages = ["hello", "hello"] node_a_key_pair = create_new_key_pair() node_a_transport_opt = ["/ip4/127.0.0.1/tcp/0"] @@ -196,10 +169,9 @@ async def test_one_notifier_on_two_nodes_with_listen(): ["connectedb", stream.mplex_conn], ["opened_streamb", stream], ] - while True: - read_string = (await stream.read()).decode() - - resp = "ack:" + read_string + for message in messages: + read_string = (await stream.read(len(message))).decode() + resp = ACK + read_string await stream.write(resp.encode()) # Add notifee for node_a @@ -222,13 +194,13 @@ async def test_one_notifier_on_two_nodes_with_listen(): # node_a assert events_a == [["connecteda", stream.mplex_conn], ["opened_streama", stream]] - messages = ["hello", "hello"] for message in messages: + expected_resp = ACK + message await stream.write(message.encode()) - response = (await stream.read()).decode() + response = (await stream.read(len(expected_resp))).decode() - assert response == ("ack:" + message) + assert response == expected_resp # Success, terminate pending tasks. await cleanup() @@ -236,7 +208,7 @@ async def test_one_notifier_on_two_nodes_with_listen(): @pytest.mark.asyncio async def test_two_notifiers(): - node_a, node_b = await perform_two_host_set_up_custom_handler(echo_stream_handler) + node_a, node_b = await perform_two_host_set_up() # Add notifee for node_a events0 = [] @@ -255,11 +227,12 @@ async def test_two_notifiers(): messages = ["hello", "hello"] for message in messages: + expected_resp = ACK + message await stream.write(message.encode()) - response = (await stream.read()).decode() + response = (await stream.read(len(expected_resp))).decode() - assert response == ("ack:" + message) + assert response == expected_resp # Success, terminate pending tasks. await cleanup() @@ -269,7 +242,7 @@ async def test_two_notifiers(): async def test_ten_notifiers(): num_notifiers = 10 - node_a, node_b = await perform_two_host_set_up_custom_handler(echo_stream_handler) + node_a, node_b = await perform_two_host_set_up() # Add notifee for node_a events_lst = [] @@ -290,11 +263,12 @@ async def test_ten_notifiers(): messages = ["hello", "hello"] for message in messages: + expected_resp = ACK + message await stream.write(message.encode()) - response = (await stream.read()).decode() + response = (await stream.read(len(expected_resp))).decode() - assert response == ("ack:" + message) + assert response == expected_resp # Success, terminate pending tasks. await cleanup() @@ -315,12 +289,12 @@ async def test_ten_notifiers_on_two_nodes(): ["opened_streamb" + str(i), stream], ] while True: - read_string = (await stream.read()).decode() + read_string = (await stream.read(MAX_READ_LEN)).decode() - resp = "ack:" + read_string + resp = ACK + read_string await stream.write(resp.encode()) - node_a, node_b = await perform_two_host_set_up_custom_handler(my_stream_handler) + node_a, node_b = await perform_two_host_set_up(my_stream_handler) # Add notifee for node_a and node_b events_lst_a = [] @@ -343,11 +317,12 @@ async def test_ten_notifiers_on_two_nodes(): messages = ["hello", "hello"] for message in messages: + expected_resp = ACK + message await stream.write(message.encode()) - response = (await stream.read()).decode() + response = (await stream.read(len(expected_resp))).decode() - assert response == ("ack:" + message) + assert response == expected_resp # Success, terminate pending tasks. await cleanup() @@ -357,7 +332,7 @@ async def test_ten_notifiers_on_two_nodes(): async def test_invalid_notifee(): num_notifiers = 10 - node_a, node_b = await perform_two_host_set_up_custom_handler(echo_stream_handler) + node_a, node_b = await perform_two_host_set_up() # Add notifee for node_a events_lst = [] @@ -372,11 +347,12 @@ async def test_invalid_notifee(): # given that InvalidNotifee should not have been added as a notifee) messages = ["hello", "hello"] for message in messages: + expected_resp = ACK + message await stream.write(message.encode()) - response = (await stream.read()).decode() + response = (await stream.read(len(expected_resp))).decode() - assert response == ("ack:" + message) + assert response == expected_resp # Success, terminate pending tasks. await cleanup() diff --git a/tests/protocol_muxer/test_protocol_muxer.py b/tests/protocol_muxer/test_protocol_muxer.py index 02f08bdc..8fb15371 100644 --- a/tests/protocol_muxer/test_protocol_muxer.py +++ b/tests/protocol_muxer/test_protocol_muxer.py @@ -1,7 +1,7 @@ import pytest from libp2p.protocol_muxer.exceptions import MultiselectClientError -from tests.utils import cleanup, set_up_nodes_by_transport_opt +from tests.utils import cleanup, set_up_nodes_by_transport_opt, echo_stream_handler # TODO: Add tests for multiple streams being opened on different # protocols through the same connection @@ -18,14 +18,8 @@ async def perform_simple_test( transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] (node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list) - async def stream_handler(stream): - while True: - read_string = (await stream.read()).decode() - response = "ack:" + read_string - await stream.write(response.encode()) - for protocol in protocols_with_handlers: - node_b.set_stream_handler(protocol, stream_handler) + node_b.set_stream_handler(protocol, echo_stream_handler) # Associate the peer with local ip address (see default parameters of Libp2p()) node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) @@ -33,11 +27,10 @@ async def perform_simple_test( stream = await node_a.new_stream(node_b.get_id(), protocols_for_client) messages = ["hello" + str(x) for x in range(10)] for message in messages: + expected_resp = "ack:" + message await stream.write(message.encode()) - - response = (await stream.read()).decode() - - assert response == ("ack:" + message) + response = (await stream.read(len(expected_resp))).decode() + assert response == expected_resp assert expected_selected_protocol == stream.get_protocol() diff --git a/tests/utils.py b/tests/utils.py index 58a08807..1f1cfc4f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -6,6 +6,8 @@ import multiaddr from libp2p import new_node from libp2p.peer.peerinfo import info_from_p2p_addr +from tests.constants import MAX_READ_LEN + async def connect(node1, node2): """ @@ -38,13 +40,13 @@ async def set_up_nodes_by_transport_opt(transport_opt_list): async def echo_stream_handler(stream): while True: - read_string = (await stream.read()).decode() + read_string = (await stream.read(MAX_READ_LEN)).decode() resp = "ack:" + read_string await stream.write(resp.encode()) -async def perform_two_host_set_up_custom_handler(handler): +async def perform_two_host_set_up(handler=echo_stream_handler): transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] (node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list) From a754e7dbbe03a8d0051424fed5e28f74f86436c6 Mon Sep 17 00:00:00 2001 From: mhchia Date: Fri, 6 Sep 2019 17:59:39 +0800 Subject: [PATCH 7/9] Add the missing tests.constants --- tests/constants.py | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 tests/constants.py diff --git a/tests/constants.py b/tests/constants.py new file mode 100644 index 00000000..3d2b3b77 --- /dev/null +++ b/tests/constants.py @@ -0,0 +1,4 @@ +# Just a arbitrary large number. +# It is used when calling `MplexStream.read(MAX_READ_LEN)`, +# to avoid `MplexStream.read()`, which blocking reads until EOF. +MAX_READ_LEN = 2 ** 32 - 1 From 1cd969a2d5b991f6531b8291fcc0482b4791376a Mon Sep 17 00:00:00 2001 From: mhchia Date: Fri, 6 Sep 2019 20:02:35 +0800 Subject: [PATCH 8/9] Fix: Add typing in functions --- libp2p/pubsub/pubsub.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index b1812933..c55a1834 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -249,8 +249,9 @@ class Pubsub: # Force context switch await asyncio.sleep(0) - # FIXME: `sub_message` can be further type hinted with mypy_protobuf - def handle_subscription(self, origin_id: ID, sub_message: Any) -> None: + def handle_subscription( + self, origin_id: ID, sub_message: rpc_pb2.RPC.SubOpts + ) -> None: """ Handle an incoming subscription message from a peer. Update internal mapping to mark the peer as subscribed or unsubscribed to topics as @@ -270,8 +271,7 @@ class Pubsub: self.peer_topics[sub_message.topicid].remove(origin_id) # FIXME(mhchia): Change the function name? - # FIXME(mhchia): `publish_message` can be further type hinted with mypy_protobuf - async def handle_talk(self, publish_message: Any) -> None: + async def handle_talk(self, publish_message: rpc_pb2.Message) -> None: """ Put incoming message from a peer onto my blocking queue :param publish_message: RPC.Message format From 6c1f77dc1a946733411ff0103e8f0ae06d6ddcab Mon Sep 17 00:00:00 2001 From: mhchia Date: Fri, 6 Sep 2019 21:35:15 +0800 Subject: [PATCH 9/9] Fix: Change the `event.close` to `event.set` And add missing parts. --- libp2p/pubsub/pubsub.py | 1 - libp2p/stream_muxer/mplex/mplex.py | 28 ++++++++++----------- libp2p/stream_muxer/mplex/mplex_stream.py | 4 +-- tests/libp2p/test_libp2p.py | 2 +- tests/libp2p/test_notify.py | 2 +- tests/protocol_muxer/test_protocol_muxer.py | 2 +- tests/utils.py | 1 - 7 files changed, 19 insertions(+), 21 deletions(-) diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index c55a1834..5c0466cd 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -3,7 +3,6 @@ import logging import time from typing import ( TYPE_CHECKING, - Any, Awaitable, Callable, Dict, diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index f342978e..1e8823a9 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -166,8 +166,11 @@ class Mplex(IMuxedConn): if channel_id is not None and flag is not None and message is not None: stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1)) is_stream_id_seen: bool + stream: MplexStream async with self.streams_lock: is_stream_id_seen = stream_id in self.streams + if is_stream_id_seen: + stream = self.streams[stream_id] # Other consequent stream message should wait until the stream get accepted # TODO: Handle more tags, and refactor `HeaderTags` if flag == HeaderTags.NewStream.value: @@ -185,8 +188,6 @@ class Mplex(IMuxedConn): # before. It is abnormal. Possibly disconnect? # TODO: Warn and emit logs about this. continue - async with self.streams_lock: - stream = self.streams[stream_id] await stream.incoming_data.put(message) elif flag in ( HeaderTags.CloseInitiator.value, @@ -194,15 +195,17 @@ class Mplex(IMuxedConn): ): if not is_stream_id_seen: continue - stream: MplexStream - async with self.streams_lock: - stream = self.streams[stream_id] + # NOTE: If remote is already closed, then return: Technically a bug + # on the other side. We should consider killing the connection. + async with stream.close_lock: + if stream.event_remote_closed.is_set(): + continue is_local_closed: bool async with stream.close_lock: stream.event_remote_closed.set() is_local_closed = stream.event_local_closed.is_set() # If local is also closed, both sides are closed. Then, we should clean up - # this stream. + # the entry of this stream, to avoid others from accessing it. if is_local_closed: async with self.streams_lock: del self.streams[stream_id] @@ -213,24 +216,21 @@ class Mplex(IMuxedConn): if not is_stream_id_seen: # This is *ok*. We forget the stream on reset. continue - stream: MplexStream - async with self.streams_lock: - stream = self.streams[stream_id] async with stream.close_lock: if not stream.event_remote_closed.is_set(): + # TODO: Why? Only if remote is not closed before then reset. stream.event_reset.set() + stream.event_remote_closed.set() + # If local is not closed, we should close it. if not stream.event_local_closed.is_set(): - stream.event_local_closed.close() + stream.event_local_closed.set() async with self.streams_lock: del self.streams[stream_id] else: # TODO: logging - print(f"message with unknown header on stream {stream_id}") if is_stream_id_seen: - async with self.streams_lock: - stream = self.streams[stream_id] - await stream.reset() + await stream.reset() # Force context switch await asyncio.sleep(0) diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index e537dda3..18c8ff02 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -4,8 +4,8 @@ from typing import TYPE_CHECKING from libp2p.stream_muxer.abc import IMuxedStream from .constants import HeaderTags -from .exceptions import MplexStreamReset, MplexStreamEOF from .datastructures import StreamID +from .exceptions import MplexStreamEOF, MplexStreamReset if TYPE_CHECKING: from libp2p.stream_muxer.mplex.mplex import Mplex @@ -55,7 +55,7 @@ class MplexStream(IMuxedStream): return self.stream_id.is_initiator async def _wait_for_data(self) -> None: - done, pending = await asyncio.wait( + done, pending = await asyncio.wait( # type: ignore [ self.event_reset.wait(), self.event_remote_closed.wait(), diff --git a/tests/libp2p/test_libp2p.py b/tests/libp2p/test_libp2p.py index b4a643d2..8090f5ea 100644 --- a/tests/libp2p/test_libp2p.py +++ b/tests/libp2p/test_libp2p.py @@ -2,8 +2,8 @@ import multiaddr import pytest from libp2p.peer.peerinfo import info_from_p2p_addr -from tests.utils import cleanup, set_up_nodes_by_transport_opt from tests.constants import MAX_READ_LEN +from tests.utils import cleanup, set_up_nodes_by_transport_opt @pytest.mark.asyncio diff --git a/tests/libp2p/test_notify.py b/tests/libp2p/test_notify.py index 206f3e3a..e21030ab 100644 --- a/tests/libp2p/test_notify.py +++ b/tests/libp2p/test_notify.py @@ -16,8 +16,8 @@ from libp2p import initialize_default_swarm, new_node from libp2p.crypto.rsa import create_new_key_pair from libp2p.host.basic_host import BasicHost from libp2p.network.notifee_interface import INotifee -from tests.utils import cleanup, perform_two_host_set_up from tests.constants import MAX_READ_LEN +from tests.utils import cleanup, perform_two_host_set_up ACK = "ack:" diff --git a/tests/protocol_muxer/test_protocol_muxer.py b/tests/protocol_muxer/test_protocol_muxer.py index 8fb15371..7830aaa8 100644 --- a/tests/protocol_muxer/test_protocol_muxer.py +++ b/tests/protocol_muxer/test_protocol_muxer.py @@ -1,7 +1,7 @@ import pytest from libp2p.protocol_muxer.exceptions import MultiselectClientError -from tests.utils import cleanup, set_up_nodes_by_transport_opt, echo_stream_handler +from tests.utils import cleanup, echo_stream_handler, set_up_nodes_by_transport_opt # TODO: Add tests for multiple streams being opened on different # protocols through the same connection diff --git a/tests/utils.py b/tests/utils.py index 1f1cfc4f..a26ebc55 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,7 +5,6 @@ import multiaddr from libp2p import new_node from libp2p.peer.peerinfo import info_from_p2p_addr - from tests.constants import MAX_READ_LEN