mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
FIXME: Make TProtocol Optional[TProtocol] to keep types consistent (#770)
* FIXME: Make TProtocol Optional[TProtocol] to keep types consistent * correct test case of test_protocol_muxer * add newsfragment * unit test added --------- Co-authored-by: Manu Sheel Gupta <manusheel.edu@gmail.com>
This commit is contained in:
@ -295,6 +295,13 @@ class BasicHost(IHost):
|
||||
)
|
||||
await net_stream.reset()
|
||||
return
|
||||
if protocol is None:
|
||||
logger.debug(
|
||||
"no protocol negotiated, closing stream from peer %s",
|
||||
net_stream.muxed_conn.peer_id,
|
||||
)
|
||||
await net_stream.reset()
|
||||
return
|
||||
net_stream.set_protocol(protocol)
|
||||
if handler is None:
|
||||
logger.debug(
|
||||
|
||||
@ -48,12 +48,11 @@ class Multiselect(IMultiselectMuxer):
|
||||
"""
|
||||
self.handlers[protocol] = handler
|
||||
|
||||
# FIXME: Make TProtocol Optional[TProtocol] to keep types consistent
|
||||
async def negotiate(
|
||||
self,
|
||||
communicator: IMultiselectCommunicator,
|
||||
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
) -> tuple[TProtocol, StreamHandlerFn | None]:
|
||||
) -> tuple[TProtocol | None, StreamHandlerFn | None]:
|
||||
"""
|
||||
Negotiate performs protocol selection.
|
||||
|
||||
@ -84,14 +83,14 @@ class Multiselect(IMultiselectMuxer):
|
||||
raise MultiselectError() from error
|
||||
|
||||
else:
|
||||
protocol = TProtocol(command)
|
||||
if protocol in self.handlers:
|
||||
protocol_to_check = None if not command else TProtocol(command)
|
||||
if protocol_to_check in self.handlers:
|
||||
try:
|
||||
await communicator.write(protocol)
|
||||
await communicator.write(command)
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectError() from error
|
||||
|
||||
return protocol, self.handlers[protocol]
|
||||
return protocol_to_check, self.handlers[protocol_to_check]
|
||||
try:
|
||||
await communicator.write(PROTOCOL_NOT_FOUND_MSG)
|
||||
except MultiselectCommunicatorError as error:
|
||||
|
||||
@ -134,8 +134,10 @@ class MultiselectClient(IMultiselectClient):
|
||||
:raise MultiselectClientError: raised when protocol negotiation failed
|
||||
:return: selected protocol
|
||||
"""
|
||||
# Represent `None` protocol as an empty string.
|
||||
protocol_str = protocol if protocol is not None else ""
|
||||
try:
|
||||
await communicator.write(protocol)
|
||||
await communicator.write(protocol_str)
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectClientError() from error
|
||||
|
||||
@ -145,7 +147,7 @@ class MultiselectClient(IMultiselectClient):
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectClientError() from error
|
||||
|
||||
if response == protocol:
|
||||
if response == protocol_str:
|
||||
return protocol
|
||||
if response == PROTOCOL_NOT_FOUND_MSG:
|
||||
raise MultiselectClientError("protocol not supported")
|
||||
|
||||
@ -30,7 +30,10 @@ class MultiselectCommunicator(IMultiselectCommunicator):
|
||||
"""
|
||||
:raise MultiselectCommunicatorError: raised when failed to write to underlying reader
|
||||
""" # noqa: E501
|
||||
msg_bytes = encode_delim(msg_str.encode())
|
||||
if msg_str is None:
|
||||
msg_bytes = encode_delim(b"")
|
||||
else:
|
||||
msg_bytes = encode_delim(msg_str.encode())
|
||||
try:
|
||||
await self.read_writer.write(msg_bytes)
|
||||
except IOException as error:
|
||||
|
||||
@ -17,6 +17,9 @@ from libp2p.custom_types import (
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.protocol_muxer.exceptions import (
|
||||
MultiselectError,
|
||||
)
|
||||
from libp2p.protocol_muxer.multiselect import (
|
||||
Multiselect,
|
||||
)
|
||||
@ -104,7 +107,7 @@ class SecurityMultistream(ABC):
|
||||
:param is_initiator: true if we are the initiator, false otherwise
|
||||
:return: selected secure transport
|
||||
"""
|
||||
protocol: TProtocol
|
||||
protocol: TProtocol | None
|
||||
communicator = MultiselectCommunicator(conn)
|
||||
if is_initiator:
|
||||
# Select protocol if initiator
|
||||
@ -114,5 +117,7 @@ class SecurityMultistream(ABC):
|
||||
else:
|
||||
# Select protocol if non-initiator
|
||||
protocol, _ = await self.multiselect.negotiate(communicator)
|
||||
if protocol is None:
|
||||
raise MultiselectError("fail to negotiate a security protocol")
|
||||
# Return transport from protocol
|
||||
return self.transports[protocol]
|
||||
|
||||
@ -17,6 +17,9 @@ from libp2p.custom_types import (
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.protocol_muxer.exceptions import (
|
||||
MultiselectError,
|
||||
)
|
||||
from libp2p.protocol_muxer.multiselect import (
|
||||
Multiselect,
|
||||
)
|
||||
@ -73,7 +76,7 @@ class MuxerMultistream:
|
||||
:param conn: conn to choose a transport over
|
||||
:return: selected muxer transport
|
||||
"""
|
||||
protocol: TProtocol
|
||||
protocol: TProtocol | None
|
||||
communicator = MultiselectCommunicator(conn)
|
||||
if conn.is_initiator:
|
||||
protocol = await self.multiselect_client.select_one_of(
|
||||
@ -81,6 +84,8 @@ class MuxerMultistream:
|
||||
)
|
||||
else:
|
||||
protocol, _ = await self.multiselect.negotiate(communicator)
|
||||
if protocol is None:
|
||||
raise MultiselectError("fail to negotiate a stream muxer protocol")
|
||||
return self.transports[protocol]
|
||||
|
||||
async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn:
|
||||
|
||||
1
newsfragments/770.internal.rst
Normal file
1
newsfragments/770.internal.rst
Normal file
@ -0,0 +1 @@
|
||||
Make TProtocol as Optional[TProtocol] to keep types consistent in py-libp2p/libp2p/protocol_muxer/multiselect.py
|
||||
@ -1,9 +1,9 @@
|
||||
from collections import deque
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
IMultiselectCommunicator,
|
||||
)
|
||||
from libp2p.abc import IMultiselectCommunicator, INetStream
|
||||
from libp2p.custom_types import TProtocol
|
||||
from libp2p.protocol_muxer.exceptions import (
|
||||
MultiselectClientError,
|
||||
@ -13,6 +13,10 @@ from libp2p.protocol_muxer.multiselect import Multiselect
|
||||
from libp2p.protocol_muxer.multiselect_client import MultiselectClient
|
||||
|
||||
|
||||
async def dummy_handler(stream: INetStream) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class DummyMultiselectCommunicator(IMultiselectCommunicator):
|
||||
"""
|
||||
Dummy MultiSelectCommunicator to test out negotiate timmeout.
|
||||
@ -31,7 +35,7 @@ class DummyMultiselectCommunicator(IMultiselectCommunicator):
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_select_one_of_timeout():
|
||||
async def test_select_one_of_timeout() -> None:
|
||||
ECHO = TProtocol("/echo/1.0.0")
|
||||
communicator = DummyMultiselectCommunicator()
|
||||
|
||||
@ -42,7 +46,7 @@ async def test_select_one_of_timeout():
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_query_multistream_command_timeout():
|
||||
async def test_query_multistream_command_timeout() -> None:
|
||||
communicator = DummyMultiselectCommunicator()
|
||||
client = MultiselectClient()
|
||||
|
||||
@ -51,9 +55,95 @@ async def test_query_multistream_command_timeout():
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_timeout():
|
||||
async def test_negotiate_timeout() -> None:
|
||||
communicator = DummyMultiselectCommunicator()
|
||||
server = Multiselect()
|
||||
|
||||
with pytest.raises(MultiselectError, match="handshake read timeout"):
|
||||
await server.negotiate(communicator, 2)
|
||||
|
||||
|
||||
class HandshakeThenHangCommunicator(IMultiselectCommunicator):
|
||||
handshaked: bool
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.handshaked = False
|
||||
|
||||
async def write(self, msg_str: str) -> None:
|
||||
if msg_str == "/multistream/1.0.0":
|
||||
self.handshaked = True
|
||||
return
|
||||
|
||||
async def read(self) -> str:
|
||||
if not self.handshaked:
|
||||
return "/multistream/1.0.0"
|
||||
# After handshake, hang on read.
|
||||
await trio.sleep_forever()
|
||||
# Should not be reached.
|
||||
return ""
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_timeout_post_handshake() -> None:
|
||||
communicator = HandshakeThenHangCommunicator()
|
||||
server = Multiselect()
|
||||
with pytest.raises(MultiselectError, match="handshake read timeout"):
|
||||
await server.negotiate(communicator, 1)
|
||||
|
||||
|
||||
class MockCommunicator(IMultiselectCommunicator):
|
||||
def __init__(self, commands_to_read: list[str]):
|
||||
self.read_queue = deque(commands_to_read)
|
||||
self.written_data: list[str] = []
|
||||
|
||||
async def write(self, msg_str: str) -> None:
|
||||
self.written_data.append(msg_str)
|
||||
|
||||
async def read(self) -> str:
|
||||
if not self.read_queue:
|
||||
raise EOFError
|
||||
return self.read_queue.popleft()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_empty_string_command() -> None:
|
||||
# server receives an empty string, which means client wants `None` protocol.
|
||||
server = Multiselect({None: dummy_handler})
|
||||
# Handshake, then empty command
|
||||
communicator = MockCommunicator(["/multistream/1.0.0", ""])
|
||||
protocol, handler = await server.negotiate(communicator)
|
||||
assert protocol is None
|
||||
assert handler == dummy_handler
|
||||
# Check that server sent back handshake and the protocol confirmation (empty string)
|
||||
assert communicator.written_data == ["/multistream/1.0.0", ""]
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_with_none_handler() -> None:
|
||||
# server has None handler, client sends "" to select it.
|
||||
server = Multiselect({None: dummy_handler, TProtocol("/proto1"): dummy_handler})
|
||||
# Handshake, then empty command
|
||||
communicator = MockCommunicator(["/multistream/1.0.0", ""])
|
||||
protocol, handler = await server.negotiate(communicator)
|
||||
assert protocol is None
|
||||
assert handler == dummy_handler
|
||||
# Check written data: handshake, protocol confirmation
|
||||
assert communicator.written_data == ["/multistream/1.0.0", ""]
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_with_none_handler_ls() -> None:
|
||||
# server has None handler, client sends "ls" then empty string.
|
||||
server = Multiselect({None: dummy_handler, TProtocol("/proto1"): dummy_handler})
|
||||
# Handshake, ls, empty command
|
||||
communicator = MockCommunicator(["/multistream/1.0.0", "ls", ""])
|
||||
protocol, handler = await server.negotiate(communicator)
|
||||
assert protocol is None
|
||||
assert handler == dummy_handler
|
||||
# Check written data: handshake, ls response, protocol confirmation
|
||||
assert communicator.written_data[0] == "/multistream/1.0.0"
|
||||
assert "/proto1" in communicator.written_data[1]
|
||||
# Note: `ls` should not list the `None` protocol.
|
||||
assert "None" not in communicator.written_data[1]
|
||||
assert "\n\n" not in communicator.written_data[1]
|
||||
assert communicator.written_data[2] == ""
|
||||
|
||||
@ -159,3 +159,41 @@ async def test_get_protocols_returns_all_registered_protocols():
|
||||
protocols = ms.get_protocols()
|
||||
|
||||
assert set(protocols) == {p1, p2, p3}
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_optional_tprotocol(security_protocol):
|
||||
with pytest.raises(Exception):
|
||||
await perform_simple_test(
|
||||
None,
|
||||
[None],
|
||||
[None],
|
||||
security_protocol,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_optional_tprotocol_client_none_server_no_none(
|
||||
security_protocol,
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
await perform_simple_test(None, [None], [PROTOCOL_ECHO], security_protocol)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_optional_tprotocol_client_none_in_list(security_protocol):
|
||||
expected_selected_protocol = PROTOCOL_ECHO
|
||||
await perform_simple_test(
|
||||
expected_selected_protocol,
|
||||
[None, PROTOCOL_ECHO],
|
||||
[PROTOCOL_ECHO],
|
||||
security_protocol,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_optional_tprotocol_server_none_client_other(
|
||||
security_protocol,
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
await perform_simple_test(None, [PROTOCOL_ECHO], [None], security_protocol)
|
||||
|
||||
Reference in New Issue
Block a user