Tested against subscriptions and publish

This commit is contained in:
mhchia
2019-09-02 23:21:57 +08:00
parent 3717dc9adf
commit 194b494057
6 changed files with 46 additions and 43 deletions

View File

@ -16,12 +16,12 @@ from typing import (
from lru import LRU
from libp2p.utils import encode_varint_prefixed
from libp2p.exceptions import ValidationError
from libp2p.host.host_interface import IHost
from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.peer.id import ID
from libp2p.typing import TProtocol
from libp2p.utils import encode_varint_prefixed, read_varint_prefixed_bytes
from .pb import rpc_pb2
from .pubsub_notifee import PubsubNotifee
@ -153,11 +153,14 @@ class Pubsub:
peer_id = stream.mplex_conn.peer_id
while True:
incoming: bytes = (await stream.read())
print("!@# continuously_read_stream: waiting")
incoming: bytes = await read_varint_prefixed_bytes(stream)
print(f"!@# continuously_read_stream: incoming={incoming}")
rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC()
rpc_incoming.ParseFromString(incoming)
if rpc_incoming.publish:
print("!@# continuously_read_stream: publish")
# deal with RPC.publish
for msg in rpc_incoming.publish:
if not self._is_subscribed_to_msg(msg):
@ -167,6 +170,7 @@ class Pubsub:
asyncio.ensure_future(self.push_msg(msg_forwarder=peer_id, msg=msg))
if rpc_incoming.subscriptions:
print("!@# continuously_read_stream: subscriptions")
# deal with RPC.subscriptions
# We don't need to relay the subscription to our
# peers because a given node only needs its peers
@ -179,6 +183,7 @@ class Pubsub:
# This is necessary because `control` is an optional field in pb2.
# Ref: https://developers.google.com/protocol-buffers/docs/reference/python-generated#singular-fields-proto2 # noqa: E501
if rpc_incoming.HasField("control"):
print("!@# continuously_read_stream: control")
# Pass rpc to router so router could perform custom logic
await self.router.handle_rpc(rpc_incoming, peer_id)
@ -221,20 +226,23 @@ class Pubsub:
on one of the supported pubsub protocols.
:param stream: newly created stream
"""
# Add peer
await self.continuously_read_stream(stream)
async def _handle_new_peer(self, peer_id: ID) -> None:
# Open a stream to peer on existing connection
# (we know connection exists since that's the only way
# an element gets added to peer_queue)
stream: INetStream = await self.host.new_stream(peer_id, self.protocols)
# Map peer to stream
peer_id: ID = stream.mplex_conn.peer_id
self.peers[peer_id] = stream
self.router.add_peer(peer_id, stream.get_protocol())
# Send hello packet
hello = self.get_hello_packet()
await stream.write(hello.SerializeToString())
# Pass stream off to stream reader
asyncio.ensure_future(self.continuously_read_stream(stream))
# Force context switch
await asyncio.sleep(0)
await stream.write(encode_varint_prefixed(hello.SerializeToString()))
# TODO: Check EOF in the future in the stream's lifetime.
# TODO: Check if the peer in black list.
self.router.add_peer(peer_id, stream.get_protocol())
async def handle_peer_queue(self) -> None:
"""
@ -247,25 +255,9 @@ class Pubsub:
peer_id: ID = await self.peer_queue.get()
# Open a stream to peer on existing connection
# (we know connection exists since that's the only way
# an element gets added to peer_queue)
stream: INetStream = await self.host.new_stream(peer_id, self.protocols)
# Add Peer
# Map peer to stream
self.peers[peer_id] = stream
self.router.add_peer(peer_id, stream.get_protocol())
# Send hello packet
hello = self.get_hello_packet()
await stream.write(hello.SerializeToString())
# TODO: Investigate whether this should be replaced by `handlePeerEOF`
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/49274b0e8aecdf6cad59d768e5702ff00aa48488/comm.go#L80 # noqa: E501
# Pass stream off to stream reader
asyncio.ensure_future(self.continuously_read_stream(stream))
asyncio.ensure_future(self._handle_new_peer(peer_id))
# Force context switch
await asyncio.sleep(0)
@ -366,7 +358,7 @@ class Pubsub:
# Broadcast message
for stream in self.peers.values():
# Write message to stream
await stream.write(raw_msg)
await stream.write(encode_varint_prefixed(raw_msg))
async def publish(self, topic_id: str, data: bytes) -> None:
"""