Make Mplex and SwarmConn not Service

After second thoughts, they seem not a good candidate of `Service`.
The shutdown logic becomes simpler by making them not `Service`.
This commit is contained in:
mhchia
2020-01-07 21:50:03 +08:00
parent eab59482c0
commit eef241e70e
7 changed files with 43 additions and 61 deletions

View File

@ -1,6 +1,8 @@
from abc import abstractmethod from abc import abstractmethod
from typing import Tuple from typing import Tuple
import trio
from libp2p.io.abc import Closer from libp2p.io.abc import Closer
from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.stream_muxer.abc import IMuxedConn from libp2p.stream_muxer.abc import IMuxedConn
@ -8,11 +10,12 @@ from libp2p.stream_muxer.abc import IMuxedConn
class INetConn(Closer): class INetConn(Closer):
muxed_conn: IMuxedConn muxed_conn: IMuxedConn
event_started: trio.Event
@abstractmethod @abstractmethod
async def new_stream(self) -> INetStream: async def new_stream(self) -> INetStream:
... ...
@abstractmethod @abstractmethod
async def get_streams(self) -> Tuple[INetStream, ...]: def get_streams(self) -> Tuple[INetStream, ...]:
... ...

View File

@ -1,6 +1,5 @@
from typing import TYPE_CHECKING, Set, Tuple from typing import TYPE_CHECKING, Set, Tuple
from async_service import Service
import trio import trio
from libp2p.network.connection.net_connection_interface import INetConn from libp2p.network.connection.net_connection_interface import INetConn
@ -17,10 +16,11 @@ Reference: https://github.com/libp2p/go-libp2p-swarm/blob/04c86bbdafd390651cb2ee
""" """
class SwarmConn(INetConn, Service): class SwarmConn(INetConn):
muxed_conn: IMuxedConn muxed_conn: IMuxedConn
swarm: "Swarm" swarm: "Swarm"
streams: Set[NetStream] streams: Set[NetStream]
event_started: trio.Event
event_closed: trio.Event event_closed: trio.Event
def __init__(self, muxed_conn: IMuxedConn, swarm: "Swarm") -> None: def __init__(self, muxed_conn: IMuxedConn, swarm: "Swarm") -> None:
@ -28,6 +28,7 @@ class SwarmConn(INetConn, Service):
self.swarm = swarm self.swarm = swarm
self.streams = set() self.streams = set()
self.event_closed = trio.Event() self.event_closed = trio.Event()
self.event_started = trio.Event()
@property @property
def is_closed(self) -> bool: def is_closed(self) -> bool:
@ -38,8 +39,6 @@ class SwarmConn(INetConn, Service):
return return
self.event_closed.set() self.event_closed.set()
await self._cleanup() await self._cleanup()
# Cancel service
await self.manager.stop()
async def _cleanup(self) -> None: async def _cleanup(self) -> None:
self.swarm.remove_conn(self) self.swarm.remove_conn(self)
@ -57,13 +56,14 @@ class SwarmConn(INetConn, Service):
self._notify_disconnected() self._notify_disconnected()
async def _handle_new_streams(self) -> None: async def _handle_new_streams(self) -> None:
while self.manager.is_running: self.event_started.set()
while True:
try: try:
stream = await self.muxed_conn.accept_stream() stream = await self.muxed_conn.accept_stream()
# Asynchronously handle the accepted stream, to avoid blocking the next stream. # Asynchronously handle the accepted stream, to avoid blocking the next stream.
except MuxedConnUnavailable: except MuxedConnUnavailable:
break break
self.manager.run_task(self._handle_muxed_stream, stream) self.swarm.manager.run_task(self._handle_muxed_stream, stream)
await self.close() await self.close()
@ -87,15 +87,14 @@ class SwarmConn(INetConn, Service):
def _notify_disconnected(self) -> None: def _notify_disconnected(self) -> None:
self.swarm.notify_disconnected(self) self.swarm.notify_disconnected(self)
async def run(self) -> None: async def start(self) -> None:
self.manager.run_task(self._handle_new_streams) await self._handle_new_streams()
await self.manager.wait_finished()
async def new_stream(self) -> NetStream: async def new_stream(self) -> NetStream:
muxed_stream = await self.muxed_conn.open_stream() muxed_stream = await self.muxed_conn.open_stream()
return self._add_stream(muxed_stream) return self._add_stream(muxed_stream)
async def get_streams(self) -> Tuple[NetStream, ...]: def get_streams(self) -> Tuple[NetStream, ...]:
return tuple(self.streams) return tuple(self.streams)
def remove_stream(self, stream: NetStream) -> None: def remove_stream(self, stream: NetStream) -> None:

View File

@ -2,7 +2,6 @@ import logging
from typing import Dict, List, Optional from typing import Dict, List, Optional
from multiaddr import Multiaddr from multiaddr import Multiaddr
import trio
from libp2p.io.abc import ReadWriteCloser from libp2p.io.abc import ReadWriteCloser
from libp2p.network.connection.net_connection_interface import INetConn from libp2p.network.connection.net_connection_interface import INetConn
@ -44,7 +43,6 @@ class Swarm(INetworkService):
common_stream_handler: Optional[StreamHandlerFn] common_stream_handler: Optional[StreamHandlerFn]
notifees: List[INotifee] notifees: List[INotifee]
event_closed: trio.Event
def __init__( def __init__(
self, self,
@ -63,8 +61,6 @@ class Swarm(INetworkService):
# Create Notifee array # Create Notifee array
self.notifees = [] self.notifees = []
self.event_closed = trio.Event()
self.common_stream_handler = None self.common_stream_handler = None
async def run(self) -> None: async def run(self) -> None:
@ -158,13 +154,11 @@ class Swarm(INetworkService):
try: try:
muxed_conn = await self.upgrader.upgrade_connection(secured_conn, peer_id) muxed_conn = await self.upgrader.upgrade_connection(secured_conn, peer_id)
self.manager.run_child_service(muxed_conn)
except MuxerUpgradeFailure as error: except MuxerUpgradeFailure as error:
error_msg = "fail to upgrade mux for peer %s" error_msg = "fail to upgrade mux for peer %s"
logger.debug(error_msg, peer_id) logger.debug(error_msg, peer_id)
await secured_conn.close() await secured_conn.close()
raise SwarmException(error_msg % peer_id) from error raise SwarmException(error_msg % peer_id) from error
logger.debug("upgraded mux for peer %s", peer_id) logger.debug("upgraded mux for peer %s", peer_id)
swarm_conn = await self.add_conn(muxed_conn) swarm_conn = await self.add_conn(muxed_conn)
@ -226,7 +220,6 @@ class Swarm(INetworkService):
muxed_conn = await self.upgrader.upgrade_connection( muxed_conn = await self.upgrader.upgrade_connection(
secured_conn, peer_id secured_conn, peer_id
) )
self.manager.run_child_service(muxed_conn)
except MuxerUpgradeFailure as error: except MuxerUpgradeFailure as error:
error_msg = "fail to upgrade mux for peer %s" error_msg = "fail to upgrade mux for peer %s"
logger.debug(error_msg, peer_id) logger.debug(error_msg, peer_id)
@ -235,8 +228,8 @@ class Swarm(INetworkService):
logger.debug("upgraded mux for peer %s", peer_id) logger.debug("upgraded mux for peer %s", peer_id)
await self.add_conn(muxed_conn) await self.add_conn(muxed_conn)
logger.debug("successfully opened connection to peer %s", peer_id) logger.debug("successfully opened connection to peer %s", peer_id)
# NOTE: This is a intentional barrier to prevent from the handler exiting and # NOTE: This is a intentional barrier to prevent from the handler exiting and
# closing the connection. # closing the connection.
await self.manager.wait_finished() await self.manager.wait_finished()
@ -261,26 +254,12 @@ class Swarm(INetworkService):
return False return False
async def close(self) -> None: async def close(self) -> None:
if self.event_closed.is_set():
return
self.event_closed.set()
# Reference: https://github.com/libp2p/go-libp2p-swarm/blob/8be680aef8dea0a4497283f2f98470c2aeae6b65/swarm.go#L124-L134 # noqa: E501
async with trio.open_nursery() as nursery:
for conn in self.connections.values():
nursery.start_soon(conn.close)
async with trio.open_nursery() as nursery:
for listener in self.listeners.values():
nursery.start_soon(listener.close)
# Cancel tasks
await self.manager.stop() await self.manager.stop()
logger.debug("swarm successfully closed") logger.debug("swarm successfully closed")
async def close_peer(self, peer_id: ID) -> None: async def close_peer(self, peer_id: ID) -> None:
if peer_id not in self.connections: if peer_id not in self.connections:
return return
# TODO: Should be changed to close multisple connections,
# if we have several connections per peer in the future.
connection = self.connections[peer_id] connection = self.connections[peer_id]
# NOTE: `connection.close` will delete `peer_id` from `self.connections` # NOTE: `connection.close` will delete `peer_id` from `self.connections`
# and `notify_disconnected` for us. # and `notify_disconnected` for us.
@ -293,12 +272,14 @@ class Swarm(INetworkService):
and start to monitor the connection for its new streams and and start to monitor the connection for its new streams and
disconnection.""" disconnection."""
swarm_conn = SwarmConn(muxed_conn, self) swarm_conn = SwarmConn(muxed_conn, self)
manager = self.manager.run_child_service(swarm_conn) self.manager.run_task(muxed_conn.start)
await muxed_conn.event_started.wait()
self.manager.run_task(swarm_conn.start)
await swarm_conn.event_started.wait()
# Store muxed_conn with peer id # Store muxed_conn with peer id
self.connections[muxed_conn.peer_id] = swarm_conn self.connections[muxed_conn.peer_id] = swarm_conn
# Call notifiers since event occurred # Call notifiers since event occurred
self.notify_connected(swarm_conn) self.notify_connected(swarm_conn)
await manager.wait_started()
return swarm_conn return swarm_conn
def remove_conn(self, swarm_conn: SwarmConn) -> None: def remove_conn(self, swarm_conn: SwarmConn) -> None:
@ -307,8 +288,6 @@ class Swarm(INetworkService):
peer_id = swarm_conn.muxed_conn.peer_id peer_id = swarm_conn.muxed_conn.peer_id
if peer_id not in self.connections: if peer_id not in self.connections:
return return
# TODO: Should be changed to remove the exact connection,
# if we have several connections per peer in the future.
del self.connections[peer_id] del self.connections[peer_id]
# Notifee # Notifee

View File

@ -1,18 +1,19 @@
from abc import abstractmethod from abc import ABC, abstractmethod
from async_service import ServiceAPI import trio
from libp2p.io.abc import ReadWriteCloser from libp2p.io.abc import ReadWriteCloser
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
class IMuxedConn(ServiceAPI): class IMuxedConn(ABC):
""" """
reference: https://github.com/libp2p/go-stream-muxer/blob/master/muxer.go reference: https://github.com/libp2p/go-stream-muxer/blob/master/muxer.go
""" """
peer_id: ID peer_id: ID
event_started: trio.Event
@abstractmethod @abstractmethod
def __init__(self, conn: ISecureConn, peer_id: ID) -> None: def __init__(self, conn: ISecureConn, peer_id: ID) -> None:
@ -27,7 +28,11 @@ class IMuxedConn(ServiceAPI):
@property @property
@abstractmethod @abstractmethod
def is_initiator(self) -> bool: def is_initiator(self) -> bool:
pass """if this connection is the initiator."""
@abstractmethod
async def start(self) -> None:
"""start the multiplexer."""
@abstractmethod @abstractmethod
async def close(self) -> None: async def close(self) -> None:

View File

@ -2,7 +2,6 @@ import logging
import math import math
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
from async_service import Service
import trio import trio
from libp2p.exceptions import ParseError from libp2p.exceptions import ParseError
@ -29,7 +28,7 @@ MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0")
logger = logging.getLogger("libp2p.stream_muxer.mplex.mplex") logger = logging.getLogger("libp2p.stream_muxer.mplex.mplex")
class Mplex(IMuxedConn, Service): class Mplex(IMuxedConn):
""" """
reference: https://github.com/libp2p/go-mplex/blob/master/multiplex.go reference: https://github.com/libp2p/go-mplex/blob/master/multiplex.go
""" """
@ -45,6 +44,7 @@ class Mplex(IMuxedConn, Service):
event_shutting_down: trio.Event event_shutting_down: trio.Event
event_closed: trio.Event event_closed: trio.Event
event_started: trio.Event
def __init__(self, secured_conn: ISecureConn, peer_id: ID) -> None: def __init__(self, secured_conn: ISecureConn, peer_id: ID) -> None:
""" """
@ -73,10 +73,10 @@ class Mplex(IMuxedConn, Service):
self.new_stream_send_channel, self.new_stream_receive_channel = channels self.new_stream_send_channel, self.new_stream_receive_channel = channels
self.event_shutting_down = trio.Event() self.event_shutting_down = trio.Event()
self.event_closed = trio.Event() self.event_closed = trio.Event()
self.event_started = trio.Event()
async def run(self) -> None: async def start(self) -> None:
self.manager.run_task(self.handle_incoming) await self.handle_incoming()
await self.manager.wait_finished()
@property @property
def is_initiator(self) -> bool: def is_initiator(self) -> bool:
@ -91,7 +91,6 @@ class Mplex(IMuxedConn, Service):
await self.secured_conn.close() await self.secured_conn.close()
# Blocked until `close` is finally set. # Blocked until `close` is finally set.
await self.event_closed.wait() await self.event_closed.wait()
await self.manager.stop()
@property @property
def is_closed(self) -> bool: def is_closed(self) -> bool:
@ -178,8 +177,8 @@ class Mplex(IMuxedConn, Service):
async def handle_incoming(self) -> None: async def handle_incoming(self) -> None:
"""Read a message off of the secured connection and add it to the """Read a message off of the secured connection and add it to the
corresponding message buffer.""" corresponding message buffer."""
self.event_started.set()
while self.manager.is_running: while True:
try: try:
await self._handle_incoming_message() await self._handle_incoming_message()
except MplexUnavailable as e: except MplexUnavailable as e:

View File

@ -188,9 +188,7 @@ class MplexStream(IMuxedStream):
if self.is_initiator if self.is_initiator
else HeaderTags.ResetReceiver else HeaderTags.ResetReceiver
) )
self.muxed_conn.manager.run_task( await self.muxed_conn.send_message(flag, None, self.stream_id)
self.muxed_conn.send_message, flag, None, self.stream_id
)
self.event_local_closed.set() self.event_local_closed.set()
self.event_remote_closed.set() self.event_remote_closed.set()

View File

@ -14,7 +14,6 @@ async def test_swarm_conn_close(swarm_conn_pair):
await trio.sleep(0.1) await trio.sleep(0.1)
await wait_all_tasks_blocked() await wait_all_tasks_blocked()
await conn_0.manager.wait_finished()
assert conn_0.is_closed assert conn_0.is_closed
assert conn_1.is_closed assert conn_1.is_closed
@ -26,22 +25,22 @@ async def test_swarm_conn_close(swarm_conn_pair):
async def test_swarm_conn_streams(swarm_conn_pair): async def test_swarm_conn_streams(swarm_conn_pair):
conn_0, conn_1 = swarm_conn_pair conn_0, conn_1 = swarm_conn_pair
assert len(await conn_0.get_streams()) == 0 assert len(conn_0.get_streams()) == 0
assert len(await conn_1.get_streams()) == 0 assert len(conn_1.get_streams()) == 0
stream_0_0 = await conn_0.new_stream() stream_0_0 = await conn_0.new_stream()
await trio.sleep(0.01) await trio.sleep(0.01)
assert len(await conn_0.get_streams()) == 1 assert len(conn_0.get_streams()) == 1
assert len(await conn_1.get_streams()) == 1 assert len(conn_1.get_streams()) == 1
stream_0_1 = await conn_0.new_stream() stream_0_1 = await conn_0.new_stream()
await trio.sleep(0.01) await trio.sleep(0.01)
assert len(await conn_0.get_streams()) == 2 assert len(conn_0.get_streams()) == 2
assert len(await conn_1.get_streams()) == 2 assert len(conn_1.get_streams()) == 2
conn_0.remove_stream(stream_0_0) conn_0.remove_stream(stream_0_0)
assert len(await conn_0.get_streams()) == 1 assert len(conn_0.get_streams()) == 1
conn_0.remove_stream(stream_0_1) conn_0.remove_stream(stream_0_1)
assert len(await conn_0.get_streams()) == 0 assert len(conn_0.get_streams()) == 0
# Nothing happen if `stream_0_1` is not present or already removed. # Nothing happen if `stream_0_1` is not present or already removed.
conn_0.remove_stream(stream_0_1) conn_0.remove_stream(stream_0_1)