diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index b26dd3c7..66e36e09 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -1,8 +1,9 @@ import logging -from typing import List, Sequence +from typing import TYPE_CHECKING, List, Sequence import multiaddr +from libp2p.host.defaults import get_default_protocols from libp2p.host.exceptions import StreamFailure from libp2p.network.network_interface import INetwork from libp2p.network.stream.net_stream_interface import INetStream @@ -17,6 +18,9 @@ from libp2p.typing import StreamHandlerFn, TProtocol from .host_interface import IHost +if TYPE_CHECKING: + from collections import OrderedDict + # Upon host creation, host takes in options, # including the list of addresses on which to listen. # Host then parses these options and delegates to its Network instance, @@ -38,12 +42,17 @@ class BasicHost(IHost): multiselect: Multiselect multiselect_client: MultiselectClient - def __init__(self, network: INetwork) -> None: + def __init__( + self, + network: INetwork, + default_protocols: "OrderedDict[TProtocol, StreamHandlerFn]" = None, + ) -> None: self._network = network self._network.set_stream_handler(self._swarm_stream_handler) self.peerstore = self._network.peerstore # Protocol muxing - self.multiselect = Multiselect() + default_protocols = default_protocols or get_default_protocols() + self.multiselect = Multiselect(default_protocols) self.multiselect_client = MultiselectClient() def get_id(self) -> ID: diff --git a/libp2p/host/defaults.py b/libp2p/host/defaults.py new file mode 100644 index 00000000..ab5952fc --- /dev/null +++ b/libp2p/host/defaults.py @@ -0,0 +1,11 @@ +from collections import OrderedDict +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from libp2p.typing import TProtocol, StreamHandlerFn + +DEFAULT_HOST_PROTOCOLS: "OrderedDict[TProtocol, StreamHandlerFn]" = OrderedDict() + + +def get_default_protocols() -> "OrderedDict[TProtocol, StreamHandlerFn]": + return DEFAULT_HOST_PROTOCOLS.copy() diff --git a/libp2p/protocol_muxer/multiselect.py b/libp2p/protocol_muxer/multiselect.py index 88f7e37e..06e268af 100644 --- a/libp2p/protocol_muxer/multiselect.py +++ b/libp2p/protocol_muxer/multiselect.py @@ -19,8 +19,12 @@ class Multiselect(IMultiselectMuxer): handlers: Dict[TProtocol, StreamHandlerFn] - def __init__(self) -> None: - self.handlers = {} + def __init__( + self, default_handlers: Dict[TProtocol, StreamHandlerFn] = None + ) -> None: + if not default_handlers: + default_handlers = {} + self.handlers = default_handlers def add_handler(self, protocol: TProtocol, handler: StreamHandlerFn) -> None: """ diff --git a/tests/host/test_basic_host.py b/tests/host/test_basic_host.py new file mode 100644 index 00000000..5718b6de --- /dev/null +++ b/tests/host/test_basic_host.py @@ -0,0 +1,14 @@ +from libp2p import initialize_default_swarm +from libp2p.crypto.rsa import create_new_key_pair +from libp2p.host.basic_host import BasicHost +from libp2p.host.defaults import get_default_protocols + + +def test_default_protocols(): + key_pair = create_new_key_pair() + swarm = initialize_default_swarm(key_pair) + host = BasicHost(swarm) + + mux = host.get_mux() + handlers = mux.handlers + assert handlers == get_default_protocols()