mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-11 07:30:55 +00:00
fixme/correct-type (#746)
* fixme/correct-type * added newsfragment and test
This commit is contained in:
@ -50,6 +50,11 @@ if TYPE_CHECKING:
|
|||||||
Pubsub,
|
Pubsub,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from libp2p.protocol_muxer.multiselect import Multiselect
|
||||||
|
|
||||||
from libp2p.pubsub.pb import (
|
from libp2p.pubsub.pb import (
|
||||||
rpc_pb2,
|
rpc_pb2,
|
||||||
)
|
)
|
||||||
@ -1545,9 +1550,8 @@ class IHost(ABC):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# FIXME: Replace with correct return type
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_mux(self) -> Any:
|
def get_mux(self) -> "Multiselect":
|
||||||
"""
|
"""
|
||||||
Retrieve the muxer instance for the host.
|
Retrieve the muxer instance for the host.
|
||||||
|
|
||||||
@ -2158,6 +2162,7 @@ class IMultiselectMuxer(ABC):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def get_protocols(self) -> tuple[TProtocol | None, ...]:
|
def get_protocols(self) -> tuple[TProtocol | None, ...]:
|
||||||
"""
|
"""
|
||||||
Retrieve the protocols for which handlers have been registered.
|
Retrieve the protocols for which handlers have been registered.
|
||||||
@ -2168,7 +2173,6 @@ class IMultiselectMuxer(ABC):
|
|||||||
A tuple of registered protocol names.
|
A tuple of registered protocol names.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return tuple(self.handlers.keys())
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def negotiate(
|
async def negotiate(
|
||||||
|
|||||||
@ -59,7 +59,7 @@ def _mk_identify_protobuf(
|
|||||||
) -> Identify:
|
) -> Identify:
|
||||||
public_key = host.get_public_key()
|
public_key = host.get_public_key()
|
||||||
laddrs = host.get_addrs()
|
laddrs = host.get_addrs()
|
||||||
protocols = host.get_mux().get_protocols()
|
protocols = tuple(str(p) for p in host.get_mux().get_protocols() if p is not None)
|
||||||
|
|
||||||
observed_addr = observed_multiaddr.to_bytes() if observed_multiaddr else b""
|
observed_addr = observed_multiaddr.to_bytes() if observed_multiaddr else b""
|
||||||
return Identify(
|
return Identify(
|
||||||
|
|||||||
@ -101,6 +101,18 @@ class Multiselect(IMultiselectMuxer):
|
|||||||
except trio.TooSlowError:
|
except trio.TooSlowError:
|
||||||
raise MultiselectError("handshake read timeout")
|
raise MultiselectError("handshake read timeout")
|
||||||
|
|
||||||
|
def get_protocols(self) -> tuple[TProtocol | None, ...]:
|
||||||
|
"""
|
||||||
|
Retrieve the protocols for which handlers have been registered.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
tuple[TProtocol, ...]
|
||||||
|
A tuple of registered protocol names.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return tuple(self.handlers.keys())
|
||||||
|
|
||||||
async def handshake(self, communicator: IMultiselectCommunicator) -> None:
|
async def handshake(self, communicator: IMultiselectCommunicator) -> None:
|
||||||
"""
|
"""
|
||||||
Perform handshake to agree on multiselect protocol.
|
Perform handshake to agree on multiselect protocol.
|
||||||
|
|||||||
@ -292,7 +292,9 @@ class RelayDiscovery(Service):
|
|||||||
# Get protocols with proper typing
|
# Get protocols with proper typing
|
||||||
mux_protocols = mux.get_protocols()
|
mux_protocols = mux.get_protocols()
|
||||||
if isinstance(mux_protocols, (list, tuple)):
|
if isinstance(mux_protocols, (list, tuple)):
|
||||||
available_protocols = list(mux_protocols)
|
available_protocols = [
|
||||||
|
p for p in mux.get_protocols() if p is not None
|
||||||
|
]
|
||||||
|
|
||||||
for protocol in available_protocols:
|
for protocol in available_protocols:
|
||||||
try:
|
try:
|
||||||
@ -312,7 +314,7 @@ class RelayDiscovery(Service):
|
|||||||
|
|
||||||
self._protocol_cache[peer_id] = peer_protocols
|
self._protocol_cache[peer_id] = peer_protocols
|
||||||
protocol_str = str(PROTOCOL_ID)
|
protocol_str = str(PROTOCOL_ID)
|
||||||
for protocol in peer_protocols:
|
for protocol in map(TProtocol, peer_protocols):
|
||||||
if protocol == protocol_str:
|
if protocol == protocol_str:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|||||||
3
newsfragments/746.bugfix.rst
Normal file
3
newsfragments/746.bugfix.rst
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
Improved type safety in `get_mux()` and `get_protocols()` by returning properly typed values instead
|
||||||
|
of `Any`. Also updated `identify.py` and `discovery.py` to handle `None` values safely and
|
||||||
|
compare protocols correctly.
|
||||||
@ -3,6 +3,7 @@ import pytest
|
|||||||
from libp2p.custom_types import (
|
from libp2p.custom_types import (
|
||||||
TProtocol,
|
TProtocol,
|
||||||
)
|
)
|
||||||
|
from libp2p.protocol_muxer.multiselect import Multiselect
|
||||||
from libp2p.tools.utils import (
|
from libp2p.tools.utils import (
|
||||||
create_echo_stream_handler,
|
create_echo_stream_handler,
|
||||||
)
|
)
|
||||||
@ -138,3 +139,23 @@ async def test_multistream_command(security_protocol):
|
|||||||
# Dialer asks for unspoorted command
|
# Dialer asks for unspoorted command
|
||||||
with pytest.raises(ValueError, match="Command not supported"):
|
with pytest.raises(ValueError, match="Command not supported"):
|
||||||
await dialer.send_command(listener.get_id(), "random")
|
await dialer.send_command(listener.get_id(), "random")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_get_protocols_returns_all_registered_protocols():
|
||||||
|
ms = Multiselect()
|
||||||
|
|
||||||
|
async def dummy_handler(stream):
|
||||||
|
pass
|
||||||
|
|
||||||
|
p1 = TProtocol("/echo/1.0.0")
|
||||||
|
p2 = TProtocol("/foo/1.0.0")
|
||||||
|
p3 = TProtocol("/bar/1.0.0")
|
||||||
|
|
||||||
|
ms.add_handler(p1, dummy_handler)
|
||||||
|
ms.add_handler(p2, dummy_handler)
|
||||||
|
ms.add_handler(p3, dummy_handler)
|
||||||
|
|
||||||
|
protocols = ms.get_protocols()
|
||||||
|
|
||||||
|
assert set(protocols) == {p1, p2, p3}
|
||||||
|
|||||||
Reference in New Issue
Block a user