mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-09 22:50:54 +00:00
FIXME: Make TProtocol Optional[TProtocol] to keep types consistent (#770)
* FIXME: Make TProtocol Optional[TProtocol] to keep types consistent * correct test case of test_protocol_muxer * add newsfragment * unit test added --------- Co-authored-by: Manu Sheel Gupta <manusheel.edu@gmail.com>
This commit is contained in:
@ -295,6 +295,13 @@ class BasicHost(IHost):
|
|||||||
)
|
)
|
||||||
await net_stream.reset()
|
await net_stream.reset()
|
||||||
return
|
return
|
||||||
|
if protocol is None:
|
||||||
|
logger.debug(
|
||||||
|
"no protocol negotiated, closing stream from peer %s",
|
||||||
|
net_stream.muxed_conn.peer_id,
|
||||||
|
)
|
||||||
|
await net_stream.reset()
|
||||||
|
return
|
||||||
net_stream.set_protocol(protocol)
|
net_stream.set_protocol(protocol)
|
||||||
if handler is None:
|
if handler is None:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|||||||
@ -48,12 +48,11 @@ class Multiselect(IMultiselectMuxer):
|
|||||||
"""
|
"""
|
||||||
self.handlers[protocol] = handler
|
self.handlers[protocol] = handler
|
||||||
|
|
||||||
# FIXME: Make TProtocol Optional[TProtocol] to keep types consistent
|
|
||||||
async def negotiate(
|
async def negotiate(
|
||||||
self,
|
self,
|
||||||
communicator: IMultiselectCommunicator,
|
communicator: IMultiselectCommunicator,
|
||||||
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||||
) -> tuple[TProtocol, StreamHandlerFn | None]:
|
) -> tuple[TProtocol | None, StreamHandlerFn | None]:
|
||||||
"""
|
"""
|
||||||
Negotiate performs protocol selection.
|
Negotiate performs protocol selection.
|
||||||
|
|
||||||
@ -84,14 +83,14 @@ class Multiselect(IMultiselectMuxer):
|
|||||||
raise MultiselectError() from error
|
raise MultiselectError() from error
|
||||||
|
|
||||||
else:
|
else:
|
||||||
protocol = TProtocol(command)
|
protocol_to_check = None if not command else TProtocol(command)
|
||||||
if protocol in self.handlers:
|
if protocol_to_check in self.handlers:
|
||||||
try:
|
try:
|
||||||
await communicator.write(protocol)
|
await communicator.write(command)
|
||||||
except MultiselectCommunicatorError as error:
|
except MultiselectCommunicatorError as error:
|
||||||
raise MultiselectError() from error
|
raise MultiselectError() from error
|
||||||
|
|
||||||
return protocol, self.handlers[protocol]
|
return protocol_to_check, self.handlers[protocol_to_check]
|
||||||
try:
|
try:
|
||||||
await communicator.write(PROTOCOL_NOT_FOUND_MSG)
|
await communicator.write(PROTOCOL_NOT_FOUND_MSG)
|
||||||
except MultiselectCommunicatorError as error:
|
except MultiselectCommunicatorError as error:
|
||||||
|
|||||||
@ -134,8 +134,10 @@ class MultiselectClient(IMultiselectClient):
|
|||||||
:raise MultiselectClientError: raised when protocol negotiation failed
|
:raise MultiselectClientError: raised when protocol negotiation failed
|
||||||
:return: selected protocol
|
:return: selected protocol
|
||||||
"""
|
"""
|
||||||
|
# Represent `None` protocol as an empty string.
|
||||||
|
protocol_str = protocol if protocol is not None else ""
|
||||||
try:
|
try:
|
||||||
await communicator.write(protocol)
|
await communicator.write(protocol_str)
|
||||||
except MultiselectCommunicatorError as error:
|
except MultiselectCommunicatorError as error:
|
||||||
raise MultiselectClientError() from error
|
raise MultiselectClientError() from error
|
||||||
|
|
||||||
@ -145,7 +147,7 @@ class MultiselectClient(IMultiselectClient):
|
|||||||
except MultiselectCommunicatorError as error:
|
except MultiselectCommunicatorError as error:
|
||||||
raise MultiselectClientError() from error
|
raise MultiselectClientError() from error
|
||||||
|
|
||||||
if response == protocol:
|
if response == protocol_str:
|
||||||
return protocol
|
return protocol
|
||||||
if response == PROTOCOL_NOT_FOUND_MSG:
|
if response == PROTOCOL_NOT_FOUND_MSG:
|
||||||
raise MultiselectClientError("protocol not supported")
|
raise MultiselectClientError("protocol not supported")
|
||||||
|
|||||||
@ -30,7 +30,10 @@ class MultiselectCommunicator(IMultiselectCommunicator):
|
|||||||
"""
|
"""
|
||||||
:raise MultiselectCommunicatorError: raised when failed to write to underlying reader
|
:raise MultiselectCommunicatorError: raised when failed to write to underlying reader
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
msg_bytes = encode_delim(msg_str.encode())
|
if msg_str is None:
|
||||||
|
msg_bytes = encode_delim(b"")
|
||||||
|
else:
|
||||||
|
msg_bytes = encode_delim(msg_str.encode())
|
||||||
try:
|
try:
|
||||||
await self.read_writer.write(msg_bytes)
|
await self.read_writer.write(msg_bytes)
|
||||||
except IOException as error:
|
except IOException as error:
|
||||||
|
|||||||
@ -17,6 +17,9 @@ from libp2p.custom_types import (
|
|||||||
from libp2p.peer.id import (
|
from libp2p.peer.id import (
|
||||||
ID,
|
ID,
|
||||||
)
|
)
|
||||||
|
from libp2p.protocol_muxer.exceptions import (
|
||||||
|
MultiselectError,
|
||||||
|
)
|
||||||
from libp2p.protocol_muxer.multiselect import (
|
from libp2p.protocol_muxer.multiselect import (
|
||||||
Multiselect,
|
Multiselect,
|
||||||
)
|
)
|
||||||
@ -104,7 +107,7 @@ class SecurityMultistream(ABC):
|
|||||||
:param is_initiator: true if we are the initiator, false otherwise
|
:param is_initiator: true if we are the initiator, false otherwise
|
||||||
:return: selected secure transport
|
:return: selected secure transport
|
||||||
"""
|
"""
|
||||||
protocol: TProtocol
|
protocol: TProtocol | None
|
||||||
communicator = MultiselectCommunicator(conn)
|
communicator = MultiselectCommunicator(conn)
|
||||||
if is_initiator:
|
if is_initiator:
|
||||||
# Select protocol if initiator
|
# Select protocol if initiator
|
||||||
@ -114,5 +117,7 @@ class SecurityMultistream(ABC):
|
|||||||
else:
|
else:
|
||||||
# Select protocol if non-initiator
|
# Select protocol if non-initiator
|
||||||
protocol, _ = await self.multiselect.negotiate(communicator)
|
protocol, _ = await self.multiselect.negotiate(communicator)
|
||||||
|
if protocol is None:
|
||||||
|
raise MultiselectError("fail to negotiate a security protocol")
|
||||||
# Return transport from protocol
|
# Return transport from protocol
|
||||||
return self.transports[protocol]
|
return self.transports[protocol]
|
||||||
|
|||||||
@ -17,6 +17,9 @@ from libp2p.custom_types import (
|
|||||||
from libp2p.peer.id import (
|
from libp2p.peer.id import (
|
||||||
ID,
|
ID,
|
||||||
)
|
)
|
||||||
|
from libp2p.protocol_muxer.exceptions import (
|
||||||
|
MultiselectError,
|
||||||
|
)
|
||||||
from libp2p.protocol_muxer.multiselect import (
|
from libp2p.protocol_muxer.multiselect import (
|
||||||
Multiselect,
|
Multiselect,
|
||||||
)
|
)
|
||||||
@ -73,7 +76,7 @@ class MuxerMultistream:
|
|||||||
:param conn: conn to choose a transport over
|
:param conn: conn to choose a transport over
|
||||||
:return: selected muxer transport
|
:return: selected muxer transport
|
||||||
"""
|
"""
|
||||||
protocol: TProtocol
|
protocol: TProtocol | None
|
||||||
communicator = MultiselectCommunicator(conn)
|
communicator = MultiselectCommunicator(conn)
|
||||||
if conn.is_initiator:
|
if conn.is_initiator:
|
||||||
protocol = await self.multiselect_client.select_one_of(
|
protocol = await self.multiselect_client.select_one_of(
|
||||||
@ -81,6 +84,8 @@ class MuxerMultistream:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
protocol, _ = await self.multiselect.negotiate(communicator)
|
protocol, _ = await self.multiselect.negotiate(communicator)
|
||||||
|
if protocol is None:
|
||||||
|
raise MultiselectError("fail to negotiate a stream muxer protocol")
|
||||||
return self.transports[protocol]
|
return self.transports[protocol]
|
||||||
|
|
||||||
async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn:
|
async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn:
|
||||||
|
|||||||
1
newsfragments/770.internal.rst
Normal file
1
newsfragments/770.internal.rst
Normal file
@ -0,0 +1 @@
|
|||||||
|
Make TProtocol as Optional[TProtocol] to keep types consistent in py-libp2p/libp2p/protocol_muxer/multiselect.py
|
||||||
@ -1,9 +1,9 @@
|
|||||||
|
from collections import deque
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
from libp2p.abc import (
|
from libp2p.abc import IMultiselectCommunicator, INetStream
|
||||||
IMultiselectCommunicator,
|
|
||||||
)
|
|
||||||
from libp2p.custom_types import TProtocol
|
from libp2p.custom_types import TProtocol
|
||||||
from libp2p.protocol_muxer.exceptions import (
|
from libp2p.protocol_muxer.exceptions import (
|
||||||
MultiselectClientError,
|
MultiselectClientError,
|
||||||
@ -13,6 +13,10 @@ from libp2p.protocol_muxer.multiselect import Multiselect
|
|||||||
from libp2p.protocol_muxer.multiselect_client import MultiselectClient
|
from libp2p.protocol_muxer.multiselect_client import MultiselectClient
|
||||||
|
|
||||||
|
|
||||||
|
async def dummy_handler(stream: INetStream) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DummyMultiselectCommunicator(IMultiselectCommunicator):
|
class DummyMultiselectCommunicator(IMultiselectCommunicator):
|
||||||
"""
|
"""
|
||||||
Dummy MultiSelectCommunicator to test out negotiate timmeout.
|
Dummy MultiSelectCommunicator to test out negotiate timmeout.
|
||||||
@ -31,7 +35,7 @@ class DummyMultiselectCommunicator(IMultiselectCommunicator):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
async def test_select_one_of_timeout():
|
async def test_select_one_of_timeout() -> None:
|
||||||
ECHO = TProtocol("/echo/1.0.0")
|
ECHO = TProtocol("/echo/1.0.0")
|
||||||
communicator = DummyMultiselectCommunicator()
|
communicator = DummyMultiselectCommunicator()
|
||||||
|
|
||||||
@ -42,7 +46,7 @@ async def test_select_one_of_timeout():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
async def test_query_multistream_command_timeout():
|
async def test_query_multistream_command_timeout() -> None:
|
||||||
communicator = DummyMultiselectCommunicator()
|
communicator = DummyMultiselectCommunicator()
|
||||||
client = MultiselectClient()
|
client = MultiselectClient()
|
||||||
|
|
||||||
@ -51,9 +55,95 @@ async def test_query_multistream_command_timeout():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
async def test_negotiate_timeout():
|
async def test_negotiate_timeout() -> None:
|
||||||
communicator = DummyMultiselectCommunicator()
|
communicator = DummyMultiselectCommunicator()
|
||||||
server = Multiselect()
|
server = Multiselect()
|
||||||
|
|
||||||
with pytest.raises(MultiselectError, match="handshake read timeout"):
|
with pytest.raises(MultiselectError, match="handshake read timeout"):
|
||||||
await server.negotiate(communicator, 2)
|
await server.negotiate(communicator, 2)
|
||||||
|
|
||||||
|
|
||||||
|
class HandshakeThenHangCommunicator(IMultiselectCommunicator):
|
||||||
|
handshaked: bool
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.handshaked = False
|
||||||
|
|
||||||
|
async def write(self, msg_str: str) -> None:
|
||||||
|
if msg_str == "/multistream/1.0.0":
|
||||||
|
self.handshaked = True
|
||||||
|
return
|
||||||
|
|
||||||
|
async def read(self) -> str:
|
||||||
|
if not self.handshaked:
|
||||||
|
return "/multistream/1.0.0"
|
||||||
|
# After handshake, hang on read.
|
||||||
|
await trio.sleep_forever()
|
||||||
|
# Should not be reached.
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_negotiate_timeout_post_handshake() -> None:
|
||||||
|
communicator = HandshakeThenHangCommunicator()
|
||||||
|
server = Multiselect()
|
||||||
|
with pytest.raises(MultiselectError, match="handshake read timeout"):
|
||||||
|
await server.negotiate(communicator, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class MockCommunicator(IMultiselectCommunicator):
|
||||||
|
def __init__(self, commands_to_read: list[str]):
|
||||||
|
self.read_queue = deque(commands_to_read)
|
||||||
|
self.written_data: list[str] = []
|
||||||
|
|
||||||
|
async def write(self, msg_str: str) -> None:
|
||||||
|
self.written_data.append(msg_str)
|
||||||
|
|
||||||
|
async def read(self) -> str:
|
||||||
|
if not self.read_queue:
|
||||||
|
raise EOFError
|
||||||
|
return self.read_queue.popleft()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_negotiate_empty_string_command() -> None:
|
||||||
|
# server receives an empty string, which means client wants `None` protocol.
|
||||||
|
server = Multiselect({None: dummy_handler})
|
||||||
|
# Handshake, then empty command
|
||||||
|
communicator = MockCommunicator(["/multistream/1.0.0", ""])
|
||||||
|
protocol, handler = await server.negotiate(communicator)
|
||||||
|
assert protocol is None
|
||||||
|
assert handler == dummy_handler
|
||||||
|
# Check that server sent back handshake and the protocol confirmation (empty string)
|
||||||
|
assert communicator.written_data == ["/multistream/1.0.0", ""]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_negotiate_with_none_handler() -> None:
|
||||||
|
# server has None handler, client sends "" to select it.
|
||||||
|
server = Multiselect({None: dummy_handler, TProtocol("/proto1"): dummy_handler})
|
||||||
|
# Handshake, then empty command
|
||||||
|
communicator = MockCommunicator(["/multistream/1.0.0", ""])
|
||||||
|
protocol, handler = await server.negotiate(communicator)
|
||||||
|
assert protocol is None
|
||||||
|
assert handler == dummy_handler
|
||||||
|
# Check written data: handshake, protocol confirmation
|
||||||
|
assert communicator.written_data == ["/multistream/1.0.0", ""]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_negotiate_with_none_handler_ls() -> None:
|
||||||
|
# server has None handler, client sends "ls" then empty string.
|
||||||
|
server = Multiselect({None: dummy_handler, TProtocol("/proto1"): dummy_handler})
|
||||||
|
# Handshake, ls, empty command
|
||||||
|
communicator = MockCommunicator(["/multistream/1.0.0", "ls", ""])
|
||||||
|
protocol, handler = await server.negotiate(communicator)
|
||||||
|
assert protocol is None
|
||||||
|
assert handler == dummy_handler
|
||||||
|
# Check written data: handshake, ls response, protocol confirmation
|
||||||
|
assert communicator.written_data[0] == "/multistream/1.0.0"
|
||||||
|
assert "/proto1" in communicator.written_data[1]
|
||||||
|
# Note: `ls` should not list the `None` protocol.
|
||||||
|
assert "None" not in communicator.written_data[1]
|
||||||
|
assert "\n\n" not in communicator.written_data[1]
|
||||||
|
assert communicator.written_data[2] == ""
|
||||||
|
|||||||
@ -159,3 +159,41 @@ async def test_get_protocols_returns_all_registered_protocols():
|
|||||||
protocols = ms.get_protocols()
|
protocols = ms.get_protocols()
|
||||||
|
|
||||||
assert set(protocols) == {p1, p2, p3}
|
assert set(protocols) == {p1, p2, p3}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_negotiate_optional_tprotocol(security_protocol):
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
await perform_simple_test(
|
||||||
|
None,
|
||||||
|
[None],
|
||||||
|
[None],
|
||||||
|
security_protocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_negotiate_optional_tprotocol_client_none_server_no_none(
|
||||||
|
security_protocol,
|
||||||
|
):
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
await perform_simple_test(None, [None], [PROTOCOL_ECHO], security_protocol)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_negotiate_optional_tprotocol_client_none_in_list(security_protocol):
|
||||||
|
expected_selected_protocol = PROTOCOL_ECHO
|
||||||
|
await perform_simple_test(
|
||||||
|
expected_selected_protocol,
|
||||||
|
[None, PROTOCOL_ECHO],
|
||||||
|
[PROTOCOL_ECHO],
|
||||||
|
security_protocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_negotiate_optional_tprotocol_server_none_client_other(
|
||||||
|
security_protocol,
|
||||||
|
):
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
await perform_simple_test(None, [PROTOCOL_ECHO], [None], security_protocol)
|
||||||
|
|||||||
Reference in New Issue
Block a user