diff --git a/libp2p/__init__.py b/libp2p/__init__.py index 6d03ce69..74fde55b 100644 --- a/libp2p/__init__.py +++ b/libp2p/__init__.py @@ -269,12 +269,8 @@ def new_host( listen_addrs=listen_addrs, ) - if enable_mDNS: - mdns = MDNSDiscovery(swarm) - mdns.start() - if disc_opt is not None: - return RoutedHost(swarm, disc_opt) - return BasicHost(swarm) + return RoutedHost(swarm, disc_opt, enable_mDNS) + return BasicHost(swarm, enable_mDNS) __version__ = __version("libp2p") diff --git a/libp2p/discovery/mdns/listener.py b/libp2p/discovery/mdns/listener.py index de75b1c1..4f7ded0d 100644 --- a/libp2p/discovery/mdns/listener.py +++ b/libp2p/discovery/mdns/listener.py @@ -36,6 +36,7 @@ class PeerListener(ServiceListener): def add_service(self, zc: Zeroconf, type_: str, name: str) -> None: if name == self.service_name: return + logger.debug(f"Adding service: {name}") info = zc.get_service_info(type_, name, timeout=5000) if not info: return @@ -47,6 +48,9 @@ class PeerListener(ServiceListener): logger.debug("Discovered Peer:", peer_info.peer_id) def remove_service(self, zc: Zeroconf, type_: str, name: str) -> None: + if name == self.service_name: + return + logger.debug(f"Removing service: {name}") peer_id = self.discovered_services.pop(name) self.peerstore.clear_addrs(peer_id) logger.debug(f"Removed Peer: {peer_id}") diff --git a/libp2p/discovery/mdns/mdns.py b/libp2p/discovery/mdns/mdns.py index e2a89463..2af6ab36 100644 --- a/libp2p/discovery/mdns/mdns.py +++ b/libp2p/discovery/mdns/mdns.py @@ -68,5 +68,6 @@ class MDNSDiscovery: def stop(self) -> None: """Unregister this peer and clean up zeroconf resources.""" + logger.debug("Stopping mDNS discovery") self.broadcaster.unregister() self.zeroconf.close() diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index 1dea876d..ccb37dc2 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -29,6 +29,7 @@ from libp2p.custom_types import ( StreamHandlerFn, TProtocol, ) +from libp2p.discovery.mdns.mdns import MDNSDiscovery from libp2p.host.defaults import ( get_default_protocols, ) @@ -89,6 +90,7 @@ class BasicHost(IHost): def __init__( self, network: INetworkService, + enable_mDNS: bool = False, default_protocols: Optional["OrderedDict[TProtocol, StreamHandlerFn]"] = None, ) -> None: self._network = network @@ -98,6 +100,8 @@ class BasicHost(IHost): default_protocols = default_protocols or get_default_protocols(self) self.multiselect = Multiselect(dict(default_protocols.items())) self.multiselect_client = MultiselectClient() + if enable_mDNS: + self.mDNS = MDNSDiscovery(network) def get_id(self) -> ID: """ @@ -162,7 +166,14 @@ class BasicHost(IHost): network = self.get_network() async with background_trio_service(network): await network.listen(*listen_addrs) - yield + if self.mDNS is not None: + logger.debug("Starting mDNS Discovery") + self.mDNS.start() + try: + yield + finally: + if self.mDNS is not None: + self.mDNS.stop() return _run() diff --git a/libp2p/host/routed_host.py b/libp2p/host/routed_host.py index b637e1eb..166a15ec 100644 --- a/libp2p/host/routed_host.py +++ b/libp2p/host/routed_host.py @@ -18,8 +18,10 @@ from libp2p.peer.peerinfo import ( class RoutedHost(BasicHost): _router: IPeerRouting - def __init__(self, network: INetworkService, router: IPeerRouting): - super().__init__(network) + def __init__( + self, network: INetworkService, router: IPeerRouting, enable_mDNS: bool = False + ): + super().__init__(network, enable_mDNS) self._router = router async def connect(self, peer_info: PeerInfo) -> None: