mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Merge pull request #271 from mhchia/fix/pubsub-interop
Pubsub interop with go-libp2p-daemon
This commit is contained in:
@ -1,11 +1,11 @@
|
||||
import heapq
|
||||
from operator import itemgetter
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerdata import PeerData
|
||||
from libp2p.peer.peerinfo import PeerInfo
|
||||
|
||||
from .utils import digest
|
||||
@ -15,16 +15,16 @@ P_UDP = "udp"
|
||||
|
||||
|
||||
class KadPeerInfo(PeerInfo):
|
||||
def __init__(self, peer_id, peer_data=None):
|
||||
super(KadPeerInfo, self).__init__(peer_id, peer_data)
|
||||
def __init__(self, peer_id, addrs):
|
||||
super(KadPeerInfo, self).__init__(peer_id, addrs)
|
||||
|
||||
self.peer_id_bytes = peer_id.to_bytes()
|
||||
self.xor_id = peer_id.xor_id
|
||||
|
||||
self.addrs = peer_data.get_addrs() if peer_data else None
|
||||
self.addrs = addrs
|
||||
|
||||
self.ip = self.addrs[0].value_for_protocol(P_IP) if peer_data else None
|
||||
self.port = int(self.addrs[0].value_for_protocol(P_UDP)) if peer_data else None
|
||||
self.ip = self.addrs[0].value_for_protocol(P_IP) if addrs else None
|
||||
self.port = int(self.addrs[0].value_for_protocol(P_UDP)) if addrs else None
|
||||
|
||||
def same_home_as(self, node):
|
||||
return sorted(self.addrs) == sorted(node.addrs)
|
||||
@ -142,14 +142,14 @@ def create_kad_peerinfo(node_id_bytes=None, sender_ip=None, sender_port=None):
|
||||
node_id = (
|
||||
ID(node_id_bytes) if node_id_bytes else ID(digest(random.getrandbits(255)))
|
||||
)
|
||||
peer_data = None
|
||||
addrs: List[Multiaddr]
|
||||
if sender_ip and sender_port:
|
||||
peer_data = PeerData()
|
||||
addr = [
|
||||
addrs = [
|
||||
Multiaddr(
|
||||
"/" + P_IP + "/" + str(sender_ip) + "/" + P_UDP + "/" + str(sender_port)
|
||||
)
|
||||
]
|
||||
peer_data.add_addrs(addr)
|
||||
else:
|
||||
addrs = []
|
||||
|
||||
return KadPeerInfo(node_id, peer_data)
|
||||
return KadPeerInfo(node_id, addrs)
|
||||
|
||||
@ -7,6 +7,7 @@ from .net_stream_interface import INetStream
|
||||
class NetStream(INetStream):
|
||||
|
||||
muxed_stream: IMuxedStream
|
||||
# TODO: Why we expose `mplex_conn` here?
|
||||
mplex_conn: IMuxedConn
|
||||
protocol_id: TProtocol
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Callable, Dict, List, Sequence
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerstore import PeerStoreError
|
||||
from libp2p.peer.peerstore_interface import IPeerStore
|
||||
from libp2p.protocol_muxer.multiselect import Multiselect
|
||||
from libp2p.protocol_muxer.multiselect_client import MultiselectClient
|
||||
@ -92,55 +93,55 @@ class Swarm(INetwork):
|
||||
:return: muxed connection
|
||||
"""
|
||||
|
||||
# Get peer info from peer store
|
||||
addrs = self.peerstore.addrs(peer_id)
|
||||
if peer_id in self.connections:
|
||||
# If muxed connection already exists for peer_id,
|
||||
# set muxed connection equal to existing muxed connection
|
||||
return self.connections[peer_id]
|
||||
|
||||
try:
|
||||
# Get peer info from peer store
|
||||
addrs = self.peerstore.addrs(peer_id)
|
||||
except PeerStoreError:
|
||||
raise SwarmException(f"No known addresses to peer {peer_id}")
|
||||
|
||||
if not addrs:
|
||||
raise SwarmException("No known addresses to peer")
|
||||
raise SwarmException(f"No known addresses to peer {peer_id}")
|
||||
|
||||
if not self.router:
|
||||
multiaddr = addrs[0]
|
||||
else:
|
||||
multiaddr = self.router.find_peer(peer_id)
|
||||
# Dial peer (connection to peer does not yet exist)
|
||||
# Transport dials peer (gets back a raw conn)
|
||||
raw_conn = await self.transport.dial(multiaddr, self.self_id)
|
||||
|
||||
if peer_id in self.connections:
|
||||
# If muxed connection already exists for peer_id,
|
||||
# set muxed connection equal to existing muxed connection
|
||||
muxed_conn = self.connections[peer_id]
|
||||
else:
|
||||
# Dial peer (connection to peer does not yet exist)
|
||||
# Transport dials peer (gets back a raw conn)
|
||||
raw_conn = await self.transport.dial(multiaddr, self.self_id)
|
||||
# Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure
|
||||
# the conn and then mux the conn
|
||||
try:
|
||||
secured_conn = await self.upgrader.upgrade_security(raw_conn, peer_id, True)
|
||||
except SecurityUpgradeFailure as error:
|
||||
# TODO: Add logging to indicate the failure
|
||||
await raw_conn.close()
|
||||
raise SwarmException(
|
||||
f"fail to upgrade the connection to a secured connection from {peer_id}"
|
||||
) from error
|
||||
try:
|
||||
muxed_conn = await self.upgrader.upgrade_connection(
|
||||
secured_conn, self.generic_protocol_handler, peer_id
|
||||
)
|
||||
except MuxerUpgradeFailure as error:
|
||||
# TODO: Add logging to indicate the failure
|
||||
await secured_conn.close()
|
||||
raise SwarmException(
|
||||
f"fail to upgrade the connection to a muxed connection from {peer_id}"
|
||||
) from error
|
||||
|
||||
# Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure
|
||||
# the conn and then mux the conn
|
||||
try:
|
||||
secured_conn = await self.upgrader.upgrade_security(
|
||||
raw_conn, peer_id, True
|
||||
)
|
||||
except SecurityUpgradeFailure as error:
|
||||
# TODO: Add logging to indicate the failure
|
||||
await raw_conn.close()
|
||||
raise SwarmException(
|
||||
f"fail to upgrade the connection to a secured connection from {peer_id}"
|
||||
) from error
|
||||
try:
|
||||
muxed_conn = await self.upgrader.upgrade_connection(
|
||||
secured_conn, self.generic_protocol_handler, peer_id
|
||||
)
|
||||
except MuxerUpgradeFailure as error:
|
||||
# TODO: Add logging to indicate the failure
|
||||
await secured_conn.close()
|
||||
raise SwarmException(
|
||||
f"fail to upgrade the connection to a muxed connection from {peer_id}"
|
||||
) from error
|
||||
# Store muxed connection in connections
|
||||
self.connections[peer_id] = muxed_conn
|
||||
|
||||
# Store muxed connection in connections
|
||||
self.connections[peer_id] = muxed_conn
|
||||
|
||||
# Call notifiers since event occurred
|
||||
for notifee in self.notifees:
|
||||
await notifee.connected(self, muxed_conn)
|
||||
# Call notifiers since event occurred
|
||||
for notifee in self.notifees:
|
||||
await notifee.connected(self, muxed_conn)
|
||||
|
||||
return muxed_conn
|
||||
|
||||
@ -152,11 +153,6 @@ class Swarm(INetwork):
|
||||
:param protocol_id: protocol id
|
||||
:return: net stream instance
|
||||
"""
|
||||
# Get peer info from peer store
|
||||
addrs = self.peerstore.addrs(peer_id)
|
||||
|
||||
if not addrs:
|
||||
raise SwarmException("No known addresses to peer")
|
||||
|
||||
muxed_conn = await self.dial_peer(peer_id)
|
||||
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
from typing import List
|
||||
from typing import List, Sequence
|
||||
|
||||
import multiaddr
|
||||
|
||||
from .id import ID
|
||||
from .peerdata import PeerData
|
||||
|
||||
|
||||
class PeerInfo:
|
||||
@ -11,9 +10,9 @@ class PeerInfo:
|
||||
peer_id: ID
|
||||
addrs: List[multiaddr.Multiaddr]
|
||||
|
||||
def __init__(self, peer_id: ID, peer_data: PeerData = None) -> None:
|
||||
def __init__(self, peer_id: ID, addrs: Sequence[multiaddr.Multiaddr]) -> None:
|
||||
self.peer_id = peer_id
|
||||
self.addrs = peer_data.get_addrs() if peer_data else None
|
||||
self.addrs = list(addrs)
|
||||
|
||||
|
||||
def info_from_p2p_addr(addr: multiaddr.Multiaddr) -> PeerInfo:
|
||||
@ -44,11 +43,7 @@ def info_from_p2p_addr(addr: multiaddr.Multiaddr) -> PeerInfo:
|
||||
if len(parts) > 1:
|
||||
addr = multiaddr.Multiaddr.join(*parts[:-1])
|
||||
|
||||
peer_data = PeerData()
|
||||
peer_data.add_addrs([addr])
|
||||
peer_data.set_protocols([p.code for p in addr.protocols()])
|
||||
|
||||
return PeerInfo(peer_id, peer_data)
|
||||
return PeerInfo(peer_id, [addr])
|
||||
|
||||
|
||||
class InvalidAddrError(ValueError):
|
||||
|
||||
@ -33,7 +33,7 @@ class PeerStore(IPeerStore):
|
||||
def peer_info(self, peer_id: ID) -> Optional[PeerInfo]:
|
||||
if peer_id in self.peer_map:
|
||||
peer_data = self.peer_map[peer_id]
|
||||
return PeerInfo(peer_id, peer_data)
|
||||
return PeerInfo(peer_id, peer_data.addrs)
|
||||
return None
|
||||
|
||||
def get_protocols(self, peer_id: ID) -> List[str]:
|
||||
|
||||
@ -2,11 +2,14 @@ from typing import Iterable, List, Sequence
|
||||
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.typing import TProtocol
|
||||
from libp2p.utils import encode_varint_prefixed
|
||||
|
||||
from .pb import rpc_pb2
|
||||
from .pubsub import Pubsub
|
||||
from .pubsub_router_interface import IPubsubRouter
|
||||
|
||||
PROTOCOL_ID = TProtocol("/floodsub/1.0.0")
|
||||
|
||||
|
||||
class FloodSub(IPubsubRouter):
|
||||
|
||||
@ -76,7 +79,7 @@ class FloodSub(IPubsubRouter):
|
||||
stream = self.pubsub.peers[peer_id]
|
||||
# FIXME: We should add a `WriteMsg` similar to write delimited messages.
|
||||
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107
|
||||
await stream.write(rpc_msg.SerializeToString())
|
||||
await stream.write(encode_varint_prefixed(rpc_msg.SerializeToString()))
|
||||
|
||||
async def join(self, topic: str) -> None:
|
||||
"""
|
||||
|
||||
@ -4,13 +4,17 @@ import random
|
||||
from typing import Any, Dict, Iterable, List, Sequence, Set
|
||||
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.pubsub import floodsub
|
||||
from libp2p.typing import TProtocol
|
||||
from libp2p.utils import encode_varint_prefixed
|
||||
|
||||
from .mcache import MessageCache
|
||||
from .pb import rpc_pb2
|
||||
from .pubsub import Pubsub
|
||||
from .pubsub_router_interface import IPubsubRouter
|
||||
|
||||
PROTOCOL_ID = TProtocol("/meshsub/1.0.0")
|
||||
|
||||
|
||||
class GossipSub(IPubsubRouter):
|
||||
|
||||
@ -104,16 +108,19 @@ class GossipSub(IPubsubRouter):
|
||||
:param peer_id: id of peer to add
|
||||
:param protocol_id: router protocol the peer speaks, e.g., floodsub, gossipsub
|
||||
"""
|
||||
|
||||
# Add peer to the correct peer list
|
||||
peer_type = GossipSub.get_peer_type(protocol_id)
|
||||
|
||||
self.peers_to_protocol[peer_id] = protocol_id
|
||||
|
||||
if peer_type == "gossip":
|
||||
if protocol_id == PROTOCOL_ID:
|
||||
self.peers_gossipsub.append(peer_id)
|
||||
elif peer_type == "flood":
|
||||
elif protocol_id == floodsub.PROTOCOL_ID:
|
||||
self.peers_floodsub.append(peer_id)
|
||||
else:
|
||||
# We should never enter here. Becuase the `protocol_id` is registered by your pubsub
|
||||
# instance in multistream-select, but it is not the protocol that gossipsub supports.
|
||||
# In this case, probably we registered gossipsub to a wrong `protocol_id`
|
||||
# in multistream-select, or wrong versions.
|
||||
# TODO: Better handling
|
||||
raise Exception(f"protocol is not supported: protocol_id={protocol_id}")
|
||||
|
||||
def remove_peer(self, peer_id: ID) -> None:
|
||||
"""
|
||||
@ -167,7 +174,7 @@ class GossipSub(IPubsubRouter):
|
||||
# FIXME: We should add a `WriteMsg` similar to write delimited messages.
|
||||
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107
|
||||
# TODO: Go use `sendRPC`, which possibly piggybacks gossip/control messages.
|
||||
await stream.write(rpc_msg.SerializeToString())
|
||||
await stream.write(encode_varint_prefixed(rpc_msg.SerializeToString()))
|
||||
|
||||
def _get_peers_to_send(
|
||||
self, topic_ids: Iterable[str], msg_forwarder: ID, origin: ID
|
||||
@ -264,29 +271,6 @@ class GossipSub(IPubsubRouter):
|
||||
# Forget mesh[topic]
|
||||
self.mesh.pop(topic, None)
|
||||
|
||||
# Interface Helper Functions
|
||||
@staticmethod
|
||||
def get_peer_type(protocol_id: str) -> str:
|
||||
# TODO: Do this in a better, more efficient way
|
||||
if "gossipsub" in protocol_id:
|
||||
return "gossip"
|
||||
if "floodsub" in protocol_id:
|
||||
return "flood"
|
||||
return "unknown"
|
||||
|
||||
async def deliver_messages_to_peers(
|
||||
self, peers: List[ID], msg_sender: ID, origin_id: ID, serialized_packet: bytes
|
||||
) -> None:
|
||||
for peer_id_in_topic in peers:
|
||||
# Forward to all peers that are not the
|
||||
# message sender and are not the message origin
|
||||
|
||||
if peer_id_in_topic not in (msg_sender, origin_id):
|
||||
stream = self.pubsub.peers[peer_id_in_topic]
|
||||
|
||||
# Publish the packet
|
||||
await stream.write(serialized_packet)
|
||||
|
||||
# Heartbeat
|
||||
async def heartbeat(self) -> None:
|
||||
"""
|
||||
@ -509,7 +493,7 @@ class GossipSub(IPubsubRouter):
|
||||
peer_stream = self.pubsub.peers[sender_peer_id]
|
||||
|
||||
# 4) And write the packet to the stream
|
||||
await peer_stream.write(rpc_msg)
|
||||
await peer_stream.write(encode_varint_prefixed(rpc_msg))
|
||||
|
||||
async def handle_graft(
|
||||
self, graft_msg: rpc_pb2.ControlGraft, sender_peer_id: ID
|
||||
@ -601,4 +585,4 @@ class GossipSub(IPubsubRouter):
|
||||
peer_stream = self.pubsub.peers[to_peer]
|
||||
|
||||
# Write rpc to stream
|
||||
await peer_stream.write(rpc_msg)
|
||||
await peer_stream.write(encode_varint_prefixed(rpc_msg))
|
||||
|
||||
@ -21,6 +21,7 @@ from libp2p.host.host_interface import IHost
|
||||
from libp2p.network.stream.net_stream_interface import INetStream
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.typing import TProtocol
|
||||
from libp2p.utils import encode_varint_prefixed, read_varint_prefixed_bytes
|
||||
|
||||
from .pb import rpc_pb2
|
||||
from .pubsub_notifee import PubsubNotifee
|
||||
@ -71,7 +72,7 @@ class Pubsub:
|
||||
|
||||
topic_validators: Dict[str, TopicValidator]
|
||||
|
||||
# NOTE: Be sure it is increased atomically everytime.
|
||||
# TODO: Be sure it is increased atomically everytime.
|
||||
counter: int # uint64
|
||||
|
||||
def __init__(
|
||||
@ -131,7 +132,7 @@ class Pubsub:
|
||||
# Call handle peer to keep waiting for updates to peer queue
|
||||
asyncio.ensure_future(self.handle_peer_queue())
|
||||
|
||||
def get_hello_packet(self) -> bytes:
|
||||
def get_hello_packet(self) -> rpc_pb2.RPC:
|
||||
"""
|
||||
Generate subscription message with all topics we are subscribed to
|
||||
only send hello packet if we have subscribed topics
|
||||
@ -141,7 +142,7 @@ class Pubsub:
|
||||
packet.subscriptions.extend(
|
||||
[rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id)]
|
||||
)
|
||||
return packet.SerializeToString()
|
||||
return packet
|
||||
|
||||
async def continuously_read_stream(self, stream: INetStream) -> None:
|
||||
"""
|
||||
@ -152,17 +153,14 @@ class Pubsub:
|
||||
peer_id = stream.mplex_conn.peer_id
|
||||
|
||||
while True:
|
||||
incoming: bytes = (await stream.read())
|
||||
incoming: bytes = await read_varint_prefixed_bytes(stream)
|
||||
rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC()
|
||||
rpc_incoming.ParseFromString(incoming)
|
||||
|
||||
if rpc_incoming.publish:
|
||||
# deal with RPC.publish
|
||||
for msg in rpc_incoming.publish:
|
||||
if not self._is_subscribed_to_msg(msg):
|
||||
continue
|
||||
# TODO(mhchia): This will block this read_stream loop until all data are pushed.
|
||||
# Should investigate further if this is an issue.
|
||||
asyncio.ensure_future(self.push_msg(msg_forwarder=peer_id, msg=msg))
|
||||
|
||||
if rpc_incoming.subscriptions:
|
||||
@ -220,20 +218,19 @@ class Pubsub:
|
||||
on one of the supported pubsub protocols.
|
||||
:param stream: newly created stream
|
||||
"""
|
||||
# Add peer
|
||||
# Map peer to stream
|
||||
peer_id: ID = stream.mplex_conn.peer_id
|
||||
await self.continuously_read_stream(stream)
|
||||
|
||||
async def _handle_new_peer(self, peer_id: ID) -> None:
|
||||
stream: INetStream = await self.host.new_stream(peer_id, self.protocols)
|
||||
|
||||
self.peers[peer_id] = stream
|
||||
self.router.add_peer(peer_id, stream.get_protocol())
|
||||
|
||||
# Send hello packet
|
||||
hello: bytes = self.get_hello_packet()
|
||||
|
||||
await stream.write(hello)
|
||||
# Pass stream off to stream reader
|
||||
asyncio.ensure_future(self.continuously_read_stream(stream))
|
||||
# Force context switch
|
||||
await asyncio.sleep(0)
|
||||
hello = self.get_hello_packet()
|
||||
await stream.write(encode_varint_prefixed(hello.SerializeToString()))
|
||||
# TODO: Check EOF of this stream.
|
||||
# TODO: Check if the peer in black list.
|
||||
self.router.add_peer(peer_id, stream.get_protocol())
|
||||
|
||||
async def handle_peer_queue(self) -> None:
|
||||
"""
|
||||
@ -246,25 +243,9 @@ class Pubsub:
|
||||
|
||||
peer_id: ID = await self.peer_queue.get()
|
||||
|
||||
# Open a stream to peer on existing connection
|
||||
# (we know connection exists since that's the only way
|
||||
# an element gets added to peer_queue)
|
||||
stream: INetStream = await self.host.new_stream(peer_id, self.protocols)
|
||||
|
||||
# Add Peer
|
||||
# Map peer to stream
|
||||
self.peers[peer_id] = stream
|
||||
self.router.add_peer(peer_id, stream.get_protocol())
|
||||
|
||||
# Send hello packet
|
||||
hello: bytes = self.get_hello_packet()
|
||||
await stream.write(hello)
|
||||
|
||||
# TODO: Investigate whether this should be replaced by `handlePeerEOF`
|
||||
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/49274b0e8aecdf6cad59d768e5702ff00aa48488/comm.go#L80 # noqa: E501
|
||||
# Pass stream off to stream reader
|
||||
asyncio.ensure_future(self.continuously_read_stream(stream))
|
||||
|
||||
asyncio.ensure_future(self._handle_new_peer(peer_id))
|
||||
# Force context switch
|
||||
await asyncio.sleep(0)
|
||||
|
||||
@ -365,7 +346,7 @@ class Pubsub:
|
||||
# Broadcast message
|
||||
for stream in self.peers.values():
|
||||
# Write message to stream
|
||||
await stream.write(raw_msg)
|
||||
await stream.write(encode_varint_prefixed(raw_msg))
|
||||
|
||||
async def publish(self, topic_id: str, data: bytes) -> None:
|
||||
"""
|
||||
|
||||
@ -36,11 +36,7 @@ class PubsubNotifee(INotifee):
|
||||
:param network: network the connection was opened on
|
||||
:param conn: connection that was opened
|
||||
"""
|
||||
|
||||
# Only add peer_id if we are initiator (otherwise we would end up
|
||||
# with two pubsub streams between us and the peer)
|
||||
if conn.initiator:
|
||||
await self.initiator_peers_queue.put(conn.peer_id)
|
||||
await self.initiator_peers_queue.put(conn.peer_id)
|
||||
|
||||
async def disconnected(self, network: INetwork, conn: IMuxedConn) -> None:
|
||||
pass
|
||||
|
||||
@ -9,5 +9,4 @@ if TYPE_CHECKING:
|
||||
TProtocol = NewType("TProtocol", str)
|
||||
StreamHandlerFn = Callable[["INetStream"], Awaitable[None]]
|
||||
|
||||
|
||||
StreamReader = Union["IMuxedStream", IRawConnection]
|
||||
StreamReader = Union["IMuxedStream", "INetStream", IRawConnection]
|
||||
|
||||
Reference in New Issue
Block a user