Fix the rest of the typing hints (#232)

* ignore kad

* fix swarm, and minor

* fix init and swarm

* ignore pb

* enable mypy

* fix basic host

* fix tcp

* fix mplex

* add typing for pb

* skip format pyi

* [mypy] no need to ignore pb now

* add typing to chat
This commit is contained in:
Chih Cheng Liang
2019-08-11 16:47:54 +08:00
committed by GitHub
parent dbb702548f
commit 28f6de37ee
21 changed files with 465 additions and 98 deletions

View File

@ -1,6 +1,7 @@
from typing import Iterable, List, Sequence
from libp2p.peer.id import ID
from libp2p.typing import TProtocol
from .pb import rpc_pb2
from .pubsub import Pubsub
@ -9,15 +10,15 @@ from .pubsub_router_interface import IPubsubRouter
class FloodSub(IPubsubRouter):
protocols: List[str]
protocols: List[TProtocol]
pubsub: Pubsub
def __init__(self, protocols: Sequence[str]) -> None:
def __init__(self, protocols: Sequence[TProtocol]) -> None:
self.protocols = list(protocols)
self.pubsub = None
def get_protocols(self) -> List[str]:
def get_protocols(self) -> List[TProtocol]:
"""
:return: the list of protocols supported by the router
"""
@ -31,7 +32,7 @@ class FloodSub(IPubsubRouter):
"""
self.pubsub = pubsub
def add_peer(self, peer_id: ID, protocol_id: str) -> None:
def add_peer(self, peer_id: ID, protocol_id: TProtocol) -> None:
"""
Notifies the router that a new peer has been connected
:param peer_id: id of peer to add
@ -43,7 +44,7 @@ class FloodSub(IPubsubRouter):
:param peer_id: id of peer to remove
"""
async def handle_rpc(self, rpc: rpc_pb2.ControlMessage, sender_peer_id: ID) -> None:
async def handle_rpc(self, rpc: rpc_pb2.RPC, sender_peer_id: ID) -> None:
"""
Invoked to process control messages in the RPC envelope.
It is invoked after subscriptions and payload messages have been processed

View File

@ -4,6 +4,7 @@ import random
from typing import Any, Dict, Iterable, List, Sequence, Set
from libp2p.peer.id import ID
from libp2p.typing import TProtocol
from .mcache import MessageCache
from .pb import rpc_pb2
@ -13,7 +14,7 @@ from .pubsub_router_interface import IPubsubRouter
class GossipSub(IPubsubRouter):
protocols: List[str]
protocols: List[TProtocol]
pubsub: Pubsub
degree: int
@ -38,7 +39,7 @@ class GossipSub(IPubsubRouter):
def __init__(
self,
protocols: Sequence[str],
protocols: Sequence[TProtocol],
degree: int,
degree_low: int,
degree_high: int,
@ -79,7 +80,7 @@ class GossipSub(IPubsubRouter):
# Interface functions
def get_protocols(self) -> List[str]:
def get_protocols(self) -> List[TProtocol]:
"""
:return: the list of protocols supported by the router
"""
@ -97,7 +98,7 @@ class GossipSub(IPubsubRouter):
# TODO: Start after delay
asyncio.ensure_future(self.heartbeat())
def add_peer(self, peer_id: ID, protocol_id: str) -> None:
def add_peer(self, peer_id: ID, protocol_id: TProtocol) -> None:
"""
Notifies the router that a new peer has been connected
:param peer_id: id of peer to add
@ -126,7 +127,7 @@ class GossipSub(IPubsubRouter):
if peer_id in self.peers_gossipsub:
self.peers_floodsub.remove(peer_id)
async def handle_rpc(self, rpc: rpc_pb2.Message, sender_peer_id: ID) -> None:
async def handle_rpc(self, rpc: rpc_pb2.RPC, sender_peer_id: ID) -> None:
"""
Invoked to process control messages in the RPC envelope.
It is invoked after subscriptions and payload messages have been processed
@ -436,7 +437,7 @@ class GossipSub(IPubsubRouter):
# RPC handlers
async def handle_ihave(self, ihave_msg: rpc_pb2.Message, sender_peer_id: ID) -> None:
async def handle_ihave(self, ihave_msg: rpc_pb2.ControlIHave, sender_peer_id: ID) -> None:
"""
Checks the seen set and requests unknown messages with an IWANT message.
"""
@ -460,7 +461,7 @@ class GossipSub(IPubsubRouter):
if msg_ids_wanted:
await self.emit_iwant(msg_ids_wanted, sender_peer_id)
async def handle_iwant(self, iwant_msg: rpc_pb2.Message, sender_peer_id: ID) -> None:
async def handle_iwant(self, iwant_msg: rpc_pb2.ControlIWant, sender_peer_id: ID) -> None:
"""
Forwards all request messages that are present in mcache to the requesting peer.
"""
@ -495,7 +496,7 @@ class GossipSub(IPubsubRouter):
# 4) And write the packet to the stream
await peer_stream.write(rpc_msg)
async def handle_graft(self, graft_msg: rpc_pb2.Message, sender_peer_id: ID) -> None:
async def handle_graft(self, graft_msg: rpc_pb2.ControlGraft, sender_peer_id: ID) -> None:
topic: str = graft_msg.topicID
# Add peer to mesh for topic
@ -506,7 +507,7 @@ class GossipSub(IPubsubRouter):
# Respond with PRUNE if not subscribed to the topic
await self.emit_prune(topic, sender_peer_id)
async def handle_prune(self, prune_msg: rpc_pb2.Message, sender_peer_id: ID) -> None:
async def handle_prune(self, prune_msg: rpc_pb2.ControlPrune, sender_peer_id: ID) -> None:
topic: str = prune_msg.topicID
# Remove peer from mesh for topic, if peer is in topic

View File

@ -0,0 +1,322 @@
# @generated by generate_proto_mypy_stubs.py. Do not edit!
import sys
from google.protobuf.descriptor import (
Descriptor as google___protobuf___descriptor___Descriptor,
EnumDescriptor as google___protobuf___descriptor___EnumDescriptor,
)
from google.protobuf.internal.containers import (
RepeatedCompositeFieldContainer as google___protobuf___internal___containers___RepeatedCompositeFieldContainer,
RepeatedScalarFieldContainer as google___protobuf___internal___containers___RepeatedScalarFieldContainer,
)
from google.protobuf.message import (
Message as google___protobuf___message___Message,
)
from typing import (
Iterable as typing___Iterable,
List as typing___List,
Optional as typing___Optional,
Text as typing___Text,
Tuple as typing___Tuple,
cast as typing___cast,
)
from typing_extensions import (
Literal as typing_extensions___Literal,
)
class RPC(google___protobuf___message___Message):
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
class SubOpts(google___protobuf___message___Message):
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
subscribe = ... # type: bool
topicid = ... # type: typing___Text
def __init__(self,
*,
subscribe : typing___Optional[bool] = None,
topicid : typing___Optional[typing___Text] = None,
) -> None: ...
@classmethod
def FromString(cls, s: bytes) -> RPC.SubOpts: ...
def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
if sys.version_info >= (3,):
def HasField(self, field_name: typing_extensions___Literal[u"subscribe",u"topicid"]) -> bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"subscribe",u"topicid"]) -> None: ...
else:
def HasField(self, field_name: typing_extensions___Literal[u"subscribe",b"subscribe",u"topicid",b"topicid"]) -> bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"subscribe",b"subscribe",u"topicid",b"topicid"]) -> None: ...
@property
def subscriptions(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[RPC.SubOpts]: ...
@property
def publish(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[Message]: ...
@property
def control(self) -> ControlMessage: ...
def __init__(self,
*,
subscriptions : typing___Optional[typing___Iterable[RPC.SubOpts]] = None,
publish : typing___Optional[typing___Iterable[Message]] = None,
control : typing___Optional[ControlMessage] = None,
) -> None: ...
@classmethod
def FromString(cls, s: bytes) -> RPC: ...
def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
if sys.version_info >= (3,):
def HasField(self, field_name: typing_extensions___Literal[u"control"]) -> bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"control",u"publish",u"subscriptions"]) -> None: ...
else:
def HasField(self, field_name: typing_extensions___Literal[u"control",b"control"]) -> bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"control",b"control",u"publish",b"publish",u"subscriptions",b"subscriptions"]) -> None: ...
class Message(google___protobuf___message___Message):
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
from_id = ... # type: bytes
data = ... # type: bytes
seqno = ... # type: bytes
topicIDs = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[typing___Text]
signature = ... # type: bytes
key = ... # type: bytes
def __init__(self,
*,
from_id : typing___Optional[bytes] = None,
data : typing___Optional[bytes] = None,
seqno : typing___Optional[bytes] = None,
topicIDs : typing___Optional[typing___Iterable[typing___Text]] = None,
signature : typing___Optional[bytes] = None,
key : typing___Optional[bytes] = None,
) -> None: ...
@classmethod
def FromString(cls, s: bytes) -> Message: ...
def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
if sys.version_info >= (3,):
def HasField(self, field_name: typing_extensions___Literal[u"data",u"from_id",u"key",u"seqno",u"signature"]) -> bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"data",u"from_id",u"key",u"seqno",u"signature",u"topicIDs"]) -> None: ...
else:
def HasField(self, field_name: typing_extensions___Literal[u"data",b"data",u"from_id",b"from_id",u"key",b"key",u"seqno",b"seqno",u"signature",b"signature"]) -> bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"data",b"data",u"from_id",b"from_id",u"key",b"key",u"seqno",b"seqno",u"signature",b"signature",u"topicIDs",b"topicIDs"]) -> None: ...
class ControlMessage(google___protobuf___message___Message):
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
@property
def ihave(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[ControlIHave]: ...
@property
def iwant(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[ControlIWant]: ...
@property
def graft(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[ControlGraft]: ...
@property
def prune(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[ControlPrune]: ...
def __init__(self,
*,
ihave : typing___Optional[typing___Iterable[ControlIHave]] = None,
iwant : typing___Optional[typing___Iterable[ControlIWant]] = None,
graft : typing___Optional[typing___Iterable[ControlGraft]] = None,
prune : typing___Optional[typing___Iterable[ControlPrune]] = None,
) -> None: ...
@classmethod
def FromString(cls, s: bytes) -> ControlMessage: ...
def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
if sys.version_info >= (3,):
def ClearField(self, field_name: typing_extensions___Literal[u"graft",u"ihave",u"iwant",u"prune"]) -> None: ...
else:
def ClearField(self, field_name: typing_extensions___Literal[u"graft",b"graft",u"ihave",b"ihave",u"iwant",b"iwant",u"prune",b"prune"]) -> None: ...
class ControlIHave(google___protobuf___message___Message):
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
topicID = ... # type: typing___Text
messageIDs = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[typing___Text]
def __init__(self,
*,
topicID : typing___Optional[typing___Text] = None,
messageIDs : typing___Optional[typing___Iterable[typing___Text]] = None,
) -> None: ...
@classmethod
def FromString(cls, s: bytes) -> ControlIHave: ...
def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
if sys.version_info >= (3,):
def HasField(self, field_name: typing_extensions___Literal[u"topicID"]) -> bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"messageIDs",u"topicID"]) -> None: ...
else:
def HasField(self, field_name: typing_extensions___Literal[u"topicID",b"topicID"]) -> bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"messageIDs",b"messageIDs",u"topicID",b"topicID"]) -> None: ...
class ControlIWant(google___protobuf___message___Message):
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
messageIDs = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[typing___Text]
def __init__(self,
*,
messageIDs : typing___Optional[typing___Iterable[typing___Text]] = None,
) -> None: ...
@classmethod
def FromString(cls, s: bytes) -> ControlIWant: ...
def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
if sys.version_info >= (3,):
def ClearField(self, field_name: typing_extensions___Literal[u"messageIDs"]) -> None: ...
else:
def ClearField(self, field_name: typing_extensions___Literal[u"messageIDs",b"messageIDs"]) -> None: ...
class ControlGraft(google___protobuf___message___Message):
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
topicID = ... # type: typing___Text
def __init__(self,
*,
topicID : typing___Optional[typing___Text] = None,
) -> None: ...
@classmethod
def FromString(cls, s: bytes) -> ControlGraft: ...
def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
if sys.version_info >= (3,):
def HasField(self, field_name: typing_extensions___Literal[u"topicID"]) -> bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"topicID"]) -> None: ...
else:
def HasField(self, field_name: typing_extensions___Literal[u"topicID",b"topicID"]) -> bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"topicID",b"topicID"]) -> None: ...
class ControlPrune(google___protobuf___message___Message):
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
topicID = ... # type: typing___Text
def __init__(self,
*,
topicID : typing___Optional[typing___Text] = None,
) -> None: ...
@classmethod
def FromString(cls, s: bytes) -> ControlPrune: ...
def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
if sys.version_info >= (3,):
def HasField(self, field_name: typing_extensions___Literal[u"topicID"]) -> bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"topicID"]) -> None: ...
else:
def HasField(self, field_name: typing_extensions___Literal[u"topicID",b"topicID"]) -> bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"topicID",b"topicID"]) -> None: ...
class TopicDescriptor(google___protobuf___message___Message):
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
class AuthOpts(google___protobuf___message___Message):
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
class AuthMode(int):
DESCRIPTOR: google___protobuf___descriptor___EnumDescriptor = ...
@classmethod
def Name(cls, number: int) -> str: ...
@classmethod
def Value(cls, name: str) -> TopicDescriptor.AuthOpts.AuthMode: ...
@classmethod
def keys(cls) -> typing___List[str]: ...
@classmethod
def values(cls) -> typing___List[TopicDescriptor.AuthOpts.AuthMode]: ...
@classmethod
def items(cls) -> typing___List[typing___Tuple[str, TopicDescriptor.AuthOpts.AuthMode]]: ...
NONE = typing___cast(TopicDescriptor.AuthOpts.AuthMode, 0)
KEY = typing___cast(TopicDescriptor.AuthOpts.AuthMode, 1)
WOT = typing___cast(TopicDescriptor.AuthOpts.AuthMode, 2)
NONE = typing___cast(TopicDescriptor.AuthOpts.AuthMode, 0)
KEY = typing___cast(TopicDescriptor.AuthOpts.AuthMode, 1)
WOT = typing___cast(TopicDescriptor.AuthOpts.AuthMode, 2)
mode = ... # type: TopicDescriptor.AuthOpts.AuthMode
keys = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[bytes]
def __init__(self,
*,
mode : typing___Optional[TopicDescriptor.AuthOpts.AuthMode] = None,
keys : typing___Optional[typing___Iterable[bytes]] = None,
) -> None: ...
@classmethod
def FromString(cls, s: bytes) -> TopicDescriptor.AuthOpts: ...
def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
if sys.version_info >= (3,):
def HasField(self, field_name: typing_extensions___Literal[u"mode"]) -> bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"keys",u"mode"]) -> None: ...
else:
def HasField(self, field_name: typing_extensions___Literal[u"mode",b"mode"]) -> bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"keys",b"keys",u"mode",b"mode"]) -> None: ...
class EncOpts(google___protobuf___message___Message):
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
class EncMode(int):
DESCRIPTOR: google___protobuf___descriptor___EnumDescriptor = ...
@classmethod
def Name(cls, number: int) -> str: ...
@classmethod
def Value(cls, name: str) -> TopicDescriptor.EncOpts.EncMode: ...
@classmethod
def keys(cls) -> typing___List[str]: ...
@classmethod
def values(cls) -> typing___List[TopicDescriptor.EncOpts.EncMode]: ...
@classmethod
def items(cls) -> typing___List[typing___Tuple[str, TopicDescriptor.EncOpts.EncMode]]: ...
NONE = typing___cast(TopicDescriptor.EncOpts.EncMode, 0)
SHAREDKEY = typing___cast(TopicDescriptor.EncOpts.EncMode, 1)
WOT = typing___cast(TopicDescriptor.EncOpts.EncMode, 2)
NONE = typing___cast(TopicDescriptor.EncOpts.EncMode, 0)
SHAREDKEY = typing___cast(TopicDescriptor.EncOpts.EncMode, 1)
WOT = typing___cast(TopicDescriptor.EncOpts.EncMode, 2)
mode = ... # type: TopicDescriptor.EncOpts.EncMode
keyHashes = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[bytes]
def __init__(self,
*,
mode : typing___Optional[TopicDescriptor.EncOpts.EncMode] = None,
keyHashes : typing___Optional[typing___Iterable[bytes]] = None,
) -> None: ...
@classmethod
def FromString(cls, s: bytes) -> TopicDescriptor.EncOpts: ...
def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
if sys.version_info >= (3,):
def HasField(self, field_name: typing_extensions___Literal[u"mode"]) -> bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"keyHashes",u"mode"]) -> None: ...
else:
def HasField(self, field_name: typing_extensions___Literal[u"mode",b"mode"]) -> bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"keyHashes",b"keyHashes",u"mode",b"mode"]) -> None: ...
name = ... # type: typing___Text
@property
def auth(self) -> TopicDescriptor.AuthOpts: ...
@property
def enc(self) -> TopicDescriptor.EncOpts: ...
def __init__(self,
*,
name : typing___Optional[typing___Text] = None,
auth : typing___Optional[TopicDescriptor.AuthOpts] = None,
enc : typing___Optional[TopicDescriptor.EncOpts] = None,
) -> None: ...
@classmethod
def FromString(cls, s: bytes) -> TopicDescriptor: ...
def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
if sys.version_info >= (3,):
def HasField(self, field_name: typing_extensions___Literal[u"auth",u"enc",u"name"]) -> bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"auth",u"enc",u"name"]) -> None: ...
else:
def HasField(self, field_name: typing_extensions___Literal[u"auth",b"auth",u"enc",b"enc",u"name",b"name"]) -> bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"auth",b"auth",u"enc",b"enc",u"name",b"name"]) -> None: ...

View File

@ -1,7 +1,18 @@
import asyncio
import logging
import time
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, NamedTuple, Tuple, Union
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
List,
NamedTuple,
Tuple,
Union,
cast,
)
from lru import LRU
@ -9,6 +20,7 @@ from libp2p.exceptions import ValidationError
from libp2p.host.host_interface import IHost
from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.peer.id import ID
from libp2p.typing import TProtocol
from .pb import rpc_pb2
from .pubsub_notifee import PubsubNotifee
@ -31,7 +43,9 @@ AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]]
ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn]
TopicValidator = NamedTuple("TopicValidator", (("validator", ValidatorFn), ("is_async", bool)))
class TopicValidator(NamedTuple):
validator: ValidatorFn
is_async: bool
class Pubsub:
@ -43,7 +57,7 @@ class Pubsub:
peer_queue: "asyncio.Queue[ID]"
protocols: List[str]
protocols: List[TProtocol]
incoming_msgs_from_peers: "asyncio.Queue[rpc_pb2.Message]"
outgoing_messages: "asyncio.Queue[rpc_pb2.Message]"
@ -192,7 +206,7 @@ class Pubsub:
Get all validators corresponding to the topics in the message.
:param msg: the message published to the topic
"""
return (
return tuple(
self.topic_validators[topic] for topic in msg.topicIDs if topic in self.topic_validators
)
@ -301,9 +315,7 @@ class Pubsub:
# Create subscribe message
packet: rpc_pb2.RPC = rpc_pb2.RPC()
packet.subscriptions.extend(
[rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id.encode("utf-8"))]
)
packet.subscriptions.extend([rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id)])
# Send out subscribe message to all peers
await self.message_all_peers(packet.SerializeToString())
@ -328,9 +340,7 @@ class Pubsub:
# Create unsubscribe message
packet: rpc_pb2.RPC = rpc_pb2.RPC()
packet.subscriptions.extend(
[rpc_pb2.RPC.SubOpts(subscribe=False, topicid=topic_id.encode("utf-8"))]
)
packet.subscriptions.extend([rpc_pb2.RPC.SubOpts(subscribe=False, topicid=topic_id)])
# Send out unsubscribe message to all peers
await self.message_all_peers(packet.SerializeToString())
@ -374,12 +384,14 @@ class Pubsub:
:param msg: the message.
"""
sync_topic_validators = []
async_topic_validator_futures = []
async_topic_validator_futures: List[Awaitable[bool]] = []
for topic_validator in self.get_msg_validators(msg):
if topic_validator.is_async:
async_topic_validator_futures.append(topic_validator.validator(msg_forwarder, msg))
async_topic_validator_futures.append(
cast(Awaitable[bool], topic_validator.validator(msg_forwarder, msg))
)
else:
sync_topic_validators.append(topic_validator.validator)
sync_topic_validators.append(cast(SyncValidatorFn, topic_validator.validator))
for validator in sync_topic_validators:
if not validator(msg_forwarder, msg):

View File

@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List
from libp2p.peer.id import ID
from libp2p.typing import TProtocol
from .pb import rpc_pb2
@ -11,7 +12,7 @@ if TYPE_CHECKING:
class IPubsubRouter(ABC):
@abstractmethod
def get_protocols(self) -> List[str]:
def get_protocols(self) -> List[TProtocol]:
"""
:return: the list of protocols supported by the router
"""
@ -25,7 +26,7 @@ class IPubsubRouter(ABC):
"""
@abstractmethod
def add_peer(self, peer_id: ID, protocol_id: str) -> None:
def add_peer(self, peer_id: ID, protocol_id: TProtocol) -> None:
"""
Notifies the router that a new peer has been connected
:param peer_id: id of peer to add
@ -39,7 +40,7 @@ class IPubsubRouter(ABC):
"""
@abstractmethod
async def handle_rpc(self, rpc: rpc_pb2.ControlMessage, sender_peer_id: ID) -> None:
async def handle_rpc(self, rpc: rpc_pb2.RPC, sender_peer_id: ID) -> None:
"""
Invoked to process control messages in the RPC envelope.
It is invoked after subscriptions and payload messages have been processed