diff --git a/libp2p/network/connection/raw_connection_interface.py b/libp2p/network/connection/raw_connection_interface.py index 1e35514b..088eaec1 100644 --- a/libp2p/network/connection/raw_connection_interface.py +++ b/libp2p/network/connection/raw_connection_interface.py @@ -1,3 +1,4 @@ +import asyncio from abc import ABC, abstractmethod @@ -6,6 +7,13 @@ class IRawConnection(ABC): A Raw Connection provides a Reader and a Writer """ + initiator: bool + + # TODO: reader and writer shouldn't be exposed. + # Need better API for the consumers + reader: asyncio.StreamReader + writer: asyncio.StreamWriter + @abstractmethod async def write(self, data: bytes) -> None: pass @@ -13,3 +21,11 @@ class IRawConnection(ABC): @abstractmethod async def read(self) -> bytes: pass + + @abstractmethod + def close(self) -> None: + pass + + @abstractmethod + def next_stream_id(self) -> int: + pass diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 08970ff2..e75d8eea 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -5,13 +5,36 @@ from .utils import encode_uvarint, decode_uvarint_from_stream from .mplex_stream import MplexStream from ..muxed_connection_interface import IMuxedConn +from typing import TYPE_CHECKING, Tuple, Dict + +if TYPE_CHECKING: + from multiaddr import Multiaddr + from libp2p.security.secure_conn_interface import ISecureConn + from libp2p.network.connection.raw_connection_interface import IRawConnection + from libp2p.network.swarm import GenericProtocolHandlerFn + from libp2p.peer.id import ID + from libp2p.stream_muxer.muxed_stream_interface import IMuxedStream + class Mplex(IMuxedConn): """ reference: https://github.com/libp2p/go-mplex/blob/master/multiplex.go """ - def __init__(self, secured_conn, generic_protocol_handler, peer_id): + secured_conn: "ISecureConn" + raw_conn: "IRawConnection" + initiator: bool + generic_protocol_handler = None + peer_id: "ID" + buffers: Dict[int, asyncio.Queue[bytes]] + stream_queue: asyncio.Queue[int] + + def __init__( + self, + secured_conn: "ISecureConn", + generic_protocol_handler: "GenericProtocolHandlerFn", + peer_id: "ID", + ) -> None: """ create a new muxed connection :param conn: an instance of raw connection @@ -39,19 +62,20 @@ class Mplex(IMuxedConn): # Kick off reading asyncio.ensure_future(self.handle_incoming()) - def close(self): + def close(self) -> None: """ close the stream muxer and underlying raw connection """ self.raw_conn.close() - def is_closed(self): + def is_closed(self) -> bool: """ check connection is fully closed :return: true if successful """ + raise NotImplementedError() - async def read_buffer(self, stream_id): + async def read_buffer(self, stream_id: int) -> bytes: """ Read a message from stream_id's buffer, check raw connection for new messages :param stream_id: stream id of stream to read from @@ -69,7 +93,9 @@ class Mplex(IMuxedConn): # Stream not created yet return None - async def open_stream(self, protocol_id, multi_addr): + async def open_stream( + self, protocol_id: str, multi_addr: "Multiaddr" + ) -> "IMuxedStream": """ creates a new muxed_stream :param protocol_id: protocol_id of stream @@ -82,7 +108,7 @@ class Mplex(IMuxedConn): await self.send_message(HeaderTags.NewStream, None, stream_id) return stream - async def accept_stream(self): + async def accept_stream(self) -> None: """ accepts a muxed stream opened by the other end :return: the accepted stream @@ -91,13 +117,12 @@ class Mplex(IMuxedConn): stream = MplexStream(stream_id, False, self) asyncio.ensure_future(self.generic_protocol_handler(stream)) - async def send_message(self, flag: HeaderTags, data, stream_id): + async def send_message(self, flag: HeaderTags, data: bytes, stream_id: int) -> int: """ sends a message over the connection :param header: header to use :param data: data to send in the message :param stream_id: stream the message is in - :return: True if success """ # << by 3, then or with flag header = (stream_id << 3) | flag.value @@ -112,7 +137,7 @@ class Mplex(IMuxedConn): return await self.write_to_stream(_bytes) - async def write_to_stream(self, _bytes): + async def write_to_stream(self, _bytes: bytearray) -> int: """ writes a byte array to a raw connection :param _bytes: byte array to write @@ -122,7 +147,7 @@ class Mplex(IMuxedConn): await self.raw_conn.writer.drain() return len(_bytes) - async def handle_incoming(self): + async def handle_incoming(self) -> None: """ Read a message off of the raw connection and add it to the corresponding message buffer """ @@ -146,7 +171,7 @@ class Mplex(IMuxedConn): # Force context switch await asyncio.sleep(0) - async def read_message(self): + async def read_message(self) -> Tuple[int, int, bytes]: """ Read a single message off of the raw connection :return: stream_id, flag, message contents diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index e452fbda..a2502a88 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -10,7 +10,7 @@ class MplexStream(IMuxedStream): reference: https://github.com/libp2p/go-mplex/blob/master/stream.go """ - def __init__(self, stream_id, initiator, mplex_conn): + def __init__(self, stream_id, initiator: bool, mplex_conn): """ create new MuxedStream in muxer :param stream_id: stream stream id