From 5fcfc677f31e17b9bd842425bab66169dc71a339 Mon Sep 17 00:00:00 2001 From: Archit Dabral <147427717+Minimega12121@users.noreply.github.com> Date: Sat, 12 Jul 2025 02:57:17 +0530 Subject: [PATCH] fixme/correct-type (#746) * fixme/correct-type * added newsfragment and test --- libp2p/abc.py | 10 ++++++--- libp2p/identity/identify/identify.py | 2 +- libp2p/protocol_muxer/multiselect.py | 12 +++++++++++ libp2p/relay/circuit_v2/discovery.py | 6 ++++-- newsfragments/746.bugfix.rst | 3 +++ .../protocol_muxer/test_protocol_muxer.py | 21 +++++++++++++++++++ 6 files changed, 48 insertions(+), 6 deletions(-) create mode 100644 newsfragments/746.bugfix.rst diff --git a/libp2p/abc.py b/libp2p/abc.py index 70c4ab71..3adb04aa 100644 --- a/libp2p/abc.py +++ b/libp2p/abc.py @@ -50,6 +50,11 @@ if TYPE_CHECKING: Pubsub, ) +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from libp2p.protocol_muxer.multiselect import Multiselect + from libp2p.pubsub.pb import ( rpc_pb2, ) @@ -1545,9 +1550,8 @@ class IHost(ABC): """ - # FIXME: Replace with correct return type @abstractmethod - def get_mux(self) -> Any: + def get_mux(self) -> "Multiselect": """ Retrieve the muxer instance for the host. @@ -2158,6 +2162,7 @@ class IMultiselectMuxer(ABC): """ + @abstractmethod def get_protocols(self) -> tuple[TProtocol | None, ...]: """ Retrieve the protocols for which handlers have been registered. @@ -2168,7 +2173,6 @@ class IMultiselectMuxer(ABC): A tuple of registered protocol names. """ - return tuple(self.handlers.keys()) @abstractmethod async def negotiate( diff --git a/libp2p/identity/identify/identify.py b/libp2p/identity/identify/identify.py index 5d066e37..15367c43 100644 --- a/libp2p/identity/identify/identify.py +++ b/libp2p/identity/identify/identify.py @@ -59,7 +59,7 @@ def _mk_identify_protobuf( ) -> Identify: public_key = host.get_public_key() laddrs = host.get_addrs() - protocols = host.get_mux().get_protocols() + protocols = tuple(str(p) for p in host.get_mux().get_protocols() if p is not None) observed_addr = observed_multiaddr.to_bytes() if observed_multiaddr else b"" return Identify( diff --git a/libp2p/protocol_muxer/multiselect.py b/libp2p/protocol_muxer/multiselect.py index 3f6ef02f..8d311391 100644 --- a/libp2p/protocol_muxer/multiselect.py +++ b/libp2p/protocol_muxer/multiselect.py @@ -101,6 +101,18 @@ class Multiselect(IMultiselectMuxer): except trio.TooSlowError: raise MultiselectError("handshake read timeout") + def get_protocols(self) -> tuple[TProtocol | None, ...]: + """ + Retrieve the protocols for which handlers have been registered. + + Returns + ------- + tuple[TProtocol, ...] + A tuple of registered protocol names. + + """ + return tuple(self.handlers.keys()) + async def handshake(self, communicator: IMultiselectCommunicator) -> None: """ Perform handshake to agree on multiselect protocol. diff --git a/libp2p/relay/circuit_v2/discovery.py b/libp2p/relay/circuit_v2/discovery.py index 734a7869..a35eacdc 100644 --- a/libp2p/relay/circuit_v2/discovery.py +++ b/libp2p/relay/circuit_v2/discovery.py @@ -292,7 +292,9 @@ class RelayDiscovery(Service): # Get protocols with proper typing mux_protocols = mux.get_protocols() if isinstance(mux_protocols, (list, tuple)): - available_protocols = list(mux_protocols) + available_protocols = [ + p for p in mux.get_protocols() if p is not None + ] for protocol in available_protocols: try: @@ -312,7 +314,7 @@ class RelayDiscovery(Service): self._protocol_cache[peer_id] = peer_protocols protocol_str = str(PROTOCOL_ID) - for protocol in peer_protocols: + for protocol in map(TProtocol, peer_protocols): if protocol == protocol_str: return True return False diff --git a/newsfragments/746.bugfix.rst b/newsfragments/746.bugfix.rst new file mode 100644 index 00000000..71970b48 --- /dev/null +++ b/newsfragments/746.bugfix.rst @@ -0,0 +1,3 @@ +Improved type safety in `get_mux()` and `get_protocols()` by returning properly typed values instead +of `Any`. Also updated `identify.py` and `discovery.py` to handle `None` values safely and +compare protocols correctly. diff --git a/tests/core/protocol_muxer/test_protocol_muxer.py b/tests/core/protocol_muxer/test_protocol_muxer.py index b089390b..1d6a0f86 100644 --- a/tests/core/protocol_muxer/test_protocol_muxer.py +++ b/tests/core/protocol_muxer/test_protocol_muxer.py @@ -3,6 +3,7 @@ import pytest from libp2p.custom_types import ( TProtocol, ) +from libp2p.protocol_muxer.multiselect import Multiselect from libp2p.tools.utils import ( create_echo_stream_handler, ) @@ -138,3 +139,23 @@ async def test_multistream_command(security_protocol): # Dialer asks for unspoorted command with pytest.raises(ValueError, match="Command not supported"): await dialer.send_command(listener.get_id(), "random") + + +@pytest.mark.trio +async def test_get_protocols_returns_all_registered_protocols(): + ms = Multiselect() + + async def dummy_handler(stream): + pass + + p1 = TProtocol("/echo/1.0.0") + p2 = TProtocol("/foo/1.0.0") + p3 = TProtocol("/bar/1.0.0") + + ms.add_handler(p1, dummy_handler) + ms.add_handler(p2, dummy_handler) + ms.add_handler(p3, dummy_handler) + + protocols = ms.get_protocols() + + assert set(protocols) == {p1, p2, p3}