fixme/correct-type (#746)

* fixme/correct-type

* added newsfragment and test
This commit is contained in:
Archit Dabral
2025-07-12 02:57:17 +05:30
committed by GitHub
parent dd14aad47c
commit 5fcfc677f3
6 changed files with 48 additions and 6 deletions

View File

@ -50,6 +50,11 @@ if TYPE_CHECKING:
Pubsub, Pubsub,
) )
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from libp2p.protocol_muxer.multiselect import Multiselect
from libp2p.pubsub.pb import ( from libp2p.pubsub.pb import (
rpc_pb2, rpc_pb2,
) )
@ -1545,9 +1550,8 @@ class IHost(ABC):
""" """
# FIXME: Replace with correct return type
@abstractmethod @abstractmethod
def get_mux(self) -> Any: def get_mux(self) -> "Multiselect":
""" """
Retrieve the muxer instance for the host. Retrieve the muxer instance for the host.
@ -2158,6 +2162,7 @@ class IMultiselectMuxer(ABC):
""" """
@abstractmethod
def get_protocols(self) -> tuple[TProtocol | None, ...]: def get_protocols(self) -> tuple[TProtocol | None, ...]:
""" """
Retrieve the protocols for which handlers have been registered. Retrieve the protocols for which handlers have been registered.
@ -2168,7 +2173,6 @@ class IMultiselectMuxer(ABC):
A tuple of registered protocol names. A tuple of registered protocol names.
""" """
return tuple(self.handlers.keys())
@abstractmethod @abstractmethod
async def negotiate( async def negotiate(

View File

@ -59,7 +59,7 @@ def _mk_identify_protobuf(
) -> Identify: ) -> Identify:
public_key = host.get_public_key() public_key = host.get_public_key()
laddrs = host.get_addrs() laddrs = host.get_addrs()
protocols = host.get_mux().get_protocols() protocols = tuple(str(p) for p in host.get_mux().get_protocols() if p is not None)
observed_addr = observed_multiaddr.to_bytes() if observed_multiaddr else b"" observed_addr = observed_multiaddr.to_bytes() if observed_multiaddr else b""
return Identify( return Identify(

View File

@ -101,6 +101,18 @@ class Multiselect(IMultiselectMuxer):
except trio.TooSlowError: except trio.TooSlowError:
raise MultiselectError("handshake read timeout") raise MultiselectError("handshake read timeout")
def get_protocols(self) -> tuple[TProtocol | None, ...]:
"""
Retrieve the protocols for which handlers have been registered.
Returns
-------
tuple[TProtocol, ...]
A tuple of registered protocol names.
"""
return tuple(self.handlers.keys())
async def handshake(self, communicator: IMultiselectCommunicator) -> None: async def handshake(self, communicator: IMultiselectCommunicator) -> None:
""" """
Perform handshake to agree on multiselect protocol. Perform handshake to agree on multiselect protocol.

View File

@ -292,7 +292,9 @@ class RelayDiscovery(Service):
# Get protocols with proper typing # Get protocols with proper typing
mux_protocols = mux.get_protocols() mux_protocols = mux.get_protocols()
if isinstance(mux_protocols, (list, tuple)): if isinstance(mux_protocols, (list, tuple)):
available_protocols = list(mux_protocols) available_protocols = [
p for p in mux.get_protocols() if p is not None
]
for protocol in available_protocols: for protocol in available_protocols:
try: try:
@ -312,7 +314,7 @@ class RelayDiscovery(Service):
self._protocol_cache[peer_id] = peer_protocols self._protocol_cache[peer_id] = peer_protocols
protocol_str = str(PROTOCOL_ID) protocol_str = str(PROTOCOL_ID)
for protocol in peer_protocols: for protocol in map(TProtocol, peer_protocols):
if protocol == protocol_str: if protocol == protocol_str:
return True return True
return False return False

View File

@ -0,0 +1,3 @@
Improved type safety in `get_mux()` and `get_protocols()` by returning properly typed values instead
of `Any`. Also updated `identify.py` and `discovery.py` to handle `None` values safely and
compare protocols correctly.

View File

@ -3,6 +3,7 @@ import pytest
from libp2p.custom_types import ( from libp2p.custom_types import (
TProtocol, TProtocol,
) )
from libp2p.protocol_muxer.multiselect import Multiselect
from libp2p.tools.utils import ( from libp2p.tools.utils import (
create_echo_stream_handler, create_echo_stream_handler,
) )
@ -138,3 +139,23 @@ async def test_multistream_command(security_protocol):
# Dialer asks for unspoorted command # Dialer asks for unspoorted command
with pytest.raises(ValueError, match="Command not supported"): with pytest.raises(ValueError, match="Command not supported"):
await dialer.send_command(listener.get_id(), "random") await dialer.send_command(listener.get_id(), "random")
@pytest.mark.trio
async def test_get_protocols_returns_all_registered_protocols():
ms = Multiselect()
async def dummy_handler(stream):
pass
p1 = TProtocol("/echo/1.0.0")
p2 = TProtocol("/foo/1.0.0")
p3 = TProtocol("/bar/1.0.0")
ms.add_handler(p1, dummy_handler)
ms.add_handler(p2, dummy_handler)
ms.add_handler(p3, dummy_handler)
protocols = ms.get_protocols()
assert set(protocols) == {p1, p2, p3}