makes test_mplex_stream.py::test_mplex_stream_read_write work

This commit is contained in:
Chih Cheng Liang
2019-11-19 18:04:48 +08:00
committed by mhchia
parent c55ea0e5bb
commit a397ccdc04
13 changed files with 70 additions and 122 deletions

View File

@ -1,11 +1,11 @@
import argparse import argparse
import asyncio import asyncio
import trio_asyncio
import trio
import sys import sys
import urllib.request import urllib.request
import multiaddr import multiaddr
import trio
import trio_asyncio
from libp2p import new_node from libp2p import new_node
from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.stream.net_stream_interface import INetStream
@ -42,7 +42,9 @@ async def run(port: int, destination: str, localhost: bool) -> None:
transport_opt = f"/ip4/{ip}/tcp/{port}" transport_opt = f"/ip4/{ip}/tcp/{port}"
host = new_node(transport_opt=[transport_opt]) host = new_node(transport_opt=[transport_opt])
await trio_asyncio.run_asyncio(host.get_network().listen,multiaddr.Multiaddr(transport_opt) ) await trio_asyncio.run_asyncio(
host.get_network().listen, multiaddr.Multiaddr(transport_opt)
)
if not destination: # its the server if not destination: # its the server
@ -70,7 +72,9 @@ async def run(port: int, destination: str, localhost: bool) -> None:
# Start a stream with the destination. # Start a stream with the destination.
# Multiaddress of the destination peer is fetched from the peerstore using 'peerId'. # Multiaddress of the destination peer is fetched from the peerstore using 'peerId'.
stream = await trio_asyncio.run_asyncio(host.new_stream, *(info.peer_id, [PROTOCOL_ID])) stream = await trio_asyncio.run_asyncio(
host.new_stream, *(info.peer_id, [PROTOCOL_ID])
)
asyncio.ensure_future(read_data(stream)) asyncio.ensure_future(read_data(stream))
asyncio.ensure_future(write_data(stream)) asyncio.ensure_future(write_data(stream))
@ -119,5 +123,6 @@ def main() -> None:
trio_asyncio.run(run, *(args.port, args.destination, args.localhost)) trio_asyncio.run(run, *(args.port, args.destination, args.localhost))
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -1,9 +1,10 @@
import trio
from trio import SocketStream
from libp2p.io.abc import ReadWriteCloser
from libp2p.io.exceptions import IOException
import logging import logging
import trio
from trio import SocketStream
from libp2p.io.abc import ReadWriteCloser
from libp2p.io.exceptions import IOException
logger = logging.getLogger("libp2p.io.trio") logger = logging.getLogger("libp2p.io.trio")

View File

@ -1,9 +1,10 @@
import trio import trio
from libp2p.io.abc import ReadWriteCloser
from libp2p.io.exceptions import IOException from libp2p.io.exceptions import IOException
from .exceptions import RawConnError from .exceptions import RawConnError
from .raw_connection_interface import IRawConnection from .raw_connection_interface import IRawConnection
from libp2p.io.abc import ReadWriteCloser
class RawConnection(IRawConnection): class RawConnection(IRawConnection):

View File

@ -3,6 +3,7 @@ import logging
from typing import Dict, List, Optional from typing import Dict, List, Optional
from multiaddr import Multiaddr from multiaddr import Multiaddr
import trio
from libp2p.io.abc import ReadWriteCloser from libp2p.io.abc import ReadWriteCloser
from libp2p.network.connection.net_connection_interface import INetConn from libp2p.network.connection.net_connection_interface import INetConn
@ -69,7 +70,7 @@ class Swarm(INetwork):
def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None: def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None:
self.common_stream_handler = stream_handler self.common_stream_handler = stream_handler
async def dial_peer(self, peer_id: ID) -> INetConn: async def dial_peer(self, peer_id: ID, nursery) -> INetConn:
""" """
dial_peer try to create a connection to peer_id. dial_peer try to create a connection to peer_id.
@ -121,6 +122,7 @@ class Swarm(INetwork):
try: try:
muxed_conn = await self.upgrader.upgrade_connection(secured_conn, peer_id) muxed_conn = await self.upgrader.upgrade_connection(secured_conn, peer_id)
muxed_conn.run(nursery)
except MuxerUpgradeFailure as error: except MuxerUpgradeFailure as error:
error_msg = "fail to upgrade mux for peer %s" error_msg = "fail to upgrade mux for peer %s"
logger.debug(error_msg, peer_id) logger.debug(error_msg, peer_id)
@ -135,7 +137,7 @@ class Swarm(INetwork):
return swarm_conn return swarm_conn
async def new_stream(self, peer_id: ID) -> INetStream: async def new_stream(self, peer_id: ID, nursery) -> INetStream:
""" """
:param peer_id: peer_id of destination :param peer_id: peer_id of destination
:param protocol_id: protocol id :param protocol_id: protocol id
@ -144,7 +146,7 @@ class Swarm(INetwork):
""" """
logger.debug("attempting to open a stream to peer %s", peer_id) logger.debug("attempting to open a stream to peer %s", peer_id)
swarm_conn = await self.dial_peer(peer_id) swarm_conn = await self.dial_peer(peer_id, nursery)
net_stream = await swarm_conn.new_stream() net_stream = await swarm_conn.new_stream()
logger.debug("successfully opened a stream to peer %s", peer_id) logger.debug("successfully opened a stream to peer %s", peer_id)
@ -183,11 +185,11 @@ class Swarm(INetwork):
raise SwarmException() from error raise SwarmException() from error
peer_id = secured_conn.get_remote_peer() peer_id = secured_conn.get_remote_peer()
try: try:
muxed_conn = await self.upgrader.upgrade_connection( muxed_conn = await self.upgrader.upgrade_connection(
secured_conn, peer_id secured_conn, peer_id
) )
muxed_conn.run(nursery)
except MuxerUpgradeFailure as error: except MuxerUpgradeFailure as error:
error_msg = "fail to upgrade mux for peer %s" error_msg = "fail to upgrade mux for peer %s"
logger.debug(error_msg, peer_id) logger.debug(error_msg, peer_id)
@ -198,6 +200,8 @@ class Swarm(INetwork):
await self.add_conn(muxed_conn) await self.add_conn(muxed_conn)
logger.debug("successfully opened connection to peer %s", peer_id) logger.debug("successfully opened connection to peer %s", peer_id)
event = trio.Event()
await event.wait()
try: try:
# Success # Success

View File

@ -3,6 +3,8 @@ import logging
from typing import Any # noqa: F401 from typing import Any # noqa: F401
from typing import Awaitable, Dict, List, Optional, Tuple from typing import Awaitable, Dict, List, Optional, Tuple
import trio
from libp2p.exceptions import ParseError from libp2p.exceptions import ParseError
from libp2p.io.exceptions import IncompleteReadError from libp2p.io.exceptions import IncompleteReadError
from libp2p.network.connection.exceptions import RawConnError from libp2p.network.connection.exceptions import RawConnError
@ -41,8 +43,6 @@ class Mplex(IMuxedConn):
event_shutting_down: asyncio.Event event_shutting_down: asyncio.Event
event_closed: asyncio.Event event_closed: asyncio.Event
_tasks: List["asyncio.Future[Any]"]
def __init__(self, secured_conn: ISecureConn, peer_id: ID) -> None: def __init__(self, secured_conn: ISecureConn, peer_id: ID) -> None:
""" """
create a new muxed connection. create a new muxed connection.
@ -66,10 +66,8 @@ class Mplex(IMuxedConn):
self.event_shutting_down = asyncio.Event() self.event_shutting_down = asyncio.Event()
self.event_closed = asyncio.Event() self.event_closed = asyncio.Event()
self._tasks = [] def run(self, nursery):
nursery.start_soon(self.handle_incoming)
# Kick off reading
self._tasks.append(asyncio.ensure_future(self.handle_incoming()))
@property @property
def is_initiator(self) -> bool: def is_initiator(self) -> bool:
@ -123,7 +121,6 @@ class Mplex(IMuxedConn):
await self.send_message(HeaderTags.NewStream, name.encode(), stream_id) await self.send_message(HeaderTags.NewStream, name.encode(), stream_id)
return stream return stream
async def accept_stream(self) -> IMuxedStream: async def accept_stream(self) -> IMuxedStream:
"""accepts a muxed stream opened by the other end.""" """accepts a muxed stream opened by the other end."""
return await self.new_stream_queue.get() return await self.new_stream_queue.get()
@ -169,7 +166,7 @@ class Mplex(IMuxedConn):
logger.debug("mplex unavailable while waiting for incoming: %s", e) logger.debug("mplex unavailable while waiting for incoming: %s", e)
break break
# Force context switch # Force context switch
await asyncio.sleep(0) await trio.sleep(0)
# If we enter here, it means this connection is shutting down. # If we enter here, it means this connection is shutting down.
# We should clean things up. # We should clean things up.
await self._cleanup() await self._cleanup()
@ -184,9 +181,7 @@ class Mplex(IMuxedConn):
# FIXME: No timeout is used in Go implementation. # FIXME: No timeout is used in Go implementation.
try: try:
header = await decode_uvarint_from_stream(self.secured_conn) header = await decode_uvarint_from_stream(self.secured_conn)
message = await asyncio.wait_for( message = await read_varint_prefixed_bytes(self.secured_conn)
read_varint_prefixed_bytes(self.secured_conn), timeout=5
)
except (ParseError, RawConnError, IncompleteReadError) as error: except (ParseError, RawConnError, IncompleteReadError) as error:
raise MplexUnavailable( raise MplexUnavailable(
"failed to read messages correctly from the underlying connection" "failed to read messages correctly from the underlying connection"

View File

@ -1,8 +1,10 @@
import trio
import asyncio import asyncio
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import trio
from libp2p.stream_muxer.abc import IMuxedStream from libp2p.stream_muxer.abc import IMuxedStream
from libp2p.utils import IQueue, TrioQueue
from .constants import HeaderTags from .constants import HeaderTags
from .datastructures import StreamID from .datastructures import StreamID
@ -26,7 +28,7 @@ class MplexStream(IMuxedStream):
close_lock: trio.Lock close_lock: trio.Lock
# NOTE: `dataIn` is size of 8 in Go implementation. # NOTE: `dataIn` is size of 8 in Go implementation.
incoming_data: "asyncio.Queue[bytes]" incoming_data: IQueue[bytes]
event_local_closed: trio.Event event_local_closed: trio.Event
event_remote_closed: trio.Event event_remote_closed: trio.Event
@ -50,69 +52,13 @@ class MplexStream(IMuxedStream):
self.event_remote_closed = trio.Event() self.event_remote_closed = trio.Event()
self.event_reset = trio.Event() self.event_reset = trio.Event()
self.close_lock = trio.Lock() self.close_lock = trio.Lock()
self.incoming_data = asyncio.Queue() self.incoming_data = TrioQueue()
self._buf = bytearray() self._buf = bytearray()
@property @property
def is_initiator(self) -> bool: def is_initiator(self) -> bool:
return self.stream_id.is_initiator return self.stream_id.is_initiator
async def _wait_for_data(self) -> None:
task_event_reset = asyncio.ensure_future(self.event_reset.wait())
task_incoming_data_get = asyncio.ensure_future(self.incoming_data.get())
task_event_remote_closed = asyncio.ensure_future(
self.event_remote_closed.wait()
)
done, pending = await asyncio.wait( # type: ignore
[ # type: ignore
task_event_reset,
task_incoming_data_get,
task_event_remote_closed,
],
return_when=asyncio.FIRST_COMPLETED,
)
for fut in pending:
fut.cancel()
if task_event_reset in done:
if self.event_reset.is_set():
raise MplexStreamReset
else:
# However, it is abnormal that `Event.wait` is unblocked without any of the flag
# is set. The task is probably cancelled.
raise Exception(
"Should not enter here. "
f"It is probably because {task_event_remote_closed} is cancelled."
)
if task_incoming_data_get in done:
data = task_incoming_data_get.result()
self._buf.extend(data)
return
if task_event_remote_closed in done:
if self.event_remote_closed.is_set():
raise MplexStreamEOF
else:
# However, it is abnormal that `Event.wait` is unblocked without any of the flag
# is set. The task is probably cancelled.
raise Exception(
"Should not enter here. "
f"It is probably because {task_event_remote_closed} is cancelled."
)
# TODO: Handle timeout when deadline is used.
async def _read_until_eof(self) -> bytes:
while True:
try:
await self._wait_for_data()
except MplexStreamEOF:
break
payload = self._buf
self._buf = self._buf[len(payload) :]
return bytes(payload)
async def read(self, n: int = -1) -> bytes: async def read(self, n: int = -1) -> bytes:
""" """
Read up to n bytes. Read possibly returns fewer than `n` bytes, if Read up to n bytes. Read possibly returns fewer than `n` bytes, if
@ -128,20 +74,7 @@ class MplexStream(IMuxedStream):
) )
if self.event_reset.is_set(): if self.event_reset.is_set():
raise MplexStreamReset raise MplexStreamReset
if n == -1: return await self.incoming_data.get()
return await self._read_until_eof()
if len(self._buf) == 0 and self.incoming_data.empty():
await self._wait_for_data()
# Now we are sure we have something to read.
# Try to put enough incoming data into `self._buf`.
while len(self._buf) < n:
try:
self._buf.extend(self.incoming_data.get_nowait())
except asyncio.QueueEmpty:
break
payload = self._buf[:n]
self._buf = self._buf[len(payload) :]
return bytes(payload)
async def write(self, data: bytes) -> int: async def write(self, data: bytes) -> int:
""" """

View File

@ -17,7 +17,7 @@ from libp2p.typing import StreamHandlerFn, TProtocol
from .constants import MAX_READ_LEN from .constants import MAX_READ_LEN
async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm) -> None: async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm, nursery: trio.Nursery) -> None:
peer_id = swarm_1.get_peer_id() peer_id = swarm_1.get_peer_id()
addrs = tuple( addrs = tuple(
addr addr
@ -25,7 +25,7 @@ async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm) -> None:
for addr in transport.get_addrs() for addr in transport.get_addrs()
) )
swarm_0.peerstore.add_addrs(peer_id, addrs, 10000) swarm_0.peerstore.add_addrs(peer_id, addrs, 10000)
await swarm_0.dial_peer(peer_id) await swarm_0.dial_peer(peer_id, nursery)
assert swarm_0.get_peer_id() in swarm_1.connections assert swarm_0.get_peer_id() in swarm_1.connections
assert swarm_1.get_peer_id() in swarm_0.connections assert swarm_1.get_peer_id() in swarm_0.connections
@ -43,7 +43,9 @@ async def set_up_nodes_by_transport_opt(
nodes_list = [] nodes_list = []
for transport_opt in transport_opt_list: for transport_opt in transport_opt_list:
node = new_node(transport_opt=transport_opt) node = new_node(transport_opt=transport_opt)
await node.get_network().listen(multiaddr.Multiaddr(transport_opt[0]), nursery=nursery) await node.get_network().listen(
multiaddr.Multiaddr(transport_opt[0]), nursery=nursery
)
nodes_list.append(node) nodes_list.append(node)
return tuple(nodes_list) return tuple(nodes_list)

View File

@ -1,18 +1,18 @@
import asyncio import asyncio
import trio import logging
from socket import socket from socket import socket
from typing import List from typing import List
from multiaddr import Multiaddr from multiaddr import Multiaddr
import trio
from libp2p.io.trio import TrioReadWriteCloser
from libp2p.network.connection.raw_connection import RawConnection from libp2p.network.connection.raw_connection import RawConnection
from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.transport.exceptions import OpenConnectionError from libp2p.transport.exceptions import OpenConnectionError
from libp2p.transport.listener_interface import IListener from libp2p.transport.listener_interface import IListener
from libp2p.transport.transport_interface import ITransport from libp2p.transport.transport_interface import ITransport
from libp2p.transport.typing import THandler from libp2p.transport.typing import THandler
from libp2p.io.trio import TrioReadWriteCloser
import logging
logger = logging.getLogger("libp2p.transport.tcp") logger = logging.getLogger("libp2p.transport.tcp")
@ -44,11 +44,9 @@ class TCPListener(IListener):
listeners = await nursery.start( listeners = await nursery.start(
serve_tcp, serve_tcp,
*( handler,
handler, int(maddr.value_for_protocol("tcp")),
int(maddr.value_for_protocol("tcp")), maddr.value_for_protocol("ip4"),
maddr.value_for_protocol("ip4"),
),
) )
# self.server = await asyncio.start_server( # self.server = await asyncio.start_server(
# self.handler, # self.handler,
@ -57,7 +55,6 @@ class TCPListener(IListener):
# ) # )
socket = listeners[0].socket socket = listeners[0].socket
self.multiaddrs.append(_multiaddr_from_socket(socket)) self.multiaddrs.append(_multiaddr_from_socket(socket))
logger.debug("Multiaddrs %s", self.multiaddrs)
return True return True

View File

@ -1,10 +1,9 @@
from typing import Awaitable, Callable, Mapping, Type from typing import Awaitable, Callable, Mapping, Type
from libp2p.io.abc import ReadWriteCloser
from libp2p.security.secure_transport_interface import ISecureTransport from libp2p.security.secure_transport_interface import ISecureTransport
from libp2p.stream_muxer.abc import IMuxedConn from libp2p.stream_muxer.abc import IMuxedConn
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from libp2p.io.abc import ReadWriteCloser
THandler = Callable[[ReadWriteCloser], Awaitable[None]] THandler = Callable[[ReadWriteCloser], Awaitable[None]]
TSecurityOptions = Mapping[TProtocol, ISecureTransport] TSecurityOptions = Mapping[TProtocol, ISecureTransport]

View File

@ -1,14 +1,14 @@
import itertools import itertools
import math import math
from typing import Generic, TypeVar
import trio
from libp2p.exceptions import ParseError from libp2p.exceptions import ParseError
from libp2p.io.abc import Reader from libp2p.io.abc import Reader
from .io.utils import read_exactly from .io.utils import read_exactly
from typing import Generic, TypeVar
import trio
# Unsigned LEB128(varint codec) # Unsigned LEB128(varint codec)
# Reference: https://github.com/ethereum/py-wasm/blob/master/wasm/parsers/leb128.py # Reference: https://github.com/ethereum/py-wasm/blob/master/wasm/parsers/leb128.py

View File

@ -1,6 +1,6 @@
import trio
import multiaddr import multiaddr
import pytest import pytest
import trio
from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.tools.constants import MAX_READ_LEN from libp2p.tools.constants import MAX_READ_LEN
@ -24,11 +24,11 @@ async def test_simple_messages(nursery):
# Associate the peer with local ip address (see default parameters of Libp2p()) # Associate the peer with local ip address (see default parameters of Libp2p())
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"])
messages = ["hello" + str(x) for x in range(10)] messages = ["hello" + str(x) for x in range(10)]
for message in messages: for message in messages:
await stream.write(message.encode()) await stream.write(message.encode())
response = (await stream.read(MAX_READ_LEN)).decode() response = (await stream.read(MAX_READ_LEN)).decode()

View File

@ -1,20 +1,31 @@
import asyncio import asyncio
import pytest import pytest
import trio
from libp2p.stream_muxer.mplex.exceptions import ( from libp2p.stream_muxer.mplex.exceptions import (
MplexStreamClosed, MplexStreamClosed,
MplexStreamEOF, MplexStreamEOF,
MplexStreamReset, MplexStreamReset,
) )
from libp2p.tools.constants import MAX_READ_LEN from libp2p.tools.constants import MAX_READ_LEN, LISTEN_MADDR
from libp2p.tools.factories import SwarmFactory
from libp2p.tools.utils import connect_swarm
DATA = b"data_123" DATA = b"data_123"
@pytest.mark.asyncio @pytest.mark.trio
async def test_mplex_stream_read_write(mplex_stream_pair): async def test_mplex_stream_read_write(nursery):
stream_0, stream_1 = mplex_stream_pair swarm0, swarm1 = SwarmFactory(), SwarmFactory()
await swarm0.listen(LISTEN_MADDR, nursery=nursery)
await swarm1.listen(LISTEN_MADDR, nursery=nursery)
await connect_swarm(swarm0, swarm1, nursery)
conn_0 = swarm0.connections[swarm1.get_peer_id()]
conn_1 = swarm1.connections[swarm0.get_peer_id()]
stream_0 = await conn_0.muxed_conn.open_stream()
await trio.sleep(1)
stream_1 = tuple(conn_1.muxed_conn.streams.values())[0]
await stream_0.write(DATA) await stream_0.write(DATA)
assert (await stream_1.read(MAX_READ_LEN)) == DATA assert (await stream_1.read(MAX_READ_LEN)) == DATA

View File

@ -1,5 +1,6 @@
import trio
import pytest import pytest
import trio
from libp2p.utils import TrioQueue from libp2p.utils import TrioQueue
@ -16,4 +17,3 @@ async def test_trio_queue():
result = await nursery.start(queue_get) result = await nursery.start(queue_get)
assert result == 123 assert result == 123