mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Merge branch 'master' into feature/porting-to-trio
This commit is contained in:
@ -9,6 +9,7 @@ from typing import (
|
||||
KeysView,
|
||||
List,
|
||||
NamedTuple,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
@ -19,6 +20,7 @@ import base58
|
||||
from lru import LRU
|
||||
import trio
|
||||
|
||||
from libp2p.crypto.keys import PrivateKey
|
||||
from libp2p.exceptions import ParseError, ValidationError
|
||||
from libp2p.host.host_interface import IHost
|
||||
from libp2p.io.exceptions import IncompleteReadError
|
||||
@ -33,7 +35,7 @@ from .abc import IPubsub, ISubscriptionAPI
|
||||
from .pb import rpc_pb2
|
||||
from .pubsub_notifee import PubsubNotifee
|
||||
from .subscription import TrioSubscriptionAPI
|
||||
from .validators import signature_validator
|
||||
from .validators import PUBSUB_SIGNING_PREFIX, signature_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .abc import IPubsubRouter # noqa: F401
|
||||
@ -73,16 +75,23 @@ class Pubsub(IPubsub, Service):
|
||||
subscribed_topics_send: Dict[str, "trio.MemorySendChannel[rpc_pb2.Message]"]
|
||||
subscribed_topics_receive: Dict[str, "TrioSubscriptionAPI"]
|
||||
|
||||
peer_topics: Dict[str, List[ID]]
|
||||
peer_topics: Dict[str, Set[ID]]
|
||||
peers: Dict[ID, INetStream]
|
||||
|
||||
topic_validators: Dict[str, TopicValidator]
|
||||
|
||||
# TODO: Be sure it is increased atomically everytime.
|
||||
counter: int # uint64
|
||||
|
||||
# Indicate if we should enforce signature verification
|
||||
strict_signing: bool
|
||||
sign_key: PrivateKey
|
||||
|
||||
def __init__(
|
||||
self, host: IHost, router: "IPubsubRouter", cache_size: int = None
|
||||
self,
|
||||
host: IHost,
|
||||
router: "IPubsubRouter",
|
||||
cache_size: int = None,
|
||||
strict_signing: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Construct a new Pubsub object, which is responsible for handling all
|
||||
@ -126,6 +135,12 @@ class Pubsub(IPubsub, Service):
|
||||
else:
|
||||
self.cache_size = cache_size
|
||||
|
||||
self.strict_signing = strict_signing
|
||||
if strict_signing:
|
||||
self.sign_key = self.host.get_private_key()
|
||||
else:
|
||||
self.sign_key = None
|
||||
|
||||
self.seen_messages = LRU(self.cache_size)
|
||||
|
||||
# Map of topics we are subscribed to blocking queues
|
||||
@ -142,7 +157,7 @@ class Pubsub(IPubsub, Service):
|
||||
# Map of topic to topic validator
|
||||
self.topic_validators = {}
|
||||
|
||||
self.counter = time.time_ns()
|
||||
self.counter = int(time.time())
|
||||
|
||||
async def run(self) -> None:
|
||||
self.manager.run_daemon_task(self.handle_peer_queue)
|
||||
@ -239,8 +254,7 @@ class Pubsub(IPubsub, Service):
|
||||
|
||||
:param topic: the topic to remove validator from
|
||||
"""
|
||||
if topic in self.topic_validators:
|
||||
del self.topic_validators[topic]
|
||||
self.topic_validators.pop(topic, None)
|
||||
|
||||
def get_msg_validators(self, msg: rpc_pb2.Message) -> Tuple[TopicValidator, ...]:
|
||||
"""
|
||||
@ -282,24 +296,22 @@ class Pubsub(IPubsub, Service):
|
||||
logger.debug("fail to add new peer %s, error %s", peer_id, error)
|
||||
return
|
||||
|
||||
self.peers[peer_id] = stream
|
||||
|
||||
# Send hello packet
|
||||
hello = self.get_hello_packet()
|
||||
try:
|
||||
await stream.write(encode_varint_prefixed(hello.SerializeToString()))
|
||||
except StreamClosed:
|
||||
logger.debug("Fail to add new peer %s: stream closed", peer_id)
|
||||
del self.peers[peer_id]
|
||||
return
|
||||
# TODO: Check if the peer in black list.
|
||||
try:
|
||||
self.router.add_peer(peer_id, stream.get_protocol())
|
||||
except Exception as error:
|
||||
logger.debug("fail to add new peer %s, error %s", peer_id, error)
|
||||
del self.peers[peer_id]
|
||||
return
|
||||
|
||||
self.peers[peer_id] = stream
|
||||
|
||||
logger.debug("added new peer %s", peer_id)
|
||||
|
||||
def _handle_dead_peer(self, peer_id: ID) -> None:
|
||||
@ -309,19 +321,16 @@ class Pubsub(IPubsub, Service):
|
||||
|
||||
for topic in self.peer_topics:
|
||||
if peer_id in self.peer_topics[topic]:
|
||||
self.peer_topics[topic].remove(peer_id)
|
||||
self.peer_topics[topic].discard(peer_id)
|
||||
|
||||
self.router.remove_peer(peer_id)
|
||||
|
||||
logger.debug("removed dead peer %s", peer_id)
|
||||
|
||||
async def handle_peer_queue(self) -> None:
|
||||
"""
|
||||
Continuously read from peer channel and each time a new peer is found,
|
||||
open a stream to the peer using a supported pubsub protocol
|
||||
TODO: Handle failure for when the peer does not support any of the
|
||||
pubsub protocols we support
|
||||
"""
|
||||
"""Continuously read from peer queue and each time a new peer is found,
|
||||
open a stream to the peer using a supported pubsub protocol pubsub
|
||||
protocols we support."""
|
||||
async with self.peer_receive_channel:
|
||||
while self.manager.is_running:
|
||||
peer_id: ID = await self.peer_receive_channel.receive()
|
||||
@ -351,14 +360,14 @@ class Pubsub(IPubsub, Service):
|
||||
"""
|
||||
if sub_message.subscribe:
|
||||
if sub_message.topicid not in self.peer_topics:
|
||||
self.peer_topics[sub_message.topicid] = [origin_id]
|
||||
self.peer_topics[sub_message.topicid] = set([origin_id])
|
||||
elif origin_id not in self.peer_topics[sub_message.topicid]:
|
||||
# Add peer to topic
|
||||
self.peer_topics[sub_message.topicid].append(origin_id)
|
||||
self.peer_topics[sub_message.topicid].add(origin_id)
|
||||
else:
|
||||
if sub_message.topicid in self.peer_topics:
|
||||
if origin_id in self.peer_topics[sub_message.topicid]:
|
||||
self.peer_topics[sub_message.topicid].remove(origin_id)
|
||||
self.peer_topics[sub_message.topicid].discard(origin_id)
|
||||
|
||||
# FIXME(mhchia): Change the function name?
|
||||
async def handle_talk(self, publish_message: rpc_pb2.Message) -> None:
|
||||
@ -476,7 +485,13 @@ class Pubsub(IPubsub, Service):
|
||||
seqno=self._next_seqno(),
|
||||
)
|
||||
|
||||
# TODO: Sign with our signing key
|
||||
if self.strict_signing:
|
||||
priv_key = self.sign_key
|
||||
signature = priv_key.sign(
|
||||
PUBSUB_SIGNING_PREFIX.encode() + msg.SerializeToString()
|
||||
)
|
||||
msg.key = self.host.get_public_key().serialize()
|
||||
msg.signature = signature
|
||||
|
||||
await self.push_msg(self.my_id, msg)
|
||||
|
||||
@ -536,18 +551,17 @@ class Pubsub(IPubsub, Service):
|
||||
|
||||
# TODO: Check if the `from` is in the blacklist. If yes, reject.
|
||||
|
||||
# TODO: Check if signing is required and if so signature should be attached.
|
||||
|
||||
# If the message is processed before, return(i.e., don't further process the message).
|
||||
if self._is_msg_seen(msg):
|
||||
return
|
||||
|
||||
# TODO: - Validate the message. If failed, reject it.
|
||||
# Validate the signature of the message
|
||||
# FIXME: `signature_validator` is currently a stub.
|
||||
if not signature_validator(msg.key, msg.SerializeToString()):
|
||||
logger.debug("Signature validation failed for msg: %s", msg)
|
||||
return
|
||||
# Check if signing is required and if so validate the signature
|
||||
if self.strict_signing:
|
||||
# Validate the signature of the message
|
||||
if not signature_validator(msg):
|
||||
logger.debug("Signature validation failed for msg: %s", msg)
|
||||
return
|
||||
|
||||
# Validate the message with registered topic validators.
|
||||
# If the validation failed, return(i.e., don't further process the message).
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user