Fix and add type hints to pubsub.py

This commit is contained in:
NIC619
2019-07-24 16:29:02 +08:00
parent f0046fa3e0
commit 8eb6a230ff

View File

@ -5,12 +5,17 @@ from typing import (
Any, Any,
Dict, Dict,
List, List,
Sequence,
Tuple, Tuple,
) )
from lru import LRU from lru import LRU
from .pb import rpc_pb2
from .pubsub_notifee import PubsubNotifee
from .pubsub_router_interface import (
IPubsubRouter,
)
from libp2p.host.host_interface import ( from libp2p.host.host_interface import (
IHost, IHost,
) )
@ -21,12 +26,6 @@ from libp2p.network.stream.net_stream_interface import (
INetStream, INetStream,
) )
from .pb import rpc_pb2
from .pubsub_notifee import PubsubNotifee
from .pubsub_router_interface import (
IPubsubRouter,
)
def get_msg_id(msg: rpc_pb2.Message) -> Tuple[bytes, bytes]: def get_msg_id(msg: rpc_pb2.Message) -> Tuple[bytes, bytes]:
# NOTE: `string(from, seqno)` in Go # NOTE: `string(from, seqno)` in Go
@ -38,26 +37,32 @@ class Pubsub:
host: IHost host: IHost
my_id: ID my_id: ID
router: IPubsubRouter router: IPubsubRouter
peer_queue: asyncio.Queue
protocols: Sequence[str] # FIXME: Should be changed to `asyncio.Queue[ID]`
incoming_msgs_from_peers: asyncio.Queue() peer_queue: asyncio.Queue[str]
outgoing_messages: asyncio.Queue()
protocols: List[str]
incoming_msgs_from_peers: asyncio.Queue[rpc_pb2.Message]
outgoing_messages: asyncio.Queue[rpc_pb2.Message]
seen_messages: LRU seen_messages: LRU
my_topics: Dict[str, asyncio.Queue] my_topics: Dict[str, asyncio.Queue]
# FIXME: Should be changed to `Dict[str, List[ID]]` # FIXME: Should be changed to `Dict[str, List[ID]]`
peer_topics: Dict[str, List[str]] peer_topics: Dict[str, List[str]]
# FIXME: Should be changed to `Dict[ID, INetStream]` # FIXME: Should be changed to `Dict[ID, INetStream]`
peers: Dict[str, INetStream] peers: Dict[str, INetStream]
# NOTE: Be sure it is increased atomically everytime. # NOTE: Be sure it is increased atomically everytime.
counter: int # uint64 counter: int # uint64
def __init__( def __init__(self,
self, host: IHost,
host: IHost, router: IPubsubRouter,
router: IPubsubRouter, my_id: ID,
my_id: ID, cache_size: int = None) -> None:
cache_size: int = None) -> None:
""" """
Construct a new Pubsub object, which is responsible for handling all Construct a new Pubsub object, which is responsible for handling all
Pubsub-related messages and relaying messages as appropriate to the Pubsub-related messages and relaying messages as appropriate to the
@ -73,6 +78,7 @@ class Pubsub:
self.router.attach(self) self.router.attach(self)
# Register a notifee # Register a notifee
# FIXME: Should be changed to `asyncio.Queue[ID]`
self.peer_queue = asyncio.Queue() self.peer_queue = asyncio.Queue()
self.host.get_network().notify(PubsubNotifee(self.peer_queue)) self.host.get_network().notify(PubsubNotifee(self.peer_queue))
@ -99,9 +105,11 @@ class Pubsub:
self.my_topics = {} self.my_topics = {}
# Map of topic to peers to keep track of what peers are subscribed to # Map of topic to peers to keep track of what peers are subscribed to
# FIXME: Should be changed to `Dict[str, ID]`
self.peer_topics = {} self.peer_topics = {}
# Create peers map, which maps peer_id (as string) to stream (to a given peer) # Create peers map, which maps peer_id (as string) to stream (to a given peer)
# FIXME: Should be changed to `Dict[ID, INetStream]`
self.peers = {} self.peers = {}
self.counter = time.time_ns() self.counter = time.time_ns()
@ -114,7 +122,7 @@ class Pubsub:
Generate subscription message with all topics we are subscribed to Generate subscription message with all topics we are subscribed to
only send hello packet if we have subscribed topics only send hello packet if we have subscribed topics
""" """
packet = rpc_pb2.RPC() packet: rpc_pb2.RPC = rpc_pb2.RPC()
if self.my_topics: if self.my_topics:
for topic_id in self.my_topics: for topic_id in self.my_topics:
packet.subscriptions.extend([rpc_pb2.RPC.SubOpts( packet.subscriptions.extend([rpc_pb2.RPC.SubOpts(
@ -131,8 +139,8 @@ class Pubsub:
peer_id = stream.mplex_conn.peer_id peer_id = stream.mplex_conn.peer_id
while True: while True:
incoming = (await stream.read()) incoming: bytes = (await stream.read())
rpc_incoming = rpc_pb2.RPC() rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC()
rpc_incoming.ParseFromString(incoming) rpc_incoming.ParseFromString(incoming)
if rpc_incoming.publish: if rpc_incoming.publish:
@ -168,12 +176,12 @@ class Pubsub:
""" """
# Add peer # Add peer
# Map peer to stream # Map peer to stream
peer_id = stream.mplex_conn.peer_id peer_id: ID = stream.mplex_conn.peer_id
self.peers[str(peer_id)] = stream self.peers[str(peer_id)] = stream
self.router.add_peer(peer_id, stream.get_protocol()) self.router.add_peer(peer_id, stream.get_protocol())
# Send hello packet # Send hello packet
hello = self.get_hello_packet() hello: bytes = self.get_hello_packet()
await stream.write(hello) await stream.write(hello)
# Pass stream off to stream reader # Pass stream off to stream reader
@ -188,12 +196,12 @@ class Pubsub:
""" """
while True: while True:
peer_id = await self.peer_queue.get() peer_id: ID = await self.peer_queue.get()
# Open a stream to peer on existing connection # Open a stream to peer on existing connection
# (we know connection exists since that's the only way # (we know connection exists since that's the only way
# an element gets added to peer_queue) # an element gets added to peer_queue)
stream = await self.host.new_stream(peer_id, self.protocols) stream: INetStream = await self.host.new_stream(peer_id, self.protocols)
# Add Peer # Add Peer
# Map peer to stream # Map peer to stream
@ -201,7 +209,7 @@ class Pubsub:
self.router.add_peer(peer_id, stream.get_protocol()) self.router.add_peer(peer_id, stream.get_protocol())
# Send hello packet # Send hello packet
hello = self.get_hello_packet() hello: bytes = self.get_hello_packet()
await stream.write(hello) await stream.write(hello)
# Pass stream off to stream reader # Pass stream off to stream reader
@ -210,8 +218,9 @@ class Pubsub:
# Force context switch # Force context switch
await asyncio.sleep(0) await asyncio.sleep(0)
# FIXME: type of `origin_id` should be changed to `ID`
# FIXME: `sub_message` can be further type hinted with mypy_protobuf # FIXME: `sub_message` can be further type hinted with mypy_protobuf
def handle_subscription(self, origin_id: ID, sub_message: Any) -> None: def handle_subscription(self, origin_id: str, sub_message: Any) -> None:
""" """
Handle an incoming subscription message from a peer. Update internal Handle an incoming subscription message from a peer. Update internal
mapping to mark the peer as subscribed or unsubscribed to topics as mapping to mark the peer as subscribed or unsubscribed to topics as
@ -236,7 +245,7 @@ class Pubsub:
async def handle_talk(self, publish_message: Any) -> None: async def handle_talk(self, publish_message: Any) -> None:
""" """
Put incoming message from a peer onto my blocking queue Put incoming message from a peer onto my blocking queue
:param talk: RPC.Message format :param publish_message: RPC.Message format
""" """
# Check if this message has any topics that we are subscribed to # Check if this message has any topics that we are subscribed to
@ -247,7 +256,7 @@ class Pubsub:
# for each topic # for each topic
await self.my_topics[topic].put(publish_message) await self.my_topics[topic].put(publish_message)
async def subscribe(self, topic_id: str) -> asyncio.Queue: async def subscribe(self, topic_id: str) -> asyncio.Queue[rpc_pb2.Message]:
""" """
Subscribe ourself to a topic Subscribe ourself to a topic
:param topic_id: topic_id to subscribe to :param topic_id: topic_id to subscribe to
@ -261,7 +270,7 @@ class Pubsub:
self.my_topics[topic_id] = asyncio.Queue() self.my_topics[topic_id] = asyncio.Queue()
# Create subscribe message # Create subscribe message
packet = rpc_pb2.RPC() packet: rpc_pb2.RPC = rpc_pb2.RPC()
packet.subscriptions.extend([rpc_pb2.RPC.SubOpts( packet.subscriptions.extend([rpc_pb2.RPC.SubOpts(
subscribe=True, subscribe=True,
topicid=topic_id.encode('utf-8') topicid=topic_id.encode('utf-8')
@ -289,7 +298,7 @@ class Pubsub:
del self.my_topics[topic_id] del self.my_topics[topic_id]
# Create unsubscribe message # Create unsubscribe message
packet = rpc_pb2.RPC() packet: rpc_pb2.RPC = rpc_pb2.RPC()
packet.subscriptions.extend([rpc_pb2.RPC.SubOpts( packet.subscriptions.extend([rpc_pb2.RPC.SubOpts(
subscribe=False, subscribe=False,
topicid=topic_id.encode('utf-8') topicid=topic_id.encode('utf-8')
@ -301,8 +310,8 @@ class Pubsub:
# Tell router we are leaving this topic # Tell router we are leaving this topic
await self.router.leave(topic_id) await self.router.leave(topic_id)
# FIXME: `rpc_msg` can be further type hinted with mypy_protobuf # FIXME: `raw_msg` can be further type hinted with mypy_protobuf
async def message_all_peers(self, rpc_msg: Any) -> None: async def message_all_peers(self, raw_msg: Any) -> None:
""" """
Broadcast a message to peers Broadcast a message to peers
:param raw_msg: raw contents of the message to broadcast :param raw_msg: raw contents of the message to broadcast
@ -311,7 +320,7 @@ class Pubsub:
# Broadcast message # Broadcast message
for _, stream in self.peers.items(): for _, stream in self.peers.items():
# Write message to stream # Write message to stream
await stream.write(rpc_msg) await stream.write(raw_msg)
async def publish(self, topic_id: str, data: bytes) -> None: async def publish(self, topic_id: str, data: bytes) -> None:
""" """